Coverage for mlprodict/onnxrt/ops_cpu/op_reshape.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRun
10from ..shape_object import ShapeObject
13def reshape_reference_implementation(data, shape):
14 new_shape = numpy.copy(shape)
15 zeros_index = numpy.where(shape == 0)
16 if len(data.shape) == 1 and data.shape[0] == 0:
17 reshaped = numpy.reshape(data, shape)
18 else:
19 try:
20 new_shape[zeros_index] = numpy.array(data.shape)[zeros_index]
21 except IndexError as e: # pragma: no cover
22 raise RuntimeError(
23 "Unable to reshape from shape %r to shape %r (or %r)."
24 "" % (data.shape, shape, new_shape)) from e
25 reshaped = numpy.reshape(data, new_shape)
26 return reshaped
29class CommonReshape(OpRun):
31 def __init__(self, onnx_node, desc=None, expected_attributes=None, **options):
32 OpRun.__init__(
33 self, onnx_node, desc=desc,
34 expected_attributes=expected_attributes, **options)
36 def _run(self, data, shape): # pylint: disable=W0221
37 return (reshape_reference_implementation(data, shape), )
39 def _infer_shapes(self, data, shape): # pylint: disable=W0221
40 return (ShapeObject(None, dtype=data.dtype), )
42 def _infer_types(self, data, shape): # pylint: disable=W0221
43 return (data, )
45 def _infer_sizes(self, *args, **kwargs):
46 res = self.run(*args, **kwargs)
47 return (dict(temp=0), ) + res
50class Reshape_5(CommonReshape):
52 def __init__(self, onnx_node, desc=None, expected_attributes=None, **options):
53 CommonReshape.__init__(self, onnx_node, desc=desc, **options)
56class Reshape_13(Reshape_5):
57 pass
60class Reshape_14(CommonReshape):
62 atts = {'allowzero': 0}
64 def __init__(self, onnx_node, desc=None, **options):
65 CommonReshape.__init__(
66 self, onnx_node, desc=desc,
67 expected_attributes=Reshape_14.atts, **options)
70if onnx_opset_version() >= 14:
71 Reshape = Reshape_14
72else:
73 Reshape = Reshape_5 # pragma: no cover