Coverage for mlprodict/asv_benchmark/_create_asv_helper.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file Functions to creates a benchmark based on :epkg:`asv`
3for many regressors and classifiers.
4"""
5import os
6import textwrap
7import hashlib
8try:
9 from ..onnx_tools.optim.sklearn_helper import set_n_jobs
10except (ValueError, ImportError): # pragma: no cover
11 from mlprodict.onnx_tools.optim.sklearn_helper import set_n_jobs
13# exec function does not import models but potentially
14# requires all specific models used to defines scenarios
15try:
16 from ..onnxrt.validate.validate_scenarios import * # pylint: disable=W0614,W0401
17except (ValueError, ImportError): # pragma: no cover
18 # Skips this step if used in a benchmark.
19 pass
22default_asv_conf = {
23 "version": 1,
24 "project": "mlprodict",
25 "project_url": "http://www.xavierdupre.fr/app/mlprodict/helpsphinx/index.html",
26 "repo": "https://github.com/sdpython/mlprodict.git",
27 "repo_subdir": "",
28 "install_command": ["python -mpip install {wheel_file}"],
29 "uninstall_command": ["return-code=any python -mpip uninstall -y {project}"],
30 "build_command": [
31 "python setup.py build",
32 "PIP_NO_BUILD_ISOLATION=false python -mpip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}"
33 ],
34 "branches": ["master"],
35 "environment_type": "virtualenv",
36 "install_timeout": 600,
37 "show_commit_url": "https://github.com/sdpython/mlprodict/commit/",
38 # "pythons": ["__PYVER__"],
39 "matrix": {
40 "cython": [],
41 "jinja2": [],
42 "joblib": [],
43 "lightgbm": [],
44 "mlinsights": [],
45 "numpy": [],
46 "onnx": ["http://localhost:8067/simple/"],
47 "onnxruntime": ["http://localhost:8067/simple/"],
48 "pandas": [],
49 "Pillow": [],
50 "pybind11": [],
51 "pyquickhelper": [],
52 "scipy": [],
53 # "git+https://github.com/xadupre/onnxconverter-common.git@jenkins"],
54 "onnxconverter-common": ["http://localhost:8067/simple/"],
55 # "git+https://github.com/xadupre/sklearn-onnx.git@jenkins"],
56 "skl2onnx": ["http://localhost:8067/simple/"],
57 # "git+https://github.com/scikit-learn/scikit-learn.git"],
58 "scikit-learn": ["http://localhost:8067/simple/"],
59 "xgboost": [],
60 },
61 "benchmark_dir": "benches",
62 "env_dir": "env",
63 "results_dir": "results",
64 "html_dir": "html",
65}
67flask_helper = """
68'''
69Local ASV files do no properly render in a browser,
70it needs to be served through a server.
71'''
72import os.path
73from flask import Flask, Response
75app = Flask(__name__)
76app.config.from_object(__name__)
79def root_dir():
80 return os.path.join(os.path.abspath(os.path.dirname(__file__)), "..", "html")
83def get_file(filename): # pragma: no cover
84 try:
85 src = os.path.join(root_dir(), filename)
86 with open(src, "r", encoding="utf-8", errors="ignore") as f:
87 return f.read()
88 except IOError as exc:
89 return str(exc)
92@app.route('/', methods=['GET'])
93def mainpage():
94 content = get_file('index.html')
95 return Response(content, mimetype="text/html")
98@app.route('/', defaults={'path': ''})
99@app.route('/<path:path>')
100def get_resource(path): # pragma: no cover
101 mimetypes = {
102 ".css": "text/css",
103 ".html": "text/html",
104 ".js": "application/javascript",
105 }
106 complete_path = os.path.join(root_dir(), path)
107 ext = os.path.splitext(path)[1]
108 mimetype = mimetypes.get(ext, "text/html")
109 content = get_file(complete_path)
110 return Response(content, mimetype=mimetype)
113if __name__ == '__main__': # pragma: no cover
114 app.run( # ssl_context=('cert.pem', 'key.pem'),
115 port=8877,
116 # host="",
117 )
118"""
120pyspy_template = """
121import sys
122sys.path.append(r"__PATH__")
123from __PYFOLD__ import __CLASSNAME__
124import time
125from datetime import datetime
128def start():
129 cl = __CLASSNAME__()
130 cl.setup_cache()
131 return cl
134def profile0(iter, cl, runtime, N, nf, opset, dtype, optim):
135 begin = time.perf_counter()
136 for i in range(0, 100):
137 cl.time_predict(runtime, N, nf, opset, dtype, optim)
138 duration = time.perf_counter() - begin
139 iter = max(100, int(25 / duration * 100)) # 25 seconds
140 return iter
143def setup_profile0(iter, cl, runtime, N, nf, opset, dtype, optim):
144 cl.setup(runtime, N, nf, opset, dtype, optim)
145 return profile0(iter, cl, runtime, N, nf, opset, dtype, optim)
148def profile(iter, cl, runtime, N, nf, opset, dtype, optim):
149 for i in range(iter):
150 cl.time_predict(runtime, N, nf, opset, dtype, optim)
151 return iter
154def setup_profile(iter, cl, runtime, N, nf, opset, dtype, optim):
155 cl.setup(runtime, N, nf, opset, dtype, optim)
156 return profile(iter, cl, runtime, N, nf, opset, dtype, optim)
159cl = start()
160iter = None
161print(datetime.now(), "begin")
162"""
165def _sklearn_subfolder(model):
166 """
167 Returns the list of subfolders for a model.
168 """
169 mod = model.__module__
170 if mod is not None and mod.startswith('mlinsights'):
171 return ['mlinsights', model.__name__] # pragma: no cover
172 if mod is not None and mod.startswith('skl2onnx.sklapi'):
173 return ['skl2onnx.sklapi', model.__name__] # pragma: no cover
174 spl = mod.split('.')
175 try:
176 pos = spl.index('sklearn')
177 except ValueError as e: # pragma: no cover
178 raise ValueError(
179 "Unable to find 'sklearn' in '{}'.".format(mod)) from e
180 res = spl[pos + 1: -1]
181 if len(res) == 0:
182 if spl[-1] == 'sklearn':
183 res = ['_externals']
184 elif spl[0] == 'sklearn':
185 res = spl[pos + 1:]
186 else:
187 raise ValueError( # pragma: no cover
188 "Unable to guess subfolder for '{}'.".format(model.__class__))
189 res.append(model.__name__)
190 return res
193def _handle_init_files(model, flat, location, verbose, location_pyspy, fLOG):
194 "Returns created, location_model, prefix_import."
195 if flat:
196 return ([], location, ".",
197 (None if location_pyspy is None else location_pyspy))
199 created = []
200 subf = _sklearn_subfolder(model)
201 subf = [_ for _ in subf if _[0] != '_' or _ == '_externals']
202 location_model = os.path.join(location, *subf)
203 prefix_import = "." * (len(subf) + 1)
204 if not os.path.exists(location_model):
205 os.makedirs(location_model)
206 for fold in [location_model, os.path.dirname(location_model),
207 os.path.dirname(os.path.dirname(location_model))]:
208 init = os.path.join(fold, '__init__.py')
209 if not os.path.exists(init):
210 with open(init, 'w') as _:
211 pass
212 created.append(init)
213 if verbose > 1 and fLOG is not None:
214 fLOG("[create_asv_benchmark] create '{}'.".format(init))
215 if location_pyspy is not None:
216 location_pyspy_model = os.path.join(location_pyspy, *subf)
217 if not os.path.exists(location_pyspy_model):
218 os.makedirs(location_pyspy_model)
219 else:
220 location_pyspy_model = None
222 return created, location_model, prefix_import, location_pyspy_model
225def _asv_class_name(model, scenario, optimisation,
226 extra, dofit, conv_options, problem,
227 shorten=True):
229 def clean_str(val):
230 s = str(val)
231 r = ""
232 for c in s:
233 if c in ",-\n":
234 r += "_"
235 continue
236 if c in ": =.+()[]{}\"'<>~":
237 continue
238 r += c
239 for k, v in {'n_estimators': 'nest',
240 'max_iter': 'mxit'}.items():
241 r = r.replace(k, v)
242 return r
244 def clean_str_list(val):
245 if val is None:
246 return "" # pragma: no cover
247 if isinstance(val, list):
248 return ".".join( # pragma: no cover
249 clean_str_list(v) for v in val if v)
250 return clean_str(val)
252 els = ['bench', model.__name__, scenario, clean_str(problem)]
253 if not dofit:
254 els.append('nofit')
255 if extra:
256 if 'random_state' in extra and extra['random_state'] == 42:
257 extra2 = extra.copy()
258 del extra2['random_state']
259 if extra2:
260 els.append(clean_str(extra2))
261 else:
262 els.append(clean_str(extra))
263 if optimisation:
264 els.append(clean_str_list(optimisation))
265 if conv_options:
266 els.append(clean_str_list(conv_options))
267 res = ".".join(els).replace("-", "_")
269 if shorten:
270 rep = {
271 'ConstantKernel': 'Cst',
272 'DotProduct': 'Dot',
273 'Exponentiation': 'Exp',
274 'ExpSineSquared': 'ExpS2',
275 'GaussianProcess': 'GaussProc',
276 'GaussianMixture': 'GaussMixt',
277 'HistGradientBoosting': 'HGB',
278 'LinearRegression': 'LinReg',
279 'LogisticRegression': 'LogReg',
280 'MultiOutput': 'MultOut',
281 'OrthogonalMatchingPursuit': 'OrthMatchPurs',
282 'PairWiseKernel': 'PW',
283 'Product': 'Prod',
284 'RationalQuadratic': 'RQ',
285 'WhiteKernel': 'WK',
286 'length_scale': 'ls',
287 'periodicity': 'pcy',
288 }
289 for k, v in rep.items():
290 res = res.replace(k, v)
292 rep = {
293 'Classifier': 'Clas',
294 'Regressor': 'Reg',
295 'KNeighbors': 'KNN',
296 'NearestNeighbors': 'kNN',
297 'RadiusNeighbors': 'RadNN',
298 }
299 for k, v in rep.items():
300 res = res.replace(k, v)
302 if len(res) > 70: # shorten filename
303 m = hashlib.sha256()
304 m.update(res.encode('utf-8'))
305 sh = m.hexdigest()
306 if len(sh) > 6:
307 sh = sh[:6]
308 res = res[:70] + sh
309 return res
312def _read_patterns():
313 """
314 Reads the testing pattern.
315 """
316 # Reads the template
317 patterns = {}
318 for suffix in ['classifier', 'classifier_raw_scores', 'regressor', 'clustering',
319 'outlier', 'trainable_transform', 'transform',
320 'multi_classifier', 'transform_positive']:
321 template_name = os.path.join(os.path.dirname(
322 __file__), "template", "skl_model_%s.py" % suffix)
323 if not os.path.exists(template_name):
324 raise FileNotFoundError( # pragma: no cover
325 "Template '{}' was not found.".format(template_name))
326 with open(template_name, "r", encoding="utf-8") as f:
327 content = f.read()
328 initial_content = '"""'.join(content.split('"""')[2:])
329 patterns[suffix] = initial_content
330 return patterns
333def _select_pattern_problem(prob, patterns):
334 """
335 Selects a benchmark type based on the problem kind.
336 """
337 if '-reg' in prob:
338 return patterns['regressor']
339 if '-cl' in prob and '-dec' in prob:
340 return patterns['classifier_raw_scores']
341 if '-cl' in prob:
342 return patterns['classifier']
343 if 'cluster' in prob:
344 return patterns['clustering']
345 if 'outlier' in prob:
346 return patterns['outlier']
347 if 'num+y-tr' in prob:
348 return patterns['trainable_transform']
349 if 'num-tr-pos' in prob:
350 return patterns['transform_positive']
351 if 'num-tr' in prob:
352 return patterns['transform']
353 if 'm-label' in prob:
354 return patterns['multi_classifier']
355 raise ValueError( # pragma: no cover
356 "Unable to guess the right pattern for '{}'.".format(prob))
359def _display_code_lines(code):
360 rows = ["%03d %s" % (i + 1, line)
361 for i, line in enumerate(code.split("\n"))]
362 return "\n".join(rows)
365def _format_dict(opts, indent):
366 """
367 Formats a dictionary as code.
368 """
369 rows = []
370 for k, v in sorted(opts.items()):
371 rows.append('%s=%r' % (k, v))
372 content = ', '.join(rows)
373 st1 = "\n".join(textwrap.wrap(content))
374 return textwrap.indent(st1, prefix=' ' * indent)
377def _additional_imports(model_name):
378 """
379 Adds additional imports for experimental models.
380 """
381 if model_name == 'IterativeImputer':
382 return ["from sklearn.experimental import enable_iterative_imputer # pylint: disable=W0611"]
383 return None
386def add_model_import_init(
387 class_content, model, optimisation=None,
388 extra=None, conv_options=None):
389 """
390 Modifies a template such as @see cl TemplateBenchmarkClassifier
391 with code associated to the model *model*.
393 @param class_content template (as a string)
394 @param model model class
395 @param optimisation model optimisation
396 @param extra addition parameter to the constructor
397 @param conv_options options for the conversion to ONNX
398 @returm modified template
399 """
400 add_imports = []
401 add_methods = []
402 add_params = ["par_modelname = '%s'" % model.__name__,
403 "par_extra = %r" % extra]
405 # additional methods and imports
406 if optimisation is not None:
407 add_imports.append(
408 'from mlprodict.onnx_tools.optim import onnx_optimisations')
409 if optimisation == 'onnx':
410 add_methods.append(textwrap.dedent('''
411 def _optimize_onnx(self, onx):
412 return onnx_optimisations(onx)'''))
413 add_params.append('par_optimonnx = True')
414 elif isinstance(optimisation, dict):
415 add_methods.append(textwrap.dedent('''
416 def _optimize_onnx(self, onx):
417 return onnx_optimisations(onx, self.par_optims)'''))
418 add_params.append('par_optims = {}'.format(
419 _format_dict(optimisation, indent=4)))
420 else:
421 raise ValueError( # pragma: no cover
422 "Unable to interpret optimisation {}.".format(optimisation))
424 # look for import place
425 lines = class_content.split('\n')
426 keep = None
427 for pos, line in enumerate(lines):
428 if "# Import specific to this model." in line:
429 keep = pos
430 break
431 if keep is None:
432 raise RuntimeError( # pragma: no cover
433 "Unable to locate where to insert import in\n{}\n".format(
434 class_content))
436 # imports
437 loc_class = model.__module__
438 sub = loc_class.split('.')
439 if 'sklearn' not in sub:
440 mod = loc_class
441 else:
442 skl = sub.index('sklearn')
443 if skl == 0:
444 if sub[-1].startswith("_"):
445 mod = '.'.join(sub[skl:-1])
446 else:
447 mod = '.'.join(sub[skl:])
448 else:
449 mod = '.'.join(sub[:-1])
451 exp_imports = _additional_imports(model.__name__)
452 if exp_imports:
453 add_imports.extend(exp_imports)
454 imp_inst = (
455 "try:\n from {0} import {1}\nexcept ImportError:\n {1} = None"
456 "".format(mod, model.__name__))
457 add_imports.append(imp_inst)
458 add_imports.append("# __IMPORTS__")
459 lines[keep + 1] = "\n".join(add_imports)
460 content = "\n".join(lines)
462 # _create_model
463 content = content.split('def _create_model(self):',
464 maxsplit=1)[0].strip(' \n')
465 lines = [content, "", " def _create_model(self):"]
466 if extra is not None and len(extra) > 0:
467 lines.append(" return {}(".format(model.__name__))
468 lines.append(_format_dict(set_n_jobs(model, extra), 12))
469 lines.append(" )")
470 else:
471 lines.append(" return {}()".format(model.__name__))
472 lines.append("")
474 # methods
475 for meth in add_methods:
476 lines.append(textwrap.indent(meth, ' '))
477 lines.append('')
479 # end
480 return "\n".join(lines), add_params
483def find_missing_sklearn_imports(pieces):
484 """
485 Finds in :epkg:`scikit-learn` the missing pieces.
487 @param pieces list of names in scikit-learn
488 @return list of corresponding imports
489 """
490 res = {}
491 for piece in pieces:
492 mod = find_sklearn_module(piece)
493 if mod not in res:
494 res[mod] = []
495 res[mod].append(piece)
497 lines = []
498 for k, v in res.items():
499 lines.append("from {} import {}".format(
500 k, ", ".join(sorted(v))))
501 return lines
504def find_sklearn_module(piece):
505 """
506 Finds the corresponding modulee for an element of :epkg:`scikit-learn`.
508 @param piece name to import
509 @return module name
511 The implementation is not intelligence and should
512 be improved. It is a kind of white list.
513 """
514 glo = globals()
515 if piece in {'LinearRegression', 'LogisticRegression',
516 'SGDClassifier'}:
517 import sklearn.linear_model
518 glo[piece] = getattr(sklearn.linear_model, piece)
519 return "sklearn.linear_model"
520 if piece in {'DecisionTreeRegressor', 'DecisionTreeClassifier'}:
521 import sklearn.tree
522 glo[piece] = getattr(sklearn.tree, piece)
523 return "sklearn.tree"
524 if piece in {'ExpSineSquared', 'DotProduct', 'RationalQuadratic', 'RBF'}:
525 import sklearn.gaussian_process.kernels
526 glo[piece] = getattr(sklearn.gaussian_process.kernels, piece)
527 return "sklearn.gaussian_process.kernels"
528 if piece in {'LinearSVC', 'LinearSVR', 'NuSVR', 'SVR', 'SVC', 'NuSVC'}: # pragma: no cover
529 import sklearn.svm
530 glo[piece] = getattr(sklearn.svm, piece)
531 return "sklearn.svm"
532 if piece in {'KMeans'}: # pragma: no cover
533 import sklearn.cluster
534 glo[piece] = getattr(sklearn.cluster, piece)
535 return "sklearn.cluster"
536 if piece in {'OneVsRestClassifier', 'OneVsOneClassifier'}: # pragma: no cover
537 import sklearn.multiclass
538 glo[piece] = getattr(sklearn.multiclass, piece)
539 return "sklearn.multiclass"
540 raise ValueError( # pragma: no cover
541 "Unable to find module to import for '{}'.".format(piece))