Coverage for mlprodict/onnx_tools/optim/_onnx_optimisation_common.py: 93%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Common functions to reduce the number of
4nodes of an :epkg:`ONNX` graphs.
5"""
6from onnx.helper import make_graph, ValueInfoProto, make_model
7from onnx import AttributeProto, NodeProto
8from onnx.helper import make_attribute
11def _apply_optimisation_on_graph(fct, onnx_model, recursive=True, debug_info=None,
12 **kwargs):
13 """
14 Applies an optimisation function *fct* on a graph
15 and not on the model.
17 @param fct function to optimize like
18 @see fn onnx_remove_node_identity
19 @param onnx_model onnx model
20 @param recursive looks into subgraphs
21 @param debug_info debug information (private)
22 @param kwargs additional parameters
23 @return new onnx _model
24 """
25 if hasattr(onnx_model, 'graph'):
26 if debug_info is None:
27 debug_info = []
28 graph = fct(
29 onnx_model.graph, debug_info=debug_info + ['GRAPH'],
30 **kwargs)
31 new_model = make_model(graph, functions=onnx_model.functions)
32 new_model.ir_version = onnx_model.ir_version
33 new_model.producer_name = onnx_model.producer_name
34 new_model.producer_version = onnx_model.producer_version
35 new_model.domain = onnx_model.domain
36 new_model.model_version = onnx_model.model_version
37 new_model.doc_string = onnx_model.doc_string
38 if hasattr(onnx_model, 'value_info'):
39 graph.value_info.extend(onnx_model.value_info) # pragma: no cover
40 while len(new_model.opset_import) > 0: # pylint: disable=E1101
41 new_model.opset_import.pop() # pylint: disable=E1101
42 for oimp in onnx_model.opset_import:
43 op_set = new_model.opset_import.add() # pylint: disable=E1101
44 op_set.domain = oimp.domain
45 op_set.version = oimp.version
46 return new_model
47 raise TypeError( # pragma: no cover
48 "This function only works on 'ModelProto' anod not not on"
49 " {}.".format(type(onnx_model)))
52def _apply_remove_node_fct_node(fct, node, recursive, debug_info):
53 """
54 Applies an optimizing function on a subgraphs.
56 @param node onnx node
57 @param recursive does it in subgraphs as well
58 @return new node
59 """
60 if not hasattr(node, 'attribute'):
61 return node # pragma: no cover
62 modified = 0
63 new_atts = []
64 for att in node.attribute:
65 if att.name == 'body':
66 new_body = fct(
67 att.g, recursive=recursive,
68 debug_info=debug_info + [att.name])
69 new_atts.append(_make_att_graph(att.name, new_body))
70 modified += 1
71 else:
72 new_atts.append(att)
73 if modified > 0:
74 new_node = _make_node(node.op_type, node.input,
75 node.output, name=node.name,
76 attributes=new_atts)
77 return new_node
78 return node
81def _make_node(op_type, inputs, outputs, name=None, doc_string=None,
82 domain=None, attributes=None):
83 """
84 Constructs a NodeProto.
86 :param op_type: (string): The name of the operator to construct
87 :param inputs: list of input names
88 :param outputs: list of output names
89 :param name: optional unique identifier for NodeProto
90 :param doc_string: optional documentation
91 string for NodeProto
92 :param domain: optional domain for NodeProto.
93 If it's None, we will just use default domain (which is empty)
94 :param attributes: the attributes of the node. The acceptable values
95 are documented in :epkg:`make_attribute`.
96 :return: node
97 """
98 node = NodeProto()
99 node.op_type = op_type
100 node.input.extend(inputs) # pylint: disable=E1101
101 node.output.extend(outputs) # pylint: disable=E1101
102 if name:
103 node.name = name
104 if doc_string:
105 node.doc_string = doc_string # pragma: no cover
106 if domain is not None:
107 node.domain = domain
108 if isinstance(attributes, dict):
109 if len(attributes) > 0: # pragma: no cover
110 node.attribute.extend( # pylint: disable=E1101
111 make_attribute(key, value)
112 for key, value in sorted(attributes.items()))
113 elif attributes:
114 for att in attributes:
115 node.attribute.extend([att]) # pylint: disable=E1101
116 return node
119def _replace(name, old_name, new_name):
120 if isinstance(old_name, dict) and new_name is None:
121 return old_name.get(name, name)
122 if name == old_name:
123 return new_name
124 return name
127def _rename_node_input(onnx_node, old_name, new_name=None):
128 """
129 Renames an input from a node.
131 @param onnx_node onnx_node
132 @param old_name old name
133 @param new_name new name or None if *old_name* is a dictionary
134 @return new node
135 """
136 inputs = [_replace(name, old_name, new_name) for name in onnx_node.input]
137 outputs = list(onnx_node.output)
138 if hasattr(onnx_node, 'attribute'):
139 new_atts = []
140 for att in onnx_node.attribute:
141 if att.name == 'body':
142 new_body = _rename_graph_input(att.g, old_name, new_name)
143 attr = AttributeProto()
144 attr.name = att.name
145 attr.g.CopyFrom(new_body) # pylint: disable=E1101
146 attr.type = AttributeProto.GRAPH # pylint: disable=E1101
147 new_atts.append(attr)
148 else:
149 new_atts.append(att)
150 atts = new_atts
151 else:
152 atts = None # pragma: no cover
153 node = _make_node(
154 onnx_node.op_type, inputs, outputs, name=onnx_node.name,
155 domain=onnx_node.domain, attributes=atts)
156 return node
159def _copy_value_info_proto(new_name, obj):
160 value_info = ValueInfoProto()
161 value_info.name = new_name
162 value_info.type.CopyFrom(obj.type) # pylint: disable=E1101
163 if obj.type.doc_string:
164 value_info.doc_string = obj.type.doc_string
165 return value_info
168def _rename_graph_output(graph, old_name, new_name):
169 """
170 Renames an output and adds an *Identity* node
171 to connect the dots.
173 @param graph ONNX graph
174 @return modified graph
175 """
176 outputs = []
177 for o in graph.output:
178 if old_name != o.name:
179 outputs.append(o)
180 else:
181 outputs.append(_copy_value_info_proto(new_name, o))
182 nodes = list(graph.node)
183 nodes.append(_make_node('Identity', [old_name], [new_name]))
184 new_graph = make_graph(nodes, graph.name, graph.input, outputs,
185 graph.initializer)
186 new_graph.value_info.extend(graph.value_info) # pylint: disable=E1101
187 return new_graph
190def _rename_graph_input(graph, old_name, new_name):
191 """
192 Renames an input and adds an *Identity* node
193 to connect the dots.
195 @param graph ONNX graph
196 @return modified graph
197 """
198 inputs = []
199 for i in graph.input:
200 if old_name != i.name:
201 inputs.append(i)
202 else:
203 inputs.append(_copy_value_info_proto(new_name, i))
204 nodes = list(graph.node)
205 nodes.append(_make_node('Identity', [new_name], [old_name]))
206 new_graph = make_graph(nodes, graph.name, inputs, graph.output,
207 graph.initializer)
208 new_graph.value_info.extend(graph.value_info) # pylint: disable=E1101
209 return new_graph
212def _make_att_graph(name, new_body):
213 attr = AttributeProto()
214 attr.name = name
215 attr.g.CopyFrom(new_body) # pylint: disable=E1101
216 attr.type = AttributeProto.GRAPH # pylint: disable=E1101
217 return attr
220def _rename_node_output(onnx_node, old_name, new_name):
221 """
222 Renames an output from a node.
224 @param onnx_node onnx_node
225 @param old_name old name
226 @param new_name new name
227 @return new node
228 """
229 inputs = list(onnx_node.input)
230 outputs = [_replace(name, old_name, new_name) for name in onnx_node.output]
231 if hasattr(onnx_node, 'attribute'):
232 new_atts = []
233 for att in onnx_node.attribute:
234 if att.name == 'body':
235 new_body = _rename_graph_output(att.g, old_name, new_name)
236 new_atts.append(_make_att_graph(att.name, new_body))
237 else:
238 new_atts.append(att)
239 atts = new_atts
240 else:
241 atts = None # pragma: no cover
242 node = _make_node(
243 onnx_node.op_type, inputs, outputs, name=onnx_node.name,
244 domain=onnx_node.domain, attributes=atts)
245 return node