Coverage for mlprodict/onnxrt/ops_cpu/op_if.py: 72%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from ...onnx_tools.onnx2py_helper import guess_dtype
8from ..shape_object import ShapeObject
9from ._op import OpRun
12class If(OpRun):
14 atts = {
15 'then_branch': None,
16 'else_branch': None,
17 }
19 def __init__(self, onnx_node, desc=None, **options):
20 OpRun.__init__(self, onnx_node, desc=desc,
21 expected_attributes=If.atts,
22 **options)
23 if not hasattr(self.then_branch, 'run'):
24 raise RuntimeError( # pragma: no cover
25 "Parameter 'then_branch' must have a method 'run', "
26 "type {}.".format(type(self.then_branch)))
27 if not hasattr(self.else_branch, 'run'):
28 raise RuntimeError( # pragma: no cover
29 "Parameter 'else_branch' must have a method 'run', "
30 "type {}.".format(type(self.else_branch)))
32 self._run_meth_then = (self.then_branch.run_in_scan
33 if hasattr(self.then_branch, 'run_in_scan')
34 else self.then_branch.run)
35 self._run_meth_else = (self.else_branch.run_in_scan
36 if hasattr(self.else_branch, 'run_in_scan')
37 else self.else_branch.run)
39 def _run(self, cond, named_inputs=None): # pylint: disable=W0221
40 if named_inputs is None:
41 named_inputs = {}
42 if len(self.then_branch.input_names) > 0:
43 if len(named_inputs) == 0:
44 raise RuntimeError( # pragma: no cover
45 "named_inputs is empty but the graph needs {}.".format(
46 self.then_branch.input_names))
47 for k in self.then_branch.input_names:
48 if k not in named_inputs:
49 raise RuntimeError( # pragma: no cover
50 "Unable to find named input '{}' in\n{}.".format(
51 k, "\n".join(sorted(named_inputs))))
52 if len(self.else_branch.input_names) > 0:
53 if len(named_inputs) == 0:
54 raise RuntimeError( # pragma: no cover
55 "named_inputs is empty but the graph needs {}.".format(
56 self.then_branch.input_names))
57 for k in self.else_branch.input_names:
58 if k not in named_inputs:
59 raise RuntimeError( # pragma: no cover
60 "Unable to find named input '{}' in\n{}.".format(
61 k, "\n".join(sorted(named_inputs))))
63 if len(cond.shape) > 0:
64 if all(cond):
65 outputs = self._run_meth_then(named_inputs)
66 return tuple([outputs[name] for name in self.then_branch.output_names])
67 elif cond:
68 outputs = self._run_meth_then(named_inputs)
69 return tuple([outputs[name] for name in self.then_branch.output_names])
70 outputs = self._run_meth_else(named_inputs)
71 return tuple([outputs[name] for name in self.else_branch.output_names])
73 def _pick_shape(self, res, name):
74 if name in res:
75 return res[name]
76 out = {o.name: o for o in self.then_branch.obj.graph.output}
77 if name not in out:
78 raise ValueError( # pragma: no cover
79 "Unable to find name=%r in %r or %r." % (
80 name, list(sorted(res)), list(sorted(out))))
81 dt = out[name].type.tensor_type.elem_type
82 return ShapeObject(None, guess_dtype(dt))
84 def _infer_shapes(self, cond, named_inputs=None): # pylint: disable=W0221
85 res = self.then_branch._set_shape_inference_runtime()
86 return tuple([self._pick_shape(res, name)
87 for name in self.then_branch.output_names])
89 def _pick_type(self, res, name):
90 if name in res:
91 return res[name]
92 out = {o.name: o for o in self.then_branch.obj.graph.output}
93 if name not in out:
94 raise ValueError(
95 "Unable to find name=%r in %r or %r." % (
96 name, list(sorted(res)), list(sorted(out))))
97 dt = out[name].type.tensor_type.elem_type
98 return guess_dtype(dt)
100 def _infer_types(self, cond, named_inputs=None): # pylint: disable=W0221
101 res = self.then_branch._set_type_inference_runtime()
102 return tuple([self._pick_type(res, name)
103 for name in self.then_branch.output_names])