Coverage for mlprodict/onnx_conv/helpers/lgbm_helper.py: 93%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Helpers to speed up the conversion of Lightgbm models or transform it.
4"""
5from collections import deque
6import ctypes
7import json
8import re
11def restore_lgbm_info(tree):
12 """
13 Restores speed up information to help
14 modifying the structure of the tree.
15 """
17 def walk_through(t):
18 if 'tree_info' in t:
19 yield None
20 elif 'tree_structure' in t:
21 for w in walk_through(t['tree_structure']):
22 yield w
23 else:
24 yield t
25 if 'left_child' in t:
26 for w in walk_through(t['left_child']):
27 yield w
28 if 'right_child' in t:
29 for w in walk_through(t['right_child']):
30 yield w
32 nodes = []
33 if 'tree_info' in tree:
34 for node in walk_through(tree):
35 if node is None:
36 nodes.append([])
37 elif 'right_child' in node or 'left_child' in node:
38 nodes[-1].append(node)
39 else:
40 for node in walk_through(tree):
41 if 'right_child' in node or 'left_child' in node:
42 nodes.append(node)
43 return nodes
46def dump_booster_model(self, num_iteration=None, start_iteration=0,
47 importance_type='split', verbose=0):
48 """
49 Dumps Booster to JSON format.
51 Parameters
52 ----------
53 self: booster
54 num_iteration : int or None, optional (default=None)
55 Index of the iteration that should be dumped.
56 If None, if the best iteration exists, it is dumped; otherwise,
57 all iterations are dumped.
58 If <= 0, all iterations are dumped.
59 start_iteration : int, optional (default=0)
60 Start index of the iteration that should be dumped.
61 importance_type : string, optional (default="split")
62 What type of feature importance should be dumped.
63 If "split", result contains numbers of times the feature is used in a model.
64 If "gain", result contains total gains of splits which use the feature.
65 verbose: dispays progress (usefull for big trees)
67 Returns
68 -------
69 json_repr : dict
70 JSON format of Booster.
72 .. note::
73 This function is inspired from
74 the :epkg:`lightgbm` (`dump_model
75 <https://lightgbm.readthedocs.io/en/latest/pythonapi/
76 lightgbm.Booster.html#lightgbm.Booster.dump_model>`_.
77 It creates intermediate structure to speed up the conversion
78 into ONNX of such model. The function overwrites the
79 `json.load` to fastly extract nodes.
80 """
81 if getattr(self, 'is_mock', False):
82 return self.dump_model(), None
83 from lightgbm.basic import (
84 _LIB, FEATURE_IMPORTANCE_TYPE_MAPPER, _safe_call,
85 json_default_with_numpy)
86 if num_iteration is None:
87 num_iteration = self.best_iteration
88 importance_type_int = FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
89 buffer_len = 1 << 20
90 tmp_out_len = ctypes.c_int64(0)
91 string_buffer = ctypes.create_string_buffer(buffer_len)
92 ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
93 if verbose >= 2:
94 print( # pragma: no cover
95 "[dump_booster_model] call CAPI: LGBM_BoosterDumpModel")
96 _safe_call(_LIB.LGBM_BoosterDumpModel(
97 self.handle,
98 ctypes.c_int(start_iteration),
99 ctypes.c_int(num_iteration),
100 ctypes.c_int(importance_type_int),
101 ctypes.c_int64(buffer_len),
102 ctypes.byref(tmp_out_len),
103 ptr_string_buffer))
104 actual_len = tmp_out_len.value
105 # if buffer length is not long enough, reallocate a buffer
106 if actual_len > buffer_len:
107 string_buffer = ctypes.create_string_buffer(actual_len)
108 ptr_string_buffer = ctypes.c_char_p(
109 *[ctypes.addressof(string_buffer)])
110 _safe_call(_LIB.LGBM_BoosterDumpModel(
111 self.handle,
112 ctypes.c_int(start_iteration),
113 ctypes.c_int(num_iteration),
114 ctypes.c_int(importance_type_int),
115 ctypes.c_int64(actual_len),
116 ctypes.byref(tmp_out_len),
117 ptr_string_buffer))
119 WHITESPACE = re.compile(
120 r'[ \t\n\r]*', re.VERBOSE | re.MULTILINE | re.DOTALL)
122 class Hook(json.JSONDecoder):
123 """
124 Keep track of the progress, stores a copy of all objects with
125 a decision into a different container in order to walk through
126 all nodes in a much faster way than going through the architecture.
127 """
129 def __init__(self, *args, info=None, n_trees=None, verbose=0,
130 **kwargs):
131 json.JSONDecoder.__init__(
132 self, object_hook=self.hook, *args, **kwargs)
133 self.nodes = []
134 self.buffer = []
135 self.info = info
136 self.n_trees = n_trees
137 self.verbose = verbose
138 self.stored = 0
139 if verbose >= 2 and n_trees is not None:
140 from tqdm import tqdm # pragma: no cover
141 self.loop = tqdm(total=n_trees) # pragma: no cover
142 self.loop.set_description("dump_booster") # pragma: no cover
143 else:
144 self.loop = None
146 def decode(self, s, _w=WHITESPACE.match):
147 return json.JSONDecoder.decode(self, s, _w=_w)
149 def raw_decode(self, s, idx=0):
150 return json.JSONDecoder.raw_decode(self, s, idx=idx)
152 def hook(self, obj):
153 """
154 Hook called everytime a JSON object is created.
155 Keep track of the progress, stores a copy of all objects with
156 a decision into a different container.
157 """
158 # Every obj goes through this function from the leaves to the root.
159 if 'tree_info' in obj:
160 self.info['decision_nodes'] = self.nodes
161 if self.n_trees is not None and len(self.nodes) != self.n_trees:
162 raise RuntimeError( # pragma: no cover
163 "Unexpected number of trees %d (expecting %d)." % (
164 len(self.nodes), self.n_trees))
165 self.nodes = []
166 if self.loop is not None:
167 self.loop.close()
168 if 'tree_structure' in obj:
169 self.nodes.append(self.buffer)
170 if self.loop is not None:
171 self.loop.update(len(self.nodes))
172 if len(self.nodes) % 10 == 0:
173 self.loop.set_description(
174 "dump_booster: %d/%d trees, %d nodes" % (
175 len(self.nodes), self.n_trees, self.stored))
176 self.buffer = []
177 if "decision_type" in obj:
178 self.buffer.append(obj)
179 self.stored += 1
180 return obj
182 if verbose >= 2:
183 print("[dump_booster_model] to_json") # pragma: no cover
184 info = {}
185 ret = json.loads(string_buffer.value.decode('utf-8'), cls=Hook,
186 info=info, n_trees=self.num_trees(), verbose=verbose)
187 ret['pandas_categorical'] = json.loads(
188 json.dumps(self.pandas_categorical,
189 default=json_default_with_numpy))
190 if verbose >= 2:
191 print("[dump_booster_model] end.") # pragma: no cover
192 return ret, info
195def dump_lgbm_booster(booster, verbose=0):
196 """
197 Dumps a Lightgbm booster into JSON.
199 :param booster: Lightgbm booster
200 :param verbose: verbosity
201 :return: json, dictionary with more information
202 """
203 js, info = dump_booster_model(booster, verbose=verbose)
204 return js, info
207def modify_tree_for_rule_in_set(gbm, use_float=False, verbose=0, count=0, # pylint: disable=R1710
208 info=None):
209 """
210 LightGBM produces sometimes a tree with a node set
211 to use rule ``==`` to a set of values (= in set),
212 the values are separated by ``||``.
213 This function unfold theses nodes.
215 :param gbm: a tree coming from lightgbm dump
216 :param use_float: use float otherwise int first
217 then float if it does not work
218 :param verbose: verbosity, use :epkg:`tqdm` to show progress
219 :param count: number of nodes already changed (origin) before this call
220 :param info: addition information to speed up this search
221 :return: number of changed nodes (include *count*)
223 A child looks like the following:
225 .. runpython::
226 :showcode:
227 :warningout: DeprecationWarning
229 import pprint
230 from mlprodict.onnx_conv.operator_converters.conv_lightgbm import modify_tree_for_rule_in_set
232 tree = {'decision_type': '==',
233 'default_left': True,
234 'internal_count': 6805,
235 'internal_value': 0.117558,
236 'left_child': {'leaf_count': 4293,
237 'leaf_index': 18,
238 'leaf_value': 0.003519117642745049},
239 'missing_type': 'None',
240 'right_child': {'leaf_count': 2512,
241 'leaf_index': 25,
242 'leaf_value': 0.012305307958365394},
243 'split_feature': 24,
244 'split_gain': 12.233599662780762,
245 'split_index': 24,
246 'threshold': '10||12||13'}
248 modify_tree_for_rule_in_set(tree)
250 pprint.pprint(tree)
251 """
252 if 'tree_info' in gbm:
253 if info is not None:
254 dec_nodes = info['decision_nodes']
255 else:
256 dec_nodes = None
257 if verbose >= 2: # pragma: no cover
258 from tqdm import tqdm
259 loop = tqdm(gbm['tree_info'])
260 for i, tree in enumerate(loop):
261 loop.set_description("rules tree %d c=%d" % (i, count))
262 count = modify_tree_for_rule_in_set(
263 tree, use_float=use_float, count=count,
264 info=None if dec_nodes is None else dec_nodes[i])
265 else:
266 for i, tree in enumerate(gbm['tree_info']):
267 count = modify_tree_for_rule_in_set(
268 tree, use_float=use_float, count=count,
269 info=None if dec_nodes is None else dec_nodes[i])
270 return count
272 if 'tree_structure' in gbm:
273 return modify_tree_for_rule_in_set(
274 gbm['tree_structure'], use_float=use_float, count=count,
275 info=info)
277 if 'decision_type' not in gbm:
278 return count
280 def str2number(val):
281 if use_float:
282 return float(val)
283 else:
284 try:
285 return int(val)
286 except ValueError: # pragma: no cover
287 return float(val)
289 if info is None:
291 def recursive_call(this, c):
292 if 'left_child' in this:
293 c = process_node(this['left_child'], count=c)
294 if 'right_child' in this:
295 c = process_node(this['right_child'], count=c)
296 return c
298 def process_node(node, count):
299 if 'decision_type' not in node:
300 return count
301 if node['decision_type'] != '==':
302 return recursive_call(node, count)
303 th = node['threshold']
304 if not isinstance(th, str):
305 return recursive_call(node, count)
306 pos = th.find('||')
307 if pos == -1:
308 return recursive_call(node, count)
309 th1 = str2number(th[:pos])
311 def doit():
312 rest = th[pos + 2:]
313 if '||' not in rest:
314 rest = str2number(rest)
316 node['threshold'] = th1
317 new_node = node.copy()
318 node['right_child'] = new_node
319 new_node['threshold'] = rest
321 doit()
322 return recursive_call(node, count + 1)
324 return process_node(gbm, count)
326 # when info is used
328 def split_node(node, th, pos):
329 th1 = str2number(th[:pos])
331 rest = th[pos + 2:]
332 if '||' not in rest:
333 rest = str2number(rest)
334 app = False
335 else:
336 app = True
338 node['threshold'] = th1
339 new_node = node.copy()
340 node['right_child'] = new_node
341 new_node['threshold'] = rest
342 return new_node, app
344 stack = deque(info)
345 while len(stack) > 0:
346 node = stack.pop()
348 if 'decision_type' not in node:
349 continue # leave
351 if node['decision_type'] != '==':
352 continue
354 th = node['threshold']
355 if not isinstance(th, str):
356 continue
358 pos = th.find('||')
359 if pos == -1:
360 continue
362 new_node, app = split_node(node, th, pos)
363 count += 1
364 if app:
365 stack.append(new_node)
367 return count