Coverage for mlprodict/onnxrt/ops_cpu/op_loop.py: 86%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
7.. versionadded:: 0.7
8"""
9import numpy
10from ._op import OpRun
11from ..shape_object import ShapeObject
14class Loop(OpRun):
16 atts = {
17 'body': None,
18 }
20 def __init__(self, onnx_node, desc=None, **options):
21 OpRun.__init__(self, onnx_node, desc=desc,
22 expected_attributes=Loop.atts,
23 **options)
24 if not hasattr(self.body, 'run'):
25 raise RuntimeError( # pragma: no cover
26 "Parameter 'body' must have a method 'run', "
27 "type {}.".format(type(self.body)))
29 self._run_meth = (self.body.run_in_scan
30 if hasattr(self.body, 'run_in_scan')
31 else self.body.run)
32 self.additional_inputs = self.body.static_inputs
34 def need_context(self):
35 """
36 The operator Loop needs to know all results produced
37 so far as the loop may silently access one of them.
38 Some information are not always referred in the list of inputs
39 (kind of static variables).
40 """
41 return len(self.additional_inputs) > 0
43 def _run(self, M, cond, v_initial, *args, callback=None, context=None): # pylint: disable=W0221
44 loop_inputs = self.body.input_names
45 inputs = {name: None for name in loop_inputs}
46 inputs[loop_inputs[2]] = v_initial
47 cond_name = self.body.output_names[0]
48 if len(args) > 0:
49 begin = len(loop_inputs) - len(args)
50 all_inputs = loop_inputs[begin:]
51 for name, val in zip(all_inputs, args):
52 inputs[name] = val
53 if len(self.additional_inputs) > 0:
54 if context is None:
55 raise RuntimeError(
56 "Additional inputs %r are missing and context is None."
57 "" % (self.additional_inputs, ))
58 for a in self.additional_inputs:
59 if a in context:
60 inputs[a] = context[a]
61 else:
62 raise RuntimeError(
63 "Additional inputs %r not found in context\n%s." % (
64 a, "\n".join(sorted(map(str, context)))))
66 it = 0
67 while cond and it < M:
68 inputs[self.body.input_names[0]] = numpy.array(it, dtype=M.dtype)
69 inputs[self.body.input_names[1]] = cond
70 outputs = self._run_meth(inputs)
71 cond = outputs[cond_name]
72 if cond is None:
73 raise RuntimeError(
74 "condition %r returned by the subgraph cannot be None."
75 "" % cond_name)
76 for i, o in zip(self.body.input_names[2:],
77 self.body.output_names[1:]):
78 inputs[i] = outputs[o]
79 if callback is not None:
80 callback(inputs, context=context)
81 it += 1
83 if it == 0:
84 outputs = {self.body.output_names[1]: cond}
85 for i, o in zip(self.body.input_names[2:],
86 self.body.output_names[1:]):
87 outputs[o] = inputs[i]
88 for o in self.body.output_names:
89 if o not in outputs:
90 outputs[o] = numpy.empty(shape=tuple())
91 res = tuple([outputs[name] for name in self.body.output_names[1:]])
92 if any(r is None for r in res):
93 raise TypeError( # pragma: no cover
94 "Operator Loop produces a None value.")
95 return res
97 def _infer_shapes(self, M, cond, v_initial, *args): # pylint: disable=W0221
98 res = self.body._set_shape_inference_runtime()
99 outputs = {k[0]: k[1:] for k in self.body.output_names_shapes_types}
100 ret = []
101 for name in self.body.output_names[1:]:
102 if name in res:
103 ret.append(res[name])
104 else:
105 find = outputs[name]
106 ret.append(ShapeObject(find[0], dtype=find[1]))
107 return tuple(ret)
109 def _infer_types(self, M, cond, v_initial, *args): # pylint: disable=W0221
110 res = self.body._set_type_inference_runtime()
111 return tuple([res[name] for name in self.body.output_names[1:]])
113 def _infer_sizes(self, M, cond, v_initial, *args, context=None): # pylint: disable=W0221
114 store = []
116 def callback_(inputs, context=None):
117 res = self.body.infer_sizes(inputs, context=context)
118 store.append(res)
120 res = self._run(M, cond, v_initial, *args, callback=callback_,
121 context=context)
123 temp = 0
124 for v in store:
125 for vv in v.values():
126 temp += sum(vv.values())
127 return (dict(temp=temp), ) + res