Coverage for mlprodict/onnxrt/ops_cpu/op_tree_ensemble_classifier.py: 94%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from collections import OrderedDict
8import numpy
9from onnx.defs import onnx_opset_version
10from ._op_helper import _get_typed_class_attribute
11from ._op import OpRunClassifierProb, RuntimeTypeError
12from ._op_classifier_string import _ClassifierCommon
13from ._new_ops import OperatorSchema
14from .op_tree_ensemble_classifier_ import ( # pylint: disable=E0611,E0401
15 RuntimeTreeEnsembleClassifierDouble,
16 RuntimeTreeEnsembleClassifierFloat)
17from .op_tree_ensemble_classifier_p_ import ( # pylint: disable=E0611,E0401
18 RuntimeTreeEnsembleClassifierPFloat,
19 RuntimeTreeEnsembleClassifierPDouble)
22class TreeEnsembleClassifierCommon(OpRunClassifierProb, _ClassifierCommon):
24 def __init__(self, dtype, onnx_node, desc=None,
25 expected_attributes=None,
26 runtime_version=3, **options):
27 OpRunClassifierProb.__init__(
28 self, onnx_node, desc=desc,
29 expected_attributes=expected_attributes, **options)
30 self._init(dtype=dtype, version=runtime_version)
32 def _get_typed_attributes(self, k):
33 return _get_typed_class_attribute(self, k, self.__class__.atts)
35 def _find_custom_operator_schema(self, op_name):
36 """
37 Finds a custom operator defined by this runtime.
38 """
39 if op_name == "TreeEnsembleClassifierDouble":
40 return TreeEnsembleClassifierDoubleSchema()
41 raise RuntimeError( # pragma: no cover
42 "Unable to find a schema for operator '{}'.".format(op_name))
44 def _init(self, dtype, version):
45 self._post_process_label_attributes()
47 atts = []
48 for k in self.__class__.atts:
49 v = self._get_typed_attributes(k)
50 if k.endswith('_as_tensor'):
51 if (v is not None and isinstance(v, numpy.ndarray) and
52 v.size > 0):
53 # replacements
54 atts[-1] = v
55 if dtype is None:
56 dtype = v.dtype
57 continue
58 atts.append(v)
60 if dtype is None:
61 dtype = numpy.float32
63 if dtype == numpy.float32:
64 if version == 0:
65 self.rt_ = RuntimeTreeEnsembleClassifierFloat()
66 elif version == 1:
67 self.rt_ = RuntimeTreeEnsembleClassifierPFloat(
68 60, 20, False, False)
69 elif version == 2:
70 self.rt_ = RuntimeTreeEnsembleClassifierPFloat(
71 60, 20, True, False)
72 elif version == 3:
73 self.rt_ = RuntimeTreeEnsembleClassifierPFloat(
74 60, 20, True, True)
75 else:
76 raise ValueError("Unknown version '{}'.".format(version))
77 elif dtype == numpy.float64:
78 if version == 0:
79 self.rt_ = RuntimeTreeEnsembleClassifierDouble()
80 elif version == 1:
81 self.rt_ = RuntimeTreeEnsembleClassifierPDouble(
82 60, 20, False, False)
83 elif version == 2:
84 self.rt_ = RuntimeTreeEnsembleClassifierPDouble(
85 60, 20, True, False)
86 elif version == 3:
87 self.rt_ = RuntimeTreeEnsembleClassifierPDouble(
88 60, 20, True, True)
89 else:
90 raise ValueError( # pragma: no cover
91 "Unknown version '{}'.".format(version))
92 else:
93 raise RuntimeTypeError( # pragma: no cover
94 "Unsupported dtype={}.".format(dtype))
95 self.rt_.init(*atts)
97 def _run(self, x): # pylint: disable=W0221
98 """
99 This is a C++ implementation coming from
100 :epkg:`onnxruntime`.
101 `tree_ensemble_classifier.cc
102 <https://github.com/microsoft/onnxruntime/blob/master/
103 onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc>`_.
104 See class :class:`RuntimeTreeEnsembleClassifier
105 <mlprodict.onnxrt.ops_cpu.op_tree_ensemble_classifier_.RuntimeTreeEnsembleClassifier>`.
106 """
107 label, scores = self.rt_.compute(x)
108 if scores.shape[0] != label.shape[0]:
109 scores = scores.reshape(label.shape[0],
110 scores.shape[0] // label.shape[0])
111 return self._post_process_predicted_label(label, scores)
114class TreeEnsembleClassifier_1(TreeEnsembleClassifierCommon):
116 atts = OrderedDict([
117 ('base_values', numpy.empty(0, dtype=numpy.float32)),
118 ('class_ids', numpy.empty(0, dtype=numpy.int64)),
119 ('class_nodeids', numpy.empty(0, dtype=numpy.int64)),
120 ('class_treeids', numpy.empty(0, dtype=numpy.int64)),
121 ('class_weights', numpy.empty(0, dtype=numpy.float32)),
122 ('classlabels_int64s', numpy.empty(0, dtype=numpy.int64)),
123 ('classlabels_strings', []),
124 ('nodes_falsenodeids', numpy.empty(0, dtype=numpy.int64)),
125 ('nodes_featureids', numpy.empty(0, dtype=numpy.int64)),
126 ('nodes_hitrates', numpy.empty(0, dtype=numpy.float32)),
127 ('nodes_missing_value_tracks_true', numpy.empty(0, dtype=numpy.int64)),
128 ('nodes_modes', []),
129 ('nodes_nodeids', numpy.empty(0, dtype=numpy.int64)),
130 ('nodes_treeids', numpy.empty(0, dtype=numpy.int64)),
131 ('nodes_truenodeids', numpy.empty(0, dtype=numpy.int64)),
132 ('nodes_values', numpy.empty(0, dtype=numpy.float32)),
133 ('post_transform', b'NONE')
134 ])
136 def __init__(self, onnx_node, desc=None, **options):
137 TreeEnsembleClassifierCommon.__init__(
138 self, numpy.float32, onnx_node, desc=desc,
139 expected_attributes=TreeEnsembleClassifier_1.atts, **options)
142class TreeEnsembleClassifier_3(TreeEnsembleClassifierCommon):
144 atts = OrderedDict([
145 ('base_values', numpy.empty(0, dtype=numpy.float32)),
146 ('base_values_as_tensor', []),
147 ('class_ids', numpy.empty(0, dtype=numpy.int64)),
148 ('class_nodeids', numpy.empty(0, dtype=numpy.int64)),
149 ('class_treeids', numpy.empty(0, dtype=numpy.int64)),
150 ('class_weights', numpy.empty(0, dtype=numpy.float32)),
151 ('class_weights_as_tensor', []),
152 ('classlabels_int64s', numpy.empty(0, dtype=numpy.int64)),
153 ('classlabels_strings', []),
154 ('nodes_falsenodeids', numpy.empty(0, dtype=numpy.int64)),
155 ('nodes_featureids', numpy.empty(0, dtype=numpy.int64)),
156 ('nodes_hitrates', numpy.empty(0, dtype=numpy.float32)),
157 ('nodes_hitrates_as_tensor', []),
158 ('nodes_missing_value_tracks_true', numpy.empty(0, dtype=numpy.int64)),
159 ('nodes_modes', []),
160 ('nodes_nodeids', numpy.empty(0, dtype=numpy.int64)),
161 ('nodes_treeids', numpy.empty(0, dtype=numpy.int64)),
162 ('nodes_truenodeids', numpy.empty(0, dtype=numpy.int64)),
163 ('nodes_values', numpy.empty(0, dtype=numpy.float32)),
164 ('nodes_values_as_tensor', []),
165 ('post_transform', b'NONE')
166 ])
168 def __init__(self, onnx_node, desc=None, **options):
169 TreeEnsembleClassifierCommon.__init__(
170 self, None, onnx_node, desc=desc,
171 expected_attributes=TreeEnsembleClassifier_3.atts, **options)
174class TreeEnsembleClassifierDouble(TreeEnsembleClassifierCommon):
176 atts = OrderedDict([
177 ('base_values', numpy.empty(0, dtype=numpy.float64)),
178 ('class_ids', numpy.empty(0, dtype=numpy.int64)),
179 ('class_nodeids', numpy.empty(0, dtype=numpy.int64)),
180 ('class_treeids', numpy.empty(0, dtype=numpy.int64)),
181 ('class_weights', numpy.empty(0, dtype=numpy.float64)),
182 ('classlabels_int64s', numpy.empty(0, dtype=numpy.int64)),
183 ('classlabels_strings', []),
184 ('nodes_falsenodeids', numpy.empty(0, dtype=numpy.int64)),
185 ('nodes_featureids', numpy.empty(0, dtype=numpy.int64)),
186 ('nodes_hitrates', numpy.empty(0, dtype=numpy.float64)),
187 ('nodes_missing_value_tracks_true', numpy.empty(0, dtype=numpy.int64)),
188 ('nodes_modes', []),
189 ('nodes_nodeids', numpy.empty(0, dtype=numpy.int64)),
190 ('nodes_treeids', numpy.empty(0, dtype=numpy.int64)),
191 ('nodes_truenodeids', numpy.empty(0, dtype=numpy.int64)),
192 ('nodes_values', numpy.empty(0, dtype=numpy.float64)),
193 ('post_transform', b'NONE')
194 ])
196 def __init__(self, onnx_node, desc=None, **options):
197 TreeEnsembleClassifierCommon.__init__(
198 self, numpy.float64, onnx_node, desc=desc,
199 expected_attributes=TreeEnsembleClassifierDouble.atts, **options)
202class TreeEnsembleClassifierDoubleSchema(OperatorSchema):
203 """
204 Defines a schema for operators added in this package
205 such as @see cl TreeEnsembleClassifierDouble.
206 """
208 def __init__(self):
209 OperatorSchema.__init__(self, 'TreeEnsembleClassifierDouble')
210 self.attributes = TreeEnsembleClassifierDouble.atts
213if onnx_opset_version() >= 16:
214 TreeEnsembleClassifier = TreeEnsembleClassifier_3
215else:
216 TreeEnsembleClassifier = TreeEnsembleClassifier_1