Coverage for mlprodict/testing/script_testing.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Utilies to test script from :epkg:`scikit-learn` documentation.
4"""
5import os
6from io import StringIO
7from contextlib import redirect_stdout, redirect_stderr
8import pprint
9import numpy
10from sklearn.base import BaseEstimator
11from .verify_code import verify_code
14class MissingVariableError(RuntimeError):
15 """
16 Raised when a variable is missing.
17 """
18 pass
21def _clean_script(content):
22 """
23 Comments out all lines containing ``.show()``.
24 """
25 new_lines = []
26 for line in content.split('\n'):
27 if '.show()' in line or 'sys.exit' in line:
28 new_lines.append("# " + line)
29 else:
30 new_lines.append(line)
31 return "\n".join(new_lines)
34def _enumerate_fit_info(fits):
35 """
36 Extracts the name of the fitted models and the data
37 used to train it.
38 """
39 for fit in fits:
40 chs = fit['children']
41 if len(chs) < 2:
42 # unable to extract the needed information
43 continue # pragma: no cover
44 model = chs[0]['str']
45 if model.endswith('.fit'):
46 model = model[:-4]
47 args = [ch['str'] for ch in chs[1:]]
48 yield model, args
51def _try_onnx(loc, model_name, args_name, **options):
52 """
53 Tries onnx conversion.
55 @param loc available variables
56 @param model_name model name among these variables
57 @param args_name arguments name among these variables
58 @param options additional options for the conversion
59 @return onnx model
60 """
61 from ..onnx_conv import to_onnx
62 if model_name not in loc:
63 raise MissingVariableError( # pragma: no cover
64 "Unable to find model '{}' in {}".format(
65 model_name, ", ".join(sorted(loc))))
66 if args_name[0] not in loc:
67 raise MissingVariableError( # pragma: no cover
68 "Unable to find data '{}' in {}".format(
69 args_name[0], ", ".join(sorted(loc))))
70 model = loc[model_name]
71 X = loc[args_name[0]]
72 dtype = options.get('dtype', numpy.float32)
73 Xt = X.astype(dtype)
74 onx = to_onnx(model, Xt, **options)
75 args = dict(onx=onx, model=model, X=Xt)
76 return onx, args
79def verify_script(file_or_name, try_onnx=True, existing_loc=None,
80 **options):
81 """
82 Checks that models fitted in an example from :epkg:`scikit-learn`
83 documentation can be converted into :epkg:`ONNX`.
85 @param file_or_name file or string
86 @param try_onnx try the onnx conversion
87 @param existing_loc existing local variables
88 @param options conversion options
89 @return list of converted models
90 """
91 if '\n' not in file_or_name and os.path.exists(file_or_name):
92 filename = file_or_name
93 with open(file_or_name, 'r', encoding='utf-8') as f:
94 content = f.read()
95 else: # pragma: no cover
96 content = file_or_name
97 filename = "<string>"
99 # comments out .show()
100 content = _clean_script(content)
102 # look for fit or predict expressions
103 _, node = verify_code(content, exc=False)
104 fits = node._fits
105 models_args = list(_enumerate_fit_info(fits))
107 # execution
108 obj = compile(content, filename, 'exec')
109 glo = globals().copy()
110 loc = {}
111 if existing_loc is not None:
112 loc.update(existing_loc) # pragma: no cover
113 glo.update(existing_loc) # pragma: no cover
114 out = StringIO()
115 err = StringIO()
117 with redirect_stdout(out):
118 with redirect_stderr(err):
119 exec(obj, glo, loc) # pylint: disable=W0122
121 # filter out values
122 cls = (BaseEstimator, numpy.ndarray)
123 loc_fil = {k: v for k, v in loc.items() if isinstance(v, cls)}
124 glo_fil = {k: v for k, v in glo.items() if k not in {'__builtins__'}}
125 onx_info = []
127 # onnx
128 if try_onnx:
129 if len(models_args) == 0:
130 raise MissingVariableError( # pragma: no cover
131 "No detected trained model in '{}'\n{}\n--LOCALS--\n{}".format(
132 filename, content, pprint.pformat(loc_fil)))
133 for model_args in models_args:
134 try:
135 onx, args = _try_onnx(loc_fil, *model_args, **options)
136 except MissingVariableError as e: # pragma: no cover
137 raise MissingVariableError("Unable to find variable in '{}'\n{}".format(
138 filename, pprint.pformat(fits))) from e
139 loc_fil[model_args[0] + "_onnx"] = onx
140 onx_info.append(args)
142 # final results
143 return dict(locals=loc_fil, globals=glo_fil,
144 stdout=out.getvalue(),
145 stderr=err.getvalue(),
146 onx_info=onx_info)