Coverage for mlprodict/onnx_tools/optim/onnx_optimisation_unused.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Optimisation of :epkg:`ONNX` graphs.
4"""
5from onnx import FunctionProto
6from onnx.helper import make_graph, make_function
7from ._onnx_optimisation_common import ( # pylint: disable=E0611
8 _apply_optimisation_on_graph, _apply_remove_node_fct_node)
11def onnx_remove_node_unused(onnx_model, recursive=True, debug_info=None, **options):
12 """
13 Removes unused nodes of the graph. An unused node
14 is not involved in the output computation.
16 @param onnx_model onnx model
17 @param recursive looks into subgraphs
18 @param debug_info debug information (private)
19 @param options unused
20 @return new onnx _model
21 """
22 if debug_info is None:
23 debug_info = [str(type(onnx_model)).rsplit(
24 '.', maxsplit=1)[-1].strip("'>")]
25 else:
26 debug_info = (debug_info +
27 [str(type(onnx_model)).rsplit('.', maxsplit=1)[-1].strip("'>")])
29 if hasattr(onnx_model, 'graph'):
30 return _apply_optimisation_on_graph(
31 onnx_remove_node_unused, onnx_model,
32 recursive=recursive, debug_info=debug_info,
33 **options)
35 graph = onnx_model
36 is_function = isinstance(graph, FunctionProto)
37 data = {}
38 valid = {}
39 edges = {}
41 if not is_function:
42 for init in graph.initializer:
43 data[init.name, 0] = init
45 for node in graph.node:
46 data[node.name, 1] = node
47 for inp in node.input:
48 data[inp, 0] = node
49 edges[(inp, 0), (node.name, 1)] = node
50 for out in node.output:
51 data[out, 0] = node
52 edges[(node.name, 1), (out, 0)] = node
54 for out in graph.output:
55 valid[out if is_function else out.name, 0] = True
57 modif = 1
58 while modif > 0:
59 modif = 0
60 for e1, e2 in edges: # pylint: disable=E1141
61 if valid.get(e2, False) and not valid.get(e1, False):
62 valid[e1] = True
63 modif += 1
65 new_nodes = [n for n in graph.node if (n.name, 1) in valid]
66 if not is_function:
67 new_inits = [n for n in graph.initializer if (n.name, 0) in valid]
69 if recursive:
70 # Handles subgraphs.
71 for i in range(len(new_nodes)): # pylint: disable=C0200
72 node = new_nodes[i]
73 if node is None or not (node.attribute): # pylint: disable=C0325
74 continue
75 new_nodes[i] = _apply_remove_node_fct_node(
76 onnx_remove_node_unused,
77 node, recursive=True, debug_info=debug_info + [node.name])
79 # Finally create the new graph.
80 nodes = list(filter(lambda n: n is not None, new_nodes))
81 if is_function:
82 return make_function(
83 onnx_model.domain, onnx_model.name,
84 onnx_model.input, onnx_model.output, nodes,
85 opset_imports=onnx_model.opset_import,
86 attributes=onnx_model.attribute,
87 doc_string=onnx_model.doc_string)
88 graph = make_graph(nodes, onnx_model.name,
89 onnx_model.input, onnx_model.output,
90 new_inits)
91 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101
92 return graph