Coverage for onnxcustom/training/ 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=W0105
4@brief Helper for :epkg:`onnxruntime-training`.
6from onnxruntime import SessionOptions, InferenceSession, RunOptions
7from ..utils.onnx_function import function_onnx_graph
8from ..utils.onnxruntime_helper import device_to_providers
9from ..utils.onnx_rewriter import unreduced_onnx_loss
10from ._base_onnx_function import BaseLearningOnnx
13class BaseLearningLoss(BaseLearningOnnx):
14 """
15 Class handling the loss for class
16 @see cl OrtGradientForwardBackwardOptimizer.
17 All classes inheriting from this one creates one ONNX function,
18 returning the loss and the gradient of the loss against the
19 outputs. Method `loss_gradient` is the main method, it computes
20 the loss and the gradient defiend by one ONNX graph and
21 executed by an instance of :epkg:`InferenceSession`.
22 """
24 def __init__(self):
25 BaseLearningOnnx.__init__(self)
26 self.ro_ = RunOptions()
28 def build_onnx_score_function(self, opset, device, weight_name):
29 """
30 Assuming the loss function was created. This
31 one takes the onnx graph and generate the onnx graph
32 for the method `loss_score`.
33 """
34 if not hasattr(self, 'loss_grad_onnx_'):
35 raise RuntimeError( # pragma: no cover
36 "Missing attribute 'loss_grad_onnx_'. "
37 "Method 'build_onnx_function' should be called first.")
39 # score
40 so = SessionOptions()
41 so.log_severity_level = 4
42 self.loss_score_onnx_ = unreduced_onnx_loss(
43 self.loss_grad_onnx_, 'Y') # pylint: disable=E1101
44 self.loss_score_sess_ = InferenceSession(
45 self.loss_score_onnx_.SerializeToString(), so,
46 providers=device_to_providers(device))
47 self.loss_score_sess_bind_ = (
48 self.loss_score_sess_.io_binding()._iobinding)
50 def _call_iobinding(self, sess, bind):
51 sess.run_with_iobinding(bind, self.ro_)
53 def loss_gradient( # pylint: disable=E1101
54 self, device, expected, predicted, weight=None):
55 """
56 Returns the loss and the gradient as OrtValue.
58 :param device: device where the training takes place
59 :param expected: expected value
60 :param predicted: predicted value
61 :param weight: optional, training weights
62 (same dimension as expected and predicted tensors)
63 :return: loss and gradient
64 """
65 if (not hasattr(self, "loss_grad_sess_") or
66 not hasattr(self, "loss_grad_sess_bind_")):
67 raise RuntimeError( # pragma: no cover
68 "Attributes 'loss_grad_sess_bind_' or 'loss_grad_sess_' is "
69 "missing. Method 'build_onnx_function' has not been called.")
70 bind = self.loss_grad_sess_bind_
71 if weight is not None:
72 self._bind_input_ortvalue(
73 "weight", bind, weight, device, cache=True)
74 else:
75 self.clear_binding_inputs("weight", bind, cache=True)
76 self._bind_input_ortvalue("X1", bind, expected, device, cache=True)
77 self._bind_input_ortvalue("X2", bind, predicted, device, cache=True)
78 self.loss_grad_sess_bind_.bind_output('Y', device)
79 self.loss_grad_sess_bind_.bind_output('Y_grad', device)
80 self._call_iobinding(self.loss_grad_sess_._sess, bind)
81 loss, grad = bind.get_outputs()
82 return loss, grad
84 def loss_scores( # pylint: disable=E1101
85 self, device, expected, predicted, weight=None):
86 """
87 Returns the weighted loss (or score)
88 for every observation as OrtValue.
90 :param device: device where the training takes place
91 :param expected: expected value
92 :param predicted: predicted value
93 :param weight: optional, training weights
94 (same dimension as expected and predicted tensors)
95 :return: a score for every observation
96 """
97 if (not hasattr(self, "loss_score_sess_") or
98 not hasattr(self, "loss_score_sess_bind_")):
99 raise RuntimeError( # pragma: no cover
100 "Attributes 'loss_score_sess_bind_' or 'loss_score_sess_' is "
101 "missing. Method 'build_onnx_function' has not been called.")
102 bind = self.loss_score_sess_bind_
103 if weight is not None:
104 self._bind_input_ortvalue(
105 "weight", bind, weight, device, cache=True)
106 else:
107 self.clear_binding_inputs("weight", bind, cache=True)
108 self._bind_input_ortvalue("X1", bind, expected, device, cache=True)
109 self._bind_input_ortvalue("X2", bind, predicted, device, cache=True)
110 self.loss_score_sess_bind_.bind_output('Y', device)
111 self._call_iobinding(self.loss_score_sess_._sess, bind)
112 score = bind.get_outputs()
113 return score[0]
115 @staticmethod
116 def select(class_name, **kwargs):
117 """
118 Returns an instance of a given initialized with
119 *kwargs*.
120 :param class_name: an instance of @see cl BaseLearningLoss
121 or a string among the following class names (see below)
122 :return: instance of @see cl BaseLearningLoss
124 Possible values for *class_name*:
125 * `'square_error'`: see @see cl SquareLearningLoss
126 * `'absolute_error'`: see @see cl AbsoluteLearningLoss
127 * `'elastic_error'`: see @see cl ElasticLearningLoss
128 """
129 if isinstance(class_name, BaseLearningLoss):
130 return class_name
131 cls = {SquareLearningLoss: ['square_error', 'square'],
132 AbsoluteLearningLoss: ['absolute_error', 'absolute'],
133 ElasticLearningLoss: ['elastic_error', 'elastic'],
134 NegLogLearningLoss: ['log', 'neglog', 'logloss']}
135 for cl, aliases in cls.items():
136 if class_name == cl.__class__.__name__ or class_name in aliases:
137 return cl(**kwargs)
138 raise ValueError( # pragma: no cover
139 "Unexpected class name %r. It should be one of %r." % (
140 class_name, list(map(lambda c: c.__name__, cls))))
143class SquareLearningLoss(BaseLearningLoss):
144 """
145 Implements a square loss :math:`(Y - Z)^2`
146 where *Y* is the output and *Z* the expected output.
147 See @see fn _onnx_grad_loss_square_error for the ONNX
148 implementation.
149 """
151 def __init__(self):
152 BaseLearningLoss.__init__(self)
154 def build_onnx_function(self, opset, device, weight_name):
155 so = SessionOptions()
156 so.log_severity_level = 4
158 # loss_grad
159 self.loss_grad_onnx_ = function_onnx_graph(
160 "grad_loss_square_error", target_opset=opset,
161 weight_name=weight_name, multiply=1)
162 self.loss_grad_sess_ = InferenceSession(
163 self.loss_grad_onnx_.SerializeToString(), so,
164 providers=device_to_providers(device))
165 self.loss_grad_sess_bind_ = (
166 self.loss_grad_sess_.io_binding()._iobinding)
168 # score
169 self.build_onnx_score_function(opset, device, weight_name)
172class AbsoluteLearningLoss(BaseLearningLoss):
173 """
174 Implements a square loss :math:`|Y - Z|`
175 where *Y* is the output and *Z* the expected output.
176 See @see fn _onnx_grad_loss_absolute_error for the ONNX
177 implementation.
178 """
180 def __init__(self):
181 BaseLearningLoss.__init__(self)
183 def build_onnx_function(self, opset, device, weight_name):
184 so = SessionOptions()
185 so.log_severity_level = 4
187 # loss_grad
188 self.loss_grad_onnx_ = function_onnx_graph(
189 "grad_loss_absolute_error", target_opset=opset,
190 weight_name=weight_name)
191 self.loss_grad_sess_ = InferenceSession(
192 self.loss_grad_onnx_.SerializeToString(), so,
193 providers=device_to_providers(device))
194 self.loss_grad_sess_bind_ = (
195 self.loss_grad_sess_.io_binding()._iobinding)
197 # score
198 self.build_onnx_score_function(opset, device, weight_name)
201class ElasticLearningLoss(BaseLearningLoss):
202 """
203 Implements a square loss
204 :math:`(Y - Z)^2 \\alpha + |Y - Z| * \\beta`
205 where *Y* is the output and *Z* the expected output,
206 :math:`\\alpha` is *l2_weight* and :math:`\\beta`
207 is *l1_weight*.
209 :param l1_weight: weight of L1 norm
210 :param l2_weight: weight of L2 norm
212 See @see fn _onnx_grad_loss_elastic_error for the ONNX
213 implementation.
214 """
216 def __init__(self, l1_weight=0.5, l2_weight=0.5):
217 BaseLearningLoss.__init__(self)
218 self.l1_weight = l1_weight
219 self.l2_weight = l2_weight
221 def build_onnx_function(self, opset, device, weight_name):
222 so = SessionOptions()
223 so.log_severity_level = 4
225 # loss_grad
226 self.loss_grad_onnx_ = function_onnx_graph(
227 "grad_loss_elastic_error", target_opset=opset,
228 weight_name=weight_name, l1_weight=self.l1_weight,
229 l2_weight=self.l2_weight)
230 self.loss_grad_sess_ = InferenceSession(
231 self.loss_grad_onnx_.SerializeToString(), so,
232 providers=device_to_providers(device))
233 self.loss_grad_sess_bind_ = (
234 self.loss_grad_sess_.io_binding()._iobinding)
236 # score
237 self.build_onnx_score_function(opset, device, weight_name)
240class NegLogLearningLoss(BaseLearningLoss):
241 """
242 Implements a negative log loss
243 `'log(yt, yp) = -(1-yt)\\log(1-yp) - yt\\log(yp)`,
244 this only works for a binary classification where *yp* is the
245 predicted probability, *yt* is the expected probability.
246 *yt* is expected to be binary, *yp* is a matrix with two
247 columns, the sum on every line is 1.
248 However, this loss is usually applied after a function softmax
249 and the gradient is directly computed from the loss to the
250 raw score before they are processed through the softmax function
251 (see class `Log
252 <
253 linear_model/_sgd_fast.pyx#L236>`_).
255 :param eps: clipping value for probabilities,
256 avoids computing `log(0)`
257 :param probability_function: function to convert
258 raw scores into probabilities, default value is `sigmoid`
259 for a logistic regression
260 """
262 def __init__(self, eps=1e-5, probability_function='sigmoid'):
263 BaseLearningLoss.__init__(self)
264 self.eps = eps
265 self.probability_function = probability_function
267 def build_onnx_function(self, opset, device, weight_name):
268 so = SessionOptions()
269 so.log_severity_level = 4
271 # loss_grad
272 fct_name = "grad_%s_neg_log_loss_error" % self.probability_function
273 self.loss_grad_onnx_ = function_onnx_graph(
274 fct_name, target_opset=opset,
275 weight_name=weight_name, eps=self.eps)
276 self.loss_grad_sess_ = InferenceSession(
277 self.loss_grad_onnx_.SerializeToString(), so,
278 providers=device_to_providers(device))
279 self.loss_grad_sess_bind_ = (
280 self.loss_grad_sess_.io_binding()._iobinding)
282 # score
283 self.build_onnx_score_function(opset, device, weight_name)