Coverage for mlprodict/onnxrt/validate/validate_helper.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Validates runtime for many :epkg:`scikit-learn` operators.
4The submodule relies on :epkg:`onnxconverter_common`,
5:epkg:`sklearn-onnx`.
6"""
7import math
8import copy
9import os
10import warnings
11from importlib import import_module
12import pickle
13from time import perf_counter
14import numpy
15from cpyquickhelper.numbers import measure_time as _c_measure_time
16from sklearn.base import BaseEstimator
17from sklearn.linear_model._base import LinearModel
18from sklearn.model_selection import train_test_split
19from sklearn import __all__ as sklearn__all__, __version__ as sklearn_version
20from .validate_problems import _problems
23class RuntimeBadResultsError(RuntimeError):
24 """
25 Raised when the results are too different from
26 :epkg:`scikit-learn`.
27 """
29 def __init__(self, msg, obs):
30 """
31 :param msg: to display
32 :param obs: observations
33 """
34 RuntimeError.__init__(self, msg)
35 self.obs = obs
38def _dictionary2str(di):
39 el = []
40 for k in sorted(di):
41 el.append('{}={}'.format(k, di[k]))
42 return '/'.join(el)
45def modules_list():
46 """
47 Returns modules and versions currently used.
49 .. runpython::
50 :showcode:
51 :rst:
52 :warningout: DeprecationWarning
54 from mlprodict.onnxrt.validate.validate_helper import modules_list
55 from pyquickhelper.pandashelper import df2rst
56 from pandas import DataFrame
57 print(df2rst(DataFrame(modules_list())))
58 """
59 def try_import(name):
60 try:
61 mod = import_module(name)
62 except ImportError: # pragma: no cover
63 return None
64 return (dict(name=name, version=mod.__version__)
65 if hasattr(mod, '__version__') else dict(name=name))
67 rows = []
68 for name in sorted(['pandas', 'numpy', 'sklearn', 'mlprodict',
69 'skl2onnx', 'onnxmltools', 'onnx', 'onnxruntime',
70 'scipy']):
71 res = try_import(name)
72 if res is not None:
73 rows.append(res)
74 return rows
77def _dispsimple(arr, fLOG):
78 if isinstance(arr, (tuple, list)):
79 for i, a in enumerate(arr):
80 fLOG("output %d" % i)
81 _dispsimple(a, fLOG)
82 elif hasattr(arr, 'shape'):
83 if len(arr.shape) == 1:
84 threshold = 8
85 else:
86 threshold = min(
87 50, min(50 // arr.shape[1], 8) * arr.shape[1])
88 fLOG(numpy.array2string(arr, max_line_width=120,
89 suppress_small=True,
90 threshold=threshold))
91 else: # pragma: no cover
92 s = str(arr)
93 if len(s) > 50:
94 s = s[:50] + "..."
95 fLOG(s)
98def _merge_options(all_conv_options, aoptions):
99 if aoptions is None:
100 return copy.deepcopy(all_conv_options)
101 if not isinstance(aoptions, dict):
102 return copy.deepcopy(aoptions) # pragma: no cover
103 merged = {}
104 for k, v in all_conv_options.items():
105 if k in aoptions:
106 merged[k] = _merge_options(v, aoptions[k])
107 else:
108 merged[k] = copy.deepcopy(v)
109 for k, v in aoptions.items():
110 if k in all_conv_options:
111 continue
112 merged[k] = copy.deepcopy(v)
113 return merged
116def sklearn_operators(subfolder=None, extended=False,
117 experimental=True):
118 """
119 Builds the list of operators from :epkg:`scikit-learn`.
120 The function goes through the list of submodule
121 and get the list of class which inherit from
122 :epkg:`scikit-learn:base:BaseEstimator`.
124 :param subfolder: look into only one subfolder
125 :param extended: extends the list to the list of operators
126 this package implements a converter for
127 :param experimental: includes experimental module from
128 :epkg:`scikit-learn` (see `sklearn.experimental
129 <https://github.com/scikit-learn/scikit-learn/
130 tree/master/sklearn/experimental>`_)
131 :return: the list of found operators
132 """
133 if experimental:
134 from sklearn.experimental import ( # pylint: disable=W0611
135 enable_hist_gradient_boosting,
136 enable_iterative_imputer)
138 subfolders = sklearn__all__ + ['mlprodict.onnx_conv']
139 found = []
140 for subm in sorted(subfolders):
141 if isinstance(subm, list):
142 continue # pragma: no cover
143 if subfolder is not None and subm != subfolder:
144 continue
146 if subm == 'feature_extraction':
147 subs = [subm, 'feature_extraction.text']
148 else:
149 subs = [subm]
151 for sub in subs:
152 if '.' in sub and sub not in {'feature_extraction.text'}:
153 name_sub = sub
154 else:
155 name_sub = "{0}.{1}".format("sklearn", sub)
156 try:
157 mod = import_module(name_sub)
158 except ModuleNotFoundError:
159 continue
161 if hasattr(mod, "register_converters"):
162 fct = getattr(mod, "register_converters")
163 cls = fct()
164 else:
165 cls = getattr(mod, "__all__", None)
166 if cls is None:
167 cls = list(mod.__dict__)
168 cls = [mod.__dict__[cl] for cl in cls]
170 for cl in cls:
171 try:
172 issub = issubclass(cl, BaseEstimator)
173 except TypeError:
174 continue
175 if cl.__name__ in {'Pipeline', 'ColumnTransformer',
176 'FeatureUnion', 'BaseEstimator',
177 'BaseEnsemble', 'BaseDecisionTree'}:
178 continue
179 if cl.__name__ in {'CustomScorerTransform'}:
180 continue
181 if (sub in {'calibration', 'dummy', 'manifold'} and
182 'Calibrated' not in cl.__name__):
183 continue
184 if issub:
185 pack = "sklearn" if sub in sklearn__all__ else cl.__module__.split('.')[
186 0]
187 found.append(
188 dict(name=cl.__name__, subfolder=sub, cl=cl, package=pack))
190 if extended:
191 from ...onnx_conv import register_converters
192 with warnings.catch_warnings():
193 warnings.simplefilter("ignore", ResourceWarning)
194 models = register_converters(True)
196 done = set(_['name'] for _ in found)
197 for m in models:
198 try:
199 name = m.__module__.split('.')
200 except AttributeError as e: # pragma: no cover
201 raise AttributeError("Unexpected value, m={}".format(m)) from e
202 sub = '.'.join(name[1:])
203 pack = name[0]
204 if m.__name__ not in done:
205 found.append(
206 dict(name=m.__name__, cl=m, package=pack, sub=sub))
208 # let's remove models which cannot predict
209 all_found = found
210 found = []
211 for mod in all_found:
212 cl = mod['cl']
213 if hasattr(cl, 'fit_predict') and not hasattr(cl, 'predict'):
214 continue
215 if hasattr(cl, 'fit_transform') and not hasattr(cl, 'transform'):
216 continue
217 if (not hasattr(cl, 'transform') and
218 not hasattr(cl, 'predict') and
219 not hasattr(cl, 'decision_function')):
220 continue
221 found.append(mod)
222 return found
225def _measure_time(fct, repeat=1, number=1, first_run=True):
226 """
227 Measures the execution time for a function.
229 :param fct: function to measure
230 :param repeat: number of times to repeat
231 :param number: number of times between two measures
232 :param first_run: if True, runs the function once before measuring
233 :return: last result, average, values
234 """
235 res = None
236 values = []
237 if first_run:
238 fct()
239 for __ in range(repeat):
240 begin = perf_counter()
241 for _ in range(number):
242 res = fct()
243 end = perf_counter()
244 values.append(end - begin)
245 if repeat * number == 1:
246 return res, values[0], values
247 return res, sum(values) / (repeat * number), values # pragma: no cover
250def _shape_exc(obj):
251 if hasattr(obj, 'shape'):
252 return obj.shape
253 if isinstance(obj, (list, dict, tuple)):
254 return "[{%d}]" % len(obj)
255 return None
258def dump_into_folder(dump_folder, obs_op=None, is_error=True,
259 **kwargs):
260 """
261 Dumps information when an error was detected
262 using :epkg:`*py:pickle`.
264 :param dump_folder: dump_folder
265 :param obs_op: obs_op (information)
266 :param is_error: is it an error or not?
267 :param kwargs: additional parameters
268 :return: name
269 """
270 if dump_folder is None:
271 raise ValueError("dump_folder cannot be None.")
272 optim = obs_op.get('optim', '')
273 optim = str(optim)
274 optim = optim.replace("<class 'sklearn.", "")
275 optim = optim.replace("<class '", "")
276 optim = optim.replace(" ", "")
277 optim = optim.replace(">", "")
278 optim = optim.replace("=", "")
279 optim = optim.replace("{", "")
280 optim = optim.replace("}", "")
281 optim = optim.replace(":", "")
282 optim = optim.replace("'", "")
283 optim = optim.replace("/", "")
284 optim = optim.replace("\\", "")
285 parts = (obs_op['runtime'], obs_op['name'], obs_op['scenario'],
286 obs_op['problem'], optim,
287 "op" + str(obs_op.get('opset', '-')),
288 "nf" + str(obs_op.get('n_features', '-')))
289 name = "dump-{}-{}.pkl".format(
290 "ERROR" if is_error else "i",
291 "-".join(map(str, parts)))
292 name = os.path.join(dump_folder, name)
293 obs_op = obs_op.copy()
294 fcts = [k for k in obs_op if k.startswith('lambda')]
295 for fct in fcts:
296 del obs_op[fct]
297 kwargs.update({'obs_op': obs_op})
298 with open(name, "wb") as f:
299 pickle.dump(kwargs, f)
300 return name
303def default_time_kwargs():
304 """
305 Returns default values *number* and *repeat* to measure
306 the execution of a function.
308 .. runpython::
309 :showcode:
310 :warningout: DeprecationWarning
312 from mlprodict.onnxrt.validate.validate_helper import default_time_kwargs
313 import pprint
314 pprint.pprint(default_time_kwargs())
316 keys define the number of rows,
317 values defines *number* and *repeat*.
318 """
319 return {
320 1: dict(number=15, repeat=20),
321 10: dict(number=10, repeat=20),
322 100: dict(number=4, repeat=10),
323 1000: dict(number=4, repeat=4),
324 10000: dict(number=2, repeat=2),
325 }
328def measure_time(stmt, x, repeat=10, number=50, div_by_number=False,
329 first_run=True, max_time=None):
330 """
331 Measures a statement and returns the results as a dictionary.
333 :param stmt: string
334 :param x: matrix
335 :param repeat: average over *repeat* experiment
336 :param number: number of executions in one row
337 :param div_by_number: divide by the number of executions
338 :param first_run: if True, runs the function once before measuring
339 :param max_time: execute the statement until the total goes
340 beyond this time (approximatively), *repeat* is ignored,
341 *div_by_number* must be set to True
342 :return: dictionary
344 See `Timer.repeat <https://docs.python.org/3/library/timeit.html?timeit.Timer.repeat>`_
345 for a better understanding of parameter *repeat* and *number*.
346 The function returns a duration corresponding to
347 *number* times the execution of the main statement.
348 """
349 if x is None:
350 raise ValueError("x cannot be None") # pragma: no cover
352 def fct():
353 stmt(x)
355 if first_run:
356 try:
357 fct()
358 except RuntimeError as e: # pragma: no cover
359 raise RuntimeError("{}-{}".format(type(x), x.dtype)) from e
361 return _c_measure_time(fct, context={}, repeat=repeat, number=number,
362 div_by_number=div_by_number, max_time=max_time)
365def _multiply_time_kwargs(time_kwargs, time_kwargs_fact, inst):
366 """
367 Multiplies values in *time_kwargs* following strategy
368 *time_kwargs_fact* for a given model *inst*.
370 :param time_kwargs: see below
371 :param time_kwargs_fact: see below
372 :param inst: :epkg:`scikit-learn` model
373 :return: new *time_kwargs*
375 Possible values for *time_kwargs_fact*:
377 - a integer: multiplies *number* by this number
378 - `'lin'`: multiplies value *number* for linear models depending
379 on the number of rows to process (:math:`\\propto 1/\\log_{10}(n)`)
381 .. runpython::
382 :showcode:
383 :warningout: DeprecationWarning
385 from pprint import pprint
386 from sklearn.linear_model import LinearRegression
387 from mlprodict.onnxrt.validate.validate_helper import (
388 default_time_kwargs, _multiply_time_kwargs)
390 lr = LinearRegression()
391 kw = default_time_kwargs()
392 pprint(kw)
394 kw2 = _multiply_time_kwargs(kw, 'lin', lr)
395 pprint(kw2)
396 """
397 if time_kwargs is None:
398 raise ValueError("time_kwargs cannot be None.") # pragma: no cover
399 if time_kwargs_fact in ('', None):
400 return time_kwargs
401 try:
402 vi = int(time_kwargs_fact)
403 time_kwargs_fact = vi
404 except (TypeError, ValueError):
405 pass
406 if isinstance(time_kwargs_fact, int):
407 time_kwargs_modified = copy.deepcopy(time_kwargs)
408 for k in time_kwargs_modified:
409 time_kwargs_modified[k]['number'] *= time_kwargs_fact
410 return time_kwargs_modified
411 if time_kwargs_fact == 'lin':
412 if isinstance(inst, LinearModel):
413 time_kwargs_modified = copy.deepcopy(time_kwargs)
414 for k in time_kwargs_modified:
415 kl = max(int(math.log(k) / math.log(10) + 1e-5), 1)
416 f = max(int(10 / kl + 0.5), 1)
417 time_kwargs_modified[k]['number'] *= f
418 time_kwargs_modified[k]['repeat'] *= 1
419 return time_kwargs_modified
420 return time_kwargs
421 raise ValueError( # pragma: no cover
422 "Unable to interpret time_kwargs_fact='{}'.".format(
423 time_kwargs_fact))
426def _get_problem_data(prob, n_features):
427 data_problem = _problems[prob](n_features=n_features)
428 if len(data_problem) == 6:
429 X_, y_, init_types, method, output_index, Xort_ = data_problem
430 dofit = True
431 elif len(data_problem) == 7:
432 X_, y_, init_types, method, output_index, Xort_, dofit = data_problem
433 else:
434 raise RuntimeError( # pragma: no cover
435 "Unable to interpret problem '{}'.".format(prob))
436 if (len(X_.shape) == 2 and X_.shape[1] != n_features and
437 n_features is not None):
438 raise RuntimeError( # pragma: no cover
439 "Problem '{}' with n_features={} returned {} features"
440 "(func={}).".format(prob, n_features, X_.shape[1],
441 _problems[prob]))
442 if y_ is None:
443 (X_train, X_test, Xort_train, # pylint: disable=W0612
444 Xort_test) = train_test_split(
445 X_, Xort_, random_state=42)
446 y_train, y_test = None, None
447 else:
448 (X_train, X_test, y_train, y_test, # pylint: disable=W0612
449 Xort_train, Xort_test) = train_test_split(
450 X_, y_, Xort_, random_state=42)
451 if isinstance(init_types, tuple):
452 init_types, conv_options = init_types
453 else:
454 conv_options = None
456 if isinstance(method, tuple):
457 method_name, predict_kwargs = method
458 else:
459 method_name = method
460 predict_kwargs = {}
462 return (X_train, X_test, y_train,
463 y_test, Xort_test,
464 init_types, conv_options, method_name,
465 output_index, dofit, predict_kwargs)