Coverage for mlprodict/onnxrt/ops_cpu/op_cum_sum.py: 91%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11class CumSum(OpRun):
13 atts = {'exclusive': 0, 'reverse': 0}
14 python_inputs = ['x', 'axis=None']
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRun.__init__(self, onnx_node, desc=desc,
18 expected_attributes=CumSum.atts,
19 **options)
21 def _run(self, x, *axis): # pylint: disable=W0221
22 axis = None if len(axis) == 0 else axis[0]
23 if axis is None:
24 if self.reverse or self.exclusive:
25 raise NotImplementedError( # pragma no cover
26 'reverse=1 or exclusive=1 not implemented')
27 if self.inplaces.get(0, False):
28 return (numpy.cumsum(x, out=x), )
29 return (numpy.cumsum(x), )
30 if not isinstance(axis, (numpy.int32, numpy.int64)):
31 if (len(axis.shape) > 1 or
32 (len(axis.shape) > 0 and axis.shape[0] != 1)):
33 raise RuntimeError( # pragma no cover
34 "axis must be an array of one number not {} "
35 "(shape {})".format(axis, axis.shape))
36 if len(axis.shape) > 0:
37 axis = axis[0] # pylint: disable=E1136
38 if self.reverse:
39 rev_indices = [slice(0, s) for s in x.shape]
40 rev_indices[axis] = slice(None, None, -1)
41 x = x[rev_indices]
42 if self.exclusive:
43 indices_c = [slice(0, s) for s in x.shape]
44 indices_d = [slice(0, s) for s in x.shape]
45 indices_c[axis] = slice(0, -1)
46 indices_d[axis] = slice(1, x.shape[axis])
47 res = numpy.zeros(x.shape, dtype=x.dtype)
48 numpy.cumsum(x[indices_c], axis=axis, out=res[indices_d])
49 else:
50 if self.inplaces.get(0, False):
51 res = numpy.cumsum(x, axis=axis, out=x)
52 else:
53 res = numpy.cumsum(x, axis=axis)
54 if self.reverse:
55 res = res[rev_indices]
56 return (res, )
58 def _infer_shapes(self, x, *axis): # pylint: disable=W0221
59 return (x, )
61 def _infer_types(self, x, *axis): # pylint: disable=W0221
62 return (x, )
64 def _infer_sizes(self, *args, **kwargs):
65 res = self.run(*args, **kwargs)
66 return (dict(temp=0), ) + res
68 def to_python(self, inputs):
69 lines = ['if exclusive or reverse:',
70 ' raise NotImplementedError("reverse=1 or exclusive=1 not implemente")',
71 'if axis is None:',
72 ' return numpy.cumsum(x)',
73 'return numpy.cumsum(x, axis=axis[0])']
74 return 'import numpy', "\n".join(lines)