Coverage for mlprodict/cli/validate.py: 95%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Command line about validation of prediction runtime.
4"""
5import os
6from io import StringIO
7from logging import getLogger
8import warnings
9import json
10from multiprocessing import Pool
11from pandas import DataFrame, read_csv, concat
12from sklearn.exceptions import ConvergenceWarning
15def benchmark_doc(runtime, black_list=None, white_list=None,
16 out_raw='bench_raw.xlsx', out_summary="bench_summary.xlsx",
17 dump_dir='dump', fLOG=print, verbose=0):
18 """
19 Runs the benchmark published into the documentation
20 (see :ref:`l-onnx-bench-onnxruntime1` and
21 :ref:`l-onnx-bench-python_compiled`).
23 :param runtime: runtime (python, python_compiled,
24 onnxruntime1, onnxruntime2)
25 :param black_list: models to skip, None for none
26 (comma separated list)
27 :param white_list: models to benchmark, None for all
28 (comma separated list)
29 :param out_raw: all results are saved in that file
30 :param out_summary: all results are summarized in that file
31 :param dump_dir: folder where to dump intermediate results
32 :param fLOG: logging function
33 :param verbose: verbosity
34 :return: list of created files
35 """
36 def _save(df, name):
37 ext = os.path.splitext(name)[-1]
38 if ext == '.xlsx':
39 df.to_excel(name, index=False)
40 elif ext == '.csv':
41 df.to_csv(name, index=False)
42 else:
43 raise ValueError( # pragma: no cover
44 "Unexpected extension in %r." % name)
45 if verbose > 1:
46 fLOG( # pragma: no cover
47 "[mlprodict] wrote '{}'".format(name))
49 from pyquickhelper.loghelper import run_cmd
50 from pyquickhelper.loghelper.run_cmd import get_interpreter_path
51 from tqdm import tqdm
52 from ..onnxrt.validate.validate_helper import sklearn_operators
53 from ..onnx_conv import register_converters, register_rewritten_operators
54 register_converters()
55 try:
56 register_rewritten_operators()
57 except KeyError: # pragma: no cover
58 warnings.warn("converter for HistGradientBoosting* not not exist. "
59 "Upgrade sklearn-onnx")
61 if black_list is None:
62 black_list = []
63 else:
64 black_list = black_list.split(',')
65 if white_list is None:
66 white_list = []
67 else:
68 white_list = white_list.split(',')
70 filenames = []
71 skls = sklearn_operators(extended=True)
72 skls = [_['name'] for _ in skls]
73 if white_list:
74 skls = [_ for _ in skls if _ in white_list]
75 skls.sort()
76 if verbose > 0:
77 pbar = tqdm(skls)
78 else:
79 pbar = skls
80 for op in pbar:
81 if black_list is not None and op in black_list:
82 continue
83 if verbose > 0:
84 pbar.set_description( # pragma: no cover
85 "[%s]" % (op + " " * (25 - len(op))))
87 loop_out_raw = os.path.join(
88 dump_dir, "bench_raw_%s_%s.csv" % (runtime, op))
89 loop_out_sum = os.path.join(
90 dump_dir, "bench_sum_%s_%s.csv" % (runtime, op))
91 cmd = ('{0} -m mlprodict validate_runtime --verbose=0 --out_raw={1} --out_summary={2} '
92 '--benchmark=1 --dump_folder={3} --runtime={4} --models={5}'.format(
93 get_interpreter_path(), loop_out_raw, loop_out_sum, dump_dir, runtime, op))
94 if verbose > 1:
95 fLOG("[mlprodict] cmd '{}'.".format(cmd)) # pragma: no cover
96 out, err = run_cmd(cmd, wait=True, fLOG=None)
97 if not os.path.exists(loop_out_sum): # pragma: no cover
98 if verbose > 2:
99 fLOG("[mlprodict] unable to find '{}'.".format(loop_out_sum))
100 if verbose > 1:
101 fLOG("[mlprodict] cmd '{}'".format(cmd))
102 fLOG("[mlprodict] unable to find '{}'".format(loop_out_sum))
103 msg = "Unable to find '{}'\n--CMD--\n{}\n--OUT--\n{}\n--ERR--\n{}".format(
104 loop_out_sum, cmd, out, err)
105 if verbose > 1:
106 fLOG(msg)
107 rows = [{'name': op, 'scenario': 'CRASH',
108 'ERROR-msg': msg.replace("\n", " -- ")}]
109 df = DataFrame(rows)
110 df.to_csv(loop_out_sum, index=False)
111 filenames.append((loop_out_raw, loop_out_sum))
113 # concatenate summaries
114 dfs_raw = [read_csv(name[0])
115 for name in filenames if os.path.exists(name[0])]
116 dfs_sum = [read_csv(name[1])
117 for name in filenames if os.path.exists(name[1])]
118 df_raw = concat(dfs_raw, sort=False)
119 piv = concat(dfs_sum, sort=False)
121 opset_cols = [(int(oc.replace("opset", "")), oc)
122 for oc in piv.columns if 'opset' in oc]
123 opset_cols.sort(reverse=True)
124 opset_cols = [oc[1] for oc in opset_cols]
125 new_cols = opset_cols[:1]
126 bench_cols = ["RT/SKL-N=1", "N=10", "N=100",
127 "N=1000", "N=10000"]
128 new_cols.extend(["ERROR-msg", "name", "problem", "scenario", 'optim'])
129 new_cols.extend(bench_cols)
130 new_cols.extend(opset_cols[1:])
131 for c in bench_cols:
132 new_cols.append(c + '-min')
133 new_cols.append(c + '-max')
134 for c in piv.columns:
135 if c.startswith("skl_") or c.startswith("onx_"):
136 new_cols.append(c)
137 new_cols = [_ for _ in new_cols if _ in piv.columns]
138 piv = piv[new_cols]
140 _save(piv, out_summary)
141 _save(df_raw, out_raw)
142 return filenames
145def validate_runtime(verbose=1, opset_min=-1, opset_max="",
146 check_runtime=True, runtime='python', debug=False,
147 models=None, out_raw="model_onnx_raw.xlsx",
148 out_summary="model_onnx_summary.xlsx",
149 dump_folder=None, dump_all=False, benchmark=False,
150 catch_warnings=True, assume_finite=True,
151 versions=False, skip_models=None,
152 extended_list=True, separate_process=False,
153 time_kwargs=None, n_features=None, fLOG=print,
154 out_graph=None, force_return=False,
155 dtype=None, skip_long_test=False,
156 number=1, repeat=1, time_kwargs_fact='lin',
157 time_limit=4, n_jobs=0):
158 """
159 Walks through most of :epkg:`scikit-learn` operators
160 or model or predictor or transformer, tries to convert
161 them into :epkg:`ONNX` and computes the predictions
162 with a specific runtime.
164 :param verbose: integer from 0 (None) to 2 (full verbose)
165 :param opset_min: tries every conversion from this minimum opset,
166 -1 to get the current opset
167 :param opset_max: tries every conversion up to maximum opset,
168 -1 to get the current opset
169 :param check_runtime: to check the runtime
170 and not only the conversion
171 :param runtime: runtime to check, python,
172 onnxruntime1 to check :epkg:`onnxruntime`,
173 onnxruntime2 to check every *ONNX* node independently
174 with onnxruntime, many runtime can be checked at the same time
175 if the value is a comma separated list
176 :param models: comma separated list of models to test or empty
177 string to test them all
178 :param skip_models: models to skip
179 :param debug: stops whenever an exception is raised,
180 only if *separate_process* is False
181 :param out_raw: output raw results into this file (excel format)
182 :param out_summary: output an aggregated view into this file (excel format)
183 :param dump_folder: folder where to dump information (pickle)
184 in case of mismatch
185 :param dump_all: dumps all models, not only the failing ones
186 :param benchmark: run benchmark
187 :param catch_warnings: catch warnings
188 :param assume_finite: See `config_context
189 <https://scikit-learn.org/stable/modules/generated/sklearn.config_context.html>`_,
190 If True, validation for finiteness will be skipped, saving time, but leading
191 to potential crashes. If False, validation for finiteness will be performed,
192 avoiding error.
193 :param versions: add columns with versions of used packages,
194 :epkg:`numpy`, :epkg:`scikit-learn`, :epkg:`onnx`, :epkg:`onnxruntime`,
195 :epkg:`sklearn-onnx`
196 :param extended_list: extends the list of :epkg:`scikit-learn` converters
197 with converters implemented in this module
198 :param separate_process: run every model in a separate process,
199 this option must be used to run all model in one row
200 even if one of them is crashing
201 :param time_kwargs: a dictionary which defines the number of rows and
202 the parameter *number* and *repeat* when benchmarking a model,
203 the value must follow :epkg:`json` format
204 :param n_features: change the default number of features for
205 a specific problem, it can also be a comma separated list
206 :param force_return: forces the function to return the results,
207 used when the results are produces through a separate process
208 :param out_graph: image name, to output a graph which summarizes
209 a benchmark in case it was run
210 :param dtype: '32' or '64' or None for both,
211 limits the test to one specific number types
212 :param skip_long_test: skips tests for high values of N if
213 they seem too long
214 :param number: to multiply number values in *time_kwargs*
215 :param repeat: to multiply repeat values in *time_kwargs*
216 :param time_kwargs_fact: to multiply number and repeat in
217 *time_kwargs* depending on the model
218 (see :func:`_multiply_time_kwargs <mlprodict.onnxrt.validate.validate_helper._multiply_time_kwargs>`)
219 :param time_limit: to stop benchmarking after this limit of time
220 :param n_jobs: force the number of jobs to have this value,
221 by default, it is equal to the number of CPU
222 :param fLOG: logging function
224 .. cmdref::
225 :title: Validates a runtime against scikit-learn
226 :cmd: -m mlprodict validate_runtime --help
227 :lid: l-cmd-validate_runtime
229 The command walks through all scikit-learn operators,
230 tries to convert them, checks the predictions,
231 and produces a report.
233 Example::
235 python -m mlprodict validate_runtime --models LogisticRegression,LinearRegression
237 Following example benchmarks models
238 :epkg:`sklearn:ensemble:RandomForestRegressor`,
239 :epkg:`sklearn:tree:DecisionTreeRegressor`, it compares
240 :epkg:`onnxruntime` against :epkg:`scikit-learn` for opset 10.
242 ::
244 python -m mlprodict validate_runtime -v 1 -o 10 -op 10 -c 1 -r onnxruntime1
245 -m RandomForestRegressor,DecisionTreeRegressor -out bench_onnxruntime.xlsx -b 1
247 Parameter ``--time_kwargs`` may be used to reduce or increase
248 bencharmak precisions. The following value tells the function
249 to run a benchmarks with datasets of 1 or 10 number, to repeat
250 a given number of time *number* predictions in one row.
251 The total time is divided by :math:`number \\times repeat``.
252 Parameter ``--time_kwargs_fact`` may be used to increase these
253 number for some specific models. ``'lin'`` multiplies
254 by 10 number when the model is linear.
256 ::
258 -t "{\\"1\\":{\\"number\\":10,\\"repeat\\":10},\\"10\\":{\\"number\\":5,\\"repeat\\":5}}"
260 The following example dumps every model in the list:
262 ::
264 python -m mlprodict validate_runtime --out_raw raw.csv --out_summary sum.csv
265 --models LinearRegression,LogisticRegression,DecisionTreeRegressor,DecisionTreeClassifier
266 -r python,onnxruntime1 -o 10 -op 10 -v 1 -b 1 -dum 1
267 -du model_dump -n 20,100,500 --out_graph benchmark.png --dtype 32
269 The command line generates a graph produced by function
270 :func:`plot_validate_benchmark
271 <mlprodict.onnxrt.validate.validate_graph.plot_validate_benchmark>`.
272 """
273 if separate_process:
274 return _validate_runtime_separate_process(
275 verbose=verbose, opset_min=opset_min, opset_max=opset_max,
276 check_runtime=check_runtime, runtime=runtime, debug=debug,
277 models=models, out_raw=out_raw,
278 out_summary=out_summary, dump_all=dump_all,
279 dump_folder=dump_folder, benchmark=benchmark,
280 catch_warnings=catch_warnings, assume_finite=assume_finite,
281 versions=versions, skip_models=skip_models,
282 extended_list=extended_list, time_kwargs=time_kwargs,
283 n_features=n_features, fLOG=fLOG, force_return=True,
284 out_graph=None, dtype=dtype, skip_long_test=skip_long_test,
285 time_kwargs_fact=time_kwargs_fact, time_limit=time_limit,
286 n_jobs=n_jobs)
288 from ..onnxrt.validate import enumerate_validated_operator_opsets # pylint: disable=E0402
290 if not isinstance(models, list):
291 models = (None if models in (None, "")
292 else models.strip().split(','))
293 if not isinstance(skip_models, list):
294 skip_models = ({} if skip_models in (None, "")
295 else skip_models.strip().split(','))
296 if verbose <= 1:
297 logger = getLogger('skl2onnx')
298 logger.disabled = True
299 if not dump_folder:
300 dump_folder = None
301 if dump_folder and not os.path.exists(dump_folder):
302 os.mkdir(dump_folder) # pragma: no cover
303 if dump_folder and not os.path.exists(dump_folder):
304 raise FileNotFoundError( # pragma: no cover
305 "Cannot find dump_folder '{0}'.".format(
306 dump_folder))
308 # handling parameters
309 if opset_max == "":
310 opset_max = None # pragma: no cover
311 if isinstance(opset_min, str):
312 opset_min = int(opset_min) # pragma: no cover
313 if isinstance(opset_max, str):
314 opset_max = int(opset_max)
315 if isinstance(verbose, str):
316 verbose = int(verbose) # pragma: no cover
317 if isinstance(extended_list, str):
318 extended_list = extended_list in (
319 '1', 'True', 'true') # pragma: no cover
320 if time_kwargs in (None, ''):
321 time_kwargs = None
322 if isinstance(time_kwargs, str):
323 time_kwargs = json.loads(time_kwargs)
324 # json only allows string as keys
325 time_kwargs = {int(k): v for k, v in time_kwargs.items()}
326 if isinstance(n_jobs, str):
327 n_jobs = int(n_jobs)
328 if n_jobs == 0:
329 n_jobs = None
330 if time_kwargs is not None and not isinstance(time_kwargs, dict):
331 raise ValueError( # pragma: no cover
332 "time_kwargs must be a dictionary not {}\n{}".format(
333 type(time_kwargs), time_kwargs))
334 if not isinstance(n_features, list):
335 if n_features in (None, ""):
336 n_features = None
337 elif ',' in n_features:
338 n_features = list(map(int, n_features.split(',')))
339 else:
340 n_features = int(n_features)
341 if not isinstance(runtime, list) and ',' in runtime:
342 runtime = runtime.split(',')
344 def fct_filter_exp(m, s):
345 cl = m.__name__
346 if cl in skip_models:
347 return False
348 pair = "%s[%s]" % (cl, s)
349 if pair in skip_models:
350 return False
351 return True
353 if dtype in ('', None):
354 fct_filter = fct_filter_exp
355 elif dtype == '32':
356 def fct_filter_exp2(m, p):
357 return fct_filter_exp(m, p) and '64' not in p
358 fct_filter = fct_filter_exp2
359 elif dtype == '64': # pragma: no cover
360 def fct_filter_exp3(m, p):
361 return fct_filter_exp(m, p) and '64' in p
362 fct_filter = fct_filter_exp3
363 else:
364 raise ValueError( # pragma: no cover
365 "dtype must be empty, 32, 64 not '{}'.".format(dtype))
367 # time_kwargs
369 if benchmark:
370 if time_kwargs is None:
371 from ..onnxrt.validate.validate_helper import default_time_kwargs # pylint: disable=E0402
372 time_kwargs = default_time_kwargs()
373 for _, v in time_kwargs.items():
374 v['number'] *= number
375 v['repeat'] *= repeat
376 if verbose > 0:
377 fLOG("time_kwargs=%r" % time_kwargs)
379 # body
381 def build_rows(models_):
382 rows = list(enumerate_validated_operator_opsets(
383 verbose, models=models_, fLOG=fLOG, runtime=runtime, debug=debug,
384 dump_folder=dump_folder, opset_min=opset_min, opset_max=opset_max,
385 benchmark=benchmark, assume_finite=assume_finite, versions=versions,
386 extended_list=extended_list, time_kwargs=time_kwargs, dump_all=dump_all,
387 n_features=n_features, filter_exp=fct_filter,
388 skip_long_test=skip_long_test, time_limit=time_limit,
389 time_kwargs_fact=time_kwargs_fact, n_jobs=n_jobs))
390 return rows
392 def catch_build_rows(models_):
393 if catch_warnings:
394 with warnings.catch_warnings():
395 warnings.simplefilter("ignore",
396 (UserWarning, ConvergenceWarning,
397 RuntimeWarning, FutureWarning))
398 rows = build_rows(models_)
399 else:
400 rows = build_rows(models_) # pragma: no cover
401 return rows
403 rows = catch_build_rows(models)
404 res = _finalize(rows, out_raw, out_summary,
405 verbose, models, out_graph, fLOG)
406 return res if (force_return or verbose >= 2) else None
409def _finalize(rows, out_raw, out_summary, verbose, models, out_graph, fLOG):
410 from ..onnxrt.validate import summary_report # pylint: disable=E0402
411 from ..tools.cleaning import clean_error_msg # pylint: disable=E0402
413 # Drops data which cannot be serialized.
414 for row in rows:
415 keys = []
416 for k in row:
417 if 'lambda' in k:
418 keys.append(k)
419 for k in keys:
420 del row[k]
422 df = DataFrame(rows)
424 if out_raw:
425 if verbose > 0:
426 fLOG("Saving raw_data into '{}'.".format(out_raw))
427 if os.path.splitext(out_raw)[-1] == ".xlsx":
428 df.to_excel(out_raw, index=False)
429 else:
430 clean_error_msg(df).to_csv(out_raw, index=False)
432 if df.shape[0] == 0:
433 raise RuntimeError("No result produced by the benchmark.")
434 piv = summary_report(df)
435 if 'optim' not in piv:
436 raise RuntimeError( # pragma: no cover
437 "Unable to produce a summary. Missing column in \n{}".format(
438 piv.columns))
440 if out_summary:
441 if verbose > 0:
442 fLOG("Saving summary into '{}'.".format(out_summary))
443 if os.path.splitext(out_summary)[-1] == ".xlsx":
444 piv.to_excel(out_summary, index=False)
445 else:
446 clean_error_msg(piv).to_csv(out_summary, index=False)
448 if verbose > 1 and models is not None:
449 fLOG(piv.T)
450 if out_graph is not None:
451 if verbose > 0:
452 fLOG("Saving graph into '{}'.".format(out_graph))
453 from ..plotting.plotting import plot_validate_benchmark
454 fig = plot_validate_benchmark(piv)[0]
455 fig.savefig(out_graph)
457 return rows
460def _validate_runtime_dict(kwargs):
461 return validate_runtime(**kwargs)
464def _validate_runtime_separate_process(**kwargs):
465 models = kwargs['models']
466 if models in (None, ""):
467 from ..onnxrt.validate.validate_helper import sklearn_operators # pragma: no cover
468 models = [_['name']
469 for _ in sklearn_operators(extended=True)] # pragma: no cover
470 elif not isinstance(models, list):
471 models = models.strip().split(',')
473 skip_models = kwargs['skip_models']
474 skip_models = {} if skip_models in (
475 None, "") else skip_models.strip().split(',')
477 verbose = kwargs['verbose']
478 fLOG = kwargs['fLOG']
479 all_rows = []
480 skls = [m for m in models if m not in skip_models]
481 skls.sort()
483 if verbose > 0:
484 from tqdm import tqdm
485 pbar = tqdm(skls)
486 else:
487 pbar = skls # pragma: no cover
489 for op in pbar:
490 if not isinstance(pbar, list):
491 pbar.set_description("[%s]" % (op + " " * (25 - len(op))))
493 if kwargs['out_raw']:
494 out_raw = os.path.splitext(kwargs['out_raw'])
495 out_raw = "".join([out_raw[0], "_", op, out_raw[1]])
496 else:
497 out_raw = None # pragma: no cover
499 if kwargs['out_summary']:
500 out_summary = os.path.splitext(kwargs['out_summary'])
501 out_summary = "".join([out_summary[0], "_", op, out_summary[1]])
502 else:
503 out_summary = None # pragma: no cover
505 new_kwargs = kwargs.copy()
506 if 'fLOG' in new_kwargs:
507 del new_kwargs['fLOG']
508 new_kwargs['out_raw'] = out_raw
509 new_kwargs['out_summary'] = out_summary
510 new_kwargs['models'] = op
511 new_kwargs['verbose'] = 0 # tqdm fails
512 new_kwargs['out_graph'] = None
514 with Pool(1) as p:
515 try:
516 result = p.apply_async(_validate_runtime_dict, [new_kwargs])
517 lrows = result.get(timeout=150) # timeout fixed to 150s
518 all_rows.extend(lrows)
519 except Exception as e: # pylint: disable=W0703
520 all_rows.append({ # pragma: no cover
521 'name': op, 'scenario': 'CRASH',
522 'ERROR-msg': str(e).replace("\n", " -- ")
523 })
525 return _finalize(all_rows, kwargs['out_raw'], kwargs['out_summary'],
526 verbose, models, kwargs.get('out_graph', None), fLOG)
529def latency(model, law='normal', size=1, number=10, repeat=10, max_time=0,
530 runtime="onnxruntime", device='cpu', fmt=None,
531 profiling=None, profile_output='profiling.csv'):
532 """
533 Measures the latency of a model (python API).
535 :param model: ONNX graph
536 :param law: random law used to generate fake inputs
537 :param size: batch size, it replaces the first dimension
538 of every input if it is left unknown
539 :param number: number of calls to measure
540 :param repeat: number of times to repeat the experiment
541 :param max_time: if it is > 0, it runs as many time during
542 that period of time
543 :param runtime: available runtime
544 :param device: device, `cpu`, `cuda:0` or a list of providers
545 `CPUExecutionProvider, CUDAExecutionProvider
546 :param fmt: None or `csv`, it then
547 returns a string formatted like a csv file
548 :param profiling: if True, profile the execution of every
549 node, if can be sorted by name or type,
550 the value for this parameter should e in `(None, 'name', 'type')`
551 :param profile_output: output name for the profiling
552 if profiling is specified
554 .. cmdref::
555 :title: Measures model latency
556 :cmd: -m mlprodict latency --help
557 :lid: l-cmd-latency
559 The command generates random inputs and call many times the
560 model on these inputs. It returns the processing time for one
561 iteration.
563 Example::
565 python -m mlprodict latency --model "model.onnx"
566 """
567 from ..onnxrt.validate.validate_latency import latency as _latency # pylint: disable=E0402
569 if not os.path.exists(model):
570 raise FileNotFoundError( # pragma: no cover
571 "Unable to find model %r." % model)
572 if profiling not in (None, '', 'name', 'type'):
573 raise ValueError( # pragma: no cover
574 "Unexpected value for profiling: %r." % profiling)
575 size = int(size)
576 number = int(number)
577 repeat = int(repeat)
578 if max_time in (None, 0, ""):
579 max_time = None
580 else:
581 max_time = float(max_time)
582 if max_time <= 0:
583 max_time = None
585 if law != "normal":
586 raise ValueError( # pragma: no cover
587 "Only law='normal' is supported, not %r." % law)
589 if profiling in ('name', 'type') and profile_output in (None, ''):
590 raise ValueError( # pragma: no cover
591 'profiling is enabled but profile_output is wrong (%r).'
592 '' % profile_output)
594 res = _latency(
595 model, law=law, size=size, number=number, repeat=repeat,
596 max_time=max_time, runtime=runtime, device=device,
597 profiling=profiling)
599 if profiling not in (None, ''):
600 res, gr = res
601 ext = os.path.splitext(profile_output)[-1]
602 gr = gr.reset_index(drop=False)
603 if ext == '.csv':
604 gr.to_csv(profile_output, index=False)
605 elif ext == '.xlsx': # pragma: no cover
606 gr.to_excel(profile_output, index=False)
607 else:
608 raise ValueError( # pragma: no cover
609 "Unexpected extension for profile_output=%r."
610 "" % profile_output)
612 if fmt == 'csv':
613 st = StringIO()
614 df = DataFrame([res])
615 df.to_csv(st, index=False)
616 return st.getvalue()
617 if fmt in (None, ''):
618 return res
619 raise ValueError( # pragma: no cover
620 "Unexpected value for fmt: %r." % fmt)