Coverage for mlprodict/onnxrt/ops_cpu/op_cum_sum.py: 91%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

45 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class CumSum(OpRun): 

12 

13 atts = {'exclusive': 0, 'reverse': 0} 

14 python_inputs = ['x', 'axis=None'] 

15 

16 def __init__(self, onnx_node, desc=None, **options): 

17 OpRun.__init__(self, onnx_node, desc=desc, 

18 expected_attributes=CumSum.atts, 

19 **options) 

20 

21 def _run(self, x, *axis): # pylint: disable=W0221 

22 axis = None if len(axis) == 0 else axis[0] 

23 if axis is None: 

24 if self.reverse or self.exclusive: 

25 raise NotImplementedError( # pragma no cover 

26 'reverse=1 or exclusive=1 not implemented') 

27 if self.inplaces.get(0, False): 

28 return (numpy.cumsum(x, out=x), ) 

29 return (numpy.cumsum(x), ) 

30 if not isinstance(axis, (numpy.int32, numpy.int64)): 

31 if (len(axis.shape) > 1 or 

32 (len(axis.shape) > 0 and axis.shape[0] != 1)): 

33 raise RuntimeError( # pragma no cover 

34 "axis must be an array of one number not {} " 

35 "(shape {})".format(axis, axis.shape)) 

36 if len(axis.shape) > 0: 

37 axis = axis[0] # pylint: disable=E1136 

38 if self.reverse: 

39 rev_indices = [slice(0, s) for s in x.shape] 

40 rev_indices[axis] = slice(None, None, -1) 

41 x = x[rev_indices] 

42 if self.exclusive: 

43 indices_c = [slice(0, s) for s in x.shape] 

44 indices_d = [slice(0, s) for s in x.shape] 

45 indices_c[axis] = slice(0, -1) 

46 indices_d[axis] = slice(1, x.shape[axis]) 

47 res = numpy.zeros(x.shape, dtype=x.dtype) 

48 numpy.cumsum(x[indices_c], axis=axis, out=res[indices_d]) 

49 else: 

50 if self.inplaces.get(0, False): 

51 res = numpy.cumsum(x, axis=axis, out=x) 

52 else: 

53 res = numpy.cumsum(x, axis=axis) 

54 if self.reverse: 

55 res = res[rev_indices] 

56 return (res, ) 

57 

58 def _infer_shapes(self, x, *axis): # pylint: disable=W0221 

59 return (x, ) 

60 

61 def _infer_types(self, x, *axis): # pylint: disable=W0221 

62 return (x, ) 

63 

64 def _infer_sizes(self, *args, **kwargs): 

65 res = self.run(*args, **kwargs) 

66 return (dict(temp=0), ) + res 

67 

68 def to_python(self, inputs): 

69 lines = ['if exclusive or reverse:', 

70 ' raise NotImplementedError("reverse=1 or exclusive=1 not implemente")', 

71 'if axis is None:', 

72 ' return numpy.cumsum(x)', 

73 'return numpy.cumsum(x, axis=axis[0])'] 

74 return 'import numpy', "\n".join(lines)