Coverage for mlprodict/testing/test_utils/utils_backend_common_compare.py: 70%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Inspired from sklearn-onnx, handles two backends.
4"""
5import numpy
6import onnx
7import pandas
8from .utils_backend_common import (
9 load_data_and_model, extract_options,
10 ExpectedAssertionError, OnnxBackendAssertionError,
11 OnnxRuntimeMissingNewOnnxOperatorException,
12 _compare_expected, _create_column)
15def compare_runtime_session( # pylint: disable=R0912
16 cls_session, test, decimal=5, options=None,
17 verbose=False, context=None, comparable_outputs=None,
18 intermediate_steps=False, classes=None,
19 disable_optimisation=False):
20 """
21 The function compares the expected output (computed with
22 the model before being converted to ONNX) and the ONNX output
23 produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`.
25 :param cls_session: inference session instance (like @see cl OnnxInference)
26 :param test: dictionary with the following keys:
27 - *onnx*: onnx model (filename or object)
28 - *expected*: expected output (filename pkl or object)
29 - *data*: input data (filename pkl or object)
30 :param decimal: precision of the comparison
31 :param options: comparison options
32 :param context: specifies custom operators
33 :param verbose: in case of error, the function may print
34 more information on the standard output
35 :param comparable_outputs: compare only these outputs
36 :param intermediate_steps: displays intermediate steps
37 in case of an error
38 :param classes: classes names (if option 'nocl' is used)
39 :param disable_optimisation: disable optimisation the runtime may do
40 :return: tuple (outut, lambda function to run the predictions)
42 The function does not return anything but raises an error
43 if the comparison failed.
44 """
45 lambda_onnx = None
46 if context is None:
47 context = {}
48 load = load_data_and_model(test, **context)
49 if verbose: # pragma no cover
50 print("[compare_runtime] test '{}' loaded".format(test['onnx']))
52 onx = test['onnx']
54 if options is None:
55 if isinstance(onx, str):
56 options = extract_options(onx)
57 else:
58 options = {}
59 elif options is None:
60 options = {}
61 elif not isinstance(options, dict):
62 raise TypeError( # pragma no cover
63 "options must be a dictionary.")
65 if verbose: # pragma no cover
66 print("[compare_runtime] InferenceSession('{}')".format(onx))
68 runtime_options = dict(disable_optimisation=disable_optimisation)
69 try:
70 sess = cls_session(onx, runtime_options=runtime_options)
71 except TypeError as et: # pragma: no cover
72 raise TypeError( # pylint: disable=W0707
73 "Wrong signature for '{}' ({}).".format(cls_session.__name__, et))
74 except ExpectedAssertionError as expe: # pragma no cover
75 raise expe
76 except Exception as e: # pylint: disable=W0703
77 if "CannotLoad" in options: # pragma no cover
78 raise ExpectedAssertionError( # pylint: disable=W0707
79 "Unable to load onnx '{0}' due to\n{1}".format(onx, e))
80 else: # pragma no cover
81 if verbose: # pragma no cover
82 model = onnx.load(onx)
83 smodel = "\nJSON ONNX\n" + str(model)
84 else:
85 smodel = ""
86 if ("NOT_IMPLEMENTED : Could not find an implementation "
87 "for the node" in str(e)):
88 # onnxruntime does not implement a specific node yet.
89 raise OnnxRuntimeMissingNewOnnxOperatorException( # pylint: disable=W0707
90 "{3} does not implement a new operator "
91 "'{0}'\n{1}\nONNX\n{2}".format(
92 onx, e, smodel, cls_session))
93 if "NOT_IMPLEMENTED : Failed to find kernel" in str(e):
94 # onnxruntime does not implement a specific node yet
95 # in the kernel included in onnxruntime.
96 raise OnnxBackendAssertionError( # pylint: disable=W0707
97 "{3} misses a kernel for operator "
98 "'{0}'\n{1}\nONNX\n{2}".format(
99 onx, e, smodel, cls_session))
100 raise OnnxBackendAssertionError( # pylint: disable=W0707
101 "Unable to load onnx '{0}'\nONNX\n{1}\n{2}".format(
102 onx, smodel, e))
104 input = load["data"]
105 DF = options.pop('DF', False)
106 if DF:
107 inputs = {c: input[c].values for c in input.columns}
108 for k in inputs:
109 if inputs[k].dtype == numpy.float64:
110 inputs[k] = inputs[k].astype(numpy.float32)
111 inputs[k] = inputs[k].reshape((inputs[k].shape[0], 1))
112 else:
113 if isinstance(input, dict):
114 inputs = input
115 elif isinstance(input, (list, numpy.ndarray, pandas.DataFrame)):
116 inp = sess.get_inputs()
117 outs = sess.get_outputs()
118 if len(outs) == 0:
119 raise OnnxBackendAssertionError( # pragma: no cover
120 "Wrong number of outputs, onnx='{2}'".format(onx))
121 if len(inp) == len(input):
122 inputs = {i.name: v for i, v in zip(inp, input)}
123 elif len(inp) == 1:
124 inputs = {inp[0].name: input}
125 elif isinstance(input, numpy.ndarray):
126 shape = sum(i.shape[1] if len(i.shape) == 2 else i.shape[0]
127 for i in inp)
128 if shape == input.shape[1]:
129 inputs = {n.name: input[:, i] for i, n in enumerate(inp)}
130 else:
131 raise OnnxBackendAssertionError( # pragma: no cover
132 "Wrong number of inputs onnx {0} != "
133 "original shape {1}, onnx='{2}'"
134 .format(len(inp), input.shape, onx))
135 elif isinstance(input, list):
136 try:
137 array_input = numpy.array(input)
138 except Exception: # pragma no cover
139 raise OnnxBackendAssertionError( # pylint: disable=W0707
140 "Wrong number of inputs onnx {0} != "
141 "original {1}, onnx='{2}'"
142 .format(len(inp), len(input), onx))
143 shape = sum(i.shape[1] for i in inp)
144 if shape == array_input.shape[1]:
145 inputs = {}
146 c = 0
147 for i, n in enumerate(inp):
148 d = c + n.shape[1]
149 inputs[n.name] = _create_column(
150 [row[c:d] for row in input], n.type)
151 c = d
152 else:
153 raise OnnxBackendAssertionError( # pragma no cover
154 "Wrong number of inputs onnx {0} != "
155 "original shape {1}, onnx='{2}'*"
156 .format(len(inp), array_input.shape, onx))
157 elif isinstance(input, pandas.DataFrame):
158 try:
159 array_input = numpy.array(input)
160 except Exception: # pragma no cover
161 raise OnnxBackendAssertionError( # pylint: disable=W0707
162 "Wrong number of inputs onnx {0} != "
163 "original {1}, onnx='{2}'"
164 .format(len(inp), len(input), onx))
165 shape = sum(i.shape[1] for i in inp)
166 if shape == array_input.shape[1]:
167 inputs = {}
168 c = 0
169 for i, n in enumerate(inp):
170 d = c + n.shape[1]
171 inputs[n.name] = _create_column(
172 input.iloc[:, c:d], n.type)
173 c = d
174 else:
175 raise OnnxBackendAssertionError( # pragma no cover
176 "Wrong number of inputs onnx {0}={1} columns != "
177 "original shape {2}, onnx='{3}'*"
178 .format(len(inp), shape, array_input.shape, onx))
179 else:
180 raise OnnxBackendAssertionError( # pragma no cover
181 "Wrong type of inputs onnx {0}, onnx='{1}'".format(
182 type(input), onx))
183 else:
184 raise OnnxBackendAssertionError( # pragma no cover
185 "Dict or list is expected, not {0}".format(type(input)))
187 for k in inputs:
188 if isinstance(inputs[k], list):
189 inputs[k] = numpy.array(inputs[k])
191 options.pop('SklCol', False) # unused here but in dump_data_and_model
193 if verbose: # pragma no cover
194 print("[compare_runtime] type(inputs)={} len={} names={}".format(
195 type(input), len(inputs), list(sorted(inputs))))
196 if verbose: # pragma no cover
197 if intermediate_steps:
198 run_options = {'verbose': 3, 'fLOG': print}
199 else:
200 run_options = {'verbose': 2, 'fLOG': print}
201 else:
202 run_options = {}
204 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
205 InvalidArgument as OrtInvalidArgument)
207 try:
208 try:
209 output = sess.run(None, inputs, **run_options)
210 except TypeError: # pragma no cover
211 output = sess.run(None, inputs)
212 lambda_onnx = lambda: sess.run(None, inputs) # noqa
213 if verbose: # pragma no cover
214 import pprint
215 pprint.pprint(output)
216 except ExpectedAssertionError as expe: # pragma no cover
217 raise expe
218 except (RuntimeError, OrtInvalidArgument) as e: # pragma no cover
219 if intermediate_steps:
220 sess.run(None, inputs, verbose=3, fLOG=print)
221 if "-Fail" in onx:
222 raise ExpectedAssertionError( # pylint: disable=W0707
223 "{1} cannot compute the prediction for '{0}'".
224 format(onx, cls_session))
225 else:
226 if verbose: # pragma no cover
227 model = onnx.load(onx)
228 smodel = "\nJSON ONNX\n" + str(model)
229 else:
230 smodel = ""
231 import pprint
232 raise OnnxBackendAssertionError( # pylint: disable=W0707
233 "{4} cannot compute the predictions"
234 " for '{0}' due to {1}{2}\n{3}"
235 .format(onx, e, smodel, pprint.pformat(inputs),
236 cls_session))
237 except Exception as e: # pragma no cover
238 raise OnnxBackendAssertionError( # pylint: disable=W0707
239 "Unable to run onnx '{0}' due to {1}".format(onx, e))
240 if verbose: # pragma no cover
241 print("[compare_runtime] done type={}".format(type(output)))
243 output0 = output.copy()
245 if comparable_outputs:
246 cmp_exp = [load["expected"][o] for o in comparable_outputs]
247 cmp_out = [output[o] for o in comparable_outputs]
248 else:
249 cmp_exp = load["expected"]
250 cmp_out = output
252 try:
253 _compare_expected(cmp_exp, cmp_out, sess, onx,
254 decimal=decimal, verbose=verbose,
255 classes=classes, **options)
256 except ExpectedAssertionError as expe: # pragma no cover
257 raise expe
258 except Exception as e: # pragma no cover
259 if verbose: # pragma no cover
260 model = onnx.load(onx)
261 smodel = "\nJSON ONNX\n" + str(model)
262 else:
263 smodel = ""
264 raise OnnxBackendAssertionError( # pylint: disable=W0707
265 "Model '{}' has discrepencies with cls='{}'.\n{}: {}{}".format(
266 onx, sess.__class__.__name__, type(e), e, smodel))
268 return output0, lambda_onnx