Note
Click here to download the full example code
Benchmark inference for scikit-learn models¶
This short code compares the execution of a couple of runtime for inference including onnxruntime. It uses examples Measure ONNX runtime performances. It is an automated process to compare the performance of a model against scikit-learn. This model is a simple model taken from all implemented by scikit-learn.
Linear Regression¶
from pandas import read_csv
from mlprodict.cli import validate_runtime
from mlprodict.plotting.plotting import plot_validate_benchmark
res = validate_runtime(
verbose=1,
out_raw="data.csv", out_summary="summary.csv",
benchmark=True, dump_folder="dump_errors",
runtime=['python', 'onnxruntime1'],
models=['LinearRegression'],
skip_models=['LinearRegression[m-reg]'],
n_features=[10, 50], dtype="32",
out_graph="bench.png",
opset_min=15, opset_max=15,
time_kwargs={
1: {"number": 50, "repeat": 50},
10: {"number": 25, "repeat": 25},
100: {"number": 20, "repeat": 20},
1000: {"number": 20, "repeat": 20},
10000: {"number": 10, "repeat": 10},
}
)
results = read_csv('summary.csv')
results
Out:
time_kwargs={1: {'number': 50, 'repeat': 50}, 10: {'number': 25, 'repeat': 25}, 100: {'number': 20, 'repeat': 20}, 1000: {'number': 20, 'repeat': 20}, 10000: {'number': 10, 'repeat': 10}}
[enumerate_validated_operator_opsets] opset in [15, 15].
0%| | 0/1 [00:00<?, ?it/s]
LinearRegression : 0%| | 0/1 [00:00<?, ?it/s][enumerate_compatible_opset] opset in [15, 15].
LinearRegression : 100%|##########| 1/1 [00:40<00:00, 40.77s/it]
LinearRegression : 100%|##########| 1/1 [00:40<00:00, 40.78s/it]
Saving raw_data into 'data.csv'.
Saving summary into 'summary.csv'.
Saving graph into 'bench.png'.
findfont: Font family ['STIXGeneral'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXGeneral'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXGeneral'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXNonUnicode'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXNonUnicode'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXNonUnicode'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXSizeOneSym'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXSizeTwoSym'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXSizeThreeSym'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXSizeFourSym'] not found. Falling back to DejaVu Sans.
findfont: Font family ['STIXSizeFiveSym'] not found. Falling back to DejaVu Sans.
findfont: Font family ['cmsy10'] not found. Falling back to DejaVu Sans.
findfont: Font family ['cmr10'] not found. Falling back to DejaVu Sans.
findfont: Font family ['cmtt10'] not found. Falling back to DejaVu Sans.
findfont: Font family ['cmmi10'] not found. Falling back to DejaVu Sans.
findfont: Font family ['cmb10'] not found. Falling back to DejaVu Sans.
findfont: Font family ['cmss10'] not found. Falling back to DejaVu Sans.
findfont: Font family ['cmex10'] not found. Falling back to DejaVu Sans.
findfont: Font family ['DejaVu Sans Display'] not found. Falling back to DejaVu Sans.
Graph.
_, ax = plot_validate_benchmark(results)
ax
# import matplotlib.pyplot as plt
# plt.show()
Out:
array([<AxesSubplot:title={'center':'RT/SKL-N=1'}>,
<AxesSubplot:title={'center':'N=10'}>,
<AxesSubplot:title={'center':'N=100'}>,
<AxesSubplot:title={'center':'N=1000'}>,
<AxesSubplot:title={'center':'N=10000'}>], dtype=object)
Total running time of the script: ( 0 minutes 57.007 seconds)