Coverage for mlprodict/onnxrt/ops_cpu/op_scan.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from ..shape_object import ShapeObject
12class Scan(OpRun):
14 atts = {
15 'body': None,
16 'num_scan_inputs': None,
17 'scan_input_axes': [],
18 'scan_input_directions': [],
19 'scan_output_axes': [],
20 'scan_output_directions': []
21 }
23 def __init__(self, onnx_node, desc=None, **options):
24 OpRun.__init__(self, onnx_node, desc=desc,
25 expected_attributes=Scan.atts,
26 **options)
27 if not hasattr(self.body, 'run'):
28 raise RuntimeError( # pragma: no cover
29 "Parameter 'body' must have a method 'run', "
30 "type {}.".format(type(self.body)))
31 self.input_directions_ = [0 if i >= len(self.scan_input_directions) else self.scan_input_directions[i]
32 for i in range(self.num_scan_inputs)]
33 max_dir_in = max(self.input_directions_)
34 if max_dir_in != 0:
35 raise RuntimeError( # pragma: no cover
36 "Scan is not implemented for other output input_direction than 0.")
37 self.input_axes_ = [0 if i >= len(self.scan_input_axes) else self.scan_input_axes[i]
38 for i in range(self.num_scan_inputs)]
39 max_axe_in = max(self.input_axes_)
40 if max_axe_in != 0:
41 raise RuntimeError( # pragma: no cover
42 "Scan is not implemented for other input axes than 0.")
43 self.input_names = self.body.input_names
44 self.output_names = self.body.output_names
45 self._run_meth = (self.body.run_in_scan
46 if hasattr(self.body, 'run_in_scan')
47 else self.body.run)
49 def _common_run_shape(self, *args):
50 num_loop_state_vars = len(args) - self.num_scan_inputs
51 num_scan_outputs = len(args) - num_loop_state_vars
53 output_directions = [0 if i >= len(self.scan_output_directions) else self.scan_output_directions[i]
54 for i in range(num_scan_outputs)]
55 max_dir_out = max(output_directions)
56 if max_dir_out != 0:
57 raise RuntimeError( # pragma: no cover
58 "Scan is not implemented for other output output_direction than 0.")
59 output_axes = [0 if i >= len(self.scan_output_axes) else self.scan_output_axes[i]
60 for i in range(num_scan_outputs)]
61 max_axe_out = max(output_axes)
62 if max_axe_out != 0:
63 raise RuntimeError( # pragma: no cover
64 "Scan is not implemented for other output axes than 0.")
66 state_names_in = self.input_names[:self.num_scan_inputs]
67 state_names_out = self.output_names[:len(state_names_in)]
68 scan_names_in = self.input_names[num_loop_state_vars:]
69 scan_names_out = self.output_names[num_loop_state_vars:]
70 scan_values = args[num_loop_state_vars:]
72 states = args[:num_loop_state_vars]
74 return (num_loop_state_vars, num_scan_outputs, output_directions,
75 max_dir_out, output_axes, max_axe_out, state_names_in,
76 state_names_out, scan_names_in, scan_names_out,
77 scan_values, states)
79 def _run(self, *args): # pylint: disable=W0221
80 (num_loop_state_vars, num_scan_outputs, output_directions, # pylint: disable=W0612
81 max_dir_out, output_axes, max_axe_out, state_names_in, # pylint: disable=W0612
82 state_names_out, scan_names_in, scan_names_out, # pylint: disable=W0612
83 scan_values, states) = self._common_run_shape(*args) # pylint: disable=W0612
85 max_iter = args[num_loop_state_vars].shape[self.input_axes_[0]]
86 results = [[] for _ in scan_names_out]
88 for iter in range(max_iter):
89 inputs = {}
90 for name, value in zip(state_names_in, states):
91 inputs[name] = value
92 for name, value in zip(scan_names_in, scan_values):
93 inputs[name] = value[iter]
95 try:
96 outputs = self._run_meth(inputs)
97 except TypeError as e: # pragma: no cover
98 raise TypeError(
99 "Unable to call 'run' for type '{}'.".format(
100 type(self.body))) from e
102 states = [outputs[name] for name in state_names_out]
103 for i, name in enumerate(scan_names_out):
104 results[i].append(numpy.expand_dims(outputs[name], axis=0))
106 for res in results:
107 conc = numpy.vstack(res)
108 states.append(conc)
109 return tuple(states)
111 def _infer_shapes(self, *args): # pylint: disable=W0221
112 (num_loop_state_vars, num_scan_outputs, output_directions, # pylint: disable=W0612
113 max_dir_out, output_axes, max_axe_out, state_names_in, # pylint: disable=W0612
114 state_names_out, scan_names_in, scan_names_out, # pylint: disable=W0612
115 scan_values, states) = self._common_run_shape(*args) # pylint: disable=W0612
117 shapes = list(states)
119 shape = args[num_loop_state_vars].shape
120 if shape is None:
121 for sout in scan_values:
122 shapes.append(ShapeObject(None, dtype=sout.dtype))
123 else:
124 max_iter = shape[self.input_axes_[0]]
125 for sout in scan_values:
126 sc = sout.copy()
127 sc[0] = max_iter
128 shapes.append(sc)
130 return tuple(shapes)