Coverage for mlprodict/tools/model_info.py: 93%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Functions to help get more information about the models.
4"""
5import inspect
6from collections import Counter
7import numpy
10def _analyse_tree(tree):
11 """
12 Extract information from a tree.
13 """
14 info = {}
15 if hasattr(tree, 'node_count'):
16 info['node_count'] = tree.node_count
18 n_nodes = tree.node_count
19 children_left = tree.children_left
20 children_right = tree.children_right
21 node_depth = numpy.zeros(shape=n_nodes, dtype=numpy.int64)
22 is_leaves = numpy.zeros(shape=n_nodes, dtype=bool)
23 stack = [(0, -1)]
24 while len(stack) > 0:
25 node_id, parent_depth = stack.pop()
26 node_depth[node_id] = parent_depth + 1
27 if children_left[node_id] != children_right[node_id]:
28 stack.append((children_left[node_id], parent_depth + 1))
29 stack.append((children_right[node_id], parent_depth + 1))
30 else:
31 is_leaves[node_id] = True
33 info['leave_count'] = sum(is_leaves)
34 info['max_depth'] = max(node_depth)
35 return info
38def _analyse_tree_h(tree):
39 """
40 Extract information from a tree in a
41 HistGradientBoosting.
42 """
43 info = {}
44 info['leave_count'] = tree.get_n_leaf_nodes()
45 info['node_count'] = len(tree.nodes)
46 info['max_depth'] = tree.get_max_depth()
47 return info
50def _reduce_infos(infos):
51 """
52 Produces agregates features.
53 """
54 def tof(obj):
55 try:
56 return obj[0]
57 except TypeError: # pragma: no cover
58 return obj
60 if not isinstance(infos, list):
61 raise TypeError( # pragma: no cover
62 "infos must a list not {}.".format(type(infos)))
63 keys = set()
64 for info in infos:
65 if not isinstance(info, dict):
66 raise TypeError( # pragma: no cover
67 "info must a dictionary not {}.".format(type(info)))
68 keys |= set(info)
70 info = {}
71 for k in keys:
72 values = [d.get(k, None) for d in infos]
73 values = [_ for _ in values if _ is not None]
74 if k.endswith('.leave_count') or k.endswith('.node_count'):
75 info['sum|%s' % k] = sum(values)
76 elif k.endswith('.max_depth'):
77 info['max|%s' % k] = max(values)
78 elif k.endswith('.size'):
79 info['sum|%s' % k] = sum(values) # pragma: no cover
80 else:
81 try:
82 un = set(values)
83 except TypeError: # pragma: no cover
84 un = set()
85 if len(un) == 1:
86 info[k] = list(un)[0]
87 continue
88 if k.endswith('.shape'):
89 row = [_[0] for _ in values]
90 col = [_[1] for _ in values if len(_) > 1]
91 if len(col) == 0:
92 info['max|%s' % k] = (max(row), )
93 else:
94 info['max|%s' % k] = (max(row), max(col))
95 continue
96 if k == 'n_classes_':
97 info['n_classes_'] = max(tof(_) for _ in values)
98 continue
99 raise NotImplementedError( # pragma: no cover
100 "Unable to reduce key '{}', values={}.".format(k, values))
101 return info
104def _get_info_lgb(model):
105 """
106 Get informations from and :epkg:`lightgbm` trees.
107 """
108 from ..onnx_conv.operator_converters.conv_lightgbm import (
109 _parse_tree_structure,
110 get_default_tree_classifier_attribute_pairs
111 )
112 gbm_text = model.dump_model()
114 info = {'objective': gbm_text['objective']}
115 if gbm_text['objective'].startswith('binary'):
116 info['n_classes'] = 1
117 elif gbm_text['objective'].startswith('multiclass'):
118 info['n_classes'] = gbm_text['num_class']
119 elif gbm_text['objective'].startswith('regression'):
120 info['n_targets'] = 1
121 else:
122 raise NotImplementedError( # pragma: no cover
123 "Unknown objective '{}'.".format(gbm_text['objective']))
124 n_classes = info.get('n_classes', info.get('n_targets', -1))
126 info['estimators_.size'] = len(gbm_text['tree_info'])
127 attrs = get_default_tree_classifier_attribute_pairs()
128 for i, tree in enumerate(gbm_text['tree_info']):
129 tree_id = i
130 class_id = tree_id % n_classes
131 learning_rate = 1.
132 _parse_tree_structure(
133 tree_id, class_id, learning_rate, tree['tree_structure'], attrs)
135 info['node_count'] = len(attrs['nodes_nodeids'])
136 info['ntrees'] = len(set(attrs['nodes_treeids']))
137 dist = Counter(attrs['nodes_modes'])
138 info['leave_count'] = dist['LEAF']
139 info['mode_count'] = len(dist)
140 return info
143def _get_info_xgb(model):
144 """
145 Get informations from and :epkg:`lightgbm` trees.
146 """
147 from ..onnx_conv.operator_converters.conv_xgboost import (
148 XGBConverter, XGBClassifierConverter)
149 objective, _, js_trees = XGBConverter.common_members(model, None)
150 attrs = XGBClassifierConverter._get_default_tree_attribute_pairs()
151 XGBConverter.fill_tree_attributes(
152 js_trees, attrs, [1 for _ in js_trees], True)
153 info = {'objective': objective}
154 info['estimators_.size'] = len(js_trees)
155 info['node_count'] = len(attrs['nodes_nodeids'])
156 info['ntrees'] = len(set(attrs['nodes_treeids']))
157 dist = Counter(attrs['nodes_modes'])
158 info['leave_count'] = dist['LEAF']
159 info['mode_count'] = len(dist)
160 return info
163def analyze_model(model, simplify=True):
164 """
165 Returns informations, statistics about a model,
166 its number of nodes, its size...
168 @param model any model
169 @param simplify simplifies the tuple of length 1
170 @return dictionary
172 .. exref::
173 :title: Extract information from a model
175 The function @see fn analyze_model extracts global
176 figures about a model, whatever it is.
178 .. runpython::
179 :showcode:
180 :warningout: DeprecationWarning
182 import pprint
183 from sklearn.datasets import load_iris
184 from sklearn.ensemble import RandomForestClassifier
185 from mlprodict.tools.model_info import analyze_model
187 data = load_iris()
188 X, y = data.data, data.target
189 model = RandomForestClassifier().fit(X, y)
190 infos = analyze_model(model)
191 pprint.pprint(infos)
192 """
193 if hasattr(model, 'SerializeToString'):
194 # ONNX model
195 from ..onnx_tools.optim.onnx_helper import onnx_statistics
196 return onnx_statistics(model)
198 if isinstance(model, numpy.ndarray):
199 info = {'shape': model.shape}
200 infos = []
201 for v in model.ravel():
202 if hasattr(v, 'fit'):
203 ii = analyze_model(v, False)
204 infos.append(ii)
205 if len(infos) == 0:
206 return info # pragma: no cover
207 for k, v in _reduce_infos(infos).items():
208 info['.%s' % k] = v
209 return info
211 # linear model
212 info = {}
213 for k in model.__dict__:
214 if k in ['tree_']:
215 continue
216 if k.endswith('_') and not k.startswith('_'):
217 v = getattr(model, k)
218 if isinstance(v, numpy.ndarray):
219 info['%s.shape' % k] = v.shape
220 elif isinstance(v, numpy.float64):
221 info['%s.shape' % k] = 1
222 elif k in ('_fit_X', ):
223 v = getattr(model, k)
224 info['%s.shape' % k] = v.shape
226 # classification
227 for f in ['n_classes_', 'n_outputs', 'n_features_']:
228 if hasattr(model, f):
229 info[f] = getattr(model, f)
231 # tree
232 if hasattr(model, 'tree_'):
233 for k, v in _analyse_tree(model.tree_).items():
234 info['tree_.%s' % k] = v
236 # tree
237 if hasattr(model, 'get_n_leaf_nodes'):
238 for k, v in _analyse_tree_h(model).items():
239 info['tree_.%s' % k] = v
241 # estimators
242 if hasattr(model, 'estimators_'):
243 info['estimators_.size'] = len(model.estimators_)
244 infos = [analyze_model(est, False) for est in model.estimators_]
245 for k, v in _reduce_infos(infos).items():
246 info['estimators_.%s' % k] = v
248 # predictors
249 if hasattr(model, '_predictors'):
250 info['_predictors.size'] = len(model._predictors)
251 infos = []
252 for est in model._predictors:
253 ii = [analyze_model(e, False) for e in est]
254 infos.extend(ii)
255 for k, v in _reduce_infos(infos).items():
256 info['_predictors.%s' % k] = v
258 # LGBM
259 if hasattr(model, 'booster_'):
260 info.update(_get_info_lgb(model.booster_))
262 # XGB
263 if hasattr(model, 'get_booster'):
264 info.update(_get_info_xgb(model))
266 # end
267 if simplify:
268 up = {}
269 for k, v in info.items():
270 if isinstance(v, tuple) and len(v) == 1:
271 up[k] = v[0]
272 info.update(up)
274 return info
277def enumerate_models(model):
278 """
279 Enumerates models with models.
281 @param model :epkg:`scikit-learn` model
282 @return enumerate models
283 """
284 yield model
285 sig = inspect.signature(model.__init__)
286 for k in sig.parameters:
287 sub = getattr(model, k, None)
288 if sub is None:
289 continue
290 if not hasattr(sub, 'fit'):
291 continue
292 for m in enumerate_models(sub):
293 yield m
296def set_random_state(model, value=0):
297 """
298 Sets all possible parameter *random_state* to 0.
300 @param model :epkg:`scikit-learn` model
301 @param value new value
302 @return model (same one)
303 """
304 for m in enumerate_models(model):
305 sig = inspect.signature(m.__init__)
306 hasit = any(filter(lambda p: p == 'random_state',
307 sig.parameters))
308 if hasit and hasattr(m, 'random_state'):
309 m.random_state = value
310 return model