Coverage for mlprodict/testing/onnx_backend.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

161 statements  

1""" 

2@file 

3@brief Tests with onnx backend. 

4""" 

5import os 

6import textwrap 

7import numpy 

8from numpy import object as dtype_object 

9from numpy.testing import assert_almost_equal 

10import onnx 

11from onnx.numpy_helper import to_array, to_list 

12from onnx.backend.test import __file__ as backend_folder 

13 

14 

15def assert_almost_equal_string(expected, value): 

16 """ 

17 Compares two arrays knowing they contain strings. 

18 Raises an exception if the test fails. 

19 

20 :param expected: expected array 

21 :param value: value 

22 """ 

23 def is_float(x): 

24 try: 

25 return True 

26 except ValueError: # pragma: no cover 

27 return False 

28 

29 if all(map(is_float, expected.ravel())): 

30 expected_float = expected.astype(numpy.float32) 

31 value_float = value.astype(numpy.float32) 

32 assert_almost_equal(expected_float, value_float) 

33 else: 

34 assert_almost_equal(expected, value) 

35 

36 

37class OnnxBackendTest: 

38 """ 

39 Definition of a backend test. It starts with a folder, 

40 in this folder, one onnx file must be there, then a subfolder 

41 for each test to run with this model. 

42 

43 :param folder: test folder 

44 :param onnx_path: onnx file 

45 :param onnx_model: loaded onnx file 

46 :param tests: list of test 

47 """ 

48 @staticmethod 

49 def _sort(filenames): 

50 temp = [] 

51 for f in filenames: 

52 name = os.path.splitext(f)[0] 

53 i = name.split('_')[-1] 

54 temp.append((int(i), f)) 

55 temp.sort() 

56 return [_[1] for _ in temp] 

57 

58 @staticmethod 

59 def _read_proto_from_file(full): 

60 if not os.path.exists(full): 

61 raise FileNotFoundError( # pragma: no cover 

62 "File not found: %r." % full) 

63 with open(full, 'rb') as f: 

64 serialized = f.read() 

65 try: 

66 loaded = to_array(onnx.load_tensor_from_string(serialized)) 

67 except Exception as e: # pylint: disable=W0703 

68 seq = onnx.SequenceProto() 

69 try: 

70 seq.ParseFromString(serialized) 

71 loaded = to_list(seq) 

72 except Exception: # pylint: disable=W0703 

73 try: 

74 loaded = onnx.load_model_from_string(serialized) 

75 except Exception: # pragma: no cover 

76 raise RuntimeError( 

77 "Unable to read %r, error is %s, content is %r." % ( 

78 full, e, serialized[:100])) from e 

79 return loaded 

80 

81 @staticmethod 

82 def _load(folder, names): 

83 res = [] 

84 for name in names: 

85 full = os.path.join(folder, name) 

86 new_tensor = OnnxBackendTest._read_proto_from_file(full) 

87 if isinstance(new_tensor, (numpy.ndarray, onnx.ModelProto, list)): 

88 t = new_tensor 

89 elif isinstance(new_tensor, onnx.TensorProto): 

90 t = to_array(new_tensor) 

91 else: 

92 raise RuntimeError( # pragma: no cover 

93 "Unexpected type %r for %r." % (type(new_tensor), full)) 

94 res.append(t) 

95 return res 

96 

97 def __repr__(self): 

98 "usual" 

99 return "%s(%r)" % (self.__class__.__name__, self.folder) 

100 

101 def __init__(self, folder): 

102 if not os.path.exists(folder): 

103 raise FileNotFoundError( # pragma: no cover 

104 "Unable to find folder %r." % folder) 

105 content = os.listdir(folder) 

106 onx = [c for c in content if os.path.splitext(c)[-1] in {'.onnx'}] 

107 if len(onx) != 1: 

108 raise ValueError( # pragma: no cover 

109 "There is more than one onnx file in %r (%r)." % ( 

110 folder, onx)) 

111 self.folder = folder 

112 self.onnx_path = os.path.join(folder, onx[0]) 

113 self.onnx_model = onnx.load(self.onnx_path) 

114 

115 self.tests = [] 

116 for sub in content: 

117 full = os.path.join(folder, sub) 

118 if os.path.isdir(full): 

119 pb = [c for c in os.listdir(full) 

120 if os.path.splitext(c)[-1] in {'.pb'}] 

121 inputs = OnnxBackendTest._sort( 

122 c for c in pb if c.startswith('input_')) 

123 outputs = OnnxBackendTest._sort( 

124 c for c in pb if c.startswith('output_')) 

125 

126 t = dict( 

127 inputs=OnnxBackendTest._load(full, inputs), 

128 outputs=OnnxBackendTest._load(full, outputs)) 

129 self.tests.append(t) 

130 

131 @property 

132 def name(self): 

133 "Returns the test name." 

134 return os.path.split(self.folder)[-1] 

135 

136 def __len__(self): 

137 "Returns the number of tests." 

138 return len(self.tests) 

139 

140 def _compare_results(self, index, i, e, o): 

141 """ 

142 Compares the expected output and the output produced 

143 by the runtime. Raises an exception if not equal. 

144 

145 :param index: test index 

146 :param i: output index 

147 :param e: expected output 

148 :param o: output 

149 """ 

150 decimal = 7 

151 if isinstance(e, numpy.ndarray): 

152 if isinstance(o, numpy.ndarray): 

153 if e.dtype == numpy.float32: 

154 decimal = 6 

155 elif e.dtype == numpy.float64: 

156 decimal = 12 

157 if e.dtype == dtype_object: 

158 try: 

159 assert_almost_equal_string(e, o) 

160 except AssertionError as ex: 

161 raise AssertionError( # pragma: no cover 

162 "Output %d of test %d in folder %r failed." % ( 

163 i, index, self.folder)) from ex 

164 else: 

165 try: 

166 assert_almost_equal(e, o, decimal=decimal) 

167 except AssertionError as ex: 

168 raise AssertionError( 

169 "Output %d of test %d in folder %r failed." % ( 

170 i, index, self.folder)) from ex 

171 elif hasattr(o, 'is_compatible'): 

172 # A shape 

173 if e.dtype != o.dtype: 

174 raise AssertionError( 

175 "Output %d of test %d in folder %r failed " 

176 "(e.dtype=%r, o=%r)." % ( 

177 i, index, self.folder, e.dtype, o)) 

178 if not o.is_compatible(e.shape): 

179 raise AssertionError( # pragma: no cover 

180 "Output %d of test %d in folder %r failed " 

181 "(e.shape=%r, o=%r)." % ( 

182 i, index, self.folder, e.shape, o)) 

183 else: 

184 raise NotImplementedError( 

185 "Comparison not implemented for type %r." % type(e)) 

186 

187 def is_random(self): 

188 "Tells if a test is random or not." 

189 if 'bernoulli' in self.folder: 

190 return True 

191 return False 

192 

193 def run(self, load_fct, run_fct, index=None, decimal=5): 

194 """ 

195 Executes a tests or all tests if index is None. 

196 The function crashes if the tests fails. 

197 

198 :param load_fct: loading function, takes a loaded onnx graph, 

199 and returns an object 

200 :param run_fct: running function, takes the result of previous 

201 function, the inputs, and returns the outputs 

202 :param index: index of the test to run or all. 

203 """ 

204 if index is None: 

205 for i in range(len(self)): 

206 self.run(load_fct, run_fct, index=i) 

207 return 

208 

209 obj = load_fct(self.onnx_model) 

210 

211 got = run_fct(obj, *self.tests[index]['inputs']) 

212 expected = self.tests[index]['outputs'] 

213 if len(got) != len(expected): 

214 raise AssertionError( # pragma: no cover 

215 "Unexpected number of output (test %d, folder %r), " 

216 "got %r, expected %r." % ( 

217 index, self.folder, len(got), len(expected))) 

218 for i, (e, o) in enumerate(zip(expected, got)): 

219 if self.is_random(): 

220 if e.dtype != o.dtype: 

221 raise AssertionError( 

222 "Output %d of test %d in folder %r failed " 

223 "(type mismatch %r != %r)." % ( 

224 i, index, self.folder, e.dtype, o.dtype)) 

225 if e.shape != o.shape: 

226 raise AssertionError( 

227 "Output %d of test %d in folder %r failed " 

228 "(shape mismatch %r != %r)." % ( 

229 i, index, self.folder, e.shape, o.shape)) 

230 else: 

231 self._compare_results(index, i, e, o) 

232 

233 def to_python(self): 

234 """ 

235 Returns a python code equivalent to the ONNX test. 

236 

237 :return: code 

238 """ 

239 from ..onnx_tools.onnx_export import export2onnx 

240 rows = [] 

241 code = export2onnx(self.onnx_model) 

242 lines = code.split('\n') 

243 lines = [line for line in lines 

244 if not line.strip().startswith('print') and 

245 not line.strip().startswith('# ')] 

246 rows.append(textwrap.dedent("\n".join(lines))) 

247 rows.append("oinf = OnnxInference(onnx_model)") 

248 for test in self.tests: 

249 rows.append("xs = [") 

250 for inp in test['inputs']: 

251 rows.append(textwrap.indent(repr(inp) + ',', ' ' * 2)) 

252 rows.append("]") 

253 rows.append("ys = [") 

254 for out in test['outputs']: 

255 rows.append(textwrap.indent(repr(out) + ',', ' ' * 2)) 

256 rows.append("]") 

257 rows.append("feeds = {n: x for n, x in zip(oinf.input_names, xs)}") 

258 rows.append("got = oinf.run(feeds)") 

259 rows.append("goty = [got[k] for k in oinf.output_names]") 

260 rows.append("for y, gy in zip(ys, goty):") 

261 rows.append(" self.assertEqualArray(y, gy)") 

262 rows.append("") 

263 code = "\n".join(rows) 

264 final = "\n".join(["def %s(self):" % self.name, 

265 textwrap.indent(code, ' ')]) 

266 try: 

267 from pyquickhelper.pycode.code_helper import remove_extra_spaces_and_pep8 

268 except ImportError: # pragma: no cover 

269 return final 

270 return remove_extra_spaces_and_pep8(final, aggressive=True) 

271 

272 

273def enumerate_onnx_tests(series, fct_filter=None): 

274 """ 

275 Collects test from a sub folder of `onnx/backend/test`. 

276 Works as an enumerator to start processing them 

277 without waiting or storing too much of them. 

278 

279 :param series: which subfolder to load 

280 :param fct_filter: function `lambda testname: boolean` 

281 to load or skip the test, None for all 

282 :return: list of @see cl OnnxBackendTest 

283 """ 

284 root = os.path.dirname(backend_folder) 

285 sub = os.path.join(root, 'data', series) 

286 if not os.path.exists(sub): 

287 raise FileNotFoundError( 

288 "Unable to find series of tests in %r, subfolders:\n%s" % ( 

289 root, "\n".join(os.listdir(root)))) 

290 tests = os.listdir(sub) 

291 for t in tests: 

292 if fct_filter is not None and not fct_filter(t): 

293 continue 

294 folder = os.path.join(sub, t) 

295 content = os.listdir(folder) 

296 onx = [c for c in content if os.path.splitext(c)[-1] in {'.onnx'}] 

297 if len(onx) == 1: 

298 yield OnnxBackendTest(folder)