Coverage for mlprodict/onnxrt/ops_cpu/op_squeeze.py: 88%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ..shape_object import ShapeObject
10from ._op import OpRunUnaryNum, OpRun
13class Squeeze_1(OpRunUnaryNum):
15 atts = {'axes': [], 'keepdims': 1}
17 def __init__(self, onnx_node, desc=None, **options):
18 OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
19 expected_attributes=Squeeze_1.atts,
20 **options)
21 if isinstance(self.axes, numpy.ndarray):
22 self.axes = tuple(self.axes)
23 elif self.axes in [[], tuple()]:
24 self.axes = None
25 elif isinstance(self.axes, list):
26 self.axes = tuple(self.axes)
28 def _run(self, data): # pylint: disable=W0221
29 if isinstance(self.axes, (tuple, list)):
30 sq = data
31 for a in reversed(self.axes):
32 sq = numpy.squeeze(sq, axis=a)
33 else:
34 sq = numpy.squeeze(data, axis=self.axes)
35 return (sq, )
37 def _infer_shapes(self, x): # pylint: disable=W0221
38 return (x.squeeze(axis=self.axes), )
40 def _infer_types(self, x): # pylint: disable=W0221
41 return (x, )
43 def _infer_sizes(self, *args, **kwargs):
44 res = self.run(*args, **kwargs)
45 return (dict(temp=0), ) + res
48class Squeeze_11(Squeeze_1):
49 pass
52class Squeeze_13(OpRun):
54 atts = {'keepdims': 1}
56 def __init__(self, onnx_node, desc=None, **options):
57 OpRun.__init__(self, onnx_node, desc=desc,
58 expected_attributes=Squeeze_13.atts,
59 **options)
60 self.axes = None
62 def _run(self, data, axes=None): # pylint: disable=W0221
63 if axes is not None:
64 if hasattr(axes, '__iter__'):
65 sq = numpy.squeeze(data, axis=tuple(axes))
66 else:
67 sq = numpy.squeeze(data, axis=axes)
68 else:
69 sq = numpy.squeeze(data)
70 return (sq, )
72 def _infer_shapes(self, x, axes=None): # pylint: disable=W0221
73 return (ShapeObject(None, dtype=x.dtype), )
75 def _infer_types(self, x, axes=None): # pylint: disable=W0221
76 return (x, )
78 def _infer_sizes(self, *args, **kwargs):
79 res = self.run(*args, **kwargs)
80 return (dict(temp=0), ) + res
83if onnx_opset_version() >= 13:
84 Squeeze = Squeeze_13
85elif onnx_opset_version() >= 11: # pragma: no cover
86 Squeeze = Squeeze_11
87else: # pragma: no cover
88 Squeeze = Squeeze_1