Coverage for onnxcustom/utils/onnx_rewriter.py: 91%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

115 statements  

1""" 

2@file 

3@brief Rewrites operator in ONNX graph. 

4""" 

5from onnx.helper import ( 

6 make_graph, make_node, make_tensor_value_info, make_model) 

7from onnx import NodeProto 

8from onnx.numpy_helper import to_array, from_array 

9 

10 

11def _unique_name(existing_names, name): 

12 """ 

13 Returns a name different from any name in *existing_names*. 

14 

15 :param existing_names: set of names 

16 :param name: current 

17 :return: unique name 

18 """ 

19 if name not in existing_names: 

20 existing_names.add(name) 

21 return name 

22 name0 = name 

23 i = 2 

24 while name in existing_names: 

25 name = "%s_%d" % (name0, i) 

26 i += 1 

27 existing_names.add(name) 

28 return name 

29 

30 

31def _existing_names(onx): 

32 """ 

33 Makes the list of existing names. 

34 Returns a set of unique names including 

35 intermediate results. 

36 """ 

37 existing_names = set() 

38 graph = onx.graph if hasattr(onx, 'graph') else onx 

39 for node in graph.node: 

40 existing_names.update(node.input) 

41 existing_names.update(node.output) 

42 return existing_names 

43 

44 

45def _onnx_rewrite_operator_node(existing_names, node, sub_onx): 

46 """ 

47 Replaces a node by a subgraph. 

48 

49 :param existing_names: existing results names 

50 :param node: onnx node to replace 

51 :param sub_onx: onnx sub_graph to use as a replacement 

52 :return: new_initializer, new_nodes 

53 """ 

54 if len(node.input) != len(sub_onx.graph.input): 

55 raise ValueError( # pragma: no cover 

56 "Mismatch with the number of inputs for operator type %r. " 

57 "%d != %d." % ( 

58 node.op_type, len(node.input), len(sub_onx.graph.nput))) 

59 if len(node.output) != len(sub_onx.graph.output): 

60 raise ValueError( # pragma: no cover 

61 "Mismatch with the number of outputs for operator type %r. " 

62 "%d != %d." % ( 

63 node.op_type, len(node.output), len(sub_onx.graph.output))) 

64 replaces = {} 

65 for inp, name in zip(sub_onx.graph.input, node.input): 

66 replaces[inp.name] = name 

67 for inp, name in zip(sub_onx.graph.output, node.output): 

68 replaces[inp.name] = name 

69 

70 new_inits = [] 

71 for init in sub_onx.graph.initializer: 

72 name = _unique_name(existing_names, init.name) 

73 replaces[init.name] = name 

74 tensor = from_array(to_array(init), name=name) 

75 new_inits.append(tensor) 

76 

77 new_nodes = [] 

78 for n in sub_onx.graph.node: 

79 new_node = NodeProto() 

80 new_node.op_type = n.op_type 

81 new_node.attribute.extend(n.attribute) # pylint: disable=E1101 

82 new_node.input.extend( # pylint: disable=E1101 

83 [replaces[i] for i in n.input]) # pylint: disable=E1101 

84 new_node.domain = n.domain 

85 new_out = [] 

86 for o in n.output: 

87 if o in replaces: 

88 new_out.append(replaces[o]) 

89 else: 

90 n = _unique_name(existing_names, o) 

91 new_out.append(n) 

92 new_node.output.extend(new_out) # pylint: disable=E1101 

93 new_nodes.append(new_node) 

94 

95 return new_inits, new_nodes 

96 

97 

98def onnx_rewrite_operator(onx, op_type, sub_onx, recursive=True, debug_info=None): 

99 """ 

100 Replaces one operator by an onnx graph. 

101 

102 :param onx: onnx graph 

103 :param op_type: operator type 

104 :param sub_onx: onnx graph 

105 :param recursive: looks into subgraphs 

106 :param debug_info: unused 

107 :return: modified onnx graph 

108 

109 .. runpython:: 

110 :showcode: 

111 

112 import numpy 

113 from skl2onnx.common.data_types import FloatTensorType 

114 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 

115 OnnxReciprocal, OnnxDiv) 

116 from mlprodict.plotting.text_plot import onnx_simple_text_plot 

117 from onnxcustom import get_max_opset 

118 from onnxcustom.utils.onnx_rewriter import onnx_rewrite_operator 

119 

120 # first graph: it contains the node to replace 

121 opset = get_max_opset() 

122 node1 = OnnxReciprocal('X', output_names=['Y'], 

123 op_version=opset) 

124 onx1 = node1.to_onnx( 

125 inputs={'X': FloatTensorType()}, 

126 outputs={'Y': FloatTensorType()}, 

127 target_opset=opset) 

128 

129 # second graph: it contains the replacement graph 

130 node2 = OnnxDiv(numpy.array([1], dtype=numpy.float32), 

131 'X', output_names=['Y'], 

132 op_version=opset) 

133 onx2 = node2.to_onnx( 

134 inputs={'X': FloatTensorType()}, 

135 outputs={'Y': FloatTensorType()}, 

136 target_opset=opset) 

137 

138 # third graph: the modified graph 

139 onx3 = onnx_rewrite_operator(onx1, 'Reciprocal', onx2) 

140 print(onnx_simple_text_plot(onx3)) 

141 """ 

142 from mlprodict.onnx_tools.optim._onnx_optimisation_common import ( # pylint: disable=C0415 

143 _apply_remove_node_fct_node, _apply_optimisation_on_graph) 

144 

145 if hasattr(onx, 'graph'): 

146 fct = (lambda graph, recursive=False, debug_info=None: 

147 onnx_rewrite_operator( 

148 graph, op_type, sub_onx, recursive=recursive, 

149 debug_info=debug_info)) 

150 return _apply_optimisation_on_graph(fct, onx, recursive=recursive) 

151 

152 existing_names = _existing_names(onx) 

153 nodes = list(onx.node) 

154 new_nodes = [] 

155 new_inits = list(onx.initializer) 

156 for i, node in enumerate(nodes): 

157 if node.op_type != op_type: 

158 new_nodes.append(node) 

159 continue 

160 inits, newn = _onnx_rewrite_operator_node( 

161 existing_names, node, sub_onx) 

162 new_inits.extend(inits) 

163 new_nodes.extend(newn) 

164 

165 if recursive: 

166 # Handles subgraphs. 

167 for i in range(len(new_nodes)): # pylint: disable=C0200 

168 node = nodes[i] 

169 if node is None or not (node.attribute): # pylint: disable=C0325 

170 continue 

171 nodes[i] = _apply_remove_node_fct_node( 

172 onnx_rewrite_operator, node, recursive=True, 

173 debug_info=None) 

174 

175 graph = make_graph( 

176 new_nodes, onx.name, onx.input, onx.output, new_inits) 

177 return graph 

178 

179 

180def unreduced_onnx_loss(onx, output_name='score'): 

181 """ 

182 Every loss function reduces the results to compute a loss. 

183 The score function needs to get the loss for every observation, 

184 not the whole loss. This function looks for a reducing node 

185 and removes it before exposing the output as the only output. 

186 

187 :param onx: onx graph 

188 :param output_name: new output name 

189 :return: new onx graph 

190 """ 

191 from mlprodict.onnx_tools.onnx_manipulations import ( # pylint: disable=C0415 

192 select_model_inputs_outputs) 

193 

194 graph = onx.graph 

195 found = [] 

196 for node in graph.node: 

197 if node.op_type.startswith('Reduce'): 

198 found.append(node) 

199 if len(found) != 1: 

200 raise RuntimeError( # pragma: no cover 

201 "Unable to find one unique Reducing node but found %d - %r." 

202 "" % (len(found), [(n.op_type, n.name) for n in found])) 

203 node = found[0] 

204 input_name = node.input[0] 

205 new_onx = select_model_inputs_outputs( 

206 onx, outputs=[input_name], infer_shapes=True) 

207 

208 inits = new_onx.graph.initializer 

209 inputs = new_onx.graph.input # pylint: disable=E1101 

210 existing_names = _existing_names(new_onx) 

211 new_name = _unique_name(existing_names, output_name) 

212 new_nodes = list(new_onx.graph.node) # pylint: disable=E1101 

213 elem = graph.output[0].type.tensor_type.elem_type 

214 new_output = [make_tensor_value_info(new_name, elem, [None, 1])] 

215 

216 if node.op_type == "ReduceSumSquare": 

217 new_node = make_node('Mul', [input_name, input_name], [new_name]) 

218 new_nodes.append(new_node) 

219 elif node.op_type == 'ReduceSum': 

220 new_node = make_node('Identity', [input_name], [new_name]) 

221 new_nodes.append(new_node) 

222 else: 

223 raise RuntimeError( # pragma: no cover 

224 "Unable to unreduce node %r." % node.op_type) 

225 

226 graph = make_graph( 

227 new_nodes, graph.name, inputs, new_output, inits) 

228 new_model = make_model(graph) 

229 new_model.ir_version = onx.ir_version 

230 new_model.producer_name = onx.producer_name 

231 new_model.producer_version = onx.producer_version 

232 new_model.domain = onx.domain 

233 new_model.model_version = onx.model_version 

234 new_model.doc_string = onx.doc_string 

235 if hasattr(onx, 'value_info'): 

236 graph.value_info.extend(onx.value_info) # pylint: disable=E1101 

237 del new_model.opset_import[:] # pylint: disable=E1101 

238 for oimp in onx.opset_import: 

239 op_set = new_model.opset_import.add() # pylint: disable=E1101 

240 op_set.domain = oimp.domain 

241 op_set.version = oimp.version 

242 return new_model