Coverage for mlprodict/onnxrt/ops_cpu/op_concat.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ...onnx_tools.onnx2py_helper import guess_numpy_type_from_dtype
9from ._op import OpRun
10from ..shape_object import ShapeObject
13class Concat(OpRun):
15 atts = {'axis': 0}
16 python_inputs = ['*inputs']
18 def __init__(self, onnx_node, desc=None, **options):
19 OpRun.__init__(self, onnx_node, desc=desc,
20 expected_attributes=Concat.atts,
21 **options)
23 def _preprocess(self, a):
24 if len(a.shape) == 0:
25 raise RuntimeError( # pragma: no cover
26 "Concat: one input has an empty shape: %r." % a)
27 if self.axis >= len(a.shape):
28 new_shape = a.shape + (1, ) * (self.axis + 1 - len(a.shape))
29 return a.reshape(new_shape)
30 return a
32 def _run(self, *args): # pylint: disable=W0221
33 targs = tuple(self._preprocess(a) for a in args)
34 return (numpy.concatenate(targs, self.axis), )
36 def _infer_shapes(self, *args): # pylint: disable=W0221
37 return (args[0].concat_columns(self.axis, *(args[1:])), )
39 def _infer_types(self, *args): # pylint: disable=W0221
40 args = [guess_numpy_type_from_dtype(a) for a in args]
41 res = (ShapeObject._infer_merged_type(*args, use_dtype=False), )
42 return res
44 def _infer_sizes(self, *args, **kwargs):
45 res = self.run(*args, **kwargs)
46 return (dict(temp=0), ) + res
48 def to_python(self, inputs):
49 return "import numpy", "return numpy.concatenate(inputs, axis=axis)"