Coverage for mlprodict/testing/onnx_backend.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Tests with onnx backend.
4"""
5import os
6import textwrap
7import numpy
8from numpy import object as dtype_object
9from numpy.testing import assert_almost_equal
10import onnx
11from onnx.numpy_helper import to_array, to_list
12from onnx.backend.test import __file__ as backend_folder
15def assert_almost_equal_string(expected, value):
16 """
17 Compares two arrays knowing they contain strings.
18 Raises an exception if the test fails.
20 :param expected: expected array
21 :param value: value
22 """
23 def is_float(x):
24 try:
25 return True
26 except ValueError: # pragma: no cover
27 return False
29 if all(map(is_float, expected.ravel())):
30 expected_float = expected.astype(numpy.float32)
31 value_float = value.astype(numpy.float32)
32 assert_almost_equal(expected_float, value_float)
33 else:
34 assert_almost_equal(expected, value)
37class OnnxBackendTest:
38 """
39 Definition of a backend test. It starts with a folder,
40 in this folder, one onnx file must be there, then a subfolder
41 for each test to run with this model.
43 :param folder: test folder
44 :param onnx_path: onnx file
45 :param onnx_model: loaded onnx file
46 :param tests: list of test
47 """
48 @staticmethod
49 def _sort(filenames):
50 temp = []
51 for f in filenames:
52 name = os.path.splitext(f)[0]
53 i = name.split('_')[-1]
54 temp.append((int(i), f))
55 temp.sort()
56 return [_[1] for _ in temp]
58 @staticmethod
59 def _read_proto_from_file(full):
60 if not os.path.exists(full):
61 raise FileNotFoundError( # pragma: no cover
62 "File not found: %r." % full)
63 with open(full, 'rb') as f:
64 serialized = f.read()
65 try:
66 loaded = to_array(onnx.load_tensor_from_string(serialized))
67 except Exception as e: # pylint: disable=W0703
68 seq = onnx.SequenceProto()
69 try:
70 seq.ParseFromString(serialized)
71 loaded = to_list(seq)
72 except Exception: # pylint: disable=W0703
73 try:
74 loaded = onnx.load_model_from_string(serialized)
75 except Exception: # pragma: no cover
76 raise RuntimeError(
77 "Unable to read %r, error is %s, content is %r." % (
78 full, e, serialized[:100])) from e
79 return loaded
81 @staticmethod
82 def _load(folder, names):
83 res = []
84 for name in names:
85 full = os.path.join(folder, name)
86 new_tensor = OnnxBackendTest._read_proto_from_file(full)
87 if isinstance(new_tensor, (numpy.ndarray, onnx.ModelProto, list)):
88 t = new_tensor
89 elif isinstance(new_tensor, onnx.TensorProto):
90 t = to_array(new_tensor)
91 else:
92 raise RuntimeError( # pragma: no cover
93 "Unexpected type %r for %r." % (type(new_tensor), full))
94 res.append(t)
95 return res
97 def __repr__(self):
98 "usual"
99 return "%s(%r)" % (self.__class__.__name__, self.folder)
101 def __init__(self, folder):
102 if not os.path.exists(folder):
103 raise FileNotFoundError( # pragma: no cover
104 "Unable to find folder %r." % folder)
105 content = os.listdir(folder)
106 onx = [c for c in content if os.path.splitext(c)[-1] in {'.onnx'}]
107 if len(onx) != 1:
108 raise ValueError( # pragma: no cover
109 "There is more than one onnx file in %r (%r)." % (
110 folder, onx))
111 self.folder = folder
112 self.onnx_path = os.path.join(folder, onx[0])
113 self.onnx_model = onnx.load(self.onnx_path)
115 self.tests = []
116 for sub in content:
117 full = os.path.join(folder, sub)
118 if os.path.isdir(full):
119 pb = [c for c in os.listdir(full)
120 if os.path.splitext(c)[-1] in {'.pb'}]
121 inputs = OnnxBackendTest._sort(
122 c for c in pb if c.startswith('input_'))
123 outputs = OnnxBackendTest._sort(
124 c for c in pb if c.startswith('output_'))
126 t = dict(
127 inputs=OnnxBackendTest._load(full, inputs),
128 outputs=OnnxBackendTest._load(full, outputs))
129 self.tests.append(t)
131 @property
132 def name(self):
133 "Returns the test name."
134 return os.path.split(self.folder)[-1]
136 def __len__(self):
137 "Returns the number of tests."
138 return len(self.tests)
140 def _compare_results(self, index, i, e, o):
141 """
142 Compares the expected output and the output produced
143 by the runtime. Raises an exception if not equal.
145 :param index: test index
146 :param i: output index
147 :param e: expected output
148 :param o: output
149 """
150 decimal = 7
151 if isinstance(e, numpy.ndarray):
152 if isinstance(o, numpy.ndarray):
153 if e.dtype == numpy.float32:
154 decimal = 6
155 elif e.dtype == numpy.float64:
156 decimal = 12
157 if e.dtype == dtype_object:
158 try:
159 assert_almost_equal_string(e, o)
160 except AssertionError as ex:
161 raise AssertionError( # pragma: no cover
162 "Output %d of test %d in folder %r failed." % (
163 i, index, self.folder)) from ex
164 else:
165 try:
166 assert_almost_equal(e, o, decimal=decimal)
167 except AssertionError as ex:
168 raise AssertionError(
169 "Output %d of test %d in folder %r failed." % (
170 i, index, self.folder)) from ex
171 elif hasattr(o, 'is_compatible'):
172 # A shape
173 if e.dtype != o.dtype:
174 raise AssertionError(
175 "Output %d of test %d in folder %r failed "
176 "(e.dtype=%r, o=%r)." % (
177 i, index, self.folder, e.dtype, o))
178 if not o.is_compatible(e.shape):
179 raise AssertionError( # pragma: no cover
180 "Output %d of test %d in folder %r failed "
181 "(e.shape=%r, o=%r)." % (
182 i, index, self.folder, e.shape, o))
183 else:
184 raise NotImplementedError(
185 "Comparison not implemented for type %r." % type(e))
187 def is_random(self):
188 "Tells if a test is random or not."
189 if 'bernoulli' in self.folder:
190 return True
191 return False
193 def run(self, load_fct, run_fct, index=None, decimal=5):
194 """
195 Executes a tests or all tests if index is None.
196 The function crashes if the tests fails.
198 :param load_fct: loading function, takes a loaded onnx graph,
199 and returns an object
200 :param run_fct: running function, takes the result of previous
201 function, the inputs, and returns the outputs
202 :param index: index of the test to run or all.
203 """
204 if index is None:
205 for i in range(len(self)):
206 self.run(load_fct, run_fct, index=i)
207 return
209 obj = load_fct(self.onnx_model)
211 got = run_fct(obj, *self.tests[index]['inputs'])
212 expected = self.tests[index]['outputs']
213 if len(got) != len(expected):
214 raise AssertionError( # pragma: no cover
215 "Unexpected number of output (test %d, folder %r), "
216 "got %r, expected %r." % (
217 index, self.folder, len(got), len(expected)))
218 for i, (e, o) in enumerate(zip(expected, got)):
219 if self.is_random():
220 if e.dtype != o.dtype:
221 raise AssertionError(
222 "Output %d of test %d in folder %r failed "
223 "(type mismatch %r != %r)." % (
224 i, index, self.folder, e.dtype, o.dtype))
225 if e.shape != o.shape:
226 raise AssertionError(
227 "Output %d of test %d in folder %r failed "
228 "(shape mismatch %r != %r)." % (
229 i, index, self.folder, e.shape, o.shape))
230 else:
231 self._compare_results(index, i, e, o)
233 def to_python(self):
234 """
235 Returns a python code equivalent to the ONNX test.
237 :return: code
238 """
239 from ..onnx_tools.onnx_export import export2onnx
240 rows = []
241 code = export2onnx(self.onnx_model)
242 lines = code.split('\n')
243 lines = [line for line in lines
244 if not line.strip().startswith('print') and
245 not line.strip().startswith('# ')]
246 rows.append(textwrap.dedent("\n".join(lines)))
247 rows.append("oinf = OnnxInference(onnx_model)")
248 for test in self.tests:
249 rows.append("xs = [")
250 for inp in test['inputs']:
251 rows.append(textwrap.indent(repr(inp) + ',', ' ' * 2))
252 rows.append("]")
253 rows.append("ys = [")
254 for out in test['outputs']:
255 rows.append(textwrap.indent(repr(out) + ',', ' ' * 2))
256 rows.append("]")
257 rows.append("feeds = {n: x for n, x in zip(oinf.input_names, xs)}")
258 rows.append("got = oinf.run(feeds)")
259 rows.append("goty = [got[k] for k in oinf.output_names]")
260 rows.append("for y, gy in zip(ys, goty):")
261 rows.append(" self.assertEqualArray(y, gy)")
262 rows.append("")
263 code = "\n".join(rows)
264 final = "\n".join(["def %s(self):" % self.name,
265 textwrap.indent(code, ' ')])
266 try:
267 from pyquickhelper.pycode.code_helper import remove_extra_spaces_and_pep8
268 except ImportError: # pragma: no cover
269 return final
270 return remove_extra_spaces_and_pep8(final, aggressive=True)
273def enumerate_onnx_tests(series, fct_filter=None):
274 """
275 Collects test from a sub folder of `onnx/backend/test`.
276 Works as an enumerator to start processing them
277 without waiting or storing too much of them.
279 :param series: which subfolder to load
280 :param fct_filter: function `lambda testname: boolean`
281 to load or skip the test, None for all
282 :return: list of @see cl OnnxBackendTest
283 """
284 root = os.path.dirname(backend_folder)
285 sub = os.path.join(root, 'data', series)
286 if not os.path.exists(sub):
287 raise FileNotFoundError(
288 "Unable to find series of tests in %r, subfolders:\n%s" % (
289 root, "\n".join(os.listdir(root))))
290 tests = os.listdir(sub)
291 for t in tests:
292 if fct_filter is not None and not fct_filter(t):
293 continue
294 folder = os.path.join(sub, t)
295 content = os.listdir(folder)
296 onx = [c for c in content if os.path.splitext(c)[-1] in {'.onnx'}]
297 if len(onx) == 1:
298 yield OnnxBackendTest(folder)