Coverage for mlprodict/onnx_conv/operator_converters/parse_lightgbm.py: 99%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Parsers for LightGBM booster.
4"""
5import numpy
6from sklearn.base import ClassifierMixin
7from skl2onnx._parse import _parse_sklearn_classifier, _parse_sklearn_simple_model
8from skl2onnx.common._apply_operation import apply_concat, apply_cast
9from skl2onnx.common.data_types import guess_proto_type
10from skl2onnx.proto import onnx_proto
13class WrappedLightGbmBooster:
14 """
15 A booster can be a classifier, a regressor.
16 Trick to wrap it in a minimal function.
17 """
19 def __init__(self, booster):
20 self.booster_ = booster
21 self.n_features_ = self.booster_.feature_name()
22 self.objective_ = self.get_objective()
23 if self.objective_.startswith('binary'):
24 self.operator_name = 'LgbmClassifier'
25 self.classes_ = self._generate_classes(booster)
26 elif self.objective_.startswith('multiclass'):
27 self.operator_name = 'LgbmClassifier'
28 self.classes_ = self._generate_classes(booster)
29 elif self.objective_.startswith('regression'): # pragma: no cover
30 self.operator_name = 'LgbmRegressor'
31 else: # pragma: no cover
32 raise NotImplementedError(
33 'Unsupported LightGbm objective: %r.' % self.objective_)
34 average_output = self.booster_.attr('average_output')
35 if average_output:
36 self.boosting_type = 'rf'
37 else:
38 # Other than random forest, other boosting types do not affect later conversion.
39 # Here `gbdt` is chosen for no reason.
40 self.boosting_type = 'gbdt'
42 @staticmethod
43 def _generate_classes(booster):
44 if isinstance(booster, dict):
45 num_class = booster['num_class']
46 else:
47 num_class = booster.attr('num_class')
48 if num_class is None:
49 dp = booster.dump_model(num_iteration=1)
50 num_class = dp['num_class']
51 if num_class == 1:
52 return numpy.asarray([0, 1])
53 return numpy.arange(num_class)
55 def get_objective(self):
56 "Returns the objective."
57 if hasattr(self, 'objective_') and self.objective_ is not None:
58 return self.objective_
59 objective = self.booster_.attr('objective')
60 if objective is not None:
61 return objective
62 dp = self.booster_.dump_model(num_iteration=1)
63 return dp['objective']
66class WrappedLightGbmBoosterClassifier(ClassifierMixin):
67 """
68 Trick to wrap a LGBMClassifier into a class.
69 """
71 def __init__(self, wrapped): # pylint: disable=W0231
72 for k in {'boosting_type', '_model_dict', '_model_dict_info',
73 'operator_name', 'classes_', 'booster_', 'n_features_',
74 'objective_', 'boosting_type', 'n_features_'}:
75 if hasattr(wrapped, k):
76 setattr(self, k, getattr(wrapped, k))
79class MockWrappedLightGbmBoosterClassifier(WrappedLightGbmBoosterClassifier):
80 """
81 Mocked lightgbm.
82 """
84 def __init__(self, tree): # pylint: disable=W0231
85 self.dumped_ = tree
86 self.is_mock = True
88 def dump_model(self):
89 "mock dump_model method"
90 self.visited = True
91 return self.dumped_
93 def feature_name(self):
94 "Returns binary features names."
95 return [0, 1]
97 def attr(self, key):
98 "Returns default values for common attributes."
99 if key == 'objective':
100 return "binary"
101 if key == 'num_class':
102 return 1
103 if key == 'average_output':
104 return None
105 raise KeyError( # pragma: no cover
106 "No response for %r." % key)
109def lightgbm_parser(scope, model, inputs, custom_parsers=None):
110 """
111 Agnostic parser for LightGBM Booster.
112 """
113 if hasattr(model, "fit"):
114 raise TypeError( # pragma: no cover
115 "This converter does not apply on type '{}'."
116 "".format(type(model)))
118 if len(inputs) == 1:
119 wrapped = WrappedLightGbmBooster(model)
120 objective = wrapped.get_objective()
121 if objective.startswith('binary'):
122 wrapped = WrappedLightGbmBoosterClassifier(wrapped)
123 return _parse_sklearn_classifier(
124 scope, wrapped, inputs, custom_parsers=custom_parsers)
125 if objective.startswith('multiclass'):
126 wrapped = WrappedLightGbmBoosterClassifier(wrapped)
127 return _parse_sklearn_classifier(
128 scope, wrapped, inputs, custom_parsers=custom_parsers)
129 if objective.startswith('regression'): # pragma: no cover
130 return _parse_sklearn_simple_model(
131 scope, wrapped, inputs, custom_parsers=custom_parsers)
132 raise NotImplementedError( # pragma: no cover
133 "Objective '{}' is not implemented yet.".format(objective))
135 # Multiple columns
136 this_operator = scope.declare_local_operator('LightGBMConcat')
137 this_operator.raw_operator = model
138 this_operator.inputs = inputs
139 var = scope.declare_local_variable(
140 'Xlgbm', inputs[0].type.__class__([None, None]))
141 this_operator.outputs.append(var)
142 return lightgbm_parser(scope, model, this_operator.outputs,
143 custom_parsers=custom_parsers)
146def shape_calculator_lightgbm_concat(operator):
147 """
148 Shape calculator for operator *LightGBMConcat*.
149 """
150 pass
153def converter_lightgbm_concat(scope, operator, container):
154 """
155 Converter for operator *LightGBMConcat*.
156 """
157 op = operator.raw_operator
158 options = container.get_options(op, dict(cast=False))
159 proto_dtype = guess_proto_type(operator.inputs[0].type)
160 if proto_dtype != onnx_proto.TensorProto.DOUBLE: # pylint: disable=E1101
161 proto_dtype = onnx_proto.TensorProto.FLOAT # pylint: disable=E1101
162 if options['cast']:
163 concat_name = scope.get_unique_variable_name('cast_lgbm')
164 apply_cast(scope, concat_name, operator.outputs[0].full_name, container,
165 operator_name=scope.get_unique_operator_name('cast_lgmb'),
166 to=proto_dtype)
167 else:
168 concat_name = operator.outputs[0].full_name
170 apply_concat(scope, [_.full_name for _ in operator.inputs],
171 concat_name, container, axis=1)