Coverage for mlprodict/onnx_conv/operator_converters/conv_xgboost.py: 91%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Modified converter from
4`XGBoost.py <https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/
5xgboost/operator_converters/XGBoost.py>`_.
6"""
7import json
8from pprint import pformat
9import numpy
10from xgboost import XGBClassifier
11from skl2onnx.common.data_types import guess_numpy_type # pylint: disable=C0411
12from ..sklconv.tree_converters import _fix_tree_ensemble
15class XGBConverter:
16 "common methods for converters"
18 @staticmethod
19 def get_xgb_params(xgb_node):
20 """
21 Retrieves parameters of a model.
22 """
23 pars = xgb_node.get_xgb_params()
24 # xgboost >= 1.0
25 if 'n_estimators' not in pars:
26 pars['n_estimators'] = xgb_node.n_estimators
27 return pars
29 @staticmethod
30 def validate(xgb_node):
31 "validates the model"
32 params = XGBConverter.get_xgb_params(xgb_node)
33 try:
34 if "objective" not in params:
35 raise AttributeError('ojective')
36 except AttributeError as e: # pragma: no cover
37 raise RuntimeError('Missing attribute in XGBoost model.') from e
39 @staticmethod
40 def common_members(xgb_node, inputs):
41 "common to regresssor and classifier"
42 params = XGBConverter.get_xgb_params(xgb_node)
43 objective = params["objective"]
44 base_score = params["base_score"]
45 booster = xgb_node.get_booster()
46 # The json format was available in October 2017.
47 # XGBoost 0.7 was the first version released with it.
48 js_tree_list = booster.get_dump(with_stats=True, dump_format='json')
49 js_trees = [json.loads(s) for s in js_tree_list]
50 return objective, base_score, js_trees
52 @staticmethod
53 def _get_default_tree_attribute_pairs(is_classifier):
54 attrs = {}
55 for k in {'nodes_treeids', 'nodes_nodeids',
56 'nodes_featureids', 'nodes_modes', 'nodes_values',
57 'nodes_truenodeids', 'nodes_falsenodeids', 'nodes_missing_value_tracks_true'}:
58 attrs[k] = []
59 if is_classifier:
60 for k in {'class_treeids', 'class_nodeids', 'class_ids', 'class_weights'}:
61 attrs[k] = []
62 else:
63 for k in {'target_treeids', 'target_nodeids', 'target_ids', 'target_weights'}:
64 attrs[k] = []
65 return attrs
67 @staticmethod
68 def _add_node(attr_pairs, is_classifier, tree_id, tree_weight, node_id,
69 feature_id, mode, value, true_child_id, false_child_id, weights, weight_id_bias,
70 missing, hitrate):
71 if isinstance(feature_id, str):
72 # Something like f0, f1...
73 if feature_id[0] == "f":
74 try:
75 feature_id = int(feature_id[1:])
76 except ValueError as e: # pragma: no cover
77 raise RuntimeError(
78 "Unable to interpret '{0}'".format(feature_id)) from e
79 else: # pragma: no cover
80 try:
81 feature_id = int(feature_id)
82 except ValueError:
83 raise RuntimeError(
84 "Unable to interpret '{0}'".format(feature_id)) from e
86 # Split condition for sklearn
87 # * if X_ptr[X_sample_stride * i + X_fx_stride * node.feature] <= node.threshold:
88 # * https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L946
89 # Split condition for xgboost
90 # * if (fvalue < split_value)
91 # * https://github.com/dmlc/xgboost/blob/master/include/xgboost/tree_model.h#L804
93 attr_pairs['nodes_treeids'].append(tree_id)
94 attr_pairs['nodes_nodeids'].append(node_id)
95 attr_pairs['nodes_featureids'].append(feature_id)
96 attr_pairs['nodes_modes'].append(mode)
97 attr_pairs['nodes_values'].append(float(value))
98 attr_pairs['nodes_truenodeids'].append(true_child_id)
99 attr_pairs['nodes_falsenodeids'].append(false_child_id)
100 attr_pairs['nodes_missing_value_tracks_true'].append(missing)
101 if 'nodes_hitrates' in attr_pairs:
102 attr_pairs['nodes_hitrates'].append(hitrate) # pragma: no cover
103 if mode == 'LEAF':
104 if is_classifier:
105 for i, w in enumerate(weights):
106 attr_pairs['class_treeids'].append(tree_id)
107 attr_pairs['class_nodeids'].append(node_id)
108 attr_pairs['class_ids'].append(i + weight_id_bias)
109 attr_pairs['class_weights'].append(float(tree_weight * w))
110 else:
111 for i, w in enumerate(weights):
112 attr_pairs['target_treeids'].append(tree_id)
113 attr_pairs['target_nodeids'].append(node_id)
114 attr_pairs['target_ids'].append(i + weight_id_bias)
115 attr_pairs['target_weights'].append(float(tree_weight * w))
117 @staticmethod
118 def _fill_node_attributes(treeid, tree_weight, jsnode, attr_pairs, is_classifier, remap):
119 if 'children' in jsnode:
120 XGBConverter._add_node(attr_pairs=attr_pairs, is_classifier=is_classifier,
121 tree_id=treeid, tree_weight=tree_weight,
122 value=jsnode['split_condition'], node_id=remap[jsnode['nodeid']],
123 feature_id=jsnode['split'],
124 mode='BRANCH_LT', # 'BRANCH_LEQ' --> is for sklearn
125 # ['children'][0]['nodeid'],
126 true_child_id=remap[jsnode['yes']],
127 # ['children'][1]['nodeid'],
128 false_child_id=remap[jsnode['no']],
129 weights=None, weight_id_bias=None,
130 # ['children'][0]['nodeid'],
131 missing=jsnode.get(
132 'missing', -1) == jsnode['yes'],
133 hitrate=jsnode.get('cover', 0))
135 for ch in jsnode['children']:
136 if 'children' in ch or 'leaf' in ch:
137 XGBConverter._fill_node_attributes(
138 treeid, tree_weight, ch, attr_pairs, is_classifier, remap)
139 else:
140 raise RuntimeError( # pragma: no cover
141 "Unable to convert this node {0}".format(ch))
143 else:
144 weights = [jsnode['leaf']]
145 weights_id_bias = 0
146 XGBConverter._add_node(attr_pairs=attr_pairs, is_classifier=is_classifier,
147 tree_id=treeid, tree_weight=tree_weight,
148 value=0., node_id=remap[jsnode['nodeid']],
149 feature_id=0, mode='LEAF',
150 true_child_id=0, false_child_id=0,
151 weights=weights, weight_id_bias=weights_id_bias,
152 missing=False, hitrate=jsnode.get('cover', 0))
154 @staticmethod
155 def _remap_nodeid(jsnode, remap=None):
156 if remap is None:
157 remap = {}
158 nid = jsnode['nodeid']
159 remap[nid] = len(remap)
160 if 'children' in jsnode:
161 for ch in jsnode['children']:
162 XGBConverter._remap_nodeid(ch, remap)
163 return remap
165 @staticmethod
166 def fill_tree_attributes(js_xgb_node, attr_pairs, tree_weights, is_classifier):
167 "fills tree attributes"
168 if not isinstance(js_xgb_node, list):
169 raise TypeError( # pragma: no cover
170 "js_xgb_node must be a list")
171 for treeid, (jstree, w) in enumerate(zip(js_xgb_node, tree_weights)):
172 remap = XGBConverter._remap_nodeid(jstree)
173 XGBConverter._fill_node_attributes(
174 treeid, w, jstree, attr_pairs, is_classifier, remap)
177class XGBRegressorConverter(XGBConverter):
178 "converter class"
180 @staticmethod
181 def validate(xgb_node):
182 return XGBConverter.validate(xgb_node)
184 @staticmethod
185 def _get_default_tree_attribute_pairs(): # pylint: disable=W0221
186 attrs = XGBConverter._get_default_tree_attribute_pairs(False)
187 attrs['post_transform'] = 'NONE'
188 attrs['n_targets'] = 1
189 return attrs
191 @staticmethod
192 def convert(scope, operator, container):
193 "converter method"
194 dtype = guess_numpy_type(operator.inputs[0].type)
195 if dtype != numpy.float64:
196 dtype = numpy.float32
197 opsetml = container.target_opset_all.get('ai.onnx.ml', None)
198 if opsetml is None:
199 opsetml = 3 if container.target_opset >= 16 else 1
200 xgb_node = operator.raw_operator
201 inputs = operator.inputs
202 objective, base_score, js_trees = XGBConverter.common_members(
203 xgb_node, inputs)
205 if objective in ["reg:gamma", "reg:tweedie"]:
206 raise RuntimeError( # pragma: no cover
207 "Objective '{}' not supported.".format(objective))
209 booster = xgb_node.get_booster()
210 if booster is None:
211 raise RuntimeError( # pragma: no cover
212 "The model was probably not trained.")
214 best_ntree_limit = getattr(booster, 'best_ntree_limit', len(js_trees))
215 if best_ntree_limit < len(js_trees):
216 js_trees = js_trees[:best_ntree_limit]
218 attr_pairs = XGBRegressorConverter._get_default_tree_attribute_pairs()
219 attr_pairs['base_values'] = [base_score]
220 XGBConverter.fill_tree_attributes(
221 js_trees, attr_pairs, [1 for _ in js_trees], False)
223 # add nodes
224 if dtype == numpy.float64 and opsetml < 3:
225 container.add_node(
226 'TreeEnsembleRegressorDouble', operator.input_full_names,
227 operator.output_full_names,
228 name=scope.get_unique_operator_name(
229 'TreeEnsembleRegressorDouble'),
230 op_domain='mlprodict', op_version=1, **attr_pairs)
231 else:
232 container.add_node(
233 'TreeEnsembleRegressor', operator.input_full_names,
234 operator.output_full_names,
235 name=scope.get_unique_operator_name('TreeEnsembleRegressor'),
236 op_domain='ai.onnx.ml', op_version=1, **attr_pairs)
237 if opsetml >= 3:
238 _fix_tree_ensemble(scope, container, opsetml, dtype)
241class XGBClassifierConverter(XGBConverter):
242 "converter for XGBClassifier"
244 @staticmethod
245 def validate(xgb_node):
246 return XGBConverter.validate(xgb_node)
248 @staticmethod
249 def _get_default_tree_attribute_pairs(): # pylint: disable=W0221
250 attrs = XGBConverter._get_default_tree_attribute_pairs(True)
251 # attrs['nodes_hitrates'] = []
252 return attrs
254 @staticmethod
255 def convert(scope, operator, container):
256 "convert method"
257 opsetml = container.target_opset_all.get('ai.onnx.ml', None)
258 if opsetml is None:
259 opsetml = 3 if container.target_opset >= 16 else 1
260 dtype = guess_numpy_type(operator.inputs[0].type)
261 if dtype != numpy.float64:
262 dtype = numpy.float32
263 xgb_node = operator.raw_operator
264 inputs = operator.inputs
266 objective, base_score, js_trees = XGBConverter.common_members(
267 xgb_node, inputs)
268 if base_score is None:
269 raise RuntimeError( # pragma: no cover
270 "base_score cannot be None")
271 params = XGBConverter.get_xgb_params(xgb_node)
273 attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs()
274 XGBConverter.fill_tree_attributes(
275 js_trees, attr_pairs, [1 for _ in js_trees], True)
277 ncl = (max(attr_pairs['class_treeids']) + 1) // params['n_estimators']
279 bst = xgb_node.get_booster()
280 best_ntree_limit = getattr(
281 bst, 'best_ntree_limit', len(js_trees)) * ncl
282 if best_ntree_limit < len(js_trees):
283 js_trees = js_trees[:best_ntree_limit]
284 attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs()
285 XGBConverter.fill_tree_attributes(
286 js_trees, attr_pairs, [1 for _ in js_trees], True)
288 if len(attr_pairs['class_treeids']) == 0:
289 raise RuntimeError( # pragma: no cover
290 "XGBoost model is empty.")
291 if 'n_estimators' not in params:
292 raise RuntimeError( # pragma: no cover
293 "Parameters not found, existing:\n{}".format(
294 pformat(params)))
295 if ncl <= 1:
296 ncl = 2
297 # See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L23.
298 attr_pairs['post_transform'] = "LOGISTIC"
299 attr_pairs['class_ids'] = [0 for v in attr_pairs['class_treeids']]
300 else:
301 # See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L35.
302 attr_pairs['post_transform'] = "SOFTMAX"
303 # attr_pairs['base_values'] = [base_score for n in range(ncl)]
304 attr_pairs['class_ids'] = [v % ncl
305 for v in attr_pairs['class_treeids']]
307 classes = xgb_node.classes_
308 if (numpy.issubdtype(classes.dtype, numpy.floating) or
309 numpy.issubdtype(classes.dtype, numpy.signedinteger)):
310 attr_pairs['classlabels_int64s'] = classes.astype('int')
311 else:
312 classes = numpy.array([s.encode('utf-8') for s in classes])
313 attr_pairs['classlabels_strings'] = classes
315 if dtype == numpy.float64 and opsetml < 3:
316 op_name = "TreeEnsembleClassifierDouble"
317 else:
318 op_name = "TreeEnsembleClassifier"
320 # add nodes
321 if objective == "binary:logistic":
322 ncl = 2
323 container.add_node(op_name, operator.input_full_names,
324 operator.output_full_names,
325 name=scope.get_unique_operator_name(
326 op_name),
327 op_domain='ai.onnx.ml', **attr_pairs)
328 elif objective == "multi:softprob":
329 ncl = len(js_trees) // params['n_estimators']
330 container.add_node(
331 op_name, operator.input_full_names,
332 operator.output_full_names,
333 name=scope.get_unique_operator_name(op_name),
334 op_domain='ai.onnx.ml', op_version=1, **attr_pairs)
335 elif objective == "reg:logistic":
336 ncl = len(js_trees) // params['n_estimators']
337 if ncl == 1:
338 ncl = 2
339 container.add_node(
340 op_name, operator.input_full_names,
341 operator.output_full_names,
342 name=scope.get_unique_operator_name(op_name),
343 op_domain='ai.onnx.ml', op_version=1, **attr_pairs)
344 else:
345 raise RuntimeError( # pragma: no cover
346 "Unexpected objective: {0}".format(objective))
348 if opsetml >= 3:
349 _fix_tree_ensemble(scope, container, opsetml, dtype)
352def convert_xgboost(scope, operator, container):
353 """
354 This converters reuses the code from
355 `XGBoost.py <https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/
356 xgboost/operator_converters/XGBoost.py>`_ and makes
357 some modifications. It implements converters
358 for models in :epkg:`xgboost`.
359 """
360 xgb_node = operator.raw_operator
361 if isinstance(xgb_node, XGBClassifier):
362 cls = XGBClassifierConverter
363 else:
364 cls = XGBRegressorConverter
365 cls.convert(scope, operator, container)