Coverage for mlprodict/onnx_tools/optim/onnx_optimisation_identity.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Optimisation of :epkg:`ONNX` graphs.
4"""
5import logging
6from onnx import FunctionProto
7from onnx.helper import make_graph, make_function
8from ._onnx_optimisation_common import ( # pylint: disable=E0611
9 _rename_node_input,
10 _rename_node_output,
11 _apply_optimisation_on_graph,
12 _apply_remove_node_fct_node)
15logger = logging.getLogger('onnx:optim')
18def onnx_remove_node_identity(onnx_model, recursive=True, debug_info=None, **options):
19 """
20 Removes as many *Identity* nodes as possible.
21 The function looks into every node and subgraphs if
22 *recursive* is True for identity node. Unless such a
23 node directy connects one input to one output, it will
24 be removed and every other node gets its inputs or
25 outputs accordingly renamed.
27 :param onnx_model: onnx model
28 :param recursive: looks into subgraphs
29 :param debug_info: debug information (private)
30 :param options: additional options (unused)
31 :return: new onnx _model
32 """
33 if debug_info is None:
34 debug_info = [str(type(onnx_model)).rsplit(
35 '.', maxsplit=1)[-1].strip("'>")]
36 else:
37 debug_info = (debug_info +
38 [str(type(onnx_model)).rsplit('.', maxsplit=1)[-1].strip("'>")])
40 if hasattr(onnx_model, 'graph'):
41 return _apply_optimisation_on_graph(
42 onnx_remove_node_identity, onnx_model,
43 recursive=recursive, debug_info=debug_info, **options)
45 graph = onnx_model
46 is_function = isinstance(graph, FunctionProto)
48 if is_function:
49 inputs = set(graph.input)
50 outputs = set(graph.output)
51 else:
52 inputs = set(i.name for i in graph.input)
53 inits = set(i.name for i in graph.initializer)
54 inputs_inits = inputs.union(inits)
55 outputs = set(o.name for o in graph.output)
57 def retrieve_idnodes(graph, existing_nodes):
58 idnodes = []
59 for i, exnode in enumerate(existing_nodes):
60 if exnode is None:
61 continue
62 if exnode.op_type == 'Identity':
63 input = exnode.input[0]
64 output = exnode.output[0]
65 idnodes.append((i, exnode, input, output))
66 return idnodes
68 nodes = list(graph.node)
69 rem = 1
70 while rem > 0:
71 rem = 0
72 idnodes = retrieve_idnodes(graph, nodes)
73 restart = False
74 for i, _, inp, out in idnodes:
75 if restart:
76 break # pragma: no cover
77 if nodes[i] is None:
78 # Already removed.
79 continue # pragma: no cover
80 if inp in inputs_inits and out in outputs:
81 # Cannot be removed.
82 continue
83 if not restart and out not in outputs:
84 # We cannot change an output name.
85 for j in range(len(nodes)): # pylint: disable=C0200
86 if nodes[j] is None:
87 continue
88 if out in nodes[j].input:
89 logger.debug('onnx_remove_node_identity:'
90 '_rename_node_input:%s:%r->%r:'
91 'out=%r:inp=%r',
92 nodes[j].op_type, nodes[j].input,
93 nodes[j].output, out, inp)
94 nodes[j] = _rename_node_input(nodes[j], out, inp)
95 rem += 1
96 if nodes[j].op_type == 'Identity':
97 restart = True # pragma: no cover
98 nodes[i] = None
99 rem += 1
100 continue
101 if not restart and inp not in inputs_inits and inp not in outputs:
102 # We cannot change an input name or an output name.
103 for j in range(len(nodes)): # pylint: disable=C0200
104 if nodes[j] is None:
105 continue
106 if inp in nodes[j].output:
107 logger.debug('onnx_remove_node_identity:'
108 '_rename_node_output:%s:%r->%r:'
109 'inp=%r:out=%r',
110 nodes[j].op_type, nodes[j].input,
111 nodes[j].output, inp, out)
112 nodes[j] = _rename_node_output(nodes[j], inp, out)
113 rem += 1
114 if nodes[j].op_type == 'Identity':
115 restart = True # pragma: no cover
116 if inp in nodes[j].input:
117 logger.debug('onnx_remove_node_identity:'
118 '_rename_node_input:%s:%r->%r:'
119 'inp=%r:out=%r',
120 nodes[j].op_type, nodes[j].input,
121 nodes[j].output, inp, out)
122 nodes[j] = _rename_node_input(nodes[j], inp, out)
123 rem += 1
124 if nodes[j].op_type == 'Identity':
125 restart = True
126 nodes[i] = None
127 rem += 1
129 if recursive:
130 # Handles subgraphs.
131 for i in range(len(nodes)): # pylint: disable=C0200
132 node = nodes[i]
133 if node is None or not (node.attribute): # pylint: disable=C0325
134 continue
135 nodes[i] = _apply_remove_node_fct_node(
136 onnx_remove_node_identity,
137 node, recursive=True, debug_info=debug_info + [node.name])
139 # Finally create the new graph.
140 nodes = list(filter(lambda n: n is not None, nodes))
141 if is_function:
142 return make_function(
143 onnx_model.domain, onnx_model.name,
144 onnx_model.input, onnx_model.output, nodes,
145 opset_imports=onnx_model.opset_import,
146 attributes=onnx_model.attribute,
147 doc_string=onnx_model.doc_string)
149 graph = make_graph(nodes, onnx_model.name,
150 onnx_model.input, onnx_model.output,
151 onnx_model.initializer)
153 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101
154 return graph