Coverage for mlprodict/onnxrt/ops_cpu/op_unsqueeze.py: 87%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ..shape_object import ShapeObject
10from ._op import OpRunUnaryNum, OpRun
13class Unsqueeze_1(OpRunUnaryNum):
15 atts = {'axes': [], 'keepdims': 1}
17 def __init__(self, onnx_node, desc=None, **options):
18 OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
19 expected_attributes=Unsqueeze_1.atts,
20 **options)
21 if isinstance(self.axes, numpy.ndarray):
22 self.axes = tuple(self.axes)
23 elif self.axes in [[], tuple()]:
24 self.axes = None
25 elif isinstance(self.axes, list):
26 self.axes = tuple(self.axes)
28 def _run(self, data): # pylint: disable=W0221
29 if isinstance(self.axes, (tuple, list)):
30 sq = data
31 for a in self.axes:
32 sq = numpy.expand_dims(sq, axis=a)
33 else:
34 raise RuntimeError( # pragma: no cover
35 "axes cannot be None for operator Unsqueeze (Unsqueeze_1).")
36 return (sq, )
38 def _infer_shapes(self, x): # pylint: disable=W0221
39 return (x.unsqueeze(axes=self.axes), )
41 def _infer_types(self, x): # pylint: disable=W0221
42 return (x, )
44 def _infer_sizes(self, *args, **kwargs):
45 res = self.run(*args, **kwargs)
46 return (dict(temp=0), ) + res
49class Unsqueeze_11(Unsqueeze_1):
50 pass
53class Unsqueeze_13(OpRun):
55 atts = {'keepdims': 1}
57 def __init__(self, onnx_node, desc=None, **options):
58 OpRun.__init__(self, onnx_node, desc=desc,
59 expected_attributes=Unsqueeze_13.atts,
60 **options)
61 self.axes = None
63 def _run(self, data, axes=None): # pylint: disable=W0221
64 if axes is not None:
65 if hasattr(axes, '__iter__') and len(axes.shape) > 0:
66 sq = numpy.expand_dims(data, axis=tuple(axes))
67 else:
68 sq = numpy.expand_dims(data, axis=axes)
69 else:
70 raise RuntimeError( # pragma: no cover
71 "axes cannot be None for operator Unsqueeze (Unsqueeze_13).")
72 return (sq, )
74 def _infer_shapes(self, x, axes=None): # pylint: disable=W0221
75 return (ShapeObject(None, dtype=x.dtype), )
77 def _infer_types(self, x, axes=None): # pylint: disable=W0221
78 return (x, )
80 def _infer_sizes(self, *args, **kwargs):
81 res = self.run(*args, **kwargs)
82 return (dict(temp=0), ) + res
85if onnx_opset_version() >= 13:
86 Unsqueeze = Unsqueeze_13
87elif onnx_opset_version() >= 11:
88 Unsqueeze = Unsqueeze_11
89else:
90 Unsqueeze = Unsqueeze_1