Coverage for mlprodict/onnxrt/ops_cpu/op_tree_ensemble_classifier.py: 94%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

77 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from collections import OrderedDict 

8import numpy 

9from onnx.defs import onnx_opset_version 

10from ._op_helper import _get_typed_class_attribute 

11from ._op import OpRunClassifierProb, RuntimeTypeError 

12from ._op_classifier_string import _ClassifierCommon 

13from ._new_ops import OperatorSchema 

14from .op_tree_ensemble_classifier_ import ( # pylint: disable=E0611,E0401 

15 RuntimeTreeEnsembleClassifierDouble, 

16 RuntimeTreeEnsembleClassifierFloat) 

17from .op_tree_ensemble_classifier_p_ import ( # pylint: disable=E0611,E0401 

18 RuntimeTreeEnsembleClassifierPFloat, 

19 RuntimeTreeEnsembleClassifierPDouble) 

20 

21 

22class TreeEnsembleClassifierCommon(OpRunClassifierProb, _ClassifierCommon): 

23 

24 def __init__(self, dtype, onnx_node, desc=None, 

25 expected_attributes=None, 

26 runtime_version=3, **options): 

27 OpRunClassifierProb.__init__( 

28 self, onnx_node, desc=desc, 

29 expected_attributes=expected_attributes, **options) 

30 self._init(dtype=dtype, version=runtime_version) 

31 

32 def _get_typed_attributes(self, k): 

33 return _get_typed_class_attribute(self, k, self.__class__.atts) 

34 

35 def _find_custom_operator_schema(self, op_name): 

36 """ 

37 Finds a custom operator defined by this runtime. 

38 """ 

39 if op_name == "TreeEnsembleClassifierDouble": 

40 return TreeEnsembleClassifierDoubleSchema() 

41 raise RuntimeError( # pragma: no cover 

42 "Unable to find a schema for operator '{}'.".format(op_name)) 

43 

44 def _init(self, dtype, version): 

45 self._post_process_label_attributes() 

46 

47 atts = [] 

48 for k in self.__class__.atts: 

49 v = self._get_typed_attributes(k) 

50 if k.endswith('_as_tensor'): 

51 if (v is not None and isinstance(v, numpy.ndarray) and 

52 v.size > 0): 

53 # replacements 

54 atts[-1] = v 

55 if dtype is None: 

56 dtype = v.dtype 

57 continue 

58 atts.append(v) 

59 

60 if dtype is None: 

61 dtype = numpy.float32 

62 

63 if dtype == numpy.float32: 

64 if version == 0: 

65 self.rt_ = RuntimeTreeEnsembleClassifierFloat() 

66 elif version == 1: 

67 self.rt_ = RuntimeTreeEnsembleClassifierPFloat( 

68 60, 20, False, False) 

69 elif version == 2: 

70 self.rt_ = RuntimeTreeEnsembleClassifierPFloat( 

71 60, 20, True, False) 

72 elif version == 3: 

73 self.rt_ = RuntimeTreeEnsembleClassifierPFloat( 

74 60, 20, True, True) 

75 else: 

76 raise ValueError("Unknown version '{}'.".format(version)) 

77 elif dtype == numpy.float64: 

78 if version == 0: 

79 self.rt_ = RuntimeTreeEnsembleClassifierDouble() 

80 elif version == 1: 

81 self.rt_ = RuntimeTreeEnsembleClassifierPDouble( 

82 60, 20, False, False) 

83 elif version == 2: 

84 self.rt_ = RuntimeTreeEnsembleClassifierPDouble( 

85 60, 20, True, False) 

86 elif version == 3: 

87 self.rt_ = RuntimeTreeEnsembleClassifierPDouble( 

88 60, 20, True, True) 

89 else: 

90 raise ValueError( # pragma: no cover 

91 "Unknown version '{}'.".format(version)) 

92 else: 

93 raise RuntimeTypeError( # pragma: no cover 

94 "Unsupported dtype={}.".format(dtype)) 

95 self.rt_.init(*atts) 

96 

97 def _run(self, x): # pylint: disable=W0221 

98 """ 

99 This is a C++ implementation coming from 

100 :epkg:`onnxruntime`. 

101 `tree_ensemble_classifier.cc 

102 <https://github.com/microsoft/onnxruntime/blob/master/ 

103 onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc>`_. 

104 See class :class:`RuntimeTreeEnsembleClassifier 

105 <mlprodict.onnxrt.ops_cpu.op_tree_ensemble_classifier_.RuntimeTreeEnsembleClassifier>`. 

106 """ 

107 label, scores = self.rt_.compute(x) 

108 if scores.shape[0] != label.shape[0]: 

109 scores = scores.reshape(label.shape[0], 

110 scores.shape[0] // label.shape[0]) 

111 return self._post_process_predicted_label(label, scores) 

112 

113 

114class TreeEnsembleClassifier_1(TreeEnsembleClassifierCommon): 

115 

116 atts = OrderedDict([ 

117 ('base_values', numpy.empty(0, dtype=numpy.float32)), 

118 ('class_ids', numpy.empty(0, dtype=numpy.int64)), 

119 ('class_nodeids', numpy.empty(0, dtype=numpy.int64)), 

120 ('class_treeids', numpy.empty(0, dtype=numpy.int64)), 

121 ('class_weights', numpy.empty(0, dtype=numpy.float32)), 

122 ('classlabels_int64s', numpy.empty(0, dtype=numpy.int64)), 

123 ('classlabels_strings', []), 

124 ('nodes_falsenodeids', numpy.empty(0, dtype=numpy.int64)), 

125 ('nodes_featureids', numpy.empty(0, dtype=numpy.int64)), 

126 ('nodes_hitrates', numpy.empty(0, dtype=numpy.float32)), 

127 ('nodes_missing_value_tracks_true', numpy.empty(0, dtype=numpy.int64)), 

128 ('nodes_modes', []), 

129 ('nodes_nodeids', numpy.empty(0, dtype=numpy.int64)), 

130 ('nodes_treeids', numpy.empty(0, dtype=numpy.int64)), 

131 ('nodes_truenodeids', numpy.empty(0, dtype=numpy.int64)), 

132 ('nodes_values', numpy.empty(0, dtype=numpy.float32)), 

133 ('post_transform', b'NONE') 

134 ]) 

135 

136 def __init__(self, onnx_node, desc=None, **options): 

137 TreeEnsembleClassifierCommon.__init__( 

138 self, numpy.float32, onnx_node, desc=desc, 

139 expected_attributes=TreeEnsembleClassifier_1.atts, **options) 

140 

141 

142class TreeEnsembleClassifier_3(TreeEnsembleClassifierCommon): 

143 

144 atts = OrderedDict([ 

145 ('base_values', numpy.empty(0, dtype=numpy.float32)), 

146 ('base_values_as_tensor', []), 

147 ('class_ids', numpy.empty(0, dtype=numpy.int64)), 

148 ('class_nodeids', numpy.empty(0, dtype=numpy.int64)), 

149 ('class_treeids', numpy.empty(0, dtype=numpy.int64)), 

150 ('class_weights', numpy.empty(0, dtype=numpy.float32)), 

151 ('class_weights_as_tensor', []), 

152 ('classlabels_int64s', numpy.empty(0, dtype=numpy.int64)), 

153 ('classlabels_strings', []), 

154 ('nodes_falsenodeids', numpy.empty(0, dtype=numpy.int64)), 

155 ('nodes_featureids', numpy.empty(0, dtype=numpy.int64)), 

156 ('nodes_hitrates', numpy.empty(0, dtype=numpy.float32)), 

157 ('nodes_hitrates_as_tensor', []), 

158 ('nodes_missing_value_tracks_true', numpy.empty(0, dtype=numpy.int64)), 

159 ('nodes_modes', []), 

160 ('nodes_nodeids', numpy.empty(0, dtype=numpy.int64)), 

161 ('nodes_treeids', numpy.empty(0, dtype=numpy.int64)), 

162 ('nodes_truenodeids', numpy.empty(0, dtype=numpy.int64)), 

163 ('nodes_values', numpy.empty(0, dtype=numpy.float32)), 

164 ('nodes_values_as_tensor', []), 

165 ('post_transform', b'NONE') 

166 ]) 

167 

168 def __init__(self, onnx_node, desc=None, **options): 

169 TreeEnsembleClassifierCommon.__init__( 

170 self, None, onnx_node, desc=desc, 

171 expected_attributes=TreeEnsembleClassifier_3.atts, **options) 

172 

173 

174class TreeEnsembleClassifierDouble(TreeEnsembleClassifierCommon): 

175 

176 atts = OrderedDict([ 

177 ('base_values', numpy.empty(0, dtype=numpy.float64)), 

178 ('class_ids', numpy.empty(0, dtype=numpy.int64)), 

179 ('class_nodeids', numpy.empty(0, dtype=numpy.int64)), 

180 ('class_treeids', numpy.empty(0, dtype=numpy.int64)), 

181 ('class_weights', numpy.empty(0, dtype=numpy.float64)), 

182 ('classlabels_int64s', numpy.empty(0, dtype=numpy.int64)), 

183 ('classlabels_strings', []), 

184 ('nodes_falsenodeids', numpy.empty(0, dtype=numpy.int64)), 

185 ('nodes_featureids', numpy.empty(0, dtype=numpy.int64)), 

186 ('nodes_hitrates', numpy.empty(0, dtype=numpy.float64)), 

187 ('nodes_missing_value_tracks_true', numpy.empty(0, dtype=numpy.int64)), 

188 ('nodes_modes', []), 

189 ('nodes_nodeids', numpy.empty(0, dtype=numpy.int64)), 

190 ('nodes_treeids', numpy.empty(0, dtype=numpy.int64)), 

191 ('nodes_truenodeids', numpy.empty(0, dtype=numpy.int64)), 

192 ('nodes_values', numpy.empty(0, dtype=numpy.float64)), 

193 ('post_transform', b'NONE') 

194 ]) 

195 

196 def __init__(self, onnx_node, desc=None, **options): 

197 TreeEnsembleClassifierCommon.__init__( 

198 self, numpy.float64, onnx_node, desc=desc, 

199 expected_attributes=TreeEnsembleClassifierDouble.atts, **options) 

200 

201 

202class TreeEnsembleClassifierDoubleSchema(OperatorSchema): 

203 """ 

204 Defines a schema for operators added in this package 

205 such as @see cl TreeEnsembleClassifierDouble. 

206 """ 

207 

208 def __init__(self): 

209 OperatorSchema.__init__(self, 'TreeEnsembleClassifierDouble') 

210 self.attributes = TreeEnsembleClassifierDouble.atts 

211 

212 

213if onnx_opset_version() >= 16: 

214 TreeEnsembleClassifier = TreeEnsembleClassifier_3 

215else: 

216 TreeEnsembleClassifier = TreeEnsembleClassifier_1