Coverage for mlprodict/onnx_tools/onnx_manipulations.py: 95%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Implements a class able to compute the predictions
4from on an :epkg:`ONNX` model.
5"""
6import hashlib
7from onnx import helper, shape_inference
8from .onnx2py_helper import guess_proto_dtype, from_array
9from .optim import onnx_remove_node_unused
12def enumerate_model_node_outputs(model, add_node=False, order=False):
13 """
14 Enumerates all the nodes of a model.
16 :param model: :epkg:`ONNX` graph
17 :param add_node: if False, the function enumerates
18 all output names from every node, otherwise, it
19 enumerates tuple (output name, node)
20 :param order: goes through outputs following the graph order
21 :return: enumerator
22 """
23 if not hasattr(model, "graph"):
24 raise TypeError( # pragma: no cover
25 "Parameter model is not an ONNX model but "
26 "{}".format(type(model)))
27 if order:
28 edges = []
29 order = {}
30 node_names = {}
31 for inp in model.graph.input:
32 order[0, inp.name] = 0
33 for node in model.graph.node:
34 order[1, node.name] = 0
35 for i in node.input:
36 edges.append(('in', i, node.name))
37 for o in node.output:
38 edges.append(('out', o, node.name))
39 node_names[o] = node
40 order[0, o] = 0
42 modif = 1
43 while modif > 0:
44 modif = 0
45 for kind, data_name, node_name in edges:
46 if kind == 'in':
47 if (0, data_name) not in order:
48 continue
49 if order[0, data_name] + 1 > order[1, node_name]:
50 modif += 1
51 order[1, node_name] = order[0, data_name] + 1
52 else:
53 if order[1, node_name] + 1 > order[0, data_name]:
54 modif += 1
55 order[0, data_name] = order[1, node_name] + 1
57 orders = [(v, k) for k, v in order.items()]
58 orders.sort()
60 for _, k in orders:
61 if k[0] == 1:
62 continue
63 out = k[1]
64 if out not in node_names:
65 continue
66 yield (out, node_names[out]) if add_node else out
67 else:
68 for node in model.graph.node:
69 for out in node.output:
70 yield (out, node) if add_node else out
73def select_model_inputs_outputs(model, outputs=None, inputs=None,
74 infer_shapes=False, overwrite=None,
75 remove_unused=True,
76 verbose=0, fLOG=None):
77 """
78 Takes a model and changes its outputs.
80 :param model: :epkg:`ONNX` model
81 :param inputs: new inputs, same ones if None
82 :param outputs: new outputs, same ones if None
83 :param infer_shapes: infer inputs and outputs shapes
84 :param overwrite: overwrite type and shapes for
85 inputs or outputs, *overwrite* is a
86 dictionary `{'name': (numpy dtype, shape)}`
87 :param remove_unused: remove unused nodes from the graph
88 :param verbose: display information while converting
89 :param fLOG: logging function
90 :return: modified model
92 The function removes unneeded nodes.
94 .. exref::
95 :title: Change ONNX model inputs
97 The following exampels shows how to change the inputs of model
98 to bypass the first nodes. Shape inferences fails to determine
99 the new inputs type. They need to be overwritten.
100 `verbose=1, fLOG=print` shows the number of deleted nodes.
102 ::
104 import onnx
105 from mlprodict.onnx_tools.onnx_manipulations import select_model_inputs_outputs
107 onx = onnx.load(path)
108 onx2 = select_model_inputs_outputs(
109 onx, inputs=["SentenceTokenizer/SentencepieceTokenizeOp:0",
110 "SentenceTokenizer/SentencepieceTokenizeOp:1"],
111 infer_shapes=True, verbose=1, fLOG=print,
112 overwrite={'SentenceTokenizer/SentencepieceTokenizeOp:0': (numpy.int32, None),
113 'SentenceTokenizer/SentencepieceTokenizeOp:1': (numpy.int64, None)})
114 onnx.save(onx2, path2)
116 .. versionchanged:: 0.6
117 Supports the case where inputs are changed.
119 .. versionchanged:: 0.7
120 Parameter *remove_unused* was added. Unused are removed by default.
121 """
122 if inputs is not None and not isinstance(inputs, list):
123 inputs = [inputs]
124 if outputs is not None and not isinstance(outputs, list):
125 outputs = [outputs]
126 if inputs is None:
127 inputs = [i.name for i in model.graph.input]
128 if outputs is None:
129 outputs = [o.name for o in model.graph.output]
131 mark_var = {}
132 for out in enumerate_model_node_outputs(model):
133 mark_var[out] = 0
134 for inp in inputs:
135 mark_var[inp] = 0
136 for out in outputs:
137 if out not in mark_var:
138 raise ValueError( # pragma: no cover
139 "Output '{}' not found in model.".format(out))
140 mark_var[out] = 1
142 nodes = model.graph.node[::-1]
143 mark_op = {}
144 for node in nodes:
145 mark_op[node.name] = 0
147 # We mark all the nodes we need to keep.
148 nb = 1
149 while nb > 0:
150 nb = 0
151 for node in nodes:
152 if mark_op[node.name] == 1:
153 continue
154 mod = False
155 for out in node.output:
156 if mark_var[out] == 1:
157 mark_op[node.name] = 1
158 mod = True
159 break
160 if not mod:
161 continue
163 nb += 1
164 for inp in node.input:
165 if inp in inputs:
166 continue
167 if mark_var.get(inp, 0) == 1:
168 continue
169 mark_var[inp] = 1
170 nb += 1
172 # All nodes verifies mark_op[node.name] == 1
173 keep_nodes = [node for node in nodes if mark_op[node.name] == 1]
175 known_shapes = {}
176 if infer_shapes:
177 shapes = shape_inference.infer_shapes(model)
178 for shape in shapes.graph.value_info: # pylint: disable=E1101
179 known_shapes[shape.name] = shape.type
180 for shape in shapes.graph.input: # pylint: disable=E1101
181 known_shapes[shape.name] = shape.type
182 for shape in shapes.graph.output: # pylint: disable=E1101
183 known_shapes[shape.name] = shape.type
184 else:
185 for shape in model.graph.input: # pylint: disable=E1101
186 known_shapes[shape.name] = shape.type
187 for shape in model.graph.output: # pylint: disable=E1101
188 known_shapes[shape.name] = shape.type
190 var_in = []
191 for name in inputs:
192 if overwrite is not None and name in overwrite:
193 dtype, shape = overwrite[name]
194 proto_dtype = guess_proto_dtype(dtype)
195 value_info = helper.make_tensor_value_info(
196 name, proto_dtype, shape)
197 elif name in known_shapes:
198 info = known_shapes[name].tensor_type
199 proto_dtype = info.elem_type
200 if proto_dtype == 0:
201 value_info = helper.ValueInfoProto()
202 value_info.name = name
203 else:
204 shape = [getattr(d, 'dim_value', None) for d in info.shape.dim]
205 if len(shape) == 0:
206 shape = None
207 else:
208 shape = [None if s == 0 else s for s in shape]
209 value_info = helper.make_tensor_value_info(
210 name, proto_dtype, shape)
211 else:
212 value_info = helper.ValueInfoProto()
213 value_info.name = name
214 var_in.append(value_info)
216 var_out = []
217 for name in outputs:
218 if overwrite is not None and name in overwrite:
219 dtype, shape = overwrite[name]
220 proto_dtype = guess_proto_dtype(dtype)
221 value_info = helper.make_tensor_value_info(
222 name, proto_dtype, shape)
223 elif name in known_shapes:
224 info = known_shapes[name].tensor_type
225 proto_dtype = info.elem_type
226 if proto_dtype == 0:
227 value_info = helper.ValueInfoProto()
228 value_info.name = name
229 else:
230 shape = [getattr(d, 'dim_value', None) for d in info.shape.dim]
231 if len(shape) == 0:
232 shape = None
233 else:
234 shape = [None if s == 0 else s for s in shape]
235 value_info = helper.make_tensor_value_info(
236 name, proto_dtype, shape)
237 else:
238 value_info = helper.ValueInfoProto()
239 value_info.name = name
240 var_out.append(value_info)
242 if verbose > 0 and fLOG is not None: # pragma: no cover
243 fLOG("[select_model_inputs_outputs] nodes %r --> %r" % (
244 len(model.graph.node), len(keep_nodes)))
245 fLOG("[select_model_inputs_outputs] inputs: %r" % var_in)
246 fLOG("[select_model_inputs_outputs] inputs: %r" % var_out)
248 graph = helper.make_graph(keep_nodes, model.graph.name, var_in,
249 var_out, model.graph.initializer)
250 onnx_model = helper.make_model(graph)
251 onnx_model.ir_version = model.ir_version
252 onnx_model.producer_name = model.producer_name
253 onnx_model.producer_version = model.producer_version
254 onnx_model.domain = model.domain
255 onnx_model.model_version = model.model_version
256 onnx_model.doc_string = model.doc_string
257 if len(model.metadata_props) > 0: # pragma: no cover
258 values = {p.key: p.value for p in model.metadata_props}
259 helper.set_model_props(onnx_model, values)
261 del onnx_model.opset_import[:] # pylint: disable=E1101
262 for oimp in model.opset_import:
263 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
264 op_set.domain = oimp.domain
265 op_set.version = oimp.version
267 # remove unused nodes
268 if remove_unused:
269 onnx_model = onnx_remove_node_unused(onnx_model, recursive=False)
271 return onnx_model
274def overwrite_opset(model, new_opset):
275 """
276 Overwrites the main opset in an ONNX file.
277 Does not change any node definition.
279 :param model: ONNX model
280 :param new_opset: new opset
281 :return: ONNX model
282 """
283 graph = helper.make_graph(
284 model.graph.node, model.graph.name, model.graph.input,
285 model.graph.output, model.graph.initializer)
286 onnx_model = helper.make_model(graph)
287 onnx_model.ir_version = model.ir_version
288 onnx_model.producer_name = model.producer_name
289 onnx_model.producer_version = model.producer_version
290 onnx_model.domain = model.domain
291 onnx_model.model_version = model.model_version
292 onnx_model.doc_string = model.doc_string
293 if len(model.metadata_props) > 0: # pragma: no cover
294 values = {p.key: p.value for p in model.metadata_props}
295 helper.set_model_props(onnx_model, values)
297 del onnx_model.opset_import[:] # pylint: disable=E1101
298 for oimp in model.opset_import:
299 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
300 if oimp.domain == '':
301 op_set.domain = oimp.domain
302 op_set.version = new_opset
303 else:
304 op_set.domain = oimp.domain
305 op_set.version = oimp.version
306 return onnx_model
309def hash_onnx_object(obj, max_size):
310 """
311 Hash the content of an object.
312 """
313 m = hashlib.sha256()
314 if hasattr(obj, 'op_type'):
315 # An operator.
316 m.update(obj.op_type.encode('ascii'))
317 m.update(str(len(obj.input)).encode('ascii'))
318 m.update(str(len(obj.output)).encode('ascii'))
319 if hasattr(obj, 'attribute'):
320 for att in obj.attribute:
321 m.update(att.name.encode('ascii'))
322 m.update(att.SerializeToString())
323 else:
324 # An initializer.
325 name = obj.name
326 docf = obj.doc_string
327 obj.name = ''
328 obj.doc_string = ''
329 try:
330 m.update(obj.SerializeToString())
331 except AttributeError as e: # pragma: no cover
332 raise RuntimeError(
333 "Unable to hash object type %r, value=%r."
334 "" % (type(obj), obj)) from e
335 finally:
336 obj.name = name
337 obj.doc_string = docf
339 content = m.hexdigest()
340 if len(content) > max_size:
341 content = content[:max_size]
342 return content.upper()
345def onnx_rename_names(model, strategy='simple', recursive=True,
346 verbose=0, fLOG=print,
347 counts=None, replace=None, taken=None):
348 """
349 Renames all names except the inputs and outputs.
351 :param model: onnx model
352 :param strategy: two strategies are implemented, see below
353 :param recursive: walk through subgraphs
354 :param verbose: verbose, if positive, reports on all changed names
355 :param fLOG: logging function
356 :param counts: used for recursion
357 :param replace: used for recursion, it can be also used to
358 to fix some replacements
359 :param taken: used for recursion
360 :return: onnx model (the model is modified in place)
362 Strategies:
364 * `'simple'`: use a letter `n` for node, `r`, `i` for initializer,
365 this letter is followed by a number
366 * `'type'`: the name depends on the node type and content,
367 the hash is kept as small as possible
368 """
369 counts = counts or {'init': 0, 'node': 0, 'result': 0}
370 replace = replace or {}
371 taken = taken or set()
372 graph = model.graph if hasattr(model, 'graph') else model
374 for obj in graph.input:
375 replace[obj.name] = obj.name
376 for obj in graph.output:
377 replace[obj.name] = obj.name
379 def _check_name_simple(prefix):
380 if prefix not in replace:
381 return prefix
382 c = 1
383 final = "%s_%d" % (prefix, c)
384 while final in taken:
385 c += 1
386 final = "%s_%d" % (prefix, c)
387 taken.add(final)
388 return final
390 def _check_name_type(obj, prefix):
391 c = 2
392 hash = hash_onnx_object(obj, c)
393 final = "%s_%s" % (prefix, hash)
394 while final in taken:
395 c += 2
396 hash = hash_onnx_object(obj, c)
397 final = "%s_%s" % (prefix, hash)
398 taken.add(final)
399 return final
401 def get_name_init(init):
402 if init.name in replace:
403 return replace[init.name]
404 if strategy == 'simple':
405 name = _check_name_simple('i%d' % counts['init'])
406 counts['init'] += 1
407 replace[init.name] = name
408 if verbose > 0 and fLOG is not None:
409 fLOG('[onnx_rename_names] init: %r -> %r' % (init.name, name))
410 return name
411 if strategy == 'type':
412 name = _check_name_type(init, 'i')
413 counts['init'] += 1
414 replace[init.name] = name
415 if verbose > 0 and fLOG is not None:
416 fLOG('[onnx_rename_names] init: %r -> %r' % (init.name, name))
417 return name
418 raise ValueError( # pragma: no cover
419 "Unknown strategy %r." % strategy)
421 def get_name_node(node):
422 node_name = 'node_%s_%d' % (node.name, id(node))
423 if node_name in replace:
424 return replace[node_name]
425 if strategy == 'simple':
426 name = _check_name_simple('n%d' % counts['node'])
427 counts['node'] += 1
428 replace[node_name] = name
429 if verbose > 0 and fLOG is not None:
430 fLOG('[onnx_rename_names] node: %r -> %r' % (node_name, name))
431 return name
432 if strategy == 'type':
433 name = _check_name_type(node, 'n')
434 counts['node'] += 1
435 replace[node_name] = name
436 if verbose > 0 and fLOG is not None:
437 fLOG('[onnx_rename_names] node: %r -> %r' % (node_name, name))
438 return name
439 raise ValueError( # pragma: no cover
440 "Unknown strategy %r." % strategy)
442 def get_name_result(node, i, name, suffix):
443 if name in replace:
444 return replace[name]
445 if strategy == 'simple':
446 new_name = _check_name_simple('r%d' % counts['result'])
447 counts['result'] += 1
448 replace[name] = new_name
449 if verbose > 0 and fLOG is not None:
450 fLOG('[onnx_rename_names] result: %r -> %r' % (name, new_name))
451 return new_name
452 if strategy == 'type':
453 new_name = _check_name_type(node, 'r%s%d' % (suffix, i))
454 counts['result'] += 1
455 replace[name] = new_name
456 if verbose > 0 and fLOG is not None:
457 fLOG('[onnx_rename_names] result: %r -> %r' % (name, new_name))
458 return new_name
459 raise ValueError( # pragma: no cover
460 "Unknown strategy %r." % strategy)
462 def get_name_input(node, i):
463 return get_name_result(node, i, node.input[i], 'i')
465 def get_name_output(node, i):
466 return get_name_result(node, i, node.output[i], 'o')
468 for init in graph.initializer:
469 init.name = get_name_init(init)
471 for node in graph.node:
472 node.name = get_name_node(node)
473 for i in range(len(node.input)): # pylint: disable=C0200
474 node.input[i] = get_name_input(node, i)
475 for i in range(len(node.output)): # pylint: disable=C0200
476 node.output[i] = get_name_output(node, i)
477 if not recursive or node.op_type not in {'Scan', 'If', 'Loop'}:
478 continue
479 # recursion
480 for att in node.attribute:
481 if att.name not in {'if_branch', 'else_branch', 'body'}:
482 continue
483 onnx_rename_names(
484 att.g, strategy=strategy, fLOG=fLOG, verbose=verbose,
485 counts=counts, replace=replace, taken=taken)
487 return model
490def insert_results_into_onnx(model, results, as_parameter=True, suffix='_DBG',
491 param_name=None, node_type='DEBUG',
492 domain='DEBUG', domain_opset=1):
493 """
494 Inserts results into an ONNX graph to produce an extended
495 ONNX graph. It can saved and looked into with a tool such as
496 :epkg:`netron`.
498 :param model: ONNX graph
499 :param results: results to be added in a dictionary
500 :param as_parameter: add new nodes with results as one parameter
501 (True) or as initializer (False)
502 :param suffix: suffix to add to new results
503 :param param_name: name of the parameter to add
504 (by default the result name), it can be a function
505 `param_name(reult_name) -> parameter_name`
506 :param node_type: type of the new node
507 :param domain: domain the new node
508 :param domain_opset: opset for *domain*
509 :return: new ONNX graph
511 See method :meth:`OnnxInference.run2onnx
512 <mlprodict.onnxrt.onnx_inference.OnnxInference.run2onnx>`
513 to see a graph this function produces.
515 .. image:: debug.png
517 .. versionadded:: 0.7
518 """
519 inputs = list(model.graph.input)
520 outputs = list(model.graph.output)
521 inits = list(model.graph.initializer)
522 nodes = {id(n): n for n in model.graph.node}
523 order = {id(n): i for i, n in enumerate(model.graph.node)}
524 nodes_copy = {}
526 names_init = set(init.name for init in inits)
527 names_input = set(init.name for init in inputs)
528 names_output = {}
529 for node in nodes.values():
530 for i, o in enumerate(node.output):
531 names_output[o] = (i, node)
533 for k, v in results.items():
534 if k in names_init:
535 # initializer are not inserted again
536 continue
537 if k in names_input:
538 # inputs are added as
539 raise NotImplementedError(
540 "Unable to add debug information on input %r." % k)
542 if k not in names_output:
543 raise RuntimeError(
544 "Unable to find result %r in the ONNX graph. Available="
545 "[%s]." % (k, ", ".join(sorted(names_output))))
547 index, node = names_output[k]
548 new_name = k + suffix
550 if id(node) not in nodes_copy:
551 new_node = helper.make_node(
552 node.op_type, list(node.input), list(node.output),
553 domain=node.domain if node.domain else None,
554 name=node.name + suffix)
555 new_node.attribute.extend(node.attribute) # pylint: disable=E1101
556 nodes_copy[id(node)] = new_node
557 order[id(new_node)] = order[id(node)]
558 new_node = nodes_copy[id(node)]
559 new_node.output[index] = new_name
561 if as_parameter:
562 pname = k if param_name is None else param_name(k)
563 atts = {pname: from_array(v, name=pname)}
564 inserted_node = helper.make_node(
565 node_type, [new_name], [k], domain=domain,
566 **atts)
567 else:
568 pname = k if param_name is None else param_name(k)
569 pname += suffix + 'i'
570 inserted_node = helper.make_node(
571 node_type, [new_name, pname], [k], domain=domain)
572 inits.append(from_array(v, name=pname))
574 order[id(inserted_node)] = order[id(node)] + 1. / (index + 2)
575 nodes[id(inserted_node)] = inserted_node
577 new_nodes = [(order[id(n)], n)
578 for n in nodes.values() if id(n) not in nodes_copy]
579 new_nodes.extend((order[id(n)], n) for n in nodes_copy.values())
580 new_nodes = [n[1] for n in sorted(new_nodes)]
582 graph = helper.make_graph(new_nodes, model.graph.name, inputs,
583 outputs, inits)
584 onnx_model = helper.make_model(graph)
585 onnx_model.ir_version = model.ir_version
586 onnx_model.producer_name = model.producer_name
587 onnx_model.producer_version = model.producer_version
588 onnx_model.domain = model.domain
589 onnx_model.model_version = model.model_version
590 onnx_model.doc_string = model.doc_string
591 if len(model.metadata_props) > 0: # pragma: no cover
592 values = {p.key: p.value for p in model.metadata_props}
593 helper.set_model_props(onnx_model, values)
595 del onnx_model.opset_import[:] # pylint: disable=E1101
596 for oimp in model.opset_import:
597 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
598 op_set.domain = oimp.domain
599 op_set.version = oimp.version
600 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
601 op_set.domain = domain
602 op_set.version = domain_opset
603 return onnx_model