Coverage for mlprodict/onnx_conv/sklconv/svm_converters.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Rewrites some of the converters implemented in
4:epkg:`sklearn-onnx`.
5"""
6import numpy
7from skl2onnx.operator_converters.support_vector_machines import (
8 convert_sklearn_svm_regressor,
9 convert_sklearn_svm_classifier)
10from skl2onnx.common.data_types import guess_numpy_type
13def _op_type_domain_regressor(dtype):
14 """
15 Defines *op_type* and *op_domain* based on `dtype`.
16 """
17 if dtype == numpy.float32:
18 return 'SVMRegressor', 'ai.onnx.ml', 1
19 if dtype == numpy.float64:
20 return 'SVMRegressorDouble', 'mlprodict', 1
21 raise RuntimeError( # pragma: no cover
22 "Unsupported dtype {}.".format(dtype))
25def _op_type_domain_classifier(dtype):
26 """
27 Defines *op_type* and *op_domain* based on `dtype`.
28 """
29 if dtype == numpy.float32:
30 return 'SVMClassifier', 'ai.onnx.ml', 1
31 if dtype == numpy.float64:
32 return 'SVMClassifierDouble', 'mlprodict', 1
33 raise RuntimeError( # pragma: no cover
34 "Unsupported dtype {}.".format(dtype))
37def new_convert_sklearn_svm_regressor(scope, operator, container):
38 """
39 Rewrites the converters implemented in
40 :epkg:`sklearn-onnx` to support an operator supporting
41 doubles.
42 """
43 dtype = guess_numpy_type(operator.inputs[0].type)
44 if dtype != numpy.float64:
45 dtype = numpy.float32
46 op_type, op_domain, op_version = _op_type_domain_regressor(dtype)
47 convert_sklearn_svm_regressor(
48 scope, operator, container, op_type=op_type, op_domain=op_domain,
49 op_version=op_version)
52def new_convert_sklearn_svm_classifier(scope, operator, container):
53 """
54 Rewrites the converters implemented in
55 :epkg:`sklearn-onnx` to support an operator supporting
56 doubles.
57 """
58 dtype = guess_numpy_type(operator.inputs[0].type)
59 if dtype != numpy.float64:
60 dtype = numpy.float32
61 op_type, op_domain, op_version = _op_type_domain_classifier(dtype)
62 convert_sklearn_svm_classifier(
63 scope, operator, container, op_type=op_type, op_domain=op_domain,
64 op_version=op_version)