Coverage for mlprodict/onnx_tools/onnx_tools.py: 89%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Functions to manipulate ONNX file.
4"""
5from onnx import helper
8def find_node_name(model, name):
9 """
10 Finds a node by its name.
11 :param model: onnx graph
12 :param name: node name
13 :return: node pointer
14 """
15 if not hasattr(model, "graph"):
16 raise TypeError( # pragma: no cover
17 "Parameter model is not an ONNX model but "
18 "{}".format(type(model)))
19 for node in model.graph.node:
20 if node.name == name:
21 return node
22 return None # pragma: no cover
25def find_node_input_name(node, name):
26 """
27 Finds a node input by its name.
28 :param node: onnx node
29 :param name: node name
30 :return: input index
31 """
32 for i, inode in enumerate(node.input.node):
33 if inode.name == name:
34 return i
35 return -1
38def insert_node(model, op_type, node, input_index=0, new_name=None, **attrs):
39 """
40 Inserts a node before one node input.
41 :param model: onnx graph
42 :param op_type:
43 :param node: node or node name
44 :param input_index: input index or input name
45 :param attrs: node attributes
46 :return: updated graph
47 """
48 if isinstance(node, str):
49 inode = find_node_name(model, node)
50 else:
51 inode = node
52 if isinstance(input_index, str):
53 input_index_ = find_node_input_name(node, input_index)
54 if input_index_ == -1:
55 raise RuntimeError( # pragma: no cover
56 "Unable to find input_index %r in node %r." % (
57 input_index, node.name)) # pylint: disable=E1120
58 input_index = input_index_
60 # guess a new name
61 names = []
62 for n in model.graph.node:
63 names.extend(n.input)
64 names.extend(n.output)
65 names = set(names)
66 if new_name is None:
67 new_name = op_type.lower()
68 root_name = new_name
69 i = 0
70 while new_name in names:
71 new_name = "%s_%d" % (root_name, i)
72 i += 1
74 new_node = helper.make_node(
75 op_type, [inode.input[input_index]], [new_name], **attrs)
76 inode.input[input_index] = new_name
77 keep_nodes = list(model.graph.node)
78 keep_nodes.append(new_node)
79 keep_nodes = ensure_topological_order(
80 model.graph.input, model.graph.initializer, keep_nodes)
82 graph = helper.make_graph(
83 keep_nodes, model.graph.name, model.graph.input,
84 model.graph.output, model.graph.initializer)
85 onnx_model = helper.make_model(graph)
86 onnx_model.ir_version = model.ir_version
87 onnx_model.producer_name = model.producer_name
88 onnx_model.producer_version = model.producer_version
89 onnx_model.domain = model.domain
90 onnx_model.model_version = model.model_version
91 onnx_model.doc_string = model.doc_string
92 if len(model.metadata_props) > 0:
93 values = {p.key: p.value for p in model.metadata_props}
94 helper.set_model_props(onnx_model, values)
96 del onnx_model.opset_import[:] # pylint: disable=E1101
97 for oimp in model.opset_import:
98 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
99 op_set.domain = oimp.domain
100 op_set.version = oimp.version
102 if len(onnx_model.graph.input) != len(model.graph.input): # pylint: disable=E1101
103 raise RuntimeError( # pragma: no cover
104 "Input mismatch {} != {}".format(
105 len(onnx_model.input), len(model.input))) # pylint: disable=E1101
106 return onnx_model
109def ensure_topological_order(inputs, initializers, nodes):
110 """
111 Ensures and modifies the order of nodes to have
112 a topological order (every node in the list
113 can only be an input for a node later in this list).
114 The function raises an exception if a cycle is detected.
116 :param inputs: graph inputs:
117 :param initializers: graph initializers
118 :param nodes: graph nodes
119 :return: list ordered nodes
120 """
121 order = {}
122 for inp in inputs:
123 name = inp.name
124 order[name] = 0
125 for inp in initializers:
126 name = inp.name
127 order[name] = 0
128 n_iter = 0
129 while n_iter < len(nodes) * 2:
130 n_iter += 1
131 missing_names = set()
132 missing_ops = []
133 for node in nodes:
134 maxi = 0
135 for name in node.input:
136 if name in order:
137 maxi = max(maxi, order[name])
138 else:
139 maxi = None
140 missing_names.add(name)
141 break
142 if maxi is None:
143 missing_ops.append(node)
144 continue
145 key = id(node)
146 if key in order:
147 continue
148 maxi += 1
149 order[key] = maxi
150 maxi += 1
151 for name in node.output:
152 if name in order:
153 raise RuntimeError( # pragma: no cover
154 "Unable to sort a node (cycle). An output was "
155 "already ordered %r (iteration=%r)." % (
156 name, n_iter))
157 order[name] = maxi
158 if len(missing_names) == 0:
159 continue
161 if len(missing_ops) > 0: # pragma: no cover
162 def nstr(name):
163 if name in order:
164 return "%s#%d" % (name, order[name])
165 return name
166 rows = ["%s(%s) -> [%s]" % (
167 n.name or n.op_type,
168 ', '.join(map(nstr, n.input)),
169 ', '.join(n.output))
170 for n in missing_ops]
171 rows.insert(0, "")
172 rows.append("--")
173 rows.append("--all-nodes--")
174 rows.append("--")
175 rows.extend("%s(%s) -> [%s]" % (
176 n.name or n.op_type,
177 ', '.join(map(nstr, n.input)),
178 ', '.join(n.output))
179 for n in nodes)
180 raise RuntimeError(
181 "After %d iterations for %d nodes, still unable "
182 "to sort names %r. The graph may be disconnected. "
183 "List of operators: %s" % (
184 n_iter, len(nodes), missing_names,
185 "\n".join(rows)))
187 # Update order
188 topo = [(order[id(node)], str(id(node))) for node in nodes]
189 topo.sort()
190 map_nodes = {str(id(node)): node for node in nodes}
191 return [map_nodes[_[1]] for _ in topo]