Coverage for onnxcustom/utils/onnx_rewriter.py: 91%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Rewrites operator in ONNX graph.
4"""
5from onnx.helper import (
6 make_graph, make_node, make_tensor_value_info, make_model)
7from onnx import NodeProto
8from onnx.numpy_helper import to_array, from_array
11def _unique_name(existing_names, name):
12 """
13 Returns a name different from any name in *existing_names*.
15 :param existing_names: set of names
16 :param name: current
17 :return: unique name
18 """
19 if name not in existing_names:
20 existing_names.add(name)
21 return name
22 name0 = name
23 i = 2
24 while name in existing_names:
25 name = "%s_%d" % (name0, i)
26 i += 1
27 existing_names.add(name)
28 return name
31def _existing_names(onx):
32 """
33 Makes the list of existing names.
34 Returns a set of unique names including
35 intermediate results.
36 """
37 existing_names = set()
38 graph = onx.graph if hasattr(onx, 'graph') else onx
39 for node in graph.node:
40 existing_names.update(node.input)
41 existing_names.update(node.output)
42 return existing_names
45def _onnx_rewrite_operator_node(existing_names, node, sub_onx):
46 """
47 Replaces a node by a subgraph.
49 :param existing_names: existing results names
50 :param node: onnx node to replace
51 :param sub_onx: onnx sub_graph to use as a replacement
52 :return: new_initializer, new_nodes
53 """
54 if len(node.input) != len(sub_onx.graph.input):
55 raise ValueError( # pragma: no cover
56 "Mismatch with the number of inputs for operator type %r. "
57 "%d != %d." % (
58 node.op_type, len(node.input), len(sub_onx.graph.nput)))
59 if len(node.output) != len(sub_onx.graph.output):
60 raise ValueError( # pragma: no cover
61 "Mismatch with the number of outputs for operator type %r. "
62 "%d != %d." % (
63 node.op_type, len(node.output), len(sub_onx.graph.output)))
64 replaces = {}
65 for inp, name in zip(sub_onx.graph.input, node.input):
66 replaces[inp.name] = name
67 for inp, name in zip(sub_onx.graph.output, node.output):
68 replaces[inp.name] = name
70 new_inits = []
71 for init in sub_onx.graph.initializer:
72 name = _unique_name(existing_names, init.name)
73 replaces[init.name] = name
74 tensor = from_array(to_array(init), name=name)
75 new_inits.append(tensor)
77 new_nodes = []
78 for n in sub_onx.graph.node:
79 new_node = NodeProto()
80 new_node.op_type = n.op_type
81 new_node.attribute.extend(n.attribute) # pylint: disable=E1101
82 new_node.input.extend( # pylint: disable=E1101
83 [replaces[i] for i in n.input]) # pylint: disable=E1101
84 new_node.domain = n.domain
85 new_out = []
86 for o in n.output:
87 if o in replaces:
88 new_out.append(replaces[o])
89 else:
90 n = _unique_name(existing_names, o)
91 new_out.append(n)
92 new_node.output.extend(new_out) # pylint: disable=E1101
93 new_nodes.append(new_node)
95 return new_inits, new_nodes
98def onnx_rewrite_operator(onx, op_type, sub_onx, recursive=True, debug_info=None):
99 """
100 Replaces one operator by an onnx graph.
102 :param onx: onnx graph
103 :param op_type: operator type
104 :param sub_onx: onnx graph
105 :param recursive: looks into subgraphs
106 :param debug_info: unused
107 :return: modified onnx graph
109 .. runpython::
110 :showcode:
112 import numpy
113 from skl2onnx.common.data_types import FloatTensorType
114 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
115 OnnxReciprocal, OnnxDiv)
116 from mlprodict.plotting.text_plot import onnx_simple_text_plot
117 from onnxcustom import get_max_opset
118 from onnxcustom.utils.onnx_rewriter import onnx_rewrite_operator
120 # first graph: it contains the node to replace
121 opset = get_max_opset()
122 node1 = OnnxReciprocal('X', output_names=['Y'],
123 op_version=opset)
124 onx1 = node1.to_onnx(
125 inputs={'X': FloatTensorType()},
126 outputs={'Y': FloatTensorType()},
127 target_opset=opset)
129 # second graph: it contains the replacement graph
130 node2 = OnnxDiv(numpy.array([1], dtype=numpy.float32),
131 'X', output_names=['Y'],
132 op_version=opset)
133 onx2 = node2.to_onnx(
134 inputs={'X': FloatTensorType()},
135 outputs={'Y': FloatTensorType()},
136 target_opset=opset)
138 # third graph: the modified graph
139 onx3 = onnx_rewrite_operator(onx1, 'Reciprocal', onx2)
140 print(onnx_simple_text_plot(onx3))
141 """
142 from mlprodict.onnx_tools.optim._onnx_optimisation_common import ( # pylint: disable=C0415
143 _apply_remove_node_fct_node, _apply_optimisation_on_graph)
145 if hasattr(onx, 'graph'):
146 fct = (lambda graph, recursive=False, debug_info=None:
147 onnx_rewrite_operator(
148 graph, op_type, sub_onx, recursive=recursive,
149 debug_info=debug_info))
150 return _apply_optimisation_on_graph(fct, onx, recursive=recursive)
152 existing_names = _existing_names(onx)
153 nodes = list(onx.node)
154 new_nodes = []
155 new_inits = list(onx.initializer)
156 for i, node in enumerate(nodes):
157 if node.op_type != op_type:
158 new_nodes.append(node)
159 continue
160 inits, newn = _onnx_rewrite_operator_node(
161 existing_names, node, sub_onx)
162 new_inits.extend(inits)
163 new_nodes.extend(newn)
165 if recursive:
166 # Handles subgraphs.
167 for i in range(len(new_nodes)): # pylint: disable=C0200
168 node = nodes[i]
169 if node is None or not (node.attribute): # pylint: disable=C0325
170 continue
171 nodes[i] = _apply_remove_node_fct_node(
172 onnx_rewrite_operator, node, recursive=True,
173 debug_info=None)
175 graph = make_graph(
176 new_nodes, onx.name, onx.input, onx.output, new_inits)
177 return graph
180def unreduced_onnx_loss(onx, output_name='score'):
181 """
182 Every loss function reduces the results to compute a loss.
183 The score function needs to get the loss for every observation,
184 not the whole loss. This function looks for a reducing node
185 and removes it before exposing the output as the only output.
187 :param onx: onx graph
188 :param output_name: new output name
189 :return: new onx graph
190 """
191 from mlprodict.onnx_tools.onnx_manipulations import ( # pylint: disable=C0415
192 select_model_inputs_outputs)
194 graph = onx.graph
195 found = []
196 for node in graph.node:
197 if node.op_type.startswith('Reduce'):
198 found.append(node)
199 if len(found) != 1:
200 raise RuntimeError( # pragma: no cover
201 "Unable to find one unique Reducing node but found %d - %r."
202 "" % (len(found), [(n.op_type, n.name) for n in found]))
203 node = found[0]
204 input_name = node.input[0]
205 new_onx = select_model_inputs_outputs(
206 onx, outputs=[input_name], infer_shapes=True)
208 inits = new_onx.graph.initializer
209 inputs = new_onx.graph.input # pylint: disable=E1101
210 existing_names = _existing_names(new_onx)
211 new_name = _unique_name(existing_names, output_name)
212 new_nodes = list(new_onx.graph.node) # pylint: disable=E1101
213 elem = graph.output[0].type.tensor_type.elem_type
214 new_output = [make_tensor_value_info(new_name, elem, [None, 1])]
216 if node.op_type == "ReduceSumSquare":
217 new_node = make_node('Mul', [input_name, input_name], [new_name])
218 new_nodes.append(new_node)
219 elif node.op_type == 'ReduceSum':
220 new_node = make_node('Identity', [input_name], [new_name])
221 new_nodes.append(new_node)
222 else:
223 raise RuntimeError( # pragma: no cover
224 "Unable to unreduce node %r." % node.op_type)
226 graph = make_graph(
227 new_nodes, graph.name, inputs, new_output, inits)
228 new_model = make_model(graph)
229 new_model.ir_version = onx.ir_version
230 new_model.producer_name = onx.producer_name
231 new_model.producer_version = onx.producer_version
232 new_model.domain = onx.domain
233 new_model.model_version = onx.model_version
234 new_model.doc_string = onx.doc_string
235 if hasattr(onx, 'value_info'):
236 graph.value_info.extend(onx.value_info) # pylint: disable=E1101
237 del new_model.opset_import[:] # pylint: disable=E1101
238 for oimp in onx.opset_import:
239 op_set = new_model.opset_import.add() # pylint: disable=E1101
240 op_set.domain = oimp.domain
241 op_set.version = oimp.version
242 return new_model