Coverage for mlprodict/onnx_tools/optim/onnx_optimisation_redundant.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Optimisation of :epkg:`ONNX` graphs.
4"""
5import copy
6import hashlib
7from onnx import FunctionProto
8from onnx.helper import make_graph, make_function
9from ._onnx_optimisation_common import ( # pylint: disable=E0611
10 _rename_node_input,
11 _rename_node_output,
12 _apply_optimisation_on_graph,
13 _apply_remove_node_fct_node)
16def _hash_obj_content(obj, max_size=1000):
17 """
18 Hash the content of an object.
19 """
20 m = hashlib.sha256()
21 if hasattr(obj, 'op_type'):
22 # An operator.
23 m.update(obj.op_type.encode('ascii'))
24 m.update(len(obj.output).to_bytes(8, byteorder='big'))
25 for i in obj.input:
26 m.update(i.encode('ascii'))
27 if hasattr(obj, 'attribute'):
28 for att in obj.attribute:
29 m.update(att.name.encode('ascii'))
30 m.update(_hash_obj_content(att))
31 else:
32 # An initializer.
33 obj = copy.deepcopy(obj)
34 obj.name = ""
35 obj.doc_string = ""
36 m.update(obj.SerializeToString())
38 content = m.digest()
39 if len(content) > max_size:
40 content = content[:max_size]
41 return content
44def onnx_remove_node_redundant(onnx_model, recursive=True, debug_info=None,
45 max_hash_size=1000, **options):
46 """
47 Removes redundant part of the graph. A redundant part is
48 a set of nodes which takes the same inputs and produces
49 the same outputs. It first starts by looking into duplicated
50 initializers, then looks into nodes taking the same inputs
51 and sharing the same type and parameters.
53 @param onnx_model onnx model
54 @param recursive looks into subgraphs
55 @param debug_info debug information (private)
56 @param max_hash_size limit the size of a hash used to detect
57 identical subgraphs
58 @param options additional options (unused)
59 @return new onnx _model
60 """
61 if debug_info is None:
62 debug_info = [str(type(onnx_model)).rsplit(
63 '.', maxsplit=1)[-1].strip("'>")]
64 else:
65 debug_info = (debug_info +
66 [str(type(onnx_model)).rsplit('.', maxsplit=1)[-1].strip("'>")])
68 if hasattr(onnx_model, 'graph'):
69 return _apply_optimisation_on_graph(
70 onnx_remove_node_redundant, onnx_model,
71 recursive=recursive, debug_info=debug_info,
72 max_hash_size=max_hash_size, **options)
74 def _enumerate_rename_list_nodes_inputs(nodes, rename):
75 for i, node in enumerate(nodes):
76 if node is None:
77 yield False, i, None
78 continue
79 if any(set(node.input) & set(rename)):
80 yield True, i, _rename_node_input(node, rename)
81 continue
82 yield False, i, node
84 graph = onnx_model
85 is_function = isinstance(graph, FunctionProto)
87 # Detects duplicated initializers.
88 hashes = {}
89 names = []
90 rename = {}
91 if is_function:
92 new_inits = []
93 else:
94 for init in graph.initializer:
95 hs = _hash_obj_content(init, max_size=max_hash_size)
96 if hs in hashes:
97 # Already seen.
98 rename[init.name] = hashes[hs] # pragma: no cover
99 else:
100 # New.
101 hashes[hs] = init.name
102 names.append(init.name)
103 new_inits = [init for init in graph.initializer
104 if init.name in set(names)]
106 # Renames node inputs.
107 new_nodes = []
108 new_nodes = list(graph.node)
109 new_nodes = list(
110 _[2] for _ in _enumerate_rename_list_nodes_inputs(new_nodes, rename))
112 # Detects duplicated operators.
113 if is_function:
114 graph_outputs = set(graph.output)
115 else:
116 graph_outputs = set(o.name for o in graph.output)
117 node_hashes = {}
118 changed = 1
119 replace = {}
120 while changed > 0:
121 changed = 0
122 nnodes = len(new_nodes)
123 for i in range(nnodes):
124 if i in replace:
125 # Already removed.
126 continue
127 node = new_nodes[i]
128 hash = _hash_obj_content(node, max_size=max_hash_size)
129 if hash in node_hashes:
130 ni = node_hashes[hash]
131 if ni == i:
132 continue
133 replace[i] = ni
134 changed += 1
136 # Specifies what to rename.
137 # One exception: the output is one of the graph output.
138 rep = new_nodes[ni]
139 for old, nn in zip(node.output, rep.output):
140 if old in graph_outputs:
141 rename[nn] = old
142 new_nodes[ni] = _rename_node_output(
143 new_nodes[ni], nn, old)
144 else:
145 rename[old] = nn
147 # Renames inputs.
148 new_new_nodes = []
149 renew_index = set()
150 for changed, ci, node in _enumerate_rename_list_nodes_inputs(new_nodes, rename):
151 if changed:
152 renew_index.add(ci)
153 new_new_nodes.append(node)
154 new_nodes = new_new_nodes
156 # Renews hashes.
157 renew_hash = set(
158 k for k, v in node_hashes.items() if v in renew_index)
159 for hs in renew_hash:
160 del node_hashes[hs]
161 new_nodes[i] = None
162 else:
163 node_hashes[hash] = i
165 if recursive:
166 # Handles subgraphs.
167 for i in range(len(new_nodes)): # pylint: disable=C0200
168 node = new_nodes[i]
169 if node is None or not (node.attribute): # pylint: disable=C0325
170 continue
171 new_nodes[i] = _apply_remove_node_fct_node(
172 onnx_remove_node_redundant,
173 node, recursive=True, debug_info=debug_info + [node.name])
175 # Finally create the new graph.
176 nodes = list(filter(lambda n: n is not None, new_nodes))
177 if is_function:
178 return make_function(
179 onnx_model.domain, onnx_model.name,
180 onnx_model.input, onnx_model.output, nodes,
181 opset_imports=onnx_model.opset_import,
182 attributes=onnx_model.attribute,
183 doc_string=onnx_model.doc_string)
185 graph = make_graph(nodes, onnx_model.name,
186 onnx_model.input, onnx_model.output,
187 new_inits)
189 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101
190 return graph