Coverage for mlprodict/onnxrt/ops_cpu/op_svm_regressor.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from collections import OrderedDict
8import numpy
9from ._op_helper import _get_typed_class_attribute
10from ._op import OpRunUnaryNum, RuntimeTypeError
11from ._new_ops import OperatorSchema
12from .op_svm_regressor_ import ( # pylint: disable=E0611,E0401
13 RuntimeSVMRegressorFloat,
14 RuntimeSVMRegressorDouble,
15)
18class SVMRegressorCommon(OpRunUnaryNum):
20 def __init__(self, dtype, onnx_node, desc=None,
21 expected_attributes=None, **options):
22 OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
23 expected_attributes=expected_attributes,
24 **options)
25 self._init(dtype=dtype)
27 def _get_typed_attributes(self, k):
28 return _get_typed_class_attribute(self, k, self.__class__.atts)
30 def _find_custom_operator_schema(self, op_name):
31 """
32 Finds a custom operator defined by this runtime.
33 """
34 if op_name == "SVMRegressorDouble":
35 return SVMRegressorDoubleSchema()
36 raise RuntimeError( # pragma: no cover
37 "Unable to find a schema for operator '{}'.".format(op_name))
39 def _init(self, dtype):
40 if dtype == numpy.float32:
41 self.rt_ = RuntimeSVMRegressorFloat(50)
42 elif dtype == numpy.float64:
43 self.rt_ = RuntimeSVMRegressorDouble(50)
44 else:
45 raise RuntimeTypeError( # pragma: no cover
46 "Unsupported dtype={}.".format(dtype))
47 atts = [self._get_typed_attributes(k)
48 for k in SVMRegressor.atts]
49 self.rt_.init(*atts)
51 def _run(self, x): # pylint: disable=W0221
52 """
53 This is a C++ implementation coming from
54 :epkg:`onnxruntime`.
55 `svm_regressor.cc
56 <https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/svm_regressor.cc>`_.
57 See class :class:`RuntimeSVMRegressor
58 <mlprodict.onnxrt.ops_cpu.op_svm_regressor_.RuntimeSVMRegressor>`.
59 """
60 pred = self.rt_.compute(x)
61 if pred.shape[0] != x.shape[0]:
62 pred = pred.reshape(x.shape[0], pred.shape[0] // x.shape[0])
63 return (pred, )
66class SVMRegressor(SVMRegressorCommon):
68 atts = OrderedDict([
69 ('coefficients', numpy.empty(0, dtype=numpy.float32)),
70 ('kernel_params', numpy.empty(0, dtype=numpy.float32)),
71 ('kernel_type', b'NONE'),
72 ('n_supports', 0),
73 ('one_class', 0),
74 ('post_transform', b'NONE'),
75 ('rho', numpy.empty(0, dtype=numpy.float32)),
76 ('support_vectors', numpy.empty(0, dtype=numpy.float32)),
77 ])
79 def __init__(self, onnx_node, desc=None, **options):
80 SVMRegressorCommon.__init__(
81 self, numpy.float32, onnx_node, desc=desc,
82 expected_attributes=SVMRegressor.atts,
83 **options)
86class SVMRegressorDouble(SVMRegressorCommon):
88 atts = OrderedDict([
89 ('coefficients', numpy.empty(0, dtype=numpy.float64)),
90 ('kernel_params', numpy.empty(0, dtype=numpy.float64)),
91 ('kernel_type', b'NONE'),
92 ('n_supports', 0),
93 ('one_class', 0),
94 ('post_transform', b'NONE'),
95 ('rho', numpy.empty(0, dtype=numpy.float64)),
96 ('support_vectors', numpy.empty(0, dtype=numpy.float64)),
97 ])
99 def __init__(self, onnx_node, desc=None, **options):
100 SVMRegressorCommon.__init__(
101 self, numpy.float64, onnx_node, desc=desc,
102 expected_attributes=SVMRegressorDouble.atts,
103 **options)
106class SVMRegressorDoubleSchema(OperatorSchema):
107 """
108 Defines a schema for operators added in this package
109 such as @see cl SVMRegressorDouble.
110 """
112 def __init__(self):
113 OperatorSchema.__init__(self, 'SVMRegressorDouble')
114 self.attributes = SVMRegressorDouble.atts