Coverage for onnxcustom/training/grad_helper.py: 93%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

121 statements  

1# pylint: disable=E1101 

2""" 

3@file 

4@brief ONNX and gradient. 

5""" 

6from io import BytesIO 

7from enum import IntFlag 

8import onnx 

9from onnx.helper import make_model, make_graph, make_node, make_tensor 

10from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

11 OrtModuleGraphBuilder, 

12 OrtModuleGraphBuilderConfiguration, 

13 TrainingGraphTransformerConfiguration) 

14from mlprodict.onnx_tools.optim.onnx_optimisation import onnx_remove_node 

15from ..utils.orttraining_helper import get_train_initializer 

16 

17 

18class DerivativeOptions(IntFlag): 

19 """ 

20 Options defining how to build the onnx graph of the 

21 gradients. 

22 

23 * `Zero`: default option, all options are disabled 

24 * `KeepYieldOp`: keeps the operator *YieldOp* in the graph, 

25 see @see fn onnx_derivative 

26 * `KeepOutputs`: keeps the output of the original graph 

27 * `FillGrad`: does not add any output to specify the gradient 

28 of the output but assumes it is one 

29 * `Loss`: the function assumes the loss was added to the graph 

30 """ 

31 

32 Zero = 0 

33 KeepYieldOp = 1 

34 KeepOutputs = 2 

35 FillGrad = 4 

36 Loss = 5 

37 

38 

39def onnx_derivative(onx, weights=None, inputs=None, 

40 options=DerivativeOptions.Zero, 

41 loss=None, label=None, path_name=None): 

42 """ 

43 Builds the gradient for an onnx graph. 

44 

45 :param onx: onnx graph 

46 :param weights: gradient against those weights, None for all real weights 

47 :param inputs: gradient against inputs, None for all real inputs 

48 :param options: options of type @see cl DerivativeOptions 

49 :param loss: loss output in case a loss was added in the graph, 

50 *options* must be equal to `DerivativeOptions.Loss` 

51 :param label: if *loss* is specified, then the label must be 

52 specified as well 

53 :param path_name: if *options* equal to `DerivativeOptions.Loss`, 

54 the gradient is saved to that path 

55 :return: onnx graph 

56 

57 The function calls :epkg:`OrtModuleGraphBuilderConfiguration` 

58 from :epkg:`onnxruntime-training`. This graph is meant to be used 

59 with @see cl OrtGradientForwardBackward and includes 

60 operator `YieldOp`. That's the graph looks this way: 

61 

62 .. gdot:: 

63 :script: DOT-SECTION 

64 

65 import numpy 

66 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 

67 OnnxAdd, OnnxMul, OnnxIdentity) 

68 from skl2onnx.common.data_types import FloatTensorType 

69 from mlprodict.onnxrt import OnnxInference 

70 from onnxcustom.training.grad_helper import ( 

71 onnx_derivative, DerivativeOptions) 

72 from onnxcustom import __max_supported_opset__ as opv 

73 

74 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32), 

75 op_version=opv, output_names=['Y']) 

76 onx = node.to_onnx({'X': FloatTensorType([None, 10])}, 

77 {'Y': FloatTensorType([None, 10])}, 

78 target_opset=opv) 

79 new_onx = onnx_derivative(onx, options=DerivativeOptions.KeepYieldOp) 

80 

81 oinf = OnnxInference(new_onx) 

82 print("DOT-SECTION", oinf.to_dot()) 

83 

84 These operators are the outputs of the 

85 initial graph and must be replaced by the gradient of these 

86 outputs to compute the gradient of the weights and the inputs. 

87 After they are replaced, it looks this way: 

88 

89 .. gdot:: 

90 :script: DOT-SECTION 

91 

92 import numpy 

93 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 

94 OnnxAdd, OnnxMul, OnnxIdentity) 

95 from skl2onnx.common.data_types import FloatTensorType 

96 from mlprodict.onnxrt import OnnxInference 

97 from onnxcustom.training.grad_helper import ( 

98 onnx_derivative, DerivativeOptions) 

99 from onnxcustom import __max_supported_opset__ as opv 

100 

101 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32), 

102 op_version=opv, output_names=['Y']) 

103 onx = node.to_onnx({'X': FloatTensorType([None, 10])}, 

104 {'Y': FloatTensorType([None, 10])}, 

105 target_opset=opv) 

106 new_onx = onnx_derivative(onx, options=DerivativeOptions.Zero) 

107 

108 oinf = OnnxInference(new_onx) 

109 print("DOT-SECTION", oinf.to_dot()) 

110 

111 The user can still compute the outputs. 

112 

113 .. gdot:: 

114 :script: DOT-SECTION 

115 

116 import numpy 

117 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 

118 OnnxAdd, OnnxMul, OnnxIdentity) 

119 from skl2onnx.common.data_types import FloatTensorType 

120 from mlprodict.onnxrt import OnnxInference 

121 from onnxcustom.training.grad_helper import ( 

122 onnx_derivative, DerivativeOptions) 

123 from onnxcustom import __max_supported_opset__ as opv 

124 

125 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32), 

126 op_version=opv, output_names=['Y']) 

127 onx = node.to_onnx({'X': FloatTensorType([None, 10])}, 

128 {'Y': FloatTensorType([None, 10])}, 

129 target_opset=opv) 

130 new_onx = onnx_derivative(onx, options=DerivativeOptions.KeepOutputs) 

131 

132 oinf = OnnxInference(new_onx) 

133 print("DOT-SECTION", oinf.to_dot()) 

134 

135 The input gradient can be filled with a constant matrix 

136 filled with one and with the expected shape. 

137 

138 .. gdot:: 

139 :script: DOT-SECTION 

140 

141 import numpy 

142 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 

143 OnnxAdd, OnnxMul, OnnxIdentity) 

144 from skl2onnx.common.data_types import FloatTensorType 

145 from mlprodict.onnxrt import OnnxInference 

146 from onnxcustom.training.grad_helper import ( 

147 onnx_derivative, DerivativeOptions) 

148 from onnxcustom import __max_supported_opset__ as opv 

149 

150 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32), 

151 op_version=opv, output_names=['Y']) 

152 onx = node.to_onnx({'X': FloatTensorType([None, 10])}, 

153 {'Y': FloatTensorType([None, 10])}, 

154 target_opset=opv) 

155 new_onx = onnx_derivative(onx, options=( 

156 DerivativeOptions.KeepOutputs | DerivativeOptions.FillGrad)) 

157 

158 oinf = OnnxInference(new_onx) 

159 print("DOT-SECTION", oinf.to_dot()) 

160 """ 

161 if not isinstance(options, DerivativeOptions): 

162 raise TypeError( 

163 "Options must be from type DerivativeOptions not %r." 

164 "" % type(options)) 

165 

166 if options == DerivativeOptions.Loss: 

167 return _onnx_derivative_loss(onx, weights=weights, inputs=inputs, 

168 options=options, loss=loss, label=label, 

169 path_name=path_name) 

170 return _onnx_derivative_fw(onx, weights=weights, inputs=inputs, 

171 options=options) 

172 

173 

174def _default_inputs(onx): 

175 "Guesses default inputs (float ones) if not specified." 

176 inputs_name = [] 

177 for i in onx.graph.input: 

178 try: 

179 elem_type = i.type.tensor_type.elem_type 

180 except AttributeError: # pragma: no cover 

181 # not a vector 

182 continue 

183 if elem_type in ( 

184 onnx.TensorProto.FLOAT16, 

185 onnx.TensorProto.FLOAT, 

186 onnx.TensorProto.DOUBLE): 

187 inputs_name.append(i.name) 

188 return inputs_name 

189 

190 

191def _onnx_derivative_fw(onx, weights, inputs, options): 

192 """ 

193 Implements a gradient based on class `OrtModuleGraphBuilder`. 

194 """ 

195 if weights is None: 

196 inits = get_train_initializer(onx) 

197 weights = list(inits) 

198 builder = OrtModuleGraphBuilder() 

199 

200 config = OrtModuleGraphBuilderConfiguration() 

201 config.initializer_names = weights 

202 config.initializer_names_to_train = weights 

203 if inputs is None: 

204 inputs_name = _default_inputs(onx) 

205 if len(inputs_name) > 0: 

206 config.input_names_require_grad = inputs_name 

207 config.build_gradient_graph = True 

208 

209 p = TrainingGraphTransformerConfiguration() 

210 config.graph_transformer_config = p 

211 

212 builder.initialize(onx.SerializeToString(), config) 

213 builder.build() 

214 train_onnx_model_serialized = builder.get_model() 

215 # optimized_pre_grad_model = builder.get_inference_optimized_model() 

216 grad_yield = onnx.load(BytesIO(train_onnx_model_serialized)) 

217 if options & DerivativeOptions.KeepYieldOp: 

218 if options != DerivativeOptions.KeepYieldOp: 

219 raise ValueError( 

220 "Option YieldOd cannot be combined with any other.") 

221 return grad_yield 

222 

223 yields_op = [ 

224 node for node in grad_yield.graph.node 

225 if node.op_type == 'YieldOp'] 

226 if len(yields_op) == 0: 

227 raise RuntimeError( # pragma: no cover 

228 "No YieldOp was found. The input graph must be wrong.") 

229 

230 other_nodes = [ 

231 node for node in grad_yield.graph.node 

232 if node.op_type != 'YieldOp'] 

233 inputs = list(grad_yield.graph.input) 

234 if options & DerivativeOptions.KeepOutputs: 

235 outputs = list(grad_yield.graph.output) 

236 else: 

237 original = set(i.name for i in onx.graph.output) 

238 outputs = [o for o in grad_yield.graph.output 

239 if o.name not in original] 

240 map_out = {o.name: o for o in onx.graph.output} 

241 for yn in yields_op: 

242 if len(yn.input) != 1 or len(yn.output) != 1: 

243 raise NotImplementedError( # pragma: no cover 

244 "Unexpected configuration for YieldOp node %r." % yn) 

245 if yn.input[0] not in map_out: 

246 raise RuntimeError( # pragma: no cover 

247 "Unable to find output %r in %r." % ( 

248 yn.input[0], list(map_out))) 

249 if not(options & DerivativeOptions.FillGrad): # pylint: disable=C0325 

250 out = map_out[yn.input[0]] 

251 new_input = onnx.ValueInfoProto() 

252 new_input.name = yn.output[0] 

253 new_input.doc_string = "from yieldop" 

254 new_input.type.CopyFrom(out.type) 

255 inputs.append(new_input) 

256 else: 

257 if not(options & DerivativeOptions.KeepOutputs): # pylint: disable=C0325 

258 raise ValueError( # pragma: no cover 

259 "FillGrad should be set with KeepOutputs.") 

260 name = "%s_shape" % yn.input[0] 

261 node = make_node('Shape', [yn.input[0]], [name]) 

262 other_nodes.append(node) 

263 out = map_out[yn.input[0]] 

264 elem_type = out.type.tensor_type.elem_type 

265 node = make_node( 

266 'ConstantOfShape', [name], [yn.output[0]], 

267 value=make_tensor( 

268 "value", elem_type, (1, ), [1])) 

269 other_nodes.append(node) 

270 if options & DerivativeOptions.KeepOutputs: 

271 # Keeps output from the original graph. 

272 outputs.append(out) 

273 

274 # Final graph. 

275 graph = make_graph( 

276 other_nodes, grad_yield.graph.name, inputs, outputs, 

277 list(grad_yield.graph.initializer)) 

278 new_model = make_model(graph) 

279 new_model.ir_version = grad_yield.ir_version 

280 new_model.producer_name = grad_yield.producer_name 

281 new_model.producer_version = grad_yield.producer_version 

282 new_model.domain = grad_yield.domain 

283 new_model.model_version = grad_yield.model_version 

284 new_model.doc_string = grad_yield.doc_string 

285 if hasattr(onx, 'value_info'): 

286 graph.value_info.extend(grad_yield.value_info) 

287 del new_model.opset_import[:] 

288 for oimp in grad_yield.opset_import: 

289 op_set = new_model.opset_import.add() 

290 op_set.domain = oimp.domain 

291 op_set.version = oimp.version 

292 

293 return onnx_remove_node(new_model) 

294 

295 

296def _onnx_derivative_loss(onx, weights, inputs, options, loss, label, 

297 path_name): 

298 """ 

299 Implements a gradient based on class `PyGradientGraphBuilder`. 

300 """ 

301 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611,C0415 

302 GradientGraphBuilder) 

303 if path_name is None: 

304 raise ValueError( 

305 "path_name must not be None if options is 'Loss'.") 

306 if weights is not None: 

307 raise ValueError( 

308 "weights must be None if options is 'Loss'.") 

309 if label is None: 

310 raise ValueError( 

311 "label must not be None if options is 'Loss'.") 

312 if loss is None or not isinstance(loss, str): 

313 raise ValueError( 

314 "loss must not None and a string if options is 'Loss'.") 

315 if isinstance(label, str): 

316 label = {label} 

317 else: 

318 label = set(label) 

319 if inputs is None: 

320 inputs_name = _default_inputs(onx) 

321 inputs = inputs_name 

322 if isinstance(inputs, str): 

323 inputs = {inputs} 

324 else: 

325 inputs = set(inputs) 

326 inputs = set(x for x in inputs if x not in label) 

327 

328 builder = GradientGraphBuilder( 

329 onx.SerializeToString(), label, inputs, loss) 

330 builder.build() 

331 builder.save(path_name) 

332 with open(path_name, "rb") as f: 

333 return onnx.load(f)