Coverage for mlprodict/npy/onnx_sklearn_wrapper.py: 98%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

255 statements  

1""" 

2@file 

3@brief Helpers to use numpy API to easily write converters 

4for :epkg:`scikit-learn` classes for :epkg:`onnx`. 

5 

6.. versionadded:: 0.6 

7""" 

8import logging 

9import numpy 

10from sklearn.base import ( 

11 ClassifierMixin, ClusterMixin, 

12 RegressorMixin, TransformerMixin) 

13from .onnx_numpy_wrapper import _created_classes_inst, wrapper_onnxnumpy_np 

14from .onnx_numpy_annotation import NDArraySameType, NDArrayType 

15from .xop import OnnxOperatorTuple 

16from .xop_variable import Variable 

17from .xop import loadop 

18from ..plotting.text_plot import onnx_simple_text_plot 

19 

20 

21logger = logging.getLogger('xop') 

22 

23 

24def _skl2onnx_add_to_container(onx, scope, container, outputs): 

25 """ 

26 Adds ONNX graph to :epkg:`skl2onnx` container and scope. 

27 

28 :param onx: onnx graph 

29 :param scope: scope 

30 :param container: container 

31 """ 

32 logger.debug("_skl2onnx_add_to_container:onx=%r outputs=%r", 

33 type(onx), outputs) 

34 mapped_names = {x.name: x.name for x in onx.graph.input} 

35 opsets = {} 

36 for op in onx.opset_import: 

37 opsets[op.domain] = op.version 

38 

39 # adding initializers 

40 for init in onx.graph.initializer: 

41 new_name = scope.get_unique_variable_name(init.name) 

42 mapped_names[init.name] = new_name 

43 container.add_initializer(new_name, None, None, init) 

44 

45 # adding nodes 

46 for node in onx.graph.node: 

47 new_inputs = [] 

48 for i in node.input: 

49 if i not in mapped_names: 

50 raise RuntimeError( # pragma: no cover 

51 "Unable to find input %r in %r." % (i, mapped_names)) 

52 new_inputs.append(mapped_names[i]) 

53 new_outputs = [] 

54 for o in node.output: 

55 new_name = scope.get_unique_variable_name(o) 

56 mapped_names[o] = new_name 

57 new_outputs.append(new_name) 

58 

59 atts = {} 

60 for att in node.attribute: 

61 if att.type == 1: # .f 

62 value = att.f 

63 elif att.type == 2: # .i 

64 value = att.i 

65 elif att.type == 3: # .s 

66 value = att.s 

67 elif att.type == 4: # .t 

68 value = att.t 

69 elif att.type == 6: # .floats 

70 value = list(att.floats) 

71 elif att.type == 7: # .ints 

72 value = list(att.ints) 

73 elif att.type == 8: # .strings 

74 value = list(att.strings) 

75 else: 

76 raise NotImplementedError( # pragma: no cover 

77 "Unable to copy attribute type %r (%r)." % ( 

78 att.type, att)) 

79 atts[att.name] = value 

80 

81 container.add_node( 

82 node.op_type, 

83 name=scope.get_unique_operator_name('_sub_' + node.name), 

84 inputs=new_inputs, outputs=new_outputs, op_domain=node.domain, 

85 op_version=opsets.get(node.domain, None), **atts) 

86 

87 # linking outputs 

88 if len(onx.graph.output) != len(outputs): 

89 raise RuntimeError( # pragma: no cover 

90 "Output size mismatch %r != %r.\n--ONNX--\n%s" % ( 

91 len(onx.graph.output), len(outputs), 

92 onnx_simple_text_plot(onx))) 

93 for out, var in zip(onx.graph.output, outputs): 

94 container.add_node( 

95 'Identity', name=scope.get_unique_operator_name( 

96 '_sub_' + out.name), 

97 inputs=[mapped_names[out.name]], outputs=[var.onnx_name]) 

98 

99 

100def _common_shape_calculator_t(operator): 

101 if not hasattr(operator, 'onnx_numpy_fct_'): 

102 raise AttributeError( 

103 "operator must have attribute 'onnx_numpy_fct_'.") 

104 X = operator.inputs 

105 if len(X) != 1: 

106 raise RuntimeError( 

107 "This function only supports one input not %r." % len(X)) 

108 if len(operator.outputs) != 1: 

109 raise RuntimeError( 

110 "This function only supports one output not %r." % len( 

111 operator.outputs)) 

112 op = operator.raw_operator 

113 cl = X[0].type.__class__ 

114 dim = [X[0].type.shape[0], getattr(op, 'n_outputs_', None)] 

115 operator.outputs[0].type = cl(dim) 

116 

117 

118def _shape_calculator_transformer(operator): 

119 """ 

120 Default shape calculator for a transformer with one input 

121 and one output of the same type. 

122 

123 .. versionadded:: 0.6 

124 """ 

125 _common_shape_calculator_t(operator) 

126 

127 

128def _shape_calculator_regressor(operator): 

129 """ 

130 Default shape calculator for a regressor with one input 

131 and one output of the same type. 

132 

133 .. versionadded:: 0.6 

134 """ 

135 _common_shape_calculator_t(operator) 

136 

137 

138def _common_shape_calculator_int_t(operator): 

139 if not hasattr(operator, 'onnx_numpy_fct_'): 

140 raise AttributeError( 

141 "operator must have attribute 'onnx_numpy_fct_'.") 

142 X = operator.inputs 

143 if len(X) != 1: 

144 raise RuntimeError( 

145 "This function only supports one input not %r." % len(X)) 

146 if len(operator.outputs) != 2: 

147 raise RuntimeError( 

148 "This function only supports two outputs not %r." % len( 

149 operator.outputs)) 

150 from skl2onnx.common.data_types import Int64TensorType # delayed 

151 op = operator.raw_operator 

152 cl = X[0].type.__class__ 

153 dim = [X[0].type.shape[0], getattr(op, 'n_outputs_', None)] 

154 operator.outputs[0].type = Int64TensorType(dim[:1]) 

155 operator.outputs[1].type = cl(dim) 

156 

157 

158def _shape_calculator_classifier(operator): 

159 """ 

160 Default shape calculator for a classifier with one input 

161 and two outputs, label (int64) and probabilites of the same type. 

162 

163 .. versionadded:: 0.6 

164 """ 

165 _common_shape_calculator_int_t(operator) 

166 

167 

168def _shape_calculator_cluster(operator): 

169 """ 

170 Default shape calculator for a clustering with one input 

171 and two outputs, label (int64) and distances of the same type. 

172 

173 .. versionadded:: 0.6 

174 """ 

175 _common_shape_calculator_int_t(operator) 

176 

177 

178def _common_converter_begin(scope, operator, container, n_outputs): 

179 if not hasattr(operator, 'onnx_numpy_fct_'): 

180 raise AttributeError( 

181 "operator must have attribute 'onnx_numpy_fct_'.") 

182 X = operator.inputs 

183 if len(X) != 1: 

184 raise RuntimeError( 

185 "This function only supports one input not %r." % len(X)) 

186 if len(operator.outputs) != n_outputs: 

187 raise RuntimeError( 

188 "This function only supports %d output not %r." % ( 

189 n_outputs, len(operator.outputs))) 

190 

191 # First conversion of the model to onnx 

192 # Then addition of the onnx graph to the main graph. 

193 from .onnx_variable import OnnxVar 

194 new_var = Variable.from_skl2onnx(X[0]) 

195 xvar = OnnxVar(new_var) 

196 fct_cl = operator.onnx_numpy_fct_ 

197 

198 opv = container.target_opset 

199 logger.debug("_common_converter_begin:xvar=%r op=%s", 

200 xvar, type(operator.raw_operator)) 

201 inst = fct_cl.fct(xvar, op_=operator.raw_operator) 

202 logger.debug("_common_converter_begin:inst=%r opv=%r fct_cl.fct=%r", 

203 type(inst), opv, fct_cl.fct) 

204 onx = inst.to_algebra(op_version=opv) 

205 logger.debug("_common_converter_begin:end:onx=%r", type(onx)) 

206 return new_var, onx 

207 

208 

209def _common_converter_t(scope, operator, container): 

210 logger.debug("_common_converter_t:op=%r -> %r", 

211 operator.inputs, operator.outputs) 

212 OnnxIdentity = loadop('Identity') 

213 opv = container.target_opset 

214 new_var, onx = _common_converter_begin(scope, operator, container, 1) 

215 final = OnnxIdentity(onx, op_version=opv, 

216 output_names=[operator.outputs[0].full_name]) 

217 onx_model = final.to_onnx( 

218 [new_var], [Variable.from_skl2onnx(o) for o in operator.outputs], 

219 target_opset=opv) 

220 _skl2onnx_add_to_container(onx_model, scope, container, operator.outputs) 

221 logger.debug("_common_converter_t:end") 

222 

223 

224def _converter_transformer(scope, operator, container): 

225 """ 

226 Default converter for a transformer with one input 

227 and one output of the same type. It assumes instance *operator* 

228 has an attribute *onnx_numpy_fct_* from a function 

229 wrapped with decorator :func:`onnxsklearn_transformer 

230 <mlprodict.npy.onnx_sklearn_wrapper.onnxsklearn_transformer>`. 

231 

232 .. versionadded:: 0.6 

233 """ 

234 _common_converter_t(scope, operator, container) 

235 

236 

237def _converter_regressor(scope, operator, container): 

238 """ 

239 Default converter for a regressor with one input 

240 and one output of the same type. It assumes instance *operator* 

241 has an attribute *onnx_numpy_fct_* from a function 

242 wrapped with decorator :func:`onnxsklearn_regressor 

243 <mlprodict.npy.onnx_sklearn_wrapper.onnxsklearn_regressor>`. 

244 

245 .. versionadded:: 0.6 

246 """ 

247 _common_converter_t(scope, operator, container) 

248 

249 

250def _common_converter_int_t(scope, operator, container): 

251 logger.debug("_common_converter_int_t:op=%r -> %r", 

252 operator.inputs, operator.outputs) 

253 OnnxIdentity = loadop('Identity') 

254 opv = container.target_opset 

255 new_var, onx = _common_converter_begin(scope, operator, container, 2) 

256 

257 if isinstance(onx, OnnxOperatorTuple): 

258 if len(operator.outputs) != len(onx): 

259 raise RuntimeError( # pragma: no cover 

260 "Mismatched number of outputs expected %d, got %d." % ( 

261 len(operator.outputs), len(onx))) 

262 first_output = None 

263 other_outputs = [] 

264 for out, ox in zip(operator.outputs, onx): 

265 if not hasattr(ox, 'add_to'): 

266 raise TypeError( # pragma: no cover 

267 "Unexpected type for onnx graph %r, inst=%r." % ( 

268 type(ox), type(operator.raw_operator))) 

269 final = OnnxIdentity(ox, op_version=opv, 

270 output_names=[out.full_name]) 

271 if first_output is None: 

272 first_output = final 

273 else: 

274 other_outputs.append(final) 

275 

276 onx_model = first_output.to_onnx( 

277 [new_var], 

278 [Variable.from_skl2onnx(o) for o in operator.outputs], 

279 target_opset=opv, other_outputs=other_outputs) 

280 _skl2onnx_add_to_container( 

281 onx_model, scope, container, operator.outputs) 

282 logger.debug("_common_converter_int_t:1:end") 

283 else: 

284 final = OnnxIdentity(onx, op_version=opv, 

285 output_names=[operator.outputs[0].full_name]) 

286 onx_model = final.to_onnx( 

287 [new_var], 

288 [Variable.from_skl2onnx(o) for o in operator.outputs], 

289 target_opset=opv) 

290 _skl2onnx_add_to_container( 

291 onx_model, scope, container, operator.outputs) 

292 logger.debug("_common_converter_int_t:2:end") 

293 

294 

295def _converter_classifier(scope, operator, container): 

296 """ 

297 Default converter for a classifier with one input 

298 and two outputs, label and probabilities of the same input type. 

299 It assumes instance *operator* 

300 has an attribute *onnx_numpy_fct_* from a function 

301 wrapped with decorator :func:`onnxsklearn_classifier 

302 <mlprodict.npy.onnx_sklearn_wrapper.onnxsklearn_classifier>`. 

303 

304 .. versionadded:: 0.6 

305 """ 

306 _common_converter_int_t(scope, operator, container) 

307 

308 

309def _converter_cluster(scope, operator, container): 

310 """ 

311 Default converter for a clustering with one input 

312 and two outputs, label and distances of the same input type. 

313 It assumes instance *operator* 

314 has an attribute *onnx_numpy_fct_* from a function 

315 wrapped with decorator :func:`onnxsklearn_cluster 

316 <mlprodict.npy.onnx_sklearn_wrapper.onnxsklearn_cluster>`. 

317 

318 .. versionadded:: 0.6 

319 """ 

320 _common_converter_int_t(scope, operator, container) 

321 

322 

323_default_cvt = { 

324 ClassifierMixin: (_shape_calculator_classifier, _converter_classifier), 

325 ClusterMixin: (_shape_calculator_cluster, _converter_cluster), 

326 RegressorMixin: (_shape_calculator_regressor, _converter_regressor), 

327 TransformerMixin: (_shape_calculator_transformer, _converter_transformer), 

328} 

329 

330 

331def update_registered_converter_npy( 

332 model, alias, convert_fct, shape_fct=None, overwrite=True, 

333 parser=None, options=None): 

334 """ 

335 Registers or updates a converter for a new model so that 

336 it can be converted when inserted in a *scikit-learn* pipeline. 

337 This function assumes the converter is written as a function 

338 decoarated with :func:`onnxsklearn_transformer 

339 <mlprodict.npy.onnx_sklearn_wrapper.onnxsklearn_transformer>`. 

340 

341 :param model: model class 

342 :param alias: alias used to register the model 

343 :param shape_fct: function which checks or modifies the expected 

344 outputs, this function should be fast so that the whole graph 

345 can be computed followed by the conversion of each model, 

346 parallelized or not 

347 :param convert_fct: function which converts a model 

348 :param overwrite: False to raise exception if a converter 

349 already exists 

350 :param parser: overwrites the parser as well if not empty 

351 :param options: registered options for this converter 

352 

353 The alias is usually the library name followed by the model name. 

354 

355 .. versionadded:: 0.6 

356 """ 

357 if (hasattr(convert_fct, "compiled") or 

358 hasattr(convert_fct, 'signed_compiled')): 

359 # type is wrapper_onnxnumpy or wrapper_onnxnumpy_np 

360 obj = convert_fct 

361 else: 

362 raise AttributeError( # pragma: no cover 

363 "Class %r must have attribute 'compiled' or 'signed_compiled' " 

364 "(object=%r)." % (type(convert_fct), convert_fct)) 

365 

366 def addattr(operator, obj): 

367 operator.onnx_numpy_fct_ = obj 

368 return operator 

369 

370 if issubclass(model, TransformerMixin): 

371 defcl = TransformerMixin 

372 elif issubclass(model, RegressorMixin): 

373 defcl = RegressorMixin 

374 elif issubclass(model, ClassifierMixin): 

375 defcl = ClassifierMixin 

376 elif issubclass(model, ClusterMixin): 

377 defcl = ClusterMixin 

378 else: 

379 defcl = None 

380 

381 if shape_fct is not None: 

382 raise NotImplementedError( # pragma: no cover 

383 "Custom shape calculator are not implemented yet.") 

384 

385 shc = _default_cvt[defcl][0] 

386 local_shape_fct = ( 

387 lambda operator: shc(addattr(operator, obj))) 

388 

389 cvtc = _default_cvt[defcl][1] 

390 local_convert_fct = ( 

391 lambda scope, operator, container: 

392 cvtc(scope, addattr(operator, obj), container)) 

393 

394 from skl2onnx import update_registered_converter # delayed 

395 update_registered_converter( 

396 model, alias, convert_fct=local_convert_fct, 

397 shape_fct=local_shape_fct, overwrite=overwrite, 

398 parser=parser, options=options) 

399 

400 

401def _internal_decorator(fct, op_version=None, runtime=None, signature=None, 

402 register_class=None, overwrite=True, options=None): 

403 name = "onnxsklearn_parser_%s_%s_%s" % ( 

404 fct.__name__, str(op_version), runtime) 

405 newclass = type( 

406 name, (wrapper_onnxnumpy_np,), { 

407 '__doc__': fct.__doc__, 

408 '__name__': name, 

409 '__getstate__': wrapper_onnxnumpy_np.__getstate__, 

410 '__setstate__': wrapper_onnxnumpy_np.__setstate__}) 

411 _created_classes_inst.append(name, newclass) 

412 res = newclass( 

413 fct=fct, op_version=op_version, runtime=runtime, 

414 signature=signature) 

415 if register_class is not None: 

416 update_registered_converter_npy( 

417 register_class, "Sklearn%s" % getattr( 

418 register_class, "__name__", "noname"), 

419 res, shape_fct=None, overwrite=overwrite, options=options) 

420 return res 

421 

422 

423def onnxsklearn_transformer(op_version=None, runtime=None, signature=None, 

424 register_class=None, overwrite=True): 

425 """ 

426 Decorator to declare a converter for a transformer implemented using 

427 :epkg:`numpy` syntax but executed with :epkg:`ONNX` 

428 operators. 

429 

430 :param op_version: :epkg:`ONNX` opset version 

431 :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference 

432 :param signature: if None, the signature is replaced by a standard signature 

433 for transformer ``NDArraySameType("all")`` 

434 :param register_class: automatically register this converter 

435 for this class to :epkg:`sklearn-onnx` 

436 :param overwrite: overwrite existing registered function if any 

437 

438 .. versionadded:: 0.6 

439 """ 

440 if signature is None: 

441 signature = NDArraySameType("all") 

442 

443 def decorator_fct(fct): 

444 return _internal_decorator(fct, signature=signature, 

445 op_version=op_version, 

446 runtime=runtime, 

447 register_class=register_class, 

448 overwrite=overwrite) 

449 return decorator_fct 

450 

451 

452def onnxsklearn_regressor(op_version=None, runtime=None, signature=None, 

453 register_class=None, overwrite=True): 

454 """ 

455 Decorator to declare a converter for a regressor implemented using 

456 :epkg:`numpy` syntax but executed with :epkg:`ONNX` 

457 operators. 

458 

459 :param op_version: :epkg:`ONNX` opset version 

460 :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference 

461 :param signature: if None, the signature is replaced by a standard signature 

462 for transformer ``NDArraySameType("all")`` 

463 :param register_class: automatically register this converter 

464 for this class to :epkg:`sklearn-onnx` 

465 :param overwrite: overwrite existing registered function if any 

466 

467 .. versionadded:: 0.6 

468 """ 

469 if signature is None: 

470 signature = NDArraySameType("all") 

471 

472 def decorator_fct(fct): 

473 return _internal_decorator(fct, signature=signature, 

474 op_version=op_version, 

475 runtime=runtime, 

476 register_class=register_class, 

477 overwrite=overwrite) 

478 return decorator_fct 

479 

480 

481def onnxsklearn_classifier(op_version=None, runtime=None, signature=None, 

482 register_class=None, overwrite=True): 

483 """ 

484 Decorator to declare a converter for a classifier implemented using 

485 :epkg:`numpy` syntax but executed with :epkg:`ONNX` 

486 operators. 

487 

488 :param op_version: :epkg:`ONNX` opset version 

489 :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference 

490 :param signature: if None, the signature is replaced by a standard signature 

491 for transformer ``NDArraySameType("all")`` 

492 :param register_class: automatically register this converter 

493 for this class to :epkg:`sklearn-onnx` 

494 :param overwrite: overwrite existing registered function if any 

495 

496 .. versionadded:: 0.6 

497 """ 

498 if signature is None: 

499 signature = NDArrayType(("T:all", ), dtypes_out=((numpy.int64, ), 'T')) 

500 

501 def decorator_fct(fct): 

502 return _internal_decorator(fct, signature=signature, 

503 op_version=op_version, 

504 runtime=runtime, 

505 register_class=register_class, 

506 overwrite=overwrite, 

507 options={'zipmap': [False, True, 'columns'], 

508 'nocl': [False, True]}) 

509 return decorator_fct 

510 

511 

512def onnxsklearn_cluster(op_version=None, runtime=None, signature=None, 

513 register_class=None, overwrite=True): 

514 """ 

515 Decorator to declare a converter for a cluster implemented using 

516 :epkg:`numpy` syntax but executed with :epkg:`ONNX` 

517 operators. 

518 

519 :param op_version: :epkg:`ONNX` opset version 

520 :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference 

521 :param signature: if None, the signature is replaced by a standard signature 

522 for transformer ``NDArraySameType("all")`` 

523 :param register_class: automatically register this converter 

524 for this class to :epkg:`sklearn-onnx` 

525 :param overwrite: overwrite existing registered function if any 

526 

527 .. versionadded:: 0.6 

528 """ 

529 if signature is None: 

530 signature = NDArrayType(("T:all", ), dtypes_out=((numpy.int64, ), 'T')) 

531 

532 def decorator_fct(fct): 

533 return _internal_decorator(fct, signature=signature, 

534 op_version=op_version, 

535 runtime=runtime, 

536 register_class=register_class, 

537 overwrite=overwrite) 

538 return decorator_fct 

539 

540 

541def _call_validate(self, X): 

542 if hasattr(self, "_validate_onnx_data"): 

543 return self._validate_onnx_data(X) 

544 return X 

545 

546 

547def _internal_method_decorator(register_class, method, op_version=None, 

548 runtime=None, signature=None, 

549 method_names=None, overwrite=True, 

550 options=None): 

551 if isinstance(method_names, str): 

552 method_names = (method_names, ) 

553 

554 if issubclass(register_class, TransformerMixin): 

555 if signature is None: 

556 signature = NDArraySameType("all") 

557 if method_names is None: 

558 method_names = ("transform", ) 

559 elif issubclass(register_class, RegressorMixin): 

560 if signature is None: 

561 signature = NDArraySameType("all") 

562 if method_names is None: 

563 method_names = ("predict", ) 

564 elif issubclass(register_class, ClassifierMixin): 

565 if signature is None: 

566 signature = NDArrayType( 

567 ("T:all", ), dtypes_out=((numpy.int64, ), 'T')) 

568 if method_names is None: 

569 method_names = ("predict", "predict_proba") 

570 if options is None: 

571 options = {'zipmap': [False, True, 'columns'], 

572 'nocl': [False, True]} 

573 elif issubclass(register_class, ClusterMixin): 

574 if signature is None: 

575 signature = NDArrayType( 

576 ("T:all", ), dtypes_out=((numpy.int64, ), 'T')) 

577 if method_names is None: 

578 method_names = ("predict", "transform") 

579 elif method_names is None: # pragma: no cover 

580 raise RuntimeError( 

581 "No obvious API was detected (one among %s), " 

582 "then 'method_names' must be specified and not left " 

583 "empty." % (", ".join(map(lambda s: s.__name__, _default_cvt)))) 

584 

585 if method_names is None: 

586 raise RuntimeError( # pragma: no cover 

587 "Methods to overwrite are not known for class %r and " 

588 "method %r." % (register_class, method)) 

589 if signature is None: 

590 raise RuntimeError( # pragma: no cover 

591 "Methods to overwrite are not known for class %r and " 

592 "method %r." % (register_class, method)) 

593 

594 name = "onnxsklearn_parser_%s_%s_%s" % ( 

595 register_class.__name__, str(op_version), runtime) 

596 newclass = type( 

597 name, (wrapper_onnxnumpy_np,), { 

598 '__doc__': method.__doc__, 

599 '__name__': name, 

600 '__getstate__': wrapper_onnxnumpy_np.__getstate__, 

601 '__setstate__': wrapper_onnxnumpy_np.__setstate__}) 

602 _created_classes_inst.append(name, newclass) 

603 

604 def _check_(op): 

605 if isinstance(op, str): 

606 raise TypeError( # pragma: no cover 

607 "Unexpected type: %r: %r." % (type(op), op)) 

608 return op 

609 

610 res = newclass( 

611 fct=lambda *args, op_=None, **kwargs: method( 

612 _check_(op_), *args, **kwargs), 

613 op_version=op_version, runtime=runtime, signature=signature, 

614 fctsig=method) 

615 

616 if len(method_names) == 1: 

617 name = method_names[0] 

618 if hasattr(register_class, name): 

619 raise RuntimeError( # pragma: no cover 

620 "Cannot overwrite method %r because it already exists in " 

621 "class %r." % (name, register_class)) 

622 m = lambda self, X: res(_call_validate(self, X), op_=self) 

623 setattr(register_class, name, m) 

624 elif len(method_names) == 0: 

625 raise RuntimeError("No available method.") # pragma: no cover 

626 else: 

627 m = lambda self, X: res(_call_validate(self, X), op_=self) 

628 setattr(register_class, method.__name__ + "_", m) 

629 for iname, name in enumerate(method_names): 

630 if hasattr(register_class, name): 

631 raise RuntimeError( # pragma: no cover 

632 "Cannot overwrite method %r because it already exists in " 

633 "class %r." % (name, register_class)) 

634 m = (lambda self, X, index_output=iname: 

635 res(_call_validate(self, X), op_=self)[index_output]) 

636 setattr(register_class, name, m) 

637 

638 update_registered_converter_npy( 

639 register_class, "Sklearn%s" % getattr( 

640 register_class, "__name__", "noname"), 

641 res, shape_fct=None, overwrite=overwrite, 

642 options=options) 

643 return res 

644 

645 

646def onnxsklearn_class(method_name, op_version=None, runtime=None, 

647 signature=None, method_names=None, 

648 overwrite=True): 

649 """ 

650 Decorator to declare a converter for a class derivated from 

651 :epkg:`scikit-learn`, implementing inference method 

652 and using :epkg:`numpy` syntax but executed with 

653 :epkg:`ONNX` operators. 

654 

655 :param method_name: name of the method implementing the 

656 inference method with :epkg:`numpy` API for ONNX 

657 :param op_version: :epkg:`ONNX` opset version 

658 :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference 

659 :param signature: if None, the signature is replaced by a standard signature 

660 depending on the model kind, otherwise, it is the signature of the 

661 ONNX function 

662 :param method_names: if None, method names is guessed based on 

663 the class kind (transformer, regressor, classifier, clusterer) 

664 :param overwrite: overwrite existing registered function if any 

665 

666 .. versionadded:: 0.6 

667 """ 

668 def decorator_class(objclass): 

669 _internal_method_decorator( 

670 objclass, method=getattr(objclass, method_name), 

671 signature=signature, op_version=op_version, 

672 runtime=runtime, method_names=method_names, 

673 overwrite=overwrite) 

674 return objclass 

675 

676 return decorator_class