Coverage for onnxcustom/training/optimizers.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Optimizer with :epkg:`onnxruntime-training`.
4"""
5import numpy
6from onnxruntime import ( # pylint: disable=E0611
7 TrainingParameters, SessionOptions, TrainingSession)
8from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
9 OrtValue as C_OrtValue, SessionIOBinding as C_IOBinding)
10from ..utils.onnxruntime_helper import (
11 numpy_to_ort_value, device_to_providers)
12from .data_loader import OrtDataLoader
13from .excs import ConvergenceError, EvaluationError
14from ._base_estimator import BaseEstimator
17class OrtGradientOptimizer(BaseEstimator):
18 """
19 Implements a simple :epkg:`Stochastic Gradient Descent`
20 with :epkg:`onnxruntime-training`.
22 :param model_onnx: onnx graph to train
23 :param weights_to_train: names of initializers to be optimized
24 :param loss_output_name: name of the loss output
25 :param max_iter: number of training iterations
26 :param training_optimizer_name: optimizing algorithm
27 :param batch_size: batch size (see class *DataLoader*)
28 :param learning_rate: a name or a learning rate instance or a float,
29 see module :mod:`onnxcustom.training.sgd_learning_rate`
30 :param device: device as :epkg:`C_OrtDevice` or a string
31 representing this device
32 :param warm_start: when set to True, reuse the solution of the previous
33 call to fit as initialization, otherwise, just erase the previous
34 solution.
35 :param verbose: use :epkg:`tqdm` to display the training progress
36 :param validation_every: validation with a test set every
37 *validation_every* iterations
38 :param saved_gradient: if not None, a filename,
39 the optimizer saves the gradient into it
40 :param sample_weight_name: name of the sample weight input
42 Once initialized, the class creates the attribute
43 `train_session_` which holds an instance of :ref:`l-ort-training-session`.
45 See example :ref:`l-orttraining-nn-gpu`.
46 """
48 def __init__(self, model_onnx, weights_to_train, loss_output_name='loss',
49 max_iter=100, training_optimizer_name='SGDOptimizer',
50 batch_size=10, learning_rate='SGD',
51 device='cpu', warm_start=False, verbose=0,
52 validation_every=0.1, saved_gradient=None,
53 sample_weight_name="weight"):
54 BaseEstimator.__init__(self, model_onnx, learning_rate, device)
55 self.batch_size = batch_size
56 self.weights_to_train = weights_to_train
57 self.loss_output_name = loss_output_name
58 self.training_optimizer_name = training_optimizer_name
59 self.verbose = verbose
60 self.max_iter = max_iter
61 self.warm_start = warm_start
62 self.saved_gradient = saved_gradient
63 self.sample_weight_name = sample_weight_name
64 if validation_every < 1:
65 self.validation_every = int(self.max_iter * validation_every)
66 else:
67 self.validation_every = validation_every # pragma: no cover
68 if self.learning_rate.needs_grad:
69 raise NotImplementedError(
70 "Any weight update involving past gradient is "
71 "not implemented (class %r)."
72 "" % self.learning_rate.__class__.__name__)
74 def fit(self, X, y, sample_weight=None, X_val=None, y_val=None,
75 use_numpy=False):
76 """
77 Trains the model.
79 :param X: features
80 :param y: expected output
81 :param sample_weight: sample weight if any
82 :param X_val: evaluation dataset
83 :param y_val: evaluation dataset
84 :param use_numpy: if True, slow iterator using numpy,
85 otherwise, minimizes copy
86 :return: self
87 """
88 input_names = [i.name for i in self.model_onnx.graph.input]
89 if ((len(input_names) == 2 and sample_weight is not None) or
90 (len(input_names) == 3 and sample_weight is None)):
91 raise RuntimeError( # pragma: no cover
92 "Number of inputs should be 2 if sample_weight is None "
93 "or 3 if not None but it is %d." % len(input_names))
94 self.train_session_ = self._create_training_session(
95 self.model_onnx, self.weights_to_train,
96 loss_output_name=self.loss_output_name,
97 training_optimizer_name=self.training_optimizer_name,
98 device=self.device)
100 if not self.warm_start:
101 state = self.get_state()
102 new_state = {}
103 for k, v in state.items():
104 if len(v.shape) > 0:
105 new_state[k] = numpy.random.randn(*v.shape).astype(v.dtype)
106 else:
107 f = numpy.random.randn(1)
108 f = f.astype(v.dtype)
109 new_state[k] = f
110 self.set_state(new_state)
112 data_loader = OrtDataLoader(
113 X, y, sample_weight=sample_weight,
114 batch_size=self.batch_size, device=self.device)
115 if X_val is not None:
116 data_loader_val = OrtDataLoader(
117 X_val, y_val, batch_size=X_val.shape[0], device=self.device,
118 random_iter=False)
119 else:
120 data_loader_val = None
122 self.learning_rate.init_learning_rate()
123 self.input_names_ = [i.name for i in self.train_session_.get_inputs()]
124 self.output_names_ = [
125 o.name for o in self.train_session_.get_outputs()]
126 self.loss_index_ = self.output_names_.index(self.loss_output_name)
128 bind = self.train_session_.io_binding()._iobinding
130 if self.verbose > 0: # pragma: no cover
131 from tqdm import tqdm # pylint: disable=C0415
132 loop = tqdm(range(self.max_iter))
133 else:
134 loop = range(self.max_iter)
136 self.train_losses_ = []
137 self.validation_losses_ = []
138 lr = self.learning_rate.value
139 for it in loop:
140 lr_alive = numpy.array([lr / self.batch_size], dtype=numpy.float32)
141 ort_lr = numpy_to_ort_value(lr_alive, self.device)
142 loss = self._iteration(data_loader, ort_lr,
143 bind, use_numpy=use_numpy,
144 sample_weight=sample_weight is not None)
145 lr = self.learning_rate.update_learning_rate(it).value
146 if self.verbose > 1: # pragma: no cover
147 loop.set_description(
148 "loss=%1.3g lr=%1.3g " # pylint: disable=E1307
149 "lrn=%1.3g" % (
150 loss, lr, lr_alive[0]))
151 self.train_losses_.append(loss)
152 if (data_loader_val is not None and
153 (it + 1) % self.validation_every == 0):
154 self.validation_losses_.append(
155 self._evaluation(data_loader_val, bind))
156 self.trained_coef_ = self.train_session_.get_state()
157 return self
159 def _bind_input_ortvalue(self, name, bind, c_ortvalue):
160 """
161 Binds :epkg:`C_OrtValue` to the structure used by
162 :epkg:`InferenceSession` to run inference.
164 :param name: str
165 :param bind: python structure
166 :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`),
167 it can be also a numpy array
168 """
169 if not isinstance(bind, C_IOBinding):
170 raise TypeError( # pragma: no cover
171 "Unexpected type %r." % type(bind))
172 if isinstance(c_ortvalue, C_OrtValue):
173 bind.bind_ortvalue_input(name, c_ortvalue)
174 elif isinstance(c_ortvalue, numpy.ndarray):
175 # This fails on linux with int64.
176 bind.bind_input(
177 name, self.device, c_ortvalue.dtype, c_ortvalue.shape,
178 c_ortvalue.__array_interface__['data'][0])
179 else:
180 raise TypeError( # pragma: no cover
181 "Unable to bind type %r for name %r." % (
182 type(c_ortvalue), name))
184 def _iteration(self, data_loader, ort_lr, bind, use_numpy, sample_weight):
185 actual_losses = []
187 bind.bind_output('loss', self.device)
188 idx = 3 if sample_weight else 2
190 if use_numpy:
191 # onnxruntime does not copy the data, so the numpy
192 # array must remain alive all along the iteration
193 lr_alive = ort_lr.numpy()
194 self._bind_input_ortvalue(
195 self.input_names_[idx], bind, lr_alive)
197 # Slow iterations.
198 for it in data_loader.iter_numpy():
199 if len(it) == 2:
200 data, target = it
201 self._bind_input_ortvalue(
202 self.input_names_[0], bind, data)
203 self._bind_input_ortvalue(
204 self.input_names_[1], bind, target)
205 else:
206 data, target, weight = it
207 self._bind_input_ortvalue(
208 self.input_names_[0], bind, data)
209 self._bind_input_ortvalue(
210 self.input_names_[1], bind, target)
211 self._bind_input_ortvalue(
212 self.input_names_[2], bind, weight)
214 self.train_session_._sess.run_with_iobinding(bind, None)
215 loss = bind.get_outputs()[0].numpy()
216 if numpy.isinf(loss) or numpy.isnan(loss):
217 raise ConvergenceError(
218 "Loss is nan, learning_rate=%r, "
219 "the gradient descent has failed "
220 "(past losses=%r)." % (
221 ort_lr.numpy(),
222 [float(v[0]) for v in (
223 actual_losses if len(actual_losses) < 5
224 else actual_losses[-5:])]))
225 actual_losses.append(loss / data.shape[0])
226 else:
227 self._bind_input_ortvalue(self.input_names_[idx], bind, ort_lr)
229 # Fast iterations
230 # Slow iterations.
231 for batch_size in data_loader.iter_bind(bind, self.input_names_):
232 self.train_session_._sess.run_with_iobinding(bind, None)
233 # We copy the predicted output as well which is not needed.
234 loss = bind.get_outputs()[0].numpy()
235 if numpy.isinf(loss) or numpy.isnan(loss):
236 raise ConvergenceError(
237 "Loss is nan or infinite, learning_rate=%r, "
238 "the gradient descent has failed "
239 "(past losses=%r)." % (
240 ort_lr.numpy(),
241 [float(v[0]) for v in (
242 actual_losses if len(actual_losses) < 5
243 else actual_losses[-5:])]))
244 actual_losses.append(loss / batch_size)
246 return numpy.array(actual_losses).mean()
248 def _evaluation(self, data_loader, bind):
249 lr_alive = numpy.array([0], dtype=numpy.float32)
250 self._bind_input_ortvalue(self.input_names_[2], bind, lr_alive)
251 bind.bind_output('loss', self.device)
253 actual_losses = []
254 total = 0
255 for batch_size in data_loader.iter_bind(bind, self.input_names_):
256 self.train_session_._sess.run_with_iobinding(bind, None)
257 outputs = bind.copy_outputs_to_cpu()
258 if numpy.isinf(outputs[0]) or numpy.isnan(outputs[0]):
259 raise EvaluationError( # pragma: no cover
260 "Loss is nan or infinite (%r), "
261 "evaluation has failed." % outputs[0])
262 actual_losses.append(outputs[0])
263 total += batch_size
264 return numpy.array(actual_losses).sum() / max(total, 1)
266 def _create_training_session(
267 self, training_onnx, weights_to_train,
268 loss_output_name='loss',
269 training_optimizer_name='SGDOptimizer',
270 device='cpu'):
271 """
272 Creates an instance of :epkg:`TrainingSession`.
274 :param training_onnx: an ONNX graph with a loss function
275 :param weights_to_train: list of initializer names to optimize
276 :param loss_output_name: output name for the loss
277 :param training_optimizer_name: optimizer name
278 :param device: one :epkg:`C_OrtDevice` or a string
279 :return: an instance of :epkg:`TrainingSession`
280 """
281 if training_optimizer_name != 'SGDOptimizer':
282 raise NotImplementedError(
283 "Only the SGDOptimizer is implemented not %r."
284 "" % training_optimizer_name)
285 ort_parameters = TrainingParameters()
286 ort_parameters.loss_output_name = loss_output_name
287 ort_parameters.use_mixed_precision = False
288 # ort_parameters.world_rank = -1
289 # ort_parameters.world_size = 1
290 # ort_parameters.gradient_accumulation_steps = 1
291 # ort_parameters.allreduce_post_accumulation = False
292 # ort_parameters.deepspeed_zero_stage = 0
293 # ort_parameters.enable_grad_norm_clip = False
294 # ort_parameters.set_gradients_as_graph_outputs = False
295 # ort_parameters.use_memory_efficient_gradient = False
296 # ort_parameters.enable_adasum = False
297 if self.saved_gradient is not None:
298 name = self.saved_gradient
299 name2 = name + ".training.onnx"
300 ort_parameters.model_with_gradient_graph_path = name
301 ort_parameters.model_with_training_graph_path = name2
303 output_types = {}
304 for output in training_onnx.graph.output:
305 output_types[output.name] = output.type.tensor_type
307 ort_parameters.weights_to_train = set(weights_to_train)
308 ort_parameters.training_optimizer_name = training_optimizer_name
309 # ort_parameters.lr_params_feed_name = lr_params_feed_name
311 ort_parameters.optimizer_attributes_map = {
312 name: {} for name in weights_to_train}
313 ort_parameters.optimizer_int_attributes_map = {
314 name: {} for name in weights_to_train}
316 session_options = SessionOptions()
317 session_options.log_severity_level = 4
318 session_options.log_verbosity_level = 4
319 # session_options.use_deterministic_compute = True
321 providers = device_to_providers(self.device)
322 session = TrainingSession(
323 training_onnx.SerializeToString(), ort_parameters, session_options,
324 providers=providers)
326 return session
328 def get_state(self):
329 """
330 Returns the trained weights.
331 """
332 if not hasattr(self, 'train_session_'):
333 if hasattr(self, 'trained_coef_'):
334 return self.trained_coef_
335 raise AttributeError("Method fit must be called before.")
336 return self.train_session_.get_state()
338 def get_trained_onnx(self, model=None):
339 """
340 Returns the trained onnx graph, the initial graph
341 modified by replacing the initializers with the trained
342 weights. If model is not specified, it uses the model
343 given as an argument to this class. This graph outputs
344 the loss and not the predictions. Parameter *model*
345 can be used to use the graph before loss was added
346 and then the returned graph will produce the predictions.
348 :param model: replace the weights in another graph
349 than the training graph
350 :return: onnx graph
351 """
352 return self._get_trained_onnx(self.get_state(), model=model)
354 def set_state(self, state):
355 """
356 Changes the trained weights.
357 """
358 if not hasattr(self, 'train_session_'):
359 raise AttributeError( # pragma: no cover
360 "Method fit must be called before.")
361 return self.train_session_.load_state(state)