Coverage for mlprodict/onnxrt/onnx_inference_exports.py: 99%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Extensions to class @see cl OnnxInference.
4"""
5import os
6import json
7import re
8from io import BytesIO
9import pickle
10import textwrap
11from onnx import numpy_helper
12from ..onnx_tools.onnx2py_helper import _var_as_dict, _type_to_string
13from ..tools.graphs import onnx2bigraph
14from ..plotting.text_plot import onnx_simple_text_plot
17class OnnxInferenceExport:
18 """
19 Implements methods to export a instance of
20 @see cl OnnxInference into :epkg:`json`, :epkg:`dot`,
21 *text*, *python*.
22 """
24 def __init__(self, oinf):
25 """
26 @param oinf @see cl OnnxInference
27 """
28 self.oinf = oinf
30 def to_dot(self, recursive=False, prefix='', # pylint: disable=R0914
31 add_rt_shapes=False, use_onnx=False, **params):
32 """
33 Produces a :epkg:`DOT` language string for the graph.
35 :param params: additional params to draw the graph
36 :param recursive: also show subgraphs inside operator like
37 @see cl Scan
38 :param prefix: prefix for every node name
39 :param add_rt_shapes: adds shapes infered from the python runtime
40 :param use_onnx: use :epkg:`onnx` dot format instead of this one
41 :return: string
43 Default options for the graph are:
45 ::
47 options = {
48 'orientation': 'portrait',
49 'ranksep': '0.25',
50 'nodesep': '0.05',
51 'width': '0.5',
52 'height': '0.1',
53 'size': '7',
54 }
56 One example:
58 .. exref::
59 :title: Convert ONNX into DOT
61 An example on how to convert an :epkg:`ONNX`
62 graph into :epkg:`DOT`.
64 .. runpython::
65 :showcode:
66 :warningout: DeprecationWarning
68 import numpy
69 from mlprodict.npy.xop import loadop
70 from mlprodict.onnxrt import OnnxInference
72 OnnxAiOnnxMlLinearRegressor = loadop(
73 ('ai.onnx.ml', 'LinearRegressor'))
75 pars = dict(coefficients=numpy.array([1., 2.]),
76 intercepts=numpy.array([1.]),
77 post_transform='NONE')
78 onx = OnnxAiOnnxMlLinearRegressor(
79 'X', output_names=['Y'], **pars)
80 model_def = onx.to_onnx(
81 {'X': pars['coefficients'].astype(numpy.float32)},
82 outputs={'Y': numpy.float32},
83 target_opset=12)
84 oinf = OnnxInference(model_def)
85 print(oinf.to_dot())
87 See an example of representation in notebook
88 :ref:`onnxvisualizationrst`.
89 """
90 clean_label_reg1 = re.compile("\\\\x\\{[0-9A-F]{1,6}\\}")
91 clean_label_reg2 = re.compile("\\\\p\\{[0-9P]{1,6}\\}")
93 def dot_name(text):
94 return text.replace("/", "_").replace(
95 ":", "__").replace(".", "_")
97 def dot_label(text):
98 for reg in [clean_label_reg1, clean_label_reg2]:
99 fall = reg.findall(text)
100 for f in fall:
101 text = text.replace(f, "_") # pragma: no cover
102 return text
104 options = {
105 'orientation': 'portrait',
106 'ranksep': '0.25',
107 'nodesep': '0.05',
108 'width': '0.5',
109 'height': '0.1',
110 'size': '7',
111 }
112 options.update({k: v for k, v in params.items() if v is not None})
114 if use_onnx:
115 from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
117 pydot_graph = GetPydotGraph(
118 self.oinf.obj.graph, name=self.oinf.obj.graph.name,
119 rankdir=params.get('rankdir', "TB"),
120 node_producer=GetOpNodeProducer(
121 "docstring", fillcolor="orange", style="filled",
122 shape="box"))
123 return pydot_graph.to_string()
125 inter_vars = {}
126 exp = ["digraph{"]
127 for opt in {'orientation', 'pad', 'nodesep', 'ranksep', 'size'}:
128 if opt in options:
129 exp.append(" {}={};".format(opt, options[opt]))
130 fontsize = 10
132 shapes = {}
133 if add_rt_shapes:
134 if not hasattr(self.oinf, 'shapes_'):
135 raise RuntimeError( # pragma: no cover
136 "No information on shapes, check the runtime '{}'.".format(self.oinf.runtime))
137 for name, shape in self.oinf.shapes_.items():
138 va = shape.evaluate().to_string()
139 shapes[name] = va
140 if name in self.oinf.inplaces_:
141 shapes[name] += "\\ninplace"
143 # inputs
144 exp.append("")
145 for obj in self.oinf.obj.graph.input:
146 dobj = _var_as_dict(obj)
147 sh = shapes.get(dobj['name'], '')
148 if sh:
149 sh = "\\nshape={}".format(sh)
150 exp.append(
151 ' {3}{0} [shape=box color=red label="{0}\\n{1}{4}" fontsize={2}];'.format(
152 dot_name(dobj['name']), _type_to_string(dobj['type']),
153 fontsize, prefix, dot_label(sh)))
154 inter_vars[obj.name] = obj
156 # outputs
157 exp.append("")
158 for obj in self.oinf.obj.graph.output:
159 dobj = _var_as_dict(obj)
160 sh = shapes.get(dobj['name'], '')
161 if sh:
162 sh = "\\nshape={}".format(sh)
163 exp.append(
164 ' {3}{0} [shape=box color=green label="{0}\\n{1}{4}" fontsize={2}];'.format(
165 dot_name(dobj['name']), _type_to_string(dobj['type']),
166 fontsize, prefix, dot_label(sh)))
167 inter_vars[obj.name] = obj
169 # initializer
170 exp.append("")
171 for obj in self.oinf.obj.graph.initializer:
172 dobj = _var_as_dict(obj)
173 val = dobj['value']
174 flat = val.flatten()
175 if flat.shape[0] < 9:
176 st = str(val)
177 else:
178 st = str(val)
179 if len(st) > 50:
180 st = st[:50] + '...'
181 st = st.replace('\n', '\\n')
182 kind = ""
183 exp.append(
184 ' {6}{0} [shape=box label="{0}\\n{4}{1}({2})\\n{3}" fontsize={5}];'.format(
185 dot_name(dobj['name']), dobj['value'].dtype,
186 dobj['value'].shape, dot_label(st), kind, fontsize, prefix))
187 inter_vars[obj.name] = obj
189 # nodes
190 fill_names = {}
191 static_inputs = [n.name for n in self.oinf.obj.graph.input]
192 static_inputs.extend(n.name for n in self.oinf.obj.graph.initializer)
193 for node in self.oinf.obj.graph.node:
194 exp.append("")
195 for out in node.output:
196 if len(out) > 0 and out not in inter_vars:
197 inter_vars[out] = out
198 sh = shapes.get(out, '')
199 if sh:
200 sh = "\\nshape={}".format(sh)
201 exp.append(
202 ' {2}{0} [shape=box label="{0}{3}" fontsize={1}];'.format(
203 dot_name(out), fontsize, dot_name(prefix), dot_label(sh)))
204 static_inputs.append(out)
206 dobj = _var_as_dict(node)
207 if dobj['name'].strip() == '': # pragma: no cover
208 name = node.op_type
209 iname = 1
210 while name in fill_names:
211 name = "%s%d" % (name, iname)
212 iname += 1
213 dobj['name'] = name
214 node.name = name
215 fill_names[name] = node
217 atts = []
218 if 'atts' in dobj:
219 for k, v in sorted(dobj['atts'].items()):
220 val = None
221 if 'value' in v:
222 val = str(v['value']).replace(
223 "\n", "\\n").replace('"', "'")
224 sl = max(30 - len(k), 10)
225 if len(val) > sl:
226 val = val[:sl] + "..."
227 if val is not None:
228 atts.append('{}={}'.format(k, val))
229 satts = "" if len(atts) == 0 else ("\\n" + "\\n".join(atts))
231 connects = []
232 if recursive and node.op_type in {'Scan', 'Loop', 'If'}:
233 fields = (['then_branch', 'else_branch']
234 if node.op_type == 'If' else ['body'])
235 for field in fields:
236 if field not in dobj['atts']:
237 continue # pragma: no cover
239 # creates the subgraph
240 body = dobj['atts'][field]['value']
241 oinf = self.oinf.__class__(
242 body, runtime=self.oinf.runtime, skip_run=self.oinf.skip_run,
243 static_inputs=static_inputs)
244 subprefix = prefix + "B_"
245 subdot = oinf.to_dot(recursive=recursive, prefix=subprefix,
246 add_rt_shapes=add_rt_shapes)
247 lines = subdot.split("\n")
248 start = 0
249 for i, line in enumerate(lines):
250 if '[' in line:
251 start = i
252 break
253 subgraph = "\n".join(lines[start:])
255 # connecting the subgraph
256 cluster = "cluster_{}{}_{}".format(
257 node.op_type, id(node), id(field))
258 exp.append(" subgraph {} {{".format(cluster))
259 exp.append(' label="{0}\\n({1}){2}";'.format(
260 dobj['op_type'], dot_name(dobj['name']), satts))
261 exp.append(' fontsize={0};'.format(fontsize))
262 exp.append(' color=black;')
263 exp.append(
264 '\n'.join(map(lambda s: ' ' + s, subgraph.split('\n'))))
266 node0 = body.node[0]
267 connects.append((
268 "{}{}".format(dot_name(subprefix),
269 dot_name(node0.name)),
270 cluster))
272 for inp1, inp2 in zip(node.input, body.input):
273 exp.append(
274 " {0}{1} -> {2}{3};".format(
275 dot_name(prefix), dot_name(inp1),
276 dot_name(subprefix), dot_name(inp2.name)))
277 for out1, out2 in zip(body.output, node.output):
278 if len(out2) == 0:
279 # Empty output, it cannot be used.
280 continue
281 exp.append(
282 " {0}{1} -> {2}{3};".format(
283 dot_name(subprefix), dot_name(out1.name),
284 dot_name(prefix), dot_name(out2)))
285 else:
286 exp.append(' {4}{1} [shape=box style="filled,rounded" color=orange label="{0}\\n({1}){2}" fontsize={3}];'.format(
287 dobj['op_type'], dot_name(dobj['name']), satts, fontsize,
288 dot_name(prefix)))
290 if connects is not None and len(connects) > 0:
291 for name, cluster in connects:
292 exp.append(
293 " {0}{1} -> {2} [lhead={3}];".format(
294 dot_name(prefix), dot_name(node.name),
295 name, cluster))
297 for inp in node.input:
298 exp.append(
299 " {0}{1} -> {0}{2};".format(
300 dot_name(prefix), dot_name(inp), dot_name(node.name)))
301 for out in node.output:
302 if len(out) == 0:
303 # Empty output, it cannot be used.
304 continue
305 exp.append(
306 " {0}{1} -> {0}{2};".format(
307 dot_name(prefix), dot_name(node.name), dot_name(out)))
309 exp.append('}')
310 return "\n".join(exp)
312 def to_json(self, indent=2):
313 """
314 Converts an :epkg:`ONNX` model into :epkg:`JSON`.
316 @param indent indentation
317 @return string
319 .. exref::
320 :title: Convert ONNX into JSON
322 An example on how to convert an :epkg:`ONNX`
323 graph into :epkg:`JSON`.
325 .. runpython::
326 :showcode:
327 :warningout: DeprecationWarning
329 import numpy
330 from mlprodict.npy.xop import loadop
331 from mlprodict.onnxrt import OnnxInference
333 OnnxAiOnnxMlLinearRegressor = loadop(
334 ('ai.onnx.ml', 'LinearRegressor'))
336 pars = dict(coefficients=numpy.array([1., 2.]),
337 intercepts=numpy.array([1.]),
338 post_transform='NONE')
339 onx = OnnxAiOnnxMlLinearRegressor(
340 'X', output_names=['Y'], **pars)
341 model_def = onx.to_onnx(
342 {'X': pars['coefficients'].astype(numpy.float32)},
343 outputs={'Y': numpy.float32},
344 target_opset=12)
345 oinf = OnnxInference(model_def)
346 print(oinf.to_json())
347 """
349 def _to_json(obj):
350 s = str(obj)
351 rows = ['{']
352 leave = None
353 for line in s.split('\n'):
354 if line.endswith("{"):
355 rows.append('"%s": {' % line.strip('{ '))
356 elif ':' in line:
357 spl = line.strip().split(':')
358 if len(spl) != 2:
359 raise RuntimeError( # pragma: no cover
360 "Unable to interpret line '{}'.".format(line))
362 if spl[0].strip() in ('type', ):
363 st = spl[1].strip()
364 if st in {'INT', 'INTS', 'FLOAT', 'FLOATS',
365 'STRING', 'STRINGS', 'TENSOR'}:
366 spl[1] = '"{}"'.format(st)
368 if spl[0] in ('floats', 'ints'):
369 if leave:
370 rows.append("{},".format(spl[1]))
371 else:
372 rows.append('"{}": [{},'.format(
373 spl[0], spl[1].strip()))
374 leave = spl[0]
375 elif leave:
376 rows[-1] = rows[-1].strip(',')
377 rows.append('],')
378 rows.append('"{}": {},'.format(
379 spl[0].strip(), spl[1].strip()))
380 leave = None
381 else:
382 rows.append('"{}": {},'.format(
383 spl[0].strip(), spl[1].strip()))
384 elif line.strip() == "}":
385 rows[-1] = rows[-1].rstrip(",")
386 rows.append(line + ",")
387 elif line:
388 raise RuntimeError( # pragma: no cover
389 "Unable to interpret line '{}'.".format(line))
390 rows[-1] = rows[-1].rstrip(',')
391 rows.append("}")
392 js = "\n".join(rows)
394 try:
395 content = json.loads(js)
396 except json.decoder.JSONDecodeError as e: # pragma: no cover
397 js2 = "\n".join("%04d %s" % (i + 1, line)
398 for i, line in enumerate(js.split("\n")))
399 raise RuntimeError(
400 "Unable to parse JSON\n{}".format(js2)) from e
401 return content
403 # meta data
404 final_obj = {}
405 for k in {'ir_version', 'producer_name', 'producer_version',
406 'domain', 'model_version', 'doc_string'}:
407 if hasattr(self.oinf.obj, k):
408 final_obj[k] = getattr(self.oinf.obj, k)
410 # inputs
411 inputs = []
412 for obj in self.oinf.obj.graph.input:
413 st = _to_json(obj)
414 inputs.append(st)
415 final_obj['inputs'] = inputs
417 # outputs
418 outputs = []
419 for obj in self.oinf.obj.graph.output:
420 st = _to_json(obj)
421 outputs.append(st)
422 final_obj['outputs'] = outputs
424 # init
425 inits = {}
426 for obj in self.oinf.obj.graph.initializer:
427 value = numpy_helper.to_array(obj).tolist()
428 inits[obj.name] = value
429 final_obj['initializers'] = inits
431 # nodes
432 nodes = []
433 for obj in self.oinf.obj.graph.node:
434 node = dict(name=obj.name, op_type=obj.op_type, domain=obj.domain,
435 inputs=[str(_) for _ in obj.input],
436 outputs=[str(_) for _ in obj.output],
437 attributes={})
438 for att in obj.attribute:
439 st = _to_json(att)
440 node['attributes'][st['name']] = st
441 del st['name']
442 nodes.append(node)
443 final_obj['nodes'] = nodes
445 return json.dumps(final_obj, indent=indent)
447 def to_python(self, prefix="onnx_pyrt_", dest=None, inline=True):
448 """
449 Converts the ONNX runtime into independant python code.
450 The function creates multiple files starting with
451 *prefix* and saved to folder *dest*.
453 @param prefix file prefix
454 @param dest destination folder
455 @param inline constant matrices are put in the python file itself
456 as byte arrays
457 @return file dictionary
459 The function does not work if the chosen runtime
460 is not *python*.
462 .. runpython::
463 :showcode:
464 :warningout: DeprecationWarning
466 import numpy
467 from mlprodict.npy.xop import loadop
468 from mlprodict.onnxrt import OnnxInference
470 OnnxAdd = loadop('Add')
472 idi = numpy.identity(2).astype(numpy.float32)
473 onx = OnnxAdd('X', idi, output_names=['Y'],
474 op_version=12)
475 model_def = onx.to_onnx({'X': idi},
476 target_opset=12)
477 X = numpy.array([[1, 2], [3, 4]], dtype=numpy.float32)
478 oinf = OnnxInference(model_def, runtime='python')
479 res = oinf.to_python()
480 print(res['onnx_pyrt_main.py'])
481 """
482 if not isinstance(prefix, str):
483 raise TypeError( # pragma: no cover
484 "prefix must be a string not %r." % type(prefix))
486 def clean_args(args):
487 new_args = []
488 for v in args:
489 # remove python keywords
490 if v.startswith('min='):
491 av = 'min_=' + v[4:]
492 elif v.startswith('max='):
493 av = 'max_=' + v[4:]
494 else:
495 av = v
496 new_args.append(av)
497 return new_args
499 if self.oinf.runtime != 'python':
500 raise ValueError(
501 "The runtime must be 'python' not '{}'.".format(
502 self.oinf.runtime))
504 # metadata
505 obj = {}
506 for k in {'ir_version', 'producer_name', 'producer_version',
507 'domain', 'model_version', 'doc_string'}:
508 if hasattr(self.oinf.obj, k):
509 obj[k] = getattr(self.oinf.obj, k)
510 code_begin = ["# coding: utf-8",
511 "'''",
512 "Python code equivalent to an ONNX graph.",
513 "It was was generated by module *mlprodict*.",
514 "'''"]
515 code_imports = ["from io import BytesIO",
516 "import pickle",
517 "from numpy import array, float32, ndarray"]
518 code_lines = ["class OnnxPythonInference:", "",
519 " def __init__(self):",
520 " self._load_inits()", "",
521 " @property",
522 " def metadata(self):",
523 " return %r" % obj, ""]
525 # inputs
526 if hasattr(self.oinf.obj, 'graph'):
527 inputs = [obj.name for obj in self.oinf.obj.graph.input]
528 outputs = [obj.name for obj in self.oinf.obj.graph.output]
529 else:
530 inputs = list(self.oinf.obj.input)
531 outputs = list(self.oinf.obj.output)
533 code_lines.extend([
534 " @property", " def inputs(self):",
535 " return %r" % inputs,
536 ""
537 ])
539 # outputs
540 code_lines.extend([
541 " @property", " def outputs(self):",
542 " return %r" % outputs,
543 ""
544 ])
546 # init
547 code_lines.extend([" def _load_inits(self):",
548 " self._inits = {}"])
549 file_data = {}
550 if hasattr(self.oinf.obj, 'graph'):
551 for obj in self.oinf.obj.graph.initializer:
552 value = numpy_helper.to_array(obj)
553 bt = BytesIO()
554 pickle.dump(value, bt)
555 name = '{1}{0}.pkl'.format(obj.name, prefix)
556 if inline:
557 code_lines.extend([
558 " iocst = %r" % bt.getvalue(),
559 " self._inits['{0}'] = pickle.loads(iocst)".format(
560 obj.name)
561 ])
562 else:
563 file_data[name] = bt.getvalue()
564 code_lines.append(
565 " self._inits['{0}'] = pickle.loads('{1}')".format(
566 obj.name, name))
567 code_lines.append('')
569 # inputs, outputs
570 inputs = self.oinf.input_names
572 # nodes
573 code_lines.extend([' def run(self, %s):' % ', '.join(inputs)])
574 ops = {}
575 if hasattr(self.oinf.obj, 'graph'):
576 code_lines.append(' # constant')
577 for obj in self.oinf.obj.graph.initializer:
578 code_lines.append(
579 " {0} = self._inits['{0}']".format(obj.name))
580 code_lines.append('')
581 code_lines.append(' # graph code')
582 for node in self.oinf.sequence_:
583 fct = 'pyrt_' + node.name
584 if fct not in ops:
585 ops[fct] = node
586 args = []
587 args.extend(node.inputs)
588 margs = node.modified_args
589 if margs is not None:
590 args.extend(clean_args(margs))
591 code_lines.append(" {0} = {1}({2})".format(
592 ', '.join(node.outputs), fct, ', '.join(args)))
593 code_lines.append('')
594 code_lines.append(' # return')
595 code_lines.append(' return %s' % ', '.join(outputs))
596 code_lines.append('')
598 # operator code
599 code_nodes = []
600 for name, op in ops.items():
601 inputs_args = clean_args(op.inputs_args)
603 code_nodes.append('def {0}({1}):'.format(
604 name, ', '.join(inputs_args)))
605 imps, code = op.to_python(op.python_inputs)
606 if imps is not None:
607 if not isinstance(imps, list):
608 imps = [imps]
609 code_imports.extend(imps)
610 code_nodes.append(textwrap.indent(code, ' '))
611 code_nodes.extend(['', ''])
613 # end
614 code_imports = list(sorted(set(code_imports)))
615 code_imports.extend(['', ''])
616 file_data[prefix + 'main.py'] = "\n".join(
617 code_begin + code_imports + code_nodes + code_lines)
619 # saves as files
620 if dest is not None:
621 for k, v in file_data.items():
622 ext = os.path.splitext(k)[-1]
623 kf = os.path.join(dest, k)
624 if ext == '.py':
625 with open(kf, "w", encoding="utf-8") as f:
626 f.write(v)
627 elif ext == '.pkl': # pragma: no cover
628 with open(kf, "wb") as f:
629 f.write(v)
630 else:
631 raise NotImplementedError( # pragma: no cover
632 "Unknown extension for file '{}'.".format(k))
633 return file_data
635 def to_text(self, recursive=False, grid=5, distance=5, kind='bi'):
636 """
637 It calls function @see fn onnx2bigraph to return
638 the ONNX graph as text.
640 :param recursive: dig into subgraphs too
641 :param grid: align text to this grid
642 :param distance: distance to the text
643 :param kind: see below
644 :return: text
646 Possible values for format:
647 * `'bi'`: use @see fn onnx2bigraph
648 * `'seq'`: use @see fn onnx_simple_text_plot
649 """
650 if kind == 'bi':
651 bigraph = onnx2bigraph(self.oinf.obj, recursive=recursive)
652 graph = bigraph.display_structure(grid=grid, distance=distance)
653 return graph.to_text()
654 if kind == 'seq':
655 return onnx_simple_text_plot(self.oinf.obj)
656 raise ValueError( # pragma: no cover
657 "Unexpected value for format=%r." % format)
659 def to_onnx_code(self):
660 """
661 Exports the ONNX graph into an :epkg:`onnx` code
662 which replicates it.
664 :return: string
665 """
666 # Lazy import as it is not a common use.
667 from ..onnx_tools.onnx_export import export2onnx
668 return export2onnx(self.oinf.obj)