Coverage for mlprodict/cli/validate.py: 95%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

251 statements  

1""" 

2@file 

3@brief Command line about validation of prediction runtime. 

4""" 

5import os 

6from io import StringIO 

7from logging import getLogger 

8import warnings 

9import json 

10from multiprocessing import Pool 

11from pandas import DataFrame, read_csv, concat 

12from sklearn.exceptions import ConvergenceWarning 

13 

14 

15def benchmark_doc(runtime, black_list=None, white_list=None, 

16 out_raw='bench_raw.xlsx', out_summary="bench_summary.xlsx", 

17 dump_dir='dump', fLOG=print, verbose=0): 

18 """ 

19 Runs the benchmark published into the documentation 

20 (see :ref:`l-onnx-bench-onnxruntime1` and 

21 :ref:`l-onnx-bench-python_compiled`). 

22 

23 :param runtime: runtime (python, python_compiled, 

24 onnxruntime1, onnxruntime2) 

25 :param black_list: models to skip, None for none 

26 (comma separated list) 

27 :param white_list: models to benchmark, None for all 

28 (comma separated list) 

29 :param out_raw: all results are saved in that file 

30 :param out_summary: all results are summarized in that file 

31 :param dump_dir: folder where to dump intermediate results 

32 :param fLOG: logging function 

33 :param verbose: verbosity 

34 :return: list of created files 

35 """ 

36 def _save(df, name): 

37 ext = os.path.splitext(name)[-1] 

38 if ext == '.xlsx': 

39 df.to_excel(name, index=False) 

40 elif ext == '.csv': 

41 df.to_csv(name, index=False) 

42 else: 

43 raise ValueError( # pragma: no cover 

44 "Unexpected extension in %r." % name) 

45 if verbose > 1: 

46 fLOG( # pragma: no cover 

47 "[mlprodict] wrote '{}'".format(name)) 

48 

49 from pyquickhelper.loghelper import run_cmd 

50 from pyquickhelper.loghelper.run_cmd import get_interpreter_path 

51 from tqdm import tqdm 

52 from ..onnxrt.validate.validate_helper import sklearn_operators 

53 from ..onnx_conv import register_converters, register_rewritten_operators 

54 register_converters() 

55 try: 

56 register_rewritten_operators() 

57 except KeyError: # pragma: no cover 

58 warnings.warn("converter for HistGradientBoosting* not not exist. " 

59 "Upgrade sklearn-onnx") 

60 

61 if black_list is None: 

62 black_list = [] 

63 else: 

64 black_list = black_list.split(',') 

65 if white_list is None: 

66 white_list = [] 

67 else: 

68 white_list = white_list.split(',') 

69 

70 filenames = [] 

71 skls = sklearn_operators(extended=True) 

72 skls = [_['name'] for _ in skls] 

73 if white_list: 

74 skls = [_ for _ in skls if _ in white_list] 

75 skls.sort() 

76 if verbose > 0: 

77 pbar = tqdm(skls) 

78 else: 

79 pbar = skls 

80 for op in pbar: 

81 if black_list is not None and op in black_list: 

82 continue 

83 if verbose > 0: 

84 pbar.set_description( # pragma: no cover 

85 "[%s]" % (op + " " * (25 - len(op)))) 

86 

87 loop_out_raw = os.path.join( 

88 dump_dir, "bench_raw_%s_%s.csv" % (runtime, op)) 

89 loop_out_sum = os.path.join( 

90 dump_dir, "bench_sum_%s_%s.csv" % (runtime, op)) 

91 cmd = ('{0} -m mlprodict validate_runtime --verbose=0 --out_raw={1} --out_summary={2} ' 

92 '--benchmark=1 --dump_folder={3} --runtime={4} --models={5}'.format( 

93 get_interpreter_path(), loop_out_raw, loop_out_sum, dump_dir, runtime, op)) 

94 if verbose > 1: 

95 fLOG("[mlprodict] cmd '{}'.".format(cmd)) # pragma: no cover 

96 out, err = run_cmd(cmd, wait=True, fLOG=None) 

97 if not os.path.exists(loop_out_sum): # pragma: no cover 

98 if verbose > 2: 

99 fLOG("[mlprodict] unable to find '{}'.".format(loop_out_sum)) 

100 if verbose > 1: 

101 fLOG("[mlprodict] cmd '{}'".format(cmd)) 

102 fLOG("[mlprodict] unable to find '{}'".format(loop_out_sum)) 

103 msg = "Unable to find '{}'\n--CMD--\n{}\n--OUT--\n{}\n--ERR--\n{}".format( 

104 loop_out_sum, cmd, out, err) 

105 if verbose > 1: 

106 fLOG(msg) 

107 rows = [{'name': op, 'scenario': 'CRASH', 

108 'ERROR-msg': msg.replace("\n", " -- ")}] 

109 df = DataFrame(rows) 

110 df.to_csv(loop_out_sum, index=False) 

111 filenames.append((loop_out_raw, loop_out_sum)) 

112 

113 # concatenate summaries 

114 dfs_raw = [read_csv(name[0]) 

115 for name in filenames if os.path.exists(name[0])] 

116 dfs_sum = [read_csv(name[1]) 

117 for name in filenames if os.path.exists(name[1])] 

118 df_raw = concat(dfs_raw, sort=False) 

119 piv = concat(dfs_sum, sort=False) 

120 

121 opset_cols = [(int(oc.replace("opset", "")), oc) 

122 for oc in piv.columns if 'opset' in oc] 

123 opset_cols.sort(reverse=True) 

124 opset_cols = [oc[1] for oc in opset_cols] 

125 new_cols = opset_cols[:1] 

126 bench_cols = ["RT/SKL-N=1", "N=10", "N=100", 

127 "N=1000", "N=10000"] 

128 new_cols.extend(["ERROR-msg", "name", "problem", "scenario", 'optim']) 

129 new_cols.extend(bench_cols) 

130 new_cols.extend(opset_cols[1:]) 

131 for c in bench_cols: 

132 new_cols.append(c + '-min') 

133 new_cols.append(c + '-max') 

134 for c in piv.columns: 

135 if c.startswith("skl_") or c.startswith("onx_"): 

136 new_cols.append(c) 

137 new_cols = [_ for _ in new_cols if _ in piv.columns] 

138 piv = piv[new_cols] 

139 

140 _save(piv, out_summary) 

141 _save(df_raw, out_raw) 

142 return filenames 

143 

144 

145def validate_runtime(verbose=1, opset_min=-1, opset_max="", 

146 check_runtime=True, runtime='python', debug=False, 

147 models=None, out_raw="model_onnx_raw.xlsx", 

148 out_summary="model_onnx_summary.xlsx", 

149 dump_folder=None, dump_all=False, benchmark=False, 

150 catch_warnings=True, assume_finite=True, 

151 versions=False, skip_models=None, 

152 extended_list=True, separate_process=False, 

153 time_kwargs=None, n_features=None, fLOG=print, 

154 out_graph=None, force_return=False, 

155 dtype=None, skip_long_test=False, 

156 number=1, repeat=1, time_kwargs_fact='lin', 

157 time_limit=4, n_jobs=0): 

158 """ 

159 Walks through most of :epkg:`scikit-learn` operators 

160 or model or predictor or transformer, tries to convert 

161 them into :epkg:`ONNX` and computes the predictions 

162 with a specific runtime. 

163 

164 :param verbose: integer from 0 (None) to 2 (full verbose) 

165 :param opset_min: tries every conversion from this minimum opset, 

166 -1 to get the current opset 

167 :param opset_max: tries every conversion up to maximum opset, 

168 -1 to get the current opset 

169 :param check_runtime: to check the runtime 

170 and not only the conversion 

171 :param runtime: runtime to check, python, 

172 onnxruntime1 to check :epkg:`onnxruntime`, 

173 onnxruntime2 to check every *ONNX* node independently 

174 with onnxruntime, many runtime can be checked at the same time 

175 if the value is a comma separated list 

176 :param models: comma separated list of models to test or empty 

177 string to test them all 

178 :param skip_models: models to skip 

179 :param debug: stops whenever an exception is raised, 

180 only if *separate_process* is False 

181 :param out_raw: output raw results into this file (excel format) 

182 :param out_summary: output an aggregated view into this file (excel format) 

183 :param dump_folder: folder where to dump information (pickle) 

184 in case of mismatch 

185 :param dump_all: dumps all models, not only the failing ones 

186 :param benchmark: run benchmark 

187 :param catch_warnings: catch warnings 

188 :param assume_finite: See `config_context 

189 <https://scikit-learn.org/stable/modules/generated/sklearn.config_context.html>`_, 

190 If True, validation for finiteness will be skipped, saving time, but leading 

191 to potential crashes. If False, validation for finiteness will be performed, 

192 avoiding error. 

193 :param versions: add columns with versions of used packages, 

194 :epkg:`numpy`, :epkg:`scikit-learn`, :epkg:`onnx`, :epkg:`onnxruntime`, 

195 :epkg:`sklearn-onnx` 

196 :param extended_list: extends the list of :epkg:`scikit-learn` converters 

197 with converters implemented in this module 

198 :param separate_process: run every model in a separate process, 

199 this option must be used to run all model in one row 

200 even if one of them is crashing 

201 :param time_kwargs: a dictionary which defines the number of rows and 

202 the parameter *number* and *repeat* when benchmarking a model, 

203 the value must follow :epkg:`json` format 

204 :param n_features: change the default number of features for 

205 a specific problem, it can also be a comma separated list 

206 :param force_return: forces the function to return the results, 

207 used when the results are produces through a separate process 

208 :param out_graph: image name, to output a graph which summarizes 

209 a benchmark in case it was run 

210 :param dtype: '32' or '64' or None for both, 

211 limits the test to one specific number types 

212 :param skip_long_test: skips tests for high values of N if 

213 they seem too long 

214 :param number: to multiply number values in *time_kwargs* 

215 :param repeat: to multiply repeat values in *time_kwargs* 

216 :param time_kwargs_fact: to multiply number and repeat in 

217 *time_kwargs* depending on the model 

218 (see :func:`_multiply_time_kwargs <mlprodict.onnxrt.validate.validate_helper._multiply_time_kwargs>`) 

219 :param time_limit: to stop benchmarking after this limit of time 

220 :param n_jobs: force the number of jobs to have this value, 

221 by default, it is equal to the number of CPU 

222 :param fLOG: logging function 

223 

224 .. cmdref:: 

225 :title: Validates a runtime against scikit-learn 

226 :cmd: -m mlprodict validate_runtime --help 

227 :lid: l-cmd-validate_runtime 

228 

229 The command walks through all scikit-learn operators, 

230 tries to convert them, checks the predictions, 

231 and produces a report. 

232 

233 Example:: 

234 

235 python -m mlprodict validate_runtime --models LogisticRegression,LinearRegression 

236 

237 Following example benchmarks models 

238 :epkg:`sklearn:ensemble:RandomForestRegressor`, 

239 :epkg:`sklearn:tree:DecisionTreeRegressor`, it compares 

240 :epkg:`onnxruntime` against :epkg:`scikit-learn` for opset 10. 

241 

242 :: 

243 

244 python -m mlprodict validate_runtime -v 1 -o 10 -op 10 -c 1 -r onnxruntime1 

245 -m RandomForestRegressor,DecisionTreeRegressor -out bench_onnxruntime.xlsx -b 1 

246 

247 Parameter ``--time_kwargs`` may be used to reduce or increase 

248 bencharmak precisions. The following value tells the function 

249 to run a benchmarks with datasets of 1 or 10 number, to repeat 

250 a given number of time *number* predictions in one row. 

251 The total time is divided by :math:`number \\times repeat``. 

252 Parameter ``--time_kwargs_fact`` may be used to increase these 

253 number for some specific models. ``'lin'`` multiplies 

254 by 10 number when the model is linear. 

255 

256 :: 

257 

258 -t "{\\"1\\":{\\"number\\":10,\\"repeat\\":10},\\"10\\":{\\"number\\":5,\\"repeat\\":5}}" 

259 

260 The following example dumps every model in the list: 

261 

262 :: 

263 

264 python -m mlprodict validate_runtime --out_raw raw.csv --out_summary sum.csv 

265 --models LinearRegression,LogisticRegression,DecisionTreeRegressor,DecisionTreeClassifier 

266 -r python,onnxruntime1 -o 10 -op 10 -v 1 -b 1 -dum 1 

267 -du model_dump -n 20,100,500 --out_graph benchmark.png --dtype 32 

268 

269 The command line generates a graph produced by function 

270 :func:`plot_validate_benchmark 

271 <mlprodict.onnxrt.validate.validate_graph.plot_validate_benchmark>`. 

272 """ 

273 if separate_process: 

274 return _validate_runtime_separate_process( 

275 verbose=verbose, opset_min=opset_min, opset_max=opset_max, 

276 check_runtime=check_runtime, runtime=runtime, debug=debug, 

277 models=models, out_raw=out_raw, 

278 out_summary=out_summary, dump_all=dump_all, 

279 dump_folder=dump_folder, benchmark=benchmark, 

280 catch_warnings=catch_warnings, assume_finite=assume_finite, 

281 versions=versions, skip_models=skip_models, 

282 extended_list=extended_list, time_kwargs=time_kwargs, 

283 n_features=n_features, fLOG=fLOG, force_return=True, 

284 out_graph=None, dtype=dtype, skip_long_test=skip_long_test, 

285 time_kwargs_fact=time_kwargs_fact, time_limit=time_limit, 

286 n_jobs=n_jobs) 

287 

288 from ..onnxrt.validate import enumerate_validated_operator_opsets # pylint: disable=E0402 

289 

290 if not isinstance(models, list): 

291 models = (None if models in (None, "") 

292 else models.strip().split(',')) 

293 if not isinstance(skip_models, list): 

294 skip_models = ({} if skip_models in (None, "") 

295 else skip_models.strip().split(',')) 

296 if verbose <= 1: 

297 logger = getLogger('skl2onnx') 

298 logger.disabled = True 

299 if not dump_folder: 

300 dump_folder = None 

301 if dump_folder and not os.path.exists(dump_folder): 

302 os.mkdir(dump_folder) # pragma: no cover 

303 if dump_folder and not os.path.exists(dump_folder): 

304 raise FileNotFoundError( # pragma: no cover 

305 "Cannot find dump_folder '{0}'.".format( 

306 dump_folder)) 

307 

308 # handling parameters 

309 if opset_max == "": 

310 opset_max = None # pragma: no cover 

311 if isinstance(opset_min, str): 

312 opset_min = int(opset_min) # pragma: no cover 

313 if isinstance(opset_max, str): 

314 opset_max = int(opset_max) 

315 if isinstance(verbose, str): 

316 verbose = int(verbose) # pragma: no cover 

317 if isinstance(extended_list, str): 

318 extended_list = extended_list in ( 

319 '1', 'True', 'true') # pragma: no cover 

320 if time_kwargs in (None, ''): 

321 time_kwargs = None 

322 if isinstance(time_kwargs, str): 

323 time_kwargs = json.loads(time_kwargs) 

324 # json only allows string as keys 

325 time_kwargs = {int(k): v for k, v in time_kwargs.items()} 

326 if isinstance(n_jobs, str): 

327 n_jobs = int(n_jobs) 

328 if n_jobs == 0: 

329 n_jobs = None 

330 if time_kwargs is not None and not isinstance(time_kwargs, dict): 

331 raise ValueError( # pragma: no cover 

332 "time_kwargs must be a dictionary not {}\n{}".format( 

333 type(time_kwargs), time_kwargs)) 

334 if not isinstance(n_features, list): 

335 if n_features in (None, ""): 

336 n_features = None 

337 elif ',' in n_features: 

338 n_features = list(map(int, n_features.split(','))) 

339 else: 

340 n_features = int(n_features) 

341 if not isinstance(runtime, list) and ',' in runtime: 

342 runtime = runtime.split(',') 

343 

344 def fct_filter_exp(m, s): 

345 cl = m.__name__ 

346 if cl in skip_models: 

347 return False 

348 pair = "%s[%s]" % (cl, s) 

349 if pair in skip_models: 

350 return False 

351 return True 

352 

353 if dtype in ('', None): 

354 fct_filter = fct_filter_exp 

355 elif dtype == '32': 

356 def fct_filter_exp2(m, p): 

357 return fct_filter_exp(m, p) and '64' not in p 

358 fct_filter = fct_filter_exp2 

359 elif dtype == '64': # pragma: no cover 

360 def fct_filter_exp3(m, p): 

361 return fct_filter_exp(m, p) and '64' in p 

362 fct_filter = fct_filter_exp3 

363 else: 

364 raise ValueError( # pragma: no cover 

365 "dtype must be empty, 32, 64 not '{}'.".format(dtype)) 

366 

367 # time_kwargs 

368 

369 if benchmark: 

370 if time_kwargs is None: 

371 from ..onnxrt.validate.validate_helper import default_time_kwargs # pylint: disable=E0402 

372 time_kwargs = default_time_kwargs() 

373 for _, v in time_kwargs.items(): 

374 v['number'] *= number 

375 v['repeat'] *= repeat 

376 if verbose > 0: 

377 fLOG("time_kwargs=%r" % time_kwargs) 

378 

379 # body 

380 

381 def build_rows(models_): 

382 rows = list(enumerate_validated_operator_opsets( 

383 verbose, models=models_, fLOG=fLOG, runtime=runtime, debug=debug, 

384 dump_folder=dump_folder, opset_min=opset_min, opset_max=opset_max, 

385 benchmark=benchmark, assume_finite=assume_finite, versions=versions, 

386 extended_list=extended_list, time_kwargs=time_kwargs, dump_all=dump_all, 

387 n_features=n_features, filter_exp=fct_filter, 

388 skip_long_test=skip_long_test, time_limit=time_limit, 

389 time_kwargs_fact=time_kwargs_fact, n_jobs=n_jobs)) 

390 return rows 

391 

392 def catch_build_rows(models_): 

393 if catch_warnings: 

394 with warnings.catch_warnings(): 

395 warnings.simplefilter("ignore", 

396 (UserWarning, ConvergenceWarning, 

397 RuntimeWarning, FutureWarning)) 

398 rows = build_rows(models_) 

399 else: 

400 rows = build_rows(models_) # pragma: no cover 

401 return rows 

402 

403 rows = catch_build_rows(models) 

404 res = _finalize(rows, out_raw, out_summary, 

405 verbose, models, out_graph, fLOG) 

406 return res if (force_return or verbose >= 2) else None 

407 

408 

409def _finalize(rows, out_raw, out_summary, verbose, models, out_graph, fLOG): 

410 from ..onnxrt.validate import summary_report # pylint: disable=E0402 

411 from ..tools.cleaning import clean_error_msg # pylint: disable=E0402 

412 

413 # Drops data which cannot be serialized. 

414 for row in rows: 

415 keys = [] 

416 for k in row: 

417 if 'lambda' in k: 

418 keys.append(k) 

419 for k in keys: 

420 del row[k] 

421 

422 df = DataFrame(rows) 

423 

424 if out_raw: 

425 if verbose > 0: 

426 fLOG("Saving raw_data into '{}'.".format(out_raw)) 

427 if os.path.splitext(out_raw)[-1] == ".xlsx": 

428 df.to_excel(out_raw, index=False) 

429 else: 

430 clean_error_msg(df).to_csv(out_raw, index=False) 

431 

432 if df.shape[0] == 0: 

433 raise RuntimeError("No result produced by the benchmark.") 

434 piv = summary_report(df) 

435 if 'optim' not in piv: 

436 raise RuntimeError( # pragma: no cover 

437 "Unable to produce a summary. Missing column in \n{}".format( 

438 piv.columns)) 

439 

440 if out_summary: 

441 if verbose > 0: 

442 fLOG("Saving summary into '{}'.".format(out_summary)) 

443 if os.path.splitext(out_summary)[-1] == ".xlsx": 

444 piv.to_excel(out_summary, index=False) 

445 else: 

446 clean_error_msg(piv).to_csv(out_summary, index=False) 

447 

448 if verbose > 1 and models is not None: 

449 fLOG(piv.T) 

450 if out_graph is not None: 

451 if verbose > 0: 

452 fLOG("Saving graph into '{}'.".format(out_graph)) 

453 from ..plotting.plotting import plot_validate_benchmark 

454 fig = plot_validate_benchmark(piv)[0] 

455 fig.savefig(out_graph) 

456 

457 return rows 

458 

459 

460def _validate_runtime_dict(kwargs): 

461 return validate_runtime(**kwargs) 

462 

463 

464def _validate_runtime_separate_process(**kwargs): 

465 models = kwargs['models'] 

466 if models in (None, ""): 

467 from ..onnxrt.validate.validate_helper import sklearn_operators # pragma: no cover 

468 models = [_['name'] 

469 for _ in sklearn_operators(extended=True)] # pragma: no cover 

470 elif not isinstance(models, list): 

471 models = models.strip().split(',') 

472 

473 skip_models = kwargs['skip_models'] 

474 skip_models = {} if skip_models in ( 

475 None, "") else skip_models.strip().split(',') 

476 

477 verbose = kwargs['verbose'] 

478 fLOG = kwargs['fLOG'] 

479 all_rows = [] 

480 skls = [m for m in models if m not in skip_models] 

481 skls.sort() 

482 

483 if verbose > 0: 

484 from tqdm import tqdm 

485 pbar = tqdm(skls) 

486 else: 

487 pbar = skls # pragma: no cover 

488 

489 for op in pbar: 

490 if not isinstance(pbar, list): 

491 pbar.set_description("[%s]" % (op + " " * (25 - len(op)))) 

492 

493 if kwargs['out_raw']: 

494 out_raw = os.path.splitext(kwargs['out_raw']) 

495 out_raw = "".join([out_raw[0], "_", op, out_raw[1]]) 

496 else: 

497 out_raw = None # pragma: no cover 

498 

499 if kwargs['out_summary']: 

500 out_summary = os.path.splitext(kwargs['out_summary']) 

501 out_summary = "".join([out_summary[0], "_", op, out_summary[1]]) 

502 else: 

503 out_summary = None # pragma: no cover 

504 

505 new_kwargs = kwargs.copy() 

506 if 'fLOG' in new_kwargs: 

507 del new_kwargs['fLOG'] 

508 new_kwargs['out_raw'] = out_raw 

509 new_kwargs['out_summary'] = out_summary 

510 new_kwargs['models'] = op 

511 new_kwargs['verbose'] = 0 # tqdm fails 

512 new_kwargs['out_graph'] = None 

513 

514 with Pool(1) as p: 

515 try: 

516 result = p.apply_async(_validate_runtime_dict, [new_kwargs]) 

517 lrows = result.get(timeout=150) # timeout fixed to 150s 

518 all_rows.extend(lrows) 

519 except Exception as e: # pylint: disable=W0703 

520 all_rows.append({ # pragma: no cover 

521 'name': op, 'scenario': 'CRASH', 

522 'ERROR-msg': str(e).replace("\n", " -- ") 

523 }) 

524 

525 return _finalize(all_rows, kwargs['out_raw'], kwargs['out_summary'], 

526 verbose, models, kwargs.get('out_graph', None), fLOG) 

527 

528 

529def latency(model, law='normal', size=1, number=10, repeat=10, max_time=0, 

530 runtime="onnxruntime", device='cpu', fmt=None, 

531 profiling=None, profile_output='profiling.csv'): 

532 """ 

533 Measures the latency of a model (python API). 

534 

535 :param model: ONNX graph 

536 :param law: random law used to generate fake inputs 

537 :param size: batch size, it replaces the first dimension 

538 of every input if it is left unknown 

539 :param number: number of calls to measure 

540 :param repeat: number of times to repeat the experiment 

541 :param max_time: if it is > 0, it runs as many time during 

542 that period of time 

543 :param runtime: available runtime 

544 :param device: device, `cpu`, `cuda:0` or a list of providers 

545 `CPUExecutionProvider, CUDAExecutionProvider 

546 :param fmt: None or `csv`, it then 

547 returns a string formatted like a csv file 

548 :param profiling: if True, profile the execution of every 

549 node, if can be sorted by name or type, 

550 the value for this parameter should e in `(None, 'name', 'type')` 

551 :param profile_output: output name for the profiling 

552 if profiling is specified 

553 

554 .. cmdref:: 

555 :title: Measures model latency 

556 :cmd: -m mlprodict latency --help 

557 :lid: l-cmd-latency 

558 

559 The command generates random inputs and call many times the 

560 model on these inputs. It returns the processing time for one 

561 iteration. 

562 

563 Example:: 

564 

565 python -m mlprodict latency --model "model.onnx" 

566 """ 

567 from ..onnxrt.validate.validate_latency import latency as _latency # pylint: disable=E0402 

568 

569 if not os.path.exists(model): 

570 raise FileNotFoundError( # pragma: no cover 

571 "Unable to find model %r." % model) 

572 if profiling not in (None, '', 'name', 'type'): 

573 raise ValueError( # pragma: no cover 

574 "Unexpected value for profiling: %r." % profiling) 

575 size = int(size) 

576 number = int(number) 

577 repeat = int(repeat) 

578 if max_time in (None, 0, ""): 

579 max_time = None 

580 else: 

581 max_time = float(max_time) 

582 if max_time <= 0: 

583 max_time = None 

584 

585 if law != "normal": 

586 raise ValueError( # pragma: no cover 

587 "Only law='normal' is supported, not %r." % law) 

588 

589 if profiling in ('name', 'type') and profile_output in (None, ''): 

590 raise ValueError( # pragma: no cover 

591 'profiling is enabled but profile_output is wrong (%r).' 

592 '' % profile_output) 

593 

594 res = _latency( 

595 model, law=law, size=size, number=number, repeat=repeat, 

596 max_time=max_time, runtime=runtime, device=device, 

597 profiling=profiling) 

598 

599 if profiling not in (None, ''): 

600 res, gr = res 

601 ext = os.path.splitext(profile_output)[-1] 

602 gr = gr.reset_index(drop=False) 

603 if ext == '.csv': 

604 gr.to_csv(profile_output, index=False) 

605 elif ext == '.xlsx': # pragma: no cover 

606 gr.to_excel(profile_output, index=False) 

607 else: 

608 raise ValueError( # pragma: no cover 

609 "Unexpected extension for profile_output=%r." 

610 "" % profile_output) 

611 

612 if fmt == 'csv': 

613 st = StringIO() 

614 df = DataFrame([res]) 

615 df.to_csv(st, index=False) 

616 return st.getvalue() 

617 if fmt in (None, ''): 

618 return res 

619 raise ValueError( # pragma: no cover 

620 "Unexpected value for fmt: %r." % fmt)