Coverage for mlprodict/onnx_tools/optim/onnx_optimisation_identity.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

82 statements  

1""" 

2@file 

3@brief Optimisation of :epkg:`ONNX` graphs. 

4""" 

5import logging 

6from onnx import FunctionProto 

7from onnx.helper import make_graph, make_function 

8from ._onnx_optimisation_common import ( # pylint: disable=E0611 

9 _rename_node_input, 

10 _rename_node_output, 

11 _apply_optimisation_on_graph, 

12 _apply_remove_node_fct_node) 

13 

14 

15logger = logging.getLogger('onnx:optim') 

16 

17 

18def onnx_remove_node_identity(onnx_model, recursive=True, debug_info=None, **options): 

19 """ 

20 Removes as many *Identity* nodes as possible. 

21 The function looks into every node and subgraphs if 

22 *recursive* is True for identity node. Unless such a 

23 node directy connects one input to one output, it will 

24 be removed and every other node gets its inputs or 

25 outputs accordingly renamed. 

26 

27 :param onnx_model: onnx model 

28 :param recursive: looks into subgraphs 

29 :param debug_info: debug information (private) 

30 :param options: additional options (unused) 

31 :return: new onnx _model 

32 """ 

33 if debug_info is None: 

34 debug_info = [str(type(onnx_model)).rsplit( 

35 '.', maxsplit=1)[-1].strip("'>")] 

36 else: 

37 debug_info = (debug_info + 

38 [str(type(onnx_model)).rsplit('.', maxsplit=1)[-1].strip("'>")]) 

39 

40 if hasattr(onnx_model, 'graph'): 

41 return _apply_optimisation_on_graph( 

42 onnx_remove_node_identity, onnx_model, 

43 recursive=recursive, debug_info=debug_info, **options) 

44 

45 graph = onnx_model 

46 is_function = isinstance(graph, FunctionProto) 

47 

48 if is_function: 

49 inputs = set(graph.input) 

50 outputs = set(graph.output) 

51 else: 

52 inputs = set(i.name for i in graph.input) 

53 inits = set(i.name for i in graph.initializer) 

54 inputs_inits = inputs.union(inits) 

55 outputs = set(o.name for o in graph.output) 

56 

57 def retrieve_idnodes(graph, existing_nodes): 

58 idnodes = [] 

59 for i, exnode in enumerate(existing_nodes): 

60 if exnode is None: 

61 continue 

62 if exnode.op_type == 'Identity': 

63 input = exnode.input[0] 

64 output = exnode.output[0] 

65 idnodes.append((i, exnode, input, output)) 

66 return idnodes 

67 

68 nodes = list(graph.node) 

69 rem = 1 

70 while rem > 0: 

71 rem = 0 

72 idnodes = retrieve_idnodes(graph, nodes) 

73 restart = False 

74 for i, _, inp, out in idnodes: 

75 if restart: 

76 break # pragma: no cover 

77 if nodes[i] is None: 

78 # Already removed. 

79 continue # pragma: no cover 

80 if inp in inputs_inits and out in outputs: 

81 # Cannot be removed. 

82 continue 

83 if not restart and out not in outputs: 

84 # We cannot change an output name. 

85 for j in range(len(nodes)): # pylint: disable=C0200 

86 if nodes[j] is None: 

87 continue 

88 if out in nodes[j].input: 

89 logger.debug('onnx_remove_node_identity:' 

90 '_rename_node_input:%s:%r->%r:' 

91 'out=%r:inp=%r', 

92 nodes[j].op_type, nodes[j].input, 

93 nodes[j].output, out, inp) 

94 nodes[j] = _rename_node_input(nodes[j], out, inp) 

95 rem += 1 

96 if nodes[j].op_type == 'Identity': 

97 restart = True # pragma: no cover 

98 nodes[i] = None 

99 rem += 1 

100 continue 

101 if not restart and inp not in inputs_inits and inp not in outputs: 

102 # We cannot change an input name or an output name. 

103 for j in range(len(nodes)): # pylint: disable=C0200 

104 if nodes[j] is None: 

105 continue 

106 if inp in nodes[j].output: 

107 logger.debug('onnx_remove_node_identity:' 

108 '_rename_node_output:%s:%r->%r:' 

109 'inp=%r:out=%r', 

110 nodes[j].op_type, nodes[j].input, 

111 nodes[j].output, inp, out) 

112 nodes[j] = _rename_node_output(nodes[j], inp, out) 

113 rem += 1 

114 if nodes[j].op_type == 'Identity': 

115 restart = True # pragma: no cover 

116 if inp in nodes[j].input: 

117 logger.debug('onnx_remove_node_identity:' 

118 '_rename_node_input:%s:%r->%r:' 

119 'inp=%r:out=%r', 

120 nodes[j].op_type, nodes[j].input, 

121 nodes[j].output, inp, out) 

122 nodes[j] = _rename_node_input(nodes[j], inp, out) 

123 rem += 1 

124 if nodes[j].op_type == 'Identity': 

125 restart = True 

126 nodes[i] = None 

127 rem += 1 

128 

129 if recursive: 

130 # Handles subgraphs. 

131 for i in range(len(nodes)): # pylint: disable=C0200 

132 node = nodes[i] 

133 if node is None or not (node.attribute): # pylint: disable=C0325 

134 continue 

135 nodes[i] = _apply_remove_node_fct_node( 

136 onnx_remove_node_identity, 

137 node, recursive=True, debug_info=debug_info + [node.name]) 

138 

139 # Finally create the new graph. 

140 nodes = list(filter(lambda n: n is not None, nodes)) 

141 if is_function: 

142 return make_function( 

143 onnx_model.domain, onnx_model.name, 

144 onnx_model.input, onnx_model.output, nodes, 

145 opset_imports=onnx_model.opset_import, 

146 attributes=onnx_model.attribute, 

147 doc_string=onnx_model.doc_string) 

148 

149 graph = make_graph(nodes, onnx_model.name, 

150 onnx_model.input, onnx_model.output, 

151 onnx_model.initializer) 

152 

153 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101 

154 return graph