Coverage for onnxcustom/training/sgd_learning_rate.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=W0105
2"""
3@file
4@brief Helper for :epkg:`onnxruntime-training`.
5"""
6import numpy
7from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
8from onnxruntime import SessionOptions, InferenceSession, RunOptions
9from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
10 OrtValue as C_OrtValue)
11from ..utils.onnx_function import function_onnx_graph
12from ..utils.onnxruntime_helper import device_to_providers
13from ._base_onnx_function import BaseLearningOnnx
16class BaseLearningRate(BaseLearningOnnx):
17 """
18 Class handling the learning rate update after every
19 iteration of a gradient. Two methods need to be overwritten
20 `init_learning_rate` and `update_learning_rate`. The first one
21 starts the loop, the second returns the next one.
22 """
24 def __init__(self):
25 BaseLearningOnnx.__init__(self)
26 self.ro_ = RunOptions()
28 def _call_iobinding(self, sess, bind):
29 sess.run_with_iobinding(bind, self.ro_)
31 def init_learning_rate(self):
32 """
33 Initializes the learning rate at the beginning of the training.
34 :return: self
35 """
36 raise NotImplementedError(
37 "This method must be overwritten.")
39 def update_learning_rate(self, t):
40 """
41 Updates the learning rate at the end of an iteration.
42 :param t: iteration number
43 :return: self
44 """
45 raise NotImplementedError(
46 "This method must be overwritten.")
48 @property
49 def value(self):
50 "Returns the current learning rate."
51 raise NotImplementedError(
52 "This method must be overwritten.")
54 def __repr_extended__(self):
55 return (
56 ', value=%r' % self.value
57 if hasattr(self, 'value_') and self.value_ is not None # pylint: disable=E1101
58 else '')
60 @property
61 def needs_grad(self):
62 """
63 Returns the True if the gradient update needs to retain
64 past gradients.
65 """
66 raise NotImplementedError(
67 "This method must be overwritten.")
69 def update_weights(self, device, statei, gradienti, batch_size,
70 velocity=None):
71 """
72 Updates weights based on the algorithm this class
73 is setting up.
75 :param device: device
76 :param statei: current weight
77 :param gradienti: gradient
78 :param batch_size: batch_size
79 :param velocity: same shape as the gradient
80 """
81 raise NotImplementedError(
82 "This method must be overwritten.")
84 def loop(self, n=1000):
85 """
86 Loops over learning rate values, *n* to be precise.
87 :param n: number of requested iterations
88 :return: iterator
89 """
90 self.init_learning_rate()
91 for i in range(n):
92 yield self.value
93 self.update_learning_rate(i + 1)
95 @staticmethod
96 def select(class_name, **kwargs):
97 """
98 Returns an instance of a given initialized with
99 *kwargs*.
100 :param class_name: an instance of @see cl BaseLearningRate
101 or a string among the following class names (see below),
102 it can also be a float and in that case, class
103 @see cl LearningRateSGD is used
104 :return: instance of @see cl BaseLearningRate
106 Possible values for *class_name*:
107 * `'SGD'` or `'LearningRateSGD'`: see @see cl LearningRateSGD
108 """
109 if isinstance(class_name, BaseLearningRate):
110 return class_name
111 if isinstance(class_name, float):
112 return LearningRateSGD(class_name)
113 cls = {LearningRateSGD: ['SGD'],
114 LearningRateSGDNesterov: ['SGDNesterov', 'Nesterov']}
115 for cl, aliases in cls.items():
116 if class_name == cl.__class__.__name__ or class_name in aliases:
117 return cl(**kwargs)
118 raise ValueError( # pragma: no cover
119 "Unexpected class name %r. It should be one of %r." % (
120 class_name, list(map(lambda c: c.__name__, cls))))
123class LearningRateSGD(BaseLearningRate):
124 """
125 Implements the learning the same way as
126 :class:`sklearn.linear_model.SGDRegressor`.
128 :param eta0: initial learning rate for the `'constant'`, `'invscaling'`
129 or `'adaptive'` schedules.
130 :param alpha: constant that multiplies the regularization term,
131 the higher the value, the stronger the regularization.
132 Also used to compute the learning rate when set to *learning_rate*
133 is set to `'optimal'`.
134 :param power_t: exponent for inverse scaling learning rate
135 :param learning_rate: learning rate schedule:
136 * `'constant'`: `eta = eta0`
137 * `'optimal'`: `eta = 1.0 / (alpha * (t + t0))` where *t0* is chosen
138 by a heuristic proposed by Leon Bottou, this number is multiplied
139 by a constant C to make the first number equal to *eta0*
140 * `'invscaling'`: `eta = eta0 / pow(t, power_t)`
142 Created attributes:
143 * `eta0_`: initial eta0
144 * `optimal_init_`: use when `learning_rate=='optimal'`
145 * `value_`: value to be returned by property `value`
146 """
148 def __init__(self, eta0=0.01, alpha=0.0001, power_t=0.25,
149 learning_rate='invscaling'):
150 BaseLearningRate.__init__(self)
151 if learning_rate not in ('invscaling', 'optimal', 'constant'):
152 raise ValueError(
153 "Unxepected value for learning_rate=%r." % learning_rate)
154 self.eta0 = eta0
155 self.alpha = alpha
156 self.power_t = power_t
157 self.learning_rate = learning_rate.lower()
158 self.value_ = None
160 @property
161 def value(self):
162 "Returns the current learning rate."
163 if self.value_ is None:
164 raise RuntimeError( # pragma: no cover
165 "Method init_learning_rate was never called.")
166 return self.value_
168 @property
169 def needs_grad(self):
170 """
171 Returns the True if the gradient update needs to retain
172 past gradients.
173 """
174 return False
176 def init_learning_rate(self):
177 """
178 Updates the learning rate at the end of an iteration.
179 :return: self
180 """
181 self.eta0_ = self.eta0
182 if self.learning_rate == "optimal":
183 typw = numpy.sqrt(1.0 / numpy.sqrt(self.alpha))
184 eta0 = typw / max(1.0, (1 + typw) * 2)
185 self.optimal_init_ = 1.0 / (eta0 * self.alpha)
186 eta = 1. / (self.alpha * self.optimal_init_)
187 self.optimal_fact_ = self.eta0 / eta
188 self.eta0_ = self.eta0
189 else:
190 self.eta0_ = self.eta0
191 self.value_ = self.eta0_
192 return self
194 def update_learning_rate(self, t):
195 """
196 Updates the learning rate at the end of an iteration.
197 :param t: iteration number
198 :return: self
199 """
200 eta = self.value_
201 if self.learning_rate == "optimal":
202 eta = self.optimal_fact_ / (self.alpha * (self.optimal_init_ + t))
203 elif self.learning_rate == "invscaling":
204 eta = self.eta0_ / numpy.power(t + 1, self.power_t)
205 self.value_ = eta
206 return self
208 def build_onnx_function(self, opset, device, n_tensors):
209 so = SessionOptions()
210 so.log_severity_level = 4
212 self.axpy_onnx_ = function_onnx_graph("axpy")
213 self.axpy_sess_ = InferenceSession(
214 self.axpy_onnx_.SerializeToString(), so,
215 providers=device_to_providers(device))
216 self.axpy_sess_binds_ = [
217 self.axpy_sess_.io_binding()._iobinding
218 for i in range(n_tensors)]
219 self.alpha_ = numpy.array(
220 [0], dtype=TENSOR_TYPE_TO_NP_TYPE[
221 self.axpy_onnx_.graph.input[0].type.tensor_type.elem_type])
223 def update_weights(self, n_bind, device, statei, gradienti, batch_size,
224 velocity=None):
225 if velocity is not None:
226 raise RuntimeError( # pragma: no cover
227 "Velocity must be None for this way of updating weights.")
228 if (not hasattr(self, "axpy_onnx_") or
229 not hasattr(self, "axpy_sess_binds_")):
230 raise RuntimeError( # pragma: no cover
231 "Attributes 'axpy_sess_binds_' or "
232 "'axpy_onnx_' is missing. Method "
233 "'build_onnx_function' has not been called.")
234 bind = self.axpy_sess_binds_[n_bind]
235 self._bind_input_ortvalue("X1", bind, gradienti, device, cache=True)
236 self._bind_input_ortvalue("X2", bind, statei, device, cache=True)
237 self.alpha_[0] = - self.value / batch_size # pylint: disable=E1130
238 ort_alpha = C_OrtValue.ortvalue_from_numpy(self.alpha_, device)
239 self._bind_input_ortvalue("alpha", bind, ort_alpha, device, cache=True)
240 self._bind_output_ortvalue('Y', bind, statei, cache=True)
241 self._call_iobinding(self.axpy_sess_._sess, bind)
242 new_weights = bind.get_outputs()[0]
243 return new_weights
246class LearningRateSGDNesterov(LearningRateSGD):
247 """
248 Implements the learning the same way as
249 :class:`sklearn.linear_model.SGDRegressor`.
251 :param eta0: initial learning rate for the `'constant'`, `'invscaling'`
252 or `'adaptive'` schedules.
253 :param alpha: constant that multiplies the regularization term,
254 the higher the value, the stronger the regularization.
255 Also used to compute the learning rate when set to *learning_rate*
256 is set to `'optimal'`.
257 :param power_t: exponent for inverse scaling learning rate
258 :param learning_rate: learning rate schedule:
259 * `'constant'`: `eta = eta0`
260 * `'optimal'`: `eta = 1.0 / (alpha * (t + t0))` where *t0* is chosen
261 by a heuristic proposed by Leon Bottou, this number is multiplied
262 by a constant C to make the first number equal to *eta0*
263 * `'invscaling'`: `eta = eta0 / pow(t, power_t)`
264 :param momentum: float, default=0.9
265 Value of momentum used, must be larger than or equal to 0.
266 :param nesterov: bool, default=True
267 Whether to use nesterov's momentum or not. Use nesterov's if True
268 Not using nesterov is equivalent to class @see cl LearningRateSGD.
270 Created attributes:
271 * `eta0_`: initial eta0
272 * `optimal_init_`: use when `learning_rate=='optimal'`
273 * `value_`: value to be returned by property `value`
275 ::
277 updates = [
278 self.momentum * velocity - self.learning_rate * grad
279 for velocity, grad in zip(self.velocities, grads)]
280 self.velocities = updates
282 if self.nesterov:
283 updates_nesterov = [
284 self.momentum * velocity - self.learning_rate * grad
285 for velocity, grad in zip(self.velocities, grads)]
286 return updates, updates_nesterov --> new gradient and velocities
287 else:
288 return updates --> new gradient
289 """
291 def __init__(self, eta0=0.01, alpha=0.0001, power_t=0.25,
292 learning_rate='invscaling', momentum=0.9, nesterov=True):
293 LearningRateSGD.__init__(
294 self, eta0=eta0, alpha=alpha, power_t=power_t,
295 learning_rate=learning_rate)
296 self.momentum = momentum
297 self.nesterov = nesterov
299 @property
300 def needs_grad(self):
301 """
302 Returns the True if the gradient update needs to retain
303 past gradients.
304 """
305 return True
307 def init_learning_rate(self):
308 """
309 Updates the learning rate at the end of an iteration.
310 :return: self
311 """
312 return LearningRateSGD.init_learning_rate(self)
314 def update_learning_rate(self, t):
315 """
316 Updates the learning rate at the end of an iteration.
317 :param t: iteration number
318 :return: self
319 """
320 return LearningRateSGD.update_learning_rate(self, t)
322 def build_onnx_function(self, opset, device, n_tensors):
323 so = SessionOptions()
324 so.log_severity_level = 4
326 # axpyw
327 if self.nesterov:
328 self.axpyw_onnx_ = function_onnx_graph("axpyw2")
329 else:
330 self.axpyw_onnx_ = function_onnx_graph("axpyw")
331 self.axpyw_sess_ = InferenceSession(
332 self.axpyw_onnx_.SerializeToString(), so,
333 providers=device_to_providers(device))
334 self.axpyw_sess_binds_ = [
335 self.axpyw_sess_.io_binding()._iobinding
336 for n in range(n_tensors)]
338 self.alpha_ = numpy.array(
339 [0], dtype=TENSOR_TYPE_TO_NP_TYPE[
340 self.axpyw_onnx_.graph.input[0].type.tensor_type.elem_type])
341 self.beta_ = numpy.array(
342 [0], dtype=TENSOR_TYPE_TO_NP_TYPE[
343 self.axpyw_onnx_.graph.input[0].type.tensor_type.elem_type])
345 def update_weights(self, n_bind, device, statei, gradienti, batch_size,
346 velocity=None):
347 if (not hasattr(self, "axpyw_onnx_") or
348 not hasattr(self, "axpyw_sess_binds_")):
349 raise RuntimeError( # pragma: no cover
350 "Attributes 'axpyw_sess_binds_' or "
351 "'axpyw_onnx_' is missing. Method "
352 "'build_onnx_function' has not been called.")
353 if velocity is None:
354 raise RuntimeError( # pragma: no cover
355 "Velocity must not be None for this way of updating weights.")
356 bind = self.axpyw_sess_binds_[n_bind]
357 self._bind_input_ortvalue("X1", bind, gradienti, device, cache=True)
358 self._bind_input_ortvalue("X2", bind, statei, device, cache=True)
359 self._bind_input_ortvalue("G", bind, velocity, device, cache=True)
360 self.alpha_[0] = - self.value / batch_size # pylint: disable=E1130
361 self.beta_[0] = self.momentum
362 ort_alpha = C_OrtValue.ortvalue_from_numpy(self.alpha_, device)
363 ort_beta = C_OrtValue.ortvalue_from_numpy(self.beta_, device)
364 self._bind_input_ortvalue("alpha", bind, ort_alpha, device, cache=True)
365 self._bind_input_ortvalue("beta", bind, ort_beta, device, cache=True)
366 self._bind_output_ortvalue('Y', bind, statei, cache=True)
367 self._bind_output_ortvalue('Z', bind, velocity, cache=True)
368 self._call_iobinding(self.axpyw_sess_._sess, bind)
369 return bind.get_outputs() # loss, velocity