Coverage for onnxcustom/training/optimizers_partial.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Optimizer with :epkg:`onnxruntime-training` forward backward training.
4"""
5import logging
6import warnings
7import numpy
8from onnxruntime import InferenceSession, SessionOptions
9from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
10 OrtValue as C_OrtValue)
11from ..utils.onnx_helper import get_onnx_opset, proto_type_to_dtype
12from ..utils.onnxruntime_helper import (
13 device_to_providers, numpy_to_ort_value)
14from ..utils.onnx_function import function_onnx_graph
15from ..utils.print_helper import str_ortvalue
16from ..utils.orttraining_helper import get_train_initializer
17from .ortgradient import OrtGradientForwardBackward
18from ._base_estimator import BaseEstimator
19from .sgd_learning_loss import BaseLearningLoss
20from .sgd_learning_penalty import BaseLearningPenalty
21from .data_loader import OrtDataLoader
22from .excs import ConvergenceError, ConvergenceWarning
25class OrtGradientForwardBackwardOptimizer(BaseEstimator):
26 """
27 Implements a simple :epkg:`Stochastic Gradient Descent`
28 with :epkg:`onnxruntime-training`. It leverages class
29 @see class OrtGradientForwardBackward.
31 :param model_onnx: onnx graph to train
32 :param weights_to_train: names of initializers to be optimized,
33 if None, function @see fn get_train_initialize returns
34 the list of float iniitializer
35 :param loss_output_name: name of the loss output
36 :param max_iter: number of training iterations
37 :param training_optimizer_name: optimizing algorithm
38 :param batch_size: batch size (see class *DataLoader*)
39 :param learning_rate: a name or a learning rate instance or a float,
40 see module :mod:`onnxcustom.training.sgd_learning_rate`
41 :param device: device as :epkg:`C_OrtDevice` or a string
42 representing this device
43 :param warm_start: when set to True, reuse the solution of the previous
44 call to fit as initialization, otherwise, just erase the previous
45 solution.
46 :param learning_loss: loss function (see below)
47 :param verbose: use :epkg:`tqdm` to display the training progress
48 :param validation_every: validation with a test set every
49 *validation_every* iterations
50 :param enable_logging: enable logging (mostly for debugging puporse
51 as it slows down the training)
52 :param weight_name: if not None, the class assumes it is trained
53 with training weight
54 :param learning_penalty: weight penalty, None, or instance of
55 @see cl BaseLearningPenalty
56 :param exc: raise exceptions (about convergence for example)
57 or keep them silent as much as possible
59 *learning_rate* can be any instance of @see cl BaseLearningRate or
60 a nick name in the following list as specified in
61 :meth:`BaseLearningRate.select
62 <onnxcustom.training.sgd_learning_loss.BaseLearningRate.select>`.
64 *learning_loss* can be any instance of @see cl BaseLearningLoss or
65 a nick name in the following list as specified in
66 :meth:`BaseLearningLoss.select
67 <onnxcustom.training.sgd_loss.BaseLearningLoss.select>`.
68 """
70 def __init__(self, model_onnx, weights_to_train=None,
71 loss_output_name='loss', max_iter=100,
72 training_optimizer_name='SGDOptimizer',
73 batch_size=10, learning_rate='SGD',
74 device='cpu', warm_start=False, verbose=0,
75 validation_every=0.1, learning_loss="square_error",
76 enable_logging=False, weight_name=None,
77 learning_penalty=None, exc=True):
78 if weights_to_train is None:
79 weights_to_train = list(get_train_initializer(model_onnx))
80 BaseEstimator.__init__(self, model_onnx, learning_rate, device)
81 self.batch_size = batch_size
82 self.weights_to_train = weights_to_train
83 self.loss_output_name = loss_output_name
84 self.training_optimizer_name = training_optimizer_name
85 self.verbose = verbose
86 self.max_iter = max_iter
87 self.warm_start = warm_start
88 self.learning_loss = BaseLearningLoss.select(learning_loss)
89 self.learning_penalty = BaseLearningPenalty.select(learning_penalty)
90 self.enable_logging = enable_logging
91 self.weight_name = weight_name
92 self.exc = exc
93 if validation_every < 1:
94 self.validation_every = int(self.max_iter * validation_every)
95 else:
96 self.validation_every = validation_every # pragma: no cover
97 self.build_onnx_function()
99 @property
100 def needs_grad(self):
101 """
102 Returns the True if the gradient update needs to retain
103 past gradients.
104 """
105 return self.learning_rate.needs_grad
107 def __getstate__(self):
108 "Removes any non pickable attribute."
109 state = BaseEstimator.__getstate__(self)
110 for att in ['train_state_', 'train_grad_state_']:
111 if hasattr(self, att):
112 train_state = []
113 for v in self.get_state():
114 if v is None:
115 train_state.append(v)
116 else:
117 train_state.append(v.numpy())
118 state[att[:-1]] = train_state
119 return state
121 def __setstate__(self, state):
122 "Restores any non pickable attribute."
123 popped = {}
124 for att in ['train_state', 'train_grad_state']:
125 if att in state:
126 popped[att] = state.pop(att)
127 BaseEstimator.__setstate__(self, state)
128 for k, v in popped.items():
129 if k == 'train_state':
130 self.set_state(v, check_trained=False, kind='weight')
131 elif k == 'train_grad_state':
132 self.set_state(v, check_trained=False, kind='grad')
133 else:
134 raise ValueError( # pragma: no cover
135 "Unexpected key state %r." % k)
136 self.build_onnx_function()
137 return self
139 def _get_att_state(self, kind):
140 if kind == 'weight':
141 return 'train_state_'
142 if kind == 'grad':
143 return 'train_grad_state_'
144 raise ValueError( # pragma: no cover
145 "Unexpected kind=%r." % kind)
147 def get_full_state(self, kind='weight'):
148 """
149 Returns the trained weights and the inputs.
150 """
151 if isinstance(kind, list):
152 return [self.get_full_state(kind=k) for k in kind]
153 att = self._get_att_state(kind)
154 if not hasattr(self, att):
155 raise AttributeError( # pragma: no cover
156 "Method fit must be called before.")
157 return getattr(self, att)
159 def get_state(self, kind='weight'):
160 """
161 Returns the trained weights.
162 """
163 att = self._get_att_state(kind)
164 if not hasattr(self, att):
165 raise AttributeError("Method fit must be called before.")
166 if getattr(self, att, None) is None:
167 raise RuntimeError( # pragma: no cover
168 "No attribute %r available (None)." % att)
169 if self.weights_to_train is None:
170 raise RuntimeError( # pragma: no cover
171 "Unexpected self.weights_to_train (None).")
172 value = getattr(self, att)
173 n = len(value) - len(self.weights_to_train)
174 return value[n:]
176 @property
177 def trained_coef_(self):
178 """
179 Returns the trained coefficients a dictionary.
180 """
181 return dict(zip(self.weights_to_train, self.get_state()))
183 def get_trained_onnx(self, model=None):
184 """
185 Returns the trained onnx graph, the initial graph
186 modified by replacing the initializers with the trained
187 weights.
189 :param model: replace the weights in another graph
190 than the training graph
191 :return: onnx graph
192 """
193 state = dict(zip(self.weights_to_train, self.get_state()))
194 return self._get_trained_onnx(state, model=model)
196 def set_state(self, state, check_trained=True, kind='weight', zero=False):
197 """
198 Changes the trained weights.
199 """
200 if check_trained and not hasattr(self, 'train_session_'):
201 raise AttributeError( # pragma: no cover
202 "Method fit must be called before.")
203 state_ = []
204 state_numpy_ = []
205 for i, v in enumerate(state):
206 if v is None:
207 state_.append(None)
208 state_numpy_.append(None)
209 elif isinstance(v, numpy.ndarray):
210 if zero:
211 v = numpy.zeros(v.shape, dtype=v.dtype)
212 ortvalue = numpy_to_ort_value(v, self.device)
213 state_.append(ortvalue)
214 # The numpy container must be retained as the ortvalue
215 # just borrows the pointer.
216 state_numpy_.append(v)
217 elif isinstance(v, C_OrtValue):
218 if zero:
219 v = self.zero_sess_.run_with_ort_values(['Y'], {'X': v})
220 state_.append(v)
221 state_numpy_.append(None)
222 else:
223 raise TypeError( # pragma: no cover
224 "Unexpected type %r for state %r." % (
225 type(v), i))
226 att = self._get_att_state(kind)
227 setattr(self, att, state_)
228 setattr(self, att + "numpy_", state_numpy_)
230 def build_onnx_function(self):
231 """
232 Creates ONNX graph and *InferenceSession* related to
233 any operations applying on *OrtValue*.
234 """
235 opset = get_onnx_opset(self.model_onnx)
236 so = SessionOptions()
237 so.log_severity_level = 4
239 n = len(self.weights_to_train)
241 # loss_grad
242 self.learning_loss.build_onnx_function(
243 opset, self.device, self.weight_name)
245 # weight update
246 self.learning_rate.build_onnx_function(opset, self.device, n)
248 # regularization
249 self.learning_penalty.build_onnx_function(opset, self.device, n)
251 # zero
252 self.zero_onnx_ = function_onnx_graph("zero")
253 self.zero_sess_ = InferenceSession(
254 self.zero_onnx_.SerializeToString(), so,
255 providers=device_to_providers(self.device))
257 # logging
258 if self.enable_logging:
259 self._logger = logging.getLogger("onnxcustom")
260 else:
261 self._logger = None
263 def fit(self, X, y, sample_weight=None,
264 X_val=None, y_val=None):
265 """
266 Trains the model.
268 :param X: features
269 :param y: expected output
270 :param sample_weight: training weight or None
271 :param X_val: evaluation dataset
272 :param y_val: evaluation dataset
273 :return: self
274 """
275 if self.training_optimizer_name != 'SGDOptimizer':
276 raise NotImplementedError(
277 "Only the SGDOptimizer is implemented not %r."
278 "" % self.training_optimizer_name)
279 logger = self._logger
281 session_function = self._create_training_session(
282 self.model_onnx, self.weights_to_train,
283 device=self.device)
284 self.train_session_ = session_function[0]
285 self.train_function_ = session_function[1]
287 self.input_names_ = self.train_session_.cls_type_._grad_input_names
288 self.output_names_ = self.train_session_.cls_type_._bw_fetches_names
289 weights_to_train = self.train_session_.weights_to_train
291 if logger is not None:
292 logger.info(
293 "[OrtGradientForwardBackwardOptimizer.fit] "
294 "input_names=%r", self.input_names_)
295 logger.info(
296 "[OrtGradientForwardBackwardOptimizer.fit] "
297 "output_names=%r", self.output_names_)
298 logger.info(
299 "[OrtGradientForwardBackwardOptimizer.fit] "
300 "weights_to_train=%r", self.weights_to_train)
301 logger.info(
302 "[OrtGradientForwardBackwardOptimizer.fit] "
303 "device=%r|%r",
304 self.device.device_id(), self.device.device_type())
305 if logger is not None:
306 logger.info(
307 "[OrtGradientForwardBackwardOptimizer.fit] "
308 "warm_start=%r", self.warm_start)
310 if not hasattr(self, 'state_'):
311 self.set_state([
312 self.train_session_.get_initializer(name, exc=False)
313 for name in self.input_names_])
314 if self.needs_grad and not hasattr(self, 'state_grad_'):
315 self.set_state([
316 self.train_session_.get_initializer(name, exc=False)
317 for name in self.input_names_],
318 kind='grad', zero=True)
319 if not self.warm_start:
320 state = self.get_full_state()
321 if len(state) != len(self.input_names_):
322 raise RuntimeError( # pragma: no cover
323 "Length mismatch %r != %r." % (
324 len(state), len(self.input_names_)))
325 new_state = []
326 for iv, v in enumerate(state):
327 if v is None:
328 new_state.append(v)
329 else:
330 if not isinstance(v, C_OrtValue):
331 raise RuntimeError( # pragma: no cover
332 "Unexpected type %r (state[%d])." % (
333 type(v), iv))
334 dtype = proto_type_to_dtype(
335 v.proto_type()
336 if hasattr(v, 'proto_type')
337 else v.data_type())
338 if len(v.shape()) > 0:
339 new_state.append(
340 numpy.random.randn(*v.shape()).astype(dtype))
341 else:
342 new_state.append(
343 numpy.random.randn(1).astype(dtype))
344 self.set_state(new_state)
345 if self.needs_grad:
346 self.set_state(new_state, kind='grad', zero=True)
348 data_loader = OrtDataLoader(
349 X, y, sample_weight, batch_size=self.batch_size,
350 device=self.device)
351 if X_val is not None:
352 data_loader_val = OrtDataLoader(
353 X_val, y_val, batch_size=X_val.shape[0], device=self.device,
354 random_iter=False)
355 else:
356 data_loader_val = None
358 self.learning_rate.init_learning_rate()
360 if self.verbose > 0: # pragma: no cover
361 from tqdm import tqdm # pylint: disable=C0415
362 loop = tqdm(range(self.max_iter))
363 else:
364 loop = range(self.max_iter)
366 self.train_losses_ = []
367 val_losses = []
368 kinds = ['weight', 'grad'] if self.needs_grad else ['weight']
369 for it in loop:
370 loss = self._iteration(
371 data_loader, self.get_full_state(kind=kinds),
372 len(weights_to_train))
373 lr = self.learning_rate.update_learning_rate(it).value
374 if self.verbose > 1: # pragma: no cover
375 loop.set_description(
376 "loss=%1.3g lr=%1.3g" % ( # pylint: disable=E1101,E1307
377 loss, lr)) # pylint: disable=E1101,E1307
378 if logger is not None:
379 logger.info(
380 "[OrtGradientForwardBackwardOptimizer.fit] "
381 "lr value=%r", lr)
383 self.train_losses_.append(loss)
384 if (data_loader_val is not None and
385 (it + 1) % self.validation_every == 0):
386 val_losses.append(
387 self._evaluation(data_loader_val, self.get_full_state()))
388 self.validation_losses_ = (
389 None if data_loader_val is None else val_losses)
391 if logger is not None:
392 logger.info(
393 "[OrtGradientForwardBackwardOptimizer.fit] "
394 "end loss=%r", self.train_losses_[-1])
395 return self
397 def _iteration(self, data_loader, states, n_weights):
398 actual_losses = []
399 bs = data_loader.batch_size
400 logger = self._logger
401 if len(states) == 1:
402 state = states[0]
403 grad = None
404 else:
405 state, grad = states
407 if logger is not None:
408 logger.debug(
409 "[OrtGradientForwardBackwardOptimizer._iteration] "
410 "iteration begin learning_rate=%r",
411 self.learning_rate)
413 prediction_cache = None
414 prediction_cache_shape = None
415 backward_outputs_cache = None
416 for ib, ito in enumerate(data_loader.iter_ortvalue()):
417 if len(ito) == 2:
418 (ortx, orty) = ito
419 ortw = None
420 else:
421 (ortx, orty, ortw) = ito
422 state[0] = ortx
424 if logger is not None:
425 logger.debug(
426 "[OrtGradientForwardBackwardOptimizer._iteration] "
427 "batch %d", ib)
429 ortx_shape = tuple(ortx.shape())
430 same_shape = (
431 prediction_cache_shape is not None and
432 ortx_shape == prediction_cache_shape)
434 if logger is not None:
435 logger.debug(
436 "[OrtGradientForwardBackwardOptimizer._iteration] forward")
438 # forward
439 if prediction_cache_shape is None or same_shape:
440 prediction_cache = None
441 prediction_cache_shape = None
442 prediction = self.train_function_.forward(
443 states[0], training=True,
444 forward_outputs_cache=prediction_cache)
445 prediction_cache = prediction
446 prediction_cache_shape = ortx_shape
448 if logger is not None:
449 logger.debug(
450 "[OrtGradientForwardBackwardOptimizer._iteration] "
451 "loss types=%r,%r",
452 orty.data_type(), prediction[0].data_type())
454 # loss
455 loss, loss_gradient = self.learning_loss.loss_gradient(
456 self.device, orty, prediction[0], weight=ortw)
458 if logger is not None:
459 logger.debug(
460 "[OrtGradientForwardBackwardOptimizer._iteration] "
461 "loss=%g has_weight=%r",
462 loss.numpy(), ortw is not None)
464 n = len(state) - n_weights
465 loss = self.learning_penalty.penalty_loss(
466 self.device, loss, *state[n:])
468 cpu_loss = loss.numpy()
470 if logger is not None:
471 logger.debug(
472 "[OrtGradientForwardBackwardOptimizer._iteration] "
473 "cpu_loss=%r", cpu_loss)
475 if numpy.isinf(cpu_loss) or numpy.isnan(cpu_loss):
476 if self.exc:
477 raise ConvergenceError(
478 "Loss is nan, learning_rate=%r, "
479 "the gradient descent has failed "
480 "(past losses=%r)." % (
481 self.learning_rate,
482 [float(v) for v in (
483 actual_losses if len(actual_losses) < 5
484 else actual_losses[-5:])]))
485 warnings.warn( # pragma: no cover
486 "Loss is nan, learning_rate=%r, "
487 "the gradient descent has failed "
488 "(past losses=%r)." % (
489 self.learning_rate,
490 [float(v) for v in (
491 actual_losses if len(actual_losses) < 5
492 else actual_losses[-5:])]),
493 ConvergenceWarning)
494 if numpy.isinf(cpu_loss): # pragma: no cover
495 cpu_loss = numpy.nan
497 # backward
498 if not same_shape:
499 backward_outputs_cache = None
500 gradient = self.train_function_.backward(
501 [loss_gradient], backward_outputs_cache=backward_outputs_cache)
502 backward_outputs_cache = gradient
504 if len(gradient) != len(state):
505 raise RuntimeError( # pragma: no cover
506 "gradient and state should have the same length but "
507 "%r != %r." % (len(gradient), len(state)))
509 n = len(state) - n_weights
511 for i in range(n, len(state)):
512 self.learning_penalty.update_weights(
513 i - n, self.device, state[i])
514 self.learning_rate.update_weights(
515 i - n, self.device, state[i],
516 gradient[i], bs,
517 None if grad is None else grad[i])
519 if logger is not None:
520 logger.debug(
521 "[OrtGradientForwardBackwardOptimizer._iteration] "
522 "loss=%g n_weights=%d", cpu_loss, n)
523 for i in range(n, len(state)):
524 logger.debug(
525 "[OrtGradientForwardBackwardOptimizer._iteration] "
526 "state[%i]=%s", i, str_ortvalue(state[i]))
528 actual_losses.append(cpu_loss / bs)
530 if logger is not None:
531 logger.debug(
532 "[OrtGradientForwardBackwardOptimizer._iteration] "
533 "iteration end")
535 return numpy.array(actual_losses).mean()
537 def _evaluation(self, data_loader, state):
538 logger = self._logger
539 actual_losses = []
540 for ib, (ortx, orty) in enumerate(data_loader.iter_ortvalue()):
541 state[0] = ortx
543 if logger is not None:
544 logger.debug( # pragma: no cover
545 "[OrtGradientForwardBackwardOptimizer._evaluation] "
546 "batch %d", ib)
548 prediction = self.train_function_.forward(state, training=False)
549 loss, _ = self.learning_loss.loss_gradient(
550 self.device, orty, prediction[0])
551 cpu_loss = loss.numpy()
552 if numpy.isinf(cpu_loss) or numpy.isnan(cpu_loss):
553 if self.exc: # pragma: no cover
554 raise ConvergenceError(
555 "Loss is nan, "
556 "the evaluation has failed "
557 "(past losses=%r)." %
558 [float(v) for v in (
559 actual_losses if len(actual_losses) < 5
560 else actual_losses[-5:])])
561 warnings.warn(
562 "Loss is nan, learning_rate=%r, "
563 "the gradient descent has failed "
564 "(past losses=%r)." % (
565 self.learning_rate,
566 [float(v) for v in (
567 actual_losses if len(actual_losses) < 5
568 else actual_losses[-5:])]),
569 ConvergenceWarning)
570 if numpy.isinf(cpu_loss):
571 cpu_loss = numpy.nan
572 actual_losses.append(cpu_loss)
574 return numpy.array(actual_losses).sum() / len(data_loader)
576 def score(self, X, y, sample_weight=None):
577 """
578 Return the whole score associated.
580 :param X: features
581 :param y: expected output
582 :param sample_weight: training weight or None
583 :return: score
584 """
585 scores = self.losses(X, y, sample_weight=sample_weight)
586 return -scores.sum() / X.shape[0]
588 def losses(self, X, y, sample_weight=None):
589 """
590 Returns the losses associated to every observation.
592 :param X: features
593 :param y: expected output
594 :param sample_weight: training weight or None
595 :return: scores
596 """
597 data_loader = OrtDataLoader(
598 X, y, sample_weight, batch_size=self.batch_size,
599 device=self.device)
601 state = self.get_full_state()
602 scores = numpy.empty((X.shape[0], ), dtype=X.dtype)
603 pos = 0
604 for ito in data_loader.iter_ortvalue():
605 if len(ito) == 2:
606 (ortx, orty) = ito
607 ortw = None
608 else:
609 (ortx, orty, ortw) = ito
610 state[0] = ortx
611 prediction = self.train_function_.forward(state, training=False)
612 score = self.learning_loss.loss_scores(
613 self.device, orty, prediction[0], ortw)
614 np_score = score.numpy()
615 # data copy could be avoided by giving a pointer to
616 # loss score or if we could create an OrtValue from a
617 # pointer.
618 end = pos + np_score.shape[0]
619 if end <= scores.shape[0]:
620 scores[pos: end] = np_score.ravel()
621 else:
622 scores[pos: end] = np_score.ravel()[end - scores.shape[0]:]
623 pos += np_score.shape[0]
624 return scores
626 def _create_training_session(
627 self, model_onnx, weights_to_train, device):
629 forback = OrtGradientForwardBackward(
630 model_onnx, weights_to_train=weights_to_train,
631 debug=False, enable_logging=False,
632 providers=device_to_providers(device))
633 inst = forback.new_instance()
634 return (forback, inst)