Coverage for mlprodict/onnxrt/ops_cpu/op_split.py: 94%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from onnx.defs import onnx_opset_version
8from ._op import OpRun
9from ..shape_object import DimensionObject, ShapeObject
12class CommonSplit(OpRun):
13 """
14 Runtime for operator *Split*.
15 """
17 def __init__(self, onnx_node, desc=None,
18 expected_attributes=None, **options):
19 if 'split' not in options:
20 options['split'] = None
21 OpRun.__init__(self, onnx_node, desc=desc,
22 expected_attributes=expected_attributes,
23 **options)
24 self.nb_outputs = len(onnx_node.output)
26 def common_run(self, mat, split): # pylint: disable=W0221
27 if split is None:
28 div = mat.shape[self.axis] // self.nb_outputs
29 split = [div] * self.nb_outputs
30 split[-1] += mat.shape[self.axis] - sum(split)
31 sli = [slice(0, s) for s in mat.shape]
32 res = []
33 pos = 0
34 for spl in split:
35 sli[self.axis] = slice(pos, pos + spl)
36 pos += spl
37 res.append(mat[tuple(sli)])
38 return tuple(res)
40 def common_infer_shapes(self, data, split): # pylint: disable=W0221
41 if split is None:
42 return tuple([ShapeObject(None, dtype=data.dtype)
43 for o in range(self.nb_outputs)])
44 res = []
45 pos = 0
46 for spl in split:
47 shape = data.copy()
48 shape[self.axis] = DimensionObject(spl)
49 pos += spl
50 res.append(shape)
51 return tuple(res)
53 def _infer_types(self, data, split): # pylint: disable=W0221
54 if split is None:
55 return tuple([data for o in range(self.nb_outputs)])
56 return tuple(data for _ in split)
58 def _infer_sizes(self, *args, **kwargs): # pylint: disable=W0221
59 res = self.run(*args, **kwargs)
60 return (dict(temp=0), ) + res
63class Split_2(CommonSplit):
64 """
65 Runtime for operator *Split*.
66 """
68 atts = {'axis': 0, 'split': None}
70 def __init__(self, onnx_node, desc=None, **options):
71 CommonSplit.__init__(self, onnx_node, desc=desc,
72 expected_attributes=Split_2.atts, **options)
74 def _run(self, mat): # pylint: disable=W0221
75 return self.common_run(mat, self.split)
77 def _infer_shapes(self, data): # pylint: disable=W0221
78 return self.common_infer_shapes(data, self.split)
80 def _infer_types(self, data): # pylint: disable=W0221
81 if self.split is None:
82 return tuple([data for o in range(self.nb_outputs)])
83 return tuple(data for _ in self.split)
86class Split_11(Split_2):
87 """
88 Runtime for operator *Split*.
89 """
90 pass
93class Split_13(CommonSplit):
94 """
95 Runtime for operator *Split*.
96 """
98 atts = {'axis': 0}
100 def __init__(self, onnx_node, desc=None, **options):
101 CommonSplit.__init__(self, onnx_node, desc=desc,
102 expected_attributes=Split_13.atts, **options)
104 def _run(self, mat, split=None): # pylint: disable=W0221
105 return self.common_run(mat, split)
107 def _infer_shapes(self, data, split=None): # pylint: disable=W0221
108 return tuple([ShapeObject(None, dtype=data.dtype)
109 for o in range(self.nb_outputs)])
111 def _infer_types(self, data, split=None): # pylint: disable=W0221
112 return tuple(data for o in range(self.nb_outputs))
115if onnx_opset_version() >= 13:
116 Split = Split_13
117elif onnx_opset_version() >= 11: # pragma: no cover
118 Split = Split_11
119else: # pragma: no cover
120 Split = Split_2