Coverage for mlprodict/onnx_conv/operator_converters/parse_lightgbm.py: 99%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

91 statements  

1""" 

2@file 

3@brief Parsers for LightGBM booster. 

4""" 

5import numpy 

6from sklearn.base import ClassifierMixin 

7from skl2onnx._parse import _parse_sklearn_classifier, _parse_sklearn_simple_model 

8from skl2onnx.common._apply_operation import apply_concat, apply_cast 

9from skl2onnx.common.data_types import guess_proto_type 

10from skl2onnx.proto import onnx_proto 

11 

12 

13class WrappedLightGbmBooster: 

14 """ 

15 A booster can be a classifier, a regressor. 

16 Trick to wrap it in a minimal function. 

17 """ 

18 

19 def __init__(self, booster): 

20 self.booster_ = booster 

21 self.n_features_ = self.booster_.feature_name() 

22 self.objective_ = self.get_objective() 

23 if self.objective_.startswith('binary'): 

24 self.operator_name = 'LgbmClassifier' 

25 self.classes_ = self._generate_classes(booster) 

26 elif self.objective_.startswith('multiclass'): 

27 self.operator_name = 'LgbmClassifier' 

28 self.classes_ = self._generate_classes(booster) 

29 elif self.objective_.startswith('regression'): # pragma: no cover 

30 self.operator_name = 'LgbmRegressor' 

31 else: # pragma: no cover 

32 raise NotImplementedError( 

33 'Unsupported LightGbm objective: %r.' % self.objective_) 

34 average_output = self.booster_.attr('average_output') 

35 if average_output: 

36 self.boosting_type = 'rf' 

37 else: 

38 # Other than random forest, other boosting types do not affect later conversion. 

39 # Here `gbdt` is chosen for no reason. 

40 self.boosting_type = 'gbdt' 

41 

42 @staticmethod 

43 def _generate_classes(booster): 

44 if isinstance(booster, dict): 

45 num_class = booster['num_class'] 

46 else: 

47 num_class = booster.attr('num_class') 

48 if num_class is None: 

49 dp = booster.dump_model(num_iteration=1) 

50 num_class = dp['num_class'] 

51 if num_class == 1: 

52 return numpy.asarray([0, 1]) 

53 return numpy.arange(num_class) 

54 

55 def get_objective(self): 

56 "Returns the objective." 

57 if hasattr(self, 'objective_') and self.objective_ is not None: 

58 return self.objective_ 

59 objective = self.booster_.attr('objective') 

60 if objective is not None: 

61 return objective 

62 dp = self.booster_.dump_model(num_iteration=1) 

63 return dp['objective'] 

64 

65 

66class WrappedLightGbmBoosterClassifier(ClassifierMixin): 

67 """ 

68 Trick to wrap a LGBMClassifier into a class. 

69 """ 

70 

71 def __init__(self, wrapped): # pylint: disable=W0231 

72 for k in {'boosting_type', '_model_dict', '_model_dict_info', 

73 'operator_name', 'classes_', 'booster_', 'n_features_', 

74 'objective_', 'boosting_type', 'n_features_'}: 

75 if hasattr(wrapped, k): 

76 setattr(self, k, getattr(wrapped, k)) 

77 

78 

79class MockWrappedLightGbmBoosterClassifier(WrappedLightGbmBoosterClassifier): 

80 """ 

81 Mocked lightgbm. 

82 """ 

83 

84 def __init__(self, tree): # pylint: disable=W0231 

85 self.dumped_ = tree 

86 self.is_mock = True 

87 

88 def dump_model(self): 

89 "mock dump_model method" 

90 self.visited = True 

91 return self.dumped_ 

92 

93 def feature_name(self): 

94 "Returns binary features names." 

95 return [0, 1] 

96 

97 def attr(self, key): 

98 "Returns default values for common attributes." 

99 if key == 'objective': 

100 return "binary" 

101 if key == 'num_class': 

102 return 1 

103 if key == 'average_output': 

104 return None 

105 raise KeyError( # pragma: no cover 

106 "No response for %r." % key) 

107 

108 

109def lightgbm_parser(scope, model, inputs, custom_parsers=None): 

110 """ 

111 Agnostic parser for LightGBM Booster. 

112 """ 

113 if hasattr(model, "fit"): 

114 raise TypeError( # pragma: no cover 

115 "This converter does not apply on type '{}'." 

116 "".format(type(model))) 

117 

118 if len(inputs) == 1: 

119 wrapped = WrappedLightGbmBooster(model) 

120 objective = wrapped.get_objective() 

121 if objective.startswith('binary'): 

122 wrapped = WrappedLightGbmBoosterClassifier(wrapped) 

123 return _parse_sklearn_classifier( 

124 scope, wrapped, inputs, custom_parsers=custom_parsers) 

125 if objective.startswith('multiclass'): 

126 wrapped = WrappedLightGbmBoosterClassifier(wrapped) 

127 return _parse_sklearn_classifier( 

128 scope, wrapped, inputs, custom_parsers=custom_parsers) 

129 if objective.startswith('regression'): # pragma: no cover 

130 return _parse_sklearn_simple_model( 

131 scope, wrapped, inputs, custom_parsers=custom_parsers) 

132 raise NotImplementedError( # pragma: no cover 

133 "Objective '{}' is not implemented yet.".format(objective)) 

134 

135 # Multiple columns 

136 this_operator = scope.declare_local_operator('LightGBMConcat') 

137 this_operator.raw_operator = model 

138 this_operator.inputs = inputs 

139 var = scope.declare_local_variable( 

140 'Xlgbm', inputs[0].type.__class__([None, None])) 

141 this_operator.outputs.append(var) 

142 return lightgbm_parser(scope, model, this_operator.outputs, 

143 custom_parsers=custom_parsers) 

144 

145 

146def shape_calculator_lightgbm_concat(operator): 

147 """ 

148 Shape calculator for operator *LightGBMConcat*. 

149 """ 

150 pass 

151 

152 

153def converter_lightgbm_concat(scope, operator, container): 

154 """ 

155 Converter for operator *LightGBMConcat*. 

156 """ 

157 op = operator.raw_operator 

158 options = container.get_options(op, dict(cast=False)) 

159 proto_dtype = guess_proto_type(operator.inputs[0].type) 

160 if proto_dtype != onnx_proto.TensorProto.DOUBLE: # pylint: disable=E1101 

161 proto_dtype = onnx_proto.TensorProto.FLOAT # pylint: disable=E1101 

162 if options['cast']: 

163 concat_name = scope.get_unique_variable_name('cast_lgbm') 

164 apply_cast(scope, concat_name, operator.outputs[0].full_name, container, 

165 operator_name=scope.get_unique_operator_name('cast_lgmb'), 

166 to=proto_dtype) 

167 else: 

168 concat_name = operator.outputs[0].full_name 

169 

170 apply_concat(scope, [_.full_name for _ in operator.inputs], 

171 concat_name, container, axis=1)