Coverage for mlprodict/asv_benchmark/create_asv.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file Functions to creates a benchmark based on :epkg:`asv`
3for many regressors and classifiers.
4"""
5import os
6import sys
7import json
8import textwrap
9import warnings
10import re
11from pyquickhelper.pycode.code_helper import remove_extra_spaces_and_pep8
12try:
13 from ._create_asv_helper import (
14 default_asv_conf,
15 flask_helper,
16 pyspy_template,
17 _handle_init_files,
18 _asv_class_name,
19 _read_patterns,
20 _select_pattern_problem,
21 _display_code_lines,
22 add_model_import_init,
23 find_missing_sklearn_imports)
24except ImportError: # pragma: no cover
25 from mlprodict.asv_benchmark._create_asv_helper import (
26 default_asv_conf,
27 flask_helper,
28 pyspy_template,
29 _handle_init_files,
30 _asv_class_name,
31 _read_patterns,
32 _select_pattern_problem,
33 _display_code_lines,
34 add_model_import_init,
35 find_missing_sklearn_imports)
37try:
38 from .. import __max_supported_opset__
39 from ..tools.asv_options_helper import (
40 shorten_onnx_options)
41 from ..onnxrt.validate.validate_helper import sklearn_operators
42 from ..onnxrt.validate.validate import (
43 _retrieve_problems_extra, _get_problem_data, _merge_options)
44except (ValueError, ImportError): # pragma: no cover
45 from mlprodict import __max_supported_opset__
46 from mlprodict.onnxrt.validate.validate_helper import sklearn_operators
47 from mlprodict.onnxrt.validate.validate import (
48 _retrieve_problems_extra, _get_problem_data, _merge_options)
49 from mlprodict.tools.asv_options_helper import shorten_onnx_options
50try:
51 from ..testing.verify_code import verify_code
52except (ValueError, ImportError): # pragma: no cover
53 from mlprodict.testing.verify_code import verify_code
55# exec function does not import models but potentially
56# requires all specific models used to define scenarios
57try:
58 from ..onnxrt.validate.validate_scenarios import * # pylint: disable=W0614,W0401
59except (ValueError, ImportError): # pragma: no cover
60 # Skips this step if used in a benchmark.
61 pass
64def create_asv_benchmark(
65 location, opset_min=-1, opset_max=None,
66 runtime=('scikit-learn', 'python_compiled'), models=None,
67 skip_models=None, extended_list=True,
68 dims=(1, 10, 100, 10000),
69 n_features=(4, 20), dtype=None,
70 verbose=0, fLOG=print, clean=True,
71 conf_params=None, filter_exp=None,
72 filter_scenario=None, flat=False,
73 exc=False, build=None, execute=False,
74 add_pyspy=False, env=None,
75 matrix=None):
76 """
77 Creates an :epkg:`asv` benchmark in a folder
78 but does not run it.
80 :param n_features: number of features to try
81 :param dims: number of observations to try
82 :param verbose: integer from 0 (None) to 2 (full verbose)
83 :param opset_min: tries every conversion from this minimum opset,
84 -1 to get the current opset defined by module :epkg:`onnx`
85 :param opset_max: tries every conversion up to maximum opset,
86 -1 to get the current opset defined by module :epkg:`onnx`
87 :param runtime: runtime to check, *scikit-learn*, *python*,
88 *python_compiled* compiles the graph structure
89 and is more efficient when the number of observations is
90 small, *onnxruntime1* to check :epkg:`onnxruntime`,
91 *onnxruntime2* to check every ONNX node independently
92 with onnxruntime, many runtime can be checked at the same time
93 if the value is a comma separated list
94 :param models: list of models to test or empty
95 string to test them all
96 :param skip_models: models to skip
97 :param extended_list: extends the list of :epkg:`scikit-learn` converters
98 with converters implemented in this module
99 :param n_features: change the default number of features for
100 a specific problem, it can also be a comma separated list
101 :param dtype: '32' or '64' or None for both,
102 limits the test to one specific number types
103 :param fLOG: logging function
104 :param clean: clean the folder first, otherwise overwrites the content
105 :param conf_params: to overwrite some of the configuration parameters
106 :param filter_exp: function which tells if the experiment must be run,
107 None to run all, takes *model, problem* as an input
108 :param filter_scenario: second function which tells if the experiment must be run,
109 None to run all, takes *model, problem, scenario, extra*
110 as an input
111 :param flat: one folder for all files or subfolders
112 :param exc: if False, raises warnings instead of exceptions
113 whenever possible
114 :param build: where to put the outputs
115 :param execute: execute each script to make sure
116 imports are correct
117 :param add_pyspy: add an extra folder with code to profile
118 each configuration
119 :param env: None to use the default configuration or ``same`` to use
120 the current one
121 :param matrix: specifies versions for a module,
122 example: ``{'onnxruntime': ['1.1.1', '1.1.2']}``,
123 if a package name starts with `'~'`, the package is removed
124 :return: created files
126 The default configuration is the following:
128 .. runpython::
129 :showcode:
130 :warningout: DeprecationWarning
132 import pprint
133 from mlprodict.asv_benchmark.create_asv import default_asv_conf
135 pprint.pprint(default_asv_conf)
137 The benchmark does not seem to work well with setting
138 ``-environment existing:same``. The publishing fails.
139 """
140 if opset_min == -1:
141 opset_min = __max_supported_opset__
142 if opset_max == -1:
143 opset_max = __max_supported_opset__ # pragma: no cover
144 if verbose > 0 and fLOG is not None: # pragma: no cover
145 fLOG("[create_asv_benchmark] opset in [{}, {}].".format(
146 opset_min, opset_max))
148 # creates the folder if it does not exist.
149 if not os.path.exists(location):
150 if verbose > 0 and fLOG is not None: # pragma: no cover
151 fLOG("[create_asv_benchmark] create folder '{}'.".format(location))
152 os.makedirs(location) # pragma: no cover
154 location_test = os.path.join(location, 'benches')
155 if not os.path.exists(location_test):
156 if verbose > 0 and fLOG is not None:
157 fLOG("[create_asv_benchmark] create folder '{}'.".format(location_test))
158 os.mkdir(location_test)
160 # Cleans the content of the folder
161 created = []
162 if clean:
163 for name in os.listdir(location_test):
164 full_name = os.path.join(location_test, name) # pragma: no cover
165 if os.path.isfile(full_name): # pragma: no cover
166 os.remove(full_name)
168 # configuration
169 conf = default_asv_conf.copy()
170 if conf_params is not None:
171 for k, v in conf_params.items():
172 conf[k] = v
173 if build is not None:
174 for fi in ['env_dir', 'results_dir', 'html_dir']: # pragma: no cover
175 conf[fi] = os.path.join(build, conf[fi])
176 if env == 'same':
177 if matrix is not None:
178 raise ValueError( # pragma: no cover
179 "Parameter matrix must be None if env is 'same'.")
180 conf['pythons'] = ['same']
181 conf['matrix'] = {}
182 elif matrix is not None:
183 drop_keys = set(p for p in matrix if p.startswith('~'))
184 matrix = {k: v for k, v in matrix.items() if k not in drop_keys}
185 conf['matrix'] = {k: v for k,
186 v in conf['matrix'].items() if k not in drop_keys}
187 conf['matrix'].update(matrix)
188 elif env is not None:
189 raise ValueError( # pragma: no cover
190 "Unable to handle env='{}'.".format(env))
191 dest = os.path.join(location, "asv.conf.json")
192 created.append(dest)
193 with open(dest, "w", encoding='utf-8') as f:
194 json.dump(conf, f, indent=4)
195 if verbose > 0 and fLOG is not None:
196 fLOG("[create_asv_benchmark] create 'asv.conf.json'.")
198 # __init__.py
199 dest = os.path.join(location, "__init__.py")
200 with open(dest, "w", encoding='utf-8') as f:
201 pass
202 created.append(dest)
203 if verbose > 0 and fLOG is not None:
204 fLOG("[create_asv_benchmark] create '__init__.py'.")
205 dest = os.path.join(location_test, '__init__.py')
206 with open(dest, "w", encoding='utf-8') as f:
207 pass
208 created.append(dest)
209 if verbose > 0 and fLOG is not None:
210 fLOG("[create_asv_benchmark] create 'benches/__init__.py'.")
212 # flask_server
213 tool_dir = os.path.join(location, 'tools')
214 if not os.path.exists(tool_dir):
215 os.mkdir(tool_dir)
216 fl = os.path.join(tool_dir, 'flask_serve.py')
217 with open(fl, "w", encoding='utf-8') as f:
218 f.write(flask_helper)
219 if verbose > 0 and fLOG is not None:
220 fLOG("[create_asv_benchmark] create 'flask_serve.py'.")
222 # command line
223 if sys.platform.startswith("win"):
224 run_bash = os.path.join(tool_dir, 'run_asv.bat') # pragma: no cover
225 else:
226 run_bash = os.path.join(tool_dir, 'run_asv.sh')
227 with open(run_bash, 'w') as f:
228 f.write(textwrap.dedent("""
229 echo --BENCHRUN--
230 python -m asv run --show-stderr --config ./asv.conf.json
231 echo --PUBLISH--
232 python -m asv publish --config ./asv.conf.json -o ./html
233 echo --CSV--
234 python -m mlprodict asv2csv -f ./results -o ./data_bench.csv
235 """))
237 # pyspy
238 if add_pyspy:
239 dest_pyspy = os.path.join(location, 'pyspy')
240 if not os.path.exists(dest_pyspy):
241 os.mkdir(dest_pyspy)
242 else:
243 dest_pyspy = None
245 if verbose > 0 and fLOG is not None:
246 fLOG("[create_asv_benchmark] create all tests.")
248 created.extend(list(_enumerate_asv_benchmark_all_models(
249 location_test, opset_min=opset_min, opset_max=opset_max,
250 runtime=runtime, models=models,
251 skip_models=skip_models, extended_list=extended_list,
252 n_features=n_features, dtype=dtype,
253 verbose=verbose, filter_exp=filter_exp,
254 filter_scenario=filter_scenario,
255 dims=dims, exc=exc, flat=flat,
256 fLOG=fLOG, execute=execute,
257 dest_pyspy=dest_pyspy)))
259 if verbose > 0 and fLOG is not None:
260 fLOG("[create_asv_benchmark] done.")
261 return created
264def _enumerate_asv_benchmark_all_models( # pylint: disable=R0914
265 location, opset_min=10, opset_max=None,
266 runtime=('scikit-learn', 'python'), models=None,
267 skip_models=None, extended_list=True,
268 n_features=None, dtype=None,
269 verbose=0, filter_exp=None,
270 dims=None, filter_scenario=None,
271 exc=True, flat=False, execute=False,
272 dest_pyspy=None, fLOG=print):
273 """
274 Loops over all possible models and fills a folder
275 with benchmarks following :epkg:`asv` concepts.
277 :param n_features: number of features to try
278 :param dims: number of observations to try
279 :param verbose: integer from 0 (None) to 2 (full verbose)
280 :param opset_min: tries every conversion from this minimum opset
281 :param opset_max: tries every conversion up to maximum opset
282 :param runtime: runtime to check, *scikit-learn*, *python*,
283 *onnxruntime1* to check :epkg:`onnxruntime`,
284 *onnxruntime2* to check every ONNX node independently
285 with onnxruntime, many runtime can be checked at the same time
286 if the value is a comma separated list
287 :param models: list of models to test or empty
288 string to test them all
289 :param skip_models: models to skip
290 :param extended_list: extends the list of :epkg:`scikit-learn` converters
291 with converters implemented in this module
292 :param n_features: change the default number of features for
293 a specific problem, it can also be a comma separated list
294 :param dtype: '32' or '64' or None for both,
295 limits the test to one specific number types
296 :param fLOG: logging function
297 :param filter_exp: function which tells if the experiment must be run,
298 None to run all, takes *model, problem* as an input
299 :param filter_scenario: second function which tells if the experiment must be run,
300 None to run all, takes *model, problem, scenario, extra*
301 as an input
302 :param exc: if False, raises warnings instead of exceptions
303 whenever possible
304 :param flat: one folder for all files or subfolders
305 :param execute: execute each script to make sure
306 imports are correct
307 :param dest_pyspy: add a file to profile the prediction
308 function with :epkg:`pyspy`
309 """
311 ops = [_ for _ in sklearn_operators(extended=extended_list)]
312 patterns = _read_patterns()
314 if models is not None:
315 if not all(map(lambda m: isinstance(m, str), models)):
316 raise ValueError(
317 "models must be a set of strings.") # pragma: no cover
318 ops_ = [_ for _ in ops if _['name'] in models]
319 if len(ops) == 0:
320 raise ValueError("Parameter models is wrong: {}\n{}".format( # pragma: no cover
321 models, ops[0]))
322 ops = ops_
323 if skip_models is not None:
324 ops = [m for m in ops if m['name'] not in skip_models]
326 if verbose > 0:
328 def iterate():
329 for i, row in enumerate(ops): # pragma: no cover
330 fLOG("{}/{} - {}".format(i + 1, len(ops), row))
331 yield row
333 if verbose >= 11:
334 verbose -= 10 # pragma: no cover
335 loop = iterate() # pragma: no cover
336 else:
337 try:
338 from tqdm import trange
340 def iterate_tqdm():
341 with trange(len(ops)) as t:
342 for i in t:
343 row = ops[i]
344 disp = row['name'] + " " * (28 - len(row['name']))
345 t.set_description("%s" % disp)
346 yield row
348 loop = iterate_tqdm()
350 except ImportError: # pragma: no cover
351 loop = iterate()
352 else:
353 loop = ops
355 if opset_max is None:
356 opset_max = __max_supported_opset__
357 opsets = list(range(opset_min, opset_max + 1))
358 all_created = set()
360 # loop on all models
361 for row in loop:
363 model = row['cl']
365 problems, extras = _retrieve_problems_extra(
366 model, verbose, fLOG, extended_list)
367 if extras is None or problems is None:
368 # Not tested yet.
369 continue # pragma: no cover
371 # flat or not flat
372 created, location_model, prefix_import, dest_pyspy_model = _handle_init_files(
373 model, flat, location, verbose, dest_pyspy, fLOG)
374 for init in created:
375 yield init
377 # loops on problems
378 for prob in problems:
379 if filter_exp is not None and not filter_exp(model, prob):
380 continue
382 (X_train, X_test, y_train,
383 y_test, Xort_test,
384 init_types, conv_options, method_name,
385 output_index, dofit, predict_kwargs) = _get_problem_data(prob, None)
387 for scenario_extra in extras:
388 subset_problems = None
389 optimisations = None
390 new_conv_options = None
392 if len(scenario_extra) > 2:
393 options = scenario_extra[2]
394 if isinstance(options, dict):
395 subset_problems = options.get('subset_problems', None)
396 optimisations = options.get('optim', None)
397 new_conv_options = options.get('conv_options', None)
398 else:
399 subset_problems = options
401 if subset_problems and isinstance(subset_problems, (list, set)):
402 if prob not in subset_problems:
403 # Skips unrelated problem for a specific configuration.
404 continue
405 elif subset_problems is not None:
406 raise RuntimeError( # pragma: no cover
407 "subset_problems must be a set or a list not {}.".format(
408 subset_problems))
410 scenario, extra = scenario_extra[:2]
411 if optimisations is None:
412 optimisations = [None]
413 if new_conv_options is None:
414 new_conv_options = [{}]
416 if (filter_scenario is not None and
417 not filter_scenario(model, prob, scenario,
418 extra, new_conv_options)):
419 continue # pragma: no cover
421 if verbose >= 3 and fLOG is not None:
422 fLOG("[create_asv_benchmark] model={} scenario={} optim={} extra={} dofit={} (problem={} method_name='{}')".format(
423 model.__name__, scenario, optimisations, extra, dofit, prob, method_name))
424 created = _create_asv_benchmark_file(
425 location_model, opsets=opsets,
426 model=model, scenario=scenario, optimisations=optimisations,
427 extra=extra, dofit=dofit, problem=prob,
428 runtime=runtime, new_conv_options=new_conv_options,
429 X_train=X_train, X_test=X_test, y_train=y_train,
430 y_test=y_test, Xort_test=Xort_test,
431 init_types=init_types, conv_options=conv_options,
432 method_name=method_name, dims=dims, n_features=n_features,
433 output_index=output_index, predict_kwargs=predict_kwargs,
434 exc=exc, prefix_import=prefix_import,
435 execute=execute, location_pyspy=dest_pyspy_model,
436 patterns=patterns)
437 for cr in created:
438 if cr in all_created:
439 raise RuntimeError( # pragma: no cover
440 "File '{}' was already created.".format(cr))
441 all_created.add(cr)
442 if verbose > 1 and fLOG is not None:
443 fLOG("[create_asv_benchmark] add '{}'.".format(cr))
444 yield cr
447def _create_asv_benchmark_file( # pylint: disable=R0914
448 location, model, scenario, optimisations, new_conv_options,
449 extra, dofit, problem, runtime, X_train, X_test, y_train,
450 y_test, Xort_test, init_types, conv_options,
451 method_name, n_features, dims, opsets,
452 output_index, predict_kwargs, prefix_import,
453 exc, execute=False, location_pyspy=None, patterns=None):
454 """
455 Creates a benchmark file based in the information received
456 through the argument. It uses one of the templates
457 like @see cl TemplateBenchmarkClassifier or
458 @see cl TemplateBenchmarkRegressor.
459 """
460 if patterns is None:
461 raise ValueError("Patterns list is empty.") # pragma: no cover
463 def format_conv_options(d_options, class_name):
464 if d_options is None:
465 return None
466 res = {}
467 for k, v in d_options.items():
468 if isinstance(k, type):
469 if "." + class_name + "'" in str(k):
470 res[class_name] = v
471 continue
472 raise ValueError( # pragma: no cover
473 "Class '{}', unable to format options {}".format(
474 class_name, d_options))
475 res[k] = v
476 return res
478 def _nick_name_options(model, opts):
479 # Shorten common onnx options, see _CommonAsvSklBenchmark._to_onnx.
480 if opts is None:
481 return opts # pragma: no cover
482 short_opts = shorten_onnx_options(model, opts)
483 if short_opts is not None:
484 return short_opts
485 res = {}
486 for k, v in opts.items():
487 if hasattr(k, '__name__'):
488 res["####" + k.__name__ + "####"] = v
489 else:
490 res[k] = v # pragma: no cover
491 return res
493 def _make_simple_name(name):
494 simple_name = name.replace("bench_", "").replace("_bench", "")
495 simple_name = simple_name.replace("bench.", "").replace(".bench", "")
496 simple_name = simple_name.replace(".", "-")
497 repl = {'_': '', 'solverliblinear': 'liblinear'}
498 for k, v in repl.items():
499 simple_name = simple_name.replace(k, v)
500 return simple_name
502 def _optdict2string(opt):
503 if isinstance(opt, str):
504 return opt
505 if isinstance(opt, list):
506 raise TypeError(
507 "Unable to process type %r." % type(opt))
508 reps = {True: 1, False: 0, 'zipmap': 'zm',
509 'optim': 'opt'}
510 info = []
511 for k, v in sorted(opt.items()):
512 if isinstance(v, dict):
513 v = _optdict2string(v)
514 if k.startswith('####'):
515 k = ''
516 i = '{}{}'.format(reps.get(k, k), reps.get(v, v))
517 info.append(i)
518 return "-".join(info)
520 runtimes_abb = {
521 'scikit-learn': 'skl',
522 'onnxruntime1': 'ort',
523 'onnxruntime2': 'ort2',
524 'python': 'pyrt',
525 'python_compiled': 'pyrtc',
526 }
527 runtime = [runtimes_abb[k] for k in runtime]
529 # Looping over configuration.
530 names = []
531 for optimisation in optimisations:
532 merged_options = [_merge_options(nconv_options, conv_options)
533 for nconv_options in new_conv_options]
535 nck_opts = [_nick_name_options(model, opts)
536 for opts in merged_options]
537 try:
538 name = _asv_class_name(
539 model, scenario, optimisation, extra,
540 dofit, conv_options, problem,
541 shorten=True)
542 except ValueError as e: # pragma: no cover
543 if exc:
544 raise e
545 warnings.warn(str(e))
546 continue
547 filename = name.replace(".", "_") + ".py"
548 try:
549 class_content = _select_pattern_problem(problem, patterns)
550 except ValueError as e:
551 if exc:
552 raise e # pragma: no cover
553 warnings.warn(str(e))
554 continue
555 full_class_name = _asv_class_name(
556 model, scenario, optimisation, extra,
557 dofit, conv_options, problem,
558 shorten=False)
559 class_name = name.replace(
560 "bench.", "").replace(".", "_") + "_bench"
562 # n_features, N, runtimes
563 rep = {
564 "['skl', 'pyrtc', 'ort'], # values for runtime": str(runtime),
565 "[1, 10, 100, 1000, 10000], # values for N": str(dims),
566 "[4, 20], # values for nf": str(n_features),
567 "[__max_supported_opset__], # values for opset": str(opsets),
568 "['float', 'double'], # values for dtype":
569 "['float']" if '-64' not in problem else "['double']",
570 "[None], # values for optim": "%r" % nck_opts,
571 }
572 for k, v in rep.items():
573 if k not in class_content:
574 raise ValueError("Unable to find '{}'\n{}.".format( # pragma: no cover
575 k, class_content))
576 class_content = class_content.replace(k, v + ',')
577 class_content = class_content.split(
578 "def _create_model(self):")[0].strip("\n ")
579 if "####" in class_content:
580 class_content = class_content.replace(
581 "'####", "").replace("####'", "")
582 if "####" in class_content:
583 raise RuntimeError( # pragma: no cover
584 "Substring '####' should not be part of the script for '{}'\n{}".format(
585 model.__name__, class_content))
587 # Model setup
588 class_content, atts = add_model_import_init(
589 class_content, model, optimisation,
590 extra, merged_options)
591 class_content = class_content.replace(
592 "class TemplateBenchmark",
593 "class {}".format(class_name))
595 # dtype, dofit
596 atts.append("chk_method_name = %r" % method_name)
597 atts.append("par_scenario = %r" % scenario)
598 atts.append("par_problem = %r" % problem)
599 atts.append("par_optimisation = %r" % optimisation)
600 if not dofit:
601 atts.append("par_dofit = False")
602 if merged_options is not None and len(merged_options) > 0:
603 atts.append("par_convopts = %r" % format_conv_options(
604 conv_options, model.__name__))
605 atts.append("par_full_test_name = %r" % full_class_name)
607 simple_name = _make_simple_name(name)
608 atts.append("benchmark_name = %r" % simple_name)
609 atts.append("pretty_name = %r" % simple_name)
611 if atts:
612 class_content = class_content.replace(
613 "# additional parameters",
614 "\n ".join(atts))
615 if prefix_import != '.':
616 class_content = class_content.replace(
617 " from .", "from .{}".format(prefix_import))
619 # Check compilation
620 try:
621 compile(class_content, filename, 'exec')
622 except SyntaxError as e: # pragma: no cover
623 raise SyntaxError("Unable to compile model '{}'\n{}".format(
624 model.__name__, class_content)) from e
626 # Verifies missing imports.
627 to_import, _ = verify_code(class_content, exc=False)
628 try:
629 miss = find_missing_sklearn_imports(to_import)
630 except ValueError as e: # pragma: no cover
631 raise ValueError(
632 "Unable to check import in script\n{}".format(
633 class_content)) from e
634 class_content = class_content.replace(
635 "# __IMPORTS__", "\n".join(miss))
636 verify_code(class_content, exc=True)
637 class_content = class_content.replace(
638 "par_extra = {", "par_extra = {\n")
639 class_content = remove_extra_spaces_and_pep8(
640 class_content, aggressive=True)
642 # Check compilation again
643 try:
644 obj = compile(class_content, filename, 'exec')
645 except SyntaxError as e: # pragma: no cover
646 raise SyntaxError("Unable to compile model '{}'\n{}".format(
647 model.__name__,
648 _display_code_lines(class_content))) from e
650 # executes to check import
651 if execute:
652 try:
653 exec(obj, globals(), locals()) # pylint: disable=W0122
654 except Exception as e: # pragma: no cover
655 raise RuntimeError(
656 "Unable to process class '{}' ('{}') a script due to '{}'\n{}".format(
657 model.__name__, filename, str(e),
658 _display_code_lines(class_content))) from e
660 # Saves
661 fullname = os.path.join(location, filename)
662 names.append(fullname)
663 with open(fullname, "w", encoding='utf-8') as f:
664 f.write(class_content)
666 if location_pyspy is not None:
667 # adding configuration for pyspy
668 class_name = re.compile(
669 'class ([A-Za-z_0-9]+)[(]').findall(class_content)[0]
670 fullname_pyspy = os.path.splitext(
671 os.path.join(location_pyspy, filename))[0]
672 pyfold = os.path.splitext(os.path.split(fullname)[-1])[0]
674 dtypes = ['float', 'double'] if '-64' in problem else ['float']
675 for dim in dims:
676 for nf in n_features:
677 for opset in opsets:
678 for dtype in dtypes:
679 for opt in nck_opts:
680 tmpl = pyspy_template.replace(
681 '__PATH__', location)
682 tmpl = tmpl.replace(
683 '__CLASSNAME__', class_name)
684 tmpl = tmpl.replace('__PYFOLD__', pyfold)
685 opt = "" if opt == {} else opt
687 first = True
688 for rt in runtime:
689 if first:
690 tmpl += textwrap.dedent("""
692 def profile0_{rt}(iter, cl, N, nf, opset, dtype, optim):
693 return setup_profile0(iter, cl, '{rt}', N, nf, opset, dtype, optim)
694 iter = profile0_{rt}(iter, cl, {dim}, {nf}, {opset}, '{dtype}', {opt})
695 print(datetime.now(), "iter", iter)
697 """).format(rt=rt, dim=dim, nf=nf, opset=opset,
698 dtype=dtype, opt="%r" % opt)
699 first = False
701 tmpl += textwrap.dedent("""
703 def profile_{rt}(iter, cl, N, nf, opset, dtype, optim):
704 return setup_profile(iter, cl, '{rt}', N, nf, opset, dtype, optim)
705 profile_{rt}(iter, cl, {dim}, {nf}, {opset}, '{dtype}', {opt})
706 print(datetime.now(), "iter", iter)
708 """).format(rt=rt, dim=dim, nf=nf, opset=opset,
709 dtype=dtype, opt="%r" % opt)
711 thename = "{n}_{dim}_{nf}_{opset}_{dtype}_{opt}.py".format(
712 n=fullname_pyspy, dim=dim, nf=nf,
713 opset=opset, dtype=dtype, opt=_optdict2string(opt))
714 with open(thename, 'w', encoding='utf-8') as f:
715 f.write(tmpl)
716 names.append(thename)
718 ext = '.bat' if sys.platform.startswith(
719 'win') else '.sh'
720 script = os.path.splitext(thename)[0] + ext
721 short = os.path.splitext(
722 os.path.split(thename)[-1])[0]
723 with open(script, 'w', encoding='utf-8') as f:
724 f.write('py-spy record --native --function --rate=10 -o {n}_fct.svg -- {py} {n}.py\n'.format(
725 py=sys.executable, n=short))
726 f.write('py-spy record --native --rate=10 -o {n}_line.svg -- {py} {n}.py\n'.format(
727 py=sys.executable, n=short))
729 return names