Coverage for onnxcustom/training/optimizers_partial.py: 98%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

283 statements  

1""" 

2@file 

3@brief Optimizer with :epkg:`onnxruntime-training` forward backward training. 

4""" 

5import logging 

6import warnings 

7import numpy 

8from onnxruntime import InferenceSession, SessionOptions 

9from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

10 OrtValue as C_OrtValue) 

11from ..utils.onnx_helper import get_onnx_opset, proto_type_to_dtype 

12from ..utils.onnxruntime_helper import ( 

13 device_to_providers, numpy_to_ort_value) 

14from ..utils.onnx_function import function_onnx_graph 

15from ..utils.print_helper import str_ortvalue 

16from ..utils.orttraining_helper import get_train_initializer 

17from .ortgradient import OrtGradientForwardBackward 

18from ._base_estimator import BaseEstimator 

19from .sgd_learning_loss import BaseLearningLoss 

20from .sgd_learning_penalty import BaseLearningPenalty 

21from .data_loader import OrtDataLoader 

22from .excs import ConvergenceError, ConvergenceWarning 

23 

24 

25class OrtGradientForwardBackwardOptimizer(BaseEstimator): 

26 """ 

27 Implements a simple :epkg:`Stochastic Gradient Descent` 

28 with :epkg:`onnxruntime-training`. It leverages class 

29 @see class OrtGradientForwardBackward. 

30 

31 :param model_onnx: onnx graph to train 

32 :param weights_to_train: names of initializers to be optimized, 

33 if None, function @see fn get_train_initialize returns 

34 the list of float iniitializer 

35 :param loss_output_name: name of the loss output 

36 :param max_iter: number of training iterations 

37 :param training_optimizer_name: optimizing algorithm 

38 :param batch_size: batch size (see class *DataLoader*) 

39 :param learning_rate: a name or a learning rate instance or a float, 

40 see module :mod:`onnxcustom.training.sgd_learning_rate` 

41 :param device: device as :epkg:`C_OrtDevice` or a string 

42 representing this device 

43 :param warm_start: when set to True, reuse the solution of the previous 

44 call to fit as initialization, otherwise, just erase the previous 

45 solution. 

46 :param learning_loss: loss function (see below) 

47 :param verbose: use :epkg:`tqdm` to display the training progress 

48 :param validation_every: validation with a test set every 

49 *validation_every* iterations 

50 :param enable_logging: enable logging (mostly for debugging puporse 

51 as it slows down the training) 

52 :param weight_name: if not None, the class assumes it is trained 

53 with training weight 

54 :param learning_penalty: weight penalty, None, or instance of 

55 @see cl BaseLearningPenalty 

56 :param exc: raise exceptions (about convergence for example) 

57 or keep them silent as much as possible 

58 

59 *learning_rate* can be any instance of @see cl BaseLearningRate or 

60 a nick name in the following list as specified in 

61 :meth:`BaseLearningRate.select 

62 <onnxcustom.training.sgd_learning_loss.BaseLearningRate.select>`. 

63 

64 *learning_loss* can be any instance of @see cl BaseLearningLoss or 

65 a nick name in the following list as specified in 

66 :meth:`BaseLearningLoss.select 

67 <onnxcustom.training.sgd_loss.BaseLearningLoss.select>`. 

68 """ 

69 

70 def __init__(self, model_onnx, weights_to_train=None, 

71 loss_output_name='loss', max_iter=100, 

72 training_optimizer_name='SGDOptimizer', 

73 batch_size=10, learning_rate='SGD', 

74 device='cpu', warm_start=False, verbose=0, 

75 validation_every=0.1, learning_loss="square_error", 

76 enable_logging=False, weight_name=None, 

77 learning_penalty=None, exc=True): 

78 if weights_to_train is None: 

79 weights_to_train = list(get_train_initializer(model_onnx)) 

80 BaseEstimator.__init__(self, model_onnx, learning_rate, device) 

81 self.batch_size = batch_size 

82 self.weights_to_train = weights_to_train 

83 self.loss_output_name = loss_output_name 

84 self.training_optimizer_name = training_optimizer_name 

85 self.verbose = verbose 

86 self.max_iter = max_iter 

87 self.warm_start = warm_start 

88 self.learning_loss = BaseLearningLoss.select(learning_loss) 

89 self.learning_penalty = BaseLearningPenalty.select(learning_penalty) 

90 self.enable_logging = enable_logging 

91 self.weight_name = weight_name 

92 self.exc = exc 

93 if validation_every < 1: 

94 self.validation_every = int(self.max_iter * validation_every) 

95 else: 

96 self.validation_every = validation_every # pragma: no cover 

97 self.build_onnx_function() 

98 

99 @property 

100 def needs_grad(self): 

101 """ 

102 Returns the True if the gradient update needs to retain 

103 past gradients. 

104 """ 

105 return self.learning_rate.needs_grad 

106 

107 def __getstate__(self): 

108 "Removes any non pickable attribute." 

109 state = BaseEstimator.__getstate__(self) 

110 for att in ['train_state_', 'train_grad_state_']: 

111 if hasattr(self, att): 

112 train_state = [] 

113 for v in self.get_state(): 

114 if v is None: 

115 train_state.append(v) 

116 else: 

117 train_state.append(v.numpy()) 

118 state[att[:-1]] = train_state 

119 return state 

120 

121 def __setstate__(self, state): 

122 "Restores any non pickable attribute." 

123 popped = {} 

124 for att in ['train_state', 'train_grad_state']: 

125 if att in state: 

126 popped[att] = state.pop(att) 

127 BaseEstimator.__setstate__(self, state) 

128 for k, v in popped.items(): 

129 if k == 'train_state': 

130 self.set_state(v, check_trained=False, kind='weight') 

131 elif k == 'train_grad_state': 

132 self.set_state(v, check_trained=False, kind='grad') 

133 else: 

134 raise ValueError( # pragma: no cover 

135 "Unexpected key state %r." % k) 

136 self.build_onnx_function() 

137 return self 

138 

139 def _get_att_state(self, kind): 

140 if kind == 'weight': 

141 return 'train_state_' 

142 if kind == 'grad': 

143 return 'train_grad_state_' 

144 raise ValueError( # pragma: no cover 

145 "Unexpected kind=%r." % kind) 

146 

147 def get_full_state(self, kind='weight'): 

148 """ 

149 Returns the trained weights and the inputs. 

150 """ 

151 if isinstance(kind, list): 

152 return [self.get_full_state(kind=k) for k in kind] 

153 att = self._get_att_state(kind) 

154 if not hasattr(self, att): 

155 raise AttributeError( # pragma: no cover 

156 "Method fit must be called before.") 

157 return getattr(self, att) 

158 

159 def get_state(self, kind='weight'): 

160 """ 

161 Returns the trained weights. 

162 """ 

163 att = self._get_att_state(kind) 

164 if not hasattr(self, att): 

165 raise AttributeError("Method fit must be called before.") 

166 if getattr(self, att, None) is None: 

167 raise RuntimeError( # pragma: no cover 

168 "No attribute %r available (None)." % att) 

169 if self.weights_to_train is None: 

170 raise RuntimeError( # pragma: no cover 

171 "Unexpected self.weights_to_train (None).") 

172 value = getattr(self, att) 

173 n = len(value) - len(self.weights_to_train) 

174 return value[n:] 

175 

176 @property 

177 def trained_coef_(self): 

178 """ 

179 Returns the trained coefficients a dictionary. 

180 """ 

181 return dict(zip(self.weights_to_train, self.get_state())) 

182 

183 def get_trained_onnx(self, model=None): 

184 """ 

185 Returns the trained onnx graph, the initial graph 

186 modified by replacing the initializers with the trained 

187 weights. 

188 

189 :param model: replace the weights in another graph 

190 than the training graph 

191 :return: onnx graph 

192 """ 

193 state = dict(zip(self.weights_to_train, self.get_state())) 

194 return self._get_trained_onnx(state, model=model) 

195 

196 def set_state(self, state, check_trained=True, kind='weight', zero=False): 

197 """ 

198 Changes the trained weights. 

199 """ 

200 if check_trained and not hasattr(self, 'train_session_'): 

201 raise AttributeError( # pragma: no cover 

202 "Method fit must be called before.") 

203 state_ = [] 

204 state_numpy_ = [] 

205 for i, v in enumerate(state): 

206 if v is None: 

207 state_.append(None) 

208 state_numpy_.append(None) 

209 elif isinstance(v, numpy.ndarray): 

210 if zero: 

211 v = numpy.zeros(v.shape, dtype=v.dtype) 

212 ortvalue = numpy_to_ort_value(v, self.device) 

213 state_.append(ortvalue) 

214 # The numpy container must be retained as the ortvalue 

215 # just borrows the pointer. 

216 state_numpy_.append(v) 

217 elif isinstance(v, C_OrtValue): 

218 if zero: 

219 v = self.zero_sess_.run_with_ort_values(['Y'], {'X': v}) 

220 state_.append(v) 

221 state_numpy_.append(None) 

222 else: 

223 raise TypeError( # pragma: no cover 

224 "Unexpected type %r for state %r." % ( 

225 type(v), i)) 

226 att = self._get_att_state(kind) 

227 setattr(self, att, state_) 

228 setattr(self, att + "numpy_", state_numpy_) 

229 

230 def build_onnx_function(self): 

231 """ 

232 Creates ONNX graph and *InferenceSession* related to 

233 any operations applying on *OrtValue*. 

234 """ 

235 opset = get_onnx_opset(self.model_onnx) 

236 so = SessionOptions() 

237 so.log_severity_level = 4 

238 

239 n = len(self.weights_to_train) 

240 

241 # loss_grad 

242 self.learning_loss.build_onnx_function( 

243 opset, self.device, self.weight_name) 

244 

245 # weight update 

246 self.learning_rate.build_onnx_function(opset, self.device, n) 

247 

248 # regularization 

249 self.learning_penalty.build_onnx_function(opset, self.device, n) 

250 

251 # zero 

252 self.zero_onnx_ = function_onnx_graph("zero") 

253 self.zero_sess_ = InferenceSession( 

254 self.zero_onnx_.SerializeToString(), so, 

255 providers=device_to_providers(self.device)) 

256 

257 # logging 

258 if self.enable_logging: 

259 self._logger = logging.getLogger("onnxcustom") 

260 else: 

261 self._logger = None 

262 

263 def fit(self, X, y, sample_weight=None, 

264 X_val=None, y_val=None): 

265 """ 

266 Trains the model. 

267 

268 :param X: features 

269 :param y: expected output 

270 :param sample_weight: training weight or None 

271 :param X_val: evaluation dataset 

272 :param y_val: evaluation dataset 

273 :return: self 

274 """ 

275 if self.training_optimizer_name != 'SGDOptimizer': 

276 raise NotImplementedError( 

277 "Only the SGDOptimizer is implemented not %r." 

278 "" % self.training_optimizer_name) 

279 logger = self._logger 

280 

281 session_function = self._create_training_session( 

282 self.model_onnx, self.weights_to_train, 

283 device=self.device) 

284 self.train_session_ = session_function[0] 

285 self.train_function_ = session_function[1] 

286 

287 self.input_names_ = self.train_session_.cls_type_._grad_input_names 

288 self.output_names_ = self.train_session_.cls_type_._bw_fetches_names 

289 weights_to_train = self.train_session_.weights_to_train 

290 

291 if logger is not None: 

292 logger.info( 

293 "[OrtGradientForwardBackwardOptimizer.fit] " 

294 "input_names=%r", self.input_names_) 

295 logger.info( 

296 "[OrtGradientForwardBackwardOptimizer.fit] " 

297 "output_names=%r", self.output_names_) 

298 logger.info( 

299 "[OrtGradientForwardBackwardOptimizer.fit] " 

300 "weights_to_train=%r", self.weights_to_train) 

301 logger.info( 

302 "[OrtGradientForwardBackwardOptimizer.fit] " 

303 "device=%r|%r", 

304 self.device.device_id(), self.device.device_type()) 

305 if logger is not None: 

306 logger.info( 

307 "[OrtGradientForwardBackwardOptimizer.fit] " 

308 "warm_start=%r", self.warm_start) 

309 

310 if not hasattr(self, 'state_'): 

311 self.set_state([ 

312 self.train_session_.get_initializer(name, exc=False) 

313 for name in self.input_names_]) 

314 if self.needs_grad and not hasattr(self, 'state_grad_'): 

315 self.set_state([ 

316 self.train_session_.get_initializer(name, exc=False) 

317 for name in self.input_names_], 

318 kind='grad', zero=True) 

319 if not self.warm_start: 

320 state = self.get_full_state() 

321 if len(state) != len(self.input_names_): 

322 raise RuntimeError( # pragma: no cover 

323 "Length mismatch %r != %r." % ( 

324 len(state), len(self.input_names_))) 

325 new_state = [] 

326 for iv, v in enumerate(state): 

327 if v is None: 

328 new_state.append(v) 

329 else: 

330 if not isinstance(v, C_OrtValue): 

331 raise RuntimeError( # pragma: no cover 

332 "Unexpected type %r (state[%d])." % ( 

333 type(v), iv)) 

334 dtype = proto_type_to_dtype( 

335 v.proto_type() 

336 if hasattr(v, 'proto_type') 

337 else v.data_type()) 

338 if len(v.shape()) > 0: 

339 new_state.append( 

340 numpy.random.randn(*v.shape()).astype(dtype)) 

341 else: 

342 new_state.append( 

343 numpy.random.randn(1).astype(dtype)) 

344 self.set_state(new_state) 

345 if self.needs_grad: 

346 self.set_state(new_state, kind='grad', zero=True) 

347 

348 data_loader = OrtDataLoader( 

349 X, y, sample_weight, batch_size=self.batch_size, 

350 device=self.device) 

351 if X_val is not None: 

352 data_loader_val = OrtDataLoader( 

353 X_val, y_val, batch_size=X_val.shape[0], device=self.device, 

354 random_iter=False) 

355 else: 

356 data_loader_val = None 

357 

358 self.learning_rate.init_learning_rate() 

359 

360 if self.verbose > 0: # pragma: no cover 

361 from tqdm import tqdm # pylint: disable=C0415 

362 loop = tqdm(range(self.max_iter)) 

363 else: 

364 loop = range(self.max_iter) 

365 

366 self.train_losses_ = [] 

367 val_losses = [] 

368 kinds = ['weight', 'grad'] if self.needs_grad else ['weight'] 

369 for it in loop: 

370 loss = self._iteration( 

371 data_loader, self.get_full_state(kind=kinds), 

372 len(weights_to_train)) 

373 lr = self.learning_rate.update_learning_rate(it).value 

374 if self.verbose > 1: # pragma: no cover 

375 loop.set_description( 

376 "loss=%1.3g lr=%1.3g" % ( # pylint: disable=E1101,E1307 

377 loss, lr)) # pylint: disable=E1101,E1307 

378 if logger is not None: 

379 logger.info( 

380 "[OrtGradientForwardBackwardOptimizer.fit] " 

381 "lr value=%r", lr) 

382 

383 self.train_losses_.append(loss) 

384 if (data_loader_val is not None and 

385 (it + 1) % self.validation_every == 0): 

386 val_losses.append( 

387 self._evaluation(data_loader_val, self.get_full_state())) 

388 self.validation_losses_ = ( 

389 None if data_loader_val is None else val_losses) 

390 

391 if logger is not None: 

392 logger.info( 

393 "[OrtGradientForwardBackwardOptimizer.fit] " 

394 "end loss=%r", self.train_losses_[-1]) 

395 return self 

396 

397 def _iteration(self, data_loader, states, n_weights): 

398 actual_losses = [] 

399 bs = data_loader.batch_size 

400 logger = self._logger 

401 if len(states) == 1: 

402 state = states[0] 

403 grad = None 

404 else: 

405 state, grad = states 

406 

407 if logger is not None: 

408 logger.debug( 

409 "[OrtGradientForwardBackwardOptimizer._iteration] " 

410 "iteration begin learning_rate=%r", 

411 self.learning_rate) 

412 

413 prediction_cache = None 

414 prediction_cache_shape = None 

415 backward_outputs_cache = None 

416 for ib, ito in enumerate(data_loader.iter_ortvalue()): 

417 if len(ito) == 2: 

418 (ortx, orty) = ito 

419 ortw = None 

420 else: 

421 (ortx, orty, ortw) = ito 

422 state[0] = ortx 

423 

424 if logger is not None: 

425 logger.debug( 

426 "[OrtGradientForwardBackwardOptimizer._iteration] " 

427 "batch %d", ib) 

428 

429 ortx_shape = tuple(ortx.shape()) 

430 same_shape = ( 

431 prediction_cache_shape is not None and 

432 ortx_shape == prediction_cache_shape) 

433 

434 if logger is not None: 

435 logger.debug( 

436 "[OrtGradientForwardBackwardOptimizer._iteration] forward") 

437 

438 # forward 

439 if prediction_cache_shape is None or same_shape: 

440 prediction_cache = None 

441 prediction_cache_shape = None 

442 prediction = self.train_function_.forward( 

443 states[0], training=True, 

444 forward_outputs_cache=prediction_cache) 

445 prediction_cache = prediction 

446 prediction_cache_shape = ortx_shape 

447 

448 if logger is not None: 

449 logger.debug( 

450 "[OrtGradientForwardBackwardOptimizer._iteration] " 

451 "loss types=%r,%r", 

452 orty.data_type(), prediction[0].data_type()) 

453 

454 # loss 

455 loss, loss_gradient = self.learning_loss.loss_gradient( 

456 self.device, orty, prediction[0], weight=ortw) 

457 

458 if logger is not None: 

459 logger.debug( 

460 "[OrtGradientForwardBackwardOptimizer._iteration] " 

461 "loss=%g has_weight=%r", 

462 loss.numpy(), ortw is not None) 

463 

464 n = len(state) - n_weights 

465 loss = self.learning_penalty.penalty_loss( 

466 self.device, loss, *state[n:]) 

467 

468 cpu_loss = loss.numpy() 

469 

470 if logger is not None: 

471 logger.debug( 

472 "[OrtGradientForwardBackwardOptimizer._iteration] " 

473 "cpu_loss=%r", cpu_loss) 

474 

475 if numpy.isinf(cpu_loss) or numpy.isnan(cpu_loss): 

476 if self.exc: 

477 raise ConvergenceError( 

478 "Loss is nan, learning_rate=%r, " 

479 "the gradient descent has failed " 

480 "(past losses=%r)." % ( 

481 self.learning_rate, 

482 [float(v) for v in ( 

483 actual_losses if len(actual_losses) < 5 

484 else actual_losses[-5:])])) 

485 warnings.warn( # pragma: no cover 

486 "Loss is nan, learning_rate=%r, " 

487 "the gradient descent has failed " 

488 "(past losses=%r)." % ( 

489 self.learning_rate, 

490 [float(v) for v in ( 

491 actual_losses if len(actual_losses) < 5 

492 else actual_losses[-5:])]), 

493 ConvergenceWarning) 

494 if numpy.isinf(cpu_loss): # pragma: no cover 

495 cpu_loss = numpy.nan 

496 

497 # backward 

498 if not same_shape: 

499 backward_outputs_cache = None 

500 gradient = self.train_function_.backward( 

501 [loss_gradient], backward_outputs_cache=backward_outputs_cache) 

502 backward_outputs_cache = gradient 

503 

504 if len(gradient) != len(state): 

505 raise RuntimeError( # pragma: no cover 

506 "gradient and state should have the same length but " 

507 "%r != %r." % (len(gradient), len(state))) 

508 

509 n = len(state) - n_weights 

510 

511 for i in range(n, len(state)): 

512 self.learning_penalty.update_weights( 

513 i - n, self.device, state[i]) 

514 self.learning_rate.update_weights( 

515 i - n, self.device, state[i], 

516 gradient[i], bs, 

517 None if grad is None else grad[i]) 

518 

519 if logger is not None: 

520 logger.debug( 

521 "[OrtGradientForwardBackwardOptimizer._iteration] " 

522 "loss=%g n_weights=%d", cpu_loss, n) 

523 for i in range(n, len(state)): 

524 logger.debug( 

525 "[OrtGradientForwardBackwardOptimizer._iteration] " 

526 "state[%i]=%s", i, str_ortvalue(state[i])) 

527 

528 actual_losses.append(cpu_loss / bs) 

529 

530 if logger is not None: 

531 logger.debug( 

532 "[OrtGradientForwardBackwardOptimizer._iteration] " 

533 "iteration end") 

534 

535 return numpy.array(actual_losses).mean() 

536 

537 def _evaluation(self, data_loader, state): 

538 logger = self._logger 

539 actual_losses = [] 

540 for ib, (ortx, orty) in enumerate(data_loader.iter_ortvalue()): 

541 state[0] = ortx 

542 

543 if logger is not None: 

544 logger.debug( # pragma: no cover 

545 "[OrtGradientForwardBackwardOptimizer._evaluation] " 

546 "batch %d", ib) 

547 

548 prediction = self.train_function_.forward(state, training=False) 

549 loss, _ = self.learning_loss.loss_gradient( 

550 self.device, orty, prediction[0]) 

551 cpu_loss = loss.numpy() 

552 if numpy.isinf(cpu_loss) or numpy.isnan(cpu_loss): 

553 if self.exc: # pragma: no cover 

554 raise ConvergenceError( 

555 "Loss is nan, " 

556 "the evaluation has failed " 

557 "(past losses=%r)." % 

558 [float(v) for v in ( 

559 actual_losses if len(actual_losses) < 5 

560 else actual_losses[-5:])]) 

561 warnings.warn( 

562 "Loss is nan, learning_rate=%r, " 

563 "the gradient descent has failed " 

564 "(past losses=%r)." % ( 

565 self.learning_rate, 

566 [float(v) for v in ( 

567 actual_losses if len(actual_losses) < 5 

568 else actual_losses[-5:])]), 

569 ConvergenceWarning) 

570 if numpy.isinf(cpu_loss): 

571 cpu_loss = numpy.nan 

572 actual_losses.append(cpu_loss) 

573 

574 return numpy.array(actual_losses).sum() / len(data_loader) 

575 

576 def score(self, X, y, sample_weight=None): 

577 """ 

578 Return the whole score associated. 

579 

580 :param X: features 

581 :param y: expected output 

582 :param sample_weight: training weight or None 

583 :return: score 

584 """ 

585 scores = self.losses(X, y, sample_weight=sample_weight) 

586 return -scores.sum() / X.shape[0] 

587 

588 def losses(self, X, y, sample_weight=None): 

589 """ 

590 Returns the losses associated to every observation. 

591 

592 :param X: features 

593 :param y: expected output 

594 :param sample_weight: training weight or None 

595 :return: scores 

596 """ 

597 data_loader = OrtDataLoader( 

598 X, y, sample_weight, batch_size=self.batch_size, 

599 device=self.device) 

600 

601 state = self.get_full_state() 

602 scores = numpy.empty((X.shape[0], ), dtype=X.dtype) 

603 pos = 0 

604 for ito in data_loader.iter_ortvalue(): 

605 if len(ito) == 2: 

606 (ortx, orty) = ito 

607 ortw = None 

608 else: 

609 (ortx, orty, ortw) = ito 

610 state[0] = ortx 

611 prediction = self.train_function_.forward(state, training=False) 

612 score = self.learning_loss.loss_scores( 

613 self.device, orty, prediction[0], ortw) 

614 np_score = score.numpy() 

615 # data copy could be avoided by giving a pointer to 

616 # loss score or if we could create an OrtValue from a 

617 # pointer. 

618 end = pos + np_score.shape[0] 

619 if end <= scores.shape[0]: 

620 scores[pos: end] = np_score.ravel() 

621 else: 

622 scores[pos: end] = np_score.ravel()[end - scores.shape[0]:] 

623 pos += np_score.shape[0] 

624 return scores 

625 

626 def _create_training_session( 

627 self, model_onnx, weights_to_train, device): 

628 

629 forback = OrtGradientForwardBackward( 

630 model_onnx, weights_to_train=weights_to_train, 

631 debug=False, enable_logging=False, 

632 providers=device_to_providers(device)) 

633 inst = forback.new_instance() 

634 return (forback, inst)