Coverage for onnxcustom/training/ortgradient.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=E1101
2"""
3@file
4@brief Gradient with :epkg:`onnxruntime-training` forward backward.
5"""
6import os
7import logging
8import warnings
9from io import BytesIO
10import onnx
11from onnx.numpy_helper import to_array
12from onnxruntime import InferenceSession, RunOptions
13from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
14 SessionIOBinding, OrtValue as C_OrtValue)
15from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
16 TrainingAgent, OrtValueCache, OrtModuleGraphBuilder,
17 OrtModuleGraphBuilderConfiguration, OrtDevice,
18 TrainingGraphTransformerConfiguration, OrtValueVector,
19 PartialGraphExecutionState)
20from ..utils.orttraining_helper import get_train_initializer
23class OrtGradientForwardBackward:
24 """
25 Implements forward backward mechanism assuming the function
26 to train is defined by an ONNX graph.
28 :param onnx_model: onnx model
29 :param weights_to_train: names of the weights to train,
30 if None, all initializer of floats type are included in the list
31 :param input_names: input names or None for all
32 :param output_names: output names or None for all
33 :param class_name: name to give the class dynamically created
34 :param sess_options: see :epkg:`SessionOptions`
35 :param providers: see :epkg:`InferenceSession`
36 :param provider_options: see :epkg:`InferenceSession`
37 :param run_options: see :epkg:`RunOptions`
38 :param graph_builder_config:
39 see :epkg:`OrtModuleGraphBuilderConfiguration`
40 :param device_index: used for cuda (0 for `cuda:0`,
41 `cuda:1`, ...), 0 by default
42 :param enable_logging: enables logging while setting up the class
43 :param debug: to run extra verification while training
45 .. note::
46 The current implementation of :epkg:`onnxruntime` forces
47 the weights to train to appear in the alphabetical order.
48 The constructor checks that condition is verified.
50 .. warning::
51 This class does not consider subgraphs.
52 """
54 def __init__(self, onnx_model, weights_to_train=None,
55 input_names=None, output_names=None, class_name=None,
56 sess_options=None, providers=None,
57 provider_options=None, run_options=None,
58 graph_builder_config=None,
59 device_index=0, enable_logging=False, debug=False):
61 if weights_to_train is None:
62 weights_to_train = (
63 OrtGradientForwardBackward._select_initializer_names(
64 onnx_model))
65 if len(weights_to_train) == 0:
66 raise RuntimeError( # pragma: no cover
67 "Unable to guess the weights to train from initializers: "
68 "%r." % [i.name for i in onnx_model.graph.initializer])
70 self.onnx_model = onnx_model
71 self.input_names = input_names
72 self.output_names = output_names
73 self.weights_to_train = weights_to_train
74 self.device_index = device_index
75 self.enable_logging = enable_logging
76 self.class_name = (class_name if class_name is not None else
77 "OrtGradientForwardBackwardFunction_%d" % id(self))
79 self.provider_options = provider_options
80 self.sess_options = sess_options
81 self.providers = providers
82 self.run_options = run_options
83 self.graph_builder_config = graph_builder_config
84 self.debug = debug
86 # default
87 if self.weights_to_train is None:
88 raise ValueError( # pragma: no cover
89 "weights_to_train must be specified.")
90 if self.input_names is None:
91 self.input_names = [obj.name
92 for obj in self.onnx_model.graph.input]
93 if self.output_names is None:
94 self.output_names = [obj.name
95 for obj in self.onnx_model.graph.output]
96 if self.class_name is None:
97 self.class_name = "TorchOrtFunction_%r" % id(
98 self) # pragma: no cover
99 if hasattr(self.providers, 'type'):
100 if self.providers.type != 'cpu':
101 self.device_index = self.providers.index
102 self.providers = self.providers.type
103 if self.providers in (None, 'cpu'):
104 self.providers = ["CPUExecutionProvider" for i in self.input_names]
105 if self.provider_options is None:
106 self.provider_options = [{} for i in self.input_names]
107 elif self.providers in ('cuda', 'cuda:0', 'gpu'):
108 self.providers = [
109 "CUDAExecutionProvider" for i in self.input_names]
110 if self.provider_options is None:
111 self.provider_options = [{} for i in self.input_names]
112 if self.provider_options is None:
113 self.provider_options = [{} for i in self.providers]
115 if list(sorted(self.weights_to_train)) != self.weights_to_train:
116 raise ValueError( # pragma: no cover
117 "List of weights to train must be sorted but %r is not. "
118 "You shoud use function onnx_rename_weights to do that "
119 "before calling this class." % self.weights_to_train)
120 set_weights = set(self.weights_to_train)
121 if len(set_weights) != len(self.weights_to_train):
122 raise ValueError( # pragma: no cover
123 "One weight is not unique in %r." % self.weights_to_train)
124 found = []
125 for i in self.onnx_model.graph.initializer:
126 if i.name not in set_weights:
127 continue
128 found.append(i.name)
129 if len(found) != len(self.weights_to_train):
130 raise ValueError(
131 "One weight name in self.weights_to_train was not found in "
132 "the initializers %r found=%r init names=%r." % (
133 self.weights_to_train, found,
134 [i.name for i in self.onnx_model.graph.initializer]))
135 if found != self.weights_to_train:
136 raise ValueError(
137 "List of weights to train must be sorted and follow the "
138 "as the initializers in the graph. %r != %r."
139 "You shoud use function onnx_rename_weights to do that "
140 "before calling this class." % (
141 self.weights_to_train, found))
143 if any(map(lambda v: v not in ['CPUExecutionProvider',
144 'CUDAExecutionProvider'],
145 self.providers)):
146 raise ValueError(
147 "Unexpected providers %r (providers=%r)." % (
148 self.providers, providers))
150 # complete initialisation
151 self._init_next()
153 @staticmethod
154 def _select_initializer_names(onnx_model):
155 """
156 Selects all initializers with float type.
158 :param onnx_model: ONNX graph
159 """
160 inits = get_train_initializer(onnx_model)
161 return list(inits)
163 def _init_next(self):
164 if self.enable_logging:
165 self._logger = logging.getLogger("onnxcustom")
166 else:
167 self._logger = None # pragma: no cover
168 if self.run_options is None:
169 self.run_options = RunOptions()
170 self.run_options.training_mode = True
172 if self.graph_builder_config is None:
173 initializer_names = [
174 i.name for i in self.onnx_model.graph.initializer]
175 input_names = [i.name for i in self.onnx_model.graph.input]
177 config = OrtModuleGraphBuilderConfiguration()
178 config.initializer_names = [init for init in initializer_names
179 if init in self.weights_to_train]
180 config.initializer_names_to_train = self.weights_to_train
181 config.input_names_require_grad = input_names
182 config.build_gradient_graph = True
184 if (len(config.initializer_names) != # noqa
185 len(config.initializer_names_to_train)):
186 raise RuntimeError( # pragma: no cover
187 "Unable to automatically fill "
188 "OrtModuleGraphBuilderConfiguration, mismatch between "
189 "%r and %r (initializer_names=%r)." % (
190 config.initializer_names,
191 config.initializer_names_to_train,
192 initializer_names))
194 p = TrainingGraphTransformerConfiguration()
195 config.graph_transformer_config = p
197 # config.enable_caching = True
198 # config.loglevel =
199 # config.use_memory_efficient_gradient = True
200 self.graph_builder_config = config
202 attributes = self._create_onnx_graphs()
203 attributes['__doc__'] = (
204 "Inherits from @see cl OrtGradientForwardBackwardFunction.")
205 attributes['__module__'] = (
206 OrtGradientForwardBackwardFunction.__module__)
207 self.cls_type_ = type(
208 self.class_name, (OrtGradientForwardBackwardFunction,),
209 attributes)
211 def new_instance(self):
212 """
213 Creates an instance of class `self.cls_type_`.
214 It implements methods *forward* and *backward*.
215 """
216 return self.cls_type_()
218 def __getstate__(self):
219 "Removes any non pickable attribute."
220 atts = [k for k in self.__dict__ if not k.endswith('_')
221 if k not in {'_logger', 'graph_builder_config',
222 'run_options'}]
223 state = {att: getattr(self, att) for att in atts}
224 state['run_options'] = None
225 state['graph_builder_config'] = None
226 return state
228 def __setstate__(self, state):
229 "Restores any non pickable attribute."
230 for att, v in state.items():
231 setattr(self, att, v)
232 self._init_next()
233 return self
235 def __repr__(self):
236 "usual"
237 return "%s(...)" % self.__class__.__name__
239 @staticmethod
240 def _repr_helper_(obj, indent=0):
241 "used to improve logging messages"
242 if obj is None:
243 return 'None'
244 rows = []
245 for c in sorted(dir(obj)):
246 if c[0] == '_':
247 continue
248 try:
249 value = getattr(obj, c)
250 except AttributeError: # pragma: no cover
251 continue
252 rows.append("%s=%r" % (c, value))
254 if indent == 0:
255 return "%s(%s)" % (obj.__class__.__name__, ", ".join(rows))
256 return "%s(\n %s)" % (
257 obj.__class__.__name__,
258 "\n ".join(rows))
260 @staticmethod
261 def _provider_name_to_device_type(provider_name):
262 if provider_name == 'CPUExecutionProvider':
263 return OrtDevice.cpu()
264 if provider_name == 'CUDAExecutionProvider': # pragma: no cover
265 return OrtDevice.cuda()
266 raise ValueError( # pragma: no cover
267 'Unexpected provider name %r.' % provider_name)
269 def get_initializer(self, name, exc=True):
270 """
271 Returns an initializer as numpy arrays.
273 :param name: initializer name
274 :param exc: raises an exception if not found or return None
275 :return: the initializer as a :epkg:`C_OrtValue`
276 """
277 for init in self.onnx_model.graph.initializer:
278 if name == init.name:
279 return to_array(init)
280 if exc:
281 raise RuntimeError( # pragma: no cover
282 "Unable to find name %r in %r." % (
283 name,
284 list(i.name for i in self.onnx_model.graph.initializer)))
285 return None
287 def _create_onnx_graphs(self):
288 """
289 Creates forward and backward ONNX graph.
290 The new class has the following attributes:
292 * `__doc__`: doc string
293 * `__module__`: module name (this file)
294 * `_run_options`: see :epkg:`RunOptions`
295 * `_sess`: :epkg:`InferenceSession` with the original graph
296 * `_sess_eval`: :epkg:`InferenceSession` on the graph
297 with weights as inputs
298 * `_training_agent`: :epkg:`TrainingAgent`
299 * `_cache`: :epkg:`OrtValueCache`
300 * `_logger`: logger
301 * `_input_names`: input names
302 * `_debug`: use debug mode
303 * `_grad_input_names`: gradient input names
304 * `_output_names`: output names
305 * `_weights_to_train`: names of the weights to train
307 Training attributes
309 * `_bw_fetches_names`: bw_fetches_names,
310 * `_fw_outputs_device_info`: fw_outputs_device_info,
311 * `_bw_outputs_device_info`: bw_outputs_device_info,
312 * `_fw_no_grad_output_device_info`: fw_no_grad_output_device_info,
313 * `_graph_info`: graph_info}
315 Additional attributes added if *keep_model* is True:
317 * `_trained_onnx`: ONNX graph for the gradient
318 * `_optimized_pre_grad_model`: evaluation ONNX graph taking
319 weights as inputs
320 * `_graph_builder`: :epkg:`OrtModuleGraphBuilder`
321 """
322 logger = self._logger
323 if logger is not None:
324 logger.info("[OrtGradientForwardBackward] create training onnx")
325 logger.info("[OrtGradientForwardBackward] input_names=%r",
326 self.input_names)
327 logger.info("[OrtGradientForwardBackward] output_names=%r",
328 self.output_names)
329 logger.info("[OrtGradientForwardBackward] weights_to_train=%r",
330 self.weights_to_train)
332 builder = OrtModuleGraphBuilder()
334 if logger is not None:
335 cf = self.graph_builder_config.graph_transformer_config
336 cfp = cf.propagate_cast_ops_config
337 logger.info(
338 "[OrtGradientForwardBackward] "
339 "OrtModuleGraphBuilder.initialize")
340 logger.info(
341 "[OrtGradientForwardBackward] graph_builder_config=%s",
342 OrtGradientForwardBackward._repr_helper_(
343 self.graph_builder_config, indent=4))
344 logger.info(
345 "[OrtGradientForwardBackward] graph_builder_config."
346 "graph_transformer_config=%s",
347 OrtGradientForwardBackward._repr_helper_(cf, indent=4))
348 logger.info(
349 "[OrtGradientForwardBackward] graph_builder_config."
350 "graph_transformer_config.propagate_cast_ops_config=%s",
351 OrtGradientForwardBackward._repr_helper_(cfp, indent=4))
353 builder.initialize(
354 self.onnx_model.SerializeToString(),
355 self.graph_builder_config)
357 if logger is not None:
358 logger.info(
359 "[OrtGradientForwardBackward] OrtModuleGraphBuilder.build")
360 builder.build()
362 if logger is not None:
363 logger.info(
364 "[OrtGradientForwardBackward] OrtModuleGraphBuilder.get_model")
366 train_onnx_model_serialized = builder.get_model()
368 optimized_pre_grad_model = builder.get_inference_optimized_model()
369 graph_info = builder.get_graph_info()
371 if logger is not None:
372 logger.info("[OrtGradientForwardBackward] graph_info=%s",
373 OrtGradientForwardBackward._repr_helper_(
374 graph_info, indent=4))
375 logger.info("[OrtGradientForwardBackward] create TrainSession")
376 logger.info("[OrtGradientForwardBackward] sess_options=%s",
377 OrtGradientForwardBackward._repr_helper_(
378 self.sess_options, indent=4))
379 logger.info(
380 "[OrtGradientForwardBackward] providers=%r", self.providers)
382 sess = InferenceSession(
383 train_onnx_model_serialized, sess_options=self.sess_options,
384 provider_options=self.provider_options, providers=self.providers)
386 if logger is not None:
387 logger.info("[OrtGradientForwardBackward] create InferenceSession")
389 sess_eval = InferenceSession(
390 optimized_pre_grad_model, sess_options=self.sess_options,
391 provider_options=self.provider_options, providers=self.providers)
393 if logger is not None:
394 logger.info("[OrtGradientForwardBackward] create training agent")
396 grad_input_names = [obj.name for obj in sess.get_inputs()]
397 bw_fetches_names = [obj.name for obj in sess.get_outputs()]
399 fw_outputs_device_info = [
400 OrtDevice(
401 OrtGradientForwardBackward._provider_name_to_device_type(i),
402 OrtDevice.default_memory(), self.device_index)
403 for i in self.providers]
404 bw_outputs_device_info = [
405 OrtDevice(
406 OrtGradientForwardBackward._provider_name_to_device_type(
407 self.providers[0]),
408 OrtDevice.default_memory(), self.device_index)
409 for i in bw_fetches_names]
410 fw_no_grad_output_device_info = [
411 OrtDevice(
412 OrtGradientForwardBackward._provider_name_to_device_type(
413 self.providers[0]),
414 OrtDevice.default_memory(), self.device_index)
415 for i in self.output_names]
417 training_agent = TrainingAgent(
418 sess._sess,
419 grad_input_names,
420 fw_outputs_device_info,
421 bw_fetches_names,
422 bw_outputs_device_info)
424 if logger is not None:
425 logger.info(
426 "[OrtGradientForwardBackward] instantiate dynamic class %r",
427 self.class_name)
428 logger.info(
429 "[OrtGradientForwardBackward] weights_to_train=%r",
430 self.weights_to_train)
431 logger.info(
432 "[OrtGradientForwardBackward] grad_input_names=%r",
433 grad_input_names)
434 logger.info(
435 "[OrtGradientForwardBackward] bw_fetches_names=%r",
436 bw_fetches_names)
437 logger.info(
438 "[OrtGradientForwardBackward] device_index=%r",
439 self.device_index)
440 devices = list(fw_outputs_device_info)
441 while len(devices) < len(grad_input_names):
442 devices.append(devices[-1])
444 trained_onnx = onnx.load(BytesIO(train_onnx_model_serialized))
445 onnx_loss = onnx.load(BytesIO(optimized_pre_grad_model))
446 for i, node in enumerate(trained_onnx.graph.node):
447 if node.name == '':
448 node.name = "N%d" % i
449 for i, node in enumerate(onnx_loss.graph.node):
450 if node.name == '':
451 node.name = "N%d" % i
453 kwargs = {
454 '_run_options': self.run_options,
455 '_sess': sess,
456 '_sess_eval': sess_eval,
457 '_training_agent': training_agent,
458 '_cache': OrtValueCache(),
459 '_logger': logger,
460 '_input_names': self.input_names,
461 '_grad_input_names': grad_input_names,
462 '_output_names': self.output_names,
463 '_bw_fetches_names': bw_fetches_names,
464 '_fw_outputs_device_info': fw_outputs_device_info,
465 '_bw_outputs_device_info': bw_outputs_device_info,
466 '_fw_no_grad_output_device_info': fw_no_grad_output_device_info,
467 '_weights_to_train': list(sorted(
468 self.weights_to_train)),
469 '_graph_info': graph_info,
470 #
471 '_trained_onnx': trained_onnx,
472 '_optimized_pre_grad_model': onnx_loss,
473 '_graph_builder': builder,
474 '_devices': devices,
475 '_debug': self.debug
476 }
477 graph = kwargs['_trained_onnx'].graph
478 kwargs.update({
479 '_onx_inp': [o.name for o in graph.input],
480 '_onx_out': [o.name for o in graph.output]
481 })
483 if len(kwargs['_onx_inp']) != len(kwargs['_onx_out']):
484 raise RuntimeError( # pragma: no cover
485 "Gradient input and output are inconsistant: "
486 "%r != %r" % (kwargs['_onx_inp'], kwargs['_onx_out']))
487 return kwargs
490class OrtGradientForwardBackwardFunction:
491 """
492 Ancestor for a class implementing forward and backward
493 and dynamically created by @see cl OrtGradientForwardBackward.
495 Attributes stored in *forward* method:
496 * `saved_tensors_`: list of tensors to save during forward
497 and to retrieve during backward
498 * `state_`: current weights stored in :epkg:`PartialGraphExecutionState`
499 """
501 def __init__(self):
502 self.states_ = []
503 self.saved_tensors_ = None
505 @classmethod
506 def save_onnx_graph(cls, folder, prefix=None, suffix=None):
507 """
508 Saves onnx graph stored in this class.
509 """
510 if prefix is None:
511 prefix = '' # pragma: no cover
512 if suffix is None:
513 suffix = '' # pragma: no cover
514 if isinstance(folder, str) and not os.path.exists(folder):
515 raise FileNotFoundError( # pragma: no cover
516 "Folder %r does not exist." % folder)
517 saved = {}
518 for k, v in cls.__dict__.items():
519 if hasattr(v, "SerializeToString"):
520 if isinstance(folder, str):
521 name = "%s%s%s.%s.onnx" % (
522 prefix, cls.__name__, suffix, k)
523 filename = os.path.join(folder, name)
524 if os.path.exists(filename):
525 warnings.warn( # pragma: no cover
526 "Filename %r already exists." % filename)
527 with open(filename, "wb") as f:
528 f.write(v.SerializeToString())
529 saved[k] = filename
530 else:
531 saved[k] = v.SerializeToString()
532 elif hasattr(v, "save_onnx_graph"):
533 saved[k] = v.save_onnx_graph(
534 folder, prefix=prefix, suffix="%s.%s" % (suffix, k))
535 return saved
537 @staticmethod
538 def device_name(device):
539 """
540 Returns the device name of a device.
542 :param device: OrtDevice
543 :return: string
544 """
545 if device.device_type() == OrtDevice.cpu():
546 return 'Cpu'
547 if device.device_type() == OrtDevice.cuda(): # pragma: no cover
548 return 'Gpu'
549 raise RuntimeError( # pragma: no cover
550 "Unexpected value for device type %r." % device.device_type())
552 @staticmethod
553 def input_to_ort(tensors, devices, debug):
554 "Converts a list of tensos into an :epkg:`OrtValueVector`."
555 def _validate_(tensors):
556 if any(map(
557 lambda tu: (
558 tu[0].device_name() !=
559 OrtGradientForwardBackwardFunction.device_name(
560 tu[1])),
561 zip(tensors, devices))):
562 raise RuntimeError( # pragma: no cover
563 "Not all inputs are on the same device %r != %r." % (
564 [OrtGradientForwardBackward.device_name(d)
565 for d in devices],
566 [x.device_name() for x in tensors]))
568 if isinstance(tensors, OrtValueVector):
569 if debug:
570 _validate_(tensors)
571 return tensors
572 if all(map(lambda t: isinstance(t, C_OrtValue), tensors)):
573 if debug:
574 _validate_(tensors)
575 vect = OrtValueVector()
576 vect.reserve(len(tensors))
577 for t in tensors:
578 if t is None:
579 raise NotImplementedError( # pragma: no cover
580 "Empty vector found.")
581 vect.push_back(t)
582 return vect
584 # generic case
585 vect = OrtValueVector()
586 vect.reserve(len(tensors))
587 for t, dev in zip(tensors, devices):
588 if t is None:
589 # if gradient then
590 # grad_output = torch.zeros(shape, device=device, dtype=dtype)
591 raise NotImplementedError( # pragma: no cover
592 "Empty vector found.")
593 if not t.data.contiguous:
594 t = t.as_contiguous() # pragma: no cover
595 vect.push_back(C_OrtValue.ortvalue_from_numpy(t, dev))
596 if debug:
597 if len(vect) != len(tensors):
598 raise RuntimeError( # pragma: no cover
599 "Unexpected array length %d != %d (len(devices)=%d)." % (
600 len(vect), len(tensors), len(devices)))
601 _validate_(vect)
602 return vect
604 def save_for_backward(self, inputs):
605 """
606 Saves inputs furing forward steps. The list inputs
607 is copied (simple copy, no deep copy).
609 :param inputs: list of tensors to save.
610 """
611 self.saved_tensors_ = list(inputs)
613 @property
614 def saved_tensors(self):
615 """
616 Returns saved tensors during forward step.
617 """
618 if self.saved_tensors_ is None:
619 raise RuntimeError( # pragma: no cover
620 "No tensors was saved with save_for_backward.")
621 return self.saved_tensors_
623 def forward(self, inputs, training=False, forward_outputs_cache=None):
624 """
625 Implements forward function.
627 :param inputs: inputs
628 :param training: only inference or training as well
629 :return: output as :epkg:`OrtValueVector`
630 """
631 logger = self._logger
632 cls = self.__class__
634 def _log(msg, *args):
635 logger.debug("[%s.forward] (%dI) " + msg,
636 cls.__name__, len(inputs), *args)
638 if logger is not None:
639 if training:
640 _log("begin with gradient")
641 else:
642 _log("begin")
643 _log("torch function %r", type(cls))
644 _log("ort class %r", cls)
645 _log("create OrtValueVector (through dlpack)")
647 forward_inputs = cls.input_to_ort(
648 inputs, cls._devices, cls._debug)
650 if training:
651 forward_outputs = forward_outputs_cache or OrtValueVector()
652 state = PartialGraphExecutionState()
653 self.states_.append(state)
654 if logger is not None:
655 _log("run_forward")
656 cls._training_agent.run_forward(
657 forward_inputs, forward_outputs, state, cls._cache)
659 self.save_for_backward(inputs)
660 if logger is not None:
661 _log("end")
662 return forward_outputs
663 else:
664 # what about bind_input (+ data_ptr)
665 if len(forward_inputs) != len(cls._grad_input_names):
666 raise RuntimeError( # pragma: no cover
667 "Size mismatch len(inputs)=%d, len(onnx inputs)=%d." % (
668 len(forward_inputs), len(cls._grad_input_names)))
669 iobinding = SessionIOBinding(cls._sess_eval._sess)
670 if logger is not None:
671 _log("bind inputs %r", cls._grad_input_names)
672 for name, inp in zip(
673 cls._grad_input_names, forward_inputs):
674 iobinding.bind_ortvalue_input(name, inp)
676 # bind output
677 if logger is not None:
678 _log("bind outputs %r", cls._output_names)
679 for name, dev in zip(
680 cls._output_names, cls._fw_no_grad_output_device_info):
681 iobinding.bind_output(name, dev)
683 # if the shape is known in advance
684 # iobinding.bind_output(
685 # output_desc.name, torch_tensor.device.type,
686 # _utils.get_device_index(target_device),
687 # _utils.dtype_torch_to_numpy(torch_tensor.dtype),
688 # list(torch_tensor.size()), torch_tensor.data_ptr())
690 if logger is not None:
691 _log("grad_enabled=False (run_with_iobinding)")
692 cls._sess_eval._sess.run_with_iobinding(
693 iobinding, cls._run_options)
694 if logger is not None:
695 _log("get_outputs")
696 ortvalues = iobinding.get_outputs()
697 if logger is not None:
698 _log("to torck.tensor (%d)", len(ortvalues))
699 _log("end")
700 return ortvalues
702 def backward(self, grad_outputs, backward_outputs_cache=None):
703 """
704 Implements backward function. The function returns
705 an :epkg:`OrtValueVector`.
706 """
707 cls = self.__class__
708 logger = cls._logger
710 def _log(msg, *args):
711 logger.debug("[%s.backward] (%dI) " + msg,
712 cls.__name__, len(grad_outputs), *args)
714 if logger is not None:
715 _log("begin")
716 _log("torch function %r", type(cls))
717 _log("ort class %r", cls)
718 _log("saved_tensors")
720 inputs = self.saved_tensors
721 if logger is not None:
722 _log("DEBUG: saved_tensors %r", type(inputs))
723 _log("self.state_.pop()")
724 state = self.states_.pop()
726 if logger is not None:
727 _log("create OrtValueVector")
729 backward_inputs = cls.input_to_ort(
730 grad_outputs, cls._bw_outputs_device_info, cls._debug)
732 if logger is not None:
733 _log("len(grad_outputs)=%d type(grad_outputs)=%r",
734 len(grad_outputs), type(grad_outputs))
735 _log("len(backward_inputs)=%d type(backward_inputs)=%r",
736 len(backward_inputs), type(backward_inputs))
737 for i in range(len(backward_inputs)): # pylint: disable=C0200
738 _log("backward_inputs[%d].shape=%r",
739 i, backward_inputs[i].shape())
740 _log("run_backward")
741 backward_outputs = backward_outputs_cache or OrtValueVector()
742 cls._training_agent.run_backward(
743 backward_inputs, backward_outputs, state)
744 if logger is not None: # pragma: no cover
745 _log("DEBUG")
746 for i, ov in enumerate(backward_outputs):
747 _log("BCK-RET: i=%d - shape=%r - ptr=%r",
748 i, ov.shape(), ov.data_ptr())
749 _log("got %r gradients", len(backward_outputs))
750 _log("end")
751 return backward_outputs