Coverage for onnxcustom/training/grad_helper.py: 93%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=E1101
2"""
3@file
4@brief ONNX and gradient.
5"""
6from io import BytesIO
7from enum import IntFlag
8import onnx
9from onnx.helper import make_model, make_graph, make_node, make_tensor
10from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
11 OrtModuleGraphBuilder,
12 OrtModuleGraphBuilderConfiguration,
13 TrainingGraphTransformerConfiguration)
14from mlprodict.onnx_tools.optim.onnx_optimisation import onnx_remove_node
15from ..utils.orttraining_helper import get_train_initializer
18class DerivativeOptions(IntFlag):
19 """
20 Options defining how to build the onnx graph of the
21 gradients.
23 * `Zero`: default option, all options are disabled
24 * `KeepYieldOp`: keeps the operator *YieldOp* in the graph,
25 see @see fn onnx_derivative
26 * `KeepOutputs`: keeps the output of the original graph
27 * `FillGrad`: does not add any output to specify the gradient
28 of the output but assumes it is one
29 * `Loss`: the function assumes the loss was added to the graph
30 """
32 Zero = 0
33 KeepYieldOp = 1
34 KeepOutputs = 2
35 FillGrad = 4
36 Loss = 5
39def onnx_derivative(onx, weights=None, inputs=None,
40 options=DerivativeOptions.Zero,
41 loss=None, label=None, path_name=None):
42 """
43 Builds the gradient for an onnx graph.
45 :param onx: onnx graph
46 :param weights: gradient against those weights, None for all real weights
47 :param inputs: gradient against inputs, None for all real inputs
48 :param options: options of type @see cl DerivativeOptions
49 :param loss: loss output in case a loss was added in the graph,
50 *options* must be equal to `DerivativeOptions.Loss`
51 :param label: if *loss* is specified, then the label must be
52 specified as well
53 :param path_name: if *options* equal to `DerivativeOptions.Loss`,
54 the gradient is saved to that path
55 :return: onnx graph
57 The function calls :epkg:`OrtModuleGraphBuilderConfiguration`
58 from :epkg:`onnxruntime-training`. This graph is meant to be used
59 with @see cl OrtGradientForwardBackward and includes
60 operator `YieldOp`. That's the graph looks this way:
62 .. gdot::
63 :script: DOT-SECTION
65 import numpy
66 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
67 OnnxAdd, OnnxMul, OnnxIdentity)
68 from skl2onnx.common.data_types import FloatTensorType
69 from mlprodict.onnxrt import OnnxInference
70 from onnxcustom.training.grad_helper import (
71 onnx_derivative, DerivativeOptions)
72 from onnxcustom import __max_supported_opset__ as opv
74 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32),
75 op_version=opv, output_names=['Y'])
76 onx = node.to_onnx({'X': FloatTensorType([None, 10])},
77 {'Y': FloatTensorType([None, 10])},
78 target_opset=opv)
79 new_onx = onnx_derivative(onx, options=DerivativeOptions.KeepYieldOp)
81 oinf = OnnxInference(new_onx)
82 print("DOT-SECTION", oinf.to_dot())
84 These operators are the outputs of the
85 initial graph and must be replaced by the gradient of these
86 outputs to compute the gradient of the weights and the inputs.
87 After they are replaced, it looks this way:
89 .. gdot::
90 :script: DOT-SECTION
92 import numpy
93 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
94 OnnxAdd, OnnxMul, OnnxIdentity)
95 from skl2onnx.common.data_types import FloatTensorType
96 from mlprodict.onnxrt import OnnxInference
97 from onnxcustom.training.grad_helper import (
98 onnx_derivative, DerivativeOptions)
99 from onnxcustom import __max_supported_opset__ as opv
101 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32),
102 op_version=opv, output_names=['Y'])
103 onx = node.to_onnx({'X': FloatTensorType([None, 10])},
104 {'Y': FloatTensorType([None, 10])},
105 target_opset=opv)
106 new_onx = onnx_derivative(onx, options=DerivativeOptions.Zero)
108 oinf = OnnxInference(new_onx)
109 print("DOT-SECTION", oinf.to_dot())
111 The user can still compute the outputs.
113 .. gdot::
114 :script: DOT-SECTION
116 import numpy
117 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
118 OnnxAdd, OnnxMul, OnnxIdentity)
119 from skl2onnx.common.data_types import FloatTensorType
120 from mlprodict.onnxrt import OnnxInference
121 from onnxcustom.training.grad_helper import (
122 onnx_derivative, DerivativeOptions)
123 from onnxcustom import __max_supported_opset__ as opv
125 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32),
126 op_version=opv, output_names=['Y'])
127 onx = node.to_onnx({'X': FloatTensorType([None, 10])},
128 {'Y': FloatTensorType([None, 10])},
129 target_opset=opv)
130 new_onx = onnx_derivative(onx, options=DerivativeOptions.KeepOutputs)
132 oinf = OnnxInference(new_onx)
133 print("DOT-SECTION", oinf.to_dot())
135 The input gradient can be filled with a constant matrix
136 filled with one and with the expected shape.
138 .. gdot::
139 :script: DOT-SECTION
141 import numpy
142 from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
143 OnnxAdd, OnnxMul, OnnxIdentity)
144 from skl2onnx.common.data_types import FloatTensorType
145 from mlprodict.onnxrt import OnnxInference
146 from onnxcustom.training.grad_helper import (
147 onnx_derivative, DerivativeOptions)
148 from onnxcustom import __max_supported_opset__ as opv
150 node = OnnxAdd('X', numpy.array([1], dtype=numpy.float32),
151 op_version=opv, output_names=['Y'])
152 onx = node.to_onnx({'X': FloatTensorType([None, 10])},
153 {'Y': FloatTensorType([None, 10])},
154 target_opset=opv)
155 new_onx = onnx_derivative(onx, options=(
156 DerivativeOptions.KeepOutputs | DerivativeOptions.FillGrad))
158 oinf = OnnxInference(new_onx)
159 print("DOT-SECTION", oinf.to_dot())
160 """
161 if not isinstance(options, DerivativeOptions):
162 raise TypeError(
163 "Options must be from type DerivativeOptions not %r."
164 "" % type(options))
166 if options == DerivativeOptions.Loss:
167 return _onnx_derivative_loss(onx, weights=weights, inputs=inputs,
168 options=options, loss=loss, label=label,
169 path_name=path_name)
170 return _onnx_derivative_fw(onx, weights=weights, inputs=inputs,
171 options=options)
174def _default_inputs(onx):
175 "Guesses default inputs (float ones) if not specified."
176 inputs_name = []
177 for i in onx.graph.input:
178 try:
179 elem_type = i.type.tensor_type.elem_type
180 except AttributeError: # pragma: no cover
181 # not a vector
182 continue
183 if elem_type in (
184 onnx.TensorProto.FLOAT16,
185 onnx.TensorProto.FLOAT,
186 onnx.TensorProto.DOUBLE):
187 inputs_name.append(i.name)
188 return inputs_name
191def _onnx_derivative_fw(onx, weights, inputs, options):
192 """
193 Implements a gradient based on class `OrtModuleGraphBuilder`.
194 """
195 if weights is None:
196 inits = get_train_initializer(onx)
197 weights = list(inits)
198 builder = OrtModuleGraphBuilder()
200 config = OrtModuleGraphBuilderConfiguration()
201 config.initializer_names = weights
202 config.initializer_names_to_train = weights
203 if inputs is None:
204 inputs_name = _default_inputs(onx)
205 if len(inputs_name) > 0:
206 config.input_names_require_grad = inputs_name
207 config.build_gradient_graph = True
209 p = TrainingGraphTransformerConfiguration()
210 config.graph_transformer_config = p
212 builder.initialize(onx.SerializeToString(), config)
213 builder.build()
214 train_onnx_model_serialized = builder.get_model()
215 # optimized_pre_grad_model = builder.get_inference_optimized_model()
216 grad_yield = onnx.load(BytesIO(train_onnx_model_serialized))
217 if options & DerivativeOptions.KeepYieldOp:
218 if options != DerivativeOptions.KeepYieldOp:
219 raise ValueError(
220 "Option YieldOd cannot be combined with any other.")
221 return grad_yield
223 yields_op = [
224 node for node in grad_yield.graph.node
225 if node.op_type == 'YieldOp']
226 if len(yields_op) == 0:
227 raise RuntimeError( # pragma: no cover
228 "No YieldOp was found. The input graph must be wrong.")
230 other_nodes = [
231 node for node in grad_yield.graph.node
232 if node.op_type != 'YieldOp']
233 inputs = list(grad_yield.graph.input)
234 if options & DerivativeOptions.KeepOutputs:
235 outputs = list(grad_yield.graph.output)
236 else:
237 original = set(i.name for i in onx.graph.output)
238 outputs = [o for o in grad_yield.graph.output
239 if o.name not in original]
240 map_out = {o.name: o for o in onx.graph.output}
241 for yn in yields_op:
242 if len(yn.input) != 1 or len(yn.output) != 1:
243 raise NotImplementedError( # pragma: no cover
244 "Unexpected configuration for YieldOp node %r." % yn)
245 if yn.input[0] not in map_out:
246 raise RuntimeError( # pragma: no cover
247 "Unable to find output %r in %r." % (
248 yn.input[0], list(map_out)))
249 if not(options & DerivativeOptions.FillGrad): # pylint: disable=C0325
250 out = map_out[yn.input[0]]
251 new_input = onnx.ValueInfoProto()
252 new_input.name = yn.output[0]
253 new_input.doc_string = "from yieldop"
254 new_input.type.CopyFrom(out.type)
255 inputs.append(new_input)
256 else:
257 if not(options & DerivativeOptions.KeepOutputs): # pylint: disable=C0325
258 raise ValueError( # pragma: no cover
259 "FillGrad should be set with KeepOutputs.")
260 name = "%s_shape" % yn.input[0]
261 node = make_node('Shape', [yn.input[0]], [name])
262 other_nodes.append(node)
263 out = map_out[yn.input[0]]
264 elem_type = out.type.tensor_type.elem_type
265 node = make_node(
266 'ConstantOfShape', [name], [yn.output[0]],
267 value=make_tensor(
268 "value", elem_type, (1, ), [1]))
269 other_nodes.append(node)
270 if options & DerivativeOptions.KeepOutputs:
271 # Keeps output from the original graph.
272 outputs.append(out)
274 # Final graph.
275 graph = make_graph(
276 other_nodes, grad_yield.graph.name, inputs, outputs,
277 list(grad_yield.graph.initializer))
278 new_model = make_model(graph)
279 new_model.ir_version = grad_yield.ir_version
280 new_model.producer_name = grad_yield.producer_name
281 new_model.producer_version = grad_yield.producer_version
282 new_model.domain = grad_yield.domain
283 new_model.model_version = grad_yield.model_version
284 new_model.doc_string = grad_yield.doc_string
285 if hasattr(onx, 'value_info'):
286 graph.value_info.extend(grad_yield.value_info)
287 del new_model.opset_import[:]
288 for oimp in grad_yield.opset_import:
289 op_set = new_model.opset_import.add()
290 op_set.domain = oimp.domain
291 op_set.version = oimp.version
293 return onnx_remove_node(new_model)
296def _onnx_derivative_loss(onx, weights, inputs, options, loss, label,
297 path_name):
298 """
299 Implements a gradient based on class `PyGradientGraphBuilder`.
300 """
301 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611,C0415
302 GradientGraphBuilder)
303 if path_name is None:
304 raise ValueError(
305 "path_name must not be None if options is 'Loss'.")
306 if weights is not None:
307 raise ValueError(
308 "weights must be None if options is 'Loss'.")
309 if label is None:
310 raise ValueError(
311 "label must not be None if options is 'Loss'.")
312 if loss is None or not isinstance(loss, str):
313 raise ValueError(
314 "loss must not None and a string if options is 'Loss'.")
315 if isinstance(label, str):
316 label = {label}
317 else:
318 label = set(label)
319 if inputs is None:
320 inputs_name = _default_inputs(onx)
321 inputs = inputs_name
322 if isinstance(inputs, str):
323 inputs = {inputs}
324 else:
325 inputs = set(inputs)
326 inputs = set(x for x in inputs if x not in label)
328 builder = GradientGraphBuilder(
329 onx.SerializeToString(), label, inputs, loss)
330 builder.build()
331 builder.save(path_name)
332 with open(path_name, "rb") as f:
333 return onnx.load(f)