Coverage for mlprodict/onnxrt/ops_cpu/op_scan.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

66 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from ..shape_object import ShapeObject 

10 

11 

12class Scan(OpRun): 

13 

14 atts = { 

15 'body': None, 

16 'num_scan_inputs': None, 

17 'scan_input_axes': [], 

18 'scan_input_directions': [], 

19 'scan_output_axes': [], 

20 'scan_output_directions': [] 

21 } 

22 

23 def __init__(self, onnx_node, desc=None, **options): 

24 OpRun.__init__(self, onnx_node, desc=desc, 

25 expected_attributes=Scan.atts, 

26 **options) 

27 if not hasattr(self.body, 'run'): 

28 raise RuntimeError( # pragma: no cover 

29 "Parameter 'body' must have a method 'run', " 

30 "type {}.".format(type(self.body))) 

31 self.input_directions_ = [0 if i >= len(self.scan_input_directions) else self.scan_input_directions[i] 

32 for i in range(self.num_scan_inputs)] 

33 max_dir_in = max(self.input_directions_) 

34 if max_dir_in != 0: 

35 raise RuntimeError( # pragma: no cover 

36 "Scan is not implemented for other output input_direction than 0.") 

37 self.input_axes_ = [0 if i >= len(self.scan_input_axes) else self.scan_input_axes[i] 

38 for i in range(self.num_scan_inputs)] 

39 max_axe_in = max(self.input_axes_) 

40 if max_axe_in != 0: 

41 raise RuntimeError( # pragma: no cover 

42 "Scan is not implemented for other input axes than 0.") 

43 self.input_names = self.body.input_names 

44 self.output_names = self.body.output_names 

45 self._run_meth = (self.body.run_in_scan 

46 if hasattr(self.body, 'run_in_scan') 

47 else self.body.run) 

48 

49 def _common_run_shape(self, *args): 

50 num_loop_state_vars = len(args) - self.num_scan_inputs 

51 num_scan_outputs = len(args) - num_loop_state_vars 

52 

53 output_directions = [0 if i >= len(self.scan_output_directions) else self.scan_output_directions[i] 

54 for i in range(num_scan_outputs)] 

55 max_dir_out = max(output_directions) 

56 if max_dir_out != 0: 

57 raise RuntimeError( # pragma: no cover 

58 "Scan is not implemented for other output output_direction than 0.") 

59 output_axes = [0 if i >= len(self.scan_output_axes) else self.scan_output_axes[i] 

60 for i in range(num_scan_outputs)] 

61 max_axe_out = max(output_axes) 

62 if max_axe_out != 0: 

63 raise RuntimeError( # pragma: no cover 

64 "Scan is not implemented for other output axes than 0.") 

65 

66 state_names_in = self.input_names[:self.num_scan_inputs] 

67 state_names_out = self.output_names[:len(state_names_in)] 

68 scan_names_in = self.input_names[num_loop_state_vars:] 

69 scan_names_out = self.output_names[num_loop_state_vars:] 

70 scan_values = args[num_loop_state_vars:] 

71 

72 states = args[:num_loop_state_vars] 

73 

74 return (num_loop_state_vars, num_scan_outputs, output_directions, 

75 max_dir_out, output_axes, max_axe_out, state_names_in, 

76 state_names_out, scan_names_in, scan_names_out, 

77 scan_values, states) 

78 

79 def _run(self, *args): # pylint: disable=W0221 

80 (num_loop_state_vars, num_scan_outputs, output_directions, # pylint: disable=W0612 

81 max_dir_out, output_axes, max_axe_out, state_names_in, # pylint: disable=W0612 

82 state_names_out, scan_names_in, scan_names_out, # pylint: disable=W0612 

83 scan_values, states) = self._common_run_shape(*args) # pylint: disable=W0612 

84 

85 max_iter = args[num_loop_state_vars].shape[self.input_axes_[0]] 

86 results = [[] for _ in scan_names_out] 

87 

88 for iter in range(max_iter): 

89 inputs = {} 

90 for name, value in zip(state_names_in, states): 

91 inputs[name] = value 

92 for name, value in zip(scan_names_in, scan_values): 

93 inputs[name] = value[iter] 

94 

95 try: 

96 outputs = self._run_meth(inputs) 

97 except TypeError as e: # pragma: no cover 

98 raise TypeError( 

99 "Unable to call 'run' for type '{}'.".format( 

100 type(self.body))) from e 

101 

102 states = [outputs[name] for name in state_names_out] 

103 for i, name in enumerate(scan_names_out): 

104 results[i].append(numpy.expand_dims(outputs[name], axis=0)) 

105 

106 for res in results: 

107 conc = numpy.vstack(res) 

108 states.append(conc) 

109 return tuple(states) 

110 

111 def _infer_shapes(self, *args): # pylint: disable=W0221 

112 (num_loop_state_vars, num_scan_outputs, output_directions, # pylint: disable=W0612 

113 max_dir_out, output_axes, max_axe_out, state_names_in, # pylint: disable=W0612 

114 state_names_out, scan_names_in, scan_names_out, # pylint: disable=W0612 

115 scan_values, states) = self._common_run_shape(*args) # pylint: disable=W0612 

116 

117 shapes = list(states) 

118 

119 shape = args[num_loop_state_vars].shape 

120 if shape is None: 

121 for sout in scan_values: 

122 shapes.append(ShapeObject(None, dtype=sout.dtype)) 

123 else: 

124 max_iter = shape[self.input_axes_[0]] 

125 for sout in scan_values: 

126 sc = sout.copy() 

127 sc[0] = max_iter 

128 shapes.append(sc) 

129 

130 return tuple(shapes)