Coverage for mlprodict/onnxrt/validate/validate_helper.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

224 statements  

1""" 

2@file 

3@brief Validates runtime for many :epkg:`scikit-learn` operators. 

4The submodule relies on :epkg:`onnxconverter_common`, 

5:epkg:`sklearn-onnx`. 

6""" 

7import math 

8import copy 

9import os 

10import warnings 

11from importlib import import_module 

12import pickle 

13from time import perf_counter 

14import numpy 

15from cpyquickhelper.numbers import measure_time as _c_measure_time 

16from sklearn.base import BaseEstimator 

17from sklearn.linear_model._base import LinearModel 

18from sklearn.model_selection import train_test_split 

19from sklearn import __all__ as sklearn__all__, __version__ as sklearn_version 

20from .validate_problems import _problems 

21 

22 

23class RuntimeBadResultsError(RuntimeError): 

24 """ 

25 Raised when the results are too different from 

26 :epkg:`scikit-learn`. 

27 """ 

28 

29 def __init__(self, msg, obs): 

30 """ 

31 :param msg: to display 

32 :param obs: observations 

33 """ 

34 RuntimeError.__init__(self, msg) 

35 self.obs = obs 

36 

37 

38def _dictionary2str(di): 

39 el = [] 

40 for k in sorted(di): 

41 el.append('{}={}'.format(k, di[k])) 

42 return '/'.join(el) 

43 

44 

45def modules_list(): 

46 """ 

47 Returns modules and versions currently used. 

48 

49 .. runpython:: 

50 :showcode: 

51 :rst: 

52 :warningout: DeprecationWarning 

53 

54 from mlprodict.onnxrt.validate.validate_helper import modules_list 

55 from pyquickhelper.pandashelper import df2rst 

56 from pandas import DataFrame 

57 print(df2rst(DataFrame(modules_list()))) 

58 """ 

59 def try_import(name): 

60 try: 

61 mod = import_module(name) 

62 except ImportError: # pragma: no cover 

63 return None 

64 return (dict(name=name, version=mod.__version__) 

65 if hasattr(mod, '__version__') else dict(name=name)) 

66 

67 rows = [] 

68 for name in sorted(['pandas', 'numpy', 'sklearn', 'mlprodict', 

69 'skl2onnx', 'onnxmltools', 'onnx', 'onnxruntime', 

70 'scipy']): 

71 res = try_import(name) 

72 if res is not None: 

73 rows.append(res) 

74 return rows 

75 

76 

77def _dispsimple(arr, fLOG): 

78 if isinstance(arr, (tuple, list)): 

79 for i, a in enumerate(arr): 

80 fLOG("output %d" % i) 

81 _dispsimple(a, fLOG) 

82 elif hasattr(arr, 'shape'): 

83 if len(arr.shape) == 1: 

84 threshold = 8 

85 else: 

86 threshold = min( 

87 50, min(50 // arr.shape[1], 8) * arr.shape[1]) 

88 fLOG(numpy.array2string(arr, max_line_width=120, 

89 suppress_small=True, 

90 threshold=threshold)) 

91 else: # pragma: no cover 

92 s = str(arr) 

93 if len(s) > 50: 

94 s = s[:50] + "..." 

95 fLOG(s) 

96 

97 

98def _merge_options(all_conv_options, aoptions): 

99 if aoptions is None: 

100 return copy.deepcopy(all_conv_options) 

101 if not isinstance(aoptions, dict): 

102 return copy.deepcopy(aoptions) # pragma: no cover 

103 merged = {} 

104 for k, v in all_conv_options.items(): 

105 if k in aoptions: 

106 merged[k] = _merge_options(v, aoptions[k]) 

107 else: 

108 merged[k] = copy.deepcopy(v) 

109 for k, v in aoptions.items(): 

110 if k in all_conv_options: 

111 continue 

112 merged[k] = copy.deepcopy(v) 

113 return merged 

114 

115 

116def sklearn_operators(subfolder=None, extended=False, 

117 experimental=True): 

118 """ 

119 Builds the list of operators from :epkg:`scikit-learn`. 

120 The function goes through the list of submodule 

121 and get the list of class which inherit from 

122 :epkg:`scikit-learn:base:BaseEstimator`. 

123 

124 :param subfolder: look into only one subfolder 

125 :param extended: extends the list to the list of operators 

126 this package implements a converter for 

127 :param experimental: includes experimental module from 

128 :epkg:`scikit-learn` (see `sklearn.experimental 

129 <https://github.com/scikit-learn/scikit-learn/ 

130 tree/master/sklearn/experimental>`_) 

131 :return: the list of found operators 

132 """ 

133 if experimental: 

134 from sklearn.experimental import ( # pylint: disable=W0611 

135 enable_hist_gradient_boosting, 

136 enable_iterative_imputer) 

137 

138 subfolders = sklearn__all__ + ['mlprodict.onnx_conv'] 

139 found = [] 

140 for subm in sorted(subfolders): 

141 if isinstance(subm, list): 

142 continue # pragma: no cover 

143 if subfolder is not None and subm != subfolder: 

144 continue 

145 

146 if subm == 'feature_extraction': 

147 subs = [subm, 'feature_extraction.text'] 

148 else: 

149 subs = [subm] 

150 

151 for sub in subs: 

152 if '.' in sub and sub not in {'feature_extraction.text'}: 

153 name_sub = sub 

154 else: 

155 name_sub = "{0}.{1}".format("sklearn", sub) 

156 try: 

157 mod = import_module(name_sub) 

158 except ModuleNotFoundError: 

159 continue 

160 

161 if hasattr(mod, "register_converters"): 

162 fct = getattr(mod, "register_converters") 

163 cls = fct() 

164 else: 

165 cls = getattr(mod, "__all__", None) 

166 if cls is None: 

167 cls = list(mod.__dict__) 

168 cls = [mod.__dict__[cl] for cl in cls] 

169 

170 for cl in cls: 

171 try: 

172 issub = issubclass(cl, BaseEstimator) 

173 except TypeError: 

174 continue 

175 if cl.__name__ in {'Pipeline', 'ColumnTransformer', 

176 'FeatureUnion', 'BaseEstimator', 

177 'BaseEnsemble', 'BaseDecisionTree'}: 

178 continue 

179 if cl.__name__ in {'CustomScorerTransform'}: 

180 continue 

181 if (sub in {'calibration', 'dummy', 'manifold'} and 

182 'Calibrated' not in cl.__name__): 

183 continue 

184 if issub: 

185 pack = "sklearn" if sub in sklearn__all__ else cl.__module__.split('.')[ 

186 0] 

187 found.append( 

188 dict(name=cl.__name__, subfolder=sub, cl=cl, package=pack)) 

189 

190 if extended: 

191 from ...onnx_conv import register_converters 

192 with warnings.catch_warnings(): 

193 warnings.simplefilter("ignore", ResourceWarning) 

194 models = register_converters(True) 

195 

196 done = set(_['name'] for _ in found) 

197 for m in models: 

198 try: 

199 name = m.__module__.split('.') 

200 except AttributeError as e: # pragma: no cover 

201 raise AttributeError("Unexpected value, m={}".format(m)) from e 

202 sub = '.'.join(name[1:]) 

203 pack = name[0] 

204 if m.__name__ not in done: 

205 found.append( 

206 dict(name=m.__name__, cl=m, package=pack, sub=sub)) 

207 

208 # let's remove models which cannot predict 

209 all_found = found 

210 found = [] 

211 for mod in all_found: 

212 cl = mod['cl'] 

213 if hasattr(cl, 'fit_predict') and not hasattr(cl, 'predict'): 

214 continue 

215 if hasattr(cl, 'fit_transform') and not hasattr(cl, 'transform'): 

216 continue 

217 if (not hasattr(cl, 'transform') and 

218 not hasattr(cl, 'predict') and 

219 not hasattr(cl, 'decision_function')): 

220 continue 

221 found.append(mod) 

222 return found 

223 

224 

225def _measure_time(fct, repeat=1, number=1, first_run=True): 

226 """ 

227 Measures the execution time for a function. 

228 

229 :param fct: function to measure 

230 :param repeat: number of times to repeat 

231 :param number: number of times between two measures 

232 :param first_run: if True, runs the function once before measuring 

233 :return: last result, average, values 

234 """ 

235 res = None 

236 values = [] 

237 if first_run: 

238 fct() 

239 for __ in range(repeat): 

240 begin = perf_counter() 

241 for _ in range(number): 

242 res = fct() 

243 end = perf_counter() 

244 values.append(end - begin) 

245 if repeat * number == 1: 

246 return res, values[0], values 

247 return res, sum(values) / (repeat * number), values # pragma: no cover 

248 

249 

250def _shape_exc(obj): 

251 if hasattr(obj, 'shape'): 

252 return obj.shape 

253 if isinstance(obj, (list, dict, tuple)): 

254 return "[{%d}]" % len(obj) 

255 return None 

256 

257 

258def dump_into_folder(dump_folder, obs_op=None, is_error=True, 

259 **kwargs): 

260 """ 

261 Dumps information when an error was detected 

262 using :epkg:`*py:pickle`. 

263 

264 :param dump_folder: dump_folder 

265 :param obs_op: obs_op (information) 

266 :param is_error: is it an error or not? 

267 :param kwargs: additional parameters 

268 :return: name 

269 """ 

270 if dump_folder is None: 

271 raise ValueError("dump_folder cannot be None.") 

272 optim = obs_op.get('optim', '') 

273 optim = str(optim) 

274 optim = optim.replace("<class 'sklearn.", "") 

275 optim = optim.replace("<class '", "") 

276 optim = optim.replace(" ", "") 

277 optim = optim.replace(">", "") 

278 optim = optim.replace("=", "") 

279 optim = optim.replace("{", "") 

280 optim = optim.replace("}", "") 

281 optim = optim.replace(":", "") 

282 optim = optim.replace("'", "") 

283 optim = optim.replace("/", "") 

284 optim = optim.replace("\\", "") 

285 parts = (obs_op['runtime'], obs_op['name'], obs_op['scenario'], 

286 obs_op['problem'], optim, 

287 "op" + str(obs_op.get('opset', '-')), 

288 "nf" + str(obs_op.get('n_features', '-'))) 

289 name = "dump-{}-{}.pkl".format( 

290 "ERROR" if is_error else "i", 

291 "-".join(map(str, parts))) 

292 name = os.path.join(dump_folder, name) 

293 obs_op = obs_op.copy() 

294 fcts = [k for k in obs_op if k.startswith('lambda')] 

295 for fct in fcts: 

296 del obs_op[fct] 

297 kwargs.update({'obs_op': obs_op}) 

298 with open(name, "wb") as f: 

299 pickle.dump(kwargs, f) 

300 return name 

301 

302 

303def default_time_kwargs(): 

304 """ 

305 Returns default values *number* and *repeat* to measure 

306 the execution of a function. 

307 

308 .. runpython:: 

309 :showcode: 

310 :warningout: DeprecationWarning 

311 

312 from mlprodict.onnxrt.validate.validate_helper import default_time_kwargs 

313 import pprint 

314 pprint.pprint(default_time_kwargs()) 

315 

316 keys define the number of rows, 

317 values defines *number* and *repeat*. 

318 """ 

319 return { 

320 1: dict(number=15, repeat=20), 

321 10: dict(number=10, repeat=20), 

322 100: dict(number=4, repeat=10), 

323 1000: dict(number=4, repeat=4), 

324 10000: dict(number=2, repeat=2), 

325 } 

326 

327 

328def measure_time(stmt, x, repeat=10, number=50, div_by_number=False, 

329 first_run=True, max_time=None): 

330 """ 

331 Measures a statement and returns the results as a dictionary. 

332 

333 :param stmt: string 

334 :param x: matrix 

335 :param repeat: average over *repeat* experiment 

336 :param number: number of executions in one row 

337 :param div_by_number: divide by the number of executions 

338 :param first_run: if True, runs the function once before measuring 

339 :param max_time: execute the statement until the total goes 

340 beyond this time (approximatively), *repeat* is ignored, 

341 *div_by_number* must be set to True 

342 :return: dictionary 

343 

344 See `Timer.repeat <https://docs.python.org/3/library/timeit.html?timeit.Timer.repeat>`_ 

345 for a better understanding of parameter *repeat* and *number*. 

346 The function returns a duration corresponding to 

347 *number* times the execution of the main statement. 

348 """ 

349 if x is None: 

350 raise ValueError("x cannot be None") # pragma: no cover 

351 

352 def fct(): 

353 stmt(x) 

354 

355 if first_run: 

356 try: 

357 fct() 

358 except RuntimeError as e: # pragma: no cover 

359 raise RuntimeError("{}-{}".format(type(x), x.dtype)) from e 

360 

361 return _c_measure_time(fct, context={}, repeat=repeat, number=number, 

362 div_by_number=div_by_number, max_time=max_time) 

363 

364 

365def _multiply_time_kwargs(time_kwargs, time_kwargs_fact, inst): 

366 """ 

367 Multiplies values in *time_kwargs* following strategy 

368 *time_kwargs_fact* for a given model *inst*. 

369 

370 :param time_kwargs: see below 

371 :param time_kwargs_fact: see below 

372 :param inst: :epkg:`scikit-learn` model 

373 :return: new *time_kwargs* 

374 

375 Possible values for *time_kwargs_fact*: 

376 

377 - a integer: multiplies *number* by this number 

378 - `'lin'`: multiplies value *number* for linear models depending 

379 on the number of rows to process (:math:`\\propto 1/\\log_{10}(n)`) 

380 

381 .. runpython:: 

382 :showcode: 

383 :warningout: DeprecationWarning 

384 

385 from pprint import pprint 

386 from sklearn.linear_model import LinearRegression 

387 from mlprodict.onnxrt.validate.validate_helper import ( 

388 default_time_kwargs, _multiply_time_kwargs) 

389 

390 lr = LinearRegression() 

391 kw = default_time_kwargs() 

392 pprint(kw) 

393 

394 kw2 = _multiply_time_kwargs(kw, 'lin', lr) 

395 pprint(kw2) 

396 """ 

397 if time_kwargs is None: 

398 raise ValueError("time_kwargs cannot be None.") # pragma: no cover 

399 if time_kwargs_fact in ('', None): 

400 return time_kwargs 

401 try: 

402 vi = int(time_kwargs_fact) 

403 time_kwargs_fact = vi 

404 except (TypeError, ValueError): 

405 pass 

406 if isinstance(time_kwargs_fact, int): 

407 time_kwargs_modified = copy.deepcopy(time_kwargs) 

408 for k in time_kwargs_modified: 

409 time_kwargs_modified[k]['number'] *= time_kwargs_fact 

410 return time_kwargs_modified 

411 if time_kwargs_fact == 'lin': 

412 if isinstance(inst, LinearModel): 

413 time_kwargs_modified = copy.deepcopy(time_kwargs) 

414 for k in time_kwargs_modified: 

415 kl = max(int(math.log(k) / math.log(10) + 1e-5), 1) 

416 f = max(int(10 / kl + 0.5), 1) 

417 time_kwargs_modified[k]['number'] *= f 

418 time_kwargs_modified[k]['repeat'] *= 1 

419 return time_kwargs_modified 

420 return time_kwargs 

421 raise ValueError( # pragma: no cover 

422 "Unable to interpret time_kwargs_fact='{}'.".format( 

423 time_kwargs_fact)) 

424 

425 

426def _get_problem_data(prob, n_features): 

427 data_problem = _problems[prob](n_features=n_features) 

428 if len(data_problem) == 6: 

429 X_, y_, init_types, method, output_index, Xort_ = data_problem 

430 dofit = True 

431 elif len(data_problem) == 7: 

432 X_, y_, init_types, method, output_index, Xort_, dofit = data_problem 

433 else: 

434 raise RuntimeError( # pragma: no cover 

435 "Unable to interpret problem '{}'.".format(prob)) 

436 if (len(X_.shape) == 2 and X_.shape[1] != n_features and 

437 n_features is not None): 

438 raise RuntimeError( # pragma: no cover 

439 "Problem '{}' with n_features={} returned {} features" 

440 "(func={}).".format(prob, n_features, X_.shape[1], 

441 _problems[prob])) 

442 if y_ is None: 

443 (X_train, X_test, Xort_train, # pylint: disable=W0612 

444 Xort_test) = train_test_split( 

445 X_, Xort_, random_state=42) 

446 y_train, y_test = None, None 

447 else: 

448 (X_train, X_test, y_train, y_test, # pylint: disable=W0612 

449 Xort_train, Xort_test) = train_test_split( 

450 X_, y_, Xort_, random_state=42) 

451 if isinstance(init_types, tuple): 

452 init_types, conv_options = init_types 

453 else: 

454 conv_options = None 

455 

456 if isinstance(method, tuple): 

457 method_name, predict_kwargs = method 

458 else: 

459 method_name = method 

460 predict_kwargs = {} 

461 

462 return (X_train, X_test, y_train, 

463 y_test, Xort_test, 

464 init_types, conv_options, method_name, 

465 output_index, dofit, predict_kwargs)