Coverage for mlprodict/onnxrt/onnx_inference.py: 96%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

739 statements  

1# pylint: disable=C0302,R0912 

2""" 

3@file 

4@brief Implements a class able to compute the predictions 

5from on an :epkg:`ONNX` model. 

6""" 

7from collections import OrderedDict 

8from io import BytesIO 

9from time import perf_counter 

10import warnings 

11import textwrap 

12import pprint 

13from keyword import iskeyword 

14import numpy 

15from scipy.sparse import coo_matrix 

16from onnx import load, load_model, checker, shape_inference 

17from onnx import onnx_pb as onnx_proto 

18from onnx.helper import make_model 

19from ..tools.code_helper import make_callable, print_code 

20from ..onnx_tools.onnx2py_helper import ( 

21 _var_as_dict, numpy_min, numpy_max, guess_numpy_type_from_string) 

22from ..onnx_tools.onnx_manipulations import ( 

23 select_model_inputs_outputs, enumerate_model_node_outputs, 

24 overwrite_opset, insert_results_into_onnx) 

25from ..onnx_tools.optim import onnx_remove_node_unused 

26from .onnx_inference_node import OnnxInferenceNode 

27from .onnx_inference_exports import OnnxInferenceExport 

28from .shape_object import ShapeObject 

29from .type_object import SequenceType 

30 

31 

32class OnnxInference: 

33 """ 

34 Loads an :epkg:`ONNX` file or object or stream. 

35 Computes the output of the :epkg:`ONNX` graph. 

36 Several runtimes are available. 

37 

38 * ``'python'``: the runtime implements every onnx operator 

39 needed to run a :epkg:`scikit-learn` model by using :epkg:`numpy` 

40 or C++ code. 

41 * ``'python_compiled'``: it is the same runtime than the previous 

42 one except every operator is called from a compiled function 

43 (@see me _build_compile_run) instead for a method going through 

44 the list of operator 

45 * ``'onnxruntime1'``: uses :epkg:`onnxruntime` (or `onnxruntime1-cuda`, ...) 

46 * ``'onnxruntime2'``: this mode is mostly used to debug as 

47 python handles calling every operator but :epkg:`onnxruntime` 

48 is called for every of them, this process may fail due to 

49 wrong inference type specially of the graph includes 

50 custom nodes, in that case, it is better to compute the output 

51 of intermediates nodes. It is much slower as fo every output, every 

52 node is computed but more robust. 

53 

54 :param onnx_or_bytes_or_stream: :epkg:`onnx` object, 

55 bytes, or filename or stream 

56 :param runtime: runtime options 

57 :param skip_run: do not build the runtime 

58 :param inplace: use inplace computation as much as possible 

59 :param input_inplace: the computation is allowed 

60 to overwrite the input, see :meth:`_guess_inplace 

61 <mlprodict.onnxrt.onnx_inference.OnnxInference._guess_inplace>` 

62 :param ir_version: if not None, overwrite the default version 

63 :param target_opset: used to overwrite *target_opset* 

64 :param runtime_options: specific options for the runtime 

65 :param inside_loop: tells the runtime the graph is meant to 

66 be repeated multiple times (in that case, inputs and 

67 outputs may share the same name) 

68 :param static_inputs: Loop can use static variables, 

69 variables from the graph which runs the loop 

70 (enumerate of strings) 

71 :param new_outputs: if the loading fails, it might worth 

72 cutting the graph, if not None, the graph will 

73 be cut to have these new_outputs as the final outputs 

74 :param new_opset: overwrite the main opset and replaces 

75 by this new one 

76 :param existing_functions: a model may contain several local functions, 

77 this parameter is used when a local function is calling another 

78 local function previously defined. 

79 

80 Among the possible runtime_options, there are: 

81 * *enable_profiling*: enables profiling for :epkg:`onnxruntime` 

82 * *session_options*: an instance of *SessionOptions* from 

83 :epkg:`onnxruntime` 

84 * *ir_version*: change ir_version 

85 

86 .. versionchanged:: 0.7 

87 Parameters *new_outputs*, *new_opset* were added. 

88 

89 .. versionchanged:: 0.8 

90 Parameters *static_inputs*, *device* were added. 

91 

92 .. versionchanged:: 0.9 

93 Parameters *existing_functions* was added. 

94 Removes *device* parameter. See runtime. 

95 Runtime `onnxruntime1-cuda` was added. 

96 """ 

97 

98 def __init__(self, onnx_or_bytes_or_stream, runtime=None, 

99 skip_run=False, inplace=True, 

100 input_inplace=False, ir_version=None, 

101 target_opset=None, runtime_options=None, 

102 session_options=None, inside_loop=False, 

103 static_inputs=None, new_outputs=None, new_opset=None, 

104 existing_functions=None): 

105 if isinstance(onnx_or_bytes_or_stream, bytes): 

106 self.obj = load_model(BytesIO(onnx_or_bytes_or_stream)) 

107 elif isinstance(onnx_or_bytes_or_stream, BytesIO): 

108 self.obj = load_model(onnx_or_bytes_or_stream) 

109 elif isinstance(onnx_or_bytes_or_stream, str): 

110 self.obj = load(onnx_or_bytes_or_stream) 

111 elif hasattr(onnx_or_bytes_or_stream, 'graph'): 

112 self.obj = onnx_or_bytes_or_stream 

113 elif isinstance(onnx_or_bytes_or_stream, onnx_proto.GraphProto): 

114 self.obj = make_model(onnx_or_bytes_or_stream, 

115 producer_name='mlprodict') 

116 elif isinstance(onnx_or_bytes_or_stream, onnx_proto.FunctionProto): 

117 self.obj = onnx_or_bytes_or_stream 

118 else: 

119 raise TypeError("Unable to handle type {}.".format( # pragma: no cover 

120 type(onnx_or_bytes_or_stream))) 

121 if ir_version is not None: 

122 self.obj.ir_version = ir_version 

123 if new_outputs is not None: 

124 self.obj = select_model_inputs_outputs( 

125 self.obj, outputs=new_outputs, infer_shapes=True) 

126 if new_opset is not None: 

127 self.obj = overwrite_opset(self.obj, new_opset) 

128 

129 self.runtime = runtime 

130 self.skip_run = skip_run 

131 self.input_inplace = input_inplace 

132 self.inplace = inplace 

133 self.force_target_opset = target_opset 

134 self.runtime_options = runtime_options 

135 self.inside_loop = inside_loop 

136 self.static_inputs = static_inputs 

137 self._init(existing_functions) 

138 

139 def __getstate__(self): 

140 """ 

141 To pickle the object. 

142 """ 

143 return {'onnx': self.obj.SerializeToString(), 

144 'runtime': self.runtime, 

145 'runtime_options': self.runtime_options, 

146 'skip_run': self.skip_run, 

147 'input_inplace': self.input_inplace, 

148 'inplace': self.inplace, 

149 'force_target_opset': self.force_target_opset, 

150 'static_inputs': self.static_inputs, 

151 'inside_loop': self.inside_loop} 

152 

153 def __setstate__(self, state): 

154 """ 

155 To unpickle the object. 

156 """ 

157 onx = state['onnx'] 

158 self.obj = load_model(BytesIO(onx)) 

159 self.runtime = state['runtime'] 

160 self.runtime_options = state['runtime_options'] 

161 self.skip_run = state['skip_run'] 

162 self.input_inplace = state['input_inplace'] 

163 self.inplace = state['inplace'] 

164 self.force_target_opset = state['force_target_opset'] 

165 self.static_inputs = state['static_inputs'] 

166 self.inside_loop = state['inside_loop'] 

167 self._init() 

168 

169 def _init(self, existing_functions=None): 

170 """ 

171 Prepares the instance to deliver predictions. 

172 """ 

173 self.graph_ = self.to_sequence(existing_functions) 

174 if len(self.graph_['sequence']) == 0: 

175 raise RuntimeError( # pragma: no cover 

176 "No runnable nodes was found in the ONNX graph.") 

177 self.outputs_ = self.graph_['outputs'] 

178 self.inputs_ = self.graph_['inputs'] 

179 is_function_proto = isinstance(self.obj, onnx_proto.FunctionProto) 

180 if is_function_proto: 

181 obj_graph = self.obj 

182 else: 

183 obj_graph = self.obj.graph 

184 

185 for ino in [obj_graph.input, obj_graph.output]: 

186 for xy in ino: 

187 if isinstance(xy, str): 

188 shape = None 

189 else: 

190 shape = xy.type.tensor_type.shape 

191 for d in shape.dim: 

192 if (d.dim_value == 0 and "0" in str(d) and 

193 'dim_param' not in str(d)): 

194 if len(shape.dim) <= 1: 

195 shape = None 

196 break 

197 # d.dim_value returns 0 whether is is 0 or empty. 

198 # it may be a parameter as well 

199 raise RuntimeError( # pragma: no cover 

200 "Wrong ONNX file, one input or output has " 

201 "an empty shape: {}.".format(xy)) 

202 

203 self.target_opset_ = self.graph_['targets'] 

204 if self.force_target_opset is not None: 

205 if isinstance(self.force_target_opset, dict): 

206 self.target_opset_ = self.force_target_opset # pragma: no cover 

207 else: 

208 self.target_opset_ = {'': self.force_target_opset} 

209 self.ir_version_ = self.graph_['ir_version'] 

210 

211 if not self.skip_run: 

212 if self.runtime is not None and self.runtime.startswith('onnxruntime1'): 

213 # Loads the onnx with onnxruntime as a single file. 

214 del self.graph_ 

215 from .ops_whole.session import OnnxWholeSession 

216 self._whole = OnnxWholeSession( 

217 self.obj, self.runtime, self.runtime_options) 

218 self._run = self._run_whole_runtime 

219 else: 

220 self.sequence_ = self.graph_['sequence'] 

221 self.inits_ = self.graph_['inits'] 

222 self.statics_ = self.graph_['statics'] 

223 dtype = self._guess_input_dtype() 

224 variables = self.inits_.copy() 

225 for node in self.sequence_: 

226 domain = node.onnx_node.domain 

227 target_opset = self.target_opset_.get(domain, None) 

228 keyf = domain, node.onnx_node.op_type 

229 if keyf in self.graph_['functions']: 

230 node.setup_runtime(self.graph_['functions'][keyf]) 

231 elif self.runtime in ('onnxruntime2', 'empty'): 

232 node.setup_runtime( 

233 self.runtime, variables, self.__class__, 

234 target_opset=target_opset, dtype=dtype, 

235 domain=domain, ir_version=self.ir_version_, 

236 runtime_options=self.runtime_options, 

237 build_inference_node_function=lambda fct: 

238 OnnxInference( 

239 fct, runtime=self.runtime, 

240 skip_run=self.skip_run, 

241 inplace=self.inplace, 

242 runtime_options=self.runtime_options, 

243 inside_loop=self.inside_loop, 

244 static_inputs=self.static_inputs)) 

245 else: 

246 node.setup_runtime( 

247 self.runtime, variables, self.__class__, 

248 target_opset=target_opset, domain=domain, 

249 ir_version=self.ir_version_, 

250 runtime_options=self.runtime_options, 

251 build_inference_node_function=lambda fct: 

252 OnnxInference( 

253 fct, runtime=self.runtime, 

254 skip_run=self.skip_run, 

255 inplace=self.inplace, 

256 runtime_options=self.runtime_options, 

257 inside_loop=self.inside_loop, 

258 static_inputs=self.static_inputs)) 

259 if hasattr(node, 'ops_') and hasattr(node.ops_, 'typed_outputs_'): 

260 for k, v in node.ops_.typed_outputs_: 

261 variables[k] = v 

262 self._run = self._run_sequence_runtime 

263 

264 if not self.skip_run and self.runtime in ('python', None): 

265 if is_function_proto: 

266 self.shapes_ = None 

267 else: 

268 self.shapes_ = self._set_shape_inference_runtime() 

269 if self.inplace: 

270 self.inplaces_ = self._guess_inplace(self.input_inplace) 

271 self.exporters_ = OnnxInferenceExport(self) 

272 self.to_json = self.exporters_.to_json 

273 self.to_dot = self.exporters_.to_dot 

274 self.to_python = self.exporters_.to_python 

275 self.to_text = self.exporters_.to_text 

276 self.to_onnx_code = self.exporters_.to_onnx_code 

277 

278 if self.runtime in ('python_compiled', 'python_compiled_debug'): 

279 # switch the inference method to the compiled one 

280 _, fct, code = self._build_compile_run('debug' in self.runtime) 

281 setattr(self, '_run_compiled', fct) 

282 setattr(self, '_run_compiled_code', code) 

283 self._run = self._run_sequence_runtime_compiled 

284 

285 def _run_sequence_runtime_compiled( 

286 self, inputs, clean_right_away=False, intermediate=False, 

287 verbose=0, node_time=False, yield_ops=None, fLOG=None): 

288 """ 

289 Executes a compiled version of @see me _run_sequence_runtime, 

290 compiled with method @see me _build_compile_run. 

291 Every parameter with a default value is ignored. 

292 Switch to ``runtime='python'`` to enable those. 

293 """ 

294 try: 

295 return self._run_compiled( # pylint: disable=E1101 

296 inputs, yield_ops=yield_ops) 

297 except NameError as e: 

298 raise RuntimeError( # pragma: no cover 

299 "Unable to compute prediction due to %r. Code:\n%s" 

300 "" % (e, print_code( 

301 self._run_compiled_code))) from e # pylint: disable=E1101 

302 

303 def _guess_input_dtype(self): 

304 for _, v in self.graph_['inputs'].items(): 

305 if 'type' not in v: 

306 continue # pragma: no cover 

307 t = v['type'] 

308 if 'elem' not in t: 

309 continue 

310 if t['elem'] == 'double': 

311 return numpy.float64 

312 return numpy.float32 

313 

314 def __str__(self): 

315 """ 

316 usual 

317 """ 

318 rows = ['OnnxInference(...)'] 

319 if hasattr(self, '_run_compiled_code'): 

320 rows.append( 

321 textwrap.indent( 

322 self._run_compiled_code, ' ')) # pylint: disable=E1101 

323 else: 

324 rows.append(textwrap.indent(str(self.obj), ' ')) 

325 return "\n".join(rows) 

326 

327 def __repr__(self): 

328 """ 

329 usual 

330 """ 

331 return "OnnxInference(...)" # pragma: no cover 

332 

333 def check_model(self): 

334 """ 

335 Checks the model follow :epkg:`ONNX` conventions. 

336 """ 

337 checker.check_model(self.obj) 

338 

339 def shape_inference(self): 

340 """ 

341 Infers the shape of the outputs 

342 with :epkg:`onnx` package. 

343 

344 @return A new :epkg:`ONNX` graph which defined outputs. 

345 """ 

346 return shape_inference.infer_shapes(self.obj) 

347 

348 @property 

349 def input_names(self): 

350 """ 

351 Returns the names of all inputs. 

352 It does not include the optional inputs. 

353 

354 .. versionchanged:: 0.6 

355 The list does not include optional inputs anymore. 

356 """ 

357 if hasattr(self.obj, 'graph'): 

358 inits = set(_.name for _ in self.obj.graph.initializer) 

359 return [_.name for _ in self.obj.graph.input if _.name not in inits] 

360 return list(self.obj.input) 

361 

362 @property 

363 def input_names_shapes(self): 

364 """ 

365 Returns the names and shapes of all inputs. 

366 This method assumes all inputs are tensors. 

367 It does not include the optional inputs. 

368 

369 .. versionchanged:: 0.6 

370 The list does not include optional inputs anymore. 

371 """ 

372 names = set(self.input_names) 

373 return [(_.name, _var_as_dict(_)['type']['shape']) 

374 for _ in self.obj.graph.input if _.name in names] 

375 

376 @staticmethod 

377 def _get_type_property(info, prop): 

378 if prop in info: 

379 return info[prop] 

380 if 'kind' in info and info['kind'] == 'sequence': 

381 if prop == 'shape': 

382 return ('?', ) 

383 raise NotImplementedError( # pragma: no cover 

384 "Unable to retrieve property %r from %r." 

385 "" % (prop, info)) 

386 

387 @property 

388 def input_names_shapes_types(self): 

389 """ 

390 Returns the names, shapes, types of all inputs. 

391 This method assumes all inputs are tensors. 

392 It does not include the optional inputs. 

393 

394 .. versionchanged:: 0.6 

395 The list does not include optional inputs anymore. 

396 """ 

397 f = OnnxInference._get_type_property 

398 names = set(self.input_names) 

399 if isinstance(self.obj, onnx_proto.FunctionProto): 

400 return [(_.name, f(_var_as_dict(_)['type'], 'shape'), 

401 'tensor(%s)' % f(_var_as_dict(_)['type'], 'elem')) 

402 for _ in self.obj.input if _.name in names] 

403 return [(_.name, f(_var_as_dict(_)['type'], 'shape'), 

404 'tensor(%s)' % f(_var_as_dict(_)['type'], 'elem')) 

405 for _ in self.obj.graph.input if _.name in names] 

406 

407 @property 

408 def output_names(self): 

409 """ 

410 Returns the names of all outputs. 

411 """ 

412 if isinstance(self.obj, onnx_proto.FunctionProto): 

413 return [_ for _ in self.obj.output] 

414 return [_.name for _ in self.obj.graph.output] 

415 

416 @property 

417 def output_names_shapes(self): 

418 """ 

419 Returns the names and shapes of all outputs. 

420 This method assumes all inputs are tensors. 

421 """ 

422 f = OnnxInference._get_type_property 

423 if isinstance(self.obj, onnx_proto.FunctionProto): 

424 return [(_, None) for _ in self.obj.output] 

425 return [(_.name, f(_var_as_dict(_)['type'], 'shape')) 

426 for _ in self.obj.graph.output] 

427 

428 @property 

429 def output_names_shapes_types(self): 

430 """ 

431 Returns the names, shapes, types of all outputs. 

432 This method assumes all inputs are tensors. 

433 It does not include the optional outputs. 

434 

435 .. versionadd:: 0.7 

436 """ 

437 names = set(self.output_names) 

438 f = OnnxInference._get_type_property 

439 if isinstance(self.obj, onnx_proto.FunctionProto): 

440 return [(_, None) for _ in self.obj.graph.output if _ in names] 

441 return [(_.name, f(_var_as_dict(_)['type'], 'shape'), 

442 'tensor(%s)' % f(_var_as_dict(_)['type'], 'elem')) 

443 for _ in self.obj.graph.output if _.name in names] 

444 

445 def global_index(self, name): 

446 """ 

447 Maps every name to one integer to avoid using dictionaries 

448 when running the predictions. 

449 

450 @param name outputs name 

451 @return integer 

452 """ 

453 if not hasattr(self, '_global_index'): 

454 self._global_index = {} 

455 if name in self._global_index: 

456 return self._global_index[name] 

457 self._global_index[name] = len(self._global_index) 

458 return self._global_index[name] 

459 

460 def to_sequence(self, existing_functions=None): 

461 """ 

462 Produces a graph to facilitate the execution. 

463 

464 One example: 

465 

466 .. exref:: 

467 :title: Convert ONNX into graph 

468 

469 An example on how to convert an :epkg:`ONNX` 

470 graph into a graph. 

471 

472 .. runpython:: 

473 :showcode: 

474 :warningout: DeprecationWarning 

475 

476 import pprint 

477 import numpy 

478 from mlprodict.npy.xop import loadop 

479 from mlprodict.onnxrt import OnnxInference 

480 

481 OnnxAiOnnxMlLinearRegressor = loadop( 

482 ('ai.onnx.ml', 'LinearRegressor')) 

483 

484 pars = dict(coefficients=numpy.array([1., 2.]), 

485 intercepts=numpy.array([1.]), 

486 post_transform='NONE') 

487 onx = OnnxAiOnnxMlLinearRegressor( 

488 'X', output_names=['Y'], **pars) 

489 model_def = onx.to_onnx( 

490 {'X': pars['coefficients'].astype(numpy.float32)}, 

491 outputs={'Y': numpy.float32}, 

492 target_opset=12) 

493 oinf = OnnxInference(model_def) 

494 pprint.pprint(oinf.to_sequence()) 

495 

496 See an example of representation in notebook 

497 :ref:`onnxvisualizationrst`. 

498 """ 

499 inits = {} 

500 variables = {} 

501 outputs = {} 

502 nodes = {} 

503 statics = {} 

504 targets = {} 

505 functions = {} 

506 if existing_functions is not None: 

507 functions.update(existing_functions) 

508 is_function_proto = isinstance(self.obj, onnx_proto.FunctionProto) 

509 

510 for o in self.obj.opset_import: 

511 targets[o.domain] = o.version 

512 

513 if (hasattr(self.obj, 'functions') and len(self.obj.functions) > 0 and 

514 (self.runtime is None or not 

515 self.runtime.startswith('onnxruntime1'))): 

516 for fct in self.obj.functions: 

517 functions[fct.domain, fct.name] = OnnxInference( 

518 fct, runtime=self.runtime, 

519 skip_run=self.skip_run, 

520 inplace=self.inplace, 

521 runtime_options=self.runtime_options, 

522 inside_loop=self.inside_loop, 

523 static_inputs=self.static_inputs, 

524 existing_functions=functions) 

525 

526 # static variables 

527 if self.static_inputs is not None: 

528 for n in self.static_inputs: 

529 statics[n] = {'name': n} 

530 self.global_index(n) 

531 

532 obj_graph = ( 

533 self.obj if isinstance(self.obj, onnx_proto.FunctionProto) 

534 else self.obj.graph) 

535 

536 # inputs 

537 for obj in obj_graph.input: 

538 if is_function_proto: 

539 variables[obj] = {'name': obj} 

540 self.global_index(obj) 

541 else: 

542 variables[obj.name] = _var_as_dict(obj) 

543 self.global_index(obj.name) 

544 

545 # outputs 

546 for obj in obj_graph.output: 

547 if is_function_proto: 

548 outputs[obj] = {'name': obj} 

549 self.global_index(obj) 

550 else: 

551 if hasattr(obj, 'type') and str(obj.type) != '': 

552 outputs[obj.name] = _var_as_dict(obj) 

553 else: 

554 outputs[obj.name] = {'name': obj.name} 

555 self.global_index(obj.name) 

556 

557 # initializer 

558 if not is_function_proto: 

559 for obj in obj_graph.initializer: 

560 init_obj = _var_as_dict(obj) 

561 if init_obj is None: 

562 raise RuntimeError( # pragma: no cover 

563 "Unable to convert an initializer\n{}".format(obj)) 

564 inits[obj.name] = init_obj 

565 self.global_index(obj.name) 

566 if 'value' not in inits[obj.name]: 

567 raise RuntimeError( # pragma: no cover 

568 "One initializer has no value: '{}'\n{}\n{}".format( 

569 obj.name, inits[obj.name], obj)) 

570 

571 # nodes 

572 for node in obj_graph.node: 

573 dobj = _var_as_dict(node) 

574 if dobj is None: 

575 raise RuntimeError( # pragma: no cover 

576 "Unable to convert a node\n{}".format(node)) 

577 if 'atts' in dobj: 

578 atts = dobj['atts'] 

579 for k, v in atts.items(): 

580 if not isinstance(v, dict) or 'value' not in v: 

581 raise RuntimeError( # pragma: no cover 

582 "A parameter has no (sparse) value '{}' " 

583 "for node '{}'\nv={}\ndobj=[{}]".format( 

584 k, node.name, v, node)) 

585 if node.name in nodes: # pragma: no cover 

586 i = 2 

587 while True: 

588 new_name = "%s_n%i" % (node.name, i) 

589 if new_name not in nodes: 

590 break 

591 i += 1 

592 else: 

593 new_name = node.name 

594 nodes[new_name] = OnnxInferenceNode(node, dobj, self.global_index) 

595 

596 # names 

597 names = {} 

598 for k, v in statics.items(): 

599 if (k, 0) in names: 

600 raise RuntimeError( # pragma: no cover 

601 "Static variables '{}' already exists (tag='{}').".format( 

602 k, names[k, 0][0])) 

603 names[k, 0] = ('S', v) 

604 for k, v in inits.items(): 

605 if (k, 0) in names: 

606 raise RuntimeError( # pragma: no cover 

607 "Initializer '{}' already exists (tag='{}').".format( 

608 k, names[k, 0][0])) 

609 names[k, 0] = ('C', v) 

610 for k, v in variables.items(): 

611 if (k, 0) in names: 

612 if k in inits: 

613 # Kind of default value for an input 

614 continue 

615 raise RuntimeError( # pragma: no cover 

616 "Variable '{}' already exists (tag='{}').".format( 

617 k, names[k, 0][0])) 

618 names[k, 0] = ('I', v) 

619 for k, v in outputs.items(): 

620 if (k, 0) in names and self.runtime != 'empty': 

621 if not self.inside_loop or names[k, 0][0] != 'I': 

622 raise RuntimeError( # pragma: no cover 

623 "Output '{}' already exists (tag='{}').".format( 

624 k, names[k, 0][0])) 

625 else: 

626 # For input, output sharing the same name, we marked the name 

627 # as an input. 

628 continue 

629 names[k, 0] = ('O', v) 

630 for k, v in nodes.items(): 

631 if (k, 1) in names: 

632 raise RuntimeError( # pragma: no cover 

633 "Node '{}' already exists (tag='{}'). " 

634 "Use inside_loop=True to bypass this exception.".format( 

635 k, names[k, 0][0])) 

636 names[k, 1] = ('N', v) 

637 

638 # ordering 

639 order = {} 

640 modif = 1 

641 intermediate = {} 

642 while modif > 0: 

643 modif = 0 

644 for (k, _), v in names.items(): 

645 if (k, 1) in order: 

646 # The operator node is already processed. 

647 continue 

648 if v[0] in {'I', 'C', 'S'}: 

649 if (k, 0) not in order: 

650 order[k, 0] = len(order) # A data node. 

651 modif += 1 

652 continue 

653 if v[0] == 'O': 

654 continue 

655 if all((inp, 0) in order for inp in v[1].inputs if inp != ''): 

656 # If all inputs are available, 

657 # We tell the operator node is processed. 

658 order[k, 1] = len(order) 

659 modif += 1 

660 for o in v[1].outputs: 

661 if (o, 0) in order: 

662 raise RuntimeError( # pragma: no cover 

663 "Two nodes share the same output '{}' " 

664 "or an operator and an output " 

665 "share the same name. " 

666 "(node: {}).".format(o, v[1])) 

667 # We add a data node. 

668 order[o, 0] = len(order) 

669 intermediate[o] = None 

670 modif += 1 

671 

672 # compute 

673 rev = [(v, k[0], k[1]) for k, v in order.items()] 

674 rev.sort() 

675 sequence = [] 

676 for _, name, node_kind in rev: 

677 if name not in nodes: 

678 continue 

679 if node_kind == 0: 

680 # It is an output which shares the same name 

681 # as a node. 

682 continue 

683 node = nodes[name] 

684 node.set_order(len(sequence)) 

685 sequence.append(node) 

686 

687 if len(sequence) == 0: 

688 from mlprodict.plotting.text_plot import onnx_simple_text_plot 

689 raise RuntimeError( # pragma: no cover 

690 "No runnable nodes was found in the ONNX graph" 

691 "\n--rev--\n{}" 

692 "\n--order--\n{}" 

693 "\n--nodes--\n{}" 

694 "\n--ONNX--\n{}\n---\n".format( 

695 "\n".join([str(_) for _ in names.items()]), 

696 "\n".join([str(_) for _ in order.items()]), 

697 "\n".join([str(_) for _ in nodes.items()]), 

698 onnx_simple_text_plot(self.obj, recursive=True))) 

699 

700 # defines where an intermediare output is not needed 

701 last_used = {} 

702 for node in sequence: 

703 for inp in node.inputs: 

704 last_used[inp] = node.order 

705 for k, ord in last_used.items(): 

706 sequence[ord].add_variable_to_clean(k) 

707 

708 results = dict(inits=inits, inputs=variables, outputs=outputs, 

709 nodes=nodes, sequence=sequence, 

710 functions=functions, 

711 intermediate=intermediate, 

712 targets=targets, 

713 ir_version=( 

714 None if is_function_proto 

715 else self.obj.ir_version), 

716 statics=statics) 

717 if len(sequence) < len(nodes): 

718 # Not all node will be executed. 

719 raise RuntimeError( # pragma: no cover 

720 "Unable to run all nodes.\n--Nodes--\n%s\n--Sequence--\n%s" 

721 "\n--Inputs--\n%s\n--Inits--\n%s\n--Statics\n%s" 

722 "" % (pprint.pformat(nodes), pprint.pformat(sequence), 

723 pprint.pformat(list(variables)), 

724 pprint.pformat(list(inits)), 

725 pprint.pformat(list(statics)))) 

726 return results 

727 

728 def run(self, inputs, clean_right_away=False, 

729 intermediate=False, verbose=0, node_time=False, 

730 overwrite_types=None, yield_ops=None, fLOG=None): 

731 """ 

732 Computes the predictions for this :epkg:`onnx` graph. 

733 

734 :param inputs: inputs as dictionary or a dataframe 

735 :param clean_right_away: clean the intermediate outputs 

736 as soon as they are not needed 

737 :param intermediate: returns a dictionary of intermediate 

738 variables instead of the results only 

739 :param verbose: display information while predicting 

740 :param node_time: measure time of each node 

741 :param overwrite_types: shape inference does not work all the time, 

742 this allows to force types when building intermediate 

743 results, see @see fn select_model_inputs_outputs 

744 :param yield_ops: dictionary to overwrite the output of 

745 operator *YieldOp* 

746 :param fLOG: logging function if *verbose > 0* 

747 :return: outputs as dictionary 

748 and a second dictionary of the time spent 

749 in each node if *node_time* is True 

750 

751 .. exref:: 

752 :title: Computes predictions with any runtime 

753 

754 The following example compares predictions 

755 between :epkg:`scikit-learn` and this runtime 

756 for the python runtime. 

757 

758 .. runpython:: 

759 :showcode: 

760 :warningout: DeprecationWarning 

761 

762 import numpy 

763 from sklearn.linear_model import LinearRegression 

764 from sklearn.datasets import load_iris 

765 from sklearn.model_selection import train_test_split 

766 from mlprodict.onnxrt import OnnxInference 

767 from mlprodict.onnx_conv import to_onnx 

768 

769 iris = load_iris() 

770 X, y = iris.data, iris.target 

771 X_train, X_test, y_train, _ = train_test_split(X, y) 

772 clr = LinearRegression() 

773 clr.fit(X_train, y_train) 

774 

775 exp = clr.predict(X_test[:5]) 

776 print(exp) 

777 

778 model_def = to_onnx(clr, X_train.astype(numpy.float32), 

779 target_opset=12) 

780 oinf = OnnxInference(model_def) 

781 y = oinf.run({'X': X_test[:5]}) 

782 print(y) 

783 

784 The function returns all intermediate outputs 

785 if *intermediate* is True. In case of runtime 

786 *onnxruntime1*, if intermediate is True, 

787 the first class builds all :epkg:`ONNX` cut out 

788 to keep the one output and converted into 

789 *OnnxInference*. 

790 

791 .. versionchanged:: 0.8 

792 Parameter *yield_ops* was added. 

793 """ 

794 def retype(col_array): 

795 if (hasattr(col_array, 'categories') and 

796 hasattr(col_array, 'from_codes')): 

797 # isinstance(col_array, pandas.Categorical): 

798 return col_array.astype(numpy.int64) 

799 return col_array 

800 

801 if hasattr(inputs, 'columns') and hasattr(inputs, 'iloc'): 

802 # == isinstance(inputs, pandas.DataFrame) 

803 inputs = OrderedDict(( 

804 name, retype(numpy.expand_dims(inputs[name].values, axis=1))) 

805 for name in inputs.columns) 

806 if intermediate: 

807 if self.inplace: 

808 raise RuntimeError( # pragma: no cover 

809 "inplace must be False if intermediate is True, a container " 

810 "might be used by several nodes.") 

811 return self._run(inputs, clean_right_away=False, 

812 intermediate=intermediate, 

813 verbose=verbose, node_time=node_time, 

814 overwrite_types=overwrite_types, 

815 yield_ops=yield_ops, fLOG=fLOG) 

816 if overwrite_types is not None: 

817 raise RuntimeError( # pragma: no cover 

818 "overwrite_types is not used if intermediate is False.") 

819 return self._run(inputs, clean_right_away=False, 

820 intermediate=intermediate, 

821 verbose=verbose, node_time=node_time, 

822 yield_ops=yield_ops, fLOG=fLOG) 

823 

824 def run2onnx(self, inputs, verbose=0, fLOG=None, 

825 as_parameter=True, suffix='_DBG', 

826 param_name=None, node_type='DEBUG', 

827 domain='DEBUG', domain_opset=1): 

828 """ 

829 Executes the graphs with the given inputs, then adds the intermediate 

830 results into ONNX nodes in the original graph. Once saved, it can be 

831 looked with a tool such as :epkg:`netron`. 

832 

833 :param inputs: inputs as dictionary or a dataframe 

834 :param verbose: display information while predicting 

835 :param fLOG: logging function if *verbose > 0* 

836 :param as_parameter: add new nodes with results as one parameter 

837 (True) or as initializer (False) 

838 :param suffix: suffix to add to new results 

839 :param param_name: name of the parameter to add 

840 (by default the result name), it can be a function 

841 `param_name(reult_name) -> parameter_name` 

842 :param node_type: type of the new node 

843 :param domain: domain the new node 

844 :param domain_opset: opset for *domain* 

845 :return: outputs as dictionary 

846 and the onnx graph with new nodes 

847 

848 The following example shows how to use it. 

849 

850 .. gdot:: 

851 :script: DOT-SECTION 

852 

853 from sklearn.linear_model import LinearRegression 

854 from sklearn.datasets import load_iris 

855 from mlprodict.onnxrt import OnnxInference 

856 import numpy 

857 

858 iris = load_iris() 

859 X = iris.data[:, :2] 

860 y = iris.target 

861 lr = LinearRegression() 

862 lr.fit(X, y) 

863 

864 from mlprodict.onnx_conv import to_onnx 

865 model_onnx = to_onnx(lr, X.astype(numpy.float32)) 

866 oinf = OnnxInference(model_onnx, inplace=False) 

867 

868 model_onnx_debug = oinf.run2onnx({'X': X[:3].astype(numpy.float32)}) 

869 oinf_debug = OnnxInference(model_onnx_debug[1]) 

870 

871 print("DOT-SECTION", oinf_debug.to_dot()) 

872 

873 .. versionadded:: 0.7 

874 """ 

875 intermediate = self.run(inputs, verbose=verbose, fLOG=fLOG, 

876 intermediate=True) 

877 for name in self.input_names: 

878 del intermediate[name] 

879 new_onx = insert_results_into_onnx( 

880 self.obj, intermediate, as_parameter=as_parameter, 

881 suffix=suffix, param_name=param_name, node_type=node_type, 

882 domain=domain, domain_opset=domain_opset) 

883 return intermediate, new_onx 

884 

885 def display_sequence(self, verbose=1): 

886 """ 

887 Shows the sequence of nodes to run if ``runtime=='python'``. 

888 """ 

889 rows = [] 

890 rows.append("#node: {}".format(len(self.sequence_))) 

891 for i, node in enumerate(self.sequence_): 

892 if verbose >= 1: 

893 rows.append("{}: {}".format(i, str(node))) 

894 return "\n".join(rows) 

895 

896 def _run_sequence_runtime(self, inputs, clean_right_away=False, 

897 intermediate=False, verbose=0, node_time=False, 

898 overwrite_types=None, yield_ops=None, 

899 fLOG=None): 

900 if overwrite_types is not None: 

901 raise NotImplementedError( # pragma: no cover 

902 "overwrite_types != None not implemented.") 

903 if clean_right_away: 

904 raise NotImplementedError( # pragma: no cover 

905 "clean_right_away=true not implemented.") 

906 

907 if node_time: 

908 mtime = [] 

909 if verbose >= 1 and fLOG is not None: 

910 printed = set() 

911 

912 if hasattr(self, "_values_init"): 

913 values = self._values_init.copy() # pylint: disable=E0203 

914 else: 

915 values = [None] * len(self._global_index) 

916 if verbose >= 1 and fLOG is not None: 

917 for k, v in self.inits_.items(): 

918 values[self._global_index[k]] = v['value'] 

919 if verbose < 3: 

920 fLOG("+ki='{}': {} (dtype={} min={} max={})".format( 

921 k, v['value'].shape, v['value'].dtype, 

922 numpy_min(v['value']), numpy_max(v['value']))) 

923 else: 

924 fLOG("+ki='{}': {} (dtype={} min={} max={}\n{}".format( 

925 k, v['value'].shape, v['value'].dtype, 

926 numpy_min(v['value']), numpy_max(v['value']), 

927 v['value'])) 

928 printed.add(k) 

929 else: 

930 for k, v in self.inits_.items(): 

931 values[self._global_index[k]] = v['value'] 

932 # stores the array to skip initialing a second time 

933 if verbose == 0 or fLOG is None: 

934 self._values_init = values.copy() 

935 

936 for name, value in inputs.items(): 

937 values[self._global_index[name]] = value 

938 

939 if verbose == 0 or fLOG is None: 

940 if node_time: 

941 for i, node in enumerate(self.sequence_): 

942 if yield_ops is not None and node.onnx_node.op_type == 'YieldOp': 

943 out = node.onnx_node.output[0] 

944 if out in yield_ops: 

945 values[out] = yield_ops[out] 

946 continue 

947 raise RuntimeError( # pragma: no cover 

948 "YieldOp output %r could not be found in " 

949 "yield_ops: %r (node=%r)." % ( 

950 out, list(sorted(yield_ops)), node.onnx_node)) 

951 t = perf_counter() 

952 node.run(values) 

953 t2 = perf_counter() 

954 mtime.append(dict(i=i, name=node.onnx_node.name, 

955 op_type=node.onnx_node.op_type, 

956 time=t2 - t)) 

957 else: 

958 for node in self.sequence_: 

959 node.run(values) 

960 else: 

961 def dispsimple(arr): 

962 if hasattr(arr, 'shape'): 

963 if len(arr.shape) <= 1: 

964 threshold = 8 

965 else: 

966 threshold = min( 

967 50, min(50 // max(arr.shape[1], 1), 8) * arr.shape[1]) 

968 if hasattr(arr, 'todense'): 

969 fLOG( # pragma: no cover 

970 numpy.array2string(arr.todense(), max_line_width=120, 

971 suppress_small=True, threshold=threshold)) 

972 else: 

973 fLOG(numpy.array2string(arr, max_line_width=120, 

974 suppress_small=True, 

975 threshold=threshold)) 

976 else: # pragma: no cover 

977 s = str(arr) 

978 if len(s) > 50: 

979 s = s[:50] + "..." 

980 fLOG(s) 

981 

982 if verbose >= 2: 

983 for k in sorted(self._global_index): 

984 if values[self._global_index[k]] is None: 

985 continue 

986 obj = values[self._global_index[k]] 

987 if k not in printed: 

988 printed.add(k) 

989 if hasattr(obj, 'shape'): 

990 fLOG("-kv='{}' shape={} dtype={} min={} max={}{}".format( 

991 k, obj.shape, obj.dtype, numpy_min(obj), 

992 numpy_max(obj), 

993 ' (sparse)' if isinstance(obj, coo_matrix) else '')) 

994 elif (isinstance(obj, list) and len(obj) > 0 and 

995 not isinstance(obj[0], dict)): # pragma: no cover 

996 fLOG("-kv='{}' list len={}".format(k, len(obj))) 

997 if verbose >= 3 and len(obj) > 0: 

998 fLOG("first={} last={}".format( 

999 obj[0], obj[-1])) 

1000 else: # pragma: no cover 

1001 fLOG("-kv='{}' type={}".format(k, type(obj))) 

1002 

1003 keys = set(k for k in range(len(values)) if values[k] is not None) 

1004 if verbose >= 1: 

1005 fLOG("-- OnnxInference: run {} nodes".format(len(self.sequence_))) 

1006 for i, node in enumerate(self.sequence_): 

1007 if verbose >= 1: 

1008 fLOG(node) 

1009 if yield_ops is not None and node.onnx_node.op_type == 'YieldOp': 

1010 out = node.onnx_node.output[0] 

1011 if out in yield_ops: 

1012 fLOG("+yo=%r" % out) 

1013 values[node.outputs_indices[0]] = yield_ops[out] 

1014 else: 

1015 raise RuntimeError( # pragma: no cover 

1016 "YieldOp output %r could not be found in " 

1017 "yield_ops: %r (node=%r)." % ( 

1018 out, list(sorted(yield_ops)), node.onnx_node)) 

1019 elif node_time: 

1020 t = perf_counter() 

1021 node.run(values) 

1022 t2 = perf_counter() 

1023 mtime.append(dict(i=i, name=node.onnx_node.name, 

1024 op_type=node.onnx_node.op_type, 

1025 time=t2 - t)) 

1026 else: 

1027 node.run(values) 

1028 added = 0 

1029 for k in range(len(values)): # pylint: disable=C0200 

1030 if values[k] is None: 

1031 continue 

1032 if k not in keys and k not in printed: 

1033 added += 1 

1034 printed.add(k) 

1035 name = list( 

1036 name for name in self._global_index # pylint: disable=C0206 

1037 if self._global_index[name] == k) 

1038 if isinstance(values[k], (numpy.ndarray, coo_matrix)): 

1039 name = name[0] 

1040 mini = numpy_min(values[k]) 

1041 maxi = numpy_max(values[k]) 

1042 fLOG("+kr{}'{}': {} (dtype={} min={} max={}{})".format( 

1043 "=" if len(values[k].shape) == 0 or min( 

1044 values[k].shape) > 0 else "*", 

1045 name, values[k].shape, values[k].dtype, 

1046 mini, maxi, 

1047 ' sparse' if isinstance(values[k], coo_matrix) else '')) 

1048 if verbose >= 3: 

1049 dispsimple(values[k]) 

1050 else: 

1051 fLOG("+kr='{}': {}".format( 

1052 name, type(values[k]))) 

1053 if verbose >= 3: # pragma: no cover 

1054 dispsimple(values[k]) 

1055 if added == 0: 

1056 fLOG("? no new result") # pragma: no cover 

1057 

1058 if intermediate: 

1059 values = [(v, k, values[v]) for k, v in self._global_index.items()] 

1060 values.sort() 

1061 values = OrderedDict((k, v) for _, k, v in values) 

1062 return (values, mtime) if node_time else values 

1063 

1064 try: 

1065 res = {k: values[self._global_index[k]] for k in self.outputs_} 

1066 except KeyError as e: # pragma: no cover 

1067 raise RuntimeError("Unable to find one output [{}]\n in [{}]" 

1068 ".".format(", ".join(sorted(self.outputs_)), 

1069 ", ".join(sorted(values)))) from e 

1070 return (res, mtime) if node_time else res 

1071 

1072 def build_intermediate(self, outputs=None, verbose=0, overwrite_types=None, 

1073 fLOG=None): 

1074 """ 

1075 Builds every possible :epkg:`ONNX` file 

1076 which computes one specific intermediate output 

1077 from the inputs. 

1078 

1079 :param outputs: subsets of outputs to get, 

1080 None to get all outputs, 

1081 :param overwrite_types: shape inference does not work all the time, 

1082 this allows to force types when building intermediate 

1083 results, see @see fn select_model_inputs_outputs 

1084 :param verbose: displays intermediate information 

1085 :param fLOG: logging function 

1086 :return: :epkg:`*py:collections:OrderedDict` 

1087 

1088 .. versionchanged: 0.6 

1089 """ 

1090 if verbose > 0: 

1091 fLOG('[build_intermediate] BEGIN.') # pragma: no cover 

1092 if outputs is not None: 

1093 if isinstance(outputs, str): 

1094 outputs = [outputs] 

1095 if not isinstance(outputs, set): 

1096 outputs = set(outputs) 

1097 ord = OrderedDict() 

1098 for output in enumerate_model_node_outputs(self.obj, order=True): 

1099 if outputs is not None and output not in outputs: 

1100 continue 

1101 subonx = select_model_inputs_outputs( 

1102 self.obj, outputs=output, infer_shapes=True, 

1103 overwrite=overwrite_types) 

1104 subonx = onnx_remove_node_unused(subonx) 

1105 if verbose > 0: 

1106 fLOG( # pragma: no cover 

1107 '[build_intermediate] + {}'.format(output)) 

1108 ord[output] = OnnxInference(subonx, runtime=self.runtime, 

1109 skip_run=self.skip_run, 

1110 runtime_options=self.runtime_options, 

1111 inplace=self.inplace, 

1112 input_inplace=self.input_inplace) 

1113 if verbose > 0: 

1114 fLOG( # pragma: no cover 

1115 '[build_intermediate] END.') 

1116 return ord 

1117 

1118 def _run_whole_runtime(self, inputs, clean_right_away=False, 

1119 intermediate=False, verbose=0, node_time=False, 

1120 overwrite_types=None, yield_ops=None, fLOG=None): 

1121 # node_time is unused 

1122 if clean_right_away: 

1123 raise RuntimeError( # pragma: no cover 

1124 "clean_right_away=true does not work with this runtime.") 

1125 if intermediate: 

1126 if hasattr(self, "intermediate_onnx_inference_"): 

1127 inter_run = self.intermediate_onnx_inference_ # pylint: disable=E0203 

1128 else: 

1129 if verbose > 0: 

1130 fLOG( # pragma: no cover 

1131 "-- OnnxInference: build intermediate") 

1132 inter_run = self.build_intermediate( 

1133 verbose=verbose, fLOG=fLOG, overwrite_types=overwrite_types) 

1134 self.intermediate_onnx_inference_ = inter_run 

1135 graph = self.to_sequence() 

1136 self.inits_ = graph['inits'] 

1137 

1138 if verbose >= 1: 

1139 fLOG( # pragma: no cover 

1140 "-- OnnxInference: run {} nodes".format( 

1141 len(self.intermediate_onnx_inference_))) 

1142 values = OrderedDict(inputs) 

1143 for k, v in self.inits_.items(): 

1144 values[k] = v['value'] 

1145 if verbose >= 2: # pragma: no cover 

1146 for k in sorted(values): 

1147 fLOG("-k='{}' shape={} dtype={}".format( 

1148 k, values[k].shape, values[k].dtype)) 

1149 for node, oinf in self.intermediate_onnx_inference_.items(): 

1150 if verbose >= 4: # pragma: no cover 

1151 fLOG('[intermediate] %r' % node) 

1152 if verbose >= 5: # pragma: no cover 

1153 fLOG(oinf.obj) 

1154 if yield_ops is not None and node.onnx_node.op_type == 'YieldOp': 

1155 out = node.onnx_node.output[0] 

1156 if out in yield_ops: 

1157 values[out] = yield_ops[out] 

1158 continue 

1159 raise RuntimeError( # pragma: no cover 

1160 "YieldOp output %r could not be found in " 

1161 "yield_ops: %r (node=%r)." % ( 

1162 out, list(sorted(yield_ops)), node.onnx_node)) 

1163 output = oinf.run(inputs)[node] 

1164 values[node] = output 

1165 if verbose >= 1: 

1166 if verbose >= 4: # pragma: no cover 

1167 for k, v in inputs.items(): 

1168 if isinstance(output, numpy.ndarray): 

1169 fLOG("-i='{}': {} (dtype={}) {}".format( 

1170 k, v.shape, v.dtype, v.ravel().tolist())) 

1171 else: 

1172 fLOG("-i='{}': {} (dtype={}) - ?".format( 

1173 k, v.shape, v.dtype)) 

1174 if isinstance(output, numpy.ndarray): 

1175 fLOG("+k='{}': {} (dtype={})".format( # pragma: no cover 

1176 node, output.shape, output.dtype)) 

1177 if verbose >= 2: # pragma: no cover 

1178 fLOG(output) 

1179 else: 

1180 fLOG("+k='{}': {}".format( # pragma: no cover 

1181 node, type(output))) 

1182 if verbose >= 2: # pragma: no cover 

1183 fLOG(output) 

1184 return values 

1185 

1186 if verbose != 0: 

1187 warnings.warn( 

1188 "verbose option not implemented if runtime is 'onnxruntime1'") 

1189 res = self._whole.run(inputs) 

1190 return {k: v for k, v in zip(self.outputs_, res)} 

1191 

1192 def __getitem__(self, item): 

1193 """ 

1194 Returns the ONNX verions of a node. 

1195 """ 

1196 if isinstance(item, tuple): 

1197 node_name, att_name = item 

1198 else: 

1199 node_name = item 

1200 att_name = None 

1201 

1202 node_ = None 

1203 for node in self.obj.graph.node: 

1204 if node.name == node_name: 

1205 node_ = node 

1206 break 

1207 

1208 if node_ is None: 

1209 raise IndexError( # pragma: no cover 

1210 "Unable to get node name '{}'.\n{}".format( 

1211 node_name, "\n".join(node.name for node in self.obj.graph.node))) 

1212 

1213 if att_name is None: 

1214 return node_ 

1215 

1216 for att in node_.attribute: 

1217 if att.name == att_name: 

1218 return att 

1219 

1220 raise IndexError( # pragma: no cover 

1221 "Unable to find attribute '{}' from node " 

1222 "'{}'.".format(att_name, node_name)) 

1223 

1224 def switch_initializers_dtype(self, model=None, 

1225 dtype_in=numpy.float32, 

1226 dtype_out=numpy.float64): 

1227 """ 

1228 Switches all initializers to ``numpy.float64``. If *model* 

1229 is None, a simple cast is done. Otherwise, the function assumes 

1230 the model is a :epkg:`scikit-learn` pipeline. 

1231 This only works if the runtime is ``'python'``. 

1232 

1233 @param model :epkg:`scikit-learn` model or None 

1234 @param dtype_in previous type 

1235 @param dtype_out next type 

1236 @return done operations 

1237 """ 

1238 from ..onnx_tools.optim.sklearn_helper import enumerate_fitted_arrays, pairwise_array_distances 

1239 

1240 if self.runtime != 'python': # pragma: no cover 

1241 raise RuntimeError("Initializers can be casted only if the " 

1242 "runtime is 'python' not '{}'.".format(self.runtime)) 

1243 

1244 if hasattr(self, '_values_init'): 

1245 del self._values_init 

1246 

1247 # first pass: simple cast 

1248 done = [] 

1249 initializer = self.inits_ 

1250 for k, v in initializer.items(): 

1251 if isinstance(v['value'], numpy.ndarray): 

1252 if v['value'].dtype == dtype_in: 

1253 v['value'] = v['value'].astype(dtype_out) 

1254 done.append(("pass1", "+", "init", k, v['value'])) 

1255 else: 

1256 done.append(("pass1", "-", "init", k, 

1257 v['value'])) # pragma: no cover 

1258 for k, v in self.graph_['nodes'].items(): 

1259 res = v.switch_initializers_dtype(dtype_in=dtype_in, 

1260 dtype_out=dtype_out) 

1261 for r in res: 

1262 done.append(("pass1", "node", k) + r) 

1263 for k, v in self.graph_['intermediate'].items(): 

1264 if v is None: 

1265 continue 

1266 res = v.switch_initializers_dtype(dtype_in=dtype_in, 

1267 dtype_out=dtype_out) 

1268 for r in res: 

1269 done.append(("pass1", "sub", k) + r) 

1270 

1271 if model is not None: 

1272 # Second pass, we compare all arrays from the model 

1273 # to the arrays in the converted models. 

1274 def dist(a): 

1275 cast = a.astype(dtype_in).astype(dtype_out) 

1276 d = pairwise_array_distances([cast], [a])[0, 0] 

1277 return d 

1278 

1279 done_ = [(c, c[-1]) for c in done] 

1280 moda_ = [(a, a[-2][-1]) for a in enumerate_fitted_arrays(model) 

1281 if dist(a[-2][-1]) > 0] 

1282 aconv = [_[-1] for _ in done_] 

1283 amoda = [_[-1] for _ in moda_] 

1284 distances = pairwise_array_distances(aconv, amoda) 

1285 

1286 for i in range(distances.shape[0]): 

1287 j = numpy.argmin(distances[i]) 

1288 d = distances[i, j] 

1289 if d < 0.1: 

1290 numpy.copyto(aconv[i], amoda[j]) 

1291 done.append(("pass2", d) + done_[i][0]) 

1292 

1293 return done 

1294 

1295 def _set_shape_inference_runtime(self): 

1296 """ 

1297 Set shapes based on shape inference 

1298 relying on the runtime. 

1299 The values are stored in every node. 

1300 """ 

1301 if not hasattr(self, 'sequence_') or not hasattr(self, 'inputs_'): 

1302 raise RuntimeError( # pragma: no cover 

1303 "This method only works if the runtime is 'python' not " 

1304 "'{}'.".format(self.runtime)) 

1305 values = OrderedDict() 

1306 for k, v in self.inputs_.items(): 

1307 # The function assumes the first dimension is unknown 

1308 # and is the batch size. 

1309 try: 

1310 values[k] = ShapeObject(v, use_n1=True, name=k) 

1311 except TypeError as e: # pragma: no cover 

1312 if v['type']['elem'] == 'unk': 

1313 impossible = True 

1314 values[k] = None 

1315 continue 

1316 raise TypeError( 

1317 "Unable to guess shape for %r (shape=%r)." % ( 

1318 k, v)) from e 

1319 

1320 impossible = False 

1321 for k, v in self.statics_.items(): 

1322 # static inputs should be known. 

1323 if k not in values: 

1324 try: 

1325 values[k] = ShapeObject(v) 

1326 except TypeError: 

1327 # default value is wrong 

1328 impossible = True 

1329 values[k] = None 

1330 

1331 for k, v in self.inits_.items(): 

1332 values[k] = ShapeObject(v['value'], name=k) 

1333 last = None 

1334 for i, node in enumerate(self.sequence_): 

1335 try: 

1336 s = node._set_shape_inference_runtime(values) 

1337 last = s 

1338 except (IndexError, TypeError, KeyError, 

1339 AttributeError) as e: # pragma: no cover 

1340 rows = [] 

1341 if last is not None: 

1342 for k, v in last.items(): 

1343 rows.append("{}: {}".format(k, v)) 

1344 for k in range(i + 1): 

1345 rows.append("{} --> {}".format(k, self.sequence_[k])) 

1346 if not impossible: 

1347 raise RuntimeError("Unable to infer shape of node {}\n{}".format( 

1348 i, '\n'.join(rows))) from e 

1349 return values 

1350 

1351 def infer_shapes(self): 

1352 """ 

1353 Computes expected shapes. 

1354 

1355 :return: dictionary of shapes 

1356 """ 

1357 return self._set_shape_inference_runtime() 

1358 

1359 def _set_type_inference_runtime(self, inputs=None): 

1360 """ 

1361 Set types based on type inference 

1362 relying on the runtime. 

1363 The values are stored in every node. 

1364 """ 

1365 if not hasattr(self, 'sequence_') or not hasattr(self, 'inputs_'): 

1366 raise RuntimeError( # pragma: no cover 

1367 "This method only works if the runtime is 'python' not " 

1368 "'{}'.".format(self.runtime)) 

1369 

1370 values = OrderedDict() 

1371 for k, v in self.statics_.items(): 

1372 values[k] = None 

1373 

1374 if inputs is None: 

1375 for k, v in self.inputs_.items(): 

1376 # The function assumes the first dimension is unknown 

1377 # and is the batch size. 

1378 if isinstance(v['type']['elem'], dict): 

1379 # sequence 

1380 values[k] = SequenceType() 

1381 else: 

1382 values[k] = guess_numpy_type_from_string(v['type']['elem']) 

1383 else: 

1384 for name, dtype in zip(self.input_names, inputs): 

1385 values[name] = dtype 

1386 

1387 for k, v in self.inits_.items(): 

1388 values[k] = v['value'].dtype 

1389 

1390 last = None 

1391 for i, node in enumerate(self.sequence_): 

1392 try: 

1393 s = node._set_type_inference_runtime(values) 

1394 last = s 

1395 except IndexError as e: # pragma: no cover 

1396 rows = [] 

1397 if last is not None: 

1398 for k, v in last.items(): 

1399 rows.append("{}: {}".format(k, v)) 

1400 for k in range(i + 1): 

1401 rows.append("{} --> {}".format(k, self.sequence_[k])) 

1402 raise RuntimeError("Unable to infer type of node {}\n{}".format( 

1403 i, '\n'.join(rows))) from e 

1404 return values 

1405 

1406 def infer_types(self, inputs=None): 

1407 """ 

1408 Computes expected shapes. 

1409 

1410 :param inputs: needed when this class host a function and not a graph 

1411 :return: dictionary of types 

1412 """ 

1413 return self._set_type_inference_runtime(inputs) 

1414 

1415 def _set_size_inference_runtime(self, inputs, context=None): 

1416 """ 

1417 Set sizes allocated during inference 

1418 relying on the runtime. 

1419 The values are stored in every node. 

1420 """ 

1421 if not hasattr(self, 'sequence_') or not hasattr(self, 'inputs_'): 

1422 raise RuntimeError( # pragma: no cover 

1423 "This method only works if the runtime is 'python' not " 

1424 "'{}'.".format(self.runtime)) 

1425 values = OrderedDict() 

1426 for k, v in self.statics_.items(): 

1427 if context is None: 

1428 raise RuntimeError( # pragma: no cover 

1429 "static variable but context is None.") 

1430 values[k] = context[k] 

1431 for k, v in self.inits_.items(): 

1432 values[k] = v['value'] 

1433 for k, v in self.inputs_.items(): 

1434 if k in inputs: 

1435 values[k] = inputs[k] 

1436 

1437 last = None 

1438 for i, node in enumerate(self.sequence_): 

1439 try: 

1440 s = node._set_size_inference_runtime(values) 

1441 last = s 

1442 except IndexError as e: # pragma: no cover 

1443 rows = [] 

1444 if last is not None: 

1445 for k, v in last.items(): 

1446 rows.append("{}: {}".format(k, v)) 

1447 for k in range(i + 1): 

1448 rows.append("{} --> {}".format(k, self.sequence_[k])) 

1449 raise RuntimeError("Unable to infer size of node {}\n{}".format( 

1450 i, '\n'.join(rows))) from e 

1451 return values 

1452 

1453 def infer_sizes(self, inputs, context=None): 

1454 """ 

1455 Computes expected sizes. 

1456 

1457 :param inputs: inputs as a dictionary 

1458 :return: dictionary of dictionary of sizes 

1459 """ 

1460 res = self._set_size_inference_runtime(inputs, context=context) 

1461 return {k: v for k, v in res.items() if k.startswith('#')} 

1462 

1463 def _guess_inplace(self, input_inplace=False): 

1464 """ 

1465 Looks into every node of the graph to see 

1466 if there is a way to do the computation 

1467 inplace. By default (*input_inplace=False*), 

1468 the function assumes inputs cannot be modified 

1469 so the first node cannot do inplace computation. 

1470 This function only works with the python runtime. 

1471 

1472 @param input_inplace the computation is allowed 

1473 to overwrite the input 

1474 

1475 This function checks that one node is used only 

1476 once and then can be modified by the next node. 

1477 Nodes `A`, `C` can be overwritten by the computation. 

1478 Node `B` cannot as it is used by two nodes. 

1479 

1480 .. blockdiag:: 

1481 

1482 diagram { 

1483 A -> B -> C -> E; 

1484 B -> D; 

1485 } 

1486 

1487 It does not handle specific case such node `B` being 

1488 overwritten by node `C` but without changing its shape 

1489 and node `D` only needs the shape of `B`. Then `B` could 

1490 be overwritten as well. 

1491 """ 

1492 forbid = {} 

1493 values = OrderedDict() 

1494 for k in self.statics_: 

1495 values[k] = dict(inplace=False, to=[], fr=[]) 

1496 for k in self.inputs_: 

1497 values[k] = dict(inplace=input_inplace, to=[], fr=[]) 

1498 for k in self.inits_: 

1499 values[k] = dict(inplace=False, to=[], fr=[]) 

1500 for node in self.sequence_: 

1501 for n in node.inputs: 

1502 if n == '': 

1503 continue 

1504 values[n]['to'].append(node) 

1505 for n in node.outputs: 

1506 if node.op_type == 'Constant': 

1507 # We cannot modify constant. 

1508 forbid[n] = node 

1509 if n not in values: 

1510 values[n] = dict(inplace=None, to=[], fr=[]) 

1511 values[n]['fr'].append(node) 

1512 

1513 # checks the number of outputs 

1514 outputs = set(self.output_names) 

1515 modif = 1 

1516 while modif > 0: 

1517 modif = 0 

1518 for n, v in values.items(): 

1519 if v['inplace'] is not None: 

1520 continue 

1521 if n in forbid: 

1522 continue 

1523 if len(v['to']) == 1: 

1524 v['inplace'] = True 

1525 modif += 1 

1526 

1527 # convey the information to every node 

1528 inplaces = {} 

1529 for n, v in values.items(): 

1530 if v['inplace']: 

1531 inplaces[n] = v 

1532 for node in v['to']: 

1533 if n in outputs: 

1534 continue 

1535 node.enable_inplace_compute(n) 

1536 

1537 return inplaces 

1538 

1539 def _build_compile_run(self, debug=False): 

1540 """ 

1541 Rewrite the run function in python, 

1542 compiles it, and adds it as a method. 

1543 

1544 @param debug insert debugging code 

1545 @return method name, callable object 

1546 

1547 .. exref:: 

1548 :title: Run a model with runtime 'python_compiled' 

1549 

1550 The following code trains a model and compute 

1551 the predictions with runtime ``'python_compiled'``. 

1552 It converts the onnx graph into a python function 

1553 which calls every operator. Its code is printed 

1554 below. 

1555 

1556 .. runpython:: 

1557 :showcode: 

1558 :warningout: DeprecationWarning 

1559 

1560 import numpy 

1561 from sklearn.datasets import load_iris 

1562 from sklearn.model_selection import train_test_split 

1563 from sklearn.ensemble import AdaBoostClassifier 

1564 from sklearn.tree import DecisionTreeClassifier 

1565 from mlprodict.onnx_conv import to_onnx 

1566 from mlprodict.onnxrt import OnnxInference 

1567 

1568 iris = load_iris() 

1569 X, y = iris.data, iris.target 

1570 X_train, X_test, y_train, __ = train_test_split(X, y, random_state=11) 

1571 y_train = y_train.astype(numpy.float32) 

1572 clr = AdaBoostClassifier( 

1573 base_estimator=DecisionTreeClassifier(max_depth=3), 

1574 n_estimators=3) 

1575 clr.fit(X_train, y_train) 

1576 

1577 model_def = to_onnx(clr, X_train.astype(numpy.float32), 

1578 target_opset=12) 

1579 

1580 oinf2 = OnnxInference(model_def, runtime='python_compiled') 

1581 print(oinf2.run({'X': X_test[:5]})) 

1582 

1583 # prints out the python function equivalent 

1584 # to the onnx graph 

1585 print(oinf2) 

1586 """ 

1587 

1588 def clean_name(name): 

1589 res = name.replace(":", "_").replace('.', '_').replace('/', '_') 

1590 if iskeyword(res): 

1591 res += '_' 

1592 return res 

1593 

1594 # inits 

1595 inputs = self.input_names 

1596 code = ['def compiled_run(dict_inputs, yield_ops=None):'] 

1597 code.append(" if yield_ops is not None:") 

1598 code.append( 

1599 " raise NotImplementedError('yields_ops should be None.')") 

1600 if debug: 

1601 code.append(" printed = {}") 

1602 

1603 context = {} 

1604 

1605 # static variables 

1606 for k in sorted(self.statics_): 

1607 code.append(" # static: {0}".format(k)) 

1608 code.append(" {0} = dict_inputs['{1}']".format( 

1609 clean_name(k), k)) 

1610 if debug: 

1611 code.append( 

1612 " debug_print('i.{0}', {1}, printed)".format( 

1613 clean_name(k), k)) 

1614 

1615 # initializers 

1616 for k, v in sorted(self.inits_.items()): 

1617 if k.startswith("_OPT_"): 

1618 raise RuntimeError( # pragma: no cover 

1619 "The runtime cannot handle any constant name " 

1620 "starting with '_OPT_': '{}'.".format(k)) 

1621 if k in inputs: 

1622 context["_OPT_" + clean_name(k)] = v['value'] 

1623 code.append(" # init: _OPT_{0} ({1})".format( 

1624 clean_name(k), k)) 

1625 if debug: 

1626 code.append( 

1627 " debug_print('c.[_OPT_{0}]', _OPT_{1}, printed)".format( 

1628 clean_name(k), k)) 

1629 else: 

1630 context[clean_name(k)] = v['value'] 

1631 code.append(" # init: {0} ({1})".format( 

1632 clean_name(k), k)) 

1633 if debug: 

1634 code.append( 

1635 " debug_print('c.[{0}]', {1}, printed)".format( 

1636 clean_name(k), k)) 

1637 

1638 # method signature 

1639 code.append(" # inputs") 

1640 for inp in inputs: 

1641 if '_OPT_' + inp in context: 

1642 # optional inputs 

1643 code.append( 

1644 " {0} = dict_inputs.get('{1}', _OPT_{0})".format( 

1645 clean_name(inp), inp)) 

1646 else: 

1647 code.append(" {0} = dict_inputs['{1}']".format( 

1648 clean_name(inp), inp)) 

1649 if debug: 

1650 code.append( 

1651 " debug_print('i.{0}', {1}, printed)".format( 

1652 clean_name(inp), inp)) 

1653 

1654 # code 

1655 for i, node in enumerate(self.sequence_): 

1656 name = "n{}_{}".format(i, node.ops_.__class__.__name__.lower()) 

1657 if node.ops_ is None: 

1658 context[name] = node.function_ 

1659 # The code of the function should be added but only once. 

1660 raise NotImplementedError( 

1661 "Not implemented for models including functions.") 

1662 else: 

1663 context[name] = node.ops_._run 

1664 if (node.ops_.__class__.__name__ == 'Loop' and 

1665 node.ops_.need_context()): 

1666 # Adding context. 

1667 ctx = "{%s}" % ", ".join( 

1668 "'%s': %s" % (n, n) for n in node.ops_.additional_inputs) 

1669 code.append(' ({1}, ) = {2}({0}, context={3})'.format( 

1670 ', '.join(map(clean_name, node.inputs)), 

1671 ', '.join(map(clean_name, node.outputs)), 

1672 name, ctx)) 

1673 else: 

1674 code.append(' ({1}, ) = {2}({0})'.format( 

1675 ', '.join(map(clean_name, node.inputs)), 

1676 ', '.join(map(clean_name, node.outputs)), 

1677 name)) 

1678 if debug: 

1679 code.append(" print('''# {}''')".format(code[-1][4:])) 

1680 for o in node.outputs: 

1681 code.append( 

1682 " debug_print('o.{0}', {1}, printed)".format( 

1683 clean_name(o), o)) 

1684 

1685 # return 

1686 code.append(' return {') 

1687 for out in self.output_names: 

1688 code.append(" '{1}': {0},".format( 

1689 clean_name(out), out)) 

1690 code.append(' }') 

1691 final_code = '\n'.join(code) 

1692 

1693 # compile the outcome 

1694 context['self'] = self 

1695 try: 

1696 obj = compile(final_code, "<string>", 'exec') 

1697 except SyntaxError as e: # pragma: no cover 

1698 raise SyntaxError( 

1699 "Unable to compile\n#####\n{}".format(final_code)) from e 

1700 fcts_obj = [_ for _ in obj.co_consts 

1701 if _ is not None and not isinstance(_, (bool, str, int))] 

1702 fct = make_callable( 

1703 "compiled_run", fcts_obj[0], final_code, context, debug) 

1704 

1705 # end 

1706 return "compiled_run", fct, final_code 

1707 

1708 def reduce_size(self, pickable=False): 

1709 """ 

1710 Reduces the memory footprint as much as possible. 

1711 

1712 @param pickable keeps a pickle object? 

1713 """ 

1714 import gc 

1715 del self.graph_ 

1716 if not pickable: 

1717 del self.obj 

1718 if self.runtime in ('python_compiled', 'python_compiled_debug'): 

1719 del self.sequence_ 

1720 gc.collect() 

1721 

1722 def get_profiling(self, as_df=False): 

1723 """ 

1724 Returns the profiling after a couple of execution. 

1725 

1726 :param as_df: return the results as a dataframe (True) 

1727 :return: dataframe or list of dictionaries 

1728 

1729 .. versionadded:: 0.6 

1730 """ 

1731 if (self.runtime_options is None or 

1732 not self.runtime_options.get('enable_profiling', False)): 

1733 raise RuntimeError( 

1734 "Profiling is available if options 'enable_profiling' " 

1735 "is set to true in 'runtime_options' but is %r." % self.runtime_options) 

1736 prof = None 

1737 if hasattr(self, '_whole'): 

1738 prof = self._whole.get_profiling() 

1739 if prof is None: 

1740 raise NotImplementedError( # pragma: no cover 

1741 "profiling is only implemented for runtime 'onnxruntime1'.") 

1742 if as_df: 

1743 import pandas 

1744 return pandas.DataFrame(prof) 

1745 return prof 

1746 

1747 def get_execution_order(self): 

1748 """ 

1749 This function returns a dictionary `{(kind, name): (order, op)}`, 

1750 *name* can be a node name or a result name. In that case, 

1751 it gets the execution order than the node which created it. 

1752 The function returns None if the order is not available 

1753 (the selected runtime does not return it). *kind* is either 

1754 `'node'` or `'node'`. If two nodes have the same name, 

1755 returned order is the last one. Initializers gets an execution 

1756 order equal to -1, inputs to 0, all others results are >= 1. 

1757 

1758 .. versionadded:: 0.7 

1759 """ 

1760 if not hasattr(self, "sequence_"): 

1761 return None 

1762 

1763 res = {} 

1764 for k, v in self.inits_.items(): 

1765 res['res', k] = (-1, v) 

1766 for name, shape in self.input_names_shapes: 

1767 res['res', name] = (0, shape) 

1768 

1769 for i, node in enumerate(self.sequence_): 

1770 key = ('node', node.onnx_node.name) 

1771 res[key] = (i + 1, node) 

1772 for out in node.onnx_node.output: 

1773 key = ('res', out) 

1774 if key in res: 

1775 raise RuntimeError( # pragma: no cover 

1776 "Output %r of node name %r already registered." 

1777 "" % (out, node.onnx_node.name)) 

1778 res[key] = (i + 1, None) 

1779 

1780 return res