Coverage for mlprodict/onnx_tools/optim/onnx_optimisation_redundant.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

105 statements  

1""" 

2@file 

3@brief Optimisation of :epkg:`ONNX` graphs. 

4""" 

5import copy 

6import hashlib 

7from onnx import FunctionProto 

8from onnx.helper import make_graph, make_function 

9from ._onnx_optimisation_common import ( # pylint: disable=E0611 

10 _rename_node_input, 

11 _rename_node_output, 

12 _apply_optimisation_on_graph, 

13 _apply_remove_node_fct_node) 

14 

15 

16def _hash_obj_content(obj, max_size=1000): 

17 """ 

18 Hash the content of an object. 

19 """ 

20 m = hashlib.sha256() 

21 if hasattr(obj, 'op_type'): 

22 # An operator. 

23 m.update(obj.op_type.encode('ascii')) 

24 m.update(len(obj.output).to_bytes(8, byteorder='big')) 

25 for i in obj.input: 

26 m.update(i.encode('ascii')) 

27 if hasattr(obj, 'attribute'): 

28 for att in obj.attribute: 

29 m.update(att.name.encode('ascii')) 

30 m.update(_hash_obj_content(att)) 

31 else: 

32 # An initializer. 

33 obj = copy.deepcopy(obj) 

34 obj.name = "" 

35 obj.doc_string = "" 

36 m.update(obj.SerializeToString()) 

37 

38 content = m.digest() 

39 if len(content) > max_size: 

40 content = content[:max_size] 

41 return content 

42 

43 

44def onnx_remove_node_redundant(onnx_model, recursive=True, debug_info=None, 

45 max_hash_size=1000, **options): 

46 """ 

47 Removes redundant part of the graph. A redundant part is 

48 a set of nodes which takes the same inputs and produces 

49 the same outputs. It first starts by looking into duplicated 

50 initializers, then looks into nodes taking the same inputs 

51 and sharing the same type and parameters. 

52 

53 @param onnx_model onnx model 

54 @param recursive looks into subgraphs 

55 @param debug_info debug information (private) 

56 @param max_hash_size limit the size of a hash used to detect 

57 identical subgraphs 

58 @param options additional options (unused) 

59 @return new onnx _model 

60 """ 

61 if debug_info is None: 

62 debug_info = [str(type(onnx_model)).rsplit( 

63 '.', maxsplit=1)[-1].strip("'>")] 

64 else: 

65 debug_info = (debug_info + 

66 [str(type(onnx_model)).rsplit('.', maxsplit=1)[-1].strip("'>")]) 

67 

68 if hasattr(onnx_model, 'graph'): 

69 return _apply_optimisation_on_graph( 

70 onnx_remove_node_redundant, onnx_model, 

71 recursive=recursive, debug_info=debug_info, 

72 max_hash_size=max_hash_size, **options) 

73 

74 def _enumerate_rename_list_nodes_inputs(nodes, rename): 

75 for i, node in enumerate(nodes): 

76 if node is None: 

77 yield False, i, None 

78 continue 

79 if any(set(node.input) & set(rename)): 

80 yield True, i, _rename_node_input(node, rename) 

81 continue 

82 yield False, i, node 

83 

84 graph = onnx_model 

85 is_function = isinstance(graph, FunctionProto) 

86 

87 # Detects duplicated initializers. 

88 hashes = {} 

89 names = [] 

90 rename = {} 

91 if is_function: 

92 new_inits = [] 

93 else: 

94 for init in graph.initializer: 

95 hs = _hash_obj_content(init, max_size=max_hash_size) 

96 if hs in hashes: 

97 # Already seen. 

98 rename[init.name] = hashes[hs] # pragma: no cover 

99 else: 

100 # New. 

101 hashes[hs] = init.name 

102 names.append(init.name) 

103 new_inits = [init for init in graph.initializer 

104 if init.name in set(names)] 

105 

106 # Renames node inputs. 

107 new_nodes = [] 

108 new_nodes = list(graph.node) 

109 new_nodes = list( 

110 _[2] for _ in _enumerate_rename_list_nodes_inputs(new_nodes, rename)) 

111 

112 # Detects duplicated operators. 

113 if is_function: 

114 graph_outputs = set(graph.output) 

115 else: 

116 graph_outputs = set(o.name for o in graph.output) 

117 node_hashes = {} 

118 changed = 1 

119 replace = {} 

120 while changed > 0: 

121 changed = 0 

122 nnodes = len(new_nodes) 

123 for i in range(nnodes): 

124 if i in replace: 

125 # Already removed. 

126 continue 

127 node = new_nodes[i] 

128 hash = _hash_obj_content(node, max_size=max_hash_size) 

129 if hash in node_hashes: 

130 ni = node_hashes[hash] 

131 if ni == i: 

132 continue 

133 replace[i] = ni 

134 changed += 1 

135 

136 # Specifies what to rename. 

137 # One exception: the output is one of the graph output. 

138 rep = new_nodes[ni] 

139 for old, nn in zip(node.output, rep.output): 

140 if old in graph_outputs: 

141 rename[nn] = old 

142 new_nodes[ni] = _rename_node_output( 

143 new_nodes[ni], nn, old) 

144 else: 

145 rename[old] = nn 

146 

147 # Renames inputs. 

148 new_new_nodes = [] 

149 renew_index = set() 

150 for changed, ci, node in _enumerate_rename_list_nodes_inputs(new_nodes, rename): 

151 if changed: 

152 renew_index.add(ci) 

153 new_new_nodes.append(node) 

154 new_nodes = new_new_nodes 

155 

156 # Renews hashes. 

157 renew_hash = set( 

158 k for k, v in node_hashes.items() if v in renew_index) 

159 for hs in renew_hash: 

160 del node_hashes[hs] 

161 new_nodes[i] = None 

162 else: 

163 node_hashes[hash] = i 

164 

165 if recursive: 

166 # Handles subgraphs. 

167 for i in range(len(new_nodes)): # pylint: disable=C0200 

168 node = new_nodes[i] 

169 if node is None or not (node.attribute): # pylint: disable=C0325 

170 continue 

171 new_nodes[i] = _apply_remove_node_fct_node( 

172 onnx_remove_node_redundant, 

173 node, recursive=True, debug_info=debug_info + [node.name]) 

174 

175 # Finally create the new graph. 

176 nodes = list(filter(lambda n: n is not None, new_nodes)) 

177 if is_function: 

178 return make_function( 

179 onnx_model.domain, onnx_model.name, 

180 onnx_model.input, onnx_model.output, nodes, 

181 opset_imports=onnx_model.opset_import, 

182 attributes=onnx_model.attribute, 

183 doc_string=onnx_model.doc_string) 

184 

185 graph = make_graph(nodes, onnx_model.name, 

186 onnx_model.input, onnx_model.output, 

187 new_inits) 

188 

189 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101 

190 return graph