Coverage for mlprodict/onnxrt/onnx_inference_node.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief OnnxInferenceNode definition.
4"""
5import sys
6import pprint
7import numpy
8from onnx import onnx_pb as onnx_proto
9from onnx.onnx_cpp2py_export.defs import SchemaError # pylint: disable=E0401,E0611
10from ..onnx_tools.onnx2py_helper import get_onnx_schema
11from .excs import MissingOperatorError
12from .ops import load_op
15class OnnxInferenceNode:
16 """
17 A node to execute.
18 """
19 class OnnxInferenceWrapper:
20 """
21 Wraps @see cl OnnxInference in a wrapper and exposes
22 the necessary function.
24 :param oinf: instance of @see cl OnnxInference
25 """
27 def __init__(self, oinf):
28 if oinf is None:
29 raise ValueError( # pragma: no cover
30 "oinf cannot be None.")
31 self.oinf = oinf
33 @property
34 def args_default(self):
35 "Returns the list of default arguments."
36 return []
38 @property
39 def args_default_modified(self):
40 "Returns the list of modified arguments."
41 return []
43 @property
44 def args_mandatory(self):
45 "Returns the list of mandatory arguments."
46 return self.oinf.input_names
48 @property
49 def args_optional(self):
50 "Returns the list of optional arguments."
51 return []
53 @property
54 def obj(self):
55 "Returns the ONNX graph."
56 return self.oinf.obj
58 def run(self, *args, **kwargs):
59 "Calls run."
60 return self.oinf.run(*args, **kwargs)
62 def to_python(self, inputs, *args, **kwargs):
63 "Calls to_python."
64 res = self.oinf.to_python(*args, **kwargs)
65 if len(res) != 1:
66 raise NotImplementedError( # pragma: no cover
67 "Not implemented if the code has multiple files.")
68 keys = list(res)
69 value = res[keys[0]]
70 lines = value.split('\n')
71 last = 0
72 for i, line in enumerate(lines):
73 if line.startswith('def '):
74 last = i - 1
75 break
76 imports = '\n'.join(
77 line for line in lines[:last] if 'import ' in line)
78 lines.append('')
79 lines.append("return OnnxPythonInference().run(%s)" %
80 ', '.join(inputs))
81 code = '\n'.join(lines[last:])
82 return imports, code
84 def need_context(self):
85 "Needs context?"
86 return False
88 def infer_types(self, *args):
89 "Calls infer_types."
90 res = self.oinf.infer_types(args)
91 names = self.oinf.obj.output
92 dtypes = [res[n] for n in names]
93 return tuple(dtypes)
95 def infer_sizes(self, *args):
96 "Calls infer_sizes."
97 values = {name: value
98 for name, value in zip(self.oinf.input_names, args)}
99 res = self.oinf.infer_sizes(values)
100 names = self.oinf.obj.output
101 sizes = [res.get(n, 0) for n in names]
102 return (res['#'], ) + tuple(sizes)
104 def enable_inplace_compute(self, index):
105 "Not implemented."
106 pass
108 def __init__(self, onnx_node, desc, global_index):
109 """
110 @param onnx_node onnx_node
111 @param desc internal description
112 @param global_index it is a function which returns a unique index
113 for the output this operator generates
114 """
115 if desc is None:
116 raise ValueError("desc should not be None.") # pragma: no cover
117 self.desc = desc
118 self.onnx_node = onnx_node
119 self._init(global_index)
121 @property
122 def name(self):
123 "Returns the ONNX name."
124 return "_".join(
125 [self.desc['domain'], self.onnx_node.op_type]).replace(
126 ".", "_").replace('__', '_').strip('_')
128 def _init(self, global_index):
129 """
130 Prepares the node.
131 """
132 self.op_type = self.onnx_node.op_type
133 self.order = -1
134 self.variable_to_clean = []
135 self.inputs = list(self.onnx_node.input)
136 self.outputs = list(self.onnx_node.output)
137 self.inplaces = []
138 self.inputs_indices = [global_index(name) for name in self.inputs]
139 self.outputs_indices = [global_index(name) for name in self.outputs]
140 self._global_index = global_index
142 def set_order(self, order):
143 """
144 Defines the order of execution.
145 """
146 self.order = order
148 def add_variable_to_clean(self, name):
149 """
150 Adds a variable which can be cleaned after the node
151 execution.
152 """
153 self.variable_to_clean.append(name)
155 def __str__(self):
156 "usual"
157 return "Onnx-{}({}) -> {}{}".format(
158 self.op_type, ", ".join(self.inputs), ", ".join(self.outputs),
159 " (name=%r)" % self.onnx_node.name
160 if self.onnx_node.name else "")
162 def __repr__(self):
163 "usual"
164 return self.__str__()
166 def setup_runtime(self, runtime=None, variables=None, rt_class=None,
167 target_opset=None, dtype=None, domain=None,
168 ir_version=None, runtime_options=None,
169 build_inference_node_function=None):
170 """
171 Loads runtime.
173 :param runtime: runtime options
174 :param variables: registered variables created by previous operators
175 :param rt_class: runtime class used to compute
176 prediction of subgraphs
177 :param target_opset: use a specific target opset
178 :param dtype: float computational type
179 :param domain: node domain
180 :param ir_version: if not None, changes the default value
181 given by :epkg:`ONNX`
182 :param runtime_options: runtime options
183 :param build_inference_node_function: function creating an inference
184 runtime from an ONNX graph
186 .. versionchanged:: 0.9
187 Parameter *build_inference_node_function* was added.
188 """
189 if self.desc is None:
190 raise AttributeError(
191 "desc should not be None.") # pragma: no cover
192 if rt_class is None:
193 # path used when this operator is a function.
194 self.function_ = OnnxInferenceNode.OnnxInferenceWrapper(runtime)
195 self.ops_ = None
196 else:
197 self.function_ = None
198 self.preprocess_parameters(
199 runtime, rt_class, ir_version=ir_version,
200 target_opset=target_opset)
201 options = {'provider': runtime} if runtime else {}
202 if domain is not None:
203 options['domain'] = domain
204 if target_opset is not None:
205 options['target_opset'] = target_opset
206 if ir_version is not None:
207 options['ir_version'] = ir_version
208 if runtime_options is not None:
209 options.update({
210 k: v for k, v in runtime_options.items()
211 if k not in {'log_severity_level'}})
212 try:
213 if runtime is not None and runtime.startswith('onnxruntime2'):
214 self.ops_ = load_op(self.onnx_node, desc=self.desc,
215 options=options if options else None,
216 variables=variables, dtype=dtype,
217 runtime=runtime)
218 elif runtime in ('python_compiled', 'python_compiled_debug'):
219 options['provider'] = 'python'
220 self.ops_ = load_op(self.onnx_node, desc=self.desc,
221 options=options if options else None,
222 variables=variables, dtype=dtype,
223 runtime=runtime)
224 else:
225 self.ops_ = load_op(self.onnx_node, desc=self.desc,
226 options=options if options else None,
227 variables=variables, dtype=dtype,
228 runtime=runtime)
229 except MissingOperatorError as e:
230 try:
231 onnx_schema = get_onnx_schema(
232 self.onnx_node.op_type, self.onnx_node.domain,
233 opset=target_opset)
234 except SchemaError:
235 raise e # pylint: disable=W0707
236 if onnx_schema is None or not onnx_schema.has_function:
237 raise e
238 self.function_ = OnnxInferenceNode.OnnxInferenceWrapper(
239 build_inference_node_function(onnx_schema.function_body))
240 self.ops_ = None
242 @staticmethod
243 def _find_static_inputs(body):
244 """
245 Determines the loop inputs. It is any defined inputs
246 by the subgraphs + any results used as a constant
247 in the subgraphs.
248 """
249 inputs_set = set(i.name for i in body.input)
250 for init in body.initializer:
251 inputs_set.add(init.name)
252 for node in body.node:
253 for i in node.output:
254 inputs_set.add(i)
255 add_inputs = []
256 for node in body.node:
257 for i in node.input:
258 if i not in inputs_set:
259 # no graph input or output node matches
260 # it must be a constant from the below graph
261 add_inputs.append(i)
262 inputs_set.add(i)
263 return add_inputs
265 def preprocess_parameters(self, runtime, rt_class, ir_version=None,
266 target_opset=None):
267 """
268 Preprocesses the parameters, loads *GraphProto*
269 (equivalent to :epkg:`ONNX` graph with less metadata).
271 @param runtime runtime options
272 @param rt_class runtime class used to compute
273 prediction of subgraphs
274 @param ir_version if not None, overwrites the default value
275 @param target_opset use a specific target opset
276 """
277 if 'atts' not in self.desc:
278 return # pragma: no cover
279 inside_loop = self.onnx_node.op_type in {'Loop'}
280 for _, v in self.desc['atts'].items():
281 if 'value' not in v:
282 continue # pragma: no cover
283 value = v['value']
284 if isinstance(value, onnx_proto.GraphProto):
285 static_inputs = OnnxInferenceNode._find_static_inputs(value)
286 try:
287 sess = rt_class(v['value'], runtime=runtime,
288 ir_version=ir_version,
289 target_opset=target_opset,
290 inside_loop=inside_loop,
291 static_inputs=static_inputs)
292 except RuntimeError as e: # pragma: no cover
293 raise RuntimeError(
294 "Unable to instantiate a node of type %r and name %r."
295 "" % (self.onnx_node.op_type, self.onnx_node.name)) from e
296 v['value_rt'] = sess
298 def run(self, values):
299 """
300 Runs the node.
301 the function updates values with outputs.
303 @param values list of existing values
304 """
305 if self.ops_ is None:
306 # Then a function.
307 feeds = {name: val
308 for name, val in zip(self.function_.obj.input, values)}
309 outputs = self.function_.run(feeds)
310 res = [outputs[k] for k in self.function_.obj.output]
312 if self.outputs_indices is None:
313 for name, value in zip(self.outputs, res):
314 values[name] = value
315 else:
316 for i, r in enumerate(res):
317 values[self.outputs_indices[i]] = r
318 return
320 # This code takes time if the graph contains many nodes.
321 # Maybe a C++ container would help in that case (to skip GIL).
322 if self.inputs_indices is None:
323 args = list(values[k] for k in self.inputs)
324 else:
325 args = list(values[k] for k in self.inputs_indices)
326 try:
327 if self.ops_.need_context():
328 context = {n: values[self._global_index(n)]
329 for n in self.ops_.additional_inputs}
330 res = self.ops_.run(*args, context=context)
331 else:
332 res = self.ops_.run(*args)
333 except TypeError as e:
334 raise RuntimeError( # pragma: no cover
335 "Unable to run operator %r, inputs=%r."
336 "" % (type(self.ops_), self.inputs)) from e
337 except OverflowError as e:
338 raise RuntimeError( # pragma: no cover
339 "Unable to run operator %r, inputs=%r."
340 "" % (type(self.ops_), self.inputs)) from e
342 if not isinstance(res, tuple):
343 raise RuntimeError( # pragma: no cover
344 "Results of operator %r should be a tuple." % type(self.ops_))
345 if len(self.outputs) != len(res):
346 raise RuntimeError( # pragma: no cover
347 "Mismatch number of outputs got {} for names {}.\n{}".format(
348 len(res), list(sorted(self.outputs)),
349 pprint.pformat(self.desc)))
351 # This code takes times if the graph contains many nodes.
352 # Maybe a C++ container would help in that case (to skip GIL).
353 if self.outputs_indices is None:
354 for name, value in zip(self.outputs, res):
355 values[name] = value
356 else:
357 for i, r in enumerate(res):
358 values[self.outputs_indices[i]] = r
360 def switch_initializers_dtype(self, dtype_in=numpy.float32,
361 dtype_out=numpy.float64):
362 """
363 Switches all initializers to ``numpy.float64``.
364 This only works if the runtime is ``'python'``.
366 @param dtype_in previous type
367 @param dtype_out next type
368 @return done operations
369 """
370 done = []
371 for k, v in self.desc['atts'].items():
372 if 'value_rt' not in v:
373 continue
374 if isinstance(v['value_rt'], numpy.ndarray):
375 if v['value_rt'].dtype == dtype_in:
376 v['value_rt'] = v['value_rt'].astype(dtype_out)
377 done.append(("+", "desc", k, v['value_rt']))
378 else:
379 done.append(("-", "desc", k, v['value_rt']))
380 if hasattr(self, 'ops_') and self.ops_ is not None:
381 res = self.ops_.switch_initializers_dtype(dtype_in, dtype_out)
382 for r in res:
383 done.append(("ops_", ) + r)
384 return done
386 def _set_shape_inference_runtime(self, values):
387 """
388 Updates *values* which shapes of the outputs.
390 :param values: container for shapes
391 """
392 if self.ops_ is None:
393 # A function, unknown types.
394 for name in self.outputs:
395 values[name] = None
396 return values
397 args = [values[k] for k in self.inputs if k != '']
398 try:
399 res = self.ops_.infer_shapes(*args)
400 except (TypeError, ValueError, AttributeError) as e: # pragma: no cover
401 raise TypeError(
402 "Unable to call infer_shapes with {} arguments for class"
403 " '{}' ({})".format(
404 len(args), self.ops_.__class__.__name__,
405 self.ops_.infer_shapes)) from e
406 if res is not None:
407 if not isinstance(res, tuple):
408 raise RuntimeError( # pragma: no cover
409 "Results of an operator should be a tuple for operator "
410 "'{}'.".format(type(self.ops_)))
411 if len(self.outputs) != len(res):
412 raise RuntimeError( # pragma: no cover
413 "Mismatch number of outputs got {} != {} for names {} "
414 "(node='{}').\n{}".format(
415 len(res), len(self.outputs), list(self.outputs),
416 self.ops_.__class__.__name__,
417 pprint.pformat(self.desc, depth=2)))
418 for name, value in zip(self.outputs, res):
419 values[name] = value
420 return values
422 def _set_type_inference_runtime(self, values):
423 """
424 Updates *values* which types of the outputs.
426 :param values: container for types
427 """
428 args = [values[k] for k in self.inputs]
429 if self.ops_ is None:
430 res = self.function_.infer_types(*args)
431 else:
432 res = self.ops_.infer_types(*args)
433 try:
434 if self.ops_ is None:
435 res = self.function_.infer_types(*args)
436 else:
437 res = self.ops_.infer_types(*args)
438 except (TypeError, ValueError) as e: # pragma: no cover
439 raise TypeError(
440 "Unable to call infer_types with {} arguments for class"
441 " '{}'".format(
442 len(args), self.ops_.__class__.__name__)) from e
443 if not isinstance(res, tuple):
444 raise RuntimeError( # pragma: no cover
445 "Results of an operator should be a tuple for operator '{}'"
446 ".".format(type(self.ops_)))
447 if len(self.outputs) != len(res):
448 raise RuntimeError( # pragma: no cover
449 "Mismatch number of outputs got {} != {} for names {} (node='{}')."
450 "\n{}".format(
451 len(res), len(self.outputs), list(self.outputs),
452 self.ops_.__class__.__name__,
453 pprint.pformat(self.desc, depth=2)))
454 for name, value in zip(self.outputs, res):
455 values[name] = value
456 return values
458 def _set_size_inference_runtime(self, values):
459 """
460 Updates *values* which types of the outputs.
462 :param values: container for sizes
463 """
464 args = [values[k] for k in self.inputs]
465 try:
466 if (self.ops_ or self.function_).need_context():
467 context = {n: values[n]
468 for n in self.ops_.additional_inputs}
469 res = self.ops_.infer_sizes(*args, context=context)
470 else:
471 res = (self.ops_ or self.function_).infer_sizes(*args)
472 except (TypeError, ValueError) as e: # pragma: no cover
473 raise TypeError(
474 "Unable to call infer_sizes with {} arguments for class"
475 " '{}' ({})".format(len(args), self.ops_.__class__.__name__,
476 self.ops_.infer_sizes)) from e
477 if not isinstance(res, tuple):
478 raise RuntimeError( # pragma: no cover
479 "Results of an operator should be a tuple for operator '{}'"
480 ".".format(type(self.ops_)))
481 if len(self.outputs) + 1 != len(res):
482 raise RuntimeError( # pragma: no cover
483 "Mismatch number of outputs got {} != {} + 1 for names {} "
484 "(node='{}').\n{}".format(
485 len(res), len(self.outputs), list(self.outputs),
486 self.ops_.__class__.__name__,
487 pprint.pformat(self.desc, depth=2)))
488 for name, value in zip(self.outputs, res[1:]):
489 values[name] = value
490 values['#' + self.onnx_node.name] = res[0]
491 return values
493 def enable_inplace_compute(self, name):
494 """
495 Let the node know that one input can be overwritten.
497 @param name input name
498 """
499 self.inplaces.append(name)
500 (self.ops_ or self.function_).enable_inplace_compute(
501 self.inputs.index(name))
503 @property
504 def inputs_args(self):
505 """
506 Returns the list of arguments as well as
507 the list of parameters with the default values
508 (close to the signature).
509 """
510 if not hasattr(self, 'ops_'):
511 raise AttributeError(
512 "Attribute 'ops_' is missing.") # pragma: no cover
513 sigs = []
514 ops_or_function = self.function_ if self.ops_ is None else self.ops_
515 mand = ops_or_function.args_mandatory
516 if mand is None:
517 mand = self.python_inputs
518 sigs.extend(mand)
519 if len(ops_or_function.args_optional) > 0:
520 sigs.extend(ops_or_function.args_optional)
521 if sys.version_info[:2] >= (3, 8):
522 sigs.append('/')
523 sigs.extend(ops_or_function.args_default)
524 return sigs
526 @property
527 def python_inputs(self):
528 """
529 Returns the python arguments.
530 """
531 if not hasattr(self, 'ops_'):
532 raise AttributeError(
533 "Attribute 'ops_' is missing.") # pragma: no cover
534 if hasattr(self.ops_, 'python_inputs'):
535 return self.ops_.python_inputs
536 return self.inputs
538 @property
539 def modified_args(self):
540 """
541 Returns the list of modified parameters.
542 """
543 if not hasattr(self, 'ops_'):
544 raise AttributeError(
545 "Attribute 'ops_' is missing.") # pragma: no cover
546 if self.ops_ is None:
547 return self.function_.args_default_modified
548 return self.ops_.args_default_modified
550 def to_python(self, inputs):
551 """
552 Returns a python code for this operator.
554 @param inputs inputs name
555 @return imports, python code, both as strings
556 """
557 if not hasattr(self, 'ops_'):
558 raise AttributeError(
559 "Attribute 'ops_' is missing.") # pragma: no cover
560 if self.ops_ is None:
561 return self.function_.to_python(inputs)
562 return self.ops_.to_python(inputs)