Intermediate results and investigation

There are many reasons why a user wants more than using the converted model into ONNX. Intermediate results may be needed, the output of every node in the graph. The ONNX may need to be altered to remove some nodes. Transfer learning is usually removing the last layers of a deep neural network. Another reaason is debugging. It often happens that the runtime fails to compute the predictions due to a shape mismatch. Then it is useful the get the shape of every intermediate result. This example looks into two ways of doing it.

Look into pipeline steps

The first way is a tricky one: it overloads methods transform, predict and predict_proba to keep a copy of inputs and outputs. It then goes through every step of the pipeline. If the pipeline has n steps, it converts the pipeline with step 1, then the pipeline with steps 1, 2, then 1, 2, 3…

from pyquickhelper.helpgen.graphviz_helper import plot_graphviz
from mlprodict.onnxrt import OnnxInference
import numpy
from onnxruntime import InferenceSession
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
from skl2onnx import to_onnx
from skl2onnx.helpers import collect_intermediate_steps
from skl2onnx.common.data_types import FloatTensorType

The pipeline.

data = load_iris()
X = data.data

pipe = Pipeline(steps=[
    ('std', StandardScaler()),
    ('km', KMeans(3))
])
pipe.fit(X)

Out:

Pipeline(steps=[('std', StandardScaler()), ('km', KMeans(n_clusters=3))])

The function goes through every step, overloads the methods transform and returns an ONNX graph for every step.

steps = collect_intermediate_steps(
    pipe, "pipeline",
    [("X", FloatTensorType([None, X.shape[1]]))])

We call method transform to population the cache the overloaded methods transform keeps.

pipe.transform(X)

Out:

array([[0.21295824, 3.15861505, 4.00404832],
       [0.99604549, 2.72563625, 4.05055769],
       [0.65198444, 3.02188403, 4.22040251],
       [0.9034561 , 2.93043986, 4.22860026],
       [0.40215457, 3.33653691, 4.12353003],
       [1.21154793, 3.52936423, 3.89643029],
       [0.50244932, 3.19234391, 4.2374443 ],
       [0.09132468, 3.03242342, 3.99197553],
       [1.42174651, 2.9795537 , 4.4445734 ],
       [0.78993078, 2.84221713, 4.08705397],
       [0.78999385, 3.3507236 , 3.92610748],
       [0.27618123, 3.09168785, 4.09865843],
       [1.03497888, 2.85719428, 4.19718995],
       [1.33482453, 3.26547013, 4.66454355],
       [1.63865558, 3.90871872, 4.13826195],
       [2.39898792, 4.51414747, 4.47633229],
       [1.20748818, 3.63475229, 4.02762963],
       [0.21618828, 3.09288714, 3.92839122],
       [1.20986655, 3.36736664, 3.72388908],
       [0.86706182, 3.53103908, 4.10521298],
       [0.50401564, 2.8436663 , 3.67990695],
       [0.66826437, 3.30977167, 3.95222508],
       [0.68658071, 3.63034505, 4.52523323],
       [0.47945627, 2.59973228, 3.60185594],
       [0.36345425, 3.00678098, 4.00845791],
       [0.99023912, 2.60615351, 3.91688379],
       [0.22683089, 2.86756462, 3.80966594],
       [0.2947186 , 3.0958103 , 3.90931811],
       [0.25361098, 2.99275191, 3.89828815],
       [0.65019824, 2.92503544, 4.12581898],
       [0.80138328, 2.78137328, 4.04810077],
       [0.52309257, 2.76837135, 3.58928575],
       [1.57658655, 4.14390673, 4.49874494],
       [1.87652483, 4.2489329 , 4.43563509],
       [0.76858489, 2.76272437, 4.008642  ],
       [0.54896332, 2.9103173 , 4.04525625],
       [0.63079314, 3.09001251, 3.81211172],
       [0.45982568, 3.44057865, 4.26421417],
       [1.2336976 , 3.06034971, 4.45456872],
       [0.14580827, 2.99422852, 3.92683189],
       [0.20261743, 3.16142101, 4.02712265],
       [2.67055552, 2.9575648 , 4.69480008],
       [0.90927099, 3.20355969, 4.4496996 ],
       [0.50081008, 2.89721622, 3.71964918],
       [0.92159916, 3.37471011, 3.91143692],
       [1.01946042, 2.70316642, 4.04740147],
       [0.86953764, 3.56280964, 4.14683513],
       [0.72275914, 3.04646993, 4.26327469],
       [0.72324305, 3.37186092, 3.98021229],
       [0.30295342, 2.94518173, 3.99446269],
       [3.43619989, 1.8639233 , 0.9452659 ],
       [2.97232682, 1.38933168, 1.00829443],
       [3.51850037, 1.6428166 , 0.73653572],
       [3.33264308, 1.00264343, 2.76204203],
       [3.35747592, 0.86560047, 1.16604995],
       [2.77550662, 0.3750882 , 1.86711784],
       [3.01808184, 1.56489146, 1.00955989],
       [2.77360088, 1.55619573, 3.3697155 ],
       [3.21148368, 1.08067281, 1.18358725],
       [2.66294828, 0.82637993, 2.48285941],
       [3.62389817, 2.01281316, 3.79967007],
       [2.70011145, 0.76353654, 1.50054672],
       [3.53658932, 1.27727048, 2.80438695],
       [2.98813829, 0.62868121, 1.34023352],
       [2.32311723, 0.77087912, 2.09655735],
       [3.14311522, 1.43272989, 1.00633966],
       [2.68234835, 0.80192   , 1.71909321],
       [2.63954211, 0.60569829, 2.17926627],
       [3.97369206, 1.18764767, 2.40214871],
       [2.87494798, 0.727372  , 2.52511757],
       [3.03853641, 1.31653995, 1.21113562],
       [2.8022861 , 0.52313867, 1.68291281],
       [3.68305664, 0.75211692, 1.71597913],
       [2.96833851, 0.55292557, 1.59856561],
       [2.9760862 , 0.87815407, 1.33753092],
       [3.13002382, 1.19061026, 1.06462905],
       [3.56679427, 1.22441299, 1.13996294],
       [3.5903606 , 1.37258261, 0.5652633 ],
       [2.93839428, 0.56006248, 1.39763754],
       [2.58203512, 0.81289907, 2.49518379],
       [2.99796537, 0.94324481, 2.75025306],
       [2.92597852, 1.03283946, 2.82866407],
       [2.68907313, 0.4343386 , 2.08201734],
       [3.42215998, 0.48873673, 1.48418961],
       [2.62771445, 0.91606802, 1.92943813],
       [2.75915071, 1.69140864, 1.40011111],
       [3.30075052, 1.44311693, 0.79992473],
       [3.73017167, 1.05036852, 2.2708714 ],
       [2.37943811, 0.83618809, 1.91690629],
       [2.98789866, 0.6470029 , 2.47017911],
       [2.89079656, 0.53979211, 2.32571939],
       [2.86642713, 0.81855214, 1.29304411],
       [2.86642575, 0.43194777, 2.17526444],
       [2.96966239, 1.58383257, 3.40973541],
       [2.77003779, 0.3618706 , 2.10849001],
       [2.38255534, 0.83187956, 1.87076527],
       [2.55559903, 0.58147273, 1.85116384],
       [2.8455521 , 0.70529895, 1.44451588],
       [2.56987887, 1.34329146, 3.11774537],
       [2.64007308, 0.41481694, 1.94990512],
       [4.24274589, 2.26819164, 1.04248866],
       [3.57067982, 0.72581017, 1.57935402],
       [4.44150237, 2.09231844, 0.52274684],
       [3.69480186, 1.12321156, 0.83298461],
       [4.11613683, 1.68255837, 0.5678145 ],
       [5.03326801, 2.72592116, 1.1830756 ],
       [3.3503222 , 1.25267619, 2.8024351 ],
       [4.577021  , 2.18852343, 0.93117407],
       [4.363498  , 1.45283591, 1.46246781],
       [4.79334275, 3.18264007, 1.4207266 ],
       [3.62749566, 1.67405555, 0.47962495],
       [3.89360823, 1.04698204, 1.09881086],
       [4.1132966 , 1.75049044, 0.31830999],
       [3.82688169, 0.92293569, 1.98175664],
       [3.91538879, 1.35721732, 1.54698303],
       [3.89835633, 1.86138575, 0.68407345],
       [3.70128288, 1.34561415, 0.52205472],
       [5.18341242, 3.80620352, 2.03678461],
       [5.58136629, 2.90217633, 1.84250874],
       [4.02615768, 1.16636059, 2.43634558],
       [4.31907679, 2.22297775, 0.48150581],
       [3.4288432 , 0.88685031, 1.67578773],
       [5.19031307, 2.72431414, 1.47096547],
       [3.64273089, 0.79101156, 1.22329554],
       [4.00723617, 2.10999425, 0.47109224],
       [4.2637671 , 2.28591141, 0.62558995],
       [3.45930032, 0.74392898, 1.14490402],
       [3.27575645, 0.98053107, 0.99645552],
       [4.05342943, 1.3282425 , 0.90181942],
       [4.1585729 , 1.98849304, 0.76242411],
       [4.71100584, 2.22822113, 1.08628479],
       [5.12224641, 3.84302072, 2.10967488],
       [4.13401784, 1.41836425, 0.93357383],
       [3.39830644, 0.74517066, 1.17526973],
       [3.63719075, 0.76558228, 1.66051938],
       [5.08776655, 2.80545775, 1.23742547],
       [4.00416552, 2.26945032, 1.04697429],
       [3.58815834, 1.42313566, 0.55013293],
       [3.19454679, 0.93290167, 1.12188023],
       [4.09907253, 1.92136662, 0.20983625],
       [4.28416057, 2.02737038, 0.5691276 ],
       [4.17402084, 2.01513279, 0.49810802],
       [3.57067982, 0.72581017, 1.57935402],
       [4.32128686, 2.19577242, 0.50497262],
       [4.3480018 , 2.37699732, 0.81423561],
       [4.1240495 , 1.77340222, 0.55018391],
       [3.97564407, 0.98294137, 1.58648502],
       [3.7539635 , 1.39731191, 0.49931367],
       [3.7969924 , 2.13822884, 1.06536484],
       [3.25638099, 0.96885287, 1.18287527]])

We compute every step and compare ONNX and scikit-learn outputs.

for step in steps:
    print('----------------------------')
    print(step['model'])
    onnx_step = step['onnx_step']
    sess = InferenceSession(onnx_step.SerializeToString(),
                            providers=['CPUExecutionProvider'])
    onnx_outputs = sess.run(None, {'X': X.astype(numpy.float32)})
    onnx_output = onnx_outputs[-1]
    skl_outputs = step['model']._debug.outputs['transform']

    # comparison
    diff = numpy.abs(skl_outputs.ravel() - onnx_output.ravel()).max()
    print("difference", diff)

# That was the first way: dynamically overwrite
# every method transform or predict in a scikit-learn
# pipeline to capture the input and output of every step,
# compare them to the output produced by truncated ONNX
# graphs built from the first one.
#

Out:

----------------------------
StandardScaler()
difference 4.799262827148709e-07
----------------------------
KMeans(n_clusters=3)
difference 4.332024853268002e-06

Python runtime to look into every node

The python runtime may be useful to easily look into every node of the ONNX graph. This option can be used to check when the computation fails due to nan values or a dimension mismatch.

onx = to_onnx(pipe, X[:1].astype(numpy.float32))

oinf = OnnxInference(onx)
oinf.run({'X': X[:2].astype(numpy.float32)},
         verbose=1, fLOG=print)

Out:

+ki='Ad_Addcst': (3,) (dtype=float32 min=1.0065230131149292 max=5.035177230834961)
+ki='Ge_Gemmcst': (3, 4) (dtype=float32 min=-1.3049873113632202 max=1.1674340963363647)
+ki='Mu_Mulcst': (1,) (dtype=float32 min=0.0 max=0.0)
-- OnnxInference: run 8 nodes
Onnx-Scaler(X) -> variable    (name='Scaler')
+kr='variable': (2, 4) (dtype=float32 min=-1.340226411819458 max=1.0190045833587646)
Onnx-ReduceSumSquare(variable) -> Re_reduced0    (name='Re_ReduceSumSquare')
+kr='Re_reduced0': (2, 1) (dtype=float32 min=4.850505828857422 max=5.376197338104248)
Onnx-Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0    (name='Mu_Mul')
+kr='Mu_C0': (2, 1) (dtype=float32 min=0.0 max=0.0)
Onnx-Gemm(variable, Ge_Gemmcst, Mu_C0) -> Ge_Y0    (name='Ge_Gemm')
+kr='Ge_Y0': (2, 3) (dtype=float32 min=-10.366023063659668 max=8.10552978515625)
Onnx-Add(Re_reduced0, Ge_Y0) -> Ad_C01    (name='Ad_Add')
+kr='Ad_C01': (2, 3) (dtype=float32 min=-4.98982572555542 max=12.956035614013672)
Onnx-Add(Ad_Addcst, Ad_C01) -> Ad_C0    (name='Ad_Add1')
+kr='Ad_C0': (2, 3) (dtype=float32 min=0.045351505279541016 max=16.407014846801758)
Onnx-Sqrt(Ad_C0) -> scores    (name='Sq_Sqrt')
+kr='scores': (2, 3) (dtype=float32 min=0.2129589319229126 max=4.0505571365356445)
Onnx-ArgMin(Ad_C0) -> label    (name='Ar_ArgMin')
+kr='label': (2,) (dtype=int64 min=0 max=0)

{'label': array([0, 0]), 'scores': array([[0.21295893, 3.158615  , 4.0040483 ],
       [0.99604493, 2.7256362 , 4.050557  ]], dtype=float32)}

And to get a sense of the intermediate results.

oinf.run({'X': X[:2].astype(numpy.float32)},
         verbose=3, fLOG=print)

# This way is usually better if you need to investigate
# issues within the code of the runtime for an operator.
#

Out:

+ki='Ad_Addcst': (3,) (dtype=float32 min=1.0065230131149292 max=5.035177230834961
[5.035177  1.006523  3.4509795]
+ki='Ge_Gemmcst': (3, 4) (dtype=float32 min=-1.3049873113632202 max=1.1674340963363647
[[-1.0145789   0.85326266 -1.3049873  -1.2548935 ]
 [-0.01139555 -0.87600833  0.37707573  0.3111534 ]
 [ 1.1674341   0.145303    1.0030255   1.0300019 ]]
+ki='Mu_Mulcst': (1,) (dtype=float32 min=0.0 max=0.0
[0.]
-kv='X' shape=(2, 4) dtype=float32 min=0.20000000298023224 max=5.099999904632568
-- OnnxInference: run 8 nodes
Onnx-Scaler(X) -> variable    (name='Scaler')
+kr='variable': (2, 4) (dtype=float32 min=-1.340226411819458 max=1.0190045833587646)
[[-0.9006812   1.0190046  -1.3402264  -1.3154442 ]
 [-1.1430167  -0.13197924 -1.3402264  -1.3154442 ]]
Onnx-ReduceSumSquare(variable) -> Re_reduced0    (name='Re_ReduceSumSquare')
+kr='Re_reduced0': (2, 1) (dtype=float32 min=4.850505828857422 max=5.376197338104248)
[[5.3761973]
 [4.850506 ]]
Onnx-Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0    (name='Mu_Mul')
+kr='Mu_C0': (2, 1) (dtype=float32 min=0.0 max=0.0)
[[0.]
 [0.]]
Onnx-Gemm(variable, Ge_Gemmcst, Mu_C0) -> Ge_Y0    (name='Ge_Gemm')
+kr='Ge_Y0': (2, 3) (dtype=float32 min=-10.366023063659668 max=8.10552978515625)
[[-10.366023    3.5941293   7.205226 ]
 [ -8.893578    1.5720632   8.10553  ]]
Onnx-Add(Re_reduced0, Ge_Y0) -> Ad_C01    (name='Ad_Add')
+kr='Ad_C01': (2, 3) (dtype=float32 min=-4.98982572555542 max=12.956035614013672)
[[-4.9898257  8.970326  12.581423 ]
 [-4.0430717  6.4225693 12.956036 ]]
Onnx-Add(Ad_Addcst, Ad_C01) -> Ad_C0    (name='Ad_Add1')
+kr='Ad_C0': (2, 3) (dtype=float32 min=0.045351505279541016 max=16.407014846801758)
[[ 0.04535151  9.97685    16.032402  ]
 [ 0.9921055   7.4290924  16.407015  ]]
Onnx-Sqrt(Ad_C0) -> scores    (name='Sq_Sqrt')
+kr='scores': (2, 3) (dtype=float32 min=0.2129589319229126 max=4.0505571365356445)
[[0.21295893 3.158615   4.0040483 ]
 [0.99604493 2.7256362  4.050557  ]]
Onnx-ArgMin(Ad_C0) -> label    (name='Ar_ArgMin')
+kr='label': (2,) (dtype=int64 min=0 max=0)
[0 0]

{'label': array([0, 0]), 'scores': array([[0.21295893, 3.158615  , 4.0040483 ],
       [0.99604493, 2.7256362 , 4.050557  ]], dtype=float32)}

Final graph

ax = plot_graphviz(oinf.to_dot())
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plot fbegin investigate

Total running time of the script: ( 0 minutes 1.943 seconds)

Gallery generated by Sphinx-Gallery