Coverage for mlprodict/testing/test_utils/utils_backend_common.py: 79%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Inspired from :epkg:`sklearn-onnx`, handles two backends.
4"""
5import os
6import pickle
7import numpy
8from numpy.testing import assert_array_almost_equal, assert_array_equal
9from scipy.sparse.csr import csr_matrix
10import pandas
11from ...onnxrt.ops_cpu.op_zipmap import ArrayZipMapDictionary
14class ExpectedAssertionError(Exception):
15 """
16 Expected failure.
17 """
18 pass
21class OnnxBackendAssertionError(AssertionError):
22 """
23 Expected failure.
24 """
25 pass
28class OnnxBackendMissingNewOnnxOperatorException(OnnxBackendAssertionError):
29 """
30 Raised when :epkg:`onnxruntime` or :epkg:`mlprodict`
31 does not implement a new operator
32 defined in the latest onnx.
33 """
34 pass
37class OnnxRuntimeMissingNewOnnxOperatorException(OnnxBackendAssertionError):
38 """
39 Raised when a new operator was added but cannot be found.
40 """
41 pass
44def evaluate_condition(backend, condition):
45 """
46 Evaluates a condition such as
47 ``StrictVersion(onnxruntime.__version__) <= StrictVersion('0.1.3')``
48 """
49 if backend == "onnxruntime": # pragma: no cover
50 import onnxruntime # pylint: disable=W0611
51 return eval(condition) # pylint: disable=W0123
52 raise NotImplementedError( # pragma no cover
53 "Not implemented for backend '{0}' and "
54 "condition '{1}'.".format(backend, condition))
57def is_backend_enabled(backend):
58 """
59 Tells if a backend is enabled.
60 Raises an exception if backend != 'onnxruntime'.
61 Unit tests only test models against this backend.
62 """
63 if backend == "onnxruntime":
64 try:
65 import onnxruntime # pylint: disable=W0611
66 return True
67 except ImportError: # pragma no cover
68 return False
69 if backend == "python":
70 return True
71 raise NotImplementedError( # pragma no cover
72 "Not implemented for backend '{0}'".format(backend))
75def load_data_and_model(items_as_dict, **context):
76 """
77 Loads every file in a dictionary {key: filename}.
78 The extension is either *pkl* and *onnx* and determines
79 how it it loaded. If the value is not a string,
80 the function assumes it was already loaded.
81 """
82 res = {}
83 for k, v in items_as_dict.items():
84 if isinstance(v, str):
85 if os.path.splitext(v)[-1] == ".pkl":
86 with open(v, "rb") as f: # pragma: no cover
87 try:
88 bin = pickle.load(f)
89 except ImportError as e:
90 if '.model.' in v:
91 continue
92 raise ImportError( # pylint: disable=W0707
93 "Unable to load '{0}' due to {1}".format(v, e))
94 res[k] = bin
95 else:
96 res[k] = v
97 else:
98 res[k] = v
99 return res
102def extract_options(name):
103 """
104 Extracts comparison option from filename.
105 As example, ``Binarizer-SkipDim1`` means
106 options *SkipDim1* is enabled.
107 ``(1, 2)`` and ``(2,)`` are considered equal.
108 Available options: see :func:`dump_data_and_model`.
109 """
110 opts = name.replace("\\", "/").split("/")[-1].split('.')[0].split('-')
111 if len(opts) == 1:
112 return {}
113 res = {}
114 for opt in opts[1:]:
115 if opt in ("SkipDim1", "OneOff", "NoProb", "NoProbOpp",
116 "Dec4", "Dec3", "Dec2", 'Svm',
117 'Out0', 'Reshape', 'SklCol', 'DF', 'OneOffArray'):
118 res[opt] = True
119 else:
120 raise NameError("Unable to parse option '{}'".format(
121 opts[1:])) # pragma no cover
122 return res
125def compare_outputs(expected, output, verbose=False, **kwargs):
126 """
127 Compares expected values and output.
128 Returns None if no error, an exception message otherwise.
129 """
130 SkipDim1 = kwargs.pop("SkipDim1", False)
131 NoProb = kwargs.pop("NoProb", False)
132 NoProbOpp = kwargs.pop("NoProbOpp", False)
133 Dec4 = kwargs.pop("Dec4", False)
134 Dec3 = kwargs.pop("Dec3", False)
135 Dec2 = kwargs.pop("Dec2", False)
136 Disc = kwargs.pop("Disc", False)
137 Mism = kwargs.pop("Mism", False)
139 if Dec4:
140 kwargs["decimal"] = min(kwargs["decimal"], 4)
141 if Dec3:
142 kwargs["decimal"] = min(kwargs["decimal"], 3)
143 if Dec2:
144 kwargs["decimal"] = min(kwargs["decimal"], 2) # pragma: no cover
145 if isinstance(expected, numpy.ndarray) and isinstance(
146 output, numpy.ndarray):
147 if SkipDim1:
148 # Arrays like (2, 1, 2, 3) becomes (2, 2, 3)
149 # as one dimension is useless.
150 expected = expected.reshape(
151 tuple([d for d in expected.shape if d > 1]))
152 output = output.reshape(
153 tuple([d for d in expected.shape if d > 1]))
154 if NoProb or NoProbOpp:
155 # One vector is (N,) with scores, negative for class 0
156 # positive for class 1
157 # The other vector is (N, 2) score in two columns.
158 if len(output.shape) == 2 and output.shape[1] == 2 and len(
159 expected.shape) == 1:
160 output = output[:, 1]
161 if NoProbOpp:
162 output = -output
163 elif len(output.shape) == 1 and len(expected.shape) == 1:
164 pass
165 elif len(expected.shape) == 1 and len(output.shape) == 2 and \
166 expected.shape[0] == output.shape[0] and \
167 output.shape[1] == 1:
168 output = output[:, 0]
169 if NoProbOpp:
170 output = -output
171 elif expected.shape != output.shape:
172 raise NotImplementedError("Shape mismatch: {0} != {1}".format( # pragma no cover
173 expected.shape, output.shape))
174 if len(expected.shape) == 1 and len(
175 output.shape) == 2 and output.shape[1] == 1:
176 output = output.ravel()
177 if len(output.shape) == 3 and output.shape[0] == 1 and len(
178 expected.shape) == 2:
179 output = output.reshape(output.shape[1:])
180 if expected.dtype in (numpy.str_, numpy.dtype("<U1"),
181 numpy.dtype("<U3")):
182 try:
183 assert_array_equal(expected, output, verbose=verbose)
184 except Exception as e: # pylint: disable=W0703
185 if Disc: # pragma no cover
186 # Bug to be fixed later.
187 return ExpectedAssertionError(str(e))
188 else: # pragma no cover
189 return OnnxBackendAssertionError(str(e))
190 else:
191 try:
192 assert_array_almost_equal(expected,
193 output,
194 verbose=verbose,
195 **kwargs)
196 except (RuntimeError, AssertionError) as e: # pragma no cover
197 longer = "\n--EXPECTED--\n{0}\n--OUTPUT--\n{1}".format(
198 expected, output) if verbose else ""
199 expected_ = numpy.asarray(expected).ravel()
200 output_ = numpy.asarray(output).ravel()
201 if len(expected_) == len(output_):
202 if numpy.issubdtype(expected_.dtype, numpy.floating):
203 diff = numpy.abs(expected_ - output_).max()
204 else:
205 diff = max((1 if ci != cj else 0)
206 for ci, cj in zip(expected_, output_))
207 if diff == 0:
208 return None
209 elif Mism:
210 return ExpectedAssertionError(
211 "dimension mismatch={0}, {1}\n{2}{3}".format(
212 expected.shape, output.shape, e, longer))
213 else:
214 return OnnxBackendAssertionError(
215 "dimension mismatch={0}, {1}\n{2}{3}".format(
216 expected.shape, output.shape, e, longer))
217 if Disc:
218 # Bug to be fixed later.
219 return ExpectedAssertionError(
220 "max-diff={0}\n--expected--output--\n{1}{2}".format(
221 diff, e, longer))
222 return OnnxBackendAssertionError(
223 "max-diff={0}\n--expected--output--\n{1}{2}".format(
224 diff, e, longer))
225 else:
226 return OnnxBackendAssertionError( # pragma: no cover
227 "Unexpected types {0} != {1}".format(
228 type(expected), type(output)))
229 return None
232def _post_process_output(res):
233 """
234 Applies post processings before running the comparison
235 such as changing type from list to arrays.
236 """
237 if isinstance(res, list):
238 if len(res) == 0:
239 return res
240 if len(res) == 1:
241 return _post_process_output(res[0])
242 if isinstance(res[0], numpy.ndarray):
243 return numpy.array(res)
244 if isinstance(res[0], dict):
245 return pandas.DataFrame(res).values
246 ls = [len(r) for r in res]
247 mi = min(ls)
248 if mi != max(ls):
249 raise NotImplementedError( # pragma no cover
250 "Unable to postprocess various number of "
251 "outputs in [{0}, {1}]"
252 .format(min(ls), max(ls)))
253 if mi > 1:
254 output = []
255 for i in range(mi):
256 output.append(_post_process_output([r[i] for r in res]))
257 return output
258 if isinstance(res[0], list):
259 # list of lists
260 if isinstance(res[0][0], list):
261 return numpy.array(res)
262 if len(res[0]) == 1 and isinstance(res[0][0], dict):
263 return _post_process_output([r[0] for r in res])
264 if len(res) == 1:
265 return res
266 if len(res[0]) != 1:
267 raise NotImplementedError( # pragma no cover
268 "Not conversion implemented for {0}".format(res))
269 st = [r[0] for r in res]
270 return numpy.vstack(st)
271 return res
272 return res
275def _create_column(values, dtype):
276 "Creates a column from values with dtype"
277 if str(dtype) == "tensor(int64)":
278 return numpy.array(values, dtype=numpy.int64)
279 if str(dtype) == "tensor(float)":
280 return numpy.array(values, dtype=numpy.float32)
281 if str(dtype) in ("tensor(double)", "tensor(float64)"):
282 return numpy.array(values, dtype=numpy.float64)
283 if str(dtype) in ("tensor(string)", "tensor(str)"):
284 return numpy.array(values, dtype=numpy.str_)
285 raise OnnxBackendAssertionError(
286 "Unable to create one column from dtype '{0}'".format(dtype))
289def _compare_expected(expected, output, sess, onnx_model,
290 decimal=5, verbose=False, classes=None,
291 **kwargs):
292 """
293 Compares the expected output against the runtime outputs.
294 This is specific to :epkg:`onnxruntime` or :epkg:`mlprodict`.
295 """
296 tested = 0
297 if isinstance(expected, list):
298 if isinstance(output, list):
299 if 'Out0' in kwargs:
300 expected = expected[:1]
301 output = output[:1]
302 del kwargs['Out0']
303 if 'Reshape' in kwargs:
304 del kwargs['Reshape']
305 output = numpy.hstack(output).ravel()
306 output = output.reshape(
307 (len(expected), len(output.ravel()) // len(expected)))
308 if len(expected) != len(output):
309 raise OnnxBackendAssertionError( # pragma no cover
310 "Unexpected number of outputs '{0}', expected={1}, got={2}"
311 .format(onnx_model, len(expected), len(output)))
312 for exp, out in zip(expected, output):
313 _compare_expected(exp, out, sess, onnx_model, decimal=5, verbose=verbose,
314 classes=classes, **kwargs)
315 tested += 1
316 else:
317 raise OnnxBackendAssertionError( # pragma no cover
318 "Type mismatch for '{0}', output type is {1}".format(
319 onnx_model, type(output)))
320 elif isinstance(expected, dict):
321 if not isinstance(output, dict):
322 raise OnnxBackendAssertionError( # pragma no cover
323 "Type mismatch for '{0}'".format(onnx_model))
324 for k, v in output.items():
325 if k not in expected:
326 continue
327 msg = compare_outputs(
328 expected[k], v, decimal=decimal, verbose=verbose, **kwargs)
329 if msg:
330 raise OnnxBackendAssertionError( # pragma no cover
331 "Unexpected output '{0}' in model '{1}'\n{2}".format(
332 k, onnx_model, msg))
333 tested += 1
334 elif isinstance(expected, numpy.ndarray):
335 if isinstance(output, list):
336 if expected.shape[0] == len(output) and isinstance(
337 output[0], dict):
338 if isinstance(output, ArrayZipMapDictionary):
339 output = pandas.DataFrame(list(output))
340 else:
341 output = pandas.DataFrame(output)
342 output = output[list(sorted(output.columns))]
343 output = output.values
344 if isinstance(output, (dict, list)):
345 if len(output) != 1: # pragma: no cover
346 ex = str(output)
347 if len(ex) > 170:
348 ex = ex[:170] + "..."
349 raise OnnxBackendAssertionError(
350 "More than one output when 1 is expected "
351 "for onnx '{0}'\n{1}"
352 .format(onnx_model, ex))
353 output = output[-1]
354 if not isinstance(output, numpy.ndarray):
355 raise OnnxBackendAssertionError( # pragma no cover
356 "output must be an array for onnx '{0}' not {1}".format(
357 onnx_model, type(output)))
358 if (classes is not None and (
359 expected.dtype == numpy.str_ or expected.dtype.char == 'U')):
360 try:
361 output = numpy.array([classes[cl] for cl in output])
362 except IndexError as e: # pragma no cover
363 raise RuntimeError('Unable to handle\n{}\n{}\n{}'.format(
364 expected, output, classes)) from e
365 msg = compare_outputs(
366 expected, output, decimal=decimal, verbose=verbose, **kwargs)
367 if isinstance(msg, ExpectedAssertionError):
368 raise msg # pylint: disable=E0702
369 if msg:
370 raise OnnxBackendAssertionError( # pragma no cover
371 "Unexpected output in model '{0}'\n{1}".format(onnx_model, msg))
372 tested += 1
373 else:
374 if isinstance(expected, csr_matrix):
375 # DictVectorizer
376 one_array = numpy.array(output)
377 dense = numpy.asarray(expected.todense())
378 msg = compare_outputs(dense, one_array, decimal=decimal,
379 verbose=verbose, **kwargs)
380 if msg:
381 raise OnnxBackendAssertionError( # pragma no cover
382 "Unexpected output in model '{0}'\n{1}".format(onnx_model, msg))
383 tested += 1
384 else:
385 raise OnnxBackendAssertionError( # pragma no cover
386 "Unexpected type for expected output ({1}) and onnx '{0}'".
387 format(onnx_model, type(expected)))
388 if tested == 0:
389 raise OnnxBackendAssertionError( # pragma no cover
390 "No test for onnx '{0}'".format(onnx_model))