Coverage for onnxcustom/training/optimizers.py: 98%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

141 statements  

1""" 

2@file 

3@brief Optimizer with :epkg:`onnxruntime-training`. 

4""" 

5import numpy 

6from onnxruntime import ( # pylint: disable=E0611 

7 TrainingParameters, SessionOptions, TrainingSession) 

8from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

9 OrtValue as C_OrtValue, SessionIOBinding as C_IOBinding) 

10from ..utils.onnxruntime_helper import ( 

11 numpy_to_ort_value, device_to_providers) 

12from .data_loader import OrtDataLoader 

13from .excs import ConvergenceError, EvaluationError 

14from ._base_estimator import BaseEstimator 

15 

16 

17class OrtGradientOptimizer(BaseEstimator): 

18 """ 

19 Implements a simple :epkg:`Stochastic Gradient Descent` 

20 with :epkg:`onnxruntime-training`. 

21 

22 :param model_onnx: onnx graph to train 

23 :param weights_to_train: names of initializers to be optimized 

24 :param loss_output_name: name of the loss output 

25 :param max_iter: number of training iterations 

26 :param training_optimizer_name: optimizing algorithm 

27 :param batch_size: batch size (see class *DataLoader*) 

28 :param learning_rate: a name or a learning rate instance or a float, 

29 see module :mod:`onnxcustom.training.sgd_learning_rate` 

30 :param device: device as :epkg:`C_OrtDevice` or a string 

31 representing this device 

32 :param warm_start: when set to True, reuse the solution of the previous 

33 call to fit as initialization, otherwise, just erase the previous 

34 solution. 

35 :param verbose: use :epkg:`tqdm` to display the training progress 

36 :param validation_every: validation with a test set every 

37 *validation_every* iterations 

38 :param saved_gradient: if not None, a filename, 

39 the optimizer saves the gradient into it 

40 :param sample_weight_name: name of the sample weight input 

41 

42 Once initialized, the class creates the attribute 

43 `train_session_` which holds an instance of :ref:`l-ort-training-session`. 

44 

45 See example :ref:`l-orttraining-nn-gpu`. 

46 """ 

47 

48 def __init__(self, model_onnx, weights_to_train, loss_output_name='loss', 

49 max_iter=100, training_optimizer_name='SGDOptimizer', 

50 batch_size=10, learning_rate='SGD', 

51 device='cpu', warm_start=False, verbose=0, 

52 validation_every=0.1, saved_gradient=None, 

53 sample_weight_name="weight"): 

54 BaseEstimator.__init__(self, model_onnx, learning_rate, device) 

55 self.batch_size = batch_size 

56 self.weights_to_train = weights_to_train 

57 self.loss_output_name = loss_output_name 

58 self.training_optimizer_name = training_optimizer_name 

59 self.verbose = verbose 

60 self.max_iter = max_iter 

61 self.warm_start = warm_start 

62 self.saved_gradient = saved_gradient 

63 self.sample_weight_name = sample_weight_name 

64 if validation_every < 1: 

65 self.validation_every = int(self.max_iter * validation_every) 

66 else: 

67 self.validation_every = validation_every # pragma: no cover 

68 if self.learning_rate.needs_grad: 

69 raise NotImplementedError( 

70 "Any weight update involving past gradient is " 

71 "not implemented (class %r)." 

72 "" % self.learning_rate.__class__.__name__) 

73 

74 def fit(self, X, y, sample_weight=None, X_val=None, y_val=None, 

75 use_numpy=False): 

76 """ 

77 Trains the model. 

78 

79 :param X: features 

80 :param y: expected output 

81 :param sample_weight: sample weight if any 

82 :param X_val: evaluation dataset 

83 :param y_val: evaluation dataset 

84 :param use_numpy: if True, slow iterator using numpy, 

85 otherwise, minimizes copy 

86 :return: self 

87 """ 

88 input_names = [i.name for i in self.model_onnx.graph.input] 

89 if ((len(input_names) == 2 and sample_weight is not None) or 

90 (len(input_names) == 3 and sample_weight is None)): 

91 raise RuntimeError( # pragma: no cover 

92 "Number of inputs should be 2 if sample_weight is None " 

93 "or 3 if not None but it is %d." % len(input_names)) 

94 self.train_session_ = self._create_training_session( 

95 self.model_onnx, self.weights_to_train, 

96 loss_output_name=self.loss_output_name, 

97 training_optimizer_name=self.training_optimizer_name, 

98 device=self.device) 

99 

100 if not self.warm_start: 

101 state = self.get_state() 

102 new_state = {} 

103 for k, v in state.items(): 

104 if len(v.shape) > 0: 

105 new_state[k] = numpy.random.randn(*v.shape).astype(v.dtype) 

106 else: 

107 f = numpy.random.randn(1) 

108 f = f.astype(v.dtype) 

109 new_state[k] = f 

110 self.set_state(new_state) 

111 

112 data_loader = OrtDataLoader( 

113 X, y, sample_weight=sample_weight, 

114 batch_size=self.batch_size, device=self.device) 

115 if X_val is not None: 

116 data_loader_val = OrtDataLoader( 

117 X_val, y_val, batch_size=X_val.shape[0], device=self.device, 

118 random_iter=False) 

119 else: 

120 data_loader_val = None 

121 

122 self.learning_rate.init_learning_rate() 

123 self.input_names_ = [i.name for i in self.train_session_.get_inputs()] 

124 self.output_names_ = [ 

125 o.name for o in self.train_session_.get_outputs()] 

126 self.loss_index_ = self.output_names_.index(self.loss_output_name) 

127 

128 bind = self.train_session_.io_binding()._iobinding 

129 

130 if self.verbose > 0: # pragma: no cover 

131 from tqdm import tqdm # pylint: disable=C0415 

132 loop = tqdm(range(self.max_iter)) 

133 else: 

134 loop = range(self.max_iter) 

135 

136 self.train_losses_ = [] 

137 self.validation_losses_ = [] 

138 lr = self.learning_rate.value 

139 for it in loop: 

140 lr_alive = numpy.array([lr / self.batch_size], dtype=numpy.float32) 

141 ort_lr = numpy_to_ort_value(lr_alive, self.device) 

142 loss = self._iteration(data_loader, ort_lr, 

143 bind, use_numpy=use_numpy, 

144 sample_weight=sample_weight is not None) 

145 lr = self.learning_rate.update_learning_rate(it).value 

146 if self.verbose > 1: # pragma: no cover 

147 loop.set_description( 

148 "loss=%1.3g lr=%1.3g " # pylint: disable=E1307 

149 "lrn=%1.3g" % ( 

150 loss, lr, lr_alive[0])) 

151 self.train_losses_.append(loss) 

152 if (data_loader_val is not None and 

153 (it + 1) % self.validation_every == 0): 

154 self.validation_losses_.append( 

155 self._evaluation(data_loader_val, bind)) 

156 self.trained_coef_ = self.train_session_.get_state() 

157 return self 

158 

159 def _bind_input_ortvalue(self, name, bind, c_ortvalue): 

160 """ 

161 Binds :epkg:`C_OrtValue` to the structure used by 

162 :epkg:`InferenceSession` to run inference. 

163 

164 :param name: str 

165 :param bind: python structure 

166 :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`), 

167 it can be also a numpy array 

168 """ 

169 if not isinstance(bind, C_IOBinding): 

170 raise TypeError( # pragma: no cover 

171 "Unexpected type %r." % type(bind)) 

172 if isinstance(c_ortvalue, C_OrtValue): 

173 bind.bind_ortvalue_input(name, c_ortvalue) 

174 elif isinstance(c_ortvalue, numpy.ndarray): 

175 # This fails on linux with int64. 

176 bind.bind_input( 

177 name, self.device, c_ortvalue.dtype, c_ortvalue.shape, 

178 c_ortvalue.__array_interface__['data'][0]) 

179 else: 

180 raise TypeError( # pragma: no cover 

181 "Unable to bind type %r for name %r." % ( 

182 type(c_ortvalue), name)) 

183 

184 def _iteration(self, data_loader, ort_lr, bind, use_numpy, sample_weight): 

185 actual_losses = [] 

186 

187 bind.bind_output('loss', self.device) 

188 idx = 3 if sample_weight else 2 

189 

190 if use_numpy: 

191 # onnxruntime does not copy the data, so the numpy 

192 # array must remain alive all along the iteration 

193 lr_alive = ort_lr.numpy() 

194 self._bind_input_ortvalue( 

195 self.input_names_[idx], bind, lr_alive) 

196 

197 # Slow iterations. 

198 for it in data_loader.iter_numpy(): 

199 if len(it) == 2: 

200 data, target = it 

201 self._bind_input_ortvalue( 

202 self.input_names_[0], bind, data) 

203 self._bind_input_ortvalue( 

204 self.input_names_[1], bind, target) 

205 else: 

206 data, target, weight = it 

207 self._bind_input_ortvalue( 

208 self.input_names_[0], bind, data) 

209 self._bind_input_ortvalue( 

210 self.input_names_[1], bind, target) 

211 self._bind_input_ortvalue( 

212 self.input_names_[2], bind, weight) 

213 

214 self.train_session_._sess.run_with_iobinding(bind, None) 

215 loss = bind.get_outputs()[0].numpy() 

216 if numpy.isinf(loss) or numpy.isnan(loss): 

217 raise ConvergenceError( 

218 "Loss is nan, learning_rate=%r, " 

219 "the gradient descent has failed " 

220 "(past losses=%r)." % ( 

221 ort_lr.numpy(), 

222 [float(v[0]) for v in ( 

223 actual_losses if len(actual_losses) < 5 

224 else actual_losses[-5:])])) 

225 actual_losses.append(loss / data.shape[0]) 

226 else: 

227 self._bind_input_ortvalue(self.input_names_[idx], bind, ort_lr) 

228 

229 # Fast iterations 

230 # Slow iterations. 

231 for batch_size in data_loader.iter_bind(bind, self.input_names_): 

232 self.train_session_._sess.run_with_iobinding(bind, None) 

233 # We copy the predicted output as well which is not needed. 

234 loss = bind.get_outputs()[0].numpy() 

235 if numpy.isinf(loss) or numpy.isnan(loss): 

236 raise ConvergenceError( 

237 "Loss is nan or infinite, learning_rate=%r, " 

238 "the gradient descent has failed " 

239 "(past losses=%r)." % ( 

240 ort_lr.numpy(), 

241 [float(v[0]) for v in ( 

242 actual_losses if len(actual_losses) < 5 

243 else actual_losses[-5:])])) 

244 actual_losses.append(loss / batch_size) 

245 

246 return numpy.array(actual_losses).mean() 

247 

248 def _evaluation(self, data_loader, bind): 

249 lr_alive = numpy.array([0], dtype=numpy.float32) 

250 self._bind_input_ortvalue(self.input_names_[2], bind, lr_alive) 

251 bind.bind_output('loss', self.device) 

252 

253 actual_losses = [] 

254 total = 0 

255 for batch_size in data_loader.iter_bind(bind, self.input_names_): 

256 self.train_session_._sess.run_with_iobinding(bind, None) 

257 outputs = bind.copy_outputs_to_cpu() 

258 if numpy.isinf(outputs[0]) or numpy.isnan(outputs[0]): 

259 raise EvaluationError( # pragma: no cover 

260 "Loss is nan or infinite (%r), " 

261 "evaluation has failed." % outputs[0]) 

262 actual_losses.append(outputs[0]) 

263 total += batch_size 

264 return numpy.array(actual_losses).sum() / max(total, 1) 

265 

266 def _create_training_session( 

267 self, training_onnx, weights_to_train, 

268 loss_output_name='loss', 

269 training_optimizer_name='SGDOptimizer', 

270 device='cpu'): 

271 """ 

272 Creates an instance of :epkg:`TrainingSession`. 

273 

274 :param training_onnx: an ONNX graph with a loss function 

275 :param weights_to_train: list of initializer names to optimize 

276 :param loss_output_name: output name for the loss 

277 :param training_optimizer_name: optimizer name 

278 :param device: one :epkg:`C_OrtDevice` or a string 

279 :return: an instance of :epkg:`TrainingSession` 

280 """ 

281 if training_optimizer_name != 'SGDOptimizer': 

282 raise NotImplementedError( 

283 "Only the SGDOptimizer is implemented not %r." 

284 "" % training_optimizer_name) 

285 ort_parameters = TrainingParameters() 

286 ort_parameters.loss_output_name = loss_output_name 

287 ort_parameters.use_mixed_precision = False 

288 # ort_parameters.world_rank = -1 

289 # ort_parameters.world_size = 1 

290 # ort_parameters.gradient_accumulation_steps = 1 

291 # ort_parameters.allreduce_post_accumulation = False 

292 # ort_parameters.deepspeed_zero_stage = 0 

293 # ort_parameters.enable_grad_norm_clip = False 

294 # ort_parameters.set_gradients_as_graph_outputs = False 

295 # ort_parameters.use_memory_efficient_gradient = False 

296 # ort_parameters.enable_adasum = False 

297 if self.saved_gradient is not None: 

298 name = self.saved_gradient 

299 name2 = name + ".training.onnx" 

300 ort_parameters.model_with_gradient_graph_path = name 

301 ort_parameters.model_with_training_graph_path = name2 

302 

303 output_types = {} 

304 for output in training_onnx.graph.output: 

305 output_types[output.name] = output.type.tensor_type 

306 

307 ort_parameters.weights_to_train = set(weights_to_train) 

308 ort_parameters.training_optimizer_name = training_optimizer_name 

309 # ort_parameters.lr_params_feed_name = lr_params_feed_name 

310 

311 ort_parameters.optimizer_attributes_map = { 

312 name: {} for name in weights_to_train} 

313 ort_parameters.optimizer_int_attributes_map = { 

314 name: {} for name in weights_to_train} 

315 

316 session_options = SessionOptions() 

317 session_options.log_severity_level = 4 

318 session_options.log_verbosity_level = 4 

319 # session_options.use_deterministic_compute = True 

320 

321 providers = device_to_providers(self.device) 

322 session = TrainingSession( 

323 training_onnx.SerializeToString(), ort_parameters, session_options, 

324 providers=providers) 

325 

326 return session 

327 

328 def get_state(self): 

329 """ 

330 Returns the trained weights. 

331 """ 

332 if not hasattr(self, 'train_session_'): 

333 if hasattr(self, 'trained_coef_'): 

334 return self.trained_coef_ 

335 raise AttributeError("Method fit must be called before.") 

336 return self.train_session_.get_state() 

337 

338 def get_trained_onnx(self, model=None): 

339 """ 

340 Returns the trained onnx graph, the initial graph 

341 modified by replacing the initializers with the trained 

342 weights. If model is not specified, it uses the model 

343 given as an argument to this class. This graph outputs 

344 the loss and not the predictions. Parameter *model* 

345 can be used to use the graph before loss was added 

346 and then the returned graph will produce the predictions. 

347 

348 :param model: replace the weights in another graph 

349 than the training graph 

350 :return: onnx graph 

351 """ 

352 return self._get_trained_onnx(self.get_state(), model=model) 

353 

354 def set_state(self, state): 

355 """ 

356 Changes the trained weights. 

357 """ 

358 if not hasattr(self, 'train_session_'): 

359 raise AttributeError( # pragma: no cover 

360 "Method fit must be called before.") 

361 return self.train_session_.load_state(state)