Coverage for mlprodict/onnxrt/ops_cpu/_op.py: 98%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

316 statements  

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Shortcut to *ops_cpu*. 

5""" 

6import pprint 

7import numpy 

8import onnx 

9import onnx.defs 

10from ..shape_object import ShapeObject 

11from ..type_object import SequenceType 

12from ._new_ops import OperatorSchema 

13 

14 

15def _build_schemas(): 

16 res = {} 

17 for schema in onnx.defs.get_all_schemas_with_history(): 

18 # Multiple version can coexist. The last one is kept. 

19 if schema.name in res: 

20 if schema.since_version > res[schema.name].since_version: 

21 # We keep the most recent one. 

22 res[schema.name] = schema 

23 else: 

24 res[schema.name] = schema 

25 res[schema.name + '_' + str(schema.since_version)] = schema 

26 return res 

27 

28 

29_schemas = _build_schemas() 

30_at_least_one = {'Constant'} 

31 

32 

33class RuntimeTypeError(RuntimeError): 

34 """ 

35 Raised when a type of a variable is unexpected. 

36 """ 

37 pass 

38 

39 

40class DefaultNone: 

41 """ 

42 Default value for parameters when the parameter is not set 

43 but the operator has a default behaviour for it. 

44 """ 

45 pass 

46 

47 

48class OpRun: 

49 """ 

50 Ancestor to all operators in this subfolder. 

51 The runtime for every node can checked into 

52 `ONNX unit tests 

53 <https://github.com/onnx/onnx/tree/master/onnx/backend/test/case/node>`_. 

54 """ 

55 

56 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

57 **options): 

58 """ 

59 @param onnx_node :epkg:`onnx` node 

60 @param desc internal representation 

61 @param expected_attributes expected attributes for this node 

62 @param options runtime options 

63 """ 

64 self._provider = 'python' 

65 self.onnx_node = onnx_node 

66 self.desc = desc 

67 self.inplaces = {} 

68 

69 if onnx_node.op_type in _schemas: 

70 self._schema = _schemas[onnx_node.op_type] 

71 else: 

72 self._schema = self._find_custom_operator_schema(onnx_node.op_type) 

73 if self._schema is None: 

74 raise RuntimeError( # pragma: no cover 

75 "Unable to find class name '{}' in available schemas:" 

76 "(onnx.__version__='{}')\n{}".format( 

77 self.__class__.__name__, 

78 onnx.__version__, 

79 "\n".join(sorted(_schemas)))) 

80 

81 if desc is not None: 

82 if 'atts' in desc: 

83 for a, b in desc['atts'].items(): 

84 if not isinstance(b, dict) or 'value' not in b: 

85 raise ValueError( # pragma: no cover 

86 "Unexpected value {}.".format(b)) 

87 options[a] = (b['value_rt'] if 'value_rt' in b 

88 else b['value']) 

89 if expected_attributes is not None: 

90 if onnx_node.op_type in _at_least_one: 

91 done = 0 

92 for a, b in expected_attributes.items(): 

93 if a in options: 

94 setattr(self, a, b) 

95 done += 1 

96 if done == 0: 

97 raise RuntimeError( # pragma: no cover 

98 "All parameters '{}' are missing from operator '{}', " 

99 "given {}.".format( 

100 a, onnx_node.op_type, list(sorted(options)))) 

101 else: 

102 for a, b in expected_attributes.items(): 

103 if a not in options: 

104 if b is DefaultNone: 

105 setattr(self, a, None) 

106 elif b is None: 

107 raise RuntimeError( # pragma: no cover 

108 "Parameter '{}' is missing from operator '{}' " 

109 "(class='{}'), given {}.".format( 

110 a, onnx_node.op_type, 

111 self.__class__.__name__, 

112 list(sorted(options)))) 

113 else: 

114 setattr(self, a, b) 

115 for k, v in options.items(): 

116 setattr(self, k, v) 

117 

118 if onnx_node.op_type not in _at_least_one: 

119 for k, v in self._schema.attributes.items(): 

120 if not hasattr(self, k) and getattr(v, 'required', True): 

121 raise RuntimeError( # pragma: no cover 

122 "Attribute '{}' is expected based on ONNX specifications " 

123 "for node '{}' and options {}.".format( 

124 k, onnx_node.op_type, pprint.pformat(options))) 

125 

126 def need_context(self): 

127 """ 

128 Tells the runtime if this node needs the context 

129 (all the results produced so far) as it may silently access 

130 one of them (operator Loop). 

131 The default answer is `False`. 

132 """ 

133 return False 

134 

135 def _find_custom_operator_schema(self, op_name): 

136 raise NotImplementedError( # pragma: no cover 

137 "This method should be overwritten for operator " 

138 "'{}'.".format(op_name)) 

139 

140 def __str__(self): 

141 """ 

142 usual 

143 """ 

144 atts = [self.__class__.__name__ + '(', 

145 " op_type={}".format(self.onnx_node.op_type)] 

146 for k, v in sorted(self.__dict__.items()): 

147 if k in {'desc', 'onnx_node'}: 

148 continue 

149 if 'a' <= k[0] <= 'z' and k[-1] != '_': 

150 atts.append(' {0}={1},'.format(k, v)) 

151 atts.append(')') 

152 return "\n".join(atts) 

153 

154 def _run(self, *args, **kwargs): 

155 """ 

156 Should be overwritten. 

157 """ 

158 raise NotImplementedError( # pragma: no cover 

159 "Method '_run' or 'to_python' should be overwritten for operator %s." 

160 "" % self.__class__.__name__) 

161 

162 def run(self, *args, **kwargs): # pylint: disable=E0202 

163 """ 

164 Calls method ``_run``. 

165 """ 

166 try: 

167 res = self._run(*args, **kwargs) 

168 except TypeError as e: 

169 raise TypeError( # pragma: no cover 

170 "Issues with types {} (operator {}).".format( 

171 ", ".join(str(type(_)) for _ in args), 

172 self.__class__.__name__)) from e 

173 return res 

174 

175 def switch_initializers_dtype(self, dtype_in=numpy.float32, 

176 dtype_out=numpy.float64): 

177 """ 

178 Switches all initializers to ``numpy.float64``. If *model* 

179 is None, a simple cast is done. 

180 

181 @param dtype_in previous type 

182 @param dtype_out next type 

183 @return done operations 

184 """ 

185 done = [] 

186 for k, v in sorted(self.__dict__.items()): 

187 if k in {'desc', 'onnx_node'}: 

188 continue 

189 if isinstance(v, numpy.ndarray): 

190 if v.dtype == dtype_in: 

191 v = v.astype(dtype_out) 

192 setattr(self, k, v) 

193 done.append(("+", "att", k, getattr(self, k))) 

194 else: 

195 done.append(("-", "att", k, getattr(self, k))) 

196 if hasattr(self, '_run_no_checks_') and hasattr(self, 'run'): 

197 self.run = self._run_no_checks_ # pylint: disable=E0202,E1101 

198 return done 

199 

200 def infer_shapes(self, *args, **kwargs): 

201 """ 

202 Infer shapes of the outputs given the shapes 

203 of the inputs. It works the same way as method *run*. 

204 """ 

205 try: 

206 res = self._infer_shapes(*args, **kwargs) 

207 except TypeError as e: 

208 raise TypeError( # pragma: no cover 

209 "Issues with (operator '{}') and shapes\n{}" 

210 "\n----args\n{}\n------kwargs\n{}".format( 

211 self.__class__.__name__, 

212 "\n".join(str(_) for _ in args), 

213 pprint.pformat(args), 

214 pprint.pformat(kwargs))) from e 

215 if res is None: 

216 return res 

217 if not isinstance(res, tuple): 

218 raise TypeError( # pragma: no cover 

219 "res must be tuple not {} (operator '{}')".format( 

220 type(res), self.__class__.__name__)) 

221 for a in res: 

222 if not isinstance(a, ShapeObject): 

223 raise TypeError( # pragma: no cover 

224 "One shape is not a ShapeObject but {} (operator '{}')".format( 

225 type(a), self.__class__.__name__)) 

226 return res 

227 

228 def _infer_shapes(self, *args, **kwargs): 

229 """ 

230 Should be overwritten. 

231 """ 

232 raise NotImplementedError( 

233 "This method should be overwritten for operator '{}'.".format( 

234 self.__class__.__name__)) # pragma: no cover 

235 

236 def infer_types(self, *args, **kwargs): 

237 """ 

238 Infer types of the outputs given the types 

239 of the inputs. It works the same way as method *run*. 

240 """ 

241 try: 

242 res = self._infer_types(*args, **kwargs) 

243 except TypeError as e: # pragma: no cover 

244 raise TypeError( 

245 "Issues with (operator '{}') and types\n{}" 

246 "\n----args\n{}\n------kwargs\n{}".format( 

247 self.__class__.__name__, 

248 "\n".join(str(_) for _ in args), 

249 pprint.pformat(args), 

250 pprint.pformat(kwargs))) from e 

251 if not isinstance(res, tuple): 

252 raise TypeError( # pragma: no cover 

253 "res must be tuple not {} (operator '{}')".format( 

254 type(res), self.__class__.__name__)) 

255 for a in res: 

256 if not isinstance(a, (numpy.dtype, SequenceType)) and a not in { 

257 numpy.int8, numpy.uint8, numpy.float16, numpy.float32, 

258 numpy.float64, numpy.int32, numpy.int64, numpy.int16, 

259 numpy.uint16, numpy.uint32, numpy.bool_, numpy.str_, 

260 numpy.uint64, bool, str}: 

261 raise TypeError( # pragma: no cover 

262 "Type ({}, {}) is not a numpy type or a sequence type " 

263 "(operator '{}')".format( 

264 a, type(a), self.__class__.__name__)) 

265 return res 

266 

267 def _infer_types(self, *args, **kwargs): 

268 """ 

269 Should be overwritten. 

270 """ 

271 raise NotImplementedError( 

272 "This method should be overwritten for operator '{}'.".format( 

273 self.__class__.__name__)) # pragma: no cover 

274 

275 def infer_sizes(self, *args, **kwargs): 

276 """ 

277 Infer sizes required for computation. 

278 It works the same way as method *run*. 

279 """ 

280 try: 

281 res = self._infer_sizes(*args, **kwargs) 

282 except TypeError as e: # pragma: no cover 

283 raise TypeError( 

284 "Issues with (operator '{}') and types\n{}" 

285 "\n----args\n{}\n------kwargs\n{}".format( 

286 self.__class__.__name__, 

287 "\n".join(str(_) for _ in args), 

288 pprint.pformat(args), 

289 pprint.pformat(kwargs))) from e 

290 if not isinstance(res, tuple): 

291 raise TypeError( # pragma: no cover 

292 "res must be dict not {} (operator '{}')".format( 

293 type(res), self.__class__.__name__)) 

294 return res 

295 

296 def _infer_sizes(self, *args, **kwargs): 

297 """ 

298 Should be overwritten. 

299 """ 

300 raise NotImplementedError( 

301 "This method should be overwritten for operator '{}'.".format( 

302 self.__class__.__name__)) # pragma: no cover 

303 

304 def enable_inplace_compute(self, index): 

305 """ 

306 Tells the node that one input can be overwritten. 

307 

308 @param index input index 

309 """ 

310 self.inplaces[index] = True 

311 

312 @property 

313 def args_default(self): 

314 """ 

315 Returns the list of arguments as well as 

316 the list of parameters with the default values 

317 (close to the signature). 

318 """ 

319 inps = [] 

320 if hasattr(self, 'atts'): 

321 for k, v in self.atts.items(): # pylint: disable=E1101 

322 if isinstance(v, (list, tuple, dict)) and len(v) == 0: 

323 v = None 

324 inps.append('%s=%r' % (k, v)) 

325 return inps 

326 

327 @property 

328 def args_default_modified(self): 

329 """ 

330 Returns the list of modified parameters. 

331 """ 

332 if not hasattr(self, 'atts'): 

333 return None 

334 

335 inps = [] 

336 for k, v in self.atts.items(): # pylint: disable=E1101 

337 val = getattr(self, k, None) 

338 if isinstance(val, numpy.ndarray) and isinstance(v, list): 

339 val = list(val) 

340 try: 

341 if val != v: 

342 inps.append('%s=%r' % (k, val)) 

343 except ValueError as e: # pragma: no cover 

344 raise ValueError( 

345 "Unexpected value for v=%r and val=%r." % (v, val)) from e 

346 return inps 

347 

348 @property 

349 def args_optional(self): 

350 """ 

351 Returns the list of optional arguments. 

352 """ 

353 inps = [] 

354 if hasattr(self, 'optional_inputs'): 

355 for k, v in self.optional_inputs.items(): # pylint: disable=E1101 

356 inps.append('%s=%r' % (k, v)) 

357 return inps 

358 

359 @property 

360 def args_mandatory(self): 

361 """ 

362 Returns the list of optional arguments. 

363 """ 

364 if hasattr(self, 'mandatory_inputs'): 

365 return self.mandatory_inputs # pylint: disable=E1101 

366 return None 

367 

368 def to_python(self, inputs): 

369 """ 

370 Returns a python code equivalent to this operator. 

371 

372 @param inputs inputs name 

373 @return imports, python code, both as strings 

374 """ 

375 raise NotImplementedError( 

376 "Operator '{}' has no equivalent python code.".format(self.__class__.__name__)) # pragma: no cover 

377 

378 def _to_python_numpy(self, inputs, numpy_name): 

379 return ("import numpy", 

380 "return numpy.%s(%s)" % (numpy_name, ", ".join(inputs))) 

381 

382 @property 

383 def atts_value(self): 

384 "Returns all parameters in a dictionary." 

385 if hasattr(self, 'atts'): 

386 return {k: getattr(self, k) 

387 for k in self.atts} # pylint: disable=E1101 

388 return None 

389 

390 

391class OpRunUnary(OpRun): 

392 """ 

393 Ancestor to all unary operators in this subfolder. 

394 Checks that inputs type are the same. 

395 """ 

396 

397 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

398 **options): 

399 OpRun.__init__(self, onnx_node, desc=desc, 

400 expected_attributes=expected_attributes, 

401 **options) 

402 

403 def run(self, x): # pylint: disable=E0202,W0221 

404 """ 

405 Calls method ``_run``. 

406 """ 

407 try: 

408 res = self._run(x) 

409 except TypeError as e: 

410 raise TypeError( # pragma: no cover 

411 "Issues with types {} (binary operator {}).".format( 

412 ", ".join(str(type(_)) for _ in [x]), 

413 self.__class__.__name__)) from e 

414 return res 

415 

416 def infer_shapes(self, x): # pylint: disable=E0202,W0221 

417 try: 

418 return self._infer_shapes(x) 

419 except TypeError as e: # pragma: no cover 

420 raise TypeError( 

421 "Issues with types {} (operator {}).".format( 

422 x.dtype, self.__class__.__name__)) from e 

423 

424 def _infer_shapes(self, x): # pylint: disable=E0202,W0221 

425 """ 

426 Returns the same shape by default. 

427 """ 

428 return (x, ) 

429 

430 def infer_types(self, x): # pylint: disable=E0202,W0221 

431 try: 

432 return self._infer_types(x) 

433 except TypeError as e: # pragma: no cover 

434 raise TypeError( 

435 "Issues with types {} (operator {}).".format( 

436 x, self.__class__.__name__)) from e 

437 

438 def _infer_types(self, x): # pylint: disable=E0202,W0221 

439 """ 

440 Returns the same type by default. 

441 """ 

442 return (x, ) 

443 

444 def _infer_sizes(self, *args, **kwargs): 

445 res = self.run(*args, **kwargs) 

446 return (dict(temp=0), ) + res 

447 

448 

449class OpRunArg(OpRunUnary): 

450 """ 

451 Ancestor to all unary operators in this subfolder 

452 and which produces position of extremas (ArgMax, ...). 

453 Checks that inputs type are the same. 

454 The class must have attributes *axis*, *keepdim*. 

455 """ 

456 

457 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

458 **options): 

459 OpRunUnary.__init__(self, onnx_node, desc=desc, 

460 expected_attributes=expected_attributes, 

461 **options) 

462 if not hasattr(self, 'keepdims'): 

463 raise AttributeError( # pragma: no cover 

464 "Attribute 'keepdims' is missing.") 

465 if not hasattr(self, 'axis'): 

466 raise AttributeError( # pragma: no cover 

467 "Attribute 'axis' is missing.") 

468 

469 def run(self, x): # pylint: disable=E0202 

470 """ 

471 Calls method ``_run``. 

472 """ 

473 res = OpRunUnary.run(self, x) 

474 if res[0].dtype != numpy.int64: 

475 raise RuntimeTypeError( # pragma: no cover 

476 "Output type mismatch: should be '{}' != output '{}' " 

477 "(operator '{}')".format( 

478 numpy.int64, res[0].dtype, self.__class__.__name__)) 

479 return res 

480 

481 def _infer_shapes(self, x): # pylint: disable=W0221 

482 sh = x.reduce(self.axis, self.keepdims, # pylint: disable=E1101 

483 dtype=numpy.int64) # pylint: disable=E1101 

484 return (sh, ) 

485 

486 def _infer_types(self, x): # pylint: disable=W0221 

487 return (numpy.int64, ) 

488 

489 def _run_no_checks_(self, x): # pylint: disable=W0221 

490 return OpRunUnary.run(self, x) 

491 

492 

493class OpRunUnaryNum(OpRunUnary): 

494 """ 

495 Ancestor to all unary and numerical operators 

496 in this subfolder. Checks that inputs type 

497 are the same. 

498 """ 

499 

500 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

501 **options): 

502 OpRunUnary.__init__(self, onnx_node, desc=desc, 

503 expected_attributes=expected_attributes, 

504 **options) 

505 

506 def run(self, x): # pylint: disable=E0202 

507 """ 

508 Calls method ``_run``. 

509 """ 

510 res = OpRunUnary.run(self, x) 

511 if len(res) == 0 or res[0] is None: 

512 return res 

513 if not isinstance(res[0], list) and res[0].dtype != x.dtype: 

514 raise RuntimeTypeError( # pragma: no cover 

515 "Output type mismatch: input '{}' != output '{}' " 

516 "(operator '{}')".format( 

517 x.dtype, res[0].dtype, self.__class__.__name__)) 

518 return res 

519 

520 def _run_no_checks_(self, x): # pylint: disable=W0221 

521 return OpRunUnary.run(self, x) 

522 

523 

524class OpRunClassifierProb(OpRunUnary): 

525 """ 

526 Ancestor to all binary operators in this subfolder. 

527 Checks that inputs type are the same. 

528 """ 

529 

530 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

531 **options): 

532 OpRunUnary.__init__(self, onnx_node, desc=desc, 

533 expected_attributes=expected_attributes, 

534 **options) 

535 

536 def run(self, x): # pylint: disable=E0202 

537 """ 

538 Calls method ``_run``. 

539 """ 

540 res = OpRunUnary.run(self, x) 

541 if x.dtype in (numpy.float32, numpy.float64) and res[1].dtype != x.dtype: 

542 raise RuntimeTypeError( # pragma: no cover 

543 "Output type mismatch: {} != {} (operator '{}')".format( 

544 x.dtype, res[1].dtype, self.__class__.__name__)) 

545 return res 

546 

547 @property 

548 def nb_classes(self): 

549 """ 

550 Returns the number of expected classes. 

551 """ 

552 return max(len(getattr(self, 'classlabels_ints', [])), 

553 len(getattr(self, 'classlabels_int64s', [])), 

554 len(self.classlabels_strings)) # pylint: disable=E1101 

555 

556 def _run_no_checks_(self, x): # pylint: disable=W0221 

557 return OpRunUnary.run(self, x) 

558 

559 def _infer_shapes(self, x): # pylint: disable=W0221 

560 """ 

561 Returns the same for the labels and the probabilities. 

562 """ 

563 return (ShapeObject((x[0], ), dtype=numpy.int64, 

564 name="{}-0".format(self.__class__.__name__)), 

565 ShapeObject((x[0], self.nb_classes), dtype=x.dtype, 

566 name="{}-1".format(self.__class__.__name__))) 

567 

568 def _infer_types(self, x): # pylint: disable=W0221 

569 """ 

570 Returns the type of the labels and the probabilities. 

571 """ 

572 return (numpy.int64, x.dtype) 

573 

574 

575class OpRunBinary(OpRun): 

576 """ 

577 Ancestor to all binary operators in this subfolder. 

578 Checks that inputs type are the same. 

579 """ 

580 

581 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

582 **options): 

583 OpRun.__init__(self, onnx_node, desc=desc, 

584 expected_attributes=expected_attributes, 

585 **options) 

586 

587 def run(self, x, y): # pylint: disable=E0202,W0221 

588 """ 

589 Calls method ``_run``. 

590 """ 

591 if x is None or y is None: 

592 raise RuntimeError( # pragma: no cover 

593 "x and y have different dtype: {} != {} ({})".format( 

594 type(x), type(y), type(self))) 

595 if x.dtype != y.dtype: 

596 raise RuntimeTypeError( 

597 "Input type mismatch: {} != {} (operator '{}', shapes {}, {})".format( 

598 x.dtype, y.dtype, self.__class__.__name__, 

599 x.shape, y.shape)) 

600 try: 

601 res = self._run(x, y) 

602 except (TypeError, ValueError) as e: # pragma: no cover 

603 raise TypeError( 

604 "Issues with types {} (binary operator {}).".format( 

605 ", ".join(str(type(_)) for _ in [x, y]), 

606 self.__class__.__name__)) from e 

607 return res 

608 

609 def _run_no_checks_(self, x, y): # pylint: disable=W0221 

610 """ 

611 Calls method ``_run``. 

612 """ 

613 try: 

614 res = self._run(x, y) 

615 except TypeError as e: # pragma: no cover 

616 raise TypeError( 

617 "Issues with types {} (binary operator {}).".format( 

618 ", ".join(str(type(_)) for _ in [x, y]), 

619 self.__class__.__name__)) from e 

620 return res 

621 

622 def _infer_shapes(self, x, y): # pylint: disable=W0221 

623 """ 

624 Returns the same shape by default. 

625 We assume the operator returns the biggest 

626 shapes as the operator could be using broacasting. 

627 """ 

628 if x is None or y is None: 

629 return None 

630 try: 

631 res = x.broadcast(y) 

632 add = "broadcast" 

633 except RuntimeError: # pragma: no cover 

634 # We know x and y and the same number of dimensions. 

635 # We pick the first one even if it might be wrong. 

636 res = x 

637 add = "1" 

638 if res.name is None: 

639 return (res.copy(name="{}{}".format( 

640 self.__class__.__name__, add)), ) 

641 return (res.copy(name="{}-{}{}".format( 

642 res.name, self.__class__.__name__, add)), ) 

643 

644 def _infer_types(self, x, y): # pylint: disable=W0221 

645 """ 

646 Returns the boolean type. 

647 """ 

648 return (x, ) 

649 

650 def _infer_sizes(self, *args, **kwargs): 

651 res = self.run(*args, **kwargs) 

652 return (dict(temp=0), ) + res 

653 

654 

655class OpRunBinaryComparison(OpRunBinary): 

656 """ 

657 Ancestor to all binary operators in this subfolder 

658 comparing tensors. 

659 """ 

660 

661 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

662 **options): 

663 OpRunBinary.__init__(self, onnx_node, desc=desc, 

664 expected_attributes=expected_attributes, 

665 **options) 

666 

667 def _infer_types(self, x, y): # pylint: disable=W0221 

668 return (numpy.bool_, ) 

669 

670 

671class OpRunBinaryNum(OpRunBinary): 

672 """ 

673 Ancestor to all binary operators in this subfolder. 

674 Checks that inputs type are the same. 

675 """ 

676 

677 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

678 **options): 

679 OpRunBinary.__init__(self, onnx_node, desc=desc, 

680 expected_attributes=expected_attributes, 

681 **options) 

682 

683 def run(self, x, y): # pylint: disable=E0202 

684 """ 

685 Calls method ``_run``. 

686 """ 

687 res = OpRunBinary.run(self, x, y) 

688 if res[0].dtype != x.dtype: 

689 raise RuntimeTypeError( 

690 "Output type mismatch: {} != {} or {} (operator '{}')" 

691 " type(x)={} type(y)={}".format( 

692 x.dtype, res[0].dtype, y.dtype, 

693 self.__class__.__name__, type(x), type(y))) 

694 return res 

695 

696 def _run_no_checks_(self, x, y): # pylint: disable=W0221 

697 """ 

698 Calls method ``_run``. 

699 """ 

700 return OpRunBinary._run_no_checks_(self, x, y) 

701 

702 

703class OpRunBinaryNumpy(OpRunBinaryNum): 

704 """ 

705 Implements the inplaces logic. 

706 *numpy_fct* is a binary numpy function which 

707 takes two matrices and has a argument *out* 

708 for inplace operations. 

709 """ 

710 

711 def __init__(self, numpy_fct, onnx_node, desc=None, 

712 expected_attributes=None, **options): 

713 OpRunBinaryNum.__init__(self, onnx_node, desc=desc, 

714 expected_attributes=expected_attributes, 

715 **options) 

716 self.numpy_fct = numpy_fct 

717 self._cannot_inplace_int = self.numpy_fct in ( 

718 numpy.divide, numpy.true_divide) 

719 

720 def _run(self, a, b): # pylint: disable=W0221 

721 if (self._cannot_inplace_int and 

722 numpy.issubdtype(a.dtype, numpy.integer)): 

723 return (self.numpy_fct(a, b), ) 

724 if self.inplaces.get(0, False) and a.size >= b.size: 

725 if len(a.shape) == 1 and b.shape == (1, 1): 

726 a = a.reshape(1, a.shape[0]) 

727 try: 

728 self.numpy_fct(a, b, out=a) 

729 return (a, ) 

730 except (ValueError, TypeError): 

731 return (self.numpy_fct(a, b), ) 

732 if self.inplaces.get(1, False) and a.size <= b.size: 

733 if len(b.shape) == 1 and a.shape == (1, 1): 

734 b = b.reshape(b.shape[0], 1) 

735 try: 

736 self.numpy_fct(a, b, out=b) 

737 return (b, ) 

738 except (ValueError, TypeError): 

739 return (self.numpy_fct(a, b), ) 

740 return (self.numpy_fct(a, b), ) 

741 

742 def to_python(self, inputs): 

743 """ 

744 Returns a python code equivalent to this operator. 

745 

746 @param inputs inputs name 

747 @return imports, python code, both as strings 

748 """ 

749 lines = [ 

750 "# inplaces not take into account {}-{}".format( 

751 self.inplaces.get(0, False), self.inplaces.get(1, False)), 

752 "return numpy.{0}({1})".format( 

753 self.numpy_fct.__name__, ', '.join(inputs)) 

754 ] 

755 return "import numpy", "\n".join(lines) 

756 

757 

758class OpRunReduceNumpy(OpRunUnaryNum): 

759 """ 

760 Implements the reduce logic. 

761 It must have a parameter *axes*. 

762 """ 

763 

764 def __init__(self, onnx_node, desc=None, 

765 expected_attributes=None, **options): 

766 if ('noop_with_empty_axes' not in expected_attributes and 

767 'axes' not in expected_attributes): 

768 raise RuntimeError( # pragma: no cover 

769 "Parameter 'axes' is expected but not found in {} " 

770 "from class {}".format(expected_attributes, type(self))) 

771 if (expected_attributes.get('noop_with_empty_axes', 0) and 

772 (expected_attributes['axes'] is None or 

773 len(expected_attributes['axes']) == 0)): 

774 raise RuntimeError( # pragma: no cover 

775 "Parameter 'axes' cannot be empty as {} (noop_with_empty_axes=1) " 

776 "from class {}".format(expected_attributes, type(self))) 

777 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

778 expected_attributes=expected_attributes, 

779 **options) 

780 if isinstance(self.axes, numpy.ndarray): # pylint: disable=E0203 

781 if (len(self.axes.shape) == 0 or # pylint: disable=E0203 

782 self.axes.shape[0] == 0): # pylint: disable=E0203 

783 self.axes = None 

784 else: 

785 self.axes = tuple(self.axes) 

786 elif self.axes in [[], tuple()]: # pylint: disable=E0203 

787 self.axes = None 

788 elif isinstance(self.axes, list): # pylint: disable=E0203 

789 self.axes = tuple(self.axes) 

790 

791 

792class OpRunCustom(OpRun): 

793 """ 

794 Automates some methods for custom operators defined 

795 outside *mlprodict*. 

796 """ 

797 

798 class OpRunCustomSchema(OperatorSchema): 

799 """ 

800 Custom schema. 

801 """ 

802 

803 def __init__(self, cls): 

804 OperatorSchema.__init__(self, cls.__name__) 

805 self.attributes = cls.atts 

806 

807 def __init__(self, onnx_node, desc=None, 

808 expected_attributes=None, **options): 

809 OpRun.__init__(self, onnx_node, desc=desc, 

810 expected_attributes=expected_attributes, 

811 **options) 

812 

813 def _find_custom_operator_schema(self, op_name): 

814 """ 

815 Finds a custom operator defined by this runtime. 

816 """ 

817 if (op_name == self.__class__.__name__ or 

818 (hasattr(self.__class__, 'op_name') and 

819 self.__class__.op_name == op_name)): # pylint: disable=E1101 

820 return OpRunCustom.OpRunCustomSchema(self.__class__) 

821 raise RuntimeError( # pragma: no cover 

822 "Unable to find a schema for operator '{}'.".format(op_name))