Coverage for mlprodict/sklapi/onnx_speed_up.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

231 statements  

1# coding: utf-8 

2""" 

3@file 

4@brief Speeding up :epkg:`scikit-learn` with :epkg:`onnx`. 

5 

6.. versionadded:: 0.7 

7""" 

8import collections 

9import inspect 

10import io 

11from contextlib import redirect_stdout, redirect_stderr 

12import numpy 

13from numpy.testing import assert_almost_equal 

14import scipy.special as scipy_special 

15import scipy.spatial.distance as scipy_distance 

16from onnx import helper, load 

17from sklearn.base import ( 

18 BaseEstimator, clone, 

19 TransformerMixin, RegressorMixin, ClassifierMixin, 

20 ClusterMixin) 

21from sklearn.preprocessing import FunctionTransformer 

22from skl2onnx.algebra.onnx_operator_mixin import OnnxOperatorMixin 

23from ..tools.code_helper import print_code 

24from .. import __max_supported_opset__ 

25from ..onnx_tools.onnx_export import export2numpy 

26from ..onnx_tools.onnx2py_helper import ( 

27 onnx_model_opsets, _var_as_dict, to_skl2onnx_type) 

28from ..onnx_tools.exports.numpy_helper import ( 

29 array_feature_extrator, 

30 argmax_use_numpy_select_last_index, 

31 argmin_use_numpy_select_last_index, 

32 make_slice) 

33from ..onnx_tools.exports.skl2onnx_helper import add_onnx_graph 

34from ..onnx_conv import to_onnx 

35from .onnx_transformer import OnnxTransformer 

36 

37 

38class _OnnxPipelineStepSpeedup(BaseEstimator, OnnxOperatorMixin): 

39 """ 

40 Speeds up inference by replacing methods *transform* or 

41 *predict* by a runtime for :epkg:`ONNX`. 

42 

43 :param estimator: estimator to train 

44 :param enforce_float32: boolean 

45 :epkg:`onnxruntime` only supports *float32*, 

46 :epkg:`scikit-learn` usually uses double floats, this parameter 

47 ensures that every array of double floats is converted into 

48 single floats 

49 :param runtime: string, defined the runtime to use 

50 as described in @see cl OnnxInference. 

51 :param target_opset: targetted ONNX opset 

52 :param conv_options: options for conversions, see @see fn to_onnx 

53 :param nopython: used by :epkg:`numba` jitter 

54 

55 Attributes created by method *fit*: 

56 

57 * `estimator_`: cloned and trained version of *estimator* 

58 * `onnxrt_`: objet of type @see cl OnnxInference, 

59 :epkg:`sklearn:preprocessing:FunctionTransformer` 

60 * `numpy_code_`: python code equivalent to the inference 

61 method if the runtime is `'numpy'` or `'numba'` 

62 * `onnx_io_names_`: dictionary, additional information 

63 if the runtime is `'numpy'` or `'numba'` 

64 

65 .. versionadded:: 0.7 

66 """ 

67 

68 def __init__(self, estimator, runtime='python', enforce_float32=True, 

69 target_opset=None, conv_options=None, nopython=True): 

70 BaseEstimator.__init__(self) 

71 self.estimator = estimator 

72 self.runtime = runtime 

73 self.enforce_float32 = enforce_float32 

74 self.target_opset = target_opset 

75 self.conv_options = conv_options 

76 self.nopython = nopython 

77 

78 def _check_fitted_(self): 

79 if not hasattr(self, 'onnxrt_'): 

80 raise AttributeError( # pragma: no cover 

81 "Object must be be fit.") 

82 

83 def _to_onnx(self, fitted_estimator, inputs): 

84 """ 

85 Converts an estimator inference into :epkg:`ONNX`. 

86 

87 :param estimator: any estimator following :epkg:`scikit-learn` API 

88 :param inputs: example of inputs 

89 :return: ONNX 

90 """ 

91 return to_onnx( 

92 self.estimator_, inputs, target_opset=self.target_opset, 

93 options=self.conv_options) 

94 

95 def _build_onnx_runtime(self, onx): 

96 """ 

97 Returns an instance of @see cl OnnxTransformer which 

98 executes the ONNX graph. 

99 

100 :param onx: ONNX graph 

101 :param runtime: runtime type (see @see cl OnnxInference) 

102 :return: instance of @see cl OnnxInference 

103 """ 

104 if self.runtime in ('numpy', 'numba'): 

105 return self._build_onnx_runtime_numpy(onx) 

106 tr = OnnxTransformer( 

107 onx, runtime=self.runtime, 

108 enforce_float32=self.enforce_float32) 

109 tr.fit() 

110 return tr 

111 

112 def _build_onnx_runtime_numpy(self, onx): 

113 """ 

114 Builds a runtime based on numpy. 

115 Exports the ONNX graph into python code 

116 based on numpy and then dynamically compiles 

117 it with method @see me _build_onnx_runtime_numpy_compile. 

118 """ 

119 model_onnx = load(io.BytesIO(onx)) 

120 self.onnx_io_names_ = {'inputs': [], 'outputs': []} 

121 for inp in model_onnx.graph.input: # pylint: disable=E1101 

122 d = _var_as_dict(inp) 

123 self.onnx_io_names_['inputs'].append((d['name'], d['type'])) 

124 for inp in model_onnx.graph.output: # pylint: disable=E1101 

125 d = _var_as_dict(inp) 

126 self.onnx_io_names_['outputs'].append((d['name'], d['type'])) 

127 self.onnx_io_names_['skl2onnx_inputs'] = [ 

128 to_skl2onnx_type(d[0], d[1]['elem'], d[1]['shape']) 

129 for d in self.onnx_io_names_['inputs']] 

130 self.onnx_io_names_['skl2onnx_outputs'] = [ 

131 to_skl2onnx_type(d[0], d[1]['elem'], d[1]['shape']) 

132 for d in self.onnx_io_names_['outputs']] 

133 self.numpy_code_ = export2numpy(model_onnx, rename=True) 

134 opsets = onnx_model_opsets(model_onnx) 

135 return self._build_onnx_runtime_numpy_compile(opsets) 

136 

137 def _build_onnx_runtime_numpy_compile(self, opsets): 

138 """ 

139 Second part of @see me _build_onnx_runtime_numpy. 

140 """ 

141 try: 

142 compiled_code = compile( 

143 self.numpy_code_, '<string>', 'exec') 

144 except SyntaxError as e: # pragma: no cover 

145 raise AssertionError( 

146 "Unable to compile a script due to %r. " 

147 "\n--CODE--\n%s" 

148 "" % (e, print_code(self.numpy_code_))) from e 

149 

150 glo = globals().copy() 

151 loc = { 

152 'numpy': numpy, 'dict': dict, 'list': list, 

153 'print': print, 'sorted': sorted, 

154 'collections': collections, 'inspect': inspect, 

155 'helper': helper, 'scipy_special': scipy_special, 

156 'scipy_distance': scipy_distance, 

157 'array_feature_extrator': array_feature_extrator, 

158 'argmin_use_numpy_select_last_index': 

159 argmin_use_numpy_select_last_index, 

160 'argmax_use_numpy_select_last_index': 

161 argmax_use_numpy_select_last_index, 

162 'make_slice': make_slice} 

163 out = io.StringIO() 

164 err = io.StringIO() 

165 with redirect_stdout(out): 

166 with redirect_stderr(err): 

167 try: 

168 exec(compiled_code, glo, loc) # pylint: disable=W0122 

169 except Exception as e: # pragma: no cover 

170 raise AssertionError( 

171 "Unable to execute a script due to %r. " 

172 "\n--OUT--\n%s\n--ERR--\n%s\n--CODE--\n%s" 

173 "" % (e, out.getvalue(), err.getvalue(), 

174 print_code(self.numpy_code_))) from e 

175 names = [k for k in loc if k.startswith('numpy_')] 

176 if len(names) != 1: 

177 raise RuntimeError( # pragma: no cover 

178 "Unable to guess which function is the one, names=%r." 

179 "" % list(sorted(names))) 

180 fct = loc[names[0]] 

181 if self.runtime == 'numba': 

182 from numba import jit 

183 jitter = jit(nopython=self.nopython) 

184 fct = jitter(fct) 

185 cl = FunctionTransformer(fct, accept_sparse=True) 

186 cl.op_version = opsets.get('', __max_supported_opset__) 

187 return cl 

188 

189 def __getstate__(self): 

190 """ 

191 :epkg:`pickle` does not support functions. 

192 This method removes any link to function 

193 when the runtime is `'numpy'`. 

194 """ 

195 state = BaseEstimator.__getstate__(self) 

196 if 'numpy_code_' in state: 

197 del state['onnxrt_'] 

198 return state 

199 

200 def __setstate__(self, state): 

201 """ 

202 :epkg:`pickle` does not support functions. 

203 This method restores the function created when 

204 the runtime is `'numpy'`. 

205 """ 

206 BaseEstimator.__setstate__(self, state) 

207 if 'numpy_code_' in state: 

208 model_onnx = load(io.BytesIO(state['onnx_'])) 

209 opsets = onnx_model_opsets(model_onnx) 

210 self.onnxrt_ = self._build_onnx_runtime_numpy_compile(opsets) 

211 

212 def fit(self, X, y=None, sample_weight=None, **kwargs): 

213 """ 

214 Fits the estimator, converts to ONNX. 

215 

216 :param X: features 

217 :param args: other arguments 

218 :param kwargs: fitting options 

219 """ 

220 if not hasattr(self, 'estimator_'): 

221 self.estimator_ = clone(self.estimator) 

222 if y is None: 

223 if sample_weight is None: 

224 self.estimator_.fit(X, **kwargs) 

225 else: 

226 self.estimator_.fit(X, sample_weight=sample_weight, **kwargs) 

227 else: 

228 if sample_weight is None: 

229 self.estimator_.fit(X, y, **kwargs) 

230 else: 

231 self.estimator_.fit( 

232 X, y, sample_weight=sample_weight, **kwargs) 

233 

234 if self.enforce_float32: 

235 X = X.astype(numpy.float32) 

236 self.onnx_ = self._to_onnx(self.estimator_, X).SerializeToString() 

237 self.onnxrt_ = self._build_onnx_runtime(self.onnx_) 

238 return self 

239 

240 @property 

241 def op_version(self): 

242 """ 

243 Returns the opset version. 

244 """ 

245 self._check_fitted_() 

246 return self.onnxrt_.op_version 

247 

248 def onnx_parser(self): 

249 """ 

250 Returns a parser for this model. 

251 """ 

252 self._check_fitted_() 

253 if isinstance(self.onnxrt_, FunctionTransformer): 

254 def parser(): 

255 # Types should be included as well. 

256 return [r[0] for r in self.onnx_io_names_['skl2onnx_outputs']] 

257 return parser 

258 return self.onnxrt_.onnx_parser() 

259 

260 def onnx_shape_calculator(self): 

261 """ 

262 Returns a shape calculator for this transform. 

263 """ 

264 self._check_fitted_() 

265 

266 if isinstance(self.onnxrt_, FunctionTransformer): 

267 def fct_shape_calculator(operator): 

268 # Types should be included as well. 

269 outputs = self.onnx_io_names_['skl2onnx_outputs'] 

270 if len(operator.outputs) != len(outputs): 

271 raise RuntimeError( # pragma: no cover 

272 "Mismatch between parser and shape calculator, " 

273 "%r != %r." % (outputs, operator.outputs)) 

274 for a, b in zip(operator.outputs, outputs): 

275 a.type = b[1] 

276 return fct_shape_calculator 

277 

278 calc = self.onnxrt_.onnx_shape_calculator() 

279 

280 def shape_calculator(operator): 

281 return calc(operator) 

282 

283 return shape_calculator 

284 

285 def onnx_converter(self): 

286 """ 

287 Returns a converter for this transform. 

288 """ 

289 self._check_fitted_() 

290 

291 if isinstance(self.onnxrt_, FunctionTransformer): 

292 

293 def fct_converter(scope, operator, container): 

294 op = operator.raw_operator 

295 onnx_model = load(io.BytesIO(op.onnx_)) 

296 add_onnx_graph(scope, operator, container, onnx_model) 

297 

298 return fct_converter 

299 

300 conv = self.onnxrt_.onnx_converter() 

301 

302 def converter(scope, operator, container): 

303 op = operator.raw_operator 

304 onnx_model = op.onnxrt_.onnxrt_.obj 

305 conv(scope, operator, container, onnx_model=onnx_model) 

306 

307 return converter 

308 

309 

310class OnnxSpeedupTransformer(TransformerMixin, 

311 _OnnxPipelineStepSpeedup): 

312 """ 

313 Trains with :epkg:`scikit-learn`, transform with :epkg:`ONNX`. 

314 

315 :param estimator: estimator to train 

316 :param enforce_float32: boolean 

317 :epkg:`onnxruntime` only supports *float32*, 

318 :epkg:`scikit-learn` usually uses double floats, this parameter 

319 ensures that every array of double floats is converted into 

320 single floats 

321 :param runtime: string, defined the runtime to use 

322 as described in @see cl OnnxInference. 

323 :param target_opset: targetted ONNX opset 

324 :param conv_options: conversion options, see @see fn to_onnx 

325 :param nopython: used by :epkg:`numba` jitter 

326 

327 Attributes created by method *fit*: 

328 

329 * `estimator_`: cloned and trained version of *estimator* 

330 * `onnxrt_`: objet of type @see cl OnnxInference, 

331 :epkg:`sklearn:preprocessing:FunctionTransformer` 

332 * `numpy_code_`: python code equivalent to the inference 

333 method if the runtime is `'numpy'` or `'numba'` 

334 * `onnx_io_names_`: dictionary, additional information 

335 if the runtime is `'numpy'` or `'numba'` 

336 

337 .. versionadded:: 0.7 

338 """ 

339 

340 def __init__(self, estimator, runtime='python', enforce_float32=True, 

341 target_opset=None, conv_options=None, nopython=True): 

342 _OnnxPipelineStepSpeedup.__init__( 

343 self, estimator, runtime=runtime, enforce_float32=enforce_float32, 

344 target_opset=target_opset, conv_options=conv_options, 

345 nopython=nopython) 

346 

347 def fit(self, X, y=None, sample_weight=None): # pylint: disable=W0221 

348 """ 

349 Trains based estimator. 

350 """ 

351 if sample_weight is None: 

352 _OnnxPipelineStepSpeedup.fit(self, X, y) 

353 else: 

354 _OnnxPipelineStepSpeedup.fit( 

355 self, X, y, sample_weight=sample_weight) 

356 return self 

357 

358 def transform(self, X): 

359 """ 

360 Transforms with *ONNX*. 

361 

362 :param X: features 

363 :return: transformed features 

364 """ 

365 return self.onnxrt_.transform(X) 

366 

367 def raw_transform(self, X): 

368 """ 

369 Transforms with *scikit-learn*. 

370 

371 :param X: features 

372 :return: transformed features 

373 """ 

374 return self.estimator_.transform(X) 

375 

376 def assert_almost_equal(self, X, **kwargs): 

377 """ 

378 Checks that ONNX and scikit-learn produces the same 

379 outputs. 

380 """ 

381 expected = self.raw_transform(X) 

382 got = self.transform(X) 

383 assert_almost_equal(expected, got, **kwargs) 

384 

385 

386class OnnxSpeedupRegressor(RegressorMixin, 

387 _OnnxPipelineStepSpeedup): 

388 """ 

389 Trains with :epkg:`scikit-learn`, transform with :epkg:`ONNX`. 

390 

391 :param estimator: estimator to train 

392 :param enforce_float32: boolean 

393 :epkg:`onnxruntime` only supports *float32*, 

394 :epkg:`scikit-learn` usually uses double floats, this parameter 

395 ensures that every array of double floats is converted into 

396 single floats 

397 :param runtime: string, defined the runtime to use 

398 as described in @see cl OnnxInference. 

399 :param target_opset: targetted ONNX opset 

400 :param conv_options: conversion options, see @see fn to_onnx 

401 :param nopython: used by :epkg:`numba` jitter 

402 

403 Attributes created by method *fit*: 

404 

405 * `estimator_`: cloned and trained version of *estimator* 

406 * `onnxrt_`: objet of type @see cl OnnxInference, 

407 :epkg:`sklearn:preprocessing:FunctionTransformer` 

408 * `numpy_code_`: python code equivalent to the inference 

409 method if the runtime is `'numpy'` or `'numba'` 

410 * `onnx_io_names_`: dictionary, additional information 

411 if the runtime is `'numpy'` or `'numba'` 

412 

413 .. versionadded:: 0.7 

414 """ 

415 

416 def __init__(self, estimator, runtime='python', enforce_float32=True, 

417 target_opset=None, conv_options=None, nopython=True): 

418 _OnnxPipelineStepSpeedup.__init__( 

419 self, estimator, runtime=runtime, enforce_float32=enforce_float32, 

420 target_opset=target_opset, conv_options=conv_options, 

421 nopython=nopython) 

422 

423 def fit(self, X, y, sample_weight=None): # pylint: disable=W0221 

424 """ 

425 Trains based estimator. 

426 """ 

427 if sample_weight is None: 

428 _OnnxPipelineStepSpeedup.fit(self, X, y) 

429 else: 

430 _OnnxPipelineStepSpeedup.fit( 

431 self, X, y, sample_weight=sample_weight) 

432 return self 

433 

434 def predict(self, X): 

435 """ 

436 Transforms with *ONNX*. 

437 

438 :param X: features 

439 :return: transformed features 

440 """ 

441 return self.onnxrt_.transform(X) 

442 

443 def raw_predict(self, X): 

444 """ 

445 Transforms with *scikit-learn*. 

446 

447 :param X: features 

448 :return: transformed features 

449 """ 

450 return self.estimator_.predict(X) 

451 

452 def assert_almost_equal(self, X, **kwargs): 

453 """ 

454 Checks that ONNX and scikit-learn produces the same 

455 outputs. 

456 """ 

457 expected = numpy.squeeze(self.raw_predict(X)) 

458 got = numpy.squeeze(self.predict(X)) 

459 assert_almost_equal(expected, got, **kwargs) 

460 

461 

462class OnnxSpeedupClassifier(ClassifierMixin, 

463 _OnnxPipelineStepSpeedup): 

464 """ 

465 Trains with :epkg:`scikit-learn`, transform with :epkg:`ONNX`. 

466 

467 :param estimator: estimator to train 

468 :param enforce_float32: boolean 

469 :epkg:`onnxruntime` only supports *float32*, 

470 :epkg:`scikit-learn` usually uses double floats, this parameter 

471 ensures that every array of double floats is converted into 

472 single floats 

473 :param runtime: string, defined the runtime to use 

474 as described in @see cl OnnxInference. 

475 :param target_opset: targetted ONNX opset 

476 :param conv_options: conversion options, see @see fn to_onnx 

477 :param nopython: used by :epkg:`numba` jitter 

478 

479 Attributes created by method *fit*: 

480 

481 * `estimator_`: cloned and trained version of *estimator* 

482 * `onnxrt_`: objet of type @see cl OnnxInference, 

483 :epkg:`sklearn:preprocessing:FunctionTransformer` 

484 * `numpy_code_`: python code equivalent to the inference 

485 method if the runtime is `'numpy'` or `'numba'` 

486 * `onnx_io_names_`: dictionary, additional information 

487 if the runtime is `'numpy'` or `'numba'` 

488 

489 .. versionadded:: 0.7 

490 """ 

491 

492 def __init__(self, estimator, runtime='python', enforce_float32=True, 

493 target_opset=None, conv_options=None, nopython=True): 

494 if conv_options is None: 

495 conv_options = {'zipmap': False} 

496 _OnnxPipelineStepSpeedup.__init__( 

497 self, estimator, runtime=runtime, enforce_float32=enforce_float32, 

498 target_opset=target_opset, conv_options=conv_options, 

499 nopython=nopython) 

500 

501 def fit(self, X, y, sample_weight=None): # pylint: disable=W0221 

502 """ 

503 Trains based estimator. 

504 """ 

505 if sample_weight is None: 

506 _OnnxPipelineStepSpeedup.fit(self, X, y) 

507 else: 

508 _OnnxPipelineStepSpeedup.fit( 

509 self, X, y, sample_weight=sample_weight) 

510 return self 

511 

512 def predict(self, X): 

513 """ 

514 Transforms with *ONNX*. 

515 

516 :param X: features 

517 :return: transformed features 

518 """ 

519 pred = self.onnxrt_.transform(X) 

520 if isinstance(pred, tuple): 

521 return pred[0] 

522 return pred.iloc[:, 0].values 

523 

524 def predict_proba(self, X): 

525 """ 

526 Transforms with *ONNX*. 

527 

528 :param X: features 

529 :return: transformed features 

530 """ 

531 pred = self.onnxrt_.transform(X) 

532 if isinstance(pred, tuple): 

533 return pred[1] 

534 return pred.iloc[:, 1:].values 

535 

536 def raw_predict(self, X): 

537 """ 

538 Transforms with *scikit-learn*. 

539 

540 :param X: features 

541 :return: transformed features 

542 """ 

543 return self.estimator_.predict(X) 

544 

545 def raw_predict_proba(self, X): 

546 """ 

547 Transforms with *scikit-learn*. 

548 

549 :param X: features 

550 :return: transformed features 

551 """ 

552 return self.estimator_.predict_proba(X) 

553 

554 def assert_almost_equal(self, X, **kwargs): 

555 """ 

556 Checks that ONNX and scikit-learn produces the same 

557 outputs. 

558 """ 

559 expected = numpy.squeeze(self.raw_predict_proba(X)) 

560 got = numpy.squeeze(self.predict_proba(X)) 

561 assert_almost_equal(expected, got, **kwargs) 

562 expected = numpy.squeeze(self.raw_predict(X)) 

563 got = numpy.squeeze(self.predict(X)) 

564 assert_almost_equal(expected, got, **kwargs) 

565 

566 

567class OnnxSpeedupCluster(ClusterMixin, 

568 _OnnxPipelineStepSpeedup): 

569 """ 

570 Trains with :epkg:`scikit-learn`, transform with :epkg:`ONNX`. 

571 

572 :param estimator: estimator to train 

573 :param enforce_float32: boolean 

574 :epkg:`onnxruntime` only supports *float32*, 

575 :epkg:`scikit-learn` usually uses double floats, this parameter 

576 ensures that every array of double floats is converted into 

577 single floats 

578 :param runtime: string, defined the runtime to use 

579 as described in @see cl OnnxInference. 

580 :param target_opset: targetted ONNX opset 

581 :param conv_options: conversion options, see @see fn to_onnx 

582 :param nopython: used by :epkg:`numba` jitter 

583 

584 Attributes created by method *fit*: 

585 

586 * `estimator_`: cloned and trained version of *estimator* 

587 * `onnxrt_`: objet of type @see cl OnnxInference, 

588 :epkg:`sklearn:preprocessing:FunctionTransformer` 

589 * `numpy_code_`: python code equivalent to the inference 

590 method if the runtime is `'numpy'` or `'numba'` 

591 * `onnx_io_names_`: dictionary, additional information 

592 if the runtime is `'numpy'` or `'numba'` 

593 

594 .. versionadded:: 0.7 

595 """ 

596 

597 def __init__(self, estimator, runtime='python', enforce_float32=True, 

598 target_opset=None, conv_options=None, nopython=True): 

599 _OnnxPipelineStepSpeedup.__init__( 

600 self, estimator, runtime=runtime, enforce_float32=enforce_float32, 

601 target_opset=target_opset, conv_options=conv_options, 

602 nopython=nopython) 

603 

604 def fit(self, X, y, sample_weight=None): # pylint: disable=W0221 

605 """ 

606 Trains based estimator. 

607 """ 

608 if sample_weight is None: 

609 _OnnxPipelineStepSpeedup.fit(self, X, y) 

610 else: 

611 _OnnxPipelineStepSpeedup.fit( 

612 self, X, y, sample_weight=sample_weight) 

613 return self 

614 

615 def predict(self, X): 

616 """ 

617 Transforms with *ONNX*. 

618 

619 :param X: features 

620 :return: transformed features 

621 """ 

622 pred = self.onnxrt_.transform(X) 

623 if isinstance(pred, tuple): 

624 return pred[0] 

625 return pred.iloc[:, 0].values 

626 

627 def transform(self, X): 

628 """ 

629 Transforms with *ONNX*. 

630 

631 :param X: features 

632 :return: transformed features 

633 """ 

634 pred = self.onnxrt_.transform(X) 

635 if isinstance(pred, tuple): 

636 return pred[1] 

637 return pred.iloc[:, 1:].values 

638 

639 def raw_predict(self, X): 

640 """ 

641 Transforms with *scikit-learn*. 

642 

643 :param X: features 

644 :return: transformed features 

645 """ 

646 return self.estimator_.predict(X) 

647 

648 def raw_transform(self, X): 

649 """ 

650 Transforms with *scikit-learn*. 

651 

652 :param X: features 

653 :return: transformed features 

654 """ 

655 return self.estimator_.transform(X) 

656 

657 def assert_almost_equal(self, X, **kwargs): 

658 """ 

659 Checks that ONNX and scikit-learn produces the same 

660 outputs. 

661 """ 

662 expected = numpy.squeeze(self.raw_transform(X)) 

663 got = numpy.squeeze(self.transform(X)) 

664 assert_almost_equal(expected, got, **kwargs) 

665 expected = numpy.squeeze(self.raw_predict(X)) 

666 got = numpy.squeeze(self.predict(X)) 

667 assert_almost_equal(expected, got, **kwargs)