Full Training with OrtGradientOptimizer¶
Design¶
onnxruntime was initially designed to speed up inference and deployment but it can also be used to train a model. It builds a graph equivalent to the gradient function also based on onnx operators and specific gradient operators. Initializers are weights that can be trained. The gradient graph has as many as outputs as initializers.
OrtGradientOptimizer
wraps
class TrainingSession from onnxruntime-training.
It starts with one model converted into ONNX graph.
A loss must be added to this graph. Then class TrainingSession
is able to compute another ONNX graph equivalent to the gradient
of the loss against the weights defined by intializers.
The first ONNX graph implements a function Y=f(W, X).
Then function add_loss_output
adds a loss to define a graph loss, Y=loss(f(W, X), W, expected_Y).
This same function is able to add the necessary nodes to compute
L1 and L2 losses or a combination of both, a L1 or L2 regularizations
or a combination of both. Assuming the user was able to create
an an ONNX graph, he would add 0.1 L1 loss + 0.9 L2 loss
and a L2 regularization on the coefficients by calling add_loss_output
like that:
onx_loss = add_loss_output(
onx, weight_name='weight', score_name='elastic',
l1_weight=0.1, l2_weight=0.9,
penalty={'coef': {'l2': 0.01}})
An instance of class OrtGradientOptimizer
is
initialized:
train_session = OrtGradientOptimizer(
onx_loss, ['intercept', 'coef'], learning_rate=1e-3)
And then trained:
train_session.fit(X_train, y_train, w_train)
Coefficients can be retrieved like the following:
state_tensors = train_session.get_state()
And train losses:
losses = train_session.train_losses_
This design does not allow any training with momentum, keeping an accumulator for gradients yet. The class does not expose all the possibilies implemented in onnxruntime-training. Next examples show that in practice.
Examples¶
The first example compares a linear regression trained with scikit-learn and another one trained with onnxruntime-training.
The two next examples explains in details how the training
with onnxruntime-training. They dig into class
OrtGradientOptimizer
.
It leverages class TrainingSession from onnxruntime-training.
This one assumes the loss function is part of the graph to train.
It takes care to the weight updating as well.
The fourth example replicates what was done with the linear regression but with a neural network built by scikit-learn. It trains the network on CPU or GPU if it is available. The last example benchmarks the different approaches.
- Train a linear regression with onnxruntime-training
- Train a linear regression with onnxruntime-training in details
- Train a linear regression with onnxruntime-training on GPU in details
- Train a scikit-learn neural network with onnxruntime-training on GPU
- Benchmark, comparison scikit-learn - onnxruntime-training