Forward backward on a neural network on GPU (Nesterov) and penalty

This example does the same as Forward backward on a neural network on GPU but updates the weights using Nesterov momentum.

A neural network with scikit-learn

import warnings
import numpy
import onnx
from pandas import DataFrame
from onnxruntime import get_device
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error
from onnxcustom.plotting.plotting_onnx import plot_onnxs
from mlprodict.onnx_conv import to_onnx
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from onnxcustom.utils.orttraining_helper import get_train_initializer
from onnxcustom.utils.onnx_helper import onnx_rename_weights
from onnxcustom.training.optimizers_partial import (
    OrtGradientForwardBackwardOptimizer)
from onnxcustom.training.sgd_learning_rate import LearningRateSGDNesterov
from onnxcustom.training.sgd_learning_penalty import ElasticLearningPenalty


X, y = make_regression(1000, n_features=10, bias=2)
X = X.astype(numpy.float32)
y = y.astype(numpy.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)

nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=100,
                  solver='sgd', learning_rate_init=5e-5,
                  n_iter_no_change=1000, batch_size=10, alpha=0,
                  momentum=0.9, nesterovs_momentum=True)

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    nn.fit(X_train, y_train)

print(nn.loss_curve_)

Out:

[12142.376850585937, 208.52571149190268, 10.635795823733012, 4.1625566299756365, 2.545690587759018, 1.9989628366629282, 1.5532766801118851, 1.253973110516866, 1.0992915083964665, 0.9226496937870979, 0.8337377693255742, 0.7372933038075765, 0.6420993973811467, 0.6062411475181579, 0.5141608527302742, 0.44477628618478776, 0.44327159454425175, 0.4196765069166819, 0.38307343701521557, 0.3546647993723551, 0.31182308614253995, 0.28883073697487516, 0.2859936295946439, 0.251798144976298, 0.23397356274227302, 0.22059242899219195, 0.2115067273378372, 0.1964540594443679, 0.18915611331661542, 0.17627958372235297, 0.17006373253961404, 0.16114117577672005, 0.15046576759467523, 0.15615677376588186, 0.13722270044187704, 0.13354423951978484, 0.1283668660124143, 0.12381672006100417, 0.11649736418078344, 0.10932136800140142, 0.10336004480719567, 0.10255347618212303, 0.09427588980644941, 0.09865071880320708, 0.09097820652027924, 0.09452121757591764, 0.08985137553264698, 0.08327294170856475, 0.0841425225759546, 0.07678079857180516, 0.07622131654371818, 0.07334446835021179, 0.07204858547697465, 0.0708627612516284, 0.06683499027043581, 0.06581077723453442, 0.06399082435294985, 0.06368227748200297, 0.061796805610259374, 0.06035637442022562, 0.05991592526435852, 0.05721309230973323, 0.05518428727673987, 0.05742139232655366, 0.05565398583188653, 0.05081921972955267, 0.05157470373436809, 0.05034799471497536, 0.049603180922567845, 0.050919087901711464, 0.05095796532308062, 0.049858448108037315, 0.045207385222117105, 0.044612640428046385, 0.04721543467914065, 0.04417784827450911, 0.044271371169015764, 0.04416281206222872, 0.044387510418891905, 0.04164197822411855, 0.041849804955224196, 0.040676024689649544, 0.03856285390133659, 0.04186660434119403, 0.03931069366323451, 0.03893013951058189, 0.03758479594563444, 0.03779271937906742, 0.035520137647787726, 0.03582555818681916, 0.03626670505230625, 0.03448854520296057, 0.03782571161165833, 0.03696222162495057, 0.03296316178205112, 0.03370304711163044, 0.034189842287451026, 0.03165586076055964, 0.03171007619549831, 0.03209064475260675]

Score:

print("mean_squared_error=%r" % mean_squared_error(y_test, nn.predict(X_test)))

Out:

mean_squared_error=0.17486624

Conversion to ONNX

onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15)
plot_onnxs(onx)

weights = list(sorted(get_train_initializer(onx)))
print(weights)
plot orttraining nn gpu fwbw nesterov

Out:

['coefficient', 'coefficient1', 'coefficient2', 'intercepts', 'intercepts1', 'intercepts2']

Training graph with forward backward

device = "cuda" if get_device().upper() == 'GPU' else 'cpu'

print("device=%r get_device()=%r" % (device, get_device()))

onx = onnx_rename_weights(onx)
train_session = OrtGradientForwardBackwardOptimizer(
    onx, device=device, verbose=1,
    learning_rate=LearningRateSGDNesterov(1e-4, nesterov=True, momentum=0.9),
    warm_start=False, max_iter=100, batch_size=10)
train_session.fit(X, y)

Out:

device='cpu' get_device()='CPU'

  0%|          | 0/100 [00:00<?, ?it/s]
  1%|1         | 1/100 [00:00<00:20,  4.81it/s]
  2%|2         | 2/100 [00:00<00:20,  4.80it/s]
  3%|3         | 3/100 [00:00<00:20,  4.79it/s]
  4%|4         | 4/100 [00:00<00:19,  4.80it/s]
  5%|5         | 5/100 [00:01<00:19,  4.80it/s]
  6%|6         | 6/100 [00:01<00:19,  4.80it/s]
  7%|7         | 7/100 [00:01<00:19,  4.80it/s]
  8%|8         | 8/100 [00:01<00:19,  4.80it/s]
  9%|9         | 9/100 [00:01<00:18,  4.80it/s]
 10%|#         | 10/100 [00:02<00:18,  4.80it/s]
 11%|#1        | 11/100 [00:02<00:18,  4.80it/s]
 12%|#2        | 12/100 [00:02<00:18,  4.80it/s]
 13%|#3        | 13/100 [00:02<00:18,  4.80it/s]
 14%|#4        | 14/100 [00:02<00:17,  4.80it/s]
 15%|#5        | 15/100 [00:03<00:17,  4.79it/s]
 16%|#6        | 16/100 [00:03<00:17,  4.78it/s]
 17%|#7        | 17/100 [00:03<00:17,  4.78it/s]
 18%|#8        | 18/100 [00:03<00:17,  4.78it/s]
 19%|#9        | 19/100 [00:03<00:16,  4.78it/s]
 20%|##        | 20/100 [00:04<00:16,  4.79it/s]
 21%|##1       | 21/100 [00:04<00:16,  4.79it/s]
 22%|##2       | 22/100 [00:04<00:16,  4.79it/s]
 23%|##3       | 23/100 [00:04<00:16,  4.79it/s]
 24%|##4       | 24/100 [00:05<00:15,  4.80it/s]
 25%|##5       | 25/100 [00:05<00:15,  4.79it/s]
 26%|##6       | 26/100 [00:05<00:15,  4.80it/s]
 27%|##7       | 27/100 [00:05<00:15,  4.80it/s]
 28%|##8       | 28/100 [00:05<00:15,  4.80it/s]
 29%|##9       | 29/100 [00:06<00:14,  4.80it/s]
 30%|###       | 30/100 [00:06<00:14,  4.80it/s]
 31%|###1      | 31/100 [00:06<00:14,  4.80it/s]
 32%|###2      | 32/100 [00:06<00:14,  4.80it/s]
 33%|###3      | 33/100 [00:06<00:13,  4.79it/s]
 34%|###4      | 34/100 [00:07<00:13,  4.79it/s]
 35%|###5      | 35/100 [00:07<00:13,  4.79it/s]
 36%|###6      | 36/100 [00:07<00:13,  4.78it/s]
 37%|###7      | 37/100 [00:07<00:13,  4.78it/s]
 38%|###8      | 38/100 [00:07<00:12,  4.78it/s]
 39%|###9      | 39/100 [00:08<00:12,  4.78it/s]
 40%|####      | 40/100 [00:08<00:12,  4.78it/s]
 41%|####1     | 41/100 [00:08<00:12,  4.78it/s]
 42%|####2     | 42/100 [00:08<00:12,  4.78it/s]
 43%|####3     | 43/100 [00:08<00:11,  4.78it/s]
 44%|####4     | 44/100 [00:09<00:11,  4.78it/s]
 45%|####5     | 45/100 [00:09<00:11,  4.78it/s]
 46%|####6     | 46/100 [00:09<00:11,  4.78it/s]
 47%|####6     | 47/100 [00:09<00:11,  4.79it/s]
 48%|####8     | 48/100 [00:10<00:10,  4.79it/s]
 49%|####9     | 49/100 [00:10<00:10,  4.78it/s]
 50%|#####     | 50/100 [00:10<00:10,  4.78it/s]
 51%|#####1    | 51/100 [00:10<00:10,  4.78it/s]
 52%|#####2    | 52/100 [00:10<00:10,  4.77it/s]
 53%|#####3    | 53/100 [00:11<00:09,  4.77it/s]
 54%|#####4    | 54/100 [00:11<00:09,  4.77it/s]
 55%|#####5    | 55/100 [00:11<00:09,  4.77it/s]
 56%|#####6    | 56/100 [00:11<00:09,  4.78it/s]
 57%|#####6    | 57/100 [00:11<00:09,  4.77it/s]
 58%|#####8    | 58/100 [00:12<00:08,  4.77it/s]
 59%|#####8    | 59/100 [00:12<00:08,  4.77it/s]
 60%|######    | 60/100 [00:12<00:08,  4.77it/s]
 61%|######1   | 61/100 [00:12<00:08,  4.77it/s]
 62%|######2   | 62/100 [00:12<00:07,  4.77it/s]
 63%|######3   | 63/100 [00:13<00:07,  4.77it/s]
 64%|######4   | 64/100 [00:13<00:07,  4.77it/s]
 65%|######5   | 65/100 [00:13<00:07,  4.77it/s]
 66%|######6   | 66/100 [00:13<00:07,  4.77it/s]
 67%|######7   | 67/100 [00:14<00:06,  4.77it/s]
 68%|######8   | 68/100 [00:14<00:06,  4.77it/s]
 69%|######9   | 69/100 [00:14<00:06,  4.77it/s]
 70%|#######   | 70/100 [00:14<00:06,  4.77it/s]
 71%|#######1  | 71/100 [00:14<00:06,  4.77it/s]
 72%|#######2  | 72/100 [00:15<00:05,  4.77it/s]
 73%|#######3  | 73/100 [00:15<00:05,  4.77it/s]
 74%|#######4  | 74/100 [00:15<00:05,  4.77it/s]
 75%|#######5  | 75/100 [00:15<00:05,  4.77it/s]
 76%|#######6  | 76/100 [00:15<00:05,  4.77it/s]
 77%|#######7  | 77/100 [00:16<00:04,  4.76it/s]
 78%|#######8  | 78/100 [00:16<00:04,  4.76it/s]
 79%|#######9  | 79/100 [00:16<00:04,  4.76it/s]
 80%|########  | 80/100 [00:16<00:04,  4.76it/s]
 81%|########1 | 81/100 [00:16<00:03,  4.76it/s]
 82%|########2 | 82/100 [00:17<00:03,  4.76it/s]
 83%|########2 | 83/100 [00:17<00:03,  4.77it/s]
 84%|########4 | 84/100 [00:17<00:03,  4.76it/s]
 85%|########5 | 85/100 [00:17<00:03,  4.77it/s]
 86%|########6 | 86/100 [00:17<00:02,  4.77it/s]
 87%|########7 | 87/100 [00:18<00:02,  4.77it/s]
 88%|########8 | 88/100 [00:18<00:02,  4.77it/s]
 89%|########9 | 89/100 [00:18<00:02,  4.77it/s]
 90%|######### | 90/100 [00:18<00:02,  4.77it/s]
 91%|#########1| 91/100 [00:19<00:01,  4.77it/s]
 92%|#########2| 92/100 [00:19<00:01,  4.77it/s]
 93%|#########3| 93/100 [00:19<00:01,  4.76it/s]
 94%|#########3| 94/100 [00:19<00:01,  4.76it/s]
 95%|#########5| 95/100 [00:19<00:01,  4.76it/s]
 96%|#########6| 96/100 [00:20<00:00,  4.76it/s]
 97%|#########7| 97/100 [00:20<00:00,  4.76it/s]
 98%|#########8| 98/100 [00:20<00:00,  4.76it/s]
 99%|#########9| 99/100 [00:20<00:00,  4.76it/s]
100%|##########| 100/100 [00:20<00:00,  4.76it/s]
100%|##########| 100/100 [00:20<00:00,  4.78it/s]

OrtGradientForwardBackwardOptimizer(model_onnx='ir_version...', weights_to_train="['I0_coeff...", loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=10, learning_rate=LearningRateSGDNesterov(eta0=0.0001, alpha=0.0001, power_t=0.25, learning_rate='invscaling', momentum=0.9, nesterov=True), value=3.1622776601683795e-05, device='cpu', warm_start=False, verbose=1, validation_every=10, learning_loss=SquareLearningLoss(), enable_logging=False, weight_name=None, learning_penalty=NoLearningPenalty(), exc=True)

Let’s see the weights.

state_tensors = train_session.get_state()

And the loss.

print(train_session.train_losses_)

df = DataFrame({'ort losses': train_session.train_losses_,
                'skl losses:': nn.loss_curve_})
df.plot(title="Train loss against iterations (Nesterov)", logy=True)
Train loss against iterations (Nesterov)

Out:

[2258.1738, 26.342966, 9.936302, 5.147777, 3.2953956, 2.8015616, 2.0468345, 1.743872, 1.343944, 1.3889159, 1.1101127, 0.98755026, 0.90764016, 0.7749563, 0.66308135, 0.73497695, 0.6427221, 0.5525322, 0.5237125, 0.62061363, 0.46638852, 0.5113034, 0.40196434, 0.3591859, 0.49574584, 0.34535193, 0.4361959, 0.39837873, 0.38860977, 0.38707742, 0.3307107, 0.28549117, 0.32713687, 0.4017986, 0.3796233, 0.33476147, 0.21781337, 0.30370104, 0.32311404, 0.28576297, 0.24571778, 0.28465894, 0.24558789, 0.21089275, 0.2831703, 0.25513417, 0.23211785, 0.25994626, 0.2551794, 0.25311303, 0.22698905, 0.2027645, 0.23341805, 0.20787859, 0.22074765, 0.20912045, 0.23379944, 0.18521471, 0.18057354, 0.23292427, 0.15141475, 0.19741966, 0.15918921, 0.18584923, 0.14298472, 0.14691725, 0.28181884, 0.1636774, 0.17148694, 0.14742981, 0.11548183, 0.15906177, 0.15372595, 0.12606488, 0.17675255, 0.16583428, 0.16700117, 0.15778151, 0.1406626, 0.1633339, 0.14795175, 0.17505321, 0.19927153, 0.13015547, 0.15346076, 0.11843845, 0.14645901, 0.15687111, 0.124431096, 0.20296392, 0.107760936, 0.13928299, 0.14325543, 0.106979266, 0.09940349, 0.14283879, 0.14309952, 0.1127268, 0.1405759, 0.11075577]

<AxesSubplot:title={'center':'Train loss against iterations (Nesterov)'}>

The convergence rate is different but both classes do not update the learning exactly the same way.

Regularization

Default parameters for MLPRegressor suggest to penalize weights during training: alpha=1e-4.

nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=100,
                  solver='sgd', learning_rate_init=5e-5,
                  n_iter_no_change=1000, batch_size=10, alpha=1e-4,
                  momentum=0.9, nesterovs_momentum=True)

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    nn.fit(X_train, y_train)

print(nn.loss_curve_)

Out:

[10287.22519014488, 114.30391664686144, 4.428802581263733, 2.2294384981979367, 1.5061151267471942, 1.0660096135118482, 0.9013473236577829, 0.6927708104238349, 0.6072377480045318, 0.5131283271757443, 0.44050787021721993, 0.3971774471186558, 0.3805070497757753, 0.350709212971592, 0.3160200295091788, 0.2866447259530703, 0.2678877657233954, 0.2624037469550847, 0.2646896001451651, 0.22975923892503577, 0.2235425944488446, 0.20868258862121097, 0.21262072222343276, 0.1918363877769391, 0.1817688411685188, 0.18029508261725902, 0.16713803397329252, 0.15096519545116424, 0.15579957813555, 0.14022681212632657, 0.1335865062277913, 0.13066095226773022, 0.12465744209192994, 0.11612236702070239, 0.11650939311578667, 0.10377068048932547, 0.09965911698656081, 0.09569806834505992, 0.09555516442207891, 0.08785860989701748, 0.08413557989854219, 0.08588188749337594, 0.07946141101812917, 0.07406683651229344, 0.07536813826996087, 0.07404707255665657, 0.06886945094286999, 0.06640573891913693, 0.06959184916854301, 0.0610276720985154, 0.06053264404808283, 0.05765451917141079, 0.05497423387708463, 0.05292999011297026, 0.053558166335922495, 0.051363563591797154, 0.05082045451762675, 0.049323814527793206, 0.04954291654862861, 0.04663594919974508, 0.04639400097087621, 0.045110241577425596, 0.04267538147623937, 0.04201486460911432, 0.041236927045960214, 0.040728190061660605, 0.03991805001185438, 0.03878842797073423, 0.03985737019779681, 0.03783477265487363, 0.03624473010996182, 0.035857376220913735, 0.034874257704518234, 0.03527559812203397, 0.034284848736864826, 0.03344557415381272, 0.03267718306214213, 0.0336423159529845, 0.03282535581768255, 0.032049191349957894, 0.03134248187396525, 0.03107644581406613, 0.029861127178698287, 0.029881794314117735, 0.030867609095594294, 0.028551491866825016, 0.029164035577879357, 0.028555548553245765, 0.027525101808546986, 0.02741220280134081, 0.027476729826804994, 0.02733537836576602, 0.025846304840237408, 0.026733439335417743, 0.024720744558956235, 0.02609910241108065, 0.02529742046883305, 0.025844847096701467, 0.024001264599111175, 0.024748266170668606]

Let’s do the same with onnxruntime.

train_session = OrtGradientForwardBackwardOptimizer(
    onx, device=device, verbose=1,
    learning_rate=LearningRateSGDNesterov(1e-4, nesterov=True, momentum=0.9),
    learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4),
    warm_start=False, max_iter=100, batch_size=10)
train_session.fit(X, y)

Out:

  0%|          | 0/100 [00:00<?, ?it/s]
  1%|1         | 1/100 [00:00<00:28,  3.45it/s]
  2%|2         | 2/100 [00:00<00:28,  3.43it/s]
  3%|3         | 3/100 [00:00<00:28,  3.42it/s]
  4%|4         | 4/100 [00:01<00:28,  3.42it/s]
  5%|5         | 5/100 [00:01<00:27,  3.42it/s]
  6%|6         | 6/100 [00:01<00:27,  3.41it/s]
  7%|7         | 7/100 [00:02<00:27,  3.41it/s]
  8%|8         | 8/100 [00:02<00:26,  3.41it/s]
  9%|9         | 9/100 [00:02<00:26,  3.41it/s]
 10%|#         | 10/100 [00:02<00:26,  3.41it/s]
 11%|#1        | 11/100 [00:03<00:26,  3.41it/s]
 12%|#2        | 12/100 [00:03<00:25,  3.41it/s]
 13%|#3        | 13/100 [00:03<00:25,  3.42it/s]
 14%|#4        | 14/100 [00:04<00:25,  3.42it/s]
 15%|#5        | 15/100 [00:04<00:24,  3.41it/s]
 16%|#6        | 16/100 [00:04<00:24,  3.42it/s]
 17%|#7        | 17/100 [00:04<00:24,  3.42it/s]
 18%|#8        | 18/100 [00:05<00:24,  3.42it/s]
 19%|#9        | 19/100 [00:05<00:23,  3.42it/s]
 20%|##        | 20/100 [00:05<00:23,  3.41it/s]
 21%|##1       | 21/100 [00:06<00:23,  3.41it/s]
 22%|##2       | 22/100 [00:06<00:22,  3.41it/s]
 23%|##3       | 23/100 [00:06<00:22,  3.41it/s]
 24%|##4       | 24/100 [00:07<00:22,  3.41it/s]
 25%|##5       | 25/100 [00:07<00:21,  3.41it/s]
 26%|##6       | 26/100 [00:07<00:21,  3.41it/s]
 27%|##7       | 27/100 [00:07<00:21,  3.41it/s]
 28%|##8       | 28/100 [00:08<00:21,  3.41it/s]
 29%|##9       | 29/100 [00:08<00:20,  3.42it/s]
 30%|###       | 30/100 [00:08<00:20,  3.42it/s]
 31%|###1      | 31/100 [00:09<00:20,  3.42it/s]
 32%|###2      | 32/100 [00:09<00:19,  3.42it/s]
 33%|###3      | 33/100 [00:09<00:19,  3.42it/s]
 34%|###4      | 34/100 [00:09<00:19,  3.42it/s]
 35%|###5      | 35/100 [00:10<00:18,  3.42it/s]
 36%|###6      | 36/100 [00:10<00:18,  3.42it/s]
 37%|###7      | 37/100 [00:10<00:18,  3.42it/s]
 38%|###8      | 38/100 [00:11<00:18,  3.42it/s]
 39%|###9      | 39/100 [00:11<00:17,  3.42it/s]
 40%|####      | 40/100 [00:11<00:17,  3.42it/s]
 41%|####1     | 41/100 [00:12<00:17,  3.42it/s]
 42%|####2     | 42/100 [00:12<00:16,  3.42it/s]
 43%|####3     | 43/100 [00:12<00:16,  3.42it/s]
 44%|####4     | 44/100 [00:12<00:16,  3.42it/s]
 45%|####5     | 45/100 [00:13<00:16,  3.42it/s]
 46%|####6     | 46/100 [00:13<00:15,  3.42it/s]
 47%|####6     | 47/100 [00:13<00:15,  3.42it/s]
 48%|####8     | 48/100 [00:14<00:15,  3.42it/s]
 49%|####9     | 49/100 [00:14<00:14,  3.42it/s]
 50%|#####     | 50/100 [00:14<00:14,  3.42it/s]
 51%|#####1    | 51/100 [00:14<00:14,  3.42it/s]
 52%|#####2    | 52/100 [00:15<00:14,  3.42it/s]
 53%|#####3    | 53/100 [00:15<00:13,  3.42it/s]
 54%|#####4    | 54/100 [00:15<00:13,  3.42it/s]
 55%|#####5    | 55/100 [00:16<00:13,  3.42it/s]
 56%|#####6    | 56/100 [00:16<00:12,  3.41it/s]
 57%|#####6    | 57/100 [00:16<00:12,  3.41it/s]
 58%|#####8    | 58/100 [00:16<00:12,  3.41it/s]
 59%|#####8    | 59/100 [00:17<00:12,  3.41it/s]
 60%|######    | 60/100 [00:17<00:11,  3.42it/s]
 61%|######1   | 61/100 [00:17<00:11,  3.42it/s]
 62%|######2   | 62/100 [00:18<00:11,  3.42it/s]
 63%|######3   | 63/100 [00:18<00:10,  3.42it/s]
 64%|######4   | 64/100 [00:18<00:10,  3.42it/s]
 65%|######5   | 65/100 [00:19<00:10,  3.41it/s]
 66%|######6   | 66/100 [00:19<00:09,  3.41it/s]
 67%|######7   | 67/100 [00:19<00:09,  3.41it/s]
 68%|######8   | 68/100 [00:19<00:09,  3.41it/s]
 69%|######9   | 69/100 [00:20<00:09,  3.41it/s]
 70%|#######   | 70/100 [00:20<00:08,  3.41it/s]
 71%|#######1  | 71/100 [00:20<00:08,  3.41it/s]
 72%|#######2  | 72/100 [00:21<00:08,  3.41it/s]
 73%|#######3  | 73/100 [00:21<00:07,  3.41it/s]
 74%|#######4  | 74/100 [00:21<00:07,  3.41it/s]
 75%|#######5  | 75/100 [00:21<00:07,  3.41it/s]
 76%|#######6  | 76/100 [00:22<00:07,  3.41it/s]
 77%|#######7  | 77/100 [00:22<00:06,  3.41it/s]
 78%|#######8  | 78/100 [00:22<00:06,  3.41it/s]
 79%|#######9  | 79/100 [00:23<00:06,  3.41it/s]
 80%|########  | 80/100 [00:23<00:05,  3.41it/s]
 81%|########1 | 81/100 [00:23<00:05,  3.41it/s]
 82%|########2 | 82/100 [00:24<00:05,  3.41it/s]
 83%|########2 | 83/100 [00:24<00:04,  3.41it/s]
 84%|########4 | 84/100 [00:24<00:04,  3.41it/s]
 85%|########5 | 85/100 [00:24<00:04,  3.41it/s]
 86%|########6 | 86/100 [00:25<00:04,  3.41it/s]
 87%|########7 | 87/100 [00:25<00:03,  3.41it/s]
 88%|########8 | 88/100 [00:25<00:03,  3.41it/s]
 89%|########9 | 89/100 [00:26<00:03,  3.41it/s]
 90%|######### | 90/100 [00:26<00:02,  3.41it/s]
 91%|#########1| 91/100 [00:26<00:02,  3.42it/s]
 92%|#########2| 92/100 [00:26<00:02,  3.42it/s]
 93%|#########3| 93/100 [00:27<00:02,  3.42it/s]
 94%|#########3| 94/100 [00:27<00:01,  3.42it/s]
 95%|#########5| 95/100 [00:27<00:01,  3.42it/s]
 96%|#########6| 96/100 [00:28<00:01,  3.41it/s]
 97%|#########7| 97/100 [00:28<00:00,  3.41it/s]
 98%|#########8| 98/100 [00:28<00:00,  3.41it/s]
 99%|#########9| 99/100 [00:28<00:00,  3.41it/s]
100%|##########| 100/100 [00:29<00:00,  3.41it/s]
100%|##########| 100/100 [00:29<00:00,  3.41it/s]

OrtGradientForwardBackwardOptimizer(model_onnx='ir_version...', weights_to_train="['I0_coeff...", loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=10, learning_rate=LearningRateSGDNesterov(eta0=0.0001, alpha=0.0001, power_t=0.25, learning_rate='invscaling', momentum=0.9, nesterov=True), value=3.1622776601683795e-05, device='cpu', warm_start=False, verbose=1, validation_every=10, learning_loss=SquareLearningLoss(), enable_logging=False, weight_name=None, learning_penalty=ElasticLearningPenalty(l1=0, l2=0.0001), exc=True)

Let’s see the weights.

state_tensors = train_session.get_state()

And the loss.

print(train_session.train_losses_)

df = DataFrame({'ort losses': train_session.train_losses_,
                'skl losses:': nn.loss_curve_})
df.plot(title="Train loss against iterations (Nesterov + penalty)", logy=True)
Train loss against iterations (Nesterov + penalty)

Out:

[4861.449, 25.046307, 11.008789, 7.74074, 5.1028852, 3.2229476, 3.2113922, 2.9235556, 2.1449852, 2.3559413, 1.8511151, 1.9164665, 1.435327, 1.4132812, 1.2099515, 1.1310208, 1.206428, 0.8012704, 0.8482481, 1.0125045, 0.76332545, 0.8698208, 0.8912954, 0.7267333, 0.83761984, 0.675399, 0.6764955, 0.5540964, 0.58731264, 0.6209444, 0.55011517, 0.4272984, 0.50298077, 0.42416194, 0.49143714, 0.4930766, 0.41217956, 0.4411206, 0.4109722, 0.34561297, 0.42141122, 0.32779393, 0.36398593, 0.37045026, 0.2863156, 0.28487918, 0.262511, 0.29033896, 0.27661306, 0.27709922, 0.22183947, 0.25974074, 0.2349325, 0.31290713, 0.24561325, 0.29334575, 0.23535314, 0.24226215, 0.2292589, 0.23041001, 0.2390392, 0.23469613, 0.2020327, 0.19836996, 0.2151799, 0.19205719, 0.1804088, 0.2004037, 0.1584917, 0.18252498, 0.16962215, 0.14306745, 0.158388, 0.1560865, 0.107742354, 0.15685132, 0.14094299, 0.13749506, 0.15084405, 0.13492058, 0.14775337, 0.12560132, 0.13170843, 0.13348602, 0.1418273, 0.1413577, 0.13540213, 0.14798613, 0.1299366, 0.1215678, 0.14300115, 0.13453592, 0.14070058, 0.13080919, 0.11906525, 0.15257156, 0.13933863, 0.14258417, 0.13619626, 0.13876237]

<AxesSubplot:title={'center':'Train loss against iterations (Nesterov + penalty)'}>

All ONNX graphs

Method Method save_onnx_graph can export all the ONNX graph used by the model on disk.

def print_graph(d):
    for k, v in sorted(d.items()):
        if isinstance(v, dict):
            print_graph(v)
        else:
            print("\n++++++", v.replace("\\", "/"), "\n")
            with open(v, "rb") as f:
                print(onnx_simple_text_plot(onnx.load(f)))


all_files = train_session.save_onnx_graph('.')
print_graph(all_files)


# import matplotlib.pyplot as plt
# plt.show()

Out:

++++++ ./SquareLLoss.learning_loss.loss_grad_onnx_.onnx

opset: domain='' version=14
input: name='X1' type=dtype('float32') shape=(0, 0)
input: name='X2' type=dtype('float32') shape=(0, 0)
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([0.5], dtype=float32)
init: name='Re_Reshapecst' type=dtype('int64') shape=(1,) -- array([-1])
init: name='Mu_Mulcst1' type=dtype('float32') shape=(1,) -- array([-1.], dtype=float32)
Sub(X1, X2) -> Su_C0
  Mul(Su_C0, Mu_Mulcst1) -> Y_grad
ReduceSumSquare(Su_C0) -> Re_reduced0
  Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
    Reshape(Mu_C0, Re_Reshapecst) -> Y
output: name='Y' type=dtype('float32') shape=()
output: name='Y_grad' type=dtype('float32') shape=()

++++++ ./SquareLLoss.learning_loss.loss_score_onnx_.onnx

opset: domain='' version=14
input: name='X1' type=dtype('float32') shape=(0, 0)
input: name='X2' type=dtype('float32') shape=(0, 0)
Sub(X1, X2) -> Su_C0
  Mul(Su_C0, Su_C0) -> Y
output: name='Y' type=dtype('float32') shape=(0, 1)

++++++ ./ElasticLPenalty.learning_penalty.penalty_grad_onnx_.onnx

opset: domain='' version=14
input: name='X' type=dtype('float32') shape=()
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([0.9998], dtype=float32)
init: name='Mu_Mulcst1' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
Mul(X, Mu_Mulcst) -> Mu_C0
Sign(X) -> Si_output0
  Mul(Si_output0, Mu_Mulcst1) -> Mu_C02
  Sub(Mu_C0, Mu_C02) -> Y
output: name='Y' type=dtype('float32') shape=()

++++++ ./ElasticLPenalty.learning_penalty.penalty_onnx_.onnx

opset: domain='' version=14
input: name='loss' type=dtype('float32') shape=()
input: name='W0' type=dtype('float32') shape=()
input: name='W1' type=dtype('float32') shape=()
input: name='W2' type=dtype('float32') shape=()
input: name='W3' type=dtype('float32') shape=()
input: name='W4' type=dtype('float32') shape=()
input: name='W5' type=dtype('float32') shape=()
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
init: name='Mu_Mulcst1' type=dtype('float32') shape=(1,) -- array([1.e-04], dtype=float32)
init: name='Re_Reshapecst' type=dtype('int64') shape=(1,) -- array([-1])
Abs(W0) -> Ab_Y0
  ReduceSum(Ab_Y0) -> Re_reduced0
    Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
Abs(W3) -> Ab_Y04
  ReduceSum(Ab_Y04) -> Re_reduced07
ReduceSumSquare(W5) -> Re_reduced012
ReduceSumSquare(W3) -> Re_reduced08
Identity(Mu_Mulcst) -> Mu_Mulcst6
  Mul(Re_reduced07, Mu_Mulcst6) -> Mu_C07
ReduceSumSquare(W4) -> Re_reduced010
Identity(Mu_Mulcst1) -> Mu_Mulcst7
  Mul(Re_reduced08, Mu_Mulcst7) -> Mu_C08
    Add(Mu_C07, Mu_C08) -> Ad_C09
Abs(W4) -> Ab_Y05
  ReduceSum(Ab_Y05) -> Re_reduced09
Identity(Mu_Mulcst) -> Mu_Mulcst8
  Mul(Re_reduced09, Mu_Mulcst8) -> Mu_C09
Identity(Mu_Mulcst1) -> Mu_Mulcst9
  Mul(Re_reduced010, Mu_Mulcst9) -> Mu_C010
    Add(Mu_C09, Mu_C010) -> Ad_C010
Abs(W5) -> Ab_Y06
  ReduceSum(Ab_Y06) -> Re_reduced011
Identity(Mu_Mulcst) -> Mu_Mulcst10
  Mul(Re_reduced011, Mu_Mulcst10) -> Mu_C011
Identity(Mu_Mulcst1) -> Mu_Mulcst11
  Mul(Re_reduced012, Mu_Mulcst11) -> Mu_C012
    Add(Mu_C011, Mu_C012) -> Ad_C011
Identity(Mu_Mulcst1) -> Mu_Mulcst3
ReduceSumSquare(W0) -> Re_reduced02
  Mul(Re_reduced02, Mu_Mulcst1) -> Mu_C02
    Add(Mu_C0, Mu_C02) -> Ad_C06
Identity(Mu_Mulcst) -> Mu_Mulcst4
Abs(W2) -> Ab_Y03
  ReduceSum(Ab_Y03) -> Re_reduced05
  Mul(Re_reduced05, Mu_Mulcst4) -> Mu_C05
ReduceSumSquare(W1) -> Re_reduced04
  Mul(Re_reduced04, Mu_Mulcst3) -> Mu_C04
ReduceSumSquare(W2) -> Re_reduced06
Identity(Mu_Mulcst1) -> Mu_Mulcst5
  Mul(Re_reduced06, Mu_Mulcst5) -> Mu_C06
    Add(Mu_C05, Mu_C06) -> Ad_C08
Identity(Mu_Mulcst) -> Mu_Mulcst2
Abs(W1) -> Ab_Y02
  ReduceSum(Ab_Y02) -> Re_reduced03
  Mul(Re_reduced03, Mu_Mulcst2) -> Mu_C03
    Add(Mu_C03, Mu_C04) -> Ad_C07
      Add(Ad_C06, Ad_C07) -> Ad_C05
      Add(Ad_C05, Ad_C08) -> Ad_C04
      Add(Ad_C04, Ad_C09) -> Ad_C03
      Add(Ad_C03, Ad_C010) -> Ad_C02
      Add(Ad_C02, Ad_C011) -> Ad_C01
        Add(loss, Ad_C01) -> Ad_C0
          Reshape(Ad_C0, Re_Reshapecst) -> Y
output: name='Y' type=dtype('float32') shape=(0,)

++++++ ./LRateSGDNesterov.learning_rate.axpyw_onnx_.onnx

opset: domain='' version=14
input: name='X1' type=dtype('float32') shape=()
input: name='X2' type=dtype('float32') shape=()
input: name='G' type=dtype('float32') shape=()
input: name='alpha' type=dtype('float32') shape=(1,)
input: name='beta' type=dtype('float32') shape=(1,)
Mul(X1, alpha) -> Mu_C0
Mul(G, beta) -> Mu_C03
  Add(Mu_C0, Mu_C03) -> Z
    Mul(Z, beta) -> Mu_C02
  Add(Mu_C0, Mu_C02) -> Ad_C0
    Add(Ad_C0, X2) -> Y
output: name='Y' type=dtype('float32') shape=()
output: name='Z' type=dtype('float32') shape=()

++++++ ./GradFBOptimizer.model_onnx.onnx

opset: domain='' version=14
input: name='X' type=dtype('float32') shape=(0, 10)
init: name='I0_coefficient' type=dtype('float32') shape=(100,)
init: name='I1_intercepts' type=dtype('float32') shape=(10,)
init: name='I2_coefficient1' type=dtype('float32') shape=(100,)
init: name='I3_intercepts1' type=dtype('float32') shape=(10,)
init: name='I4_coefficient2' type=dtype('float32') shape=(10,)
init: name='I5_intercepts2' type=dtype('float32') shape=(1,) -- array([0.14376946], dtype=float32)
init: name='I6_shape_tensor' type=dtype('int64') shape=(2,) -- array([-1,  1])
Cast(X, to=1) -> r0
  MatMul(r0, I0_coefficient) -> r1
    Add(r1, I1_intercepts) -> r2
      Relu(r2) -> r3
        MatMul(r3, I2_coefficient1) -> r4
          Add(r4, I3_intercepts1) -> r5
            Relu(r5) -> r6
              MatMul(r6, I4_coefficient2) -> r7
                Add(r7, I5_intercepts2) -> r8
                  Reshape(r8, I6_shape_tensor) -> variable
output: name='variable' type=dtype('float32') shape=(0, 1)

++++++ ./OrtGradientForwardBackwardFunction_140036806160976.train_function_._optimized_pre_grad_model.onnx

opset: domain='' version=14
opset: domain='com.microsoft.experimental' version=1
opset: domain='ai.onnx.preview.training' version=1
opset: domain='com.microsoft.nchwc' version=1
opset: domain='com.ms.internal.nhwc' version=1
opset: domain='ai.onnx.training' version=1
opset: domain='ai.onnx.ml' version=2
opset: domain='com.microsoft' version=1
input: name='X' type=dtype('float32') shape=(0, 10)
input: name='I0_coefficient' type=dtype('float32') shape=(10, 10)
input: name='I1_intercepts' type=dtype('float32') shape=(1, 10)
input: name='I2_coefficient1' type=dtype('float32') shape=(10, 10)
input: name='I3_intercepts1' type=dtype('float32') shape=(1, 10)
input: name='I4_coefficient2' type=dtype('float32') shape=(10, 1)
input: name='I5_intercepts2' type=dtype('float32') shape=(1, 1)
init: name='I6_shape_tensor' type=dtype('int64') shape=(2,) -- array([-1,  1])
MatMul(X, I0_coefficient) -> r1
  Add(r1, I1_intercepts) -> r2
    Relu(r2) -> r3
      MatMul(r3, I2_coefficient1) -> r4
        Add(r4, I3_intercepts1) -> r5
          Relu(r5) -> r6
            MatMul(r6, I4_coefficient2) -> r7
              Add(r7, I5_intercepts2) -> r8
                Reshape(r8, I6_shape_tensor, allowzero=0) -> variable
output: name='variable' type=dtype('float32') shape=(0, 1)

++++++ ./OrtGradientForwardBackwardFunction_140036806160976.train_function_._trained_onnx.onnx

opset: domain='' version=14
opset: domain='com.microsoft.experimental' version=1
opset: domain='ai.onnx.preview.training' version=1
opset: domain='com.microsoft.nchwc' version=1
opset: domain='com.ms.internal.nhwc' version=1
opset: domain='ai.onnx.training' version=1
opset: domain='ai.onnx.ml' version=2
opset: domain='com.microsoft' version=1
input: name='X' type=dtype('float32') shape=(0, 10)
input: name='I0_coefficient' type=dtype('float32') shape=(10, 10)
input: name='I1_intercepts' type=dtype('float32') shape=(1, 10)
input: name='I2_coefficient1' type=dtype('float32') shape=(10, 10)
input: name='I3_intercepts1' type=dtype('float32') shape=(1, 10)
input: name='I4_coefficient2' type=dtype('float32') shape=(10, 1)
input: name='I5_intercepts2' type=dtype('float32') shape=(1, 1)
init: name='I6_shape_tensor' type=dtype('int64') shape=(2,) -- array([-1,  1])
init: name='n1_Grad/A_target_shape' type=dtype('int64') shape=(2,) -- array([-1, 10])
init: name='n1_Grad/dY_target_shape' type=dtype('int64') shape=(2,) -- array([-1, 10])
init: name='n4_Grad/A_target_shape' type=dtype('int64') shape=(2,) -- array([-1, 10])
init: name='n4_Grad/dY_target_shape' type=dtype('int64') shape=(2,) -- array([-1, 10])
init: name='n7_Grad/A_target_shape' type=dtype('int64') shape=(2,) -- array([-1, 10])
init: name='n7_Grad/dY_target_shape' type=dtype('int64') shape=(2,) -- array([-1,  1])
MatMul(X, I0_coefficient) -> r1
  Add(r1, I1_intercepts) -> r2
    Relu(r2) -> r3
      MatMul(r3, I2_coefficient1) -> r4
        Add(r4, I3_intercepts1) -> r5
          Relu(r5) -> r6
            MatMul(r6, I4_coefficient2) -> r7
              Add(r7, I5_intercepts2) -> r8
                Reshape(r8, I6_shape_tensor, allowzero=0) -> variable
                  YieldOp(variable) -> variable_grad
                Shape(r8) -> n9_Grad/x_shape
                  Reshape(variable_grad, n9_Grad/x_shape, allowzero=0) -> r8_grad
Shape(I5_intercepts2) -> n8_Grad/Shape_I5_intercepts2
Shape(r7) -> n8_Grad/Shape_r7
  BroadcastGradientArgs(n8_Grad/Shape_r7, n8_Grad/Shape_I5_intercepts2) -> n8_Grad/ReduceAxes_r7, n8_Grad/ReduceAxes_I5_intercepts2
    ReduceSum(r8_grad, n8_Grad/ReduceAxes_I5_intercepts2, noop_with_empty_axes=1, keepdims=1) -> n8_Grad/ReduceSum_r8_grad_for_I5_intercepts2
  Reshape(n8_Grad/ReduceSum_r8_grad_for_I5_intercepts2, n8_Grad/Shape_I5_intercepts2, allowzero=0) -> I5_intercepts2_grad
ReduceSum(r8_grad, n8_Grad/ReduceAxes_r7, noop_with_empty_axes=1, keepdims=1) -> n8_Grad/ReduceSum_r8_grad_for_r7
  Reshape(n8_Grad/ReduceSum_r8_grad_for_r7, n8_Grad/Shape_r7, allowzero=0) -> r7_grad
    Reshape(r7_grad, n7_Grad/dY_target_shape, allowzero=0) -> n7_Grad/dY_reshape_2d
Reshape(r6, n7_Grad/A_target_shape, allowzero=0) -> n7_Grad/A_reshape_2d
  Gemm(n7_Grad/A_reshape_2d, n7_Grad/dY_reshape_2d, beta=1.00, transB=0, transA=1, alpha=1.00) -> I4_coefficient2_grad
FusedMatMul(r7_grad, I4_coefficient2, transB=1, alpha=1.00, transA=0) -> n7_Grad/PreReduceGrad0
  Shape(n7_Grad/PreReduceGrad0) -> n7_Grad/Shape_n7_Grad/PreReduceGrad0
Shape(r6) -> n7_Grad/Shape_r6
  BroadcastGradientArgs(n7_Grad/Shape_r6, n7_Grad/Shape_n7_Grad/PreReduceGrad0) -> n7_Grad/ReduceAxes_r6_for_r6,
  ReduceSum(n7_Grad/PreReduceGrad0, n7_Grad/ReduceAxes_r6_for_r6, noop_with_empty_axes=1, keepdims=1) -> n7_Grad/ReduceSum_n7_Grad/PreReduceGrad0_for_r6
  Reshape(n7_Grad/ReduceSum_n7_Grad/PreReduceGrad0_for_r6, n7_Grad/Shape_r6, allowzero=0) -> r6_grad
    ReluGrad(r6_grad, r6) -> r5_grad
Shape(I3_intercepts1) -> n5_Grad/Shape_I3_intercepts1
Shape(r4) -> n5_Grad/Shape_r4
  BroadcastGradientArgs(n5_Grad/Shape_r4, n5_Grad/Shape_I3_intercepts1) -> n5_Grad/ReduceAxes_r4, n5_Grad/ReduceAxes_I3_intercepts1
    ReduceSum(r5_grad, n5_Grad/ReduceAxes_I3_intercepts1, noop_with_empty_axes=1, keepdims=1) -> n5_Grad/ReduceSum_r5_grad_for_I3_intercepts1
  Reshape(n5_Grad/ReduceSum_r5_grad_for_I3_intercepts1, n5_Grad/Shape_I3_intercepts1, allowzero=0) -> I3_intercepts1_grad
ReduceSum(r5_grad, n5_Grad/ReduceAxes_r4, noop_with_empty_axes=1, keepdims=1) -> n5_Grad/ReduceSum_r5_grad_for_r4
  Reshape(n5_Grad/ReduceSum_r5_grad_for_r4, n5_Grad/Shape_r4, allowzero=0) -> r4_grad
    Reshape(r4_grad, n4_Grad/dY_target_shape, allowzero=0) -> n4_Grad/dY_reshape_2d
Reshape(r3, n4_Grad/A_target_shape, allowzero=0) -> n4_Grad/A_reshape_2d
  Gemm(n4_Grad/A_reshape_2d, n4_Grad/dY_reshape_2d, beta=1.00, transB=0, transA=1, alpha=1.00) -> I2_coefficient1_grad
FusedMatMul(r4_grad, I2_coefficient1, transB=1, alpha=1.00, transA=0) -> n4_Grad/PreReduceGrad0
  Shape(n4_Grad/PreReduceGrad0) -> n4_Grad/Shape_n4_Grad/PreReduceGrad0
Shape(r3) -> n4_Grad/Shape_r3
  BroadcastGradientArgs(n4_Grad/Shape_r3, n4_Grad/Shape_n4_Grad/PreReduceGrad0) -> n4_Grad/ReduceAxes_r3_for_r3,
  ReduceSum(n4_Grad/PreReduceGrad0, n4_Grad/ReduceAxes_r3_for_r3, noop_with_empty_axes=1, keepdims=1) -> n4_Grad/ReduceSum_n4_Grad/PreReduceGrad0_for_r3
  Reshape(n4_Grad/ReduceSum_n4_Grad/PreReduceGrad0_for_r3, n4_Grad/Shape_r3, allowzero=0) -> r3_grad
    ReluGrad(r3_grad, r3) -> r2_grad
Shape(I1_intercepts) -> n2_Grad/Shape_I1_intercepts
Shape(r1) -> n2_Grad/Shape_r1
  BroadcastGradientArgs(n2_Grad/Shape_r1, n2_Grad/Shape_I1_intercepts) -> n2_Grad/ReduceAxes_r1, n2_Grad/ReduceAxes_I1_intercepts
    ReduceSum(r2_grad, n2_Grad/ReduceAxes_I1_intercepts, noop_with_empty_axes=1, keepdims=1) -> n2_Grad/ReduceSum_r2_grad_for_I1_intercepts
  Reshape(n2_Grad/ReduceSum_r2_grad_for_I1_intercepts, n2_Grad/Shape_I1_intercepts, allowzero=0) -> I1_intercepts_grad
ReduceSum(r2_grad, n2_Grad/ReduceAxes_r1, noop_with_empty_axes=1, keepdims=1) -> n2_Grad/ReduceSum_r2_grad_for_r1
  Reshape(n2_Grad/ReduceSum_r2_grad_for_r1, n2_Grad/Shape_r1, allowzero=0) -> r1_grad
    Reshape(r1_grad, n1_Grad/dY_target_shape, allowzero=0) -> n1_Grad/dY_reshape_2d
Reshape(X, n1_Grad/A_target_shape, allowzero=0) -> n1_Grad/A_reshape_2d
  Gemm(n1_Grad/A_reshape_2d, n1_Grad/dY_reshape_2d, beta=1.00, transB=0, transA=1, alpha=1.00) -> I0_coefficient_grad
FusedMatMul(r1_grad, I0_coefficient, transB=1, alpha=1.00, transA=0) -> n1_Grad/PreReduceGrad0
  Shape(n1_Grad/PreReduceGrad0) -> n1_Grad/Shape_n1_Grad/PreReduceGrad0
Shape(X) -> n1_Grad/Shape_X
  BroadcastGradientArgs(n1_Grad/Shape_X, n1_Grad/Shape_n1_Grad/PreReduceGrad0) -> n1_Grad/ReduceAxes_X_for_X,
  ReduceSum(n1_Grad/PreReduceGrad0, n1_Grad/ReduceAxes_X_for_X, noop_with_empty_axes=1, keepdims=1) -> n1_Grad/ReduceSum_n1_Grad/PreReduceGrad0_for_X
  Reshape(n1_Grad/ReduceSum_n1_Grad/PreReduceGrad0_for_X, n1_Grad/Shape_X, allowzero=0) -> X_grad
output: name='X_grad' type=dtype('float32') shape=(0, 10)
output: name='I0_coefficient_grad' type=dtype('float32') shape=(10, 10)
output: name='I1_intercepts_grad' type=dtype('float32') shape=(1, 10)
output: name='I2_coefficient1_grad' type=dtype('float32') shape=(10, 10)
output: name='I3_intercepts1_grad' type=dtype('float32') shape=(1, 10)
output: name='I4_coefficient2_grad' type=dtype('float32') shape=(10, 1)
output: name='I5_intercepts2_grad' type=dtype('float32') shape=(1, 1)

++++++ ./GradFBOptimizer.zero_onnx_.onnx

opset: domain='' version=14
input: name='X' type=dtype('float32') shape=()
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
Mul(X, Mu_Mulcst) -> Y
output: name='Y' type=dtype('float32') shape=()

Total running time of the script: ( 1 minutes 23.585 seconds)

Gallery generated by Sphinx-Gallery