Coverage for onnxcustom/training/ortgradient.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

323 statements  

1# pylint: disable=E1101 

2""" 

3@file 

4@brief Gradient with :epkg:`onnxruntime-training` forward backward. 

5""" 

6import os 

7import logging 

8import warnings 

9from io import BytesIO 

10import onnx 

11from onnx.numpy_helper import to_array 

12from onnxruntime import InferenceSession, RunOptions 

13from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

14 SessionIOBinding, OrtValue as C_OrtValue) 

15from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

16 TrainingAgent, OrtValueCache, OrtModuleGraphBuilder, 

17 OrtModuleGraphBuilderConfiguration, OrtDevice, 

18 TrainingGraphTransformerConfiguration, OrtValueVector, 

19 PartialGraphExecutionState) 

20from ..utils.orttraining_helper import get_train_initializer 

21 

22 

23class OrtGradientForwardBackward: 

24 """ 

25 Implements forward backward mechanism assuming the function 

26 to train is defined by an ONNX graph. 

27 

28 :param onnx_model: onnx model 

29 :param weights_to_train: names of the weights to train, 

30 if None, all initializer of floats type are included in the list 

31 :param input_names: input names or None for all 

32 :param output_names: output names or None for all 

33 :param class_name: name to give the class dynamically created 

34 :param sess_options: see :epkg:`SessionOptions` 

35 :param providers: see :epkg:`InferenceSession` 

36 :param provider_options: see :epkg:`InferenceSession` 

37 :param run_options: see :epkg:`RunOptions` 

38 :param graph_builder_config: 

39 see :epkg:`OrtModuleGraphBuilderConfiguration` 

40 :param device_index: used for cuda (0 for `cuda:0`, 

41 `cuda:1`, ...), 0 by default 

42 :param enable_logging: enables logging while setting up the class 

43 :param debug: to run extra verification while training 

44 

45 .. note:: 

46 The current implementation of :epkg:`onnxruntime` forces 

47 the weights to train to appear in the alphabetical order. 

48 The constructor checks that condition is verified. 

49 

50 .. warning:: 

51 This class does not consider subgraphs. 

52 """ 

53 

54 def __init__(self, onnx_model, weights_to_train=None, 

55 input_names=None, output_names=None, class_name=None, 

56 sess_options=None, providers=None, 

57 provider_options=None, run_options=None, 

58 graph_builder_config=None, 

59 device_index=0, enable_logging=False, debug=False): 

60 

61 if weights_to_train is None: 

62 weights_to_train = ( 

63 OrtGradientForwardBackward._select_initializer_names( 

64 onnx_model)) 

65 if len(weights_to_train) == 0: 

66 raise RuntimeError( # pragma: no cover 

67 "Unable to guess the weights to train from initializers: " 

68 "%r." % [i.name for i in onnx_model.graph.initializer]) 

69 

70 self.onnx_model = onnx_model 

71 self.input_names = input_names 

72 self.output_names = output_names 

73 self.weights_to_train = weights_to_train 

74 self.device_index = device_index 

75 self.enable_logging = enable_logging 

76 self.class_name = (class_name if class_name is not None else 

77 "OrtGradientForwardBackwardFunction_%d" % id(self)) 

78 

79 self.provider_options = provider_options 

80 self.sess_options = sess_options 

81 self.providers = providers 

82 self.run_options = run_options 

83 self.graph_builder_config = graph_builder_config 

84 self.debug = debug 

85 

86 # default 

87 if self.weights_to_train is None: 

88 raise ValueError( # pragma: no cover 

89 "weights_to_train must be specified.") 

90 if self.input_names is None: 

91 self.input_names = [obj.name 

92 for obj in self.onnx_model.graph.input] 

93 if self.output_names is None: 

94 self.output_names = [obj.name 

95 for obj in self.onnx_model.graph.output] 

96 if self.class_name is None: 

97 self.class_name = "TorchOrtFunction_%r" % id( 

98 self) # pragma: no cover 

99 if hasattr(self.providers, 'type'): 

100 if self.providers.type != 'cpu': 

101 self.device_index = self.providers.index 

102 self.providers = self.providers.type 

103 if self.providers in (None, 'cpu'): 

104 self.providers = ["CPUExecutionProvider" for i in self.input_names] 

105 if self.provider_options is None: 

106 self.provider_options = [{} for i in self.input_names] 

107 elif self.providers in ('cuda', 'cuda:0', 'gpu'): 

108 self.providers = [ 

109 "CUDAExecutionProvider" for i in self.input_names] 

110 if self.provider_options is None: 

111 self.provider_options = [{} for i in self.input_names] 

112 if self.provider_options is None: 

113 self.provider_options = [{} for i in self.providers] 

114 

115 if list(sorted(self.weights_to_train)) != self.weights_to_train: 

116 raise ValueError( # pragma: no cover 

117 "List of weights to train must be sorted but %r is not. " 

118 "You shoud use function onnx_rename_weights to do that " 

119 "before calling this class." % self.weights_to_train) 

120 set_weights = set(self.weights_to_train) 

121 if len(set_weights) != len(self.weights_to_train): 

122 raise ValueError( # pragma: no cover 

123 "One weight is not unique in %r." % self.weights_to_train) 

124 found = [] 

125 for i in self.onnx_model.graph.initializer: 

126 if i.name not in set_weights: 

127 continue 

128 found.append(i.name) 

129 if len(found) != len(self.weights_to_train): 

130 raise ValueError( 

131 "One weight name in self.weights_to_train was not found in " 

132 "the initializers %r found=%r init names=%r." % ( 

133 self.weights_to_train, found, 

134 [i.name for i in self.onnx_model.graph.initializer])) 

135 if found != self.weights_to_train: 

136 raise ValueError( 

137 "List of weights to train must be sorted and follow the " 

138 "as the initializers in the graph. %r != %r." 

139 "You shoud use function onnx_rename_weights to do that " 

140 "before calling this class." % ( 

141 self.weights_to_train, found)) 

142 

143 if any(map(lambda v: v not in ['CPUExecutionProvider', 

144 'CUDAExecutionProvider'], 

145 self.providers)): 

146 raise ValueError( 

147 "Unexpected providers %r (providers=%r)." % ( 

148 self.providers, providers)) 

149 

150 # complete initialisation 

151 self._init_next() 

152 

153 @staticmethod 

154 def _select_initializer_names(onnx_model): 

155 """ 

156 Selects all initializers with float type. 

157 

158 :param onnx_model: ONNX graph 

159 """ 

160 inits = get_train_initializer(onnx_model) 

161 return list(inits) 

162 

163 def _init_next(self): 

164 if self.enable_logging: 

165 self._logger = logging.getLogger("onnxcustom") 

166 else: 

167 self._logger = None # pragma: no cover 

168 if self.run_options is None: 

169 self.run_options = RunOptions() 

170 self.run_options.training_mode = True 

171 

172 if self.graph_builder_config is None: 

173 initializer_names = [ 

174 i.name for i in self.onnx_model.graph.initializer] 

175 input_names = [i.name for i in self.onnx_model.graph.input] 

176 

177 config = OrtModuleGraphBuilderConfiguration() 

178 config.initializer_names = [init for init in initializer_names 

179 if init in self.weights_to_train] 

180 config.initializer_names_to_train = self.weights_to_train 

181 config.input_names_require_grad = input_names 

182 config.build_gradient_graph = True 

183 

184 if (len(config.initializer_names) != # noqa 

185 len(config.initializer_names_to_train)): 

186 raise RuntimeError( # pragma: no cover 

187 "Unable to automatically fill " 

188 "OrtModuleGraphBuilderConfiguration, mismatch between " 

189 "%r and %r (initializer_names=%r)." % ( 

190 config.initializer_names, 

191 config.initializer_names_to_train, 

192 initializer_names)) 

193 

194 p = TrainingGraphTransformerConfiguration() 

195 config.graph_transformer_config = p 

196 

197 # config.enable_caching = True 

198 # config.loglevel = 

199 # config.use_memory_efficient_gradient = True 

200 self.graph_builder_config = config 

201 

202 attributes = self._create_onnx_graphs() 

203 attributes['__doc__'] = ( 

204 "Inherits from @see cl OrtGradientForwardBackwardFunction.") 

205 attributes['__module__'] = ( 

206 OrtGradientForwardBackwardFunction.__module__) 

207 self.cls_type_ = type( 

208 self.class_name, (OrtGradientForwardBackwardFunction,), 

209 attributes) 

210 

211 def new_instance(self): 

212 """ 

213 Creates an instance of class `self.cls_type_`. 

214 It implements methods *forward* and *backward*. 

215 """ 

216 return self.cls_type_() 

217 

218 def __getstate__(self): 

219 "Removes any non pickable attribute." 

220 atts = [k for k in self.__dict__ if not k.endswith('_') 

221 if k not in {'_logger', 'graph_builder_config', 

222 'run_options'}] 

223 state = {att: getattr(self, att) for att in atts} 

224 state['run_options'] = None 

225 state['graph_builder_config'] = None 

226 return state 

227 

228 def __setstate__(self, state): 

229 "Restores any non pickable attribute." 

230 for att, v in state.items(): 

231 setattr(self, att, v) 

232 self._init_next() 

233 return self 

234 

235 def __repr__(self): 

236 "usual" 

237 return "%s(...)" % self.__class__.__name__ 

238 

239 @staticmethod 

240 def _repr_helper_(obj, indent=0): 

241 "used to improve logging messages" 

242 if obj is None: 

243 return 'None' 

244 rows = [] 

245 for c in sorted(dir(obj)): 

246 if c[0] == '_': 

247 continue 

248 try: 

249 value = getattr(obj, c) 

250 except AttributeError: # pragma: no cover 

251 continue 

252 rows.append("%s=%r" % (c, value)) 

253 

254 if indent == 0: 

255 return "%s(%s)" % (obj.__class__.__name__, ", ".join(rows)) 

256 return "%s(\n %s)" % ( 

257 obj.__class__.__name__, 

258 "\n ".join(rows)) 

259 

260 @staticmethod 

261 def _provider_name_to_device_type(provider_name): 

262 if provider_name == 'CPUExecutionProvider': 

263 return OrtDevice.cpu() 

264 if provider_name == 'CUDAExecutionProvider': # pragma: no cover 

265 return OrtDevice.cuda() 

266 raise ValueError( # pragma: no cover 

267 'Unexpected provider name %r.' % provider_name) 

268 

269 def get_initializer(self, name, exc=True): 

270 """ 

271 Returns an initializer as numpy arrays. 

272 

273 :param name: initializer name 

274 :param exc: raises an exception if not found or return None 

275 :return: the initializer as a :epkg:`C_OrtValue` 

276 """ 

277 for init in self.onnx_model.graph.initializer: 

278 if name == init.name: 

279 return to_array(init) 

280 if exc: 

281 raise RuntimeError( # pragma: no cover 

282 "Unable to find name %r in %r." % ( 

283 name, 

284 list(i.name for i in self.onnx_model.graph.initializer))) 

285 return None 

286 

287 def _create_onnx_graphs(self): 

288 """ 

289 Creates forward and backward ONNX graph. 

290 The new class has the following attributes: 

291 

292 * `__doc__`: doc string 

293 * `__module__`: module name (this file) 

294 * `_run_options`: see :epkg:`RunOptions` 

295 * `_sess`: :epkg:`InferenceSession` with the original graph 

296 * `_sess_eval`: :epkg:`InferenceSession` on the graph 

297 with weights as inputs 

298 * `_training_agent`: :epkg:`TrainingAgent` 

299 * `_cache`: :epkg:`OrtValueCache` 

300 * `_logger`: logger 

301 * `_input_names`: input names 

302 * `_debug`: use debug mode 

303 * `_grad_input_names`: gradient input names 

304 * `_output_names`: output names 

305 * `_weights_to_train`: names of the weights to train 

306 

307 Training attributes 

308 

309 * `_bw_fetches_names`: bw_fetches_names, 

310 * `_fw_outputs_device_info`: fw_outputs_device_info, 

311 * `_bw_outputs_device_info`: bw_outputs_device_info, 

312 * `_fw_no_grad_output_device_info`: fw_no_grad_output_device_info, 

313 * `_graph_info`: graph_info} 

314 

315 Additional attributes added if *keep_model* is True: 

316 

317 * `_trained_onnx`: ONNX graph for the gradient 

318 * `_optimized_pre_grad_model`: evaluation ONNX graph taking 

319 weights as inputs 

320 * `_graph_builder`: :epkg:`OrtModuleGraphBuilder` 

321 """ 

322 logger = self._logger 

323 if logger is not None: 

324 logger.info("[OrtGradientForwardBackward] create training onnx") 

325 logger.info("[OrtGradientForwardBackward] input_names=%r", 

326 self.input_names) 

327 logger.info("[OrtGradientForwardBackward] output_names=%r", 

328 self.output_names) 

329 logger.info("[OrtGradientForwardBackward] weights_to_train=%r", 

330 self.weights_to_train) 

331 

332 builder = OrtModuleGraphBuilder() 

333 

334 if logger is not None: 

335 cf = self.graph_builder_config.graph_transformer_config 

336 cfp = cf.propagate_cast_ops_config 

337 logger.info( 

338 "[OrtGradientForwardBackward] " 

339 "OrtModuleGraphBuilder.initialize") 

340 logger.info( 

341 "[OrtGradientForwardBackward] graph_builder_config=%s", 

342 OrtGradientForwardBackward._repr_helper_( 

343 self.graph_builder_config, indent=4)) 

344 logger.info( 

345 "[OrtGradientForwardBackward] graph_builder_config." 

346 "graph_transformer_config=%s", 

347 OrtGradientForwardBackward._repr_helper_(cf, indent=4)) 

348 logger.info( 

349 "[OrtGradientForwardBackward] graph_builder_config." 

350 "graph_transformer_config.propagate_cast_ops_config=%s", 

351 OrtGradientForwardBackward._repr_helper_(cfp, indent=4)) 

352 

353 builder.initialize( 

354 self.onnx_model.SerializeToString(), 

355 self.graph_builder_config) 

356 

357 if logger is not None: 

358 logger.info( 

359 "[OrtGradientForwardBackward] OrtModuleGraphBuilder.build") 

360 builder.build() 

361 

362 if logger is not None: 

363 logger.info( 

364 "[OrtGradientForwardBackward] OrtModuleGraphBuilder.get_model") 

365 

366 train_onnx_model_serialized = builder.get_model() 

367 

368 optimized_pre_grad_model = builder.get_inference_optimized_model() 

369 graph_info = builder.get_graph_info() 

370 

371 if logger is not None: 

372 logger.info("[OrtGradientForwardBackward] graph_info=%s", 

373 OrtGradientForwardBackward._repr_helper_( 

374 graph_info, indent=4)) 

375 logger.info("[OrtGradientForwardBackward] create TrainSession") 

376 logger.info("[OrtGradientForwardBackward] sess_options=%s", 

377 OrtGradientForwardBackward._repr_helper_( 

378 self.sess_options, indent=4)) 

379 logger.info( 

380 "[OrtGradientForwardBackward] providers=%r", self.providers) 

381 

382 sess = InferenceSession( 

383 train_onnx_model_serialized, sess_options=self.sess_options, 

384 provider_options=self.provider_options, providers=self.providers) 

385 

386 if logger is not None: 

387 logger.info("[OrtGradientForwardBackward] create InferenceSession") 

388 

389 sess_eval = InferenceSession( 

390 optimized_pre_grad_model, sess_options=self.sess_options, 

391 provider_options=self.provider_options, providers=self.providers) 

392 

393 if logger is not None: 

394 logger.info("[OrtGradientForwardBackward] create training agent") 

395 

396 grad_input_names = [obj.name for obj in sess.get_inputs()] 

397 bw_fetches_names = [obj.name for obj in sess.get_outputs()] 

398 

399 fw_outputs_device_info = [ 

400 OrtDevice( 

401 OrtGradientForwardBackward._provider_name_to_device_type(i), 

402 OrtDevice.default_memory(), self.device_index) 

403 for i in self.providers] 

404 bw_outputs_device_info = [ 

405 OrtDevice( 

406 OrtGradientForwardBackward._provider_name_to_device_type( 

407 self.providers[0]), 

408 OrtDevice.default_memory(), self.device_index) 

409 for i in bw_fetches_names] 

410 fw_no_grad_output_device_info = [ 

411 OrtDevice( 

412 OrtGradientForwardBackward._provider_name_to_device_type( 

413 self.providers[0]), 

414 OrtDevice.default_memory(), self.device_index) 

415 for i in self.output_names] 

416 

417 training_agent = TrainingAgent( 

418 sess._sess, 

419 grad_input_names, 

420 fw_outputs_device_info, 

421 bw_fetches_names, 

422 bw_outputs_device_info) 

423 

424 if logger is not None: 

425 logger.info( 

426 "[OrtGradientForwardBackward] instantiate dynamic class %r", 

427 self.class_name) 

428 logger.info( 

429 "[OrtGradientForwardBackward] weights_to_train=%r", 

430 self.weights_to_train) 

431 logger.info( 

432 "[OrtGradientForwardBackward] grad_input_names=%r", 

433 grad_input_names) 

434 logger.info( 

435 "[OrtGradientForwardBackward] bw_fetches_names=%r", 

436 bw_fetches_names) 

437 logger.info( 

438 "[OrtGradientForwardBackward] device_index=%r", 

439 self.device_index) 

440 devices = list(fw_outputs_device_info) 

441 while len(devices) < len(grad_input_names): 

442 devices.append(devices[-1]) 

443 

444 trained_onnx = onnx.load(BytesIO(train_onnx_model_serialized)) 

445 onnx_loss = onnx.load(BytesIO(optimized_pre_grad_model)) 

446 for i, node in enumerate(trained_onnx.graph.node): 

447 if node.name == '': 

448 node.name = "N%d" % i 

449 for i, node in enumerate(onnx_loss.graph.node): 

450 if node.name == '': 

451 node.name = "N%d" % i 

452 

453 kwargs = { 

454 '_run_options': self.run_options, 

455 '_sess': sess, 

456 '_sess_eval': sess_eval, 

457 '_training_agent': training_agent, 

458 '_cache': OrtValueCache(), 

459 '_logger': logger, 

460 '_input_names': self.input_names, 

461 '_grad_input_names': grad_input_names, 

462 '_output_names': self.output_names, 

463 '_bw_fetches_names': bw_fetches_names, 

464 '_fw_outputs_device_info': fw_outputs_device_info, 

465 '_bw_outputs_device_info': bw_outputs_device_info, 

466 '_fw_no_grad_output_device_info': fw_no_grad_output_device_info, 

467 '_weights_to_train': list(sorted( 

468 self.weights_to_train)), 

469 '_graph_info': graph_info, 

470 # 

471 '_trained_onnx': trained_onnx, 

472 '_optimized_pre_grad_model': onnx_loss, 

473 '_graph_builder': builder, 

474 '_devices': devices, 

475 '_debug': self.debug 

476 } 

477 graph = kwargs['_trained_onnx'].graph 

478 kwargs.update({ 

479 '_onx_inp': [o.name for o in graph.input], 

480 '_onx_out': [o.name for o in graph.output] 

481 }) 

482 

483 if len(kwargs['_onx_inp']) != len(kwargs['_onx_out']): 

484 raise RuntimeError( # pragma: no cover 

485 "Gradient input and output are inconsistant: " 

486 "%r != %r" % (kwargs['_onx_inp'], kwargs['_onx_out'])) 

487 return kwargs 

488 

489 

490class OrtGradientForwardBackwardFunction: 

491 """ 

492 Ancestor for a class implementing forward and backward 

493 and dynamically created by @see cl OrtGradientForwardBackward. 

494 

495 Attributes stored in *forward* method: 

496 * `saved_tensors_`: list of tensors to save during forward 

497 and to retrieve during backward 

498 * `state_`: current weights stored in :epkg:`PartialGraphExecutionState` 

499 """ 

500 

501 def __init__(self): 

502 self.states_ = [] 

503 self.saved_tensors_ = None 

504 

505 @classmethod 

506 def save_onnx_graph(cls, folder, prefix=None, suffix=None): 

507 """ 

508 Saves onnx graph stored in this class. 

509 """ 

510 if prefix is None: 

511 prefix = '' # pragma: no cover 

512 if suffix is None: 

513 suffix = '' # pragma: no cover 

514 if isinstance(folder, str) and not os.path.exists(folder): 

515 raise FileNotFoundError( # pragma: no cover 

516 "Folder %r does not exist." % folder) 

517 saved = {} 

518 for k, v in cls.__dict__.items(): 

519 if hasattr(v, "SerializeToString"): 

520 if isinstance(folder, str): 

521 name = "%s%s%s.%s.onnx" % ( 

522 prefix, cls.__name__, suffix, k) 

523 filename = os.path.join(folder, name) 

524 if os.path.exists(filename): 

525 warnings.warn( # pragma: no cover 

526 "Filename %r already exists." % filename) 

527 with open(filename, "wb") as f: 

528 f.write(v.SerializeToString()) 

529 saved[k] = filename 

530 else: 

531 saved[k] = v.SerializeToString() 

532 elif hasattr(v, "save_onnx_graph"): 

533 saved[k] = v.save_onnx_graph( 

534 folder, prefix=prefix, suffix="%s.%s" % (suffix, k)) 

535 return saved 

536 

537 @staticmethod 

538 def device_name(device): 

539 """ 

540 Returns the device name of a device. 

541 

542 :param device: OrtDevice 

543 :return: string 

544 """ 

545 if device.device_type() == OrtDevice.cpu(): 

546 return 'Cpu' 

547 if device.device_type() == OrtDevice.cuda(): # pragma: no cover 

548 return 'Gpu' 

549 raise RuntimeError( # pragma: no cover 

550 "Unexpected value for device type %r." % device.device_type()) 

551 

552 @staticmethod 

553 def input_to_ort(tensors, devices, debug): 

554 "Converts a list of tensos into an :epkg:`OrtValueVector`." 

555 def _validate_(tensors): 

556 if any(map( 

557 lambda tu: ( 

558 tu[0].device_name() != 

559 OrtGradientForwardBackwardFunction.device_name( 

560 tu[1])), 

561 zip(tensors, devices))): 

562 raise RuntimeError( # pragma: no cover 

563 "Not all inputs are on the same device %r != %r." % ( 

564 [OrtGradientForwardBackward.device_name(d) 

565 for d in devices], 

566 [x.device_name() for x in tensors])) 

567 

568 if isinstance(tensors, OrtValueVector): 

569 if debug: 

570 _validate_(tensors) 

571 return tensors 

572 if all(map(lambda t: isinstance(t, C_OrtValue), tensors)): 

573 if debug: 

574 _validate_(tensors) 

575 vect = OrtValueVector() 

576 vect.reserve(len(tensors)) 

577 for t in tensors: 

578 if t is None: 

579 raise NotImplementedError( # pragma: no cover 

580 "Empty vector found.") 

581 vect.push_back(t) 

582 return vect 

583 

584 # generic case 

585 vect = OrtValueVector() 

586 vect.reserve(len(tensors)) 

587 for t, dev in zip(tensors, devices): 

588 if t is None: 

589 # if gradient then 

590 # grad_output = torch.zeros(shape, device=device, dtype=dtype) 

591 raise NotImplementedError( # pragma: no cover 

592 "Empty vector found.") 

593 if not t.data.contiguous: 

594 t = t.as_contiguous() # pragma: no cover 

595 vect.push_back(C_OrtValue.ortvalue_from_numpy(t, dev)) 

596 if debug: 

597 if len(vect) != len(tensors): 

598 raise RuntimeError( # pragma: no cover 

599 "Unexpected array length %d != %d (len(devices)=%d)." % ( 

600 len(vect), len(tensors), len(devices))) 

601 _validate_(vect) 

602 return vect 

603 

604 def save_for_backward(self, inputs): 

605 """ 

606 Saves inputs furing forward steps. The list inputs 

607 is copied (simple copy, no deep copy). 

608 

609 :param inputs: list of tensors to save. 

610 """ 

611 self.saved_tensors_ = list(inputs) 

612 

613 @property 

614 def saved_tensors(self): 

615 """ 

616 Returns saved tensors during forward step. 

617 """ 

618 if self.saved_tensors_ is None: 

619 raise RuntimeError( # pragma: no cover 

620 "No tensors was saved with save_for_backward.") 

621 return self.saved_tensors_ 

622 

623 def forward(self, inputs, training=False, forward_outputs_cache=None): 

624 """ 

625 Implements forward function. 

626 

627 :param inputs: inputs 

628 :param training: only inference or training as well 

629 :return: output as :epkg:`OrtValueVector` 

630 """ 

631 logger = self._logger 

632 cls = self.__class__ 

633 

634 def _log(msg, *args): 

635 logger.debug("[%s.forward] (%dI) " + msg, 

636 cls.__name__, len(inputs), *args) 

637 

638 if logger is not None: 

639 if training: 

640 _log("begin with gradient") 

641 else: 

642 _log("begin") 

643 _log("torch function %r", type(cls)) 

644 _log("ort class %r", cls) 

645 _log("create OrtValueVector (through dlpack)") 

646 

647 forward_inputs = cls.input_to_ort( 

648 inputs, cls._devices, cls._debug) 

649 

650 if training: 

651 forward_outputs = forward_outputs_cache or OrtValueVector() 

652 state = PartialGraphExecutionState() 

653 self.states_.append(state) 

654 if logger is not None: 

655 _log("run_forward") 

656 cls._training_agent.run_forward( 

657 forward_inputs, forward_outputs, state, cls._cache) 

658 

659 self.save_for_backward(inputs) 

660 if logger is not None: 

661 _log("end") 

662 return forward_outputs 

663 else: 

664 # what about bind_input (+ data_ptr) 

665 if len(forward_inputs) != len(cls._grad_input_names): 

666 raise RuntimeError( # pragma: no cover 

667 "Size mismatch len(inputs)=%d, len(onnx inputs)=%d." % ( 

668 len(forward_inputs), len(cls._grad_input_names))) 

669 iobinding = SessionIOBinding(cls._sess_eval._sess) 

670 if logger is not None: 

671 _log("bind inputs %r", cls._grad_input_names) 

672 for name, inp in zip( 

673 cls._grad_input_names, forward_inputs): 

674 iobinding.bind_ortvalue_input(name, inp) 

675 

676 # bind output 

677 if logger is not None: 

678 _log("bind outputs %r", cls._output_names) 

679 for name, dev in zip( 

680 cls._output_names, cls._fw_no_grad_output_device_info): 

681 iobinding.bind_output(name, dev) 

682 

683 # if the shape is known in advance 

684 # iobinding.bind_output( 

685 # output_desc.name, torch_tensor.device.type, 

686 # _utils.get_device_index(target_device), 

687 # _utils.dtype_torch_to_numpy(torch_tensor.dtype), 

688 # list(torch_tensor.size()), torch_tensor.data_ptr()) 

689 

690 if logger is not None: 

691 _log("grad_enabled=False (run_with_iobinding)") 

692 cls._sess_eval._sess.run_with_iobinding( 

693 iobinding, cls._run_options) 

694 if logger is not None: 

695 _log("get_outputs") 

696 ortvalues = iobinding.get_outputs() 

697 if logger is not None: 

698 _log("to torck.tensor (%d)", len(ortvalues)) 

699 _log("end") 

700 return ortvalues 

701 

702 def backward(self, grad_outputs, backward_outputs_cache=None): 

703 """ 

704 Implements backward function. The function returns 

705 an :epkg:`OrtValueVector`. 

706 """ 

707 cls = self.__class__ 

708 logger = cls._logger 

709 

710 def _log(msg, *args): 

711 logger.debug("[%s.backward] (%dI) " + msg, 

712 cls.__name__, len(grad_outputs), *args) 

713 

714 if logger is not None: 

715 _log("begin") 

716 _log("torch function %r", type(cls)) 

717 _log("ort class %r", cls) 

718 _log("saved_tensors") 

719 

720 inputs = self.saved_tensors 

721 if logger is not None: 

722 _log("DEBUG: saved_tensors %r", type(inputs)) 

723 _log("self.state_.pop()") 

724 state = self.states_.pop() 

725 

726 if logger is not None: 

727 _log("create OrtValueVector") 

728 

729 backward_inputs = cls.input_to_ort( 

730 grad_outputs, cls._bw_outputs_device_info, cls._debug) 

731 

732 if logger is not None: 

733 _log("len(grad_outputs)=%d type(grad_outputs)=%r", 

734 len(grad_outputs), type(grad_outputs)) 

735 _log("len(backward_inputs)=%d type(backward_inputs)=%r", 

736 len(backward_inputs), type(backward_inputs)) 

737 for i in range(len(backward_inputs)): # pylint: disable=C0200 

738 _log("backward_inputs[%d].shape=%r", 

739 i, backward_inputs[i].shape()) 

740 _log("run_backward") 

741 backward_outputs = backward_outputs_cache or OrtValueVector() 

742 cls._training_agent.run_backward( 

743 backward_inputs, backward_outputs, state) 

744 if logger is not None: # pragma: no cover 

745 _log("DEBUG") 

746 for i, ov in enumerate(backward_outputs): 

747 _log("BCK-RET: i=%d - shape=%r - ptr=%r", 

748 i, ov.shape(), ov.data_ptr()) 

749 _log("got %r gradients", len(backward_outputs)) 

750 _log("end") 

751 return backward_outputs