Coverage for mlprodict/onnxrt/onnx_inference_node.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

264 statements  

1""" 

2@file 

3@brief OnnxInferenceNode definition. 

4""" 

5import sys 

6import pprint 

7import numpy 

8from onnx import onnx_pb as onnx_proto 

9from onnx.onnx_cpp2py_export.defs import SchemaError # pylint: disable=E0401,E0611 

10from ..onnx_tools.onnx2py_helper import get_onnx_schema 

11from .excs import MissingOperatorError 

12from .ops import load_op 

13 

14 

15class OnnxInferenceNode: 

16 """ 

17 A node to execute. 

18 """ 

19 class OnnxInferenceWrapper: 

20 """ 

21 Wraps @see cl OnnxInference in a wrapper and exposes 

22 the necessary function. 

23 

24 :param oinf: instance of @see cl OnnxInference 

25 """ 

26 

27 def __init__(self, oinf): 

28 if oinf is None: 

29 raise ValueError( # pragma: no cover 

30 "oinf cannot be None.") 

31 self.oinf = oinf 

32 

33 @property 

34 def args_default(self): 

35 "Returns the list of default arguments." 

36 return [] 

37 

38 @property 

39 def args_default_modified(self): 

40 "Returns the list of modified arguments." 

41 return [] 

42 

43 @property 

44 def args_mandatory(self): 

45 "Returns the list of mandatory arguments." 

46 return self.oinf.input_names 

47 

48 @property 

49 def args_optional(self): 

50 "Returns the list of optional arguments." 

51 return [] 

52 

53 @property 

54 def obj(self): 

55 "Returns the ONNX graph." 

56 return self.oinf.obj 

57 

58 def run(self, *args, **kwargs): 

59 "Calls run." 

60 return self.oinf.run(*args, **kwargs) 

61 

62 def to_python(self, inputs, *args, **kwargs): 

63 "Calls to_python." 

64 res = self.oinf.to_python(*args, **kwargs) 

65 if len(res) != 1: 

66 raise NotImplementedError( # pragma: no cover 

67 "Not implemented if the code has multiple files.") 

68 keys = list(res) 

69 value = res[keys[0]] 

70 lines = value.split('\n') 

71 last = 0 

72 for i, line in enumerate(lines): 

73 if line.startswith('def '): 

74 last = i - 1 

75 break 

76 imports = '\n'.join( 

77 line for line in lines[:last] if 'import ' in line) 

78 lines.append('') 

79 lines.append("return OnnxPythonInference().run(%s)" % 

80 ', '.join(inputs)) 

81 code = '\n'.join(lines[last:]) 

82 return imports, code 

83 

84 def need_context(self): 

85 "Needs context?" 

86 return False 

87 

88 def infer_types(self, *args): 

89 "Calls infer_types." 

90 res = self.oinf.infer_types(args) 

91 names = self.oinf.obj.output 

92 dtypes = [res[n] for n in names] 

93 return tuple(dtypes) 

94 

95 def infer_sizes(self, *args): 

96 "Calls infer_sizes." 

97 values = {name: value 

98 for name, value in zip(self.oinf.input_names, args)} 

99 res = self.oinf.infer_sizes(values) 

100 names = self.oinf.obj.output 

101 sizes = [res.get(n, 0) for n in names] 

102 return (res['#'], ) + tuple(sizes) 

103 

104 def enable_inplace_compute(self, index): 

105 "Not implemented." 

106 pass 

107 

108 def __init__(self, onnx_node, desc, global_index): 

109 """ 

110 @param onnx_node onnx_node 

111 @param desc internal description 

112 @param global_index it is a function which returns a unique index 

113 for the output this operator generates 

114 """ 

115 if desc is None: 

116 raise ValueError("desc should not be None.") # pragma: no cover 

117 self.desc = desc 

118 self.onnx_node = onnx_node 

119 self._init(global_index) 

120 

121 @property 

122 def name(self): 

123 "Returns the ONNX name." 

124 return "_".join( 

125 [self.desc['domain'], self.onnx_node.op_type]).replace( 

126 ".", "_").replace('__', '_').strip('_') 

127 

128 def _init(self, global_index): 

129 """ 

130 Prepares the node. 

131 """ 

132 self.op_type = self.onnx_node.op_type 

133 self.order = -1 

134 self.variable_to_clean = [] 

135 self.inputs = list(self.onnx_node.input) 

136 self.outputs = list(self.onnx_node.output) 

137 self.inplaces = [] 

138 self.inputs_indices = [global_index(name) for name in self.inputs] 

139 self.outputs_indices = [global_index(name) for name in self.outputs] 

140 self._global_index = global_index 

141 

142 def set_order(self, order): 

143 """ 

144 Defines the order of execution. 

145 """ 

146 self.order = order 

147 

148 def add_variable_to_clean(self, name): 

149 """ 

150 Adds a variable which can be cleaned after the node 

151 execution. 

152 """ 

153 self.variable_to_clean.append(name) 

154 

155 def __str__(self): 

156 "usual" 

157 return "Onnx-{}({}) -> {}{}".format( 

158 self.op_type, ", ".join(self.inputs), ", ".join(self.outputs), 

159 " (name=%r)" % self.onnx_node.name 

160 if self.onnx_node.name else "") 

161 

162 def __repr__(self): 

163 "usual" 

164 return self.__str__() 

165 

166 def setup_runtime(self, runtime=None, variables=None, rt_class=None, 

167 target_opset=None, dtype=None, domain=None, 

168 ir_version=None, runtime_options=None, 

169 build_inference_node_function=None): 

170 """ 

171 Loads runtime. 

172 

173 :param runtime: runtime options 

174 :param variables: registered variables created by previous operators 

175 :param rt_class: runtime class used to compute 

176 prediction of subgraphs 

177 :param target_opset: use a specific target opset 

178 :param dtype: float computational type 

179 :param domain: node domain 

180 :param ir_version: if not None, changes the default value 

181 given by :epkg:`ONNX` 

182 :param runtime_options: runtime options 

183 :param build_inference_node_function: function creating an inference 

184 runtime from an ONNX graph 

185 

186 .. versionchanged:: 0.9 

187 Parameter *build_inference_node_function* was added. 

188 """ 

189 if self.desc is None: 

190 raise AttributeError( 

191 "desc should not be None.") # pragma: no cover 

192 if rt_class is None: 

193 # path used when this operator is a function. 

194 self.function_ = OnnxInferenceNode.OnnxInferenceWrapper(runtime) 

195 self.ops_ = None 

196 else: 

197 self.function_ = None 

198 self.preprocess_parameters( 

199 runtime, rt_class, ir_version=ir_version, 

200 target_opset=target_opset) 

201 options = {'provider': runtime} if runtime else {} 

202 if domain is not None: 

203 options['domain'] = domain 

204 if target_opset is not None: 

205 options['target_opset'] = target_opset 

206 if ir_version is not None: 

207 options['ir_version'] = ir_version 

208 if runtime_options is not None: 

209 options.update({ 

210 k: v for k, v in runtime_options.items() 

211 if k not in {'log_severity_level'}}) 

212 try: 

213 if runtime is not None and runtime.startswith('onnxruntime2'): 

214 self.ops_ = load_op(self.onnx_node, desc=self.desc, 

215 options=options if options else None, 

216 variables=variables, dtype=dtype, 

217 runtime=runtime) 

218 elif runtime in ('python_compiled', 'python_compiled_debug'): 

219 options['provider'] = 'python' 

220 self.ops_ = load_op(self.onnx_node, desc=self.desc, 

221 options=options if options else None, 

222 variables=variables, dtype=dtype, 

223 runtime=runtime) 

224 else: 

225 self.ops_ = load_op(self.onnx_node, desc=self.desc, 

226 options=options if options else None, 

227 variables=variables, dtype=dtype, 

228 runtime=runtime) 

229 except MissingOperatorError as e: 

230 try: 

231 onnx_schema = get_onnx_schema( 

232 self.onnx_node.op_type, self.onnx_node.domain, 

233 opset=target_opset) 

234 except SchemaError: 

235 raise e # pylint: disable=W0707 

236 if onnx_schema is None or not onnx_schema.has_function: 

237 raise e 

238 self.function_ = OnnxInferenceNode.OnnxInferenceWrapper( 

239 build_inference_node_function(onnx_schema.function_body)) 

240 self.ops_ = None 

241 

242 @staticmethod 

243 def _find_static_inputs(body): 

244 """ 

245 Determines the loop inputs. It is any defined inputs 

246 by the subgraphs + any results used as a constant 

247 in the subgraphs. 

248 """ 

249 inputs_set = set(i.name for i in body.input) 

250 for init in body.initializer: 

251 inputs_set.add(init.name) 

252 for node in body.node: 

253 for i in node.output: 

254 inputs_set.add(i) 

255 add_inputs = [] 

256 for node in body.node: 

257 for i in node.input: 

258 if i not in inputs_set: 

259 # no graph input or output node matches 

260 # it must be a constant from the below graph 

261 add_inputs.append(i) 

262 inputs_set.add(i) 

263 return add_inputs 

264 

265 def preprocess_parameters(self, runtime, rt_class, ir_version=None, 

266 target_opset=None): 

267 """ 

268 Preprocesses the parameters, loads *GraphProto* 

269 (equivalent to :epkg:`ONNX` graph with less metadata). 

270 

271 @param runtime runtime options 

272 @param rt_class runtime class used to compute 

273 prediction of subgraphs 

274 @param ir_version if not None, overwrites the default value 

275 @param target_opset use a specific target opset 

276 """ 

277 if 'atts' not in self.desc: 

278 return # pragma: no cover 

279 inside_loop = self.onnx_node.op_type in {'Loop'} 

280 for _, v in self.desc['atts'].items(): 

281 if 'value' not in v: 

282 continue # pragma: no cover 

283 value = v['value'] 

284 if isinstance(value, onnx_proto.GraphProto): 

285 static_inputs = OnnxInferenceNode._find_static_inputs(value) 

286 try: 

287 sess = rt_class(v['value'], runtime=runtime, 

288 ir_version=ir_version, 

289 target_opset=target_opset, 

290 inside_loop=inside_loop, 

291 static_inputs=static_inputs) 

292 except RuntimeError as e: # pragma: no cover 

293 raise RuntimeError( 

294 "Unable to instantiate a node of type %r and name %r." 

295 "" % (self.onnx_node.op_type, self.onnx_node.name)) from e 

296 v['value_rt'] = sess 

297 

298 def run(self, values): 

299 """ 

300 Runs the node. 

301 the function updates values with outputs. 

302 

303 @param values list of existing values 

304 """ 

305 if self.ops_ is None: 

306 # Then a function. 

307 feeds = {name: val 

308 for name, val in zip(self.function_.obj.input, values)} 

309 outputs = self.function_.run(feeds) 

310 res = [outputs[k] for k in self.function_.obj.output] 

311 

312 if self.outputs_indices is None: 

313 for name, value in zip(self.outputs, res): 

314 values[name] = value 

315 else: 

316 for i, r in enumerate(res): 

317 values[self.outputs_indices[i]] = r 

318 return 

319 

320 # This code takes time if the graph contains many nodes. 

321 # Maybe a C++ container would help in that case (to skip GIL). 

322 if self.inputs_indices is None: 

323 args = list(values[k] for k in self.inputs) 

324 else: 

325 args = list(values[k] for k in self.inputs_indices) 

326 try: 

327 if self.ops_.need_context(): 

328 context = {n: values[self._global_index(n)] 

329 for n in self.ops_.additional_inputs} 

330 res = self.ops_.run(*args, context=context) 

331 else: 

332 res = self.ops_.run(*args) 

333 except TypeError as e: 

334 raise RuntimeError( # pragma: no cover 

335 "Unable to run operator %r, inputs=%r." 

336 "" % (type(self.ops_), self.inputs)) from e 

337 except OverflowError as e: 

338 raise RuntimeError( # pragma: no cover 

339 "Unable to run operator %r, inputs=%r." 

340 "" % (type(self.ops_), self.inputs)) from e 

341 

342 if not isinstance(res, tuple): 

343 raise RuntimeError( # pragma: no cover 

344 "Results of operator %r should be a tuple." % type(self.ops_)) 

345 if len(self.outputs) != len(res): 

346 raise RuntimeError( # pragma: no cover 

347 "Mismatch number of outputs got {} for names {}.\n{}".format( 

348 len(res), list(sorted(self.outputs)), 

349 pprint.pformat(self.desc))) 

350 

351 # This code takes times if the graph contains many nodes. 

352 # Maybe a C++ container would help in that case (to skip GIL). 

353 if self.outputs_indices is None: 

354 for name, value in zip(self.outputs, res): 

355 values[name] = value 

356 else: 

357 for i, r in enumerate(res): 

358 values[self.outputs_indices[i]] = r 

359 

360 def switch_initializers_dtype(self, dtype_in=numpy.float32, 

361 dtype_out=numpy.float64): 

362 """ 

363 Switches all initializers to ``numpy.float64``. 

364 This only works if the runtime is ``'python'``. 

365 

366 @param dtype_in previous type 

367 @param dtype_out next type 

368 @return done operations 

369 """ 

370 done = [] 

371 for k, v in self.desc['atts'].items(): 

372 if 'value_rt' not in v: 

373 continue 

374 if isinstance(v['value_rt'], numpy.ndarray): 

375 if v['value_rt'].dtype == dtype_in: 

376 v['value_rt'] = v['value_rt'].astype(dtype_out) 

377 done.append(("+", "desc", k, v['value_rt'])) 

378 else: 

379 done.append(("-", "desc", k, v['value_rt'])) 

380 if hasattr(self, 'ops_') and self.ops_ is not None: 

381 res = self.ops_.switch_initializers_dtype(dtype_in, dtype_out) 

382 for r in res: 

383 done.append(("ops_", ) + r) 

384 return done 

385 

386 def _set_shape_inference_runtime(self, values): 

387 """ 

388 Updates *values* which shapes of the outputs. 

389 

390 :param values: container for shapes 

391 """ 

392 if self.ops_ is None: 

393 # A function, unknown types. 

394 for name in self.outputs: 

395 values[name] = None 

396 return values 

397 args = [values[k] for k in self.inputs if k != ''] 

398 try: 

399 res = self.ops_.infer_shapes(*args) 

400 except (TypeError, ValueError, AttributeError) as e: # pragma: no cover 

401 raise TypeError( 

402 "Unable to call infer_shapes with {} arguments for class" 

403 " '{}' ({})".format( 

404 len(args), self.ops_.__class__.__name__, 

405 self.ops_.infer_shapes)) from e 

406 if res is not None: 

407 if not isinstance(res, tuple): 

408 raise RuntimeError( # pragma: no cover 

409 "Results of an operator should be a tuple for operator " 

410 "'{}'.".format(type(self.ops_))) 

411 if len(self.outputs) != len(res): 

412 raise RuntimeError( # pragma: no cover 

413 "Mismatch number of outputs got {} != {} for names {} " 

414 "(node='{}').\n{}".format( 

415 len(res), len(self.outputs), list(self.outputs), 

416 self.ops_.__class__.__name__, 

417 pprint.pformat(self.desc, depth=2))) 

418 for name, value in zip(self.outputs, res): 

419 values[name] = value 

420 return values 

421 

422 def _set_type_inference_runtime(self, values): 

423 """ 

424 Updates *values* which types of the outputs. 

425 

426 :param values: container for types 

427 """ 

428 args = [values[k] for k in self.inputs] 

429 if self.ops_ is None: 

430 res = self.function_.infer_types(*args) 

431 else: 

432 res = self.ops_.infer_types(*args) 

433 try: 

434 if self.ops_ is None: 

435 res = self.function_.infer_types(*args) 

436 else: 

437 res = self.ops_.infer_types(*args) 

438 except (TypeError, ValueError) as e: # pragma: no cover 

439 raise TypeError( 

440 "Unable to call infer_types with {} arguments for class" 

441 " '{}'".format( 

442 len(args), self.ops_.__class__.__name__)) from e 

443 if not isinstance(res, tuple): 

444 raise RuntimeError( # pragma: no cover 

445 "Results of an operator should be a tuple for operator '{}'" 

446 ".".format(type(self.ops_))) 

447 if len(self.outputs) != len(res): 

448 raise RuntimeError( # pragma: no cover 

449 "Mismatch number of outputs got {} != {} for names {} (node='{}')." 

450 "\n{}".format( 

451 len(res), len(self.outputs), list(self.outputs), 

452 self.ops_.__class__.__name__, 

453 pprint.pformat(self.desc, depth=2))) 

454 for name, value in zip(self.outputs, res): 

455 values[name] = value 

456 return values 

457 

458 def _set_size_inference_runtime(self, values): 

459 """ 

460 Updates *values* which types of the outputs. 

461 

462 :param values: container for sizes 

463 """ 

464 args = [values[k] for k in self.inputs] 

465 try: 

466 if (self.ops_ or self.function_).need_context(): 

467 context = {n: values[n] 

468 for n in self.ops_.additional_inputs} 

469 res = self.ops_.infer_sizes(*args, context=context) 

470 else: 

471 res = (self.ops_ or self.function_).infer_sizes(*args) 

472 except (TypeError, ValueError) as e: # pragma: no cover 

473 raise TypeError( 

474 "Unable to call infer_sizes with {} arguments for class" 

475 " '{}' ({})".format(len(args), self.ops_.__class__.__name__, 

476 self.ops_.infer_sizes)) from e 

477 if not isinstance(res, tuple): 

478 raise RuntimeError( # pragma: no cover 

479 "Results of an operator should be a tuple for operator '{}'" 

480 ".".format(type(self.ops_))) 

481 if len(self.outputs) + 1 != len(res): 

482 raise RuntimeError( # pragma: no cover 

483 "Mismatch number of outputs got {} != {} + 1 for names {} " 

484 "(node='{}').\n{}".format( 

485 len(res), len(self.outputs), list(self.outputs), 

486 self.ops_.__class__.__name__, 

487 pprint.pformat(self.desc, depth=2))) 

488 for name, value in zip(self.outputs, res[1:]): 

489 values[name] = value 

490 values['#' + self.onnx_node.name] = res[0] 

491 return values 

492 

493 def enable_inplace_compute(self, name): 

494 """ 

495 Let the node know that one input can be overwritten. 

496 

497 @param name input name 

498 """ 

499 self.inplaces.append(name) 

500 (self.ops_ or self.function_).enable_inplace_compute( 

501 self.inputs.index(name)) 

502 

503 @property 

504 def inputs_args(self): 

505 """ 

506 Returns the list of arguments as well as 

507 the list of parameters with the default values 

508 (close to the signature). 

509 """ 

510 if not hasattr(self, 'ops_'): 

511 raise AttributeError( 

512 "Attribute 'ops_' is missing.") # pragma: no cover 

513 sigs = [] 

514 ops_or_function = self.function_ if self.ops_ is None else self.ops_ 

515 mand = ops_or_function.args_mandatory 

516 if mand is None: 

517 mand = self.python_inputs 

518 sigs.extend(mand) 

519 if len(ops_or_function.args_optional) > 0: 

520 sigs.extend(ops_or_function.args_optional) 

521 if sys.version_info[:2] >= (3, 8): 

522 sigs.append('/') 

523 sigs.extend(ops_or_function.args_default) 

524 return sigs 

525 

526 @property 

527 def python_inputs(self): 

528 """ 

529 Returns the python arguments. 

530 """ 

531 if not hasattr(self, 'ops_'): 

532 raise AttributeError( 

533 "Attribute 'ops_' is missing.") # pragma: no cover 

534 if hasattr(self.ops_, 'python_inputs'): 

535 return self.ops_.python_inputs 

536 return self.inputs 

537 

538 @property 

539 def modified_args(self): 

540 """ 

541 Returns the list of modified parameters. 

542 """ 

543 if not hasattr(self, 'ops_'): 

544 raise AttributeError( 

545 "Attribute 'ops_' is missing.") # pragma: no cover 

546 if self.ops_ is None: 

547 return self.function_.args_default_modified 

548 return self.ops_.args_default_modified 

549 

550 def to_python(self, inputs): 

551 """ 

552 Returns a python code for this operator. 

553 

554 @param inputs inputs name 

555 @return imports, python code, both as strings 

556 """ 

557 if not hasattr(self, 'ops_'): 

558 raise AttributeError( 

559 "Attribute 'ops_' is missing.") # pragma: no cover 

560 if self.ops_ is None: 

561 return self.function_.to_python(inputs) 

562 return self.ops_.to_python(inputs)