Coverage for mlprodict/onnxrt/validate/validate_benchmark_replay.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Measures time processing for ONNX models.
4"""
5import pickle
6import os
7import sklearn
8from ...tools.ort_wrapper import InferenceSession
9from .. import OnnxInference
10from .validate_helper import default_time_kwargs, measure_time, _multiply_time_kwargs
11from .validate_benchmark import make_n_rows
14class SimplifiedOnnxInference:
15 """
16 Simple wrapper around InferenceSession which imitates
17 @see cl OnnxInference. It only enable *CPUExecutionProvider*.
19 :param runtime: see :class:`InferenceSession
20 <mlprodict.tools.ort_wrapper.InferenceSession>`
21 """
23 def __init__(self, ort, runtime='onnxruntime'):
24 self.sess = InferenceSession(ort, runtime=runtime)
26 @property
27 def input_names(self):
28 "Returns InferenceSession input names."
29 return [_.name for _ in self.sess.get_inputs()]
31 def run(self, input):
32 "Calls InferenceSession.run."
33 return self.sess.run(None, input)
36def enumerate_benchmark_replay(folder, runtime='python', time_kwargs=None,
37 skip_long_test=True, time_kwargs_fact=None,
38 time_limit=4, verbose=1, fLOG=None):
39 """
40 Replays a benchmark stored with function
41 :func:`enumerate_validated_operator_opsets
42 <mlprodict.onnxrt.validate.validate.enumerate_validated_operator_opsets>`
43 or command line :ref:`validate_runtime <l-cmd-validate_runtime>`.
44 Enumerates the results.
46 @param folder folder where to find pickled files, all files must have
47 *pkl* or *pickle* extension
48 @param runtime runtime or runtimes
49 @param time_kwargs to define a more precise way to measure a model
50 @param skip_long_test skips tests for high values of N if they seem too long
51 @param time_kwargs_fact see :func:`_multiply_time_kwargs <mlprodict.onnxrt.validate.validate_helper._multiply_time_kwargs>`
52 @param time_limit to skip the rest of the test after this limit (in second)
53 @param verbose if >= 1, uses :epkg:`tqdm`
54 @param fLOG logging function
55 @return iterator on results
56 """
57 from onnxruntime.capi._pybind_state import Fail as OrtFail # pylint: disable=E0611
59 files = [_ for _ in os.listdir(folder) if _.endswith(
60 ".pkl") or _.endswith("_.pickle")]
61 if len(files) == 0:
62 raise FileNotFoundError(
63 "Unable to find any file in folder '{}'.".format(folder))
65 if time_kwargs in (None, ''):
66 time_kwargs = default_time_kwargs()
68 if isinstance(runtime, str):
69 runtime = runtime.split(",")
71 loop = files
72 if verbose >= 1:
73 try:
74 from tqdm import tqdm
75 loop = tqdm(files)
76 except ImportError: # pragma: no cover
77 pass
79 for pkl in loop:
80 if "ERROR" in pkl:
81 # An error.
82 if verbose >= 2 and fLOG is not None: # pragma: no cover
83 fLOG( # pragma: no cover
84 "[enumerate_benchmark_replay] skip '{}'.".format(pkl))
85 continue # pragma: no cover
86 if verbose >= 2 and fLOG is not None:
87 fLOG("[enumerate_benchmark_replay] process '{}'.".format(pkl))
88 row = {}
89 with open(os.path.join(folder, pkl), 'rb') as f:
90 obj = pickle.load(f)
91 X_test = obj['X_test']
92 ort_test = obj['Xort_test']
93 onx = obj['onnx_bytes']
94 model = obj['skl_model']
95 tkw = _multiply_time_kwargs(time_kwargs, time_kwargs_fact, model)
96 row['folder'] = folder
97 row['filename'] = pkl
98 row['n_features'] = X_test.shape[1]
100 for key in ['assume_finite', 'conv_options',
101 'init_types', 'idtype', 'method_name', 'n_features',
102 'name', 'optim', 'opset', 'predict_kwargs',
103 'output_index', 'problem', 'scenario']:
104 row[key] = obj['obs_op'][key]
106 # 'bench-batch',
107 # 'bench-skl',
109 oinfs = {}
110 for rt in runtime:
111 if rt == 'onnxruntime':
112 try:
113 oinfs[rt] = SimplifiedOnnxInference(onx)
114 except (OrtFail, RuntimeError) as e: # pragma: no cover
115 row['ERROR'] = str(e)
116 oinfs[rt] = None
117 else:
118 try:
119 oinfs[rt] = OnnxInference(
120 onx, runtime=rt, runtime_options=dict(
121 log_severity_level=3))
122 except (OrtFail, RuntimeError) as e: # pragma: no cover
123 row['ERROR'] = str(e)
124 oinfs[rt] = None
126 for k, v in sorted(tkw.items()):
127 if verbose >= 3 and fLOG is not None:
128 fLOG( # pragma: no cover
129 "[enumerate_benchmark_replay] process n_rows={} - {}".format(k, v))
130 xt = make_n_rows(X_test, k)
131 number = v['number']
132 repeat = v['repeat']
134 meth = getattr(model, row['method_name'])
135 with sklearn.config_context(assume_finite=row['assume_finite']):
136 skl = measure_time(lambda x: meth(x), xt,
137 number=number, repeat=repeat,
138 div_by_number=True)
139 if verbose >= 4 and fLOG is not None:
140 fLOG( # pragma: no cover
141 "[enumerate_benchmark_replay] skl={}".format(skl))
142 row['%d-skl-details' % k] = skl
143 row['%d-skl' % k] = skl['average']
145 xto = make_n_rows(ort_test, k)
146 for rt in runtime:
147 oinf = oinfs[rt]
148 if oinf is None:
149 continue # pragma: no cover
150 if len(oinf.input_names) != 1:
151 raise NotImplementedError( # pragma: no cover
152 "This function only allows one input not {}".format(
153 len(oinf.input_names)))
154 name = oinf.input_names[0]
155 ort = measure_time(lambda x: oinf.run({name: x}), xto,
156 number=number, repeat=repeat,
157 div_by_number=True)
158 if verbose >= 4 and fLOG is not None:
159 fLOG( # pragma: no cover
160 "[enumerate_benchmark_replay] {}={}".format(rt, ort))
161 row['%d-%s-detail' % (k, rt)] = ort
162 row['%d-%s' % (k, rt)] = ort['average']
163 yield row