Coverage for mlprodict/onnxrt/ops_cpu/op_reduce_sum.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRunReduceNumpy, RuntimeTypeError, OpRun
12class ReduceSum_1(OpRunReduceNumpy):
14 atts = {'axes': [], 'keepdims': 1}
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRunReduceNumpy.__init__(self, onnx_node, desc=desc,
18 expected_attributes=ReduceSum_1.atts,
19 **options)
21 def _run(self, data): # pylint: disable=W0221
22 return (numpy.sum(data, axis=self.axes,
23 keepdims=self.keepdims,
24 dtype=data.dtype), )
27class ReduceSum_11(ReduceSum_1):
29 def __init__(self, onnx_node, desc=None, **options):
30 ReduceSum_1.__init__(self, onnx_node, desc=desc, **options)
33class ReduceSum_13(OpRunReduceNumpy):
35 atts = {'axes': [], 'keepdims': 1, 'noop_with_empty_axes': 0}
37 def __init__(self, onnx_node, desc=None, **options):
38 OpRunReduceNumpy.__init__(self, onnx_node, desc=desc,
39 expected_attributes=ReduceSum_13.atts,
40 **options)
42 def run(self, data, axes=None): # pylint: disable=E0202,W0221,W0237
43 """
44 Calls method ``_run``.
45 """
46 res = self._run(data, axes=axes)
47 if not self.keepdims and not isinstance(res[0], numpy.ndarray):
48 res = (numpy.array([res[0]], dtype=res[0].dtype), )
49 if res[0].dtype != data.dtype:
50 raise RuntimeTypeError( # pragma: no cover
51 "Output type mismatch: input '{}' != output '{}' "
52 "(operator '{}')".format(
53 data.dtype, res[0].dtype, self.__class__.__name__))
54 return res
56 def _run_no_checks_(self, x, axes=None): # pylint: disable=W0221
57 return OpRun.run(self, x, axes)
59 def _run(self, data, axes=None): # pylint: disable=W0221
60 if ((axes is None or len(axes.shape) == 0 or axes.shape[0] == 0) and
61 self.noop_with_empty_axes):
62 return (data, )
63 if ((axes is not None and len(axes.shape) > 0 and axes.shape[0] > 0) and
64 not isinstance(axes, int)):
65 if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0:
66 axes = int(axes)
67 else:
68 axes = tuple(axes.ravel().tolist()) if len(axes) > 0 else None
69 try:
70 return (numpy.sum(data, axis=axes if axes else None,
71 keepdims=self.keepdims,
72 dtype=data.dtype), )
73 except TypeError as e: # pragma: no cover
74 raise TypeError(
75 "Unable to reduce shape %r with axes=%r." % (
76 data.shape, axes)) from e
78 def infer_shapes(self, data, axes=None): # pylint: disable=E0202,W0221,W0237
79 return self._infer_shapes(data, axes=axes)
81 def _infer_shapes(self, data, axes=None): # pylint: disable=W0221,W0237
82 """
83 Returns the same shape by default.
84 """
85 sh = data.reduce(axes, self.keepdims, # pylint: disable=E1101
86 dtype=numpy.int64) # pylint: disable=E1101
87 return (sh, )
89 def infer_types(self, data, axes=None): # pylint: disable=E0202,W0221,W0237
90 return self._infer_types(data, axes=axes)
92 def _infer_types(self, data, axes=None): # pylint: disable=W0221,W0237
93 return (data, )
95 def _infer_sizes(self, *args, **kwargs): # pylint: disable=W0221
96 res = self.run(*args, **kwargs)
97 return (dict(temp=0), ) + res
100if onnx_opset_version() >= 13:
101 ReduceSum = ReduceSum_13
102elif onnx_opset_version() >= 11: # pragma: no cover
103 ReduceSum = ReduceSum_11
104else: # pragma: no cover
105 ReduceSum = ReduceSum_1