Coverage for mlprodict/onnxrt/validate/validate_python.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Helpers to validate python code.
4"""
5import pickle
6import pprint
7import numpy
8from numpy.linalg import det as npy_det # pylint: disable=E0611
9from scipy.spatial.distance import cdist # pylint: disable=E0611
10from scipy.special import expit, erf # pylint: disable=E0611
11from scipy.linalg import solve # pylint: disable=E0611
12from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
13from ...tools.code_helper import make_callable
16def _make_callable(fct, obj, code, gl, debug):
17 """
18 Same function as @see fn make_callable but deals with
19 function which an undefined number of arguments.
20 """
21 def pyrt_Concat_(*inputs, axis=0):
22 return numpy.concatenate(inputs, axis=axis)
24 if fct == "pyrt_Concat":
25 return pyrt_Concat_
26 return make_callable(fct, obj, code, gl, debug)
29def validate_python_inference(oinf, inputs, tolerance=0.):
30 """
31 Validates the code produced by method :meth:`to_python
32 <mlprodict.onnxrt.onnx_inference_exports.OnnxInferenceExport.to_python>`.
33 The function compiles and executes the code
34 given as an argument and compares the results to
35 what *oinf* returns. This function is mostly used for
36 unit testing purpose but it is not robust enough
37 to handle all cases.
39 @param oinf @see cl OnnxInference
40 @param inputs inputs as dictionary
41 @param tolerance discrepencies must be below or equal to
42 this theshold
44 The function fails if the expected output are not the same.
45 """
46 from ..ops_cpu.op_argmax import _argmax
47 from ..ops_cpu.op_argmin import _argmin
48 from ..ops_cpu.op_celu import _vcelu1
49 from ..ops_cpu.op_leaky_relu import _leaky_relu
51 cd = oinf.to_python()
52 code = cd['onnx_pyrt_main.py']
54 exp = oinf.run(inputs)
55 if not isinstance(exp, dict):
56 raise TypeError( # pragma: no cover
57 "exp is not a dictionary by '{}'.".format(type(exp)))
58 if len(exp) == 0:
59 raise ValueError( # pragma: no cover
60 "No result to compare.")
61 inps = ['{0}={0}'.format(k) for k in sorted(inputs)]
62 code += "\n".join(['', '', 'opi = OnnxPythonInference()',
63 'res = opi.run(%s)' % ', '.join(inps)])
65 try:
66 cp = compile(code, "<string>", mode='exec')
67 except SyntaxError as e:
68 raise SyntaxError(
69 "Error %s in code\n%s" % (str(e), code)) from e
70 pyrt_fcts = [_ for _ in cp.co_names if _.startswith("pyrt_")]
71 fcts_local = {}
73 gl = {'numpy': numpy, 'pickle': pickle,
74 'expit': expit, 'erf': erf, 'cdist': cdist,
75 '_argmax': _argmax, '_argmin': _argmin,
76 '_vcelu1': _vcelu1, 'solve': solve,
77 'fft': numpy.fft.fft, 'rfft': numpy.fft.rfft,
78 'fft2': numpy.fft.fft2,
79 'npy_det': npy_det, 'ndarray': numpy.ndarray,
80 '_leaky_relu': _leaky_relu,
81 'nan': numpy.nan,
82 'TENSOR_TYPE_TO_NP_TYPE': TENSOR_TYPE_TO_NP_TYPE}
84 for fct in pyrt_fcts:
85 for obj in cp.co_consts:
86 if isinstance(obj, str):
87 continue
88 sobj = str(obj)
89 if '<string>' in sobj and fct in sobj:
90 fcts_local[fct] = _make_callable(fct, obj, code, gl, False)
92 gl.update(fcts_local)
93 loc = inputs
94 try:
95 exec(cp, gl, loc) # pylint: disable=W0122
96 except (NameError, TypeError, SyntaxError, # pragma: no cover
97 IndexError, ValueError) as e:
98 raise RuntimeError(
99 "Unable to execute code.\n{}\n-----\n{}".format(e, code)) from e
101 got = loc['res']
102 keys = list(sorted(exp))
103 if isinstance(got, numpy.ndarray) and len(keys) == 1:
104 got = {keys[0]: got}
106 if not isinstance(got, dict):
107 raise TypeError( # pragma: no cover
108 "got is not a dictionary by '{}'\n--\n{}\n---\n{}\n--code--\n{}".format(
109 type(got), dir(got), pprint.pformat(str(loc)), code))
110 if len(got) != len(exp):
111 raise RuntimeError( # pragma: no cover
112 "Different number of results.\nexp: {}\ngot: {}\n--code--\n{}".format(
113 ", ".join(sorted(exp)), ", ".join(sorted(got)), code))
115 if keys != list(sorted(got)):
116 raise RuntimeError( # pragma: no cover
117 "Different result names.\nexp: {}\ngot: {}\n--code--\n{}".format(
118 ", ".join(sorted(exp)), ", ".join(sorted(got)), code))
120 for k in keys:
121 e = exp[k]
122 g = got[k]
123 if isinstance(e, numpy.ndarray):
124 if e.shape != g.shape:
125 raise ValueError( # pragma: no cover
126 "Shapes are different {} != {}\n---\n{}\n{}.".format(
127 e.shape, g.shape, e, g))
128 diff = 0
129 for a, b in zip(e.ravel(), g.ravel()):
130 if a == b:
131 continue
132 if (isinstance(a, float) and isinstance(b, float) and
133 numpy.isnan(a) and numpy.isnan(b)):
134 continue # pragma: no cover
135 diff = max(diff, abs(a - b))
136 if tolerance != 'random' and diff > tolerance:
137 raise ValueError( # pragma: no cover
138 "Values are different (max diff={}>{})\n--EXP--\n{}\n--GOT--"
139 "\n{}\n--\n{}".format(diff, tolerance, e, g, code))
140 else:
141 raise NotImplementedError( # pragma: no cover
142 "Unable to compare values of type '{}'.".format(type(e)))