Coverage for mlprodict/onnxrt/onnx_inference.py: 96%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=C0302,R0912
2"""
3@file
4@brief Implements a class able to compute the predictions
5from on an :epkg:`ONNX` model.
6"""
7from collections import OrderedDict
8from io import BytesIO
9from time import perf_counter
10import warnings
11import textwrap
12import pprint
13from keyword import iskeyword
14import numpy
15from scipy.sparse import coo_matrix
16from onnx import load, load_model, checker, shape_inference
17from onnx import onnx_pb as onnx_proto
18from onnx.helper import make_model
19from ..tools.code_helper import make_callable, print_code
20from ..onnx_tools.onnx2py_helper import (
21 _var_as_dict, numpy_min, numpy_max, guess_numpy_type_from_string)
22from ..onnx_tools.onnx_manipulations import (
23 select_model_inputs_outputs, enumerate_model_node_outputs,
24 overwrite_opset, insert_results_into_onnx)
25from ..onnx_tools.optim import onnx_remove_node_unused
26from .onnx_inference_node import OnnxInferenceNode
27from .onnx_inference_exports import OnnxInferenceExport
28from .shape_object import ShapeObject
29from .type_object import SequenceType
32class OnnxInference:
33 """
34 Loads an :epkg:`ONNX` file or object or stream.
35 Computes the output of the :epkg:`ONNX` graph.
36 Several runtimes are available.
38 * ``'python'``: the runtime implements every onnx operator
39 needed to run a :epkg:`scikit-learn` model by using :epkg:`numpy`
40 or C++ code.
41 * ``'python_compiled'``: it is the same runtime than the previous
42 one except every operator is called from a compiled function
43 (@see me _build_compile_run) instead for a method going through
44 the list of operator
45 * ``'onnxruntime1'``: uses :epkg:`onnxruntime` (or `onnxruntime1-cuda`, ...)
46 * ``'onnxruntime2'``: this mode is mostly used to debug as
47 python handles calling every operator but :epkg:`onnxruntime`
48 is called for every of them, this process may fail due to
49 wrong inference type specially of the graph includes
50 custom nodes, in that case, it is better to compute the output
51 of intermediates nodes. It is much slower as fo every output, every
52 node is computed but more robust.
54 :param onnx_or_bytes_or_stream: :epkg:`onnx` object,
55 bytes, or filename or stream
56 :param runtime: runtime options
57 :param skip_run: do not build the runtime
58 :param inplace: use inplace computation as much as possible
59 :param input_inplace: the computation is allowed
60 to overwrite the input, see :meth:`_guess_inplace
61 <mlprodict.onnxrt.onnx_inference.OnnxInference._guess_inplace>`
62 :param ir_version: if not None, overwrite the default version
63 :param target_opset: used to overwrite *target_opset*
64 :param runtime_options: specific options for the runtime
65 :param inside_loop: tells the runtime the graph is meant to
66 be repeated multiple times (in that case, inputs and
67 outputs may share the same name)
68 :param static_inputs: Loop can use static variables,
69 variables from the graph which runs the loop
70 (enumerate of strings)
71 :param new_outputs: if the loading fails, it might worth
72 cutting the graph, if not None, the graph will
73 be cut to have these new_outputs as the final outputs
74 :param new_opset: overwrite the main opset and replaces
75 by this new one
76 :param existing_functions: a model may contain several local functions,
77 this parameter is used when a local function is calling another
78 local function previously defined.
80 Among the possible runtime_options, there are:
81 * *enable_profiling*: enables profiling for :epkg:`onnxruntime`
82 * *session_options*: an instance of *SessionOptions* from
83 :epkg:`onnxruntime`
84 * *ir_version*: change ir_version
86 .. versionchanged:: 0.7
87 Parameters *new_outputs*, *new_opset* were added.
89 .. versionchanged:: 0.8
90 Parameters *static_inputs*, *device* were added.
92 .. versionchanged:: 0.9
93 Parameters *existing_functions* was added.
94 Removes *device* parameter. See runtime.
95 Runtime `onnxruntime1-cuda` was added.
96 """
98 def __init__(self, onnx_or_bytes_or_stream, runtime=None,
99 skip_run=False, inplace=True,
100 input_inplace=False, ir_version=None,
101 target_opset=None, runtime_options=None,
102 session_options=None, inside_loop=False,
103 static_inputs=None, new_outputs=None, new_opset=None,
104 existing_functions=None):
105 if isinstance(onnx_or_bytes_or_stream, bytes):
106 self.obj = load_model(BytesIO(onnx_or_bytes_or_stream))
107 elif isinstance(onnx_or_bytes_or_stream, BytesIO):
108 self.obj = load_model(onnx_or_bytes_or_stream)
109 elif isinstance(onnx_or_bytes_or_stream, str):
110 self.obj = load(onnx_or_bytes_or_stream)
111 elif hasattr(onnx_or_bytes_or_stream, 'graph'):
112 self.obj = onnx_or_bytes_or_stream
113 elif isinstance(onnx_or_bytes_or_stream, onnx_proto.GraphProto):
114 self.obj = make_model(onnx_or_bytes_or_stream,
115 producer_name='mlprodict')
116 elif isinstance(onnx_or_bytes_or_stream, onnx_proto.FunctionProto):
117 self.obj = onnx_or_bytes_or_stream
118 else:
119 raise TypeError("Unable to handle type {}.".format( # pragma: no cover
120 type(onnx_or_bytes_or_stream)))
121 if ir_version is not None:
122 self.obj.ir_version = ir_version
123 if new_outputs is not None:
124 self.obj = select_model_inputs_outputs(
125 self.obj, outputs=new_outputs, infer_shapes=True)
126 if new_opset is not None:
127 self.obj = overwrite_opset(self.obj, new_opset)
129 self.runtime = runtime
130 self.skip_run = skip_run
131 self.input_inplace = input_inplace
132 self.inplace = inplace
133 self.force_target_opset = target_opset
134 self.runtime_options = runtime_options
135 self.inside_loop = inside_loop
136 self.static_inputs = static_inputs
137 self._init(existing_functions)
139 def __getstate__(self):
140 """
141 To pickle the object.
142 """
143 return {'onnx': self.obj.SerializeToString(),
144 'runtime': self.runtime,
145 'runtime_options': self.runtime_options,
146 'skip_run': self.skip_run,
147 'input_inplace': self.input_inplace,
148 'inplace': self.inplace,
149 'force_target_opset': self.force_target_opset,
150 'static_inputs': self.static_inputs,
151 'inside_loop': self.inside_loop}
153 def __setstate__(self, state):
154 """
155 To unpickle the object.
156 """
157 onx = state['onnx']
158 self.obj = load_model(BytesIO(onx))
159 self.runtime = state['runtime']
160 self.runtime_options = state['runtime_options']
161 self.skip_run = state['skip_run']
162 self.input_inplace = state['input_inplace']
163 self.inplace = state['inplace']
164 self.force_target_opset = state['force_target_opset']
165 self.static_inputs = state['static_inputs']
166 self.inside_loop = state['inside_loop']
167 self._init()
169 def _init(self, existing_functions=None):
170 """
171 Prepares the instance to deliver predictions.
172 """
173 self.graph_ = self.to_sequence(existing_functions)
174 if len(self.graph_['sequence']) == 0:
175 raise RuntimeError( # pragma: no cover
176 "No runnable nodes was found in the ONNX graph.")
177 self.outputs_ = self.graph_['outputs']
178 self.inputs_ = self.graph_['inputs']
179 is_function_proto = isinstance(self.obj, onnx_proto.FunctionProto)
180 if is_function_proto:
181 obj_graph = self.obj
182 else:
183 obj_graph = self.obj.graph
185 for ino in [obj_graph.input, obj_graph.output]:
186 for xy in ino:
187 if isinstance(xy, str):
188 shape = None
189 else:
190 shape = xy.type.tensor_type.shape
191 for d in shape.dim:
192 if (d.dim_value == 0 and "0" in str(d) and
193 'dim_param' not in str(d)):
194 if len(shape.dim) <= 1:
195 shape = None
196 break
197 # d.dim_value returns 0 whether is is 0 or empty.
198 # it may be a parameter as well
199 raise RuntimeError( # pragma: no cover
200 "Wrong ONNX file, one input or output has "
201 "an empty shape: {}.".format(xy))
203 self.target_opset_ = self.graph_['targets']
204 if self.force_target_opset is not None:
205 if isinstance(self.force_target_opset, dict):
206 self.target_opset_ = self.force_target_opset # pragma: no cover
207 else:
208 self.target_opset_ = {'': self.force_target_opset}
209 self.ir_version_ = self.graph_['ir_version']
211 if not self.skip_run:
212 if self.runtime is not None and self.runtime.startswith('onnxruntime1'):
213 # Loads the onnx with onnxruntime as a single file.
214 del self.graph_
215 from .ops_whole.session import OnnxWholeSession
216 self._whole = OnnxWholeSession(
217 self.obj, self.runtime, self.runtime_options)
218 self._run = self._run_whole_runtime
219 else:
220 self.sequence_ = self.graph_['sequence']
221 self.inits_ = self.graph_['inits']
222 self.statics_ = self.graph_['statics']
223 dtype = self._guess_input_dtype()
224 variables = self.inits_.copy()
225 for node in self.sequence_:
226 domain = node.onnx_node.domain
227 target_opset = self.target_opset_.get(domain, None)
228 keyf = domain, node.onnx_node.op_type
229 if keyf in self.graph_['functions']:
230 node.setup_runtime(self.graph_['functions'][keyf])
231 elif self.runtime in ('onnxruntime2', 'empty'):
232 node.setup_runtime(
233 self.runtime, variables, self.__class__,
234 target_opset=target_opset, dtype=dtype,
235 domain=domain, ir_version=self.ir_version_,
236 runtime_options=self.runtime_options,
237 build_inference_node_function=lambda fct:
238 OnnxInference(
239 fct, runtime=self.runtime,
240 skip_run=self.skip_run,
241 inplace=self.inplace,
242 runtime_options=self.runtime_options,
243 inside_loop=self.inside_loop,
244 static_inputs=self.static_inputs))
245 else:
246 node.setup_runtime(
247 self.runtime, variables, self.__class__,
248 target_opset=target_opset, domain=domain,
249 ir_version=self.ir_version_,
250 runtime_options=self.runtime_options,
251 build_inference_node_function=lambda fct:
252 OnnxInference(
253 fct, runtime=self.runtime,
254 skip_run=self.skip_run,
255 inplace=self.inplace,
256 runtime_options=self.runtime_options,
257 inside_loop=self.inside_loop,
258 static_inputs=self.static_inputs))
259 if hasattr(node, 'ops_') and hasattr(node.ops_, 'typed_outputs_'):
260 for k, v in node.ops_.typed_outputs_:
261 variables[k] = v
262 self._run = self._run_sequence_runtime
264 if not self.skip_run and self.runtime in ('python', None):
265 if is_function_proto:
266 self.shapes_ = None
267 else:
268 self.shapes_ = self._set_shape_inference_runtime()
269 if self.inplace:
270 self.inplaces_ = self._guess_inplace(self.input_inplace)
271 self.exporters_ = OnnxInferenceExport(self)
272 self.to_json = self.exporters_.to_json
273 self.to_dot = self.exporters_.to_dot
274 self.to_python = self.exporters_.to_python
275 self.to_text = self.exporters_.to_text
276 self.to_onnx_code = self.exporters_.to_onnx_code
278 if self.runtime in ('python_compiled', 'python_compiled_debug'):
279 # switch the inference method to the compiled one
280 _, fct, code = self._build_compile_run('debug' in self.runtime)
281 setattr(self, '_run_compiled', fct)
282 setattr(self, '_run_compiled_code', code)
283 self._run = self._run_sequence_runtime_compiled
285 def _run_sequence_runtime_compiled(
286 self, inputs, clean_right_away=False, intermediate=False,
287 verbose=0, node_time=False, yield_ops=None, fLOG=None):
288 """
289 Executes a compiled version of @see me _run_sequence_runtime,
290 compiled with method @see me _build_compile_run.
291 Every parameter with a default value is ignored.
292 Switch to ``runtime='python'`` to enable those.
293 """
294 try:
295 return self._run_compiled( # pylint: disable=E1101
296 inputs, yield_ops=yield_ops)
297 except NameError as e:
298 raise RuntimeError( # pragma: no cover
299 "Unable to compute prediction due to %r. Code:\n%s"
300 "" % (e, print_code(
301 self._run_compiled_code))) from e # pylint: disable=E1101
303 def _guess_input_dtype(self):
304 for _, v in self.graph_['inputs'].items():
305 if 'type' not in v:
306 continue # pragma: no cover
307 t = v['type']
308 if 'elem' not in t:
309 continue
310 if t['elem'] == 'double':
311 return numpy.float64
312 return numpy.float32
314 def __str__(self):
315 """
316 usual
317 """
318 rows = ['OnnxInference(...)']
319 if hasattr(self, '_run_compiled_code'):
320 rows.append(
321 textwrap.indent(
322 self._run_compiled_code, ' ')) # pylint: disable=E1101
323 else:
324 rows.append(textwrap.indent(str(self.obj), ' '))
325 return "\n".join(rows)
327 def __repr__(self):
328 """
329 usual
330 """
331 return "OnnxInference(...)" # pragma: no cover
333 def check_model(self):
334 """
335 Checks the model follow :epkg:`ONNX` conventions.
336 """
337 checker.check_model(self.obj)
339 def shape_inference(self):
340 """
341 Infers the shape of the outputs
342 with :epkg:`onnx` package.
344 @return A new :epkg:`ONNX` graph which defined outputs.
345 """
346 return shape_inference.infer_shapes(self.obj)
348 @property
349 def input_names(self):
350 """
351 Returns the names of all inputs.
352 It does not include the optional inputs.
354 .. versionchanged:: 0.6
355 The list does not include optional inputs anymore.
356 """
357 if hasattr(self.obj, 'graph'):
358 inits = set(_.name for _ in self.obj.graph.initializer)
359 return [_.name for _ in self.obj.graph.input if _.name not in inits]
360 return list(self.obj.input)
362 @property
363 def input_names_shapes(self):
364 """
365 Returns the names and shapes of all inputs.
366 This method assumes all inputs are tensors.
367 It does not include the optional inputs.
369 .. versionchanged:: 0.6
370 The list does not include optional inputs anymore.
371 """
372 names = set(self.input_names)
373 return [(_.name, _var_as_dict(_)['type']['shape'])
374 for _ in self.obj.graph.input if _.name in names]
376 @staticmethod
377 def _get_type_property(info, prop):
378 if prop in info:
379 return info[prop]
380 if 'kind' in info and info['kind'] == 'sequence':
381 if prop == 'shape':
382 return ('?', )
383 raise NotImplementedError( # pragma: no cover
384 "Unable to retrieve property %r from %r."
385 "" % (prop, info))
387 @property
388 def input_names_shapes_types(self):
389 """
390 Returns the names, shapes, types of all inputs.
391 This method assumes all inputs are tensors.
392 It does not include the optional inputs.
394 .. versionchanged:: 0.6
395 The list does not include optional inputs anymore.
396 """
397 f = OnnxInference._get_type_property
398 names = set(self.input_names)
399 if isinstance(self.obj, onnx_proto.FunctionProto):
400 return [(_.name, f(_var_as_dict(_)['type'], 'shape'),
401 'tensor(%s)' % f(_var_as_dict(_)['type'], 'elem'))
402 for _ in self.obj.input if _.name in names]
403 return [(_.name, f(_var_as_dict(_)['type'], 'shape'),
404 'tensor(%s)' % f(_var_as_dict(_)['type'], 'elem'))
405 for _ in self.obj.graph.input if _.name in names]
407 @property
408 def output_names(self):
409 """
410 Returns the names of all outputs.
411 """
412 if isinstance(self.obj, onnx_proto.FunctionProto):
413 return [_ for _ in self.obj.output]
414 return [_.name for _ in self.obj.graph.output]
416 @property
417 def output_names_shapes(self):
418 """
419 Returns the names and shapes of all outputs.
420 This method assumes all inputs are tensors.
421 """
422 f = OnnxInference._get_type_property
423 if isinstance(self.obj, onnx_proto.FunctionProto):
424 return [(_, None) for _ in self.obj.output]
425 return [(_.name, f(_var_as_dict(_)['type'], 'shape'))
426 for _ in self.obj.graph.output]
428 @property
429 def output_names_shapes_types(self):
430 """
431 Returns the names, shapes, types of all outputs.
432 This method assumes all inputs are tensors.
433 It does not include the optional outputs.
435 .. versionadd:: 0.7
436 """
437 names = set(self.output_names)
438 f = OnnxInference._get_type_property
439 if isinstance(self.obj, onnx_proto.FunctionProto):
440 return [(_, None) for _ in self.obj.graph.output if _ in names]
441 return [(_.name, f(_var_as_dict(_)['type'], 'shape'),
442 'tensor(%s)' % f(_var_as_dict(_)['type'], 'elem'))
443 for _ in self.obj.graph.output if _.name in names]
445 def global_index(self, name):
446 """
447 Maps every name to one integer to avoid using dictionaries
448 when running the predictions.
450 @param name outputs name
451 @return integer
452 """
453 if not hasattr(self, '_global_index'):
454 self._global_index = {}
455 if name in self._global_index:
456 return self._global_index[name]
457 self._global_index[name] = len(self._global_index)
458 return self._global_index[name]
460 def to_sequence(self, existing_functions=None):
461 """
462 Produces a graph to facilitate the execution.
464 One example:
466 .. exref::
467 :title: Convert ONNX into graph
469 An example on how to convert an :epkg:`ONNX`
470 graph into a graph.
472 .. runpython::
473 :showcode:
474 :warningout: DeprecationWarning
476 import pprint
477 import numpy
478 from mlprodict.npy.xop import loadop
479 from mlprodict.onnxrt import OnnxInference
481 OnnxAiOnnxMlLinearRegressor = loadop(
482 ('ai.onnx.ml', 'LinearRegressor'))
484 pars = dict(coefficients=numpy.array([1., 2.]),
485 intercepts=numpy.array([1.]),
486 post_transform='NONE')
487 onx = OnnxAiOnnxMlLinearRegressor(
488 'X', output_names=['Y'], **pars)
489 model_def = onx.to_onnx(
490 {'X': pars['coefficients'].astype(numpy.float32)},
491 outputs={'Y': numpy.float32},
492 target_opset=12)
493 oinf = OnnxInference(model_def)
494 pprint.pprint(oinf.to_sequence())
496 See an example of representation in notebook
497 :ref:`onnxvisualizationrst`.
498 """
499 inits = {}
500 variables = {}
501 outputs = {}
502 nodes = {}
503 statics = {}
504 targets = {}
505 functions = {}
506 if existing_functions is not None:
507 functions.update(existing_functions)
508 is_function_proto = isinstance(self.obj, onnx_proto.FunctionProto)
510 for o in self.obj.opset_import:
511 targets[o.domain] = o.version
513 if (hasattr(self.obj, 'functions') and len(self.obj.functions) > 0 and
514 (self.runtime is None or not
515 self.runtime.startswith('onnxruntime1'))):
516 for fct in self.obj.functions:
517 functions[fct.domain, fct.name] = OnnxInference(
518 fct, runtime=self.runtime,
519 skip_run=self.skip_run,
520 inplace=self.inplace,
521 runtime_options=self.runtime_options,
522 inside_loop=self.inside_loop,
523 static_inputs=self.static_inputs,
524 existing_functions=functions)
526 # static variables
527 if self.static_inputs is not None:
528 for n in self.static_inputs:
529 statics[n] = {'name': n}
530 self.global_index(n)
532 obj_graph = (
533 self.obj if isinstance(self.obj, onnx_proto.FunctionProto)
534 else self.obj.graph)
536 # inputs
537 for obj in obj_graph.input:
538 if is_function_proto:
539 variables[obj] = {'name': obj}
540 self.global_index(obj)
541 else:
542 variables[obj.name] = _var_as_dict(obj)
543 self.global_index(obj.name)
545 # outputs
546 for obj in obj_graph.output:
547 if is_function_proto:
548 outputs[obj] = {'name': obj}
549 self.global_index(obj)
550 else:
551 if hasattr(obj, 'type') and str(obj.type) != '':
552 outputs[obj.name] = _var_as_dict(obj)
553 else:
554 outputs[obj.name] = {'name': obj.name}
555 self.global_index(obj.name)
557 # initializer
558 if not is_function_proto:
559 for obj in obj_graph.initializer:
560 init_obj = _var_as_dict(obj)
561 if init_obj is None:
562 raise RuntimeError( # pragma: no cover
563 "Unable to convert an initializer\n{}".format(obj))
564 inits[obj.name] = init_obj
565 self.global_index(obj.name)
566 if 'value' not in inits[obj.name]:
567 raise RuntimeError( # pragma: no cover
568 "One initializer has no value: '{}'\n{}\n{}".format(
569 obj.name, inits[obj.name], obj))
571 # nodes
572 for node in obj_graph.node:
573 dobj = _var_as_dict(node)
574 if dobj is None:
575 raise RuntimeError( # pragma: no cover
576 "Unable to convert a node\n{}".format(node))
577 if 'atts' in dobj:
578 atts = dobj['atts']
579 for k, v in atts.items():
580 if not isinstance(v, dict) or 'value' not in v:
581 raise RuntimeError( # pragma: no cover
582 "A parameter has no (sparse) value '{}' "
583 "for node '{}'\nv={}\ndobj=[{}]".format(
584 k, node.name, v, node))
585 if node.name in nodes: # pragma: no cover
586 i = 2
587 while True:
588 new_name = "%s_n%i" % (node.name, i)
589 if new_name not in nodes:
590 break
591 i += 1
592 else:
593 new_name = node.name
594 nodes[new_name] = OnnxInferenceNode(node, dobj, self.global_index)
596 # names
597 names = {}
598 for k, v in statics.items():
599 if (k, 0) in names:
600 raise RuntimeError( # pragma: no cover
601 "Static variables '{}' already exists (tag='{}').".format(
602 k, names[k, 0][0]))
603 names[k, 0] = ('S', v)
604 for k, v in inits.items():
605 if (k, 0) in names:
606 raise RuntimeError( # pragma: no cover
607 "Initializer '{}' already exists (tag='{}').".format(
608 k, names[k, 0][0]))
609 names[k, 0] = ('C', v)
610 for k, v in variables.items():
611 if (k, 0) in names:
612 if k in inits:
613 # Kind of default value for an input
614 continue
615 raise RuntimeError( # pragma: no cover
616 "Variable '{}' already exists (tag='{}').".format(
617 k, names[k, 0][0]))
618 names[k, 0] = ('I', v)
619 for k, v in outputs.items():
620 if (k, 0) in names and self.runtime != 'empty':
621 if not self.inside_loop or names[k, 0][0] != 'I':
622 raise RuntimeError( # pragma: no cover
623 "Output '{}' already exists (tag='{}').".format(
624 k, names[k, 0][0]))
625 else:
626 # For input, output sharing the same name, we marked the name
627 # as an input.
628 continue
629 names[k, 0] = ('O', v)
630 for k, v in nodes.items():
631 if (k, 1) in names:
632 raise RuntimeError( # pragma: no cover
633 "Node '{}' already exists (tag='{}'). "
634 "Use inside_loop=True to bypass this exception.".format(
635 k, names[k, 0][0]))
636 names[k, 1] = ('N', v)
638 # ordering
639 order = {}
640 modif = 1
641 intermediate = {}
642 while modif > 0:
643 modif = 0
644 for (k, _), v in names.items():
645 if (k, 1) in order:
646 # The operator node is already processed.
647 continue
648 if v[0] in {'I', 'C', 'S'}:
649 if (k, 0) not in order:
650 order[k, 0] = len(order) # A data node.
651 modif += 1
652 continue
653 if v[0] == 'O':
654 continue
655 if all((inp, 0) in order for inp in v[1].inputs if inp != ''):
656 # If all inputs are available,
657 # We tell the operator node is processed.
658 order[k, 1] = len(order)
659 modif += 1
660 for o in v[1].outputs:
661 if (o, 0) in order:
662 raise RuntimeError( # pragma: no cover
663 "Two nodes share the same output '{}' "
664 "or an operator and an output "
665 "share the same name. "
666 "(node: {}).".format(o, v[1]))
667 # We add a data node.
668 order[o, 0] = len(order)
669 intermediate[o] = None
670 modif += 1
672 # compute
673 rev = [(v, k[0], k[1]) for k, v in order.items()]
674 rev.sort()
675 sequence = []
676 for _, name, node_kind in rev:
677 if name not in nodes:
678 continue
679 if node_kind == 0:
680 # It is an output which shares the same name
681 # as a node.
682 continue
683 node = nodes[name]
684 node.set_order(len(sequence))
685 sequence.append(node)
687 if len(sequence) == 0:
688 from mlprodict.plotting.text_plot import onnx_simple_text_plot
689 raise RuntimeError( # pragma: no cover
690 "No runnable nodes was found in the ONNX graph"
691 "\n--rev--\n{}"
692 "\n--order--\n{}"
693 "\n--nodes--\n{}"
694 "\n--ONNX--\n{}\n---\n".format(
695 "\n".join([str(_) for _ in names.items()]),
696 "\n".join([str(_) for _ in order.items()]),
697 "\n".join([str(_) for _ in nodes.items()]),
698 onnx_simple_text_plot(self.obj, recursive=True)))
700 # defines where an intermediare output is not needed
701 last_used = {}
702 for node in sequence:
703 for inp in node.inputs:
704 last_used[inp] = node.order
705 for k, ord in last_used.items():
706 sequence[ord].add_variable_to_clean(k)
708 results = dict(inits=inits, inputs=variables, outputs=outputs,
709 nodes=nodes, sequence=sequence,
710 functions=functions,
711 intermediate=intermediate,
712 targets=targets,
713 ir_version=(
714 None if is_function_proto
715 else self.obj.ir_version),
716 statics=statics)
717 if len(sequence) < len(nodes):
718 # Not all node will be executed.
719 raise RuntimeError( # pragma: no cover
720 "Unable to run all nodes.\n--Nodes--\n%s\n--Sequence--\n%s"
721 "\n--Inputs--\n%s\n--Inits--\n%s\n--Statics\n%s"
722 "" % (pprint.pformat(nodes), pprint.pformat(sequence),
723 pprint.pformat(list(variables)),
724 pprint.pformat(list(inits)),
725 pprint.pformat(list(statics))))
726 return results
728 def run(self, inputs, clean_right_away=False,
729 intermediate=False, verbose=0, node_time=False,
730 overwrite_types=None, yield_ops=None, fLOG=None):
731 """
732 Computes the predictions for this :epkg:`onnx` graph.
734 :param inputs: inputs as dictionary or a dataframe
735 :param clean_right_away: clean the intermediate outputs
736 as soon as they are not needed
737 :param intermediate: returns a dictionary of intermediate
738 variables instead of the results only
739 :param verbose: display information while predicting
740 :param node_time: measure time of each node
741 :param overwrite_types: shape inference does not work all the time,
742 this allows to force types when building intermediate
743 results, see @see fn select_model_inputs_outputs
744 :param yield_ops: dictionary to overwrite the output of
745 operator *YieldOp*
746 :param fLOG: logging function if *verbose > 0*
747 :return: outputs as dictionary
748 and a second dictionary of the time spent
749 in each node if *node_time* is True
751 .. exref::
752 :title: Computes predictions with any runtime
754 The following example compares predictions
755 between :epkg:`scikit-learn` and this runtime
756 for the python runtime.
758 .. runpython::
759 :showcode:
760 :warningout: DeprecationWarning
762 import numpy
763 from sklearn.linear_model import LinearRegression
764 from sklearn.datasets import load_iris
765 from sklearn.model_selection import train_test_split
766 from mlprodict.onnxrt import OnnxInference
767 from mlprodict.onnx_conv import to_onnx
769 iris = load_iris()
770 X, y = iris.data, iris.target
771 X_train, X_test, y_train, _ = train_test_split(X, y)
772 clr = LinearRegression()
773 clr.fit(X_train, y_train)
775 exp = clr.predict(X_test[:5])
776 print(exp)
778 model_def = to_onnx(clr, X_train.astype(numpy.float32),
779 target_opset=12)
780 oinf = OnnxInference(model_def)
781 y = oinf.run({'X': X_test[:5]})
782 print(y)
784 The function returns all intermediate outputs
785 if *intermediate* is True. In case of runtime
786 *onnxruntime1*, if intermediate is True,
787 the first class builds all :epkg:`ONNX` cut out
788 to keep the one output and converted into
789 *OnnxInference*.
791 .. versionchanged:: 0.8
792 Parameter *yield_ops* was added.
793 """
794 def retype(col_array):
795 if (hasattr(col_array, 'categories') and
796 hasattr(col_array, 'from_codes')):
797 # isinstance(col_array, pandas.Categorical):
798 return col_array.astype(numpy.int64)
799 return col_array
801 if hasattr(inputs, 'columns') and hasattr(inputs, 'iloc'):
802 # == isinstance(inputs, pandas.DataFrame)
803 inputs = OrderedDict((
804 name, retype(numpy.expand_dims(inputs[name].values, axis=1)))
805 for name in inputs.columns)
806 if intermediate:
807 if self.inplace:
808 raise RuntimeError( # pragma: no cover
809 "inplace must be False if intermediate is True, a container "
810 "might be used by several nodes.")
811 return self._run(inputs, clean_right_away=False,
812 intermediate=intermediate,
813 verbose=verbose, node_time=node_time,
814 overwrite_types=overwrite_types,
815 yield_ops=yield_ops, fLOG=fLOG)
816 if overwrite_types is not None:
817 raise RuntimeError( # pragma: no cover
818 "overwrite_types is not used if intermediate is False.")
819 return self._run(inputs, clean_right_away=False,
820 intermediate=intermediate,
821 verbose=verbose, node_time=node_time,
822 yield_ops=yield_ops, fLOG=fLOG)
824 def run2onnx(self, inputs, verbose=0, fLOG=None,
825 as_parameter=True, suffix='_DBG',
826 param_name=None, node_type='DEBUG',
827 domain='DEBUG', domain_opset=1):
828 """
829 Executes the graphs with the given inputs, then adds the intermediate
830 results into ONNX nodes in the original graph. Once saved, it can be
831 looked with a tool such as :epkg:`netron`.
833 :param inputs: inputs as dictionary or a dataframe
834 :param verbose: display information while predicting
835 :param fLOG: logging function if *verbose > 0*
836 :param as_parameter: add new nodes with results as one parameter
837 (True) or as initializer (False)
838 :param suffix: suffix to add to new results
839 :param param_name: name of the parameter to add
840 (by default the result name), it can be a function
841 `param_name(reult_name) -> parameter_name`
842 :param node_type: type of the new node
843 :param domain: domain the new node
844 :param domain_opset: opset for *domain*
845 :return: outputs as dictionary
846 and the onnx graph with new nodes
848 The following example shows how to use it.
850 .. gdot::
851 :script: DOT-SECTION
853 from sklearn.linear_model import LinearRegression
854 from sklearn.datasets import load_iris
855 from mlprodict.onnxrt import OnnxInference
856 import numpy
858 iris = load_iris()
859 X = iris.data[:, :2]
860 y = iris.target
861 lr = LinearRegression()
862 lr.fit(X, y)
864 from mlprodict.onnx_conv import to_onnx
865 model_onnx = to_onnx(lr, X.astype(numpy.float32))
866 oinf = OnnxInference(model_onnx, inplace=False)
868 model_onnx_debug = oinf.run2onnx({'X': X[:3].astype(numpy.float32)})
869 oinf_debug = OnnxInference(model_onnx_debug[1])
871 print("DOT-SECTION", oinf_debug.to_dot())
873 .. versionadded:: 0.7
874 """
875 intermediate = self.run(inputs, verbose=verbose, fLOG=fLOG,
876 intermediate=True)
877 for name in self.input_names:
878 del intermediate[name]
879 new_onx = insert_results_into_onnx(
880 self.obj, intermediate, as_parameter=as_parameter,
881 suffix=suffix, param_name=param_name, node_type=node_type,
882 domain=domain, domain_opset=domain_opset)
883 return intermediate, new_onx
885 def display_sequence(self, verbose=1):
886 """
887 Shows the sequence of nodes to run if ``runtime=='python'``.
888 """
889 rows = []
890 rows.append("#node: {}".format(len(self.sequence_)))
891 for i, node in enumerate(self.sequence_):
892 if verbose >= 1:
893 rows.append("{}: {}".format(i, str(node)))
894 return "\n".join(rows)
896 def _run_sequence_runtime(self, inputs, clean_right_away=False,
897 intermediate=False, verbose=0, node_time=False,
898 overwrite_types=None, yield_ops=None,
899 fLOG=None):
900 if overwrite_types is not None:
901 raise NotImplementedError( # pragma: no cover
902 "overwrite_types != None not implemented.")
903 if clean_right_away:
904 raise NotImplementedError( # pragma: no cover
905 "clean_right_away=true not implemented.")
907 if node_time:
908 mtime = []
909 if verbose >= 1 and fLOG is not None:
910 printed = set()
912 if hasattr(self, "_values_init"):
913 values = self._values_init.copy() # pylint: disable=E0203
914 else:
915 values = [None] * len(self._global_index)
916 if verbose >= 1 and fLOG is not None:
917 for k, v in self.inits_.items():
918 values[self._global_index[k]] = v['value']
919 if verbose < 3:
920 fLOG("+ki='{}': {} (dtype={} min={} max={})".format(
921 k, v['value'].shape, v['value'].dtype,
922 numpy_min(v['value']), numpy_max(v['value'])))
923 else:
924 fLOG("+ki='{}': {} (dtype={} min={} max={}\n{}".format(
925 k, v['value'].shape, v['value'].dtype,
926 numpy_min(v['value']), numpy_max(v['value']),
927 v['value']))
928 printed.add(k)
929 else:
930 for k, v in self.inits_.items():
931 values[self._global_index[k]] = v['value']
932 # stores the array to skip initialing a second time
933 if verbose == 0 or fLOG is None:
934 self._values_init = values.copy()
936 for name, value in inputs.items():
937 values[self._global_index[name]] = value
939 if verbose == 0 or fLOG is None:
940 if node_time:
941 for i, node in enumerate(self.sequence_):
942 if yield_ops is not None and node.onnx_node.op_type == 'YieldOp':
943 out = node.onnx_node.output[0]
944 if out in yield_ops:
945 values[out] = yield_ops[out]
946 continue
947 raise RuntimeError( # pragma: no cover
948 "YieldOp output %r could not be found in "
949 "yield_ops: %r (node=%r)." % (
950 out, list(sorted(yield_ops)), node.onnx_node))
951 t = perf_counter()
952 node.run(values)
953 t2 = perf_counter()
954 mtime.append(dict(i=i, name=node.onnx_node.name,
955 op_type=node.onnx_node.op_type,
956 time=t2 - t))
957 else:
958 for node in self.sequence_:
959 node.run(values)
960 else:
961 def dispsimple(arr):
962 if hasattr(arr, 'shape'):
963 if len(arr.shape) <= 1:
964 threshold = 8
965 else:
966 threshold = min(
967 50, min(50 // max(arr.shape[1], 1), 8) * arr.shape[1])
968 if hasattr(arr, 'todense'):
969 fLOG( # pragma: no cover
970 numpy.array2string(arr.todense(), max_line_width=120,
971 suppress_small=True, threshold=threshold))
972 else:
973 fLOG(numpy.array2string(arr, max_line_width=120,
974 suppress_small=True,
975 threshold=threshold))
976 else: # pragma: no cover
977 s = str(arr)
978 if len(s) > 50:
979 s = s[:50] + "..."
980 fLOG(s)
982 if verbose >= 2:
983 for k in sorted(self._global_index):
984 if values[self._global_index[k]] is None:
985 continue
986 obj = values[self._global_index[k]]
987 if k not in printed:
988 printed.add(k)
989 if hasattr(obj, 'shape'):
990 fLOG("-kv='{}' shape={} dtype={} min={} max={}{}".format(
991 k, obj.shape, obj.dtype, numpy_min(obj),
992 numpy_max(obj),
993 ' (sparse)' if isinstance(obj, coo_matrix) else ''))
994 elif (isinstance(obj, list) and len(obj) > 0 and
995 not isinstance(obj[0], dict)): # pragma: no cover
996 fLOG("-kv='{}' list len={}".format(k, len(obj)))
997 if verbose >= 3 and len(obj) > 0:
998 fLOG("first={} last={}".format(
999 obj[0], obj[-1]))
1000 else: # pragma: no cover
1001 fLOG("-kv='{}' type={}".format(k, type(obj)))
1003 keys = set(k for k in range(len(values)) if values[k] is not None)
1004 if verbose >= 1:
1005 fLOG("-- OnnxInference: run {} nodes".format(len(self.sequence_)))
1006 for i, node in enumerate(self.sequence_):
1007 if verbose >= 1:
1008 fLOG(node)
1009 if yield_ops is not None and node.onnx_node.op_type == 'YieldOp':
1010 out = node.onnx_node.output[0]
1011 if out in yield_ops:
1012 fLOG("+yo=%r" % out)
1013 values[node.outputs_indices[0]] = yield_ops[out]
1014 else:
1015 raise RuntimeError( # pragma: no cover
1016 "YieldOp output %r could not be found in "
1017 "yield_ops: %r (node=%r)." % (
1018 out, list(sorted(yield_ops)), node.onnx_node))
1019 elif node_time:
1020 t = perf_counter()
1021 node.run(values)
1022 t2 = perf_counter()
1023 mtime.append(dict(i=i, name=node.onnx_node.name,
1024 op_type=node.onnx_node.op_type,
1025 time=t2 - t))
1026 else:
1027 node.run(values)
1028 added = 0
1029 for k in range(len(values)): # pylint: disable=C0200
1030 if values[k] is None:
1031 continue
1032 if k not in keys and k not in printed:
1033 added += 1
1034 printed.add(k)
1035 name = list(
1036 name for name in self._global_index # pylint: disable=C0206
1037 if self._global_index[name] == k)
1038 if isinstance(values[k], (numpy.ndarray, coo_matrix)):
1039 name = name[0]
1040 mini = numpy_min(values[k])
1041 maxi = numpy_max(values[k])
1042 fLOG("+kr{}'{}': {} (dtype={} min={} max={}{})".format(
1043 "=" if len(values[k].shape) == 0 or min(
1044 values[k].shape) > 0 else "*",
1045 name, values[k].shape, values[k].dtype,
1046 mini, maxi,
1047 ' sparse' if isinstance(values[k], coo_matrix) else ''))
1048 if verbose >= 3:
1049 dispsimple(values[k])
1050 else:
1051 fLOG("+kr='{}': {}".format(
1052 name, type(values[k])))
1053 if verbose >= 3: # pragma: no cover
1054 dispsimple(values[k])
1055 if added == 0:
1056 fLOG("? no new result") # pragma: no cover
1058 if intermediate:
1059 values = [(v, k, values[v]) for k, v in self._global_index.items()]
1060 values.sort()
1061 values = OrderedDict((k, v) for _, k, v in values)
1062 return (values, mtime) if node_time else values
1064 try:
1065 res = {k: values[self._global_index[k]] for k in self.outputs_}
1066 except KeyError as e: # pragma: no cover
1067 raise RuntimeError("Unable to find one output [{}]\n in [{}]"
1068 ".".format(", ".join(sorted(self.outputs_)),
1069 ", ".join(sorted(values)))) from e
1070 return (res, mtime) if node_time else res
1072 def build_intermediate(self, outputs=None, verbose=0, overwrite_types=None,
1073 fLOG=None):
1074 """
1075 Builds every possible :epkg:`ONNX` file
1076 which computes one specific intermediate output
1077 from the inputs.
1079 :param outputs: subsets of outputs to get,
1080 None to get all outputs,
1081 :param overwrite_types: shape inference does not work all the time,
1082 this allows to force types when building intermediate
1083 results, see @see fn select_model_inputs_outputs
1084 :param verbose: displays intermediate information
1085 :param fLOG: logging function
1086 :return: :epkg:`*py:collections:OrderedDict`
1088 .. versionchanged: 0.6
1089 """
1090 if verbose > 0:
1091 fLOG('[build_intermediate] BEGIN.') # pragma: no cover
1092 if outputs is not None:
1093 if isinstance(outputs, str):
1094 outputs = [outputs]
1095 if not isinstance(outputs, set):
1096 outputs = set(outputs)
1097 ord = OrderedDict()
1098 for output in enumerate_model_node_outputs(self.obj, order=True):
1099 if outputs is not None and output not in outputs:
1100 continue
1101 subonx = select_model_inputs_outputs(
1102 self.obj, outputs=output, infer_shapes=True,
1103 overwrite=overwrite_types)
1104 subonx = onnx_remove_node_unused(subonx)
1105 if verbose > 0:
1106 fLOG( # pragma: no cover
1107 '[build_intermediate] + {}'.format(output))
1108 ord[output] = OnnxInference(subonx, runtime=self.runtime,
1109 skip_run=self.skip_run,
1110 runtime_options=self.runtime_options,
1111 inplace=self.inplace,
1112 input_inplace=self.input_inplace)
1113 if verbose > 0:
1114 fLOG( # pragma: no cover
1115 '[build_intermediate] END.')
1116 return ord
1118 def _run_whole_runtime(self, inputs, clean_right_away=False,
1119 intermediate=False, verbose=0, node_time=False,
1120 overwrite_types=None, yield_ops=None, fLOG=None):
1121 # node_time is unused
1122 if clean_right_away:
1123 raise RuntimeError( # pragma: no cover
1124 "clean_right_away=true does not work with this runtime.")
1125 if intermediate:
1126 if hasattr(self, "intermediate_onnx_inference_"):
1127 inter_run = self.intermediate_onnx_inference_ # pylint: disable=E0203
1128 else:
1129 if verbose > 0:
1130 fLOG( # pragma: no cover
1131 "-- OnnxInference: build intermediate")
1132 inter_run = self.build_intermediate(
1133 verbose=verbose, fLOG=fLOG, overwrite_types=overwrite_types)
1134 self.intermediate_onnx_inference_ = inter_run
1135 graph = self.to_sequence()
1136 self.inits_ = graph['inits']
1138 if verbose >= 1:
1139 fLOG( # pragma: no cover
1140 "-- OnnxInference: run {} nodes".format(
1141 len(self.intermediate_onnx_inference_)))
1142 values = OrderedDict(inputs)
1143 for k, v in self.inits_.items():
1144 values[k] = v['value']
1145 if verbose >= 2: # pragma: no cover
1146 for k in sorted(values):
1147 fLOG("-k='{}' shape={} dtype={}".format(
1148 k, values[k].shape, values[k].dtype))
1149 for node, oinf in self.intermediate_onnx_inference_.items():
1150 if verbose >= 4: # pragma: no cover
1151 fLOG('[intermediate] %r' % node)
1152 if verbose >= 5: # pragma: no cover
1153 fLOG(oinf.obj)
1154 if yield_ops is not None and node.onnx_node.op_type == 'YieldOp':
1155 out = node.onnx_node.output[0]
1156 if out in yield_ops:
1157 values[out] = yield_ops[out]
1158 continue
1159 raise RuntimeError( # pragma: no cover
1160 "YieldOp output %r could not be found in "
1161 "yield_ops: %r (node=%r)." % (
1162 out, list(sorted(yield_ops)), node.onnx_node))
1163 output = oinf.run(inputs)[node]
1164 values[node] = output
1165 if verbose >= 1:
1166 if verbose >= 4: # pragma: no cover
1167 for k, v in inputs.items():
1168 if isinstance(output, numpy.ndarray):
1169 fLOG("-i='{}': {} (dtype={}) {}".format(
1170 k, v.shape, v.dtype, v.ravel().tolist()))
1171 else:
1172 fLOG("-i='{}': {} (dtype={}) - ?".format(
1173 k, v.shape, v.dtype))
1174 if isinstance(output, numpy.ndarray):
1175 fLOG("+k='{}': {} (dtype={})".format( # pragma: no cover
1176 node, output.shape, output.dtype))
1177 if verbose >= 2: # pragma: no cover
1178 fLOG(output)
1179 else:
1180 fLOG("+k='{}': {}".format( # pragma: no cover
1181 node, type(output)))
1182 if verbose >= 2: # pragma: no cover
1183 fLOG(output)
1184 return values
1186 if verbose != 0:
1187 warnings.warn(
1188 "verbose option not implemented if runtime is 'onnxruntime1'")
1189 res = self._whole.run(inputs)
1190 return {k: v for k, v in zip(self.outputs_, res)}
1192 def __getitem__(self, item):
1193 """
1194 Returns the ONNX verions of a node.
1195 """
1196 if isinstance(item, tuple):
1197 node_name, att_name = item
1198 else:
1199 node_name = item
1200 att_name = None
1202 node_ = None
1203 for node in self.obj.graph.node:
1204 if node.name == node_name:
1205 node_ = node
1206 break
1208 if node_ is None:
1209 raise IndexError( # pragma: no cover
1210 "Unable to get node name '{}'.\n{}".format(
1211 node_name, "\n".join(node.name for node in self.obj.graph.node)))
1213 if att_name is None:
1214 return node_
1216 for att in node_.attribute:
1217 if att.name == att_name:
1218 return att
1220 raise IndexError( # pragma: no cover
1221 "Unable to find attribute '{}' from node "
1222 "'{}'.".format(att_name, node_name))
1224 def switch_initializers_dtype(self, model=None,
1225 dtype_in=numpy.float32,
1226 dtype_out=numpy.float64):
1227 """
1228 Switches all initializers to ``numpy.float64``. If *model*
1229 is None, a simple cast is done. Otherwise, the function assumes
1230 the model is a :epkg:`scikit-learn` pipeline.
1231 This only works if the runtime is ``'python'``.
1233 @param model :epkg:`scikit-learn` model or None
1234 @param dtype_in previous type
1235 @param dtype_out next type
1236 @return done operations
1237 """
1238 from ..onnx_tools.optim.sklearn_helper import enumerate_fitted_arrays, pairwise_array_distances
1240 if self.runtime != 'python': # pragma: no cover
1241 raise RuntimeError("Initializers can be casted only if the "
1242 "runtime is 'python' not '{}'.".format(self.runtime))
1244 if hasattr(self, '_values_init'):
1245 del self._values_init
1247 # first pass: simple cast
1248 done = []
1249 initializer = self.inits_
1250 for k, v in initializer.items():
1251 if isinstance(v['value'], numpy.ndarray):
1252 if v['value'].dtype == dtype_in:
1253 v['value'] = v['value'].astype(dtype_out)
1254 done.append(("pass1", "+", "init", k, v['value']))
1255 else:
1256 done.append(("pass1", "-", "init", k,
1257 v['value'])) # pragma: no cover
1258 for k, v in self.graph_['nodes'].items():
1259 res = v.switch_initializers_dtype(dtype_in=dtype_in,
1260 dtype_out=dtype_out)
1261 for r in res:
1262 done.append(("pass1", "node", k) + r)
1263 for k, v in self.graph_['intermediate'].items():
1264 if v is None:
1265 continue
1266 res = v.switch_initializers_dtype(dtype_in=dtype_in,
1267 dtype_out=dtype_out)
1268 for r in res:
1269 done.append(("pass1", "sub", k) + r)
1271 if model is not None:
1272 # Second pass, we compare all arrays from the model
1273 # to the arrays in the converted models.
1274 def dist(a):
1275 cast = a.astype(dtype_in).astype(dtype_out)
1276 d = pairwise_array_distances([cast], [a])[0, 0]
1277 return d
1279 done_ = [(c, c[-1]) for c in done]
1280 moda_ = [(a, a[-2][-1]) for a in enumerate_fitted_arrays(model)
1281 if dist(a[-2][-1]) > 0]
1282 aconv = [_[-1] for _ in done_]
1283 amoda = [_[-1] for _ in moda_]
1284 distances = pairwise_array_distances(aconv, amoda)
1286 for i in range(distances.shape[0]):
1287 j = numpy.argmin(distances[i])
1288 d = distances[i, j]
1289 if d < 0.1:
1290 numpy.copyto(aconv[i], amoda[j])
1291 done.append(("pass2", d) + done_[i][0])
1293 return done
1295 def _set_shape_inference_runtime(self):
1296 """
1297 Set shapes based on shape inference
1298 relying on the runtime.
1299 The values are stored in every node.
1300 """
1301 if not hasattr(self, 'sequence_') or not hasattr(self, 'inputs_'):
1302 raise RuntimeError( # pragma: no cover
1303 "This method only works if the runtime is 'python' not "
1304 "'{}'.".format(self.runtime))
1305 values = OrderedDict()
1306 for k, v in self.inputs_.items():
1307 # The function assumes the first dimension is unknown
1308 # and is the batch size.
1309 try:
1310 values[k] = ShapeObject(v, use_n1=True, name=k)
1311 except TypeError as e: # pragma: no cover
1312 if v['type']['elem'] == 'unk':
1313 impossible = True
1314 values[k] = None
1315 continue
1316 raise TypeError(
1317 "Unable to guess shape for %r (shape=%r)." % (
1318 k, v)) from e
1320 impossible = False
1321 for k, v in self.statics_.items():
1322 # static inputs should be known.
1323 if k not in values:
1324 try:
1325 values[k] = ShapeObject(v)
1326 except TypeError:
1327 # default value is wrong
1328 impossible = True
1329 values[k] = None
1331 for k, v in self.inits_.items():
1332 values[k] = ShapeObject(v['value'], name=k)
1333 last = None
1334 for i, node in enumerate(self.sequence_):
1335 try:
1336 s = node._set_shape_inference_runtime(values)
1337 last = s
1338 except (IndexError, TypeError, KeyError,
1339 AttributeError) as e: # pragma: no cover
1340 rows = []
1341 if last is not None:
1342 for k, v in last.items():
1343 rows.append("{}: {}".format(k, v))
1344 for k in range(i + 1):
1345 rows.append("{} --> {}".format(k, self.sequence_[k]))
1346 if not impossible:
1347 raise RuntimeError("Unable to infer shape of node {}\n{}".format(
1348 i, '\n'.join(rows))) from e
1349 return values
1351 def infer_shapes(self):
1352 """
1353 Computes expected shapes.
1355 :return: dictionary of shapes
1356 """
1357 return self._set_shape_inference_runtime()
1359 def _set_type_inference_runtime(self, inputs=None):
1360 """
1361 Set types based on type inference
1362 relying on the runtime.
1363 The values are stored in every node.
1364 """
1365 if not hasattr(self, 'sequence_') or not hasattr(self, 'inputs_'):
1366 raise RuntimeError( # pragma: no cover
1367 "This method only works if the runtime is 'python' not "
1368 "'{}'.".format(self.runtime))
1370 values = OrderedDict()
1371 for k, v in self.statics_.items():
1372 values[k] = None
1374 if inputs is None:
1375 for k, v in self.inputs_.items():
1376 # The function assumes the first dimension is unknown
1377 # and is the batch size.
1378 if isinstance(v['type']['elem'], dict):
1379 # sequence
1380 values[k] = SequenceType()
1381 else:
1382 values[k] = guess_numpy_type_from_string(v['type']['elem'])
1383 else:
1384 for name, dtype in zip(self.input_names, inputs):
1385 values[name] = dtype
1387 for k, v in self.inits_.items():
1388 values[k] = v['value'].dtype
1390 last = None
1391 for i, node in enumerate(self.sequence_):
1392 try:
1393 s = node._set_type_inference_runtime(values)
1394 last = s
1395 except IndexError as e: # pragma: no cover
1396 rows = []
1397 if last is not None:
1398 for k, v in last.items():
1399 rows.append("{}: {}".format(k, v))
1400 for k in range(i + 1):
1401 rows.append("{} --> {}".format(k, self.sequence_[k]))
1402 raise RuntimeError("Unable to infer type of node {}\n{}".format(
1403 i, '\n'.join(rows))) from e
1404 return values
1406 def infer_types(self, inputs=None):
1407 """
1408 Computes expected shapes.
1410 :param inputs: needed when this class host a function and not a graph
1411 :return: dictionary of types
1412 """
1413 return self._set_type_inference_runtime(inputs)
1415 def _set_size_inference_runtime(self, inputs, context=None):
1416 """
1417 Set sizes allocated during inference
1418 relying on the runtime.
1419 The values are stored in every node.
1420 """
1421 if not hasattr(self, 'sequence_') or not hasattr(self, 'inputs_'):
1422 raise RuntimeError( # pragma: no cover
1423 "This method only works if the runtime is 'python' not "
1424 "'{}'.".format(self.runtime))
1425 values = OrderedDict()
1426 for k, v in self.statics_.items():
1427 if context is None:
1428 raise RuntimeError( # pragma: no cover
1429 "static variable but context is None.")
1430 values[k] = context[k]
1431 for k, v in self.inits_.items():
1432 values[k] = v['value']
1433 for k, v in self.inputs_.items():
1434 if k in inputs:
1435 values[k] = inputs[k]
1437 last = None
1438 for i, node in enumerate(self.sequence_):
1439 try:
1440 s = node._set_size_inference_runtime(values)
1441 last = s
1442 except IndexError as e: # pragma: no cover
1443 rows = []
1444 if last is not None:
1445 for k, v in last.items():
1446 rows.append("{}: {}".format(k, v))
1447 for k in range(i + 1):
1448 rows.append("{} --> {}".format(k, self.sequence_[k]))
1449 raise RuntimeError("Unable to infer size of node {}\n{}".format(
1450 i, '\n'.join(rows))) from e
1451 return values
1453 def infer_sizes(self, inputs, context=None):
1454 """
1455 Computes expected sizes.
1457 :param inputs: inputs as a dictionary
1458 :return: dictionary of dictionary of sizes
1459 """
1460 res = self._set_size_inference_runtime(inputs, context=context)
1461 return {k: v for k, v in res.items() if k.startswith('#')}
1463 def _guess_inplace(self, input_inplace=False):
1464 """
1465 Looks into every node of the graph to see
1466 if there is a way to do the computation
1467 inplace. By default (*input_inplace=False*),
1468 the function assumes inputs cannot be modified
1469 so the first node cannot do inplace computation.
1470 This function only works with the python runtime.
1472 @param input_inplace the computation is allowed
1473 to overwrite the input
1475 This function checks that one node is used only
1476 once and then can be modified by the next node.
1477 Nodes `A`, `C` can be overwritten by the computation.
1478 Node `B` cannot as it is used by two nodes.
1480 .. blockdiag::
1482 diagram {
1483 A -> B -> C -> E;
1484 B -> D;
1485 }
1487 It does not handle specific case such node `B` being
1488 overwritten by node `C` but without changing its shape
1489 and node `D` only needs the shape of `B`. Then `B` could
1490 be overwritten as well.
1491 """
1492 forbid = {}
1493 values = OrderedDict()
1494 for k in self.statics_:
1495 values[k] = dict(inplace=False, to=[], fr=[])
1496 for k in self.inputs_:
1497 values[k] = dict(inplace=input_inplace, to=[], fr=[])
1498 for k in self.inits_:
1499 values[k] = dict(inplace=False, to=[], fr=[])
1500 for node in self.sequence_:
1501 for n in node.inputs:
1502 if n == '':
1503 continue
1504 values[n]['to'].append(node)
1505 for n in node.outputs:
1506 if node.op_type == 'Constant':
1507 # We cannot modify constant.
1508 forbid[n] = node
1509 if n not in values:
1510 values[n] = dict(inplace=None, to=[], fr=[])
1511 values[n]['fr'].append(node)
1513 # checks the number of outputs
1514 outputs = set(self.output_names)
1515 modif = 1
1516 while modif > 0:
1517 modif = 0
1518 for n, v in values.items():
1519 if v['inplace'] is not None:
1520 continue
1521 if n in forbid:
1522 continue
1523 if len(v['to']) == 1:
1524 v['inplace'] = True
1525 modif += 1
1527 # convey the information to every node
1528 inplaces = {}
1529 for n, v in values.items():
1530 if v['inplace']:
1531 inplaces[n] = v
1532 for node in v['to']:
1533 if n in outputs:
1534 continue
1535 node.enable_inplace_compute(n)
1537 return inplaces
1539 def _build_compile_run(self, debug=False):
1540 """
1541 Rewrite the run function in python,
1542 compiles it, and adds it as a method.
1544 @param debug insert debugging code
1545 @return method name, callable object
1547 .. exref::
1548 :title: Run a model with runtime 'python_compiled'
1550 The following code trains a model and compute
1551 the predictions with runtime ``'python_compiled'``.
1552 It converts the onnx graph into a python function
1553 which calls every operator. Its code is printed
1554 below.
1556 .. runpython::
1557 :showcode:
1558 :warningout: DeprecationWarning
1560 import numpy
1561 from sklearn.datasets import load_iris
1562 from sklearn.model_selection import train_test_split
1563 from sklearn.ensemble import AdaBoostClassifier
1564 from sklearn.tree import DecisionTreeClassifier
1565 from mlprodict.onnx_conv import to_onnx
1566 from mlprodict.onnxrt import OnnxInference
1568 iris = load_iris()
1569 X, y = iris.data, iris.target
1570 X_train, X_test, y_train, __ = train_test_split(X, y, random_state=11)
1571 y_train = y_train.astype(numpy.float32)
1572 clr = AdaBoostClassifier(
1573 base_estimator=DecisionTreeClassifier(max_depth=3),
1574 n_estimators=3)
1575 clr.fit(X_train, y_train)
1577 model_def = to_onnx(clr, X_train.astype(numpy.float32),
1578 target_opset=12)
1580 oinf2 = OnnxInference(model_def, runtime='python_compiled')
1581 print(oinf2.run({'X': X_test[:5]}))
1583 # prints out the python function equivalent
1584 # to the onnx graph
1585 print(oinf2)
1586 """
1588 def clean_name(name):
1589 res = name.replace(":", "_").replace('.', '_').replace('/', '_')
1590 if iskeyword(res):
1591 res += '_'
1592 return res
1594 # inits
1595 inputs = self.input_names
1596 code = ['def compiled_run(dict_inputs, yield_ops=None):']
1597 code.append(" if yield_ops is not None:")
1598 code.append(
1599 " raise NotImplementedError('yields_ops should be None.')")
1600 if debug:
1601 code.append(" printed = {}")
1603 context = {}
1605 # static variables
1606 for k in sorted(self.statics_):
1607 code.append(" # static: {0}".format(k))
1608 code.append(" {0} = dict_inputs['{1}']".format(
1609 clean_name(k), k))
1610 if debug:
1611 code.append(
1612 " debug_print('i.{0}', {1}, printed)".format(
1613 clean_name(k), k))
1615 # initializers
1616 for k, v in sorted(self.inits_.items()):
1617 if k.startswith("_OPT_"):
1618 raise RuntimeError( # pragma: no cover
1619 "The runtime cannot handle any constant name "
1620 "starting with '_OPT_': '{}'.".format(k))
1621 if k in inputs:
1622 context["_OPT_" + clean_name(k)] = v['value']
1623 code.append(" # init: _OPT_{0} ({1})".format(
1624 clean_name(k), k))
1625 if debug:
1626 code.append(
1627 " debug_print('c.[_OPT_{0}]', _OPT_{1}, printed)".format(
1628 clean_name(k), k))
1629 else:
1630 context[clean_name(k)] = v['value']
1631 code.append(" # init: {0} ({1})".format(
1632 clean_name(k), k))
1633 if debug:
1634 code.append(
1635 " debug_print('c.[{0}]', {1}, printed)".format(
1636 clean_name(k), k))
1638 # method signature
1639 code.append(" # inputs")
1640 for inp in inputs:
1641 if '_OPT_' + inp in context:
1642 # optional inputs
1643 code.append(
1644 " {0} = dict_inputs.get('{1}', _OPT_{0})".format(
1645 clean_name(inp), inp))
1646 else:
1647 code.append(" {0} = dict_inputs['{1}']".format(
1648 clean_name(inp), inp))
1649 if debug:
1650 code.append(
1651 " debug_print('i.{0}', {1}, printed)".format(
1652 clean_name(inp), inp))
1654 # code
1655 for i, node in enumerate(self.sequence_):
1656 name = "n{}_{}".format(i, node.ops_.__class__.__name__.lower())
1657 if node.ops_ is None:
1658 context[name] = node.function_
1659 # The code of the function should be added but only once.
1660 raise NotImplementedError(
1661 "Not implemented for models including functions.")
1662 else:
1663 context[name] = node.ops_._run
1664 if (node.ops_.__class__.__name__ == 'Loop' and
1665 node.ops_.need_context()):
1666 # Adding context.
1667 ctx = "{%s}" % ", ".join(
1668 "'%s': %s" % (n, n) for n in node.ops_.additional_inputs)
1669 code.append(' ({1}, ) = {2}({0}, context={3})'.format(
1670 ', '.join(map(clean_name, node.inputs)),
1671 ', '.join(map(clean_name, node.outputs)),
1672 name, ctx))
1673 else:
1674 code.append(' ({1}, ) = {2}({0})'.format(
1675 ', '.join(map(clean_name, node.inputs)),
1676 ', '.join(map(clean_name, node.outputs)),
1677 name))
1678 if debug:
1679 code.append(" print('''# {}''')".format(code[-1][4:]))
1680 for o in node.outputs:
1681 code.append(
1682 " debug_print('o.{0}', {1}, printed)".format(
1683 clean_name(o), o))
1685 # return
1686 code.append(' return {')
1687 for out in self.output_names:
1688 code.append(" '{1}': {0},".format(
1689 clean_name(out), out))
1690 code.append(' }')
1691 final_code = '\n'.join(code)
1693 # compile the outcome
1694 context['self'] = self
1695 try:
1696 obj = compile(final_code, "<string>", 'exec')
1697 except SyntaxError as e: # pragma: no cover
1698 raise SyntaxError(
1699 "Unable to compile\n#####\n{}".format(final_code)) from e
1700 fcts_obj = [_ for _ in obj.co_consts
1701 if _ is not None and not isinstance(_, (bool, str, int))]
1702 fct = make_callable(
1703 "compiled_run", fcts_obj[0], final_code, context, debug)
1705 # end
1706 return "compiled_run", fct, final_code
1708 def reduce_size(self, pickable=False):
1709 """
1710 Reduces the memory footprint as much as possible.
1712 @param pickable keeps a pickle object?
1713 """
1714 import gc
1715 del self.graph_
1716 if not pickable:
1717 del self.obj
1718 if self.runtime in ('python_compiled', 'python_compiled_debug'):
1719 del self.sequence_
1720 gc.collect()
1722 def get_profiling(self, as_df=False):
1723 """
1724 Returns the profiling after a couple of execution.
1726 :param as_df: return the results as a dataframe (True)
1727 :return: dataframe or list of dictionaries
1729 .. versionadded:: 0.6
1730 """
1731 if (self.runtime_options is None or
1732 not self.runtime_options.get('enable_profiling', False)):
1733 raise RuntimeError(
1734 "Profiling is available if options 'enable_profiling' "
1735 "is set to true in 'runtime_options' but is %r." % self.runtime_options)
1736 prof = None
1737 if hasattr(self, '_whole'):
1738 prof = self._whole.get_profiling()
1739 if prof is None:
1740 raise NotImplementedError( # pragma: no cover
1741 "profiling is only implemented for runtime 'onnxruntime1'.")
1742 if as_df:
1743 import pandas
1744 return pandas.DataFrame(prof)
1745 return prof
1747 def get_execution_order(self):
1748 """
1749 This function returns a dictionary `{(kind, name): (order, op)}`,
1750 *name* can be a node name or a result name. In that case,
1751 it gets the execution order than the node which created it.
1752 The function returns None if the order is not available
1753 (the selected runtime does not return it). *kind* is either
1754 `'node'` or `'node'`. If two nodes have the same name,
1755 returned order is the last one. Initializers gets an execution
1756 order equal to -1, inputs to 0, all others results are >= 1.
1758 .. versionadded:: 0.7
1759 """
1760 if not hasattr(self, "sequence_"):
1761 return None
1763 res = {}
1764 for k, v in self.inits_.items():
1765 res['res', k] = (-1, v)
1766 for name, shape in self.input_names_shapes:
1767 res['res', name] = (0, shape)
1769 for i, node in enumerate(self.sequence_):
1770 key = ('node', node.onnx_node.name)
1771 res[key] = (i + 1, node)
1772 for out in node.onnx_node.output:
1773 key = ('res', out)
1774 if key in res:
1775 raise RuntimeError( # pragma: no cover
1776 "Output %r of node name %r already registered."
1777 "" % (out, node.onnx_node.name))
1778 res[key] = (i + 1, None)
1780 return res