Forward backward on a neural network on GPU

This example leverages example Train a linear regression with onnxruntime-training on GPU in details to train a neural network from scikit-learn on GPU. The code uses the same code introduced in Train a linear regression with forward backward.

A neural network with scikit-learn

import warnings
import numpy
from pandas import DataFrame
from onnxruntime import get_device
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error
from onnxcustom.plotting.plotting_onnx import plot_onnxs
from mlprodict.onnx_conv import to_onnx
from onnxcustom.utils.orttraining_helper import get_train_initializer
from onnxcustom.utils.onnx_helper import onnx_rename_weights
from onnxcustom.training.optimizers_partial import (
    OrtGradientForwardBackwardOptimizer)


X, y = make_regression(1000, n_features=10, bias=2)
X = X.astype(numpy.float32)
y = y.astype(numpy.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)

nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=100,
                  solver='sgd', learning_rate_init=5e-5,
                  n_iter_no_change=1000, batch_size=10, alpha=0,
                  momentum=0, nesterovs_momentum=False)

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    nn.fit(X_train, y_train)

print(nn.loss_curve_)

Out:

[17544.239088541668, 17445.39939453125, 16851.697962239585, 9936.29657063802, 578.7086664835612, 363.7043079630534, 303.0591586303711, 263.4929034932454, 233.25586934407553, 204.30566955566405, 172.90032079060873, 143.48786010742188, 113.72599021911621, 89.05499539693197, 67.66418999989827, 52.226442693074546, 40.51452351888021, 33.41873360951742, 28.47235949198405, 24.883258193333944, 22.138922170003255, 20.585099875132244, 18.629345563252766, 17.559885835647584, 16.2082879336675, 15.240904502868652, 14.386598463058471, 13.640493542353312, 12.904703311920166, 12.1366836643219, 11.582631352742514, 11.007895731925965, 10.595031744639078, 10.100601154963176, 9.607780857086182, 9.32366268157959, 8.893436075846354, 8.497686149279277, 8.115553169250488, 7.89189037322998, 7.45089237054189, 7.205984074274699, 6.973689598242442, 6.638352518876394, 6.303099693457286, 6.1739739648501075, 5.848528136412303, 5.574154332478841, 5.538568253914515, 5.2339846197764075, 5.105911478996277, 4.804820866584778, 4.785186088085174, 4.551310795942943, 4.446464851697286, 4.360751740137736, 4.181670752763748, 4.098967498540878, 3.8756878646214803, 3.8330387020111085, 3.6485162377357483, 3.5573809655507405, 3.522226726214091, 3.4048656916618345, 3.287863028049469, 3.197446967760722, 3.1613333423932395, 3.0292748788992565, 2.940122969945272, 2.892592170238495, 2.7825973331928253, 2.7162149997552234, 2.6807045571009316, 2.583968596458435, 2.5162945834795636, 2.4940704933802285, 2.4130879537264507, 2.3671645935376486, 2.331347656647364, 2.2333084126313527, 2.234156309366226, 2.1370069853464764, 2.1346942806243896, 2.0659455911318463, 2.052324904600779, 1.9889240940411885, 1.973656435807546, 1.9235015380382539, 1.8849119726816814, 1.854487155477206, 1.8110669531424841, 1.7570327099164327, 1.742891683181127, 1.6999185677369435, 1.6805048727989196, 1.6396250144640605, 1.65585151831309, 1.5909660635391871, 1.6092190595467886, 1.5235266812642416]

Score:

print("mean_squared_error=%r" % mean_squared_error(y_test, nn.predict(X_test)))

Out:

mean_squared_error=4.0102124

Conversion to ONNX

onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15)
plot_onnxs(onx)
plot orttraining nn gpu fwbw

Out:

<AxesSubplot:>

Initializers to train

weights = list(sorted(get_train_initializer(onx)))
print(weights)

Out:

['coefficient', 'coefficient1', 'coefficient2', 'intercepts', 'intercepts1', 'intercepts2']

Training graph with forward backward

device = "cuda" if get_device().upper() == 'GPU' else 'cpu'

print("device=%r get_device()=%r" % (device, get_device()))

Out:

device='cpu' get_device()='CPU'

The training session. The first instructions fails for an odd reason as the class TrainingAgent expects to find the list of weights to train in alphabetical order. That means the list onx.graph.initializer must be sorted by alphabetical order of their names otherwise the process could crash unless it is caught earlier with the following exception.

try:
    train_session = OrtGradientForwardBackwardOptimizer(
        onx, device=device, verbose=1,
        warm_start=False, max_iter=100, batch_size=10)
    train_session.fit(X, y)
except ValueError as e:
    print(e)

Out:

List of weights to train must be sorted but ['coefficient', 'intercepts', 'coefficient1', 'intercepts1', 'coefficient2', 'intercepts2'] is not. You shoud use function onnx_rename_weights to do that before calling this class.

Function onnx_rename_weights does not change the order of the initializer but renames them. Then class TrainingAgent may work.

onx = onnx_rename_weights(onx)
train_session = OrtGradientForwardBackwardOptimizer(
    onx, device=device, verbose=1,
    learning_rate=5e-5, warm_start=False, max_iter=100, batch_size=10)
train_session.fit(X, y)

Out:

  0%|          | 0/100 [00:00<?, ?it/s]
  1%|1         | 1/100 [00:00<00:16,  6.14it/s]
  2%|2         | 2/100 [00:00<00:16,  6.09it/s]
  3%|3         | 3/100 [00:00<00:16,  6.06it/s]
  4%|4         | 4/100 [00:00<00:15,  6.07it/s]
  5%|5         | 5/100 [00:00<00:15,  6.07it/s]
  6%|6         | 6/100 [00:00<00:15,  6.06it/s]
  7%|7         | 7/100 [00:01<00:15,  6.07it/s]
  8%|8         | 8/100 [00:01<00:15,  6.07it/s]
  9%|9         | 9/100 [00:01<00:15,  6.07it/s]
 10%|#         | 10/100 [00:01<00:14,  6.05it/s]
 11%|#1        | 11/100 [00:01<00:14,  6.06it/s]
 12%|#2        | 12/100 [00:01<00:14,  6.08it/s]
 13%|#3        | 13/100 [00:02<00:14,  6.08it/s]
 14%|#4        | 14/100 [00:02<00:14,  6.10it/s]
 15%|#5        | 15/100 [00:02<00:13,  6.10it/s]
 16%|#6        | 16/100 [00:02<00:13,  6.09it/s]
 17%|#7        | 17/100 [00:02<00:13,  6.09it/s]
 18%|#8        | 18/100 [00:02<00:13,  6.09it/s]
 19%|#9        | 19/100 [00:03<00:13,  6.10it/s]
 20%|##        | 20/100 [00:03<00:13,  6.11it/s]
 21%|##1       | 21/100 [00:03<00:12,  6.11it/s]
 22%|##2       | 22/100 [00:03<00:12,  6.11it/s]
 23%|##3       | 23/100 [00:03<00:12,  6.11it/s]
 24%|##4       | 24/100 [00:03<00:12,  6.09it/s]
 25%|##5       | 25/100 [00:04<00:12,  6.10it/s]
 26%|##6       | 26/100 [00:04<00:12,  6.08it/s]
 27%|##7       | 27/100 [00:04<00:12,  6.08it/s]
 28%|##8       | 28/100 [00:04<00:11,  6.09it/s]
 29%|##9       | 29/100 [00:04<00:11,  6.09it/s]
 30%|###       | 30/100 [00:04<00:11,  6.09it/s]
 31%|###1      | 31/100 [00:05<00:11,  6.09it/s]
 32%|###2      | 32/100 [00:05<00:11,  6.09it/s]
 33%|###3      | 33/100 [00:05<00:10,  6.10it/s]
 34%|###4      | 34/100 [00:05<00:10,  6.10it/s]
 35%|###5      | 35/100 [00:05<00:10,  6.09it/s]
 36%|###6      | 36/100 [00:05<00:10,  6.09it/s]
 37%|###7      | 37/100 [00:06<00:10,  6.09it/s]
 38%|###8      | 38/100 [00:06<00:10,  6.09it/s]
 39%|###9      | 39/100 [00:06<00:10,  6.10it/s]
 40%|####      | 40/100 [00:06<00:09,  6.10it/s]
 41%|####1     | 41/100 [00:06<00:09,  6.10it/s]
 42%|####2     | 42/100 [00:06<00:09,  6.09it/s]
 43%|####3     | 43/100 [00:07<00:09,  6.10it/s]
 44%|####4     | 44/100 [00:07<00:09,  6.10it/s]
 45%|####5     | 45/100 [00:07<00:09,  6.10it/s]
 46%|####6     | 46/100 [00:07<00:08,  6.10it/s]
 47%|####6     | 47/100 [00:07<00:08,  6.09it/s]
 48%|####8     | 48/100 [00:07<00:08,  6.09it/s]
 49%|####9     | 49/100 [00:08<00:08,  6.09it/s]
 50%|#####     | 50/100 [00:08<00:08,  6.09it/s]
 51%|#####1    | 51/100 [00:08<00:08,  6.08it/s]
 52%|#####2    | 52/100 [00:08<00:07,  6.07it/s]
 53%|#####3    | 53/100 [00:08<00:07,  6.08it/s]
 54%|#####4    | 54/100 [00:08<00:07,  6.08it/s]
 55%|#####5    | 55/100 [00:09<00:07,  6.08it/s]
 56%|#####6    | 56/100 [00:09<00:07,  6.08it/s]
 57%|#####6    | 57/100 [00:09<00:07,  6.08it/s]
 58%|#####8    | 58/100 [00:09<00:06,  6.09it/s]
 59%|#####8    | 59/100 [00:09<00:06,  6.08it/s]
 60%|######    | 60/100 [00:09<00:06,  6.09it/s]
 61%|######1   | 61/100 [00:10<00:06,  6.09it/s]
 62%|######2   | 62/100 [00:10<00:06,  6.10it/s]
 63%|######3   | 63/100 [00:10<00:06,  6.09it/s]
 64%|######4   | 64/100 [00:10<00:05,  6.09it/s]
 65%|######5   | 65/100 [00:10<00:05,  6.09it/s]
 66%|######6   | 66/100 [00:10<00:05,  6.10it/s]
 67%|######7   | 67/100 [00:11<00:05,  6.10it/s]
 68%|######8   | 68/100 [00:11<00:05,  6.10it/s]
 69%|######9   | 69/100 [00:11<00:05,  6.10it/s]
 70%|#######   | 70/100 [00:11<00:04,  6.10it/s]
 71%|#######1  | 71/100 [00:11<00:04,  6.11it/s]
 72%|#######2  | 72/100 [00:11<00:04,  6.09it/s]
 73%|#######3  | 73/100 [00:11<00:04,  6.10it/s]
 74%|#######4  | 74/100 [00:12<00:04,  6.10it/s]
 75%|#######5  | 75/100 [00:12<00:04,  6.10it/s]
 76%|#######6  | 76/100 [00:12<00:03,  6.09it/s]
 77%|#######7  | 77/100 [00:12<00:03,  6.09it/s]
 78%|#######8  | 78/100 [00:12<00:03,  6.09it/s]
 79%|#######9  | 79/100 [00:12<00:03,  6.09it/s]
 80%|########  | 80/100 [00:13<00:03,  6.10it/s]
 81%|########1 | 81/100 [00:13<00:03,  6.10it/s]
 82%|########2 | 82/100 [00:13<00:02,  6.11it/s]
 83%|########2 | 83/100 [00:13<00:02,  6.10it/s]
 84%|########4 | 84/100 [00:13<00:02,  6.11it/s]
 85%|########5 | 85/100 [00:13<00:02,  6.11it/s]
 86%|########6 | 86/100 [00:14<00:02,  6.11it/s]
 87%|########7 | 87/100 [00:14<00:02,  6.11it/s]
 88%|########8 | 88/100 [00:14<00:01,  6.11it/s]
 89%|########9 | 89/100 [00:14<00:01,  6.11it/s]
 90%|######### | 90/100 [00:14<00:01,  6.12it/s]
 91%|#########1| 91/100 [00:14<00:01,  6.10it/s]
 92%|#########2| 92/100 [00:15<00:01,  6.10it/s]
 93%|#########3| 93/100 [00:15<00:01,  6.10it/s]
 94%|#########3| 94/100 [00:15<00:00,  6.10it/s]
 95%|#########5| 95/100 [00:15<00:00,  6.11it/s]
 96%|#########6| 96/100 [00:15<00:00,  6.11it/s]
 97%|#########7| 97/100 [00:15<00:00,  6.11it/s]
 98%|#########8| 98/100 [00:16<00:00,  6.11it/s]
 99%|#########9| 99/100 [00:16<00:00,  6.11it/s]
100%|##########| 100/100 [00:16<00:00,  6.10it/s]
100%|##########| 100/100 [00:16<00:00,  6.09it/s]

OrtGradientForwardBackwardOptimizer(model_onnx='ir_version...', weights_to_train="['I0_coeff...", loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=10, learning_rate=LearningRateSGD(eta0=5e-05, alpha=0.0001, power_t=0.25, learning_rate='invscaling'), value=1.5811388300841898e-05, device='cpu', warm_start=False, verbose=1, validation_every=10, learning_loss=SquareLearningLoss(), enable_logging=False, weight_name=None, learning_penalty=NoLearningPenalty(), exc=True)

Let’s see the weights.

state_tensors = train_session.get_state()

And the loss.

print(train_session.train_losses_)

df = DataFrame({'ort losses': train_session.train_losses_,
                'skl losses:': nn.loss_curve_})
df.plot(title="Train loss against iterations", logy=True)
Train loss against iterations

Out:

[16358.237, 5245.7144, 791.9146, 475.01016, 365.89728, 287.2625, 316.34387, 268.7727, 238.57036, 220.94836, 210.72664, 216.84703, 192.0854, 189.69366, 177.92604, 174.44455, 170.88574, 167.9788, 152.47972, 146.18692, 151.67169, 150.75684, 132.07741, 127.326546, 137.69536, 125.71342, 127.42827, 119.43576, 117.0933, 113.59126, 110.065254, 97.58933, 107.50377, 98.65988, 101.61492, 99.12544, 82.48794, 84.40924, 78.09038, 83.60169, 73.691826, 77.49237, 77.6597, 67.53298, 70.54678, 70.26047, 75.35473, 61.8246, 61.11597, 62.82642, 56.499134, 57.642384, 57.121212, 52.127586, 45.865807, 45.233936, 43.956543, 51.81715, 47.956566, 47.48082, 37.974804, 40.80382, 40.317867, 42.91117, 38.90777, 42.617798, 36.355934, 38.167225, 35.32463, 35.189297, 34.29746, 32.68386, 31.869308, 30.69357, 31.296328, 30.557032, 29.209177, 28.941505, 27.85178, 30.813942, 26.714746, 26.163963, 26.874887, 26.126352, 25.41134, 25.111835, 24.379822, 24.806108, 24.111834, 21.73557, 25.762707, 22.200773, 23.234297, 21.221653, 23.038507, 21.410286, 20.231337, 20.53657, 21.797396, 21.553669]

<AxesSubplot:title={'center':'Train loss against iterations'}>

The convergence rate is different but both classes do not update the learning the same way.

# import matplotlib.pyplot as plt
# plt.show()

Total running time of the script: ( 0 minutes 31.850 seconds)

Gallery generated by Sphinx-Gallery