Coverage for mlprodict/npy/xop.py: 90%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1182 statements  

1# pylint: disable=E1101,C0302 

2""" 

3@file 

4@brief Xop API to build onnx graphs. Inspired from :epkg:`sklearn-onnx`. 

5 

6.. versionadded:: 0.9 

7""" 

8import os 

9import pprint 

10import logging 

11import hashlib 

12from collections import OrderedDict 

13import numpy 

14from scipy.sparse.coo import coo_matrix 

15import onnx 

16from onnx import GraphProto, TensorProto, ValueInfoProto 

17from onnx.helper import ( 

18 make_node, make_graph, make_model, make_value_info, 

19 make_tensor_value_info, make_function, make_opsetid, 

20 make_tensor_type_proto, make_operatorsetid) 

21from onnx.numpy_helper import from_array, to_array 

22from onnx.shape_inference import infer_shapes 

23from ._cache import cache_folder 

24from .xop_variable import ( 

25 Variable, is_numpy_dtype, numpy_type_prototype, max_supported_opset, 

26 DetectedVariable, InputDetectedVariable, OutputDetectedVariable, 

27 NodeResultName, guess_numpy_type) 

28from .xop_auto import get_rst_doc 

29 

30 

31logger = logging.getLogger('xop') 

32 

33 

34def _default_OPSET_TO_IR_VERSION(): 

35 """ 

36 Returns the default mapping between opset and ir_version. 

37 

38 .. runpython:: 

39 :showcode: 

40 

41 import pprint 

42 from mlprodict.npy.xop import _default_OPSET_TO_IR_VERSION 

43 pprint.pprint(_default_OPSET_TO_IR_VERSION()) 

44 """ 

45 return { 

46 1: 3, 2: 3, 3: 3, 4: 3, 5: 3, 6: 3, 

47 7: 3, 8: 4, 9: 4, 10: 5, 11: 6, 12: 7, 

48 13: 7, 14: 7, 15: 8, 16: 8} 

49 

50 

51def _domain_to_class_name(domain): 

52 """ 

53 Converts domain into a name. 

54 

55 :param domain: domain name such as `ai.onnx.ml` 

56 :return: string 

57 

58 .. runpython:: 

59 :showcode: 

60 

61 from mlprodict.npy.xop import _domain_to_class_name 

62 print(_domain_to_class_name('ai.onnx.ml')) 

63 """ 

64 if domain == 'ai.onnx': 

65 return '' 

66 dom = domain.split('.') 

67 res = [] 

68 for d in dom: 

69 if len(d) == 0: 

70 res.append(d) 

71 elif len(d) == 1: 

72 res.append(d.upper()) 

73 else: 

74 res.append(d[0].upper() + d[1:]) 

75 return "".join(res) 

76 

77 

78def _populate_schemas(): 

79 """ 

80 Populates all schemas. 

81 """ 

82 res = {} 

83 versions = {} 

84 domains = {} 

85 for schema in onnx.defs.get_all_schemas_with_history(): 

86 if schema.support_level == schema.SupportType.EXPERIMENTAL: 

87 # Skips experimental operators. 

88 continue 

89 # Multiple version can coexist. The last one is kept. 

90 if schema.name in res: 

91 if schema.since_version > res[schema.name].since_version: 

92 # We keep the most recent one. 

93 res[schema.domain, schema.name] = schema 

94 else: 

95 res[schema.domain, schema.name] = schema 

96 full_name = schema.name + '_' + str(schema.since_version) 

97 res[schema.domain, full_name] = schema 

98 key = schema.domain, schema.name 

99 if key not in versions: 

100 versions[key] = set() 

101 if schema.name not in domains: 

102 domains[schema.name] = set() 

103 domains[schema.name].add(schema.domain) 

104 versions[key].add(full_name) 

105 return res, versions, domains 

106 

107 

108def _find_operator_domain(name): 

109 """ 

110 Determines the domain of an operator. 

111 Raises an exception if not found or if there is an ambiguity. 

112 

113 :param name: operator name 

114 :return: domain 

115 """ 

116 if name not in _all_domains: 

117 raise ValueError( 

118 "Unable to guess domain for operator %r. " 

119 "Not found in %r." % (name, list(_all_domains))) 

120 domains = _all_domains[name] 

121 if len(domains) == 1: 

122 return list(domains)[0] 

123 raise ValueError( # pragma: no cover 

124 "Unable to guess domain of operator %r, found domains %r." % ( 

125 name, domains)) 

126 

127 

128def ClassFactory(class_name, op_name, inputs, outputs, 

129 input_range, output_range, 

130 domain, attr_names, doc, 

131 deprecated, since_version, 

132 past_version): 

133 """ 

134 Dynamically creates a class for a specific operator. 

135 

136 :param class_name: class name 

137 :param op_name: operator type 

138 :param inputs: expected inputs 

139 :param outputs: expected outputs 

140 :param input_range: input range 

141 :param output_range: output_range 

142 :param domain: domain 

143 :param attr_names: attributes names 

144 :param doc: docstring 

145 :param deprecated: is the operator deprecated 

146 :param since_version: available since version 

147 :param past_version: list of versions 

148 """ 

149 

150 def __init__(self, *args, **kwargs): 

151 

152 op_version = kwargs.pop('op_version', None) 

153 if isinstance(op_version, dict): 

154 op_version = op_version.get(domain, None) 

155 

156 if op_version is None: 

157 if len(args) == 0 and input_range[0] == input_range[1]: 

158 args = [_[0] for _ in self.__class__.expected_inputs] 

159 if not (input_range[0] <= len(args) <= input_range[1]): 

160 raise RuntimeError( # pragma: no cover 

161 "Unexpected number of inputs, " 

162 "got {}, expecting {} for operator " 

163 "'{}'.".format( 

164 len(args), len(inputs), op_name)) 

165 

166 attr_names = self.attr_names 

167 if '_' in self.__class__.__name__: 

168 op_version_class = int(self.__class__.__name__.split('_')[-1]) 

169 if op_version is None: 

170 op_version = op_version_class 

171 try: 

172 op_version = min(op_version, op_version_class) 

173 except TypeError: # pragma: no cover 

174 raise TypeError( # pylint: disable=W0707 

175 "Could not compare versions {} ? {} for " 

176 "class '{}' since_version {}. Parameter 'op_version' " 

177 "is probably missing when the class " 

178 "is instantiated.".format( 

179 op_version, op_version_class, class_name, 

180 since_version)) 

181 else: 

182 op_version_class = None 

183 

184 # By default, the op_version is None. 

185 # None means the latest available. 

186 if op_version is None: 

187 op_version = since_version 

188 

189 found = None 

190 if op_version is not None: 

191 # attr_names refers to the most recent version of 

192 # this operator. We may need an older one. 

193 for op in range(op_version, 0, -1): 

194 name = '{}_{}'.format(self.__class__.__name__, op) 

195 if name in self.past_version: 

196 found = (name, op) 

197 attr_names = self.past_version[name].attr_names 

198 break 

199 if (op_version_class is not None and found is not None and 

200 found[-1] != op_version_class): 

201 raise RuntimeError( # pragma: no cover 

202 "op_version={} does not refer to the same opset as the class " 

203 "name ('{}').".format(op_version, self.__class__.__name__)) 

204 for key in kwargs: 

205 if key in {'output_names', 'op_version', 'domain', 'ir_version', 

206 'global_context', 'clear_subgraph_inputs'}: 

207 continue 

208 if key not in attr_names: 

209 raise TypeError( # pragma: no cover 

210 "Argument '%s' not valid for '%s' opset=%s." 

211 % (key, op_name, op_version)) 

212 

213 if op_version is not None: 

214 kwargs['op_version'] = op_version 

215 # This class can only be created by a user. Let's check 

216 # types are either a variable, an operator or an array. 

217 for i, a in enumerate(args): 

218 if isinstance(a, tuple): 

219 if len(a) != 2: 

220 raise TypeError( # pragma: no cover 

221 "Input %r is a tuple or class %r, it must have two " 

222 "elements (name, type) not %r." % (i, class_name, a)) 

223 if not isinstance(a[0], str): 

224 raise TypeError( # pragma: no cover 

225 "Input %r is a tuple or class %r, it must be a tuple " 

226 "(name, type) not %r." % (i, class_name, a)) 

227 continue 

228 if not isinstance(a, ( 

229 Variable, OnnxOperator, numpy.ndarray, str, 

230 OnnxOperatorItem, coo_matrix)): 

231 raise TypeError( # pragma: no cover 

232 "Unexpected type %r for input %r of operator %r. " 

233 "It must be an instance of Variable (or a string), " 

234 "OnnxOperator, OnnxOperatorItem, numpy.ndarray, " 

235 "coo_matrix)." % ( 

236 type(a), i, class_name)) 

237 OnnxOperator.__init__(self, *args, **kwargs) 

238 

239 newclass = type(class_name, (OnnxOperator,), 

240 {"__init__": __init__, '__doc__': doc, 

241 'expected_inputs': inputs, 

242 'expected_outputs': outputs, 

243 'operator_name': op_name, 

244 'input_range': input_range, 

245 'output_range': output_range, 

246 'domain': domain, 

247 'is_deprecated': deprecated, 

248 'since_version': since_version, 

249 'past_version': past_version, 

250 'attr_names': attr_names, 

251 'op_type': op_name, 

252 '__module__': __name__}) 

253 return newclass 

254 

255 

256def _dynamic_class_creation(operator_names=None, cache=False, include_past=False, 

257 verbose=0, fLOG=print): 

258 """ 

259 Automatically generates classes for each of the operators 

260 module *onnx* defines and described at 

261 `Operators 

262 <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_ 

263 and `Operators 

264 <https://github.com/onnx/onnx/blob/master/docs/ 

265 Operators-ml.md>`_. 

266 

267 :param operator_names: list of operators to request or None for all 

268 :param cache: extract the documentation from onnx package and 

269 saves it on disk it True 

270 :param include_past: includes past versions if operator_names is None 

271 :param verbose: display some progress 

272 :param fLOG: logging function 

273 :return: list of requested operators as a tuple 

274 """ 

275 def _c(obj, label, i): 

276 name = '%s%d' % (obj.name or label, i) 

277 tys = obj.typeStr or '' 

278 return (name, tys) 

279 

280 cache_dir = cache_folder() 

281 if operator_names is None: 

282 operator_names = list(_all_schemas_versions) 

283 if include_past: 

284 add = [] 

285 for domain, op in operator_names: 

286 add.extend( 

287 [(domain, k) 

288 for k in _all_schemas_versions[domain, op]]) 

289 operator_names.extend(add) 

290 operator_names.sort() 

291 

292 # type verification 

293 ops = [] 

294 for name in operator_names: 

295 if isinstance(name, str): 

296 if name.startswith('Onnx'): 

297 raise ValueError( 

298 "Operator name cannot start with Onnx: %r." % name) 

299 domain = _find_operator_domain(name.split('_', maxsplit=1)[0]) 

300 ops.append((domain, name)) 

301 elif isinstance(name, tuple) and len(name) == 2: 

302 if name[1].startswith('Onnx'): 

303 raise ValueError( # pragma: no cover 

304 "Operator name cannot starts with Onnx: %r." % name) 

305 ops.append(name) 

306 else: 

307 raise ValueError( # pragma: no cover 

308 "Operator to fetch must be a string or a " 

309 "`tuple(domain, name)` not %r." % (name)) 

310 operator_names = ops 

311 

312 # versions 

313 res = _all_schemas 

314 cls = {} 

315 set_names = dict() 

316 set_skip = set() 

317 for pos, (op_domain, op_name) in enumerate(operator_names): 

318 if op_domain == 'ai.onnx': 

319 op_domain = '' 

320 set_names[op_domain, op_name] = pos 

321 if '_' in op_name and not include_past: 

322 n = op_name.split('_')[0] 

323 set_skip.add((op_domain, n)) 

324 if n not in set_names: 

325 set_names[op_domain, n] = -1 

326 

327 if verbose > 1 and fLOG is not None: 

328 fLOG( # pragma: no cover 

329 "[_dynamic_class_creation] set_names=%r" % set_names) 

330 fLOG( # pragma: no cover 

331 "[_dynamic_class_creation] set_skip=%r" % set_skip) 

332 

333 returned_classes = [] 

334 positions = {} 

335 

336 for (op_domain, op_name), position in set_names.items(): 

337 cl_name = 'Onnx' + _domain_to_class_name(op_domain) + op_name 

338 if verbose > 3 and fLOG is not None: 

339 fLOG( # pragma: no cover 

340 '[_dynamic_class_creation] cl_name=%r op_domain=%r op_name=%r (in=%d)' % ( 

341 cl_name, op_domain, op_name, 1 if cl_name in _all_classes else 0)) 

342 if cl_name in _all_classes: 

343 if cl_name not in set_skip: 

344 if position >= 0: 

345 returned_classes.append((position, _all_classes[cl_name])) 

346 continue 

347 

348 # operator name without domain 

349 if '_' in op_name: 

350 names = [op_name] 

351 else: 

352 try: 

353 names = _all_schemas_versions[op_domain, op_name].copy() 

354 except KeyError as e: # pragma: no cover 

355 raise ValueError( 

356 "Operator %r (domain=%r) does not exists." % ( 

357 op_name, op_domain)) from e 

358 names.add(op_name) 

359 

360 if verbose > 0 and fLOG is not None: 

361 fLOG( # pragma: no cover 

362 "[_dynamic_class_creation] op_domain=%r op_name=%r, cl_name=%r names=%r" 

363 "" % (op_domain, op_name, cl_name, names)) 

364 

365 for name in names: 

366 try: 

367 schema = res[op_domain, name] 

368 except KeyError as e: 

369 raise ValueError( 

370 "Operator (%r, %r) does not exists (available=%r)" % ( 

371 op_domain, name, pprint.pformat(list(res)))) from e 

372 inputs = [_c(o, 'I', i) for i, o in enumerate(schema.inputs)] 

373 outputs = [_c(o, 'O', i) for i, o in enumerate(schema.outputs)] 

374 args = [p for p in schema.attributes] 

375 

376 if '_' in name: 

377 class_name = "Onnx" + _domain_to_class_name(op_domain) + name 

378 else: 

379 class_name = ( 

380 "Onnx" + _domain_to_class_name(op_domain) + schema.name) 

381 

382 if verbose > 0 and fLOG is not None: 

383 fLOG( # pragma: no cover 

384 "[_dynamic_class_creation] op_name=%r, cl_name=%r cache=%r" 

385 "" % (op_name, class_name, cache)) 

386 

387 filename = os.path.join( 

388 cache_dir, 

389 schema.name + '_' + str(schema.since_version) + ".rst") 

390 if not cache and os.path.exists(filename): 

391 with open(filename, "r", encoding="utf-8") as f: # pragma: no cover 

392 doc = f.read() 

393 else: 

394 doc = get_rst_doc(schema) 

395 if cache: # pragma: no cover 

396 with open(filename, 'w', encoding='utf-8') as f: 

397 f.write(doc) 

398 

399 cl = ClassFactory(class_name, schema.name, inputs, outputs, 

400 [schema.min_input, schema.max_input], 

401 [schema.min_output, schema.max_output], 

402 schema.domain, args, 

403 "**Version**" + doc.split('**Version**')[-1], 

404 getattr(schema, 'deprecated', False), 

405 schema.since_version, {}) 

406 cls[class_name] = cl 

407 if name == op_name: 

408 positions[class_name] = position 

409 

410 # Retrieves past classes. 

411 for name in cls: # pylint: disable=C0206 

412 if '_' not in name: 

413 continue 

414 main, _ = name.split('_') 

415 if main in cls: # pylint: disable=R1715 

416 last = cls[main] 

417 else: 

418 last = _all_classes[main] 

419 last.past_version[name] = cls[name] 

420 

421 # final 

422 _all_classes.update(cls) 

423 for cl_name, v in cls.items(): 

424 if v not in set_skip and positions.get(cl_name, -1) >= 0: 

425 returned_classes.append((positions[cl_name], v)) 

426 

427 returned_classes.sort() 

428 return tuple(e[1] for e in returned_classes) 

429 

430 

431def loadop(*names, cache=False, verbose=0, fLOG=print): 

432 """ 

433 Dynamically creates a class for a every operator type in 

434 the given list. 

435 """ 

436 res = _dynamic_class_creation( 

437 names, cache=cache, verbose=verbose, fLOG=fLOG) 

438 if len(res) == 1: 

439 return res[0] 

440 return res 

441 

442 

443class OnnxLoadFactory: 

444 """ 

445 Automatically creating all operators from onnx packages 

446 takes time. That's why function @see cl loadop only creates 

447 classes for the requested operators. This class does the same 

448 when an attributes is requested. 

449 

450 :: 

451 

452 cl = OnnxLoadOperators() 

453 x = cl.Add(...) 

454 

455 It is equivalent to: 

456 

457 :: 

458 

459 OnnxAdd = loadop('Add') 

460 x = OnnxAdd(...) 

461 """ 

462 

463 def __init__(self): 

464 self._loaded_classes = {} 

465 

466 def __getattr__(self, name): 

467 """ 

468 Enables expressions such as: 

469 

470 :: 

471 

472 ops = OnnxLoadFactory() 

473 op = ops.Abs('X') 

474 """ 

475 if name == '_loaded_classes': 

476 return self._loaded_classes 

477 if name in self._loaded_classes: 

478 return self._loaded_classes[name] 

479 cl = loadop(name) 

480 self._loaded_classes[name] = cl 

481 self._loaded_classes[cl.__name__] = cl 

482 return cl 

483 

484 

485class OnnxOperatorBase: 

486 """ 

487 Base class for @see cl OnnxOperator, @see cl OnnxOperatorItem, 

488 @see cl OnnxOperatorTuple. 

489 """ 

490 

491 def __init__(self): 

492 pass 

493 

494 def add_to(self, builder): 

495 "This method should be overwritten." 

496 raise NotImplementedError( # pragma: no cover 

497 "Not overwritten for class %r." % type(self)) 

498 

499 @property 

500 def output_names(self): 

501 "This method should be overwritten." 

502 raise NotImplementedError( # pragma: no cover 

503 "Not overwritten for class %r." % type(self)) 

504 

505 def find_named_inputs(self): 

506 """ 

507 Returns all inputs to the graph. 

508 """ 

509 raise NotImplementedError( # pragma: no cover 

510 "Method 'find_named_inputs' must be overloaded for type %s." 

511 "" % type(self)) 

512 

513 def f(self, *args, **kwargs): 

514 """ 

515 Evaluates this node. 

516 """ 

517 raise NotImplementedError( # pragma: no cover 

518 "Method 'f' must be overloaded for type %s." % type(self)) 

519 

520 

521class OnnxOperatorItem(OnnxOperatorBase): 

522 """ 

523 Accessor to one of the output returned by a @see cl OnnxOperator. 

524 

525 :param onx_op: @see cl OnnxOperator 

526 :param index: integer 

527 :param op_version: defines the opset version 

528 """ 

529 

530 def __init__(self, onx_op, index, op_version=None): 

531 OnnxOperatorBase.__init__(self) 

532 if not isinstance(index, int): 

533 raise TypeError( # pragma: no cover 

534 "index must be an integer not %r." % type(index)) 

535 logger.debug("OnnxOperatorItem(%r, %d, op_version=%r)", 

536 onx_op, index, op_version) 

537 if not isinstance(onx_op, OnnxOperatorBase): 

538 raise TypeError( # pragma: no cover 

539 "onx_op must be an OnnxOperator not %r." % type(onx_op)) 

540 self.onx_op = onx_op 

541 self.index = index 

542 self.op_version = op_version 

543 

544 @property 

545 def output_names(self): 

546 "Returns None." 

547 return None 

548 

549 @property 

550 def inputs(self): 

551 "Returns the only inputs in a list." 

552 return [NodeResultName(self.onx_op, self.index)] 

553 

554 def add_to(self, builder): 

555 """ 

556 Adds to graph builder. 

557 Does nothing because the original node is already added. 

558 

559 :param builder: instance of @see cl _GraphBuilder, 

560 it must have a method `add_node` 

561 """ 

562 pass 

563 

564 def __str__(self): 

565 "usual" 

566 return "%s[%d]" % (str(self.onx_op), self.index) 

567 

568 def __repr__(self): 

569 "usual" 

570 return "%s(%s[%d])" % ( 

571 self.__class__.__name__, 

572 self.onx_op.__class__.__name__, 

573 self.index) 

574 

575 def get_output_result(self, i=0): 

576 """ 

577 Returns the output name at position *i*. 

578 """ 

579 if i != 0: 

580 raise IndexError( # pragma: no cover 

581 "Can only return the first item.") 

582 return self.onx_op.get_output_result(self.index) 

583 

584 def find_named_inputs(self): 

585 """ 

586 Returns all inputs to the graph. 

587 """ 

588 return self.onx_op.find_named_inputs() 

589 

590 def f(self, *inputs, verbose=0, fLOG=None, # pylint: disable=W0221 

591 clear_cache=False, runtime=None): 

592 """ 

593 Computes the predictions for this node. 

594 Similar to an eager evaluation. 

595 

596 :param inputs: inputs as dictionary or a list of inputs 

597 (see below) 

598 :param verbose: display information while predicting 

599 :param fLOG: logging function if *verbose > 0* 

600 :param clear_cache: onnx graph is created once unless 

601 this parameter is True 

602 :param runtime: runtime to use for the evaluation, 

603 see @see cl OnnxInference 

604 :return: outputs as a dictionary if the input were given as a 

605 dictionary or a single result or a tuple otherwise 

606 

607 The inputs refer to the inputs of the graph. 

608 The method walks through all inputs and finds inputs defined as 

609 string. It replaces them by the value found in the dictionary. 

610 If the inputs are specified in a list, the function retrieves the 

611 list of inputs defined as a string and assigns them a value. 

612 Logging function can be used to get more insight about it. 

613 During the evaluation every node is independently converted 

614 into ONNX. The ONNX graph is cached in the class itself. 

615 """ 

616 res = self.onx_op.f(*inputs, verbose=verbose, fLOG=fLOG, 

617 clear_cache=clear_cache, runtime=runtime) 

618 if isinstance(res, dict): 

619 names = self.onx_op.output_names 

620 if names is None: 

621 names = self.onx_op.expected_outputs 

622 name = names[self.index][0] 

623 else: 

624 name = names[self.index] 

625 return {name: res[name]} 

626 return res[self.index] 

627 

628 

629class OnnxOperatorTuple(OnnxOperatorBase): 

630 """ 

631 Class used to return multiple @see cl OnnxVar 

632 at the same time. 

633 """ 

634 

635 def __init__(self, first, *args): 

636 OnnxOperatorBase.__init__(self) 

637 logger.debug("%s([%r], %d in)", 

638 self.__class__.__name__, type(first), len(args)) 

639 if isinstance(first, (list, tuple)): 

640 raise TypeError( # pragma: no cover 

641 "Unexpected type for first %r." % type(first)) 

642 logger.debug('OnnxOperatorTuple(%d in)', 1 + len(args)) 

643 if len(args) > 0: 

644 self.values = (first,) + args 

645 self.unique = None 

646 else: 

647 self.values = None 

648 self.unique = first 

649 if self.values is not None and self.unique is not None: 

650 raise RuntimeError( # pragma: no cover 

651 "Unexpected configuration. One member (values or unique) must be " 

652 "null, unique=%r, values=%r" % (self.unique, self.values)) 

653 if self.values is None and self.unique is None: 

654 raise RuntimeError( # pragma: no cover 

655 "Unexpected configuration. One member (values or unique) must be " 

656 "not null.") 

657 

658 def __repr__(self): 

659 "usual" 

660 if self.values is None: 

661 return "%s(%r)" % (self.__class__.__name__, type(self.unique)) 

662 return "%s(%s)" % (self.__class__.__name__, ", ".join( 

663 "%r" % type(v) for v in self.values)) 

664 

665 @property 

666 def inputs(self): 

667 "Returns the only inputs in a list." 

668 if self.values is None: 

669 return [self.unique] 

670 raise NotImplementedError( # pragma: no cover 

671 "OnnxOperatorTuple.inputs is missing.") 

672 

673 def add_to(self, builder): 

674 """ 

675 Adds to graph builder. 

676 Does nothing because the original node is already added. 

677 

678 :param builder: instance of @see cl _GraphBuilder, 

679 it must have a method `add_node` 

680 """ 

681 pass 

682 

683 def __len__(self): 

684 "usual" 

685 if self.values is None: 

686 raise NotImplementedError( # pragma: no cover 

687 "Not yet implemented in this case unique=%r, " 

688 "values=%r." % (self.unique, self.values)) 

689 return len(self.values) 

690 

691 def __iter__(self): 

692 "Iterates on the outputs." 

693 if self.values is None: 

694 raise NotImplementedError( # pragma: no cover 

695 "Not yet implemented in this case.") 

696 for v in self.values: 

697 yield v 

698 

699 def __getitem__(self, i): 

700 "usual" 

701 if self.values is None: 

702 return self.unique[i] 

703 return self.values[i] 

704 

705 @property 

706 def outputs(self): 

707 "Returns 'output_names' of attribute 'unique'." 

708 if self.values is None: 

709 if hasattr(self.unique, 'to_onnx'): 

710 return self.unique.outputs 

711 raise NotImplementedError( # pragma: no cover 

712 "Not implemented yet unique=%r values=%r." % ( 

713 self.unique, self.values)) 

714 

715 @property 

716 def output_names(self): 

717 "Returns 'output_names' of attribute 'unique'." 

718 if self.values is None: 

719 if hasattr(self.unique, 'to_onnx'): 

720 return self.unique.output_names 

721 raise NotImplementedError( # pragma: no cover 

722 "Not implemented yet unique=%r values=%r." % ( 

723 self.unique, self.values)) 

724 

725 @output_names.setter 

726 def output_names(self, value): 

727 """ 

728 Updates 'output_names' of attribute 'unique' 

729 or every output name of attribute 'values'. 

730 """ 

731 logger.debug("OnnxOperatorTuple:output_names:set(%r)", value) 

732 OnnxIdentity = loadop('Identity') 

733 if self.values is None: 

734 if (hasattr(self.unique, 'to_onnx') or 

735 hasattr(self.unique, 'add_to')): 

736 if len(value) > 1: 

737 self.values = tuple( 

738 OnnxIdentity( 

739 self.unique[i], output_names=value[i:i + 1], 

740 op_version=self.unique.op_version) 

741 for i in range(0, len(value))) 

742 self.unique = None 

743 return 

744 self.unique.output_names = [Variable(v) for v in value] 

745 return 

746 raise NotImplementedError( # pragma: no cover 

747 "Not implemented yet, value=%r, unique=%r values=%r." % ( 

748 value, self.unique, self.values)) 

749 if self.values is not None and len(self.values) == len(value): 

750 for name, v in zip(value, self.values): 

751 v.output_names = [Variable(name)] 

752 return 

753 raise NotImplementedError( # pragma: no cover 

754 "Not implemented yet, value=%r, unique=%r values=%r." % ( 

755 value, self.unique, self.values)) 

756 

757 def to_onnx(self, inputs=None, outputs=None, 

758 other_outputs=None, target_opset=None, 

759 optim=True, verbose=0, run_shape=True): 

760 """ 

761 Converts this operator into an ONNX graph. 

762 It follows the same signature as :meth:`OnnxOperator.to_onnx 

763 <mlprodict.npy.xop.OnnxOperator.to_onnx>` and calls this 

764 method of the unique input object or the first one 

765 if there are several. In that case, other inputs in 

766 attribute `values` are moved into container 

767 `other_outputs`. 

768 """ 

769 if self.values is None: 

770 return self.unique.to_onnx( 

771 inputs=inputs, outputs=outputs, other_outputs=other_outputs, 

772 target_opset=target_opset, optim=optim, verbose=verbose, 

773 run_shape=run_shape) 

774 new_other_outputs = self.values[1:] 

775 if other_outputs is not None: 

776 new_other_outputs.extend(other_outputs) 

777 return self.values[0].to_onnx( 

778 inputs=inputs, outputs=outputs, other_outputs=new_other_outputs, 

779 target_opset=target_opset, optim=optim, verbose=verbose, 

780 run_shape=run_shape) 

781 

782 

783class OnnxOperator(OnnxOperatorBase): 

784 """ 

785 Ancestor to every *ONNX* operator exposed in 

786 :mod:`mlprodict.npy.xops` and :mod:`mlprodict.npy.xops_ml`. 

787 

788 :param inputs: list of inputs expected by the operator 

789 :param op_version: to select a specific version of the operator 

790 :param output_names: used defined names for the outputs 

791 :param domain: to overwrite the default domain 

792 :param global_context: operator *If* executes one subgraph 

793 whose nodes may use one existing output in the current 

794 context. If not used in the main graph, these operators 

795 are not linked to the output and cannot be retrieved. 

796 *global_context* is a dictionary mapped the subgraph input 

797 names to these operators. 

798 :param kwargs: additional parameters of the operator 

799 

800 .. versionadd:: 0.9 

801 """ 

802 @classmethod 

803 def __class_getitem__(cls, opset): 

804 """ 

805 Enables expression `cls[opset]`. It returns the appropriate class 

806 `cls_opset`. Parameter *op_version* should be specified. 

807 """ 

808 if not isinstance(opset, int): 

809 raise ValueError( 

810 "opset must an integer not %r." % type(opset)) 

811 best = None 

812 for _, v in cls.past_version.items(): 

813 if v.since_version == opset: 

814 return lambda *args, **kwargs: v( 

815 *args, op_version=opset, **kwargs) 

816 if v.since_version <= opset and ( 

817 best is None or best.since_version < v.since_version): 

818 best = v 

819 if best is None: 

820 raise ValueError( 

821 "Unable to find a version of operator %r and opset %r." % ( 

822 cls.__name__, opset)) 

823 return lambda *args, **kwargs: best( 

824 *args, op_version=opset, **kwargs) 

825 

826 def __init__(self, *inputs, op_version=None, output_names=None, 

827 domain=None, global_context=None, **kwargs): 

828 

829 OnnxOperatorBase.__init__(self) 

830 logger.debug("%s(%d in, op_version=%r, output_names=%r)", 

831 self.__class__.__name__, len(inputs), op_version, 

832 output_names) 

833 if (output_names is None and 

834 self.__class__.__name__.startswith("OnnxScan")): 

835 raise NotImplementedError( 

836 "The class cannot infer the number of variables " 

837 "for node '{}' yet. output_names must be specified" 

838 ".".format(self.__class__.__name__)) 

839 if isinstance(output_names, (str, Variable)): 

840 output_names = [output_names] 

841 if isinstance(output_names[0], str): 

842 output_names[0] = Variable(output_names[0]) 

843 elif isinstance(output_names, list): 

844 if len(output_names) == 0: 

845 raise ValueError( # pragma: no cover 

846 "output_names cannot be empty (operator %r)." 

847 "" % self.__class__.__name__) 

848 output_names = output_names.copy() 

849 for i in range(len(output_names)): # pylint: disable=C0200 

850 if isinstance(output_names[i], str): 

851 output_names[i] = Variable(output_names[i]) 

852 elif output_names is not None: 

853 raise TypeError( # pragma: no cover 

854 "output_names must be a string or a list not %r." 

855 "" % type(output_names)) 

856 

857 if op_version is None: 

858 if domain == '': 

859 self.op_version = max_supported_opset() 

860 else: 

861 self.op_version = None 

862 else: 

863 self.op_version = op_version 

864 self.since_version = self.__class__.since_version 

865 

866 if (self.op_version is not None and 

867 self.op_version < self.since_version): 

868 schema = self.find_schema(self.op_version) 

869 self.since_version = schema.since_version 

870 self.expected_inputs = schema.expected_inputs.copy() 

871 self.expected_outputs = schema.expected_outputs.copy() 

872 self.input_range = schema.input_range 

873 self.output_range = schema.output_range 

874 else: 

875 self.expected_inputs = ( 

876 None if self.__class__.expected_inputs is None 

877 else self.__class__.expected_inputs.copy()) 

878 self.expected_outputs = ( 

879 None if self.__class__.expected_outputs is None 

880 else self.__class__.expected_outputs.copy()) 

881 self.input_range = self.__class__.input_range 

882 self.output_range = self.__class__.output_range 

883 if self.__class__.__name__ not in { 

884 'OnnxScan', 'OnnxLoop', 'OnnxIf'}: 

885 # The minimum opset depends on embedded graph 

886 # by default, it takes the given op_version but the 

887 # optimal value could be lower. 

888 self.op_version = self.since_version 

889 if self.op_version is None: 

890 self.op_version = self.since_version 

891 

892 if (self.op_version is not None and 

893 self.op_version < self.since_version): 

894 raise RuntimeError( # pragma: no cover 

895 "Operator '{}': requested version {} < " 

896 "{} schema version.".format( 

897 self.__class__.__name__, 

898 self.op_version, self.since_version)) 

899 

900 self.state = None 

901 self.domain = domain 

902 self.kwargs = kwargs 

903 self.max_item_ = None 

904 

905 # check inputs 

906 self.inputs = [] 

907 if len(inputs) > 0: 

908 for inp in inputs: 

909 if isinstance(inp, str): 

910 self.inputs.append(Variable(inp)) 

911 elif isinstance(inp, tuple): 

912 if len(inp) != 2: 

913 raise RuntimeError( # pragma: no cover 

914 "Unexpected tuple %r." % (inp, )) 

915 self.inputs.append( 

916 Variable(inp[0], dtype=guess_numpy_type(inp[1]), 

917 shape=inp[1].shape)) 

918 elif isinstance(inp, (OnnxOperatorBase, Variable)): 

919 self.inputs.append(inp) 

920 elif isinstance(inp, (numpy.ndarray, coo_matrix, TensorProto)): 

921 self.inputs.append(inp) 

922 elif isinstance(inp, ValueInfoProto): 

923 self.inputs.append(inp.type.tensor_type) 

924 else: 

925 raise TypeError( # pragma: no cover 

926 "Unable to interpret the input name for type {} in " 

927 "operator '{}' (value={}).".format( 

928 type(inp), self.__class__.__name__, inp)) 

929 

930 if (self.inputs is not None and 

931 (len(self.inputs) < self.input_range[0] or 

932 len(self.inputs) > self.input_range[1])): 

933 raise RuntimeError( # pragma: no cover 

934 "Operator '{}' expects a number of inputs in [{}, {}] not {} " 

935 "(expected opset={}, class opset={})".format( 

936 getattr(self, 'operator_name', '?'), *self.input_range, 

937 len(self.inputs), op_version, self.op_version)) 

938 # global context 

939 if global_context is None: 

940 self.global_context = None 

941 else: 

942 if not isinstance(global_context, dict): 

943 raise TypeError( # pragma: no cover 

944 "global_context must be a dictionary not %r." 

945 "" % type(global_context)) 

946 for k, v in global_context.items(): 

947 if not isinstance(v, OnnxOperatorBase): 

948 raise TypeError( # pragma: no cover 

949 "Value %r in must be an OnnxOperatorBase not %r." 

950 "" % (k, type(v))) 

951 self.global_context = global_context 

952 

953 # check output 

954 self.output_names_ = output_names 

955 self.output_variables = None 

956 

957 if self.output_names is not None: 

958 if len(self.output_names) == 0: 

959 raise ValueError( # pragma: no cover 

960 "output_names can be None but cannot be empty for " 

961 "operator %r." % self) 

962 if self.output_variables is None: 

963 self.output_variables = [None for o in self.output_names] 

964 for i in range(len(self.output_names)): # pylint: disable=C0200 

965 name = self.output_names[i] 

966 if isinstance(name, Variable): 

967 self.output_variables[i] = name 

968 else: 

969 raise TypeError( # pragma: no cover 

970 "output_names must be a list of strings " 

971 "and element %r is %r (%r)" % ( 

972 i, type(name), name)) 

973 if all(map(lambda x: x is None, self.output_variables)): 

974 self.output_variables = None 

975 

976 if (self.output_names is not None and ( 

977 self.expected_outputs is None or 

978 len(self.output_names) > len(self.expected_outputs))): 

979 if self.expected_outputs is None: 

980 self.expected_outputs = [] 

981 for i in range(len(self.expected_outputs), 

982 len(self.output_names)): 

983 self.expected_outputs.append((self.output_names[i], None)) 

984 

985 if (self.expected_inputs is None or 

986 len(self.inputs) > len(self.expected_inputs)): 

987 if self.expected_inputs is None: 

988 self.expected_inputs = [] 

989 for i in range(len(self.expected_inputs), 

990 len(self.inputs)): 

991 inp = self.inputs[i] 

992 if isinstance(inp, str): 

993 inp = (inp, None) 

994 elif hasattr(inp, 'add_to'): 

995 # OnnxOperator 

996 existing = set(_[0] for _ in self.expected_inputs) 

997 i = 10 

998 name = "input%d" % (10 + i) 

999 while name in existing: 

1000 i += 1 

1001 name = "input%d" % (10 + i) 

1002 inp = (name, None) 

1003 self.expected_inputs.append(inp) 

1004 

1005 self._post_process_attributes() 

1006 self._check() 

1007 

1008 @property 

1009 def output_names(self): 

1010 "Returns `self.output_names_`." 

1011 return self.output_names_ 

1012 

1013 @output_names.setter 

1014 def output_names(self, value): 

1015 logger.debug("OnnxOperator:output_names:set(%r)", value) 

1016 self.output_names_ = value 

1017 

1018 def _check(self): 

1019 input_types = (Variable, OnnxOperatorBase, numpy.ndarray, 

1020 TensorProto) 

1021 for o in self.inputs: 

1022 if not isinstance(o, input_types): 

1023 raise TypeError( # pragma: no cover 

1024 "Wrong type for inputs %r." % ( 

1025 self.inputs, )) 

1026 if self.output_names is not None: 

1027 for o in self.output_names: 

1028 if not isinstance(o, Variable): 

1029 raise TypeError( # pragma: no cover 

1030 "Wrong type for output_names %r." % ( 

1031 self.output_names, )) 

1032 

1033 def _post_process_attributes(self): 

1034 """ 

1035 Walks through attributes and replaces them by ONNX values. 

1036 """ 

1037 # Looks into attributes if there is any tuple 

1038 # (GraphProto, OnnxOperator). In that case, the function 

1039 # replaces the tuple by the graph proto and keeps 

1040 # in attributes graph_algebra the OnnxOperator 

1041 # which is the source of it. 

1042 updates = {} 

1043 graph_algebra = {} 

1044 for k, v in self.kwargs.items(): 

1045 if isinstance(v, tuple) and isinstance(v[0], GraphProto): 

1046 updates[k] = v[0] 

1047 graph_algebra[k] = v[1] 

1048 

1049 if len(graph_algebra) > 0: 

1050 self.kwargs.update(updates) 

1051 self.graph_algebra = graph_algebra 

1052 

1053 if self.__class__.__name__ == "OnnxConstantOfShape": 

1054 if "value" in self.kwargs: 

1055 value = self.kwargs['value'] 

1056 if isinstance(value, TensorProto): 

1057 return 

1058 if isinstance(value, numpy.ndarray): 

1059 if value.shape == (1, ): 

1060 val = value[0] 

1061 elif len(value.shape) == 0: 

1062 val = value 

1063 else: 

1064 raise RuntimeError( # pragma: no cover 

1065 "Unexpected shape %r for value, it must be " 

1066 "an array of one element." % value.shape) 

1067 self.kwargs['value'] = from_array( 

1068 numpy.array([val], dtype=value.dtype)) 

1069 return 

1070 raise TypeError( # pragma: no cover 

1071 "Unexpected type %r for value. It should be an array " 

1072 "of one element." % type(value)) 

1073 return 

1074 

1075 if self.__class__.__name__ == "OnnxCast": 

1076 if "to" in self.kwargs: 

1077 value = self.kwargs['to'] 

1078 if not isinstance(value, int): 

1079 try: 

1080 to = numpy_type_prototype(value) 

1081 except ValueError as e: # pragma: no cover 

1082 raise ValueError( 

1083 "Unable to convert argument to in operator cast, " 

1084 "type is %r, value is %r." % (type(value), value)) from e 

1085 self.kwargs['to'] = to 

1086 return 

1087 

1088 def update_max_item(self, index): 

1089 """ 

1090 Some operators return a undefined number of outputs. 

1091 The method is called when require one of them (with `__getitem__`) 

1092 and keeps the greater requested index assuming the node does 

1093 not output any result beyond that index. 

1094 

1095 :param index: requested index 

1096 """ 

1097 if self.max_item_ is None: 

1098 self.max_item_ = index 

1099 else: 

1100 self.max_item_ = max(self.max_item_, index) 

1101 if self.expected_outputs is None: 

1102 self.expected_outputs = [] 

1103 while len(self.expected_outputs) <= self.max_item_: 

1104 self.expected_outputs.append( 

1105 (("NEWOUTPUT", len(self.expected_outputs)), None)) 

1106 

1107 def find_schema(self, op_version): 

1108 """ 

1109 Checks if there is an existing schema for a specific version. 

1110 

1111 :param op_version: requested version 

1112 :return: schema 

1113 """ 

1114 if not hasattr(self.__class__, 'past_version'): 

1115 raise RuntimeError( # pragma: no cover 

1116 "Missing attribute 'past_version', there is " 

1117 "no other available schema.") 

1118 found = None 

1119 for v in self.past_version.values(): 

1120 if v.since_version > op_version: 

1121 continue 

1122 if found is None or v.since_version > found.since_version: 

1123 found = v 

1124 if found is None: 

1125 raise RuntimeError( # pragma: no cover 

1126 "Operator '{}': requested version {} < " 

1127 "{} schema version (past_version {}).".format( 

1128 self.__class__.__name__, 

1129 op_version, self.since_version, 

1130 [v.since_version for v in self.past_version.values()])) 

1131 return found 

1132 

1133 def __repr__(self): 

1134 """ 

1135 usual 

1136 """ 

1137 return "{}({} in) -> {}".format( 

1138 self.__class__.__name__, 

1139 len(self.inputs) if self.inputs is not None else 0, 

1140 [str(o) for o in self.output_names] 

1141 if self.output_names is not None else "?") 

1142 

1143 def get_output_result(self, i=0): 

1144 """ 

1145 Returns the output name at position *i*. 

1146 """ 

1147 return NodeResultName(self, i) 

1148 

1149 def __getitem__(self, index): 

1150 """ 

1151 Returns an accessor to one of the output 

1152 of this node. 

1153 """ 

1154 self.update_max_item(index) 

1155 return OnnxOperatorItem(self, index, self.op_version) 

1156 

1157 def __iter__(self): 

1158 """ 

1159 Allows expressions such as ``a, b = OnnxTopK(...)``. 

1160 """ 

1161 n = None 

1162 if self.output_names is not None: 

1163 n = len(self.output_names) 

1164 else: 

1165 rg = self.output_range 

1166 if rg[0] == rg[1] and rg[0] > 0: 

1167 n = rg[0] 

1168 if n is None and self.max_item_ is not None: 

1169 n = self.max_item_ + 1 

1170 if n is None: 

1171 raise RuntimeError( # pragma: no cover 

1172 "Unable to guess the number of outputs of node type %r. " 

1173 "Uses operator [] to select a specific output." % 

1174 self.__class__.__name__) 

1175 if self.max_item_ is not None: 

1176 n = max(n, self.max_item_ + 1) 

1177 for i in range(n): 

1178 yield self[i] 

1179 

1180 def add_to(self, builder): 

1181 """ 

1182 Adds to graph builder. 

1183 

1184 :param builder: instance of @see cl _GraphBuilder, 

1185 it must have a method `add_node` 

1186 """ 

1187 logger.debug("%s.add_to(builder)", self.__class__.__name__) 

1188 inputs = builder.get_input_names(self, self.inputs) 

1189 if self.output_names is not None: 

1190 n_outputs = len(self.output_names) 

1191 elif self.expected_outputs is not None: 

1192 n_outputs = len(self.expected_outputs) 

1193 else: 

1194 n_outputs = self.output_range[0] 

1195 outputs = [builder.get_unique_output_name(NodeResultName(self, i)) 

1196 for i in range(n_outputs)] 

1197 builder.add_node( 

1198 self.operator_name, 

1199 builder.get_unique_name( 

1200 '_' + self.operator_name.lower(), reserved=False), 

1201 inputs, outputs, domain=self.domain, opset=self.op_version, 

1202 **self.kwargs) 

1203 

1204 @staticmethod 

1205 def _node_to_graph_preprocess_list(inputs): 

1206 new_inputs = OrderedDict() 

1207 for el in inputs: 

1208 if isinstance(el, str): 

1209 new_inputs[el] = Variable(el) 

1210 elif isinstance(el, Variable): 

1211 new_inputs[el.name] = el 

1212 elif isinstance(el, tuple) and len(el) == 2: 

1213 # sklearn-onnx 

1214 new_inputs[el[0]] = Variable( 

1215 el[0], guess_numpy_type(el[1]), el[1].shape) 

1216 else: 

1217 raise TypeError( # pragma: no cover 

1218 "Unable to handle input type %r (%r)." % (type(el), el)) 

1219 return new_inputs 

1220 

1221 @staticmethod 

1222 def _node_to_graph_process_input(inputs, set_inputs, node, inp, 

1223 new_inputs, new_stack, inputs_dtype, 

1224 as_function=False): 

1225 if not as_function and inputs is None and inputs_dtype is None: 

1226 raise RuntimeError( # pragma: no cover 

1227 "Both inputs and inputs_dtype cannot be None at the same time " 

1228 "for inp=%r." % (inp, )) 

1229 if isinstance(inp, OnnxOperator): 

1230 new_stack.append(inp) 

1231 elif isinstance(inp, OnnxOperatorItem): 

1232 new_stack.append(inp) 

1233 new_stack.append(inp.onx_op) 

1234 elif isinstance(inp, OnnxOperatorTuple): 

1235 # new_stack.append(inp) 

1236 # new_stack.append(inp.onx_op) 

1237 raise NotImplementedError( # pragma: no cover 

1238 "Unable to guess inputs when one input is OnnxOperatorTuple.") 

1239 elif isinstance(inp, Variable): 

1240 if inp.name in set_inputs: 

1241 return 

1242 set_inputs.add(inp.name) 

1243 if inputs is None and inputs_dtype is None: 

1244 new_inputs.append(InputDetectedVariable(node, inp)) 

1245 elif isinstance(inputs, dict): 

1246 if inp.name in inputs: 

1247 new_inputs.append( 

1248 InputDetectedVariable( 

1249 node, inp.copy_merge(inputs[inp.name]))) 

1250 else: 

1251 raise ValueError( # pragma: no cover 

1252 "Unable to find input %r in %r." % ( 

1253 inp, inputs)) 

1254 elif inputs_dtype is not None: 

1255 new_inputs.append( 

1256 InputDetectedVariable(node, inp.copy_add(inputs_dtype))) 

1257 elif isinstance(inputs, Variable): 

1258 if inp.name == inputs.name: 

1259 new_inputs.append( 

1260 InputDetectedVariable(node, inp.copy_merge(inputs))) 

1261 else: 

1262 new_inputs.append( 

1263 InputDetectedVariable(node, inp)) 

1264 else: 

1265 raise RuntimeError( # pragma: no cover 

1266 "Unable to handle inputs=%r." % inputs) 

1267 elif isinstance(inp, numpy.ndarray): 

1268 pass 

1269 else: 

1270 raise TypeError( # pragma: no cover 

1271 "Unexpected input type %r in node type %r." % ( 

1272 type(inp), type(node))) 

1273 

1274 @staticmethod 

1275 def _node_to_graph_get_type(node, name=None, outputs=None, 

1276 outputs_dtype=None): 

1277 if outputs is None: 

1278 return outputs_dtype 

1279 if isinstance(outputs, Variable): 

1280 if name is None: 

1281 return outputs.dtype or outputs_dtype 

1282 if isinstance(name, Variable): 

1283 return outputs.dtype or name.dtype or outputs_dtype 

1284 else: 

1285 raise RuntimeError( # pragma: no cover 

1286 "Unable to handle outputs=%r." % outputs) 

1287 if isinstance(outputs, dict): 

1288 if name is None: 

1289 raise RuntimeError( # pragma: no cover 

1290 "Unable to get type among %r, name=None." % ( 

1291 outputs, )) 

1292 if isinstance(name, Variable): 

1293 n = name.name 

1294 else: 

1295 n = name 

1296 if n not in outputs: 

1297 return None 

1298 return outputs[n] 

1299 if isinstance(outputs, list): 

1300 raise NotImplementedError( # pragma: no cover 

1301 "Unexpected type for name=%r, outputs=%r." % ( 

1302 name, outputs)) 

1303 if is_numpy_dtype(outputs): 

1304 return outputs 

1305 raise RuntimeError( # pragma: no cover 

1306 "Unable to handle outputs=%r." % outputs) 

1307 

1308 @staticmethod 

1309 def _node_to_graph_reorder_by_name(new_inputs, inputs): 

1310 memo = OrderedDict((n.name, n) for n in new_inputs) 

1311 done = set() 

1312 result = [] 

1313 for inp in inputs: 

1314 if inp.name in memo: 

1315 result.append(memo[inp.name]) 

1316 done.add(inp.name) 

1317 for k, v in memo.items(): 

1318 if k in done: 

1319 continue 

1320 result.append(v) 

1321 return result 

1322 

1323 def _node_to_graph(self, other_outputs=None, inputs=None, outputs=None, 

1324 as_function=False): 

1325 """ 

1326 Builds a graph as a list of nodes to walk through in that order. 

1327 """ 

1328 

1329 node_outputs = [self] 

1330 if other_outputs is not None: 

1331 node_outputs += other_outputs 

1332 

1333 logger.debug("%s._node_to_graph:inputs=%r", 

1334 self.__class__.__name__, inputs) 

1335 logger.debug("%s._node_to_graph:outputs=%r", 

1336 self.__class__.__name__, outputs) 

1337 

1338 # preprocess inputs, outputs 

1339 _keep_inputs = None 

1340 inputs_dtype = None 

1341 if isinstance(inputs, list): 

1342 _keep_inputs = inputs 

1343 inputs_dict = self._node_to_graph_preprocess_list(inputs) 

1344 elif isinstance(inputs, dict): 

1345 inputs_dict = inputs 

1346 elif isinstance(inputs, Variable): 

1347 inputs = [inputs] 

1348 inputs_dict = self._node_to_graph_preprocess_list(inputs) 

1349 elif is_numpy_dtype(inputs): 

1350 inputs_dtype = inputs 

1351 inputs_dict = None 

1352 else: 

1353 raise TypeError( # pragma: no cover 

1354 "Unexpected type %r for inputs." % type(inputs)) 

1355 

1356 _keep_outputs = None 

1357 outputs_dtype = None 

1358 if isinstance(outputs, list): 

1359 _keep_outputs = outputs 

1360 outputs_dict = self._node_to_graph_preprocess_list(outputs) 

1361 elif isinstance(outputs, dict): 

1362 outputs_dict = outputs 

1363 elif isinstance(outputs, Variable): 

1364 outputs = [outputs] 

1365 outputs_dict = self._node_to_graph_preprocess_list(outputs) 

1366 elif is_numpy_dtype(outputs): 

1367 outputs_dtype = outputs 

1368 outputs_dict = None 

1369 else: 

1370 raise TypeError( # pragma: no cover 

1371 "Unexpected type %r for outputs." % type(outputs)) 

1372 

1373 logger.debug("%s._node_to_graph:inputs=%r", 

1374 self.__class__.__name__, inputs) 

1375 logger.debug("%s._node_to_graph:outputs=%r", 

1376 self.__class__.__name__, outputs) 

1377 logger.debug("%s._node_to_graph:inputs_dict=%r", 

1378 self.__class__.__name__, inputs_dict) 

1379 logger.debug("%s._node_to_graph:outputs_dict=%r", 

1380 self.__class__.__name__, outputs_dict) 

1381 logger.debug("%s._node_to_graph:inputs_dtype=%r", 

1382 self.__class__.__name__, inputs_dtype) 

1383 logger.debug("%s._node_to_graph:outputs_dtype=%r", 

1384 self.__class__.__name__, outputs_dtype) 

1385 

1386 # walk through graph 

1387 stack = list(node_outputs) 

1388 new_inputs = [] 

1389 set_inputs = set() 

1390 memo = [] 

1391 while len(stack) > 0: 

1392 memo.extend(stack) 

1393 new_stack = [] 

1394 for obj in stack: 

1395 if isinstance(obj, OnnxOperatorItem): 

1396 # nothing to do, OnnxOperatorItem is created 

1397 # by OnnxOperator.__getitem__. 

1398 pass 

1399 elif isinstance(obj, (OnnxOperator, OnnxOperatorTuple)): 

1400 for inp in obj.inputs: 

1401 self._node_to_graph_process_input( 

1402 inputs_dict, set_inputs, obj, inp, new_inputs, 

1403 new_stack, inputs_dtype, as_function=as_function) 

1404 else: 

1405 raise TypeError( # pragma: no cover 

1406 "Unexpected type %r." % type(obj)) 

1407 stack = new_stack 

1408 

1409 # reorder new_inputs to follow inputs initial order 

1410 if _keep_inputs is not None: 

1411 new_inputs = self._node_to_graph_reorder_by_name( 

1412 new_inputs, inputs) 

1413 

1414 logger.debug("%s._node_to_graph:new_inputs=%r", 

1415 self.__class__.__name__, new_inputs) 

1416 

1417 # eliminate duplicates 

1418 done = set() 

1419 nodes = [] 

1420 for node in reversed(memo): 

1421 if id(node) in done: 

1422 continue 

1423 done.add(id(node)) 

1424 nodes.append(node) 

1425 

1426 # outputs 

1427 set_names = set() 

1428 new_outputs = [] 

1429 run_shape = False 

1430 for node in node_outputs: 

1431 if node.output_names is None: 

1432 n = self.output_range[0] 

1433 for i in range(n): 

1434 to = self._node_to_graph_get_type( 

1435 node, outputs=outputs_dict, 

1436 outputs_dtype=outputs_dtype) 

1437 if to is None: 

1438 run_shape = True 

1439 res = '???_%d' % i 

1440 var = Variable(res, added_dtype=to) 

1441 if var.name in set_names: 

1442 raise RuntimeError( # pragma: no cover 

1443 "Duplicated output name var=%r." % var) 

1444 set_names.add(var.name) 

1445 new_outputs.append(OutputDetectedVariable(node, var, i)) 

1446 else: 

1447 for i, o in enumerate(node.output_names): 

1448 if isinstance(o, str): 

1449 raise TypeError( # pragma: no cover 

1450 "Output %d - %r (%r) not allowed in node %r." % ( 

1451 i, o, node.output_names, node)) 

1452 to = self._node_to_graph_get_type( 

1453 node, o, outputs=outputs_dict, 

1454 outputs_dtype=outputs_dtype) 

1455 if to is None: 

1456 run_shape = True 

1457 res = (o, to) 

1458 var = o.copy_merge(to) 

1459 if var.name in set_names: 

1460 raise RuntimeError( # pragma: no cover 

1461 "Duplicated output name o=%r var=%r." % (o, var)) 

1462 set_names.add(var.name) 

1463 new_outputs.append(OutputDetectedVariable(node, var, i)) 

1464 if len(new_outputs) == 0: 

1465 raise RuntimeError( # pragma: no cover 

1466 "No detected outputs inputs=%r outputs=%r." % ( 

1467 inputs_dict, outputs_dict)) 

1468 

1469 # reorder new_outputs to follow outputs initial order 

1470 if _keep_outputs is not None: 

1471 new_outputs = self._node_to_graph_reorder_by_name( 

1472 new_outputs, outputs) 

1473 

1474 logger.debug("%s._node_to_graph:new_outputs=%r", 

1475 self.__class__.__name__, new_outputs) 

1476 

1477 return nodes, new_inputs, new_outputs, run_shape 

1478 

1479 def to_onnx(self, inputs=None, outputs=None, 

1480 other_outputs=None, target_opset=None, 

1481 optim=True, verbose=0, run_shape=True, 

1482 function_name=None, function_domain=None, 

1483 fLOG=print): 

1484 """ 

1485 Converts this operator into an ONNX graph. 

1486 

1487 :param inputs: information about type, it should not be None 

1488 :param outputs: information about types, if None, the function 

1489 will use shape inference to guess the final output type 

1490 and shape 

1491 :param other_outputs: additional nodes to consider 

1492 as graph outputs but not outputs of this particular 

1493 node 

1494 :param target_opset: dictionary with target opset per domain, 

1495 None for the default one 

1496 :param optim: optimize the model with function 

1497 @see fn onnx_optimisations 

1498 :param run_shape: in case output shapes are not specify, 

1499 the function runs function :epkg:`infer_shapes` 

1500 to guess them, False would disable that 

1501 default behaviour 

1502 :param verbose: prints information 

1503 :param function_name: if not None, returns a :epkg:`FunctionProto` 

1504 :param function_domain: in case of a function, declares the function 

1505 as part of this domain 

1506 :param fLOG: logging function 

1507 :return ONNX stucture 

1508 """ 

1509 # opsets 

1510 logger.debug( 

1511 "%s.to_onnx(%r, %r, other_outputs=%r, target_opset=%r, as_function=%r)", 

1512 self.__class__.__name__, inputs, outputs, 

1513 other_outputs, target_opset, function_name) 

1514 if isinstance(target_opset, dict): 

1515 dom = self.domain or '' 

1516 target_opset = target_opset.get(dom, None) 

1517 elif isinstance(target_opset, int): 

1518 if self.domain not in ('', None): 

1519 # The target_opset is for the domain '' we ignore it. 

1520 target_opset = None 

1521 elif target_opset is not None: 

1522 raise TypeError( # pragma: no cover 

1523 "target_opset must be a dictionary {domain: " 

1524 "target_opset} not %r for operator %r." % ( 

1525 target_opset, self.__class__.__name__)) 

1526 

1527 if self.domain in ('', None) and target_opset == 1: 

1528 raise RuntimeError( # pragma: no cover 

1529 "target_opset cannot be 1.") 

1530 if (self.op_version is not None and target_opset is not None and 

1531 self.op_version > target_opset): 

1532 raise RuntimeError( # pragma: no cover 

1533 "target_opset={} is lower than the version={} requested " 

1534 "for this node '{}'.".format( 

1535 target_opset, self.op_version, self.__class__.__name__)) 

1536 

1537 # get the graph 

1538 nodes, graph_inputs, graph_outputs, run_shape2 = self._node_to_graph( 

1539 other_outputs, inputs, outputs, as_function=function_name is not None) 

1540 logger.debug("%s.to_onnx:graph_inputs=%r", 

1541 self.__class__.__name__, graph_inputs) 

1542 logger.debug("%s.to_onnx:graph_outputs=%r", 

1543 self.__class__.__name__, graph_outputs) 

1544 if len(nodes) == 0: 

1545 raise RuntimeError( # pragma: no cover 

1546 "Node list is empty.") 

1547 if verbose > 1: 

1548 for i, n in enumerate(nodes): # pragma: no cover 

1549 fLOG("nodes[%d]=%r" % (i, n)) 

1550 for i, n in enumerate(graph_inputs): # pragma: no cover 

1551 fLOG("graph_inputs[%d]=%r" % (i, n)) 

1552 

1553 # creates a _GraphBuilder 

1554 builder = _GraphBuilder() 

1555 

1556 # reserve input names starting by the first one 

1557 for node in reversed(nodes): 

1558 for var in node.inputs: 

1559 if isinstance(var, Variable): 

1560 logger.debug("%s.to_onnx:_add_name(%r)", 

1561 self.__class__.__name__, var.name) 

1562 builder._add_name(var.name) 

1563 

1564 # reserve output names starting by the last ones 

1565 for node in reversed(nodes): 

1566 builder.reserve_names(node, node.output_names) 

1567 

1568 # adds every node to the builder 

1569 for i, node in enumerate(nodes): 

1570 logger.debug("%s.to_onnx:node:%d/%d:%r", 

1571 self.__class__.__name__, i, len(nodes), node) 

1572 

1573 for node in nodes: 

1574 node.add_to(builder) 

1575 

1576 return builder.to_onnx( 

1577 inputs=graph_inputs, outputs=graph_outputs, 

1578 target_opset=target_opset, verbose=verbose, 

1579 optim=optim, run_shape=run_shape and run_shape2, 

1580 function_name=function_name, function_domain=function_domain) 

1581 

1582 def predecessors(self): 

1583 """ 

1584 Returns the list of predecessors. 

1585 

1586 :return: list of @see cl OnnxOperator 

1587 """ 

1588 stack = [self] 

1589 last = 0 

1590 while True: 

1591 end = len(stack) 

1592 if end == last: 

1593 break 

1594 for i in range(last, end): 

1595 node = stack[i] 

1596 for inp in node.inputs: 

1597 if isinstance(inp, OnnxOperatorBase): 

1598 stack.append(inp) 

1599 last = end 

1600 return stack 

1601 

1602 def __call__(self, *args, function_name=None, function_domain=None, 

1603 **kwargs): 

1604 """ 

1605 Creates an instance of class @see cl OnnxOperatorFunction. 

1606 Equivalent to `OnnxOperatorFunction(proto, *args, **kwargs)`. 

1607 

1608 :param args: see @see cl OnnxOperatorFunction 

1609 :param function_name: name to be given to the function 

1610 :param function_domain: function domain, if None, 

1611 it is given a default value 

1612 :param kwargs: see @see cl OnnxOperatorFunction 

1613 :return: instance of type @see cl OnnxOperatorFunction 

1614 """ 

1615 if function_name is None: 

1616 def clean(name): 

1617 if name.startswith("Onnx"): 

1618 name = name[4:] 

1619 return name 

1620 

1621 pred = self.predecessors() 

1622 cls = [clean(p.__class__.__name__) for p in pred] 

1623 function_name = "".join(cls) 

1624 onx = self.to_onnx(function_name=function_name, 

1625 function_domain=function_domain) 

1626 return OnnxOperatorFunction(onx, *args, **kwargs) 

1627 

1628 def find_named_inputs(self): 

1629 """ 

1630 Retrieves all named inputs in this graph. 

1631 """ 

1632 unique = set() 

1633 found = [] 

1634 for inp in self.inputs: 

1635 if isinstance(inp, str): 

1636 if inp not in unique: 

1637 found.append(inp) 

1638 unique.add(inp) 

1639 elif isinstance(inp, Variable): 

1640 if inp.name not in unique: 

1641 found.append(inp.name) 

1642 unique.add(inp.name) 

1643 elif isinstance(inp, OnnxOperatorBase): 

1644 f = inp.find_named_inputs() 

1645 for n in f: 

1646 if n not in unique: 

1647 found.append(n) 

1648 unique.add(n) 

1649 elif isinstance(inp, numpy.ndarray): 

1650 pass 

1651 else: 

1652 raise RuntimeError( # pragma: no cover 

1653 "Unexpected input type %r." % type(inp)) 

1654 return found 

1655 

1656 def to_onnx_this(self, evaluated_inputs): 

1657 """ 

1658 Returns a simple ONNX graph corresponding to this node. 

1659 

1660 :param evaluated_inputs: inputs as a list 

1661 :return: ONNX graph 

1662 """ 

1663 inputs_names = ['I%d' % i for i in range(len(evaluated_inputs))] 

1664 if self.output_names is None: 

1665 if self.expected_outputs is None: 

1666 raise NotImplementedError( 

1667 "expected_outputs and output_names are not defined.") 

1668 output_names = [o[0] for o in self.expected_outputs] 

1669 else: 

1670 output_names = [o.name for o in self.output_names] 

1671 node = make_node(self.op_type, inputs_names, output_names, 

1672 domain=self.domain, name="f", **self.kwargs) 

1673 onx_inputs = [Variable(name, a.dtype).make_value_info() 

1674 for name, a in zip(inputs_names, evaluated_inputs)] 

1675 onx_outputs = [make_value_info(name, make_tensor_type_proto(0, [])) 

1676 for name in output_names] 

1677 graph = make_graph([node], 'f', onx_inputs, onx_outputs) 

1678 model = make_model( 

1679 graph, opset_imports=[make_operatorsetid( 

1680 self.domain or '', self.since_version)]) 

1681 return model 

1682 

1683 def run(self, *inputs, verbose=0, fLOG=None, clear_cache=False, runtime=None): 

1684 """ 

1685 Other name for 

1686 `OnnxInference.f <mlprodict.onnxrt.onnx_inference.OnnxInference.f>`_. 

1687 """ 

1688 return self.f(*inputs, verbose=verbose, fLOG=fLOG, 

1689 clear_cache=clear_cache, runtime=runtime) 

1690 

1691 def f(self, *inputs, verbose=0, fLOG=None, # pylint: disable=W0221 

1692 clear_cache=False, runtime=None): 

1693 """ 

1694 Computes the predictions for this node. 

1695 Similar to an eager evaluation. 

1696 

1697 :param inputs: inputs as dictionary or a list of inputs 

1698 (see below) 

1699 :param verbose: display information while predicting 

1700 :param fLOG: logging function if *verbose > 0* 

1701 :param clear_cache: onnx graph is created once unless 

1702 this parameter is True 

1703 :param runtime: runtime to use for the evaluation, 

1704 see @see cl OnnxInference 

1705 :return: outputs as a dictionary if the input were given as a 

1706 dictionary or a single result or a tuple otherwise 

1707 

1708 The inputs refer to the inputs of the graph. 

1709 The method walks through all inputs and finds inputs defined as 

1710 string. It replaces them by the value found in the dictionary. 

1711 If the inputs are specified in a list, the function retrieves the 

1712 list of inputs defined as a string and assigns them a value. 

1713 Logging function can be used to get more insight about it. 

1714 During the evaluation every node is independently converted 

1715 into ONNX. The ONNX graph is cached in the class itself. 

1716 """ 

1717 # input evaluation 

1718 if len(inputs) == 1 and isinstance(inputs[0], dict): 

1719 dict_inputs = inputs[0] 

1720 as_dict = True 

1721 elif not isinstance(inputs, (tuple, list)): 

1722 raise TypeError( # pragma: no cover 

1723 "inputs must be a list not %r." % type(inputs)) 

1724 elif len(inputs) > 0 and isinstance(inputs[0], OnnxOperator): 

1725 raise TypeError( # pragma: no cover 

1726 "Unexpected type for inputs[0]: %r." % type(inputs[0])) 

1727 else: 

1728 as_dict = False 

1729 if verbose > 0: 

1730 fLOG( # pragma: no cover 

1731 "[OnnxOperator.f] retrieves named inputs") 

1732 if hasattr(self, "feval_named_inputs_"): 

1733 named_inputs = self.feval_named_inputs_ # pylint: disable=E0203 

1734 else: 

1735 named_inputs = self.find_named_inputs() 

1736 self.feval_named_inputs_ = named_inputs 

1737 if len(named_inputs) != len(inputs): 

1738 raise RuntimeError( 

1739 "Mismatch between the number of found inputs (%d) and " 

1740 "the number of given inputs (%d) (found %r)." 

1741 "" % ( 

1742 len(named_inputs), len(inputs), named_inputs)) 

1743 dict_inputs = { 

1744 name: value for name, value in zip(named_inputs, inputs)} 

1745 if verbose > 0: 

1746 fLOG( # pragma: no cover 

1747 "[OnnxOperator.f] found inputs: %r" % (named_inputs, )) 

1748 

1749 # conversion 

1750 evaluated_inputs = [] 

1751 for i, inp in enumerate(self.inputs): 

1752 if isinstance(inp, str): 

1753 evaluated_inputs.append(dict_inputs[inp]) 

1754 elif isinstance(inp, Variable): 

1755 evaluated_inputs.append(dict_inputs[inp.name]) 

1756 elif isinstance(inp, OnnxOperatorBase): 

1757 if verbose > 0: 

1758 fLOG( # pragma: no cover 

1759 "[OnnxOperator.f] evaluate input %d (op_type=%r)" % ( 

1760 i, self.__class__.op_type)) 

1761 out = inp.f(dict_inputs, verbose=verbose, fLOG=fLOG) 

1762 if isinstance(out, dict): 

1763 if len(out) == 1: 

1764 evaluated_inputs.append(out.popitem()[1]) 

1765 else: 

1766 raise NotImplementedError( 

1767 "Not yet implemented in case when there are multiple " 

1768 "outputs (%r)." % list(out)) 

1769 elif isinstance(out, list): 

1770 evaluated_inputs.extend(out) 

1771 else: 

1772 evaluated_inputs.append(out) 

1773 elif isinstance(inp, numpy.ndarray): 

1774 evaluated_inputs.append(inp) 

1775 else: 

1776 raise RuntimeError( # pragma: no cover 

1777 "Unexpected type %r for input %d." % (type(inp), i)) 

1778 

1779 # conversion to ONNX 

1780 if not hasattr(self, 'feval_onnx_'): 

1781 self.feval_onnx_ = {} 

1782 key = tuple((m.dtype, m.shape) for m in evaluated_inputs) 

1783 if key not in self.feval_onnx_ or clear_cache: 

1784 if verbose > 0: 

1785 fLOG("[OnnxOperator.f] creating node %r, inputs=%r" % ( 

1786 self.op_type, key)) 

1787 from ..onnxrt import OnnxInference 

1788 model = self.to_onnx_this(evaluated_inputs) 

1789 oinf = OnnxInference(model, runtime=runtime) 

1790 self.feval_onnx_[key] = oinf 

1791 else: 

1792 oinf = self.feval_onnx_[key] 

1793 

1794 # execution 

1795 if verbose > 0: 

1796 fLOG("[OnnxOperator.f] execute node %r" % self.op_type) 

1797 got = oinf.run({k: v for k, v in 

1798 zip(oinf.input_names, evaluated_inputs)}) 

1799 if as_dict: 

1800 return got 

1801 if len(got) == 1: 

1802 return got.popitem()[1] 

1803 return [got[n] for n in oinf.output_names] 

1804 

1805 @staticmethod 

1806 def _merge_op_version(n1, n2): 

1807 if isinstance(n2, OnnxOperator): 

1808 if n1.op_version is None: 

1809 opv = n2.op_version 

1810 elif n2.op_version is None: 

1811 opv = n1.op_version 

1812 elif n1.op_version == n2.op_version: 

1813 opv = n1.op_version 

1814 else: 

1815 opv = max(n1.op_version, n2.op_version) 

1816 elif isinstance(n2, OnnxOperatorItem): 

1817 opv = OnnxOperator._merge_op_version(n1, n2.onx_op) 

1818 elif isinstance(n2, OnnxOperatorTuple): 

1819 raise NotImplementedError( # pragma: no cover 

1820 "_merge_op_version is not implemented when n2 " 

1821 "is OnnxOperatorTuple.") 

1822 else: 

1823 opv = n1.op_version 

1824 return opv 

1825 

1826 def __add__(self, ov): 

1827 """ 

1828 Automatically adds operator `OnnxAdd` to the graph. 

1829 

1830 :param ov: onnx node 

1831 :return: `OnnxAdd(self, ov)` 

1832 """ 

1833 OnnxAdd = loadop('Add') 

1834 opv = self._merge_op_version(self, ov) 

1835 return OnnxAdd(self, ov, op_version=opv) 

1836 

1837 def __sub__(self, ov): 

1838 """ 

1839 Automatically adds operator `OnnxSub` to the graph. 

1840 

1841 :param ov: onnx node 

1842 :return: `OnnxSub(self, ov)` 

1843 """ 

1844 OnnxSub = loadop('Sub') 

1845 opv = self._merge_op_version(self, ov) 

1846 return OnnxSub(self, ov, op_version=opv) 

1847 

1848 def __mul__(self, ov): 

1849 """ 

1850 Automatically adds operator `OnnxMul` to the graph. 

1851 

1852 :param ov: onnx node 

1853 :return: `OnnxMul(self, ov)` 

1854 """ 

1855 OnnxMul = loadop('Mul') 

1856 opv = self._merge_op_version(self, ov) 

1857 return OnnxMul(self, ov, op_version=opv) 

1858 

1859 def __truediv__(self, ov): 

1860 """ 

1861 Automatically adds operator `OnnxDiv` to the graph. 

1862 

1863 :param ov: onnx node 

1864 :return: `OnnxDiv(self, ov)` 

1865 """ 

1866 OnnxDiv = loadop('Div') 

1867 opv = self._merge_op_version(self, ov) 

1868 return OnnxDiv(self, ov, op_version=opv) 

1869 

1870 def __pow__(self, ov): 

1871 """ 

1872 Automatically adds operator `OnnxPow` to the graph. 

1873 

1874 :param ov: onnx node 

1875 :return: `OnnPow(self, ov)` 

1876 """ 

1877 OnnxPow = loadop('Pow') 

1878 opv = self._merge_op_version(self, ov) 

1879 return OnnxPow(self, ov, op_version=opv) 

1880 

1881 def __mod__(self, ov): 

1882 """ 

1883 Automatically adds operator `OnnxMod` to the graph. 

1884 

1885 :param ov: onnx node 

1886 :return: `OnnxMod(self, ov)` 

1887 """ 

1888 OnnxMod = loadop('Mod') 

1889 opv = self._merge_op_version(self, ov) 

1890 return OnnxMod(self, ov, op_version=opv) 

1891 

1892 def __matmul__(self, ov): 

1893 """ 

1894 Automatically adds operator `OnnxMatMul` to the graph. 

1895 

1896 :param ov: onnx node 

1897 :return: `OnnMatMul(self, ov)` 

1898 """ 

1899 OnnxMatMul = loadop('MatMul') 

1900 opv = self._merge_op_version(self, ov) 

1901 return OnnxMatMul(self, ov, op_version=opv) 

1902 

1903 def __gt__(self, ov): 

1904 """ 

1905 Automatically adds operator `OnnxGreater` to the graph. 

1906 

1907 :param ov: onnx node 

1908 :return: `OnnxGreater(self, ov)` 

1909 """ 

1910 OnnxGreater = loadop('Greater') 

1911 opv = self._merge_op_version(self, ov) 

1912 return OnnxGreater(self, ov, op_version=opv) 

1913 

1914 def __lt__(self, ov): 

1915 """ 

1916 Automatically adds operator `OnnxLess` to the graph. 

1917 

1918 :param ov: onnx node 

1919 :return: `OnnxLess(self, ov)` 

1920 """ 

1921 OnnxLess = loadop('Less') 

1922 opv = self._merge_op_version(self, ov) 

1923 return OnnxLess(self, ov, op_version=opv) 

1924 

1925 def __eq__(self, ov): 

1926 """ 

1927 Automatically adds operator `OnnxEqual` to the graph. 

1928 

1929 :param ov: onnx node 

1930 :return: `OnnxEqual(self, ov)` 

1931 """ 

1932 OnnxEqual = loadop('Equal') 

1933 opv = self._merge_op_version(self, ov) 

1934 return OnnxEqual(self, ov, op_version=opv) 

1935 

1936 def and_(self, ov): 

1937 """ 

1938 Automatically adds operator `OnnxAnd` to the graph. 

1939 

1940 :param ov: onnx node 

1941 :return: `OnnxAnd(self, ov)` 

1942 """ 

1943 OnnxAnd = loadop('And') 

1944 opv = self._merge_op_version(self, ov) 

1945 return OnnxAnd(self, ov, op_version=opv) 

1946 

1947 def or_(self, ov): 

1948 """ 

1949 Automatically adds operator `OnnxOr` to the graph. 

1950 

1951 :param ov: onnx node 

1952 :return: `OnnxOr(self, ov)` 

1953 """ 

1954 OnnxOr = loadop('Or') 

1955 opv = self._merge_op_version(self, ov) 

1956 return OnnxOr(self, ov, op_version=opv) 

1957 

1958 def __ne__(self, ov): 

1959 """ 

1960 Automatically adds operator `OnnxNot x OnnxEqual` to the graph. 

1961 

1962 :param ov: onnx node 

1963 :return: `OnnxNot(OnnxEqual(self, ov))` 

1964 """ 

1965 OnnxNot, OnnxEqual = loadop('Not', 'Equal') 

1966 opv = self._merge_op_version(self, ov) 

1967 return OnnxNot(OnnxEqual(self, ov, op_version=opv), op_version=opv) 

1968 

1969 def __abs__(self): 

1970 """ 

1971 Automatically adds operator `OnnxAbs` to the graph. 

1972 

1973 :param ov: onnx node 

1974 :return: `OnnxAbs(self, ov)` 

1975 """ 

1976 OnnxAbs = loadop('Abs') 

1977 return OnnxAbs(self, op_version=self.op_version) 

1978 

1979 def not_(self): 

1980 """ 

1981 Automatically adds operator `OnnxNot` to the graph. 

1982 

1983 :param ov: onnx node 

1984 :return: `OnnxNot(self, ov)` 

1985 """ 

1986 OnnxNot = loadop('Not') 

1987 return OnnxNot(self, op_version=self.op_version) 

1988 

1989 def astype(self, to): 

1990 """ 

1991 Automatically adds operator `OnnxCast` to the graph. 

1992 

1993 :param ov: onnx node 

1994 :return: `OnnxCast(self, ov, to=to)` 

1995 """ 

1996 OnnxCast = loadop('Cast') 

1997 return OnnxCast(self, to=to, op_version=self.op_version) 

1998 

1999 

2000class OnnxOperatorFunction(OnnxOperator): 

2001 """ 

2002 This operator is used to insert existing ONNX function into 

2003 the ONNX graph being built. 

2004 """ 

2005 

2006 domain = 'mlprodict' 

2007 since_version = 1 

2008 expected_inputs = None 

2009 expected_outputs = None 

2010 input_range = [1, 1e9] 

2011 output_range = [1, 1e9] 

2012 op_type = 'Function' 

2013 domain = 'mlprodict.xop' 

2014 

2015 @staticmethod 

2016 def attribute_to_value(att): 

2017 """ 

2018 Converts an attribute into a value using python structures. 

2019 """ 

2020 if isinstance(att, onnx.AttributeProto): 

2021 dtype = att.type 

2022 else: 

2023 raise NotImplementedError( # pragma: no cover 

2024 "Unable to copy attribute type %r." % type(att)) 

2025 if dtype == 1: # .f 

2026 value = att.f 

2027 elif dtype == 2: # .i 

2028 value = att.i 

2029 elif dtype == 3: # .s 

2030 value = att.s 

2031 elif dtype == 4: # .t 

2032 value = att.t 

2033 elif dtype == 6: # .floats 

2034 value = list(att.floats) 

2035 elif dtype == 7: # .ints 

2036 value = list(att.ints) 

2037 elif dtype == 8: # .strings 

2038 value = list(att.strings) 

2039 elif dtype == 11: # .double_data 

2040 value = list(att.double_data) 

2041 else: 

2042 raise NotImplementedError( # pragma: no cover 

2043 "Unable to copy attribute type %r (%r)." % ( 

2044 dtype, att)) 

2045 return value 

2046 

2047 def __init__(self, function_proto, *inputs, output_names=None): 

2048 logger.debug("Function(ONNX, %d in, output_names=%r)", 

2049 len(inputs), output_names) 

2050 if function_proto is None: 

2051 raise ValueError( 

2052 "function_proto cannot be None.") # pragma: no cover 

2053 if not isinstance(function_proto, onnx.FunctionProto): 

2054 raise TypeError( # pragma: no cover 

2055 "function_proto must be of type FunctionProto not %r." % 

2056 type(function_proto)) 

2057 if len(inputs) > len(function_proto.input): 

2058 raise RuntimeError( # pragma: no cover 

2059 "Unexpected number of inputs %r > expected %r." % ( 

2060 len(inputs), len(function_proto.input))) 

2061 if (output_names is not None and 

2062 len(output_names) != len(function_proto.output)): 

2063 raise RuntimeError( # pragma: no cover 

2064 "Unexpected number of outputs %r != expected %r." % ( 

2065 len(output_names), len(function_proto.output))) 

2066 OnnxOperator.__init__(self, *inputs, output_names=output_names) 

2067 self.model = function_proto 

2068 

2069 def __repr__(self): 

2070 "usual" 

2071 atts = {} 

2072 for att in ['output_names']: 

2073 value = getattr(self, att, None) 

2074 if value is not None: 

2075 atts[att] = value 

2076 atts.update(self.kwargs) 

2077 msg = ", ".join("%s=%r" % (k, v) for k, v in atts.items()) 

2078 if len(atts) > 0: 

2079 msg = ", " + msg 

2080 return "%s(...%s)" % ( 

2081 self.__class__.__name__, msg) 

2082 

2083 def add_to(self, builder): 

2084 """ 

2085 Adds to graph builder. 

2086 

2087 :param builder: instance of @see cl _GraphBuilder, 

2088 it must have a method `add_node` 

2089 """ 

2090 logger.debug("Function.add_to(builder)") 

2091 inputs = builder.get_input_names(self, self.inputs) 

2092 n_outputs = len(self.model.output) 

2093 outputs = [builder.get_unique_output_name(NodeResultName(self, i)) 

2094 for i in range(n_outputs)] 

2095 

2096 # linking inputs 

2097 builder.add_function(self.model) 

2098 builder.add_node( 

2099 self.model.name, builder.get_unique_name( 

2100 '_fct_' + self.model.name, reserved=False), 

2101 inputs, outputs, domain=self.model.domain) 

2102 

2103 

2104class _GraphBuilder: 

2105 """ 

2106 Graph builder. It takes a graph structure made with 

2107 instances of @see cl OnnxOperatorBase. 

2108 The main method is `to_onnx`. 

2109 

2110 * `initializer`: list of initializers to add to the ONNX graph 

2111 * `node`: list of nodes to add to the ONNX graph 

2112 * `input`: list of inputs to add to the ONNX graph 

2113 * `output`: list of inputs to add to the ONNX graph 

2114 * `opsets`: opsets of the ONNX graph 

2115 * `input_names`: dictionary of input names 

2116 `{name: InputDetectedVariable}` 

2117 * `node_output_names`: memorizes a name for a node output 

2118 when the user did not specify any 

2119 `{(id(node), index): OutputDetectedVariable}` 

2120 * `reserved_names`: dictionary `{ name : (node, index) }`, 

2121 name which should remain unchanged in the ONNX graph 

2122 * `names`: list of uniques names 

2123 * `functions`: dictionary `{ domain, name: function_proto }` 

2124 * `function_hashes`: dictionary `{ domain, name: hash of function_proto }` 

2125 """ 

2126 

2127 def __init__(self): 

2128 self.initializer = [] 

2129 self.node = [] 

2130 self.input = [] 

2131 self.output = [] 

2132 self.opsets = {} 

2133 self.input_names = {} 

2134 self.node_output_names = {} 

2135 self.reserved_names = {} 

2136 self.names = set() 

2137 self.functions = {} 

2138 self.function_hashes = {} 

2139 

2140 def _add_name(self, name): 

2141 self.names.add(name) 

2142 

2143 @staticmethod 

2144 def number2alpha(index): 

2145 """ 

2146 Converts a numbers into a string keeping the same 

2147 alphabetical order. 

2148 """ 

2149 dec = str(int(index)) 

2150 if len(dec) == 1: 

2151 return dec 

2152 return chr(96 + len(dec)) + dec 

2153 

2154 def reserve_names(self, node, output_names): 

2155 """ 

2156 Adds names to the list of reserved names. 

2157 All must be unique. 

2158 

2159 :param node: node or None for an input 

2160 :param output_names: names of the output 

2161 """ 

2162 if output_names is None: 

2163 return 

2164 for index, var in enumerate(output_names): 

2165 if not isinstance(var, Variable): 

2166 raise TypeError( # pragma: no cover 

2167 "Unexpected type %r for %r." % (type(var), var)) 

2168 self.reserve_name(node, var.name, index) 

2169 

2170 def reserve_name(self, node, name, index): 

2171 """ 

2172 Reserves a name so that it cannot be changed. 

2173 

2174 :param node: node or None for an input 

2175 :param name: name 

2176 :param index: input index 

2177 """ 

2178 if not isinstance(name, str): 

2179 raise TypeError( # pragma: no cover 

2180 "Name %r is not a string." % (name, )) 

2181 if name in self.reserved_names: 

2182 raise RuntimeError( # pragma: no cover 

2183 "Name %r is already reserved from node %r, index=%d." % ( 

2184 name, node, index)) 

2185 logger.debug("_GraphBuilder.reserve_name([%s-%d], %r, %r)", 

2186 node.__class__.__name__, id(node), 

2187 name, index) 

2188 self.reserved_names[name] = (node, index) 

2189 self._add_name(name) 

2190 

2191 def get_unique_output_name(self, result): 

2192 """ 

2193 Returns a unique output_name for a NodeResultName. 

2194 

2195 :param result: instance of @see cl NodeResultName 

2196 """ 

2197 if not isinstance(result, NodeResultName): 

2198 raise TypeError( # pragma: no cover 

2199 "Result must be of type NodeResultName not %r (%r)." % ( 

2200 type(result), result)) 

2201 if result.node is None: 

2202 key = None, result.index 

2203 else: 

2204 key = id(result.node), result.index 

2205 if key in self.node_output_names: 

2206 return self.node_output_names[key] 

2207 name = result.get_name() 

2208 if name in self.reserved_names: 

2209 unique = name 

2210 else: 

2211 unique = self.get_unique_name(name) 

2212 self.node_output_names[key] = unique 

2213 return unique 

2214 

2215 def get_unique_name(self, name, reserved=True): 

2216 """ 

2217 Returns a unique name to name an output. 

2218 

2219 :param name: name 

2220 :param reserved: bypass if the name is a reserved one 

2221 :return: unique name, may be the same if not taken already 

2222 """ 

2223 if not isinstance(name, str): 

2224 raise TypeError( # pragma: no cover 

2225 "name must be a string not %r." % type(name)) 

2226 if reserved and name in self.reserved_names: 

2227 logger.debug( # pragma: no cover 

2228 "_GraphBuilder.get_unique_name(%r) 1-> %r", name, name) 

2229 return name 

2230 if name not in self.names: 

2231 self._add_name(name) 

2232 logger.debug("_GraphBuilder.get_unique_name(%r) 2-> %r", 

2233 name, name) 

2234 return name 

2235 i = 1 

2236 new_name = "%s_%s" % (name, self.number2alpha(i)) 

2237 while new_name in self.names: 

2238 i += 1 

2239 new_name = "%s_%s" % (name, self.number2alpha(i)) 

2240 self._add_name(new_name) 

2241 logger.debug("_GraphBuilder.get_unique_name(%r) 3-> %r", 

2242 name, new_name) 

2243 return new_name 

2244 

2245 def get_input_names(self, node, inputs): 

2246 """ 

2247 Returns input names for node *node* and inputs *inputs*. 

2248 

2249 :param node: node 

2250 :param inputs: inputs 

2251 :return: name 

2252 """ 

2253 names = [] 

2254 for i in inputs: 

2255 if isinstance(i, Variable): 

2256 self._add_name(i.name) 

2257 names.append(i.name) 

2258 self.input_names[i.name] = InputDetectedVariable(None, i) 

2259 elif isinstance(i, OnnxOperator): 

2260 key = id(i), 0 

2261 try: 

2262 name = self.node_output_names[key] 

2263 except KeyError as e: # pragma: no cover 

2264 raise RuntimeError( 

2265 "Unable to find key %r for input %r in node %r." % ( 

2266 key, i, node)) from e 

2267 names.append(name) 

2268 elif isinstance(i, OnnxOperatorItem): 

2269 if isinstance(i.onx_op, OnnxOperatorTuple): 

2270 if i.onx_op.values is None: 

2271 key = id(i.onx_op.unique), i.index 

2272 else: 

2273 key = id(i.onx_op[i.index]), 0 

2274 elif isinstance(i.onx_op, OnnxOperator): 

2275 key = id(i.onx_op), i.index 

2276 else: 

2277 raise TypeError( # pragma: no cover 

2278 "Unexpected type for OnnxOperatorItem: %r." % type(i.onx_op)) 

2279 try: 

2280 name = self.node_output_names[key] 

2281 except KeyError as e: # pragma: no cover 

2282 raise RuntimeError( 

2283 "Unable to find key %r for input %r in node %r." % ( 

2284 key, i, node)) from e 

2285 names.append(name) 

2286 elif isinstance(i, OnnxOperatorTuple): 

2287 raise NotImplementedError() 

2288 elif isinstance(i, numpy.ndarray): 

2289 # Adding an initializer 

2290 name = self.get_unique_name('init', reserved=False) 

2291 init = from_array(i, name) 

2292 self.initializer.append(init) 

2293 names.append(name) 

2294 else: 

2295 raise TypeError( # pragma: no cover 

2296 "Unexpected type for an input %r." % type(i)) 

2297 return names 

2298 

2299 def add_initializer(self, name, init): 

2300 """ 

2301 Adds an initializer to the graph. 

2302 

2303 :param name: initializer name 

2304 :param init: initializer to copy 

2305 :return: created intializer 

2306 """ 

2307 if isinstance(init, onnx.TensorProto): 

2308 tensor = to_array(init) 

2309 val = from_array(tensor, name) 

2310 logger.debug("_GraphBuilder.add_initializer:1(%r, %r, %r)", 

2311 name, tensor.dtype, tensor.shape) 

2312 elif isinstance(init, numpy.ndarray): 

2313 value = to_array(init) 

2314 val = from_array(value, name) 

2315 logger.debug("_GraphBuilder.add_initializer:2(%r, %r, %r)", 

2316 name, init.dtype, init.shape) 

2317 else: 

2318 raise NotImplementedError( # pragma: no cover 

2319 "Unsupported initializer type %r." % type(init)) 

2320 self.initializer.append(val) 

2321 return val 

2322 

2323 def add_function(self, function_proto, 

2324 raise_if_exist=False, check_unique=True, 

2325 opset=1): 

2326 """ 

2327 Adds a function to the graph. 

2328 

2329 :param function_proto: instance of type :epkg:`FunctionProto` 

2330 :param raise_if_exist: raises an exception if a function of the 

2331 same name was already added 

2332 :param check_unique: checks if a function was added twice, 

2333 it is the same 

2334 :param opset: opset for the domain the function belongs to 

2335 """ 

2336 def _hash(p): 

2337 m = hashlib.sha256() 

2338 m.update(p.SerializeToString()) 

2339 return m.hexdigest()[:64] 

2340 

2341 key = function_proto.domain, function_proto.name 

2342 if key in self.functions: 

2343 if raise_if_exist: 

2344 raise RuntimeError( # pragma: no cover 

2345 "Function %r is added for the second time." % (key, )) 

2346 if check_unique: 

2347 hs = _hash(function_proto) 

2348 if hs != self.function_hashes[key]: 

2349 raise RuntimeError( # pragma: no cover 

2350 "Function %r is added for the second time " 

2351 "and the content is not the same." % (key, )) 

2352 return 

2353 self.functions[key] = function_proto 

2354 self.function_hashes[key] = _hash(function_proto) 

2355 

2356 if function_proto.domain not in self.opsets: 

2357 self.opsets[function_proto.domain] = opset 

2358 else: 

2359 self.opsets[function_proto.domain] = max( 

2360 opset, self.opsets[function_proto.domain]) 

2361 

2362 def add_node(self, op_type, name, inputs, outputs, domain='', 

2363 opset=None, **attributes): 

2364 """ 

2365 Adds a node to the graph. 

2366 

2367 :param op_type: operator type 

2368 :param name: node name 

2369 :param inputs: inputs name list 

2370 :param outputs: outputs name list 

2371 :param domain: node domain 

2372 :param opset: node opset 

2373 :return: created node 

2374 """ 

2375 if domain is None: 

2376 domain = '' 

2377 logger.debug("_GraphBuilder.add_node(%r, %r, " 

2378 "inputs=%r, outputs=%r, domain=%r, opset=%r)", 

2379 op_type, name, inputs, outputs, domain, opset) 

2380 if not isinstance(inputs, list): 

2381 raise TypeError( # pragma: no cover 

2382 "inputs must be a list not %r." % type(inputs)) 

2383 if not isinstance(outputs, list): 

2384 raise TypeError( # pragma: no cover 

2385 "inputs must be a list not %r." % type(outputs)) 

2386 if any(map(lambda x: not isinstance(x, str), inputs)): 

2387 raise TypeError( # pragma: no cover 

2388 "inputs must be all strings not %r." % inputs) 

2389 if any(map(lambda x: not isinstance(x, str), outputs)): 

2390 raise TypeError( # pragma: no cover 

2391 "outputs must be all strings not %r." % outputs) 

2392 if opset is not None: 

2393 if domain not in self.opsets: 

2394 self.opsets[domain] = opset 

2395 else: 

2396 self.opsets[domain] = max(opset, self.opsets[domain]) 

2397 node = make_node(op_type, inputs, outputs, name=name, 

2398 domain=domain, **attributes) 

2399 self.node.append(node) 

2400 return node 

2401 

2402 def _process_io(self, inputs, input_names): 

2403 if inputs is None: 

2404 return [ 

2405 make_tensor_value_info( 

2406 'X', TensorProto.FLOAT, None) # pylint: disable=E1101 

2407 for name in self.input_names] 

2408 

2409 if not isinstance(inputs, list): 

2410 if is_numpy_dtype(inputs): 

2411 inputs = [inputs] 

2412 

2413 if input_names is None: 

2414 # outputs 

2415 set_names = set() 

2416 input_names = [] 

2417 new_inputs = [] 

2418 for inp in inputs: 

2419 if isinstance(inp, OutputDetectedVariable): 

2420 if inp.name in set_names: 

2421 raise ValueError( # pragma: no cover 

2422 "Names already taken %r in %r." % ( 

2423 inp.name, inputs)) 

2424 set_names.add(inp.name) 

2425 key = id(inp.node), inp.index 

2426 if key in self.node_output_names: 

2427 new_name = self.node_output_names[key] 

2428 new_var = OutputDetectedVariable( 

2429 inp.node, inp.var.copy_name(new_name), inp.index) 

2430 input_names.append(new_var) 

2431 new_inputs.append(new_var) 

2432 else: 

2433 raise RuntimeError( # pragma: no cover 

2434 "Key %r is ambiguous or defined in " 

2435 "two nodes %r, id(node)=%d, index=%d." % ( 

2436 key, inp, id(inp.node), inp.index)) 

2437 else: 

2438 raise TypeError( # pragma: no cover 

2439 "Unexpected type %r (it should be " 

2440 "OutputDetectedVariable) in %r." % (inp, inputs)) 

2441 inputs = new_inputs 

2442 if len(input_names) == 0: 

2443 raise RuntimeError( # pragma: no cover 

2444 "Unable to cross %r and %r or %r (set_names=%r)." % ( 

2445 inputs, self.output_names_rev, 

2446 self.node_output_names_rev, set_names)) 

2447 elif not isinstance(input_names, list): 

2448 raise RuntimeError( # pragma: no cover 

2449 "Unexpected type for input_names %r." % type(input_names)) 

2450 else: 

2451 # inputs 

2452 pass 

2453 

2454 # common parts 

2455 if len(input_names) != len(inputs): 

2456 raise RuntimeError( # pragma: no cover 

2457 "Mismatch between %r and %r." % ( 

2458 input_names, inputs)) 

2459 

2460 if isinstance(input_names, list): 

2461 d_input_names = {} 

2462 for inp in input_names: 

2463 if inp.name in d_input_names: 

2464 raise ValueError( # pragma: no cover 

2465 "Duplicated name %r in %r." % (inp.name, input_names)) 

2466 d_input_names[inp.name] = inp 

2467 elif isinstance(input_names, dict): 

2468 d_input_names = input_names 

2469 else: 

2470 raise TypeError( # pragma: no cover 

2471 "Unexpected type for input_names %r (%r)." % ( 

2472 type(input_names), input_names)) 

2473 

2474 # mapping 

2475 res = [] 

2476 for inp in inputs: 

2477 if not isinstance(inp, DetectedVariable): 

2478 raise TypeError( # pragma: no cover 

2479 "inp not DetectedVariable but %r (%r)" 

2480 "." % (type(inp), inp)) 

2481 if inp.name.startswith('???'): 

2482 raise RuntimeError( # pragma: no cover 

2483 "Issue with variable %r." % inp) 

2484 var = d_input_names[inp.name] 

2485 if not isinstance(var, DetectedVariable): 

2486 raise TypeError( # pragma: no cover 

2487 "var not Variable but %r (%r)." % ( 

2488 type(var), var)) 

2489 # inp: Variable 

2490 # var: str 

2491 if inp.var != var.var: 

2492 raise RuntimeError( # pragma: no cover 

2493 "Unexpected %r != %r." % (inp, var)) 

2494 res.append(make_tensor_value_info( 

2495 inp.name, inp.var.proto_added_type, 

2496 inp.var.proto_added_shape)) 

2497 

2498 return res 

2499 

2500 def to_onnx(self, inputs=None, outputs=None, 

2501 target_opset=None, run_shape=False, 

2502 optim=True, function_name=None, 

2503 function_domain=None, verbose=0): 

2504 """ 

2505 Converts this operator into an ONNX graph. 

2506 

2507 :param inputs: specific inputs (as a dictionary) or 

2508 default inputs if not specified 

2509 :param outputs: specific outputs 

2510 :param target_opset: dictionary with target opset per domain, 

2511 None for the default one 

2512 :param run_shape: run shape inference before returning the model 

2513 :param optim: optimize the model with function 

2514 @see fn onnx_optimisations 

2515 :param function_name: if not None builds a :epkg:`FunctionProto` 

2516 use this name 

2517 :param function_domain: in case of a function, declares the function 

2518 as part of this domain, `'mlprodict'` if None 

2519 :param verbose: prints information 

2520 :return: onnx graph 

2521 """ 

2522 logger.debug("_GraphBuilder.to_onnx(%r, %r, target_opset=%r)", 

2523 inputs, outputs, target_opset) 

2524 # inputs and outputs 

2525 if not all(map(lambda x: isinstance(x, InputDetectedVariable), inputs)): 

2526 raise TypeError( # pragma: no cover 

2527 "One of the input is not InputDetectedVariable.") 

2528 if not all(map(lambda x: isinstance(x, OutputDetectedVariable), outputs)): 

2529 raise TypeError( # pragma: no cover 

2530 "One of the outputs is not OutputDetectedVariable.") 

2531 self.input = self._process_io(inputs, list(self.input_names.values())) 

2532 self.output = self._process_io(outputs, None) 

2533 logger.debug("_GraphBuilder.to_onnx:self.input=%r", 

2534 [i.name for i in self.input]) 

2535 logger.debug("_GraphBuilder.to_onnx:self.output=%r", 

2536 [i.name for i in self.output]) 

2537 logger.debug("_GraphBuilder.to_onnx:build:n_inputs=%r n_inits=%r n_nodes=%r " 

2538 "n_outputs=%r", 

2539 len(self.input), len(self.initializer), len(self.node), 

2540 len(self.output)) 

2541 

2542 if function_name is not None: 

2543 if function_domain is None: 

2544 function_domain = 'mlprodict' 

2545 if len(self.initializer) > 0: 

2546 nodes = [] 

2547 for init in self.initializer: 

2548 nodes.append( 

2549 make_node('Constant', [], [init.name], value=init, 

2550 name='_init_%s' % init.name)) 

2551 nodes.extend(self.node) 

2552 else: 

2553 nodes = self.node 

2554 fct = make_function( 

2555 function_domain, function_name, 

2556 [_.name for _ in self.input], 

2557 [_.name for _ in self.output], 

2558 nodes, 

2559 [make_opsetid(k, v) for k, v in self.opsets.items()]) 

2560 if optim: 

2561 from ..onnx_tools.optim import onnx_optimisations 

2562 fct = onnx_optimisations(fct) 

2563 return fct 

2564 else: 

2565 graph = make_graph( 

2566 self.node, 'XOP', self.input, self.output, self.initializer) 

2567 onnx_model = make_model( 

2568 graph, functions=list(self.functions.values())) 

2569 opv = self.opsets.get('', max_supported_opset()) 

2570 opset2ir = _default_OPSET_TO_IR_VERSION() 

2571 irv = opset2ir.get(opv, max(opset2ir.values())) 

2572 onnx_model.ir_version = irv 

2573 

2574 logger.debug("_GraphBuilder.to_onnx:2onnx:n_inputs=%r n_inits=%r " 

2575 "n_nodes=%r n_outputs=%r", 

2576 len(onnx_model.graph.input), 

2577 len(onnx_model.graph.initializer), 

2578 len(onnx_model.graph.node), 

2579 len(onnx_model.graph.output)) 

2580 

2581 del onnx_model.opset_import[:] # pylint: disable=E1101 

2582 seen_opset = set() 

2583 for k, v in self.opsets.items(): 

2584 if (k or '') in seen_opset: 

2585 raise RuntimeError( # pragma: no cover 

2586 "Duplicated opset (%r, %r)." % (k, v)) 

2587 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

2588 op_set.domain = k or '' 

2589 op_set.version = v 

2590 seen_opset.add(op_set.domain) 

2591 

2592 # optimisation, remove redundant constant, unnecessary 

2593 # identity nodes. 

2594 if optim: 

2595 from ..onnx_tools.optim import onnx_optimisations 

2596 onnx_model = onnx_optimisations(onnx_model) 

2597 

2598 logger.debug("_GraphBuilder.to_onnx:optim:n_inputs=%r n_inits=%r " 

2599 "n_nodes=%r n_outputs=%r", 

2600 len(onnx_model.graph.input), 

2601 len(onnx_model.graph.initializer), 

2602 len(onnx_model.graph.node), 

2603 len(onnx_model.graph.output)) 

2604 

2605 if run_shape: 

2606 with_shape = infer_shapes(onnx_model) 

2607 logger.debug("_GraphBuilder.to_onnx:shape:n_inputs=%r " 

2608 "n_inits=%r n_nodes=%r n_outputs=%r", 

2609 len(with_shape.graph.input), 

2610 len(with_shape.graph.initializer), 

2611 len(with_shape.graph.node), 

2612 len(with_shape.graph.output)) 

2613 return with_shape 

2614 

2615 logger.debug("_GraphBuilder.to_onnx() -> done") 

2616 return onnx_model 

2617 

2618 

2619_all_schemas, _all_schemas_versions, _all_domains = _populate_schemas() 

2620_all_classes = {} 

2621onnx_load_factory = Xop = OnnxLoadFactory()