Coverage for mlprodict/testing/test_utils/utils_backend_python.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Inspired from sklearn-onnx, handles two backends.
4"""
5from ...onnxrt import OnnxInference
6from .utils_backend_common_compare import compare_runtime_session
9class MockVariableName:
10 "A string."
12 def __init__(self, name):
13 self.name = name
15 @property
16 def shape(self):
17 "returns shape"
18 raise NotImplementedError( # pragma: no cover
19 "No shape for '{}'.".format(self.name))
21 @property
22 def type(self):
23 "returns type"
24 raise NotImplementedError( # pragma: no cover
25 "No type for '{}'.".format(self.name))
28class MockVariableNameShape(MockVariableName):
29 "A string and a shape."
31 def __init__(self, name, sh):
32 MockVariableName.__init__(self, name)
33 self._shape = sh
35 @property
36 def shape(self):
37 "returns shape"
38 return self._shape
41class MockVariableNameShapeType(MockVariableNameShape):
42 "A string and a shape and a type."
44 def __init__(self, name, sh, stype):
45 MockVariableNameShape.__init__(self, name, sh)
46 self._stype = stype
48 @property
49 def type(self):
50 "returns type"
51 return self._stype
54class OnnxInference2(OnnxInference):
55 "onnxruntime API"
57 def run(self, name, inputs, *args, **kwargs): # pylint: disable=W0221
58 "onnxruntime API"
59 res = OnnxInference.run(self, inputs, **kwargs)
60 if name is None:
61 return [res[n] for n in self.output_names]
62 if name in res: # pragma: no cover
63 return res[name]
64 raise RuntimeError( # pragma: no cover
65 "Unable to find output '{}'.".format(name))
67 def get_inputs(self):
68 "onnxruntime API"
69 return [MockVariableNameShapeType(*n) for n in self.input_names_shapes_types]
71 def get_outputs(self):
72 "onnxruntime API"
73 return [MockVariableNameShape(*n) for n in self.output_names_shapes]
75 def run_in_scan(self, inputs, verbose=0, fLOG=None):
76 "Instance to run in operator scan."
77 return OnnxInference.run(self, inputs, verbose=verbose, fLOG=fLOG)
80def compare_runtime(test, decimal=5, options=None,
81 verbose=False, context=None, comparable_outputs=None,
82 intermediate_steps=False, classes=None,
83 disable_optimisation=False):
84 """
85 The function compares the expected output (computed with
86 the model before being converted to ONNX) and the ONNX output
87 produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`.
89 :param test: dictionary with the following keys:
90 - *onnx*: onnx model (filename or object)
91 - *expected*: expected output (filename pkl or object)
92 - *data*: input data (filename pkl or object)
93 :param decimal: precision of the comparison
94 :param options: comparison options
95 :param context: specifies custom operators
96 :param verbose: in case of error, the function may print
97 more information on the standard output
98 :param comparable_outputs: compare only these outputs
99 :param intermediate_steps: displays intermediate steps
100 in case of an error
101 :param classes: classes names (if option 'nocl' is used)
102 :param disable_optimisation: disable optimisation the runtime may do
103 :return: tuple (outut, lambda function to run the predictions)
105 The function does not return anything but raises an error
106 if the comparison failed.
107 """
108 return compare_runtime_session(
109 OnnxInference2, test, decimal=decimal, options=options,
110 verbose=verbose, context=context,
111 comparable_outputs=comparable_outputs,
112 intermediate_steps=intermediate_steps,
113 classes=classes, disable_optimisation=disable_optimisation)