Coverage for mlprodict/testing/test_utils/utils_backend_common_compare.py: 70%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

84 statements  

1""" 

2@file 

3@brief Inspired from sklearn-onnx, handles two backends. 

4""" 

5import numpy 

6import onnx 

7import pandas 

8from .utils_backend_common import ( 

9 load_data_and_model, extract_options, 

10 ExpectedAssertionError, OnnxBackendAssertionError, 

11 OnnxRuntimeMissingNewOnnxOperatorException, 

12 _compare_expected, _create_column) 

13 

14 

15def compare_runtime_session( # pylint: disable=R0912 

16 cls_session, test, decimal=5, options=None, 

17 verbose=False, context=None, comparable_outputs=None, 

18 intermediate_steps=False, classes=None, 

19 disable_optimisation=False): 

20 """ 

21 The function compares the expected output (computed with 

22 the model before being converted to ONNX) and the ONNX output 

23 produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`. 

24 

25 :param cls_session: inference session instance (like @see cl OnnxInference) 

26 :param test: dictionary with the following keys: 

27 - *onnx*: onnx model (filename or object) 

28 - *expected*: expected output (filename pkl or object) 

29 - *data*: input data (filename pkl or object) 

30 :param decimal: precision of the comparison 

31 :param options: comparison options 

32 :param context: specifies custom operators 

33 :param verbose: in case of error, the function may print 

34 more information on the standard output 

35 :param comparable_outputs: compare only these outputs 

36 :param intermediate_steps: displays intermediate steps 

37 in case of an error 

38 :param classes: classes names (if option 'nocl' is used) 

39 :param disable_optimisation: disable optimisation the runtime may do 

40 :return: tuple (outut, lambda function to run the predictions) 

41 

42 The function does not return anything but raises an error 

43 if the comparison failed. 

44 """ 

45 lambda_onnx = None 

46 if context is None: 

47 context = {} 

48 load = load_data_and_model(test, **context) 

49 if verbose: # pragma no cover 

50 print("[compare_runtime] test '{}' loaded".format(test['onnx'])) 

51 

52 onx = test['onnx'] 

53 

54 if options is None: 

55 if isinstance(onx, str): 

56 options = extract_options(onx) 

57 else: 

58 options = {} 

59 elif options is None: 

60 options = {} 

61 elif not isinstance(options, dict): 

62 raise TypeError( # pragma no cover 

63 "options must be a dictionary.") 

64 

65 if verbose: # pragma no cover 

66 print("[compare_runtime] InferenceSession('{}')".format(onx)) 

67 

68 runtime_options = dict(disable_optimisation=disable_optimisation) 

69 try: 

70 sess = cls_session(onx, runtime_options=runtime_options) 

71 except TypeError as et: # pragma: no cover 

72 raise TypeError( # pylint: disable=W0707 

73 "Wrong signature for '{}' ({}).".format(cls_session.__name__, et)) 

74 except ExpectedAssertionError as expe: # pragma no cover 

75 raise expe 

76 except Exception as e: # pylint: disable=W0703 

77 if "CannotLoad" in options: # pragma no cover 

78 raise ExpectedAssertionError( # pylint: disable=W0707 

79 "Unable to load onnx '{0}' due to\n{1}".format(onx, e)) 

80 else: # pragma no cover 

81 if verbose: # pragma no cover 

82 model = onnx.load(onx) 

83 smodel = "\nJSON ONNX\n" + str(model) 

84 else: 

85 smodel = "" 

86 if ("NOT_IMPLEMENTED : Could not find an implementation " 

87 "for the node" in str(e)): 

88 # onnxruntime does not implement a specific node yet. 

89 raise OnnxRuntimeMissingNewOnnxOperatorException( # pylint: disable=W0707 

90 "{3} does not implement a new operator " 

91 "'{0}'\n{1}\nONNX\n{2}".format( 

92 onx, e, smodel, cls_session)) 

93 if "NOT_IMPLEMENTED : Failed to find kernel" in str(e): 

94 # onnxruntime does not implement a specific node yet 

95 # in the kernel included in onnxruntime. 

96 raise OnnxBackendAssertionError( # pylint: disable=W0707 

97 "{3} misses a kernel for operator " 

98 "'{0}'\n{1}\nONNX\n{2}".format( 

99 onx, e, smodel, cls_session)) 

100 raise OnnxBackendAssertionError( # pylint: disable=W0707 

101 "Unable to load onnx '{0}'\nONNX\n{1}\n{2}".format( 

102 onx, smodel, e)) 

103 

104 input = load["data"] 

105 DF = options.pop('DF', False) 

106 if DF: 

107 inputs = {c: input[c].values for c in input.columns} 

108 for k in inputs: 

109 if inputs[k].dtype == numpy.float64: 

110 inputs[k] = inputs[k].astype(numpy.float32) 

111 inputs[k] = inputs[k].reshape((inputs[k].shape[0], 1)) 

112 else: 

113 if isinstance(input, dict): 

114 inputs = input 

115 elif isinstance(input, (list, numpy.ndarray, pandas.DataFrame)): 

116 inp = sess.get_inputs() 

117 outs = sess.get_outputs() 

118 if len(outs) == 0: 

119 raise OnnxBackendAssertionError( # pragma: no cover 

120 "Wrong number of outputs, onnx='{2}'".format(onx)) 

121 if len(inp) == len(input): 

122 inputs = {i.name: v for i, v in zip(inp, input)} 

123 elif len(inp) == 1: 

124 inputs = {inp[0].name: input} 

125 elif isinstance(input, numpy.ndarray): 

126 shape = sum(i.shape[1] if len(i.shape) == 2 else i.shape[0] 

127 for i in inp) 

128 if shape == input.shape[1]: 

129 inputs = {n.name: input[:, i] for i, n in enumerate(inp)} 

130 else: 

131 raise OnnxBackendAssertionError( # pragma: no cover 

132 "Wrong number of inputs onnx {0} != " 

133 "original shape {1}, onnx='{2}'" 

134 .format(len(inp), input.shape, onx)) 

135 elif isinstance(input, list): 

136 try: 

137 array_input = numpy.array(input) 

138 except Exception: # pragma no cover 

139 raise OnnxBackendAssertionError( # pylint: disable=W0707 

140 "Wrong number of inputs onnx {0} != " 

141 "original {1}, onnx='{2}'" 

142 .format(len(inp), len(input), onx)) 

143 shape = sum(i.shape[1] for i in inp) 

144 if shape == array_input.shape[1]: 

145 inputs = {} 

146 c = 0 

147 for i, n in enumerate(inp): 

148 d = c + n.shape[1] 

149 inputs[n.name] = _create_column( 

150 [row[c:d] for row in input], n.type) 

151 c = d 

152 else: 

153 raise OnnxBackendAssertionError( # pragma no cover 

154 "Wrong number of inputs onnx {0} != " 

155 "original shape {1}, onnx='{2}'*" 

156 .format(len(inp), array_input.shape, onx)) 

157 elif isinstance(input, pandas.DataFrame): 

158 try: 

159 array_input = numpy.array(input) 

160 except Exception: # pragma no cover 

161 raise OnnxBackendAssertionError( # pylint: disable=W0707 

162 "Wrong number of inputs onnx {0} != " 

163 "original {1}, onnx='{2}'" 

164 .format(len(inp), len(input), onx)) 

165 shape = sum(i.shape[1] for i in inp) 

166 if shape == array_input.shape[1]: 

167 inputs = {} 

168 c = 0 

169 for i, n in enumerate(inp): 

170 d = c + n.shape[1] 

171 inputs[n.name] = _create_column( 

172 input.iloc[:, c:d], n.type) 

173 c = d 

174 else: 

175 raise OnnxBackendAssertionError( # pragma no cover 

176 "Wrong number of inputs onnx {0}={1} columns != " 

177 "original shape {2}, onnx='{3}'*" 

178 .format(len(inp), shape, array_input.shape, onx)) 

179 else: 

180 raise OnnxBackendAssertionError( # pragma no cover 

181 "Wrong type of inputs onnx {0}, onnx='{1}'".format( 

182 type(input), onx)) 

183 else: 

184 raise OnnxBackendAssertionError( # pragma no cover 

185 "Dict or list is expected, not {0}".format(type(input))) 

186 

187 for k in inputs: 

188 if isinstance(inputs[k], list): 

189 inputs[k] = numpy.array(inputs[k]) 

190 

191 options.pop('SklCol', False) # unused here but in dump_data_and_model 

192 

193 if verbose: # pragma no cover 

194 print("[compare_runtime] type(inputs)={} len={} names={}".format( 

195 type(input), len(inputs), list(sorted(inputs)))) 

196 if verbose: # pragma no cover 

197 if intermediate_steps: 

198 run_options = {'verbose': 3, 'fLOG': print} 

199 else: 

200 run_options = {'verbose': 2, 'fLOG': print} 

201 else: 

202 run_options = {} 

203 

204 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

205 InvalidArgument as OrtInvalidArgument) 

206 

207 try: 

208 try: 

209 output = sess.run(None, inputs, **run_options) 

210 except TypeError: # pragma no cover 

211 output = sess.run(None, inputs) 

212 lambda_onnx = lambda: sess.run(None, inputs) # noqa 

213 if verbose: # pragma no cover 

214 import pprint 

215 pprint.pprint(output) 

216 except ExpectedAssertionError as expe: # pragma no cover 

217 raise expe 

218 except (RuntimeError, OrtInvalidArgument) as e: # pragma no cover 

219 if intermediate_steps: 

220 sess.run(None, inputs, verbose=3, fLOG=print) 

221 if "-Fail" in onx: 

222 raise ExpectedAssertionError( # pylint: disable=W0707 

223 "{1} cannot compute the prediction for '{0}'". 

224 format(onx, cls_session)) 

225 else: 

226 if verbose: # pragma no cover 

227 model = onnx.load(onx) 

228 smodel = "\nJSON ONNX\n" + str(model) 

229 else: 

230 smodel = "" 

231 import pprint 

232 raise OnnxBackendAssertionError( # pylint: disable=W0707 

233 "{4} cannot compute the predictions" 

234 " for '{0}' due to {1}{2}\n{3}" 

235 .format(onx, e, smodel, pprint.pformat(inputs), 

236 cls_session)) 

237 except Exception as e: # pragma no cover 

238 raise OnnxBackendAssertionError( # pylint: disable=W0707 

239 "Unable to run onnx '{0}' due to {1}".format(onx, e)) 

240 if verbose: # pragma no cover 

241 print("[compare_runtime] done type={}".format(type(output))) 

242 

243 output0 = output.copy() 

244 

245 if comparable_outputs: 

246 cmp_exp = [load["expected"][o] for o in comparable_outputs] 

247 cmp_out = [output[o] for o in comparable_outputs] 

248 else: 

249 cmp_exp = load["expected"] 

250 cmp_out = output 

251 

252 try: 

253 _compare_expected(cmp_exp, cmp_out, sess, onx, 

254 decimal=decimal, verbose=verbose, 

255 classes=classes, **options) 

256 except ExpectedAssertionError as expe: # pragma no cover 

257 raise expe 

258 except Exception as e: # pragma no cover 

259 if verbose: # pragma no cover 

260 model = onnx.load(onx) 

261 smodel = "\nJSON ONNX\n" + str(model) 

262 else: 

263 smodel = "" 

264 raise OnnxBackendAssertionError( # pylint: disable=W0707 

265 "Model '{}' has discrepencies with cls='{}'.\n{}: {}{}".format( 

266 onx, sess.__class__.__name__, type(e), e, smodel)) 

267 

268 return output0, lambda_onnx