Coverage for mlprodict/onnxrt/ops_cpu/op_svm_classifier.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from collections import OrderedDict
8import numpy
9from ._op_helper import _get_typed_class_attribute
10from ._op import OpRunClassifierProb, RuntimeTypeError
11from ._op_classifier_string import _ClassifierCommon
12from ._new_ops import OperatorSchema
13from .op_svm_classifier_ import ( # pylint: disable=E0611,E0401
14 RuntimeSVMClassifierFloat,
15 RuntimeSVMClassifierDouble,
16)
19class SVMClassifierCommon(OpRunClassifierProb, _ClassifierCommon):
21 def __init__(self, dtype, onnx_node, desc=None,
22 expected_attributes=None, **options):
23 OpRunClassifierProb.__init__(self, onnx_node, desc=desc,
24 expected_attributes=expected_attributes,
25 **options)
26 self._init(dtype=dtype)
28 def _get_typed_attributes(self, k):
29 return _get_typed_class_attribute(self, k, self.__class__.atts)
31 def _find_custom_operator_schema(self, op_name):
32 """
33 Finds a custom operator defined by this runtime.
34 """
35 if op_name == "SVMClassifierDouble":
36 return SVMClassifierDoubleSchema()
37 raise RuntimeError( # pragma: no cover
38 "Unable to find a schema for operator '{}'.".format(op_name))
40 def _init(self, dtype):
41 self._post_process_label_attributes()
42 if dtype == numpy.float32:
43 self.rt_ = RuntimeSVMClassifierFloat(20)
44 elif dtype == numpy.float64:
45 self.rt_ = RuntimeSVMClassifierDouble(20)
46 else:
47 raise RuntimeTypeError( # pragma: no cover
48 "Unsupported dtype={}.".format(dtype))
49 atts = [self._get_typed_attributes(k)
50 for k in SVMClassifier.atts]
51 self.rt_.init(*atts)
53 def _run(self, x): # pylint: disable=W0221
54 """
55 This is a C++ implementation coming from
56 :epkg:`onnxruntime`.
57 `svm_classifier.cc
58 <https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/svm_classifier.cc>`_.
59 See class :class:`RuntimeSVMClassifier
60 <mlprodict.onnxrt.ops_cpu.op_svm_classifier_.RuntimeSVMClassifier>`.
61 """
62 label, scores = self.rt_.compute(x)
63 if scores.shape[0] != label.shape[0]:
64 scores = scores.reshape(label.shape[0],
65 scores.shape[0] // label.shape[0])
66 return self._post_process_predicted_label(label, scores)
69class SVMClassifier(SVMClassifierCommon):
71 atts = OrderedDict([
72 ('classlabels_ints', numpy.empty(0, dtype=numpy.int64)),
73 ('classlabels_strings', []),
74 ('coefficients', numpy.empty(0, dtype=numpy.float32)),
75 ('kernel_params', numpy.empty(0, dtype=numpy.float32)),
76 ('kernel_type', b'NONE'),
77 ('post_transform', b'NONE'),
78 ('prob_a', numpy.empty(0, dtype=numpy.float32)),
79 ('prob_b', numpy.empty(0, dtype=numpy.float32)),
80 ('rho', numpy.empty(0, dtype=numpy.float32)),
81 ('support_vectors', numpy.empty(0, dtype=numpy.float32)),
82 ('vectors_per_class', numpy.empty(0, dtype=numpy.float32)),
83 ])
85 def __init__(self, onnx_node, desc=None, **options):
86 SVMClassifierCommon.__init__(
87 self, numpy.float32, onnx_node, desc=desc,
88 expected_attributes=SVMClassifier.atts,
89 **options)
92class SVMClassifierDouble(SVMClassifierCommon):
94 atts = OrderedDict([
95 ('classlabels_ints', numpy.empty(0, dtype=numpy.int64)),
96 ('classlabels_strings', []),
97 ('coefficients', numpy.empty(0, dtype=numpy.float64)),
98 ('kernel_params', numpy.empty(0, dtype=numpy.float64)),
99 ('kernel_type', b'NONE'),
100 ('post_transform', b'NONE'),
101 ('prob_a', numpy.empty(0, dtype=numpy.float64)),
102 ('prob_b', numpy.empty(0, dtype=numpy.float64)),
103 ('rho', numpy.empty(0, dtype=numpy.float64)),
104 ('support_vectors', numpy.empty(0, dtype=numpy.float64)),
105 ('vectors_per_class', numpy.empty(0, dtype=numpy.float64)),
106 ])
108 def __init__(self, onnx_node, desc=None, **options):
109 SVMClassifierCommon.__init__(
110 self, numpy.float64, onnx_node, desc=desc,
111 expected_attributes=SVMClassifierDouble.atts,
112 **options)
115class SVMClassifierDoubleSchema(OperatorSchema):
116 """
117 Defines a schema for operators added in this package
118 such as @see cl SVMClassifierDouble.
119 """
121 def __init__(self):
122 OperatorSchema.__init__(self, 'SVMClassifierDouble')
123 self.attributes = SVMClassifierDouble.atts