Speed up scikit-learn inference with ONNX#

Is it possible to make scikit-learn faster with ONNX? That’s question this example tries to answer. The scenario is is the following:

  • a model is trained

  • it is converted into ONNX for inference

  • it selects a runtime to compute the prediction

The following runtime are tested:

  • python: python runtime for ONNX

  • onnxruntime1: onnxruntime

  • numpy: the ONNX graph is converted into numpy code

  • numba: the numpy code is accelerated with numba.

PCA#

Let’s look at a very simple model, a PCA.

import numpy
from pandas import DataFrame
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.decomposition import PCA
from pyquickhelper.pycode.profiling import profile
from mlprodict.sklapi import OnnxSpeedupTransformer
from cpyquickhelper.numbers.speed_measure import measure_time
from tqdm import tqdm

Data and models to test.

data, _ = make_regression(1000, n_features=20)
data = data.astype(numpy.float32)
models = [
    ('sklearn', PCA(n_components=10)),
    ('python', OnnxSpeedupTransformer(
        PCA(n_components=10), runtime='python')),
    ('onnxruntime1', OnnxSpeedupTransformer(
        PCA(n_components=10), runtime='onnxruntime1')),
    ('numpy', OnnxSpeedupTransformer(
        PCA(n_components=10), runtime='numpy')),
    ('numba', OnnxSpeedupTransformer(
        PCA(n_components=10), runtime='numba'))]

Training.

for name, model in tqdm(models):
    model.fit(data)

Out:

  0%|          | 0/5 [00:00<?, ?it/s]
 20%|##        | 1/5 [00:00<00:00,  7.80it/s]
 60%|######    | 3/5 [00:00<00:00, 10.66it/s]
100%|##########| 5/5 [00:07<00:00,  1.83s/it]
100%|##########| 5/5 [00:07<00:00,  1.45s/it]

Profiling of runtime onnxruntime1.

def fct():
    for i in range(1000):
        models[2][1].transform(data)


res = profile(fct, pyinst_format="text")
print(res[1])

Out:

  _     ._   __/__   _ _  _  _ _/_   Recorded: 03:15:13 AM Samples:  505
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.639     CPU time: 2.550
/   _/                      v4.1.1

Program: /var/lib/jenkins/workspace/mlprodict/mlprodict_UT_39_std/_doc/examples/plot_speedup_pca.py

0.639 profile  ../pycode/profiling.py:457
`- 0.639 fct  plot_speedup_pca.py:67
      [13 frames hidden]  plot_speedup_pca, mlprodict
         0.631 run  mlprodict/onnxrt/ops_whole/session.py:98

Profiling of runtime numpy.

def fct():
    for i in range(1000):
        models[3][1].transform(data)


res = profile(fct, pyinst_format="text")
print(res[1])

Out:

  _     ._   __/__   _ _  _  _ _/_   Recorded: 03:15:14 AM Samples:  320
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.332     CPU time: 0.332
/   _/                      v4.1.1

Program: /var/lib/jenkins/workspace/mlprodict/mlprodict_UT_39_std/_doc/examples/plot_speedup_pca.py

0.331 profile  ../pycode/profiling.py:457
`- 0.331 fct  plot_speedup_pca.py:79
      [16 frames hidden]  plot_speedup_pca, mlprodict, sklearn,...
         0.280 numpy_mlprodict_ONNX_PCA  <string>:11
         |- 0.141 [self]
         |- 0.131 array  <built-in>:0

The class OnnxSpeedupTransformer converts the PCA into ONNX and then converts it into a python code using numpy. The code is the following.

print(models[3][1].numpy_code_)

Out:

import numpy
import scipy.special as scipy_special
import scipy.spatial.distance as scipy_distance
from mlprodict.onnx_tools.exports.numpy_helper import (
    argmax_use_numpy_select_last_index,
    argmin_use_numpy_select_last_index,
    array_feature_extrator,
    make_slice)


def numpy_mlprodict_ONNX_PCA(X):
    '''
    Numpy function for ``mlprodict_ONNX_PCA``.

    * producer: skl2onnx
    * version: 0
    * description:
    '''
    # initializers

    list_value = [0.03522588685154915, -0.04391317069530487, -0.3019331097602844, -0.08857168257236481, -0.04125336557626724, -0.1372506469488144, -0.15545883774757385, -0.29013821482658386, -0.2170533686876297, -0.009242745116353035, 0.028533577919006348, 0.0418868213891983, 0.09391921758651733, -0.673380434513092, 0.4197906255722046, -0.36630380153656006, -0.21160711348056793, 0.028479784727096558, -0.07911509275436401, 0.05429431051015854, -0.016461068764328957, -0.08221270889043808, 0.3420103192329407, -0.18598555028438568, 0.15762969851493835, 0.46943679451942444, 0.3404766917228699, -0.3100897967815399, -0.18107417225837708, 0.08803647756576538, 0.39513328671455383, 0.03356841951608658, -0.05867309123277664, 0.4098595976829529, 0.26771920919418335, 0.025141030550003052, -0.0803273618221283, -0.28375062346458435, 0.056819431483745575, -0.21815301477909088, -0.31348809599876404, -0.24914886057376862, -0.08464746177196503, 0.04204845428466797, -0.25113722681999207, 0.09550014138221741, -0.4219156503677368, -0.13716453313827515, -0.20961827039718628, 0.06006816029548645, 0.02216215245425701, -0.3065590560436249, 0.23661229014396667, 0.20935629308223724, -0.1726744920015335, -0.36021891236305237, 0.024149619042873383, 0.06666494160890579, 0.1875750720500946, 0.5891869068145752, 0.052102379500865936, 0.1691839098930359, -0.010547593235969543, -0.05830360949039459, -0.2233782708644867, -0.04592195153236389, -0.07198181748390198, -0.3452170491218567, -0.17131690680980682, 0.1980271339416504, -0.1746671050786972, -0.09624757617712021, 0.15323825180530548, 0.44910112023353577, 0.5162782669067383, -0.21690776944160461, -0.05309027433395386, 0.08803343772888184, -0.20859375596046448, -0.06328219920396805, -0.3302276134490967, 0.1708548367023468, -0.06401528418064117, 0.04334565997123718, -0.09085354208946228, -0.3452633023262024, 0.11034107208251953, -0.32488223910331726, -0.18863177299499512, -0.38356223702430725, 0.4194985628128052, 0.07467803359031677, -0.2680917978286743, 0.05319248139858246, -0.011900395154953003, 0.16783447563648224, 0.13117912411689758, 0.010787688195705414, -0.15326648950576782, 0.15091925859451294,
                  0.05567193403840065, -0.5821548104286194, -0.005654362961649895, -0.13977614045143127, 0.08097386360168457, 0.20613563060760498, -0.2828896641731262, 0.12731046974658966, 0.06732460111379623, -0.20602290332317352, 0.0016411282122135162, 0.20149409770965576, 0.44979655742645264, 0.11433097720146179, -0.2531198263168335, 0.055678918957710266, -0.27453410625457764, 0.30891719460487366, -0.057148516178131104, -0.352786123752594, 0.2109539806842804, -0.17314523458480835, 0.5185893177986145, -0.07017163932323456, -0.2076755166053772, -0.25260186195373535, 0.20356816053390503, -0.36128267645835876, 0.10138076543807983, -0.16016840934753418, 0.3096393048763275, -0.32698506116867065, 0.02671171724796295, -0.014824658632278442, 0.02823743224143982, -0.10258353501558304, -0.11335937678813934, -0.12684054672718048, 0.0032808110117912292, -0.22259417176246643, 0.030816804617643356, 0.19896170496940613, 0.12073422968387604, -0.0027247369289398193, 0.2638765573501587, -0.11494152247905731, 0.15325583517551422, 0.20857945084571838, -0.06654155254364014, 0.0842096135020256, 0.17692816257476807, -0.0732412189245224, -0.11389152705669403, 0.028294242918491364, 0.1814122498035431, -0.07548843324184418, -0.13339780271053314, -0.17244917154312134, 0.13185881078243256, 0.047141265124082565, -0.4110371172428131, -0.008370316587388515, 0.07260086387395859, 0.15411722660064697, 0.26084113121032715, 0.15544359385967255, -0.1315050572156906, -0.3465263843536377, 0.22994205355644226, 0.21516229212284088, -0.2586739957332611, -0.41148707270622253, -0.19924192130565643, -0.03929010406136513, 0.018542705103754997, -0.02322215959429741, 0.5503065586090088, 0.1206894963979721, -0.04807346314191818, -0.21227966248989105, 0.07046424597501755, -0.09622500091791153, 0.21259407699108124, 0.06374454498291016, 0.049957338720560074, 0.1281668096780777, -0.062413424253463745, 0.05819839984178543, -0.7248937487602234, 0.1750917285680771, -0.08923415094614029, 0.1404252052307129, 0.16485461592674255, -0.11322721838951111, 0.14091645181179047, 0.33257782459259033, -0.13305926322937012, -0.11384911835193634, 0.28106456995010376, -0.06577727943658829]
    B = numpy.array(list_value, dtype=numpy.float32).reshape((20, 10))

    list_value = [-0.011537947691977024, -0.053047001361846924, 0.00043868148350156844, 0.021356476470828056, 0.04350115731358528, -0.020801274105906487, -0.003591367742046714, 0.08736559003591537, -0.003987821284681559, 0.027182267978787422, -
                  0.008783889003098011, -0.05426081269979477, -0.02160467393696308, -0.02549884095788002, 0.03422679752111435, 0.007602499797940254, -0.029017562046647072, -0.06731828302145004, -0.015037109144032001, 0.030689967796206474]
    C = numpy.array(list_value, dtype=numpy.float32)

    # nodes

    D = X - C
    variable = D @ B

    return variable

Benchmark.

bench = []
for name, model in tqdm(models):
    for size in (1, 10, 100, 1000, 10000, 100000, 200000):
        data, _ = make_regression(size, n_features=20)
        data = data.astype(numpy.float32)

        # We run it a first time (numba compiles
        # the function during the first execution).
        model.transform(data)
        res = measure_time(
            lambda: model.transform(data), div_by_number=True,
            context={'data': data, 'model': model})
        res['name'] = name
        res['size'] = size
        bench.append(res)

df = DataFrame(bench)
piv = df.pivot("size", "name", "average")
piv

Out:

  0%|          | 0/5 [00:00<?, ?it/s]
 20%|##        | 1/5 [00:39<02:37, 39.33s/it]
 40%|####      | 2/5 [01:07<01:38, 32.96s/it]
 60%|######    | 3/5 [01:21<00:48, 24.17s/it]
 80%|########  | 4/5 [01:47<00:25, 25.07s/it]
100%|##########| 5/5 [02:19<00:00, 27.40s/it]
100%|##########| 5/5 [02:19<00:00, 27.90s/it]
name numba numpy onnxruntime1 python sklearn
size
1 0.000026 0.000103 0.000297 0.000195 0.000280
10 0.000031 0.000109 0.000295 0.000180 0.000289
100 0.000048 0.000129 0.000335 0.000191 0.000326
1000 0.000193 0.000285 0.000539 0.000356 0.000586
10000 0.001769 0.001876 0.001425 0.002232 0.002843
100000 0.013270 0.016511 0.007728 0.017223 0.023993
200000 0.028003 0.031735 0.014700 0.034493 0.048167


Graph.

fig, ax = plt.subplots(1, 2, figsize=(10, 4))
piv.plot(title="Speedup PCA with ONNX (lower better)",
         logx=True, logy=True, ax=ax[0])
piv2 = piv.copy()
for c in piv2.columns:
    piv2[c] /= piv['sklearn']
print(piv2)
piv2.plot(title="baseline=scikit-learn (lower better)",
          logx=True, logy=True, ax=ax[1])
plt.show()
Speedup PCA with ONNX (lower better), baseline=scikit-learn (lower better)

Out:

name       numba     numpy  onnxruntime1    python  sklearn
size
1       0.094160  0.367328      1.062296  0.696770      1.0
10      0.107015  0.376076      1.021664  0.622485      1.0
100     0.147017  0.397408      1.028905  0.587404      1.0
1000    0.329933  0.486672      0.919924  0.607708      1.0
10000   0.622233  0.659963      0.501087  0.784828      1.0
100000  0.553091  0.688183      0.322086  0.717848      1.0
200000  0.581378  0.658858      0.305190  0.716115      1.0

Total running time of the script: ( 2 minutes 31.004 seconds)

Gallery generated by Sphinx-Gallery