Coverage for onnxcustom/training/sgd_learning_penalty.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=W0105
2"""
3@file
4@brief Helper for :epkg:`onnxruntime-training`.
5"""
6from onnxruntime import SessionOptions, InferenceSession, RunOptions
7from ..utils.onnx_function import function_onnx_graph
8from ..utils.onnxruntime_helper import device_to_providers
9from ._base_onnx_function import BaseLearningOnnx
12class BaseLearningPenalty(BaseLearningOnnx):
13 """
14 Class handling the penalty on the coefficients for class
15 @see cl OrtGradientForwardBackwardOptimizer.
16 """
18 def __init__(self):
19 BaseLearningOnnx.__init__(self)
20 self.ro_ = RunOptions()
22 def _call_iobinding(self, sess, bind):
23 sess.run_with_iobinding(bind, self.ro_)
25 @staticmethod
26 def select(class_name, **kwargs):
27 """
28 Returns an instance of a given initialized with
29 *kwargs*.
30 :param class_name: an instance of @see cl BaseLearningPenalty
31 or a string among the following class names (see below)
32 :return: instance of @see cl BaseLearningPenalty
34 Possible values for *class_name*:
35 * None or `'penalty'`: see @see cl L1L2PenaltyLearning
36 """
37 if isinstance(class_name, BaseLearningPenalty):
38 return class_name
39 cls = {NoLearningPenalty: [None, ''],
40 ElasticLearningPenalty: ['elastic', 'l1l2']}
41 for cl, aliases in cls.items():
42 if class_name == cl.__class__.__name__ or class_name in aliases:
43 return cl(**kwargs)
44 raise ValueError( # pragma: no cover
45 "Unexpected class name %r. It should be one of %r." % (
46 class_name, list(map(lambda c: c.__name__, cls))))
48 def penalty_loss(self, device, loss, *weights):
49 """
50 Returns the received loss. Updates the loss inplace.
52 :param device: device where the training takes place
53 :param loss: loss without penalty
54 :param weights: any weights to be penalized
55 :return: loss
56 """
57 raise NotImplementedError(
58 "penalty_loss must be overwritten.")
60 def update_weights(self, device, statei):
61 """
62 Returns the received loss. Updates the weight inplace.
64 :param device: device where the training takes place
65 :param statei: loss without penalty
66 :return: weight
67 """
68 raise NotImplementedError(
69 "update_weights must be overwritten.")
72class NoLearningPenalty(BaseLearningPenalty):
73 """
74 No regularization.
75 """
77 def __init__(self):
78 BaseLearningPenalty.__init__(self)
80 def build_onnx_function(self, opset, device, n_tensors):
81 # Nothing to do.
82 pass
84 def penalty_loss(self, device, loss, *weights):
85 """
86 Returns the received loss. Updates the loss inplace.
88 :param device: device where the training takes place
89 :param loss: loss without penalty
90 :param weights: any weights to be penalized
91 :return: loss
92 """
93 return loss
95 def update_weights(self, n_bind, device, statei):
96 """
97 Returns the received loss. Updates the weight inplace.
99 :param device: device where the training takes place
100 :param statei: loss without penalty
101 :return: weight
102 """
103 return statei
106class ElasticLearningPenalty(BaseLearningPenalty):
107 """
108 Implements a L1 or L2 regularization on weights.
109 """
111 def __init__(self, l1=0.5, l2=0.5):
112 BaseLearningPenalty.__init__(self)
113 self.l1 = l1
114 self.l2 = l2
116 def build_onnx_function(self, opset, device, n_tensors):
117 so = SessionOptions()
118 so.log_severity_level = 4
120 # loss_grad
121 self.penalty_onnx_ = function_onnx_graph(
122 "n_penalty_elastic_error", target_opset=opset, n_tensors=n_tensors,
123 loss_shape=None, l1_weight=self.l1, l2_weight=self.l2)
124 self.penalty_sess_ = InferenceSession(
125 self.penalty_onnx_.SerializeToString(), so,
126 providers=device_to_providers(device))
127 self.penalty_sess_bind_ = (
128 self.penalty_sess_.io_binding()._iobinding)
129 self.names_ = [i.name for i in self.penalty_onnx_.graph.input]
131 # weight updates
132 self.penalty_grad_onnx_ = function_onnx_graph(
133 "update_penalty_elastic_error", target_opset=opset,
134 l1=self.l1, l2=self.l2)
135 self.penalty_grad_sess_ = InferenceSession(
136 self.penalty_grad_onnx_.SerializeToString(), so,
137 providers=device_to_providers(device))
138 self.penalty_grad_sess_binds_ = [
139 self.penalty_grad_sess_.io_binding()._iobinding
140 for n in range(n_tensors)]
142 def penalty_loss(self, device, *inputs):
143 """
144 Computes the penalty associated to every
145 weights and adds them up to the loss.
147 :param device: device where the training takes place
148 :param inputs: loss without penalty and weights
149 :return: loss + penatlies
150 """
151 if (not hasattr(self, "penalty_onnx_") or
152 not hasattr(self, "penalty_sess_bind_")):
153 raise RuntimeError( # pragma: no cover
154 "Attributes 'penalty_sess_bind_' or 'penalty_onnx_' is "
155 "missing. Method 'build_onnx_function' has not been called.")
156 if len(self.names_) != len(inputs):
157 raise RuntimeError( # pragma: no cover
158 "Mismatched number of inputs: %d != %d." % (
159 len(self.names_), len(inputs)))
161 for name, inp in zip(self.names_, inputs):
162 self._bind_input_ortvalue(
163 name, self.penalty_sess_bind_, inp, device, cache=True)
164 self._bind_output_ortvalue(
165 'Y', self.penalty_sess_bind_, inputs[0], cache=True)
166 self._call_iobinding(self.penalty_sess_._sess, self.penalty_sess_bind_)
167 return self.penalty_sess_bind_.get_outputs()[0]
169 def update_weights(self, n_bind, device, statei):
170 if (not hasattr(self, "penalty_grad_onnx_") or
171 not hasattr(self, "penalty_grad_sess_binds_")):
172 raise RuntimeError( # pragma: no cover
173 "Attributes 'penalty_grad_sess_binds_' or "
174 "'penalty_grad_onnx_' is missing. Method "
175 "'build_onnx_function' has not been called.")
176 bind = self.penalty_grad_sess_binds_[n_bind]
177 self._bind_input_ortvalue("X", bind, statei, device, cache=True)
178 self._bind_output_ortvalue('Y', bind, statei, cache=True)
179 self._call_iobinding(self.penalty_grad_sess_._sess, bind)
180 return bind.get_outputs()[0] # X