Note
Click here to download the full example code
Speed up scikit-learn inference with ONNX#
Is it possible to make scikit-learn faster with ONNX? That’s question this example tries to answer. The scenario is is the following:
a model is trained
it is converted into ONNX for inference
it selects a runtime to compute the prediction
The following runtime are tested:
python: python runtime for ONNX
onnxruntime1: onnxruntime
numpy: the ONNX graph is converted into numpy code
numba: the numpy code is accelerated with numba.
PCA#
Let’s look at a very simple model, a PCA.
import numpy
from pandas import DataFrame
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.decomposition import PCA
from pyquickhelper.pycode.profiling import profile
from mlprodict.sklapi import OnnxSpeedupTransformer
from cpyquickhelper.numbers.speed_measure import measure_time
from tqdm import tqdm
Data and models to test.
data, _ = make_regression(1000, n_features=20)
data = data.astype(numpy.float32)
models = [
('sklearn', PCA(n_components=10)),
('python', OnnxSpeedupTransformer(
PCA(n_components=10), runtime='python')),
('onnxruntime1', OnnxSpeedupTransformer(
PCA(n_components=10), runtime='onnxruntime1')),
('numpy', OnnxSpeedupTransformer(
PCA(n_components=10), runtime='numpy')),
('numba', OnnxSpeedupTransformer(
PCA(n_components=10), runtime='numba'))]
Training.
for name, model in tqdm(models):
model.fit(data)
Out:
0%| | 0/5 [00:00<?, ?it/s]
20%|## | 1/5 [00:00<00:00, 7.80it/s]
60%|###### | 3/5 [00:00<00:00, 10.66it/s]
100%|##########| 5/5 [00:07<00:00, 1.83s/it]
100%|##########| 5/5 [00:07<00:00, 1.45s/it]
Profiling of runtime onnxruntime1.
def fct():
for i in range(1000):
models[2][1].transform(data)
res = profile(fct, pyinst_format="text")
print(res[1])
Out:
_ ._ __/__ _ _ _ _ _/_ Recorded: 03:15:13 AM Samples: 505
/_//_/// /_\ / //_// / //_'/ // Duration: 0.639 CPU time: 2.550
/ _/ v4.1.1
Program: /var/lib/jenkins/workspace/mlprodict/mlprodict_UT_39_std/_doc/examples/plot_speedup_pca.py
0.639 profile ../pycode/profiling.py:457
`- 0.639 fct plot_speedup_pca.py:67
[13 frames hidden] plot_speedup_pca, mlprodict
0.631 run mlprodict/onnxrt/ops_whole/session.py:98
Profiling of runtime numpy.
def fct():
for i in range(1000):
models[3][1].transform(data)
res = profile(fct, pyinst_format="text")
print(res[1])
Out:
_ ._ __/__ _ _ _ _ _/_ Recorded: 03:15:14 AM Samples: 320
/_//_/// /_\ / //_// / //_'/ // Duration: 0.332 CPU time: 0.332
/ _/ v4.1.1
Program: /var/lib/jenkins/workspace/mlprodict/mlprodict_UT_39_std/_doc/examples/plot_speedup_pca.py
0.331 profile ../pycode/profiling.py:457
`- 0.331 fct plot_speedup_pca.py:79
[16 frames hidden] plot_speedup_pca, mlprodict, sklearn,...
0.280 numpy_mlprodict_ONNX_PCA <string>:11
|- 0.141 [self]
|- 0.131 array <built-in>:0
The class OnnxSpeedupTransformer converts the PCA into ONNX and then converts it into a python code using numpy. The code is the following.
print(models[3][1].numpy_code_)
Out:
import numpy
import scipy.special as scipy_special
import scipy.spatial.distance as scipy_distance
from mlprodict.onnx_tools.exports.numpy_helper import (
argmax_use_numpy_select_last_index,
argmin_use_numpy_select_last_index,
array_feature_extrator,
make_slice)
def numpy_mlprodict_ONNX_PCA(X):
'''
Numpy function for ``mlprodict_ONNX_PCA``.
* producer: skl2onnx
* version: 0
* description:
'''
# initializers
list_value = [0.03522588685154915, -0.04391317069530487, -0.3019331097602844, -0.08857168257236481, -0.04125336557626724, -0.1372506469488144, -0.15545883774757385, -0.29013821482658386, -0.2170533686876297, -0.009242745116353035, 0.028533577919006348, 0.0418868213891983, 0.09391921758651733, -0.673380434513092, 0.4197906255722046, -0.36630380153656006, -0.21160711348056793, 0.028479784727096558, -0.07911509275436401, 0.05429431051015854, -0.016461068764328957, -0.08221270889043808, 0.3420103192329407, -0.18598555028438568, 0.15762969851493835, 0.46943679451942444, 0.3404766917228699, -0.3100897967815399, -0.18107417225837708, 0.08803647756576538, 0.39513328671455383, 0.03356841951608658, -0.05867309123277664, 0.4098595976829529, 0.26771920919418335, 0.025141030550003052, -0.0803273618221283, -0.28375062346458435, 0.056819431483745575, -0.21815301477909088, -0.31348809599876404, -0.24914886057376862, -0.08464746177196503, 0.04204845428466797, -0.25113722681999207, 0.09550014138221741, -0.4219156503677368, -0.13716453313827515, -0.20961827039718628, 0.06006816029548645, 0.02216215245425701, -0.3065590560436249, 0.23661229014396667, 0.20935629308223724, -0.1726744920015335, -0.36021891236305237, 0.024149619042873383, 0.06666494160890579, 0.1875750720500946, 0.5891869068145752, 0.052102379500865936, 0.1691839098930359, -0.010547593235969543, -0.05830360949039459, -0.2233782708644867, -0.04592195153236389, -0.07198181748390198, -0.3452170491218567, -0.17131690680980682, 0.1980271339416504, -0.1746671050786972, -0.09624757617712021, 0.15323825180530548, 0.44910112023353577, 0.5162782669067383, -0.21690776944160461, -0.05309027433395386, 0.08803343772888184, -0.20859375596046448, -0.06328219920396805, -0.3302276134490967, 0.1708548367023468, -0.06401528418064117, 0.04334565997123718, -0.09085354208946228, -0.3452633023262024, 0.11034107208251953, -0.32488223910331726, -0.18863177299499512, -0.38356223702430725, 0.4194985628128052, 0.07467803359031677, -0.2680917978286743, 0.05319248139858246, -0.011900395154953003, 0.16783447563648224, 0.13117912411689758, 0.010787688195705414, -0.15326648950576782, 0.15091925859451294,
0.05567193403840065, -0.5821548104286194, -0.005654362961649895, -0.13977614045143127, 0.08097386360168457, 0.20613563060760498, -0.2828896641731262, 0.12731046974658966, 0.06732460111379623, -0.20602290332317352, 0.0016411282122135162, 0.20149409770965576, 0.44979655742645264, 0.11433097720146179, -0.2531198263168335, 0.055678918957710266, -0.27453410625457764, 0.30891719460487366, -0.057148516178131104, -0.352786123752594, 0.2109539806842804, -0.17314523458480835, 0.5185893177986145, -0.07017163932323456, -0.2076755166053772, -0.25260186195373535, 0.20356816053390503, -0.36128267645835876, 0.10138076543807983, -0.16016840934753418, 0.3096393048763275, -0.32698506116867065, 0.02671171724796295, -0.014824658632278442, 0.02823743224143982, -0.10258353501558304, -0.11335937678813934, -0.12684054672718048, 0.0032808110117912292, -0.22259417176246643, 0.030816804617643356, 0.19896170496940613, 0.12073422968387604, -0.0027247369289398193, 0.2638765573501587, -0.11494152247905731, 0.15325583517551422, 0.20857945084571838, -0.06654155254364014, 0.0842096135020256, 0.17692816257476807, -0.0732412189245224, -0.11389152705669403, 0.028294242918491364, 0.1814122498035431, -0.07548843324184418, -0.13339780271053314, -0.17244917154312134, 0.13185881078243256, 0.047141265124082565, -0.4110371172428131, -0.008370316587388515, 0.07260086387395859, 0.15411722660064697, 0.26084113121032715, 0.15544359385967255, -0.1315050572156906, -0.3465263843536377, 0.22994205355644226, 0.21516229212284088, -0.2586739957332611, -0.41148707270622253, -0.19924192130565643, -0.03929010406136513, 0.018542705103754997, -0.02322215959429741, 0.5503065586090088, 0.1206894963979721, -0.04807346314191818, -0.21227966248989105, 0.07046424597501755, -0.09622500091791153, 0.21259407699108124, 0.06374454498291016, 0.049957338720560074, 0.1281668096780777, -0.062413424253463745, 0.05819839984178543, -0.7248937487602234, 0.1750917285680771, -0.08923415094614029, 0.1404252052307129, 0.16485461592674255, -0.11322721838951111, 0.14091645181179047, 0.33257782459259033, -0.13305926322937012, -0.11384911835193634, 0.28106456995010376, -0.06577727943658829]
B = numpy.array(list_value, dtype=numpy.float32).reshape((20, 10))
list_value = [-0.011537947691977024, -0.053047001361846924, 0.00043868148350156844, 0.021356476470828056, 0.04350115731358528, -0.020801274105906487, -0.003591367742046714, 0.08736559003591537, -0.003987821284681559, 0.027182267978787422, -
0.008783889003098011, -0.05426081269979477, -0.02160467393696308, -0.02549884095788002, 0.03422679752111435, 0.007602499797940254, -0.029017562046647072, -0.06731828302145004, -0.015037109144032001, 0.030689967796206474]
C = numpy.array(list_value, dtype=numpy.float32)
# nodes
D = X - C
variable = D @ B
return variable
Benchmark.
bench = []
for name, model in tqdm(models):
for size in (1, 10, 100, 1000, 10000, 100000, 200000):
data, _ = make_regression(size, n_features=20)
data = data.astype(numpy.float32)
# We run it a first time (numba compiles
# the function during the first execution).
model.transform(data)
res = measure_time(
lambda: model.transform(data), div_by_number=True,
context={'data': data, 'model': model})
res['name'] = name
res['size'] = size
bench.append(res)
df = DataFrame(bench)
piv = df.pivot("size", "name", "average")
piv
Out:
0%| | 0/5 [00:00<?, ?it/s]
20%|## | 1/5 [00:39<02:37, 39.33s/it]
40%|#### | 2/5 [01:07<01:38, 32.96s/it]
60%|###### | 3/5 [01:21<00:48, 24.17s/it]
80%|######## | 4/5 [01:47<00:25, 25.07s/it]
100%|##########| 5/5 [02:19<00:00, 27.40s/it]
100%|##########| 5/5 [02:19<00:00, 27.90s/it]
Graph.
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
piv.plot(title="Speedup PCA with ONNX (lower better)",
logx=True, logy=True, ax=ax[0])
piv2 = piv.copy()
for c in piv2.columns:
piv2[c] /= piv['sklearn']
print(piv2)
piv2.plot(title="baseline=scikit-learn (lower better)",
logx=True, logy=True, ax=ax[1])
plt.show()
Out:
name numba numpy onnxruntime1 python sklearn
size
1 0.094160 0.367328 1.062296 0.696770 1.0
10 0.107015 0.376076 1.021664 0.622485 1.0
100 0.147017 0.397408 1.028905 0.587404 1.0
1000 0.329933 0.486672 0.919924 0.607708 1.0
10000 0.622233 0.659963 0.501087 0.784828 1.0
100000 0.553091 0.688183 0.322086 0.717848 1.0
200000 0.581378 0.658858 0.305190 0.716115 1.0
Total running time of the script: ( 2 minutes 31.004 seconds)