Coverage for onnxcustom/training/_base.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

33 statements  

1""" 

2@file 

3@brief Base class for @see cl BaseEstimator and @see cl BaseOnnxFunction. 

4""" 

5import os 

6import inspect 

7import warnings 

8 

9 

10class BaseOnnxClass: 

11 """ 

12 Bases class with common functions to handle attributes 

13 in classes owning ONNX graphs. 

14 """ 

15 

16 @classmethod 

17 def _get_param_names(cls): 

18 "Extracts all parameters to serialize." 

19 init = getattr(cls.__init__, "deprecated_original", cls.__init__) 

20 init_signature = inspect.signature(init) 

21 parameters = [ 

22 p for p in init_signature.parameters.values() 

23 if p.name != "self" and p.kind != p.VAR_KEYWORD] 

24 return [(p.name, p.default) for p in parameters] 

25 

26 def save_onnx_graph(self, folder, prefix=None, suffix=None): 

27 """ 

28 Saves all ONNX files stored in this class. 

29 

30 :param folder: folder where to save (it must exists) or 

31 ``bytes`` if the onnx graph must be returned as bytes, 

32 not files 

33 :param prefix: suffix to add to the name 

34 :param suffix: suffix to add to the name 

35 :return: list of saved files (dictionary 

36 `{ attribute: filename or dictionary }`) 

37 

38 The function raises a warning if a file already exists. 

39 The function uses class name, attribute names to compose 

40 file names. It shortens them for frequent classes. 

41 

42 * 'Learning' -> 'L' 

43 * 'OrtGradient' -> 'Grad' 

44 * 'ForwardBackward' -> 'FB' 

45 

46 .. runpython:: 

47 :showcode: 

48 

49 import io 

50 import numpy 

51 import onnx 

52 from sklearn.datasets import make_regression 

53 from sklearn.model_selection import train_test_split 

54 from sklearn.linear_model import LinearRegression 

55 from skl2onnx import to_onnx 

56 from mlprodict.plotting.text_plot import onnx_simple_text_plot 

57 from onnxcustom.training.optimizers_partial import ( 

58 OrtGradientForwardBackwardOptimizer) 

59 from onnxcustom.training.sgd_learning_rate import ( 

60 LearningRateSGDNesterov) 

61 from onnxcustom.training.sgd_learning_penalty import ( 

62 ElasticLearningPenalty) 

63 

64 

65 def walk_through(obj, prefix="", only_name=True): 

66 for k, v in obj.items(): 

67 if isinstance(v, dict): 

68 p = prefix + "." + k if prefix else k 

69 walk_through(v, prefix=p, only_name=only_name) 

70 elif only_name: 

71 name = "%s.%s" % (prefix, k) if prefix else k 

72 print('+', name) 

73 else: 

74 name = "%s.%s" % (prefix, k) if prefix else k 

75 print('\\n++++++', name) 

76 print() 

77 bf = io.BytesIO(v) 

78 onx = onnx.load(bf) 

79 print(onnx_simple_text_plot(onx)) 

80 

81 

82 X, y = make_regression( # pylint: disable=W0632 

83 100, n_features=3, bias=2, random_state=0) 

84 X = X.astype(numpy.float32) 

85 y = y.astype(numpy.float32) 

86 X_train, _, y_train, __ = train_test_split(X, y) 

87 reg = LinearRegression() 

88 reg.fit(X_train, y_train) 

89 reg.coef_ = reg.coef_.reshape((1, -1)) 

90 opset = 15 

91 onx = to_onnx(reg, X_train, target_opset=opset, 

92 black_op={'LinearRegressor'}) 

93 inits = ['coef', 'intercept'] 

94 

95 train_session = OrtGradientForwardBackwardOptimizer( 

96 onx, inits, 

97 learning_rate=LearningRateSGDNesterov( 

98 1e-4, nesterov=False, momentum=0.9), 

99 learning_penalty=ElasticLearningPenalty(l1=1e-3, l2=1e-4), 

100 warm_start=False, max_iter=100, batch_size=10) 

101 

102 onxs = train_session.save_onnx_graph(bytes) 

103 

104 print("+ all onnx graphs") 

105 walk_through(onxs, only_name=True) 

106 walk_through(onxs, only_name=False) 

107 """ 

108 repls = {'Learning': 'L', 'OrtGradient': 'Grad', 

109 'ForwardBackward': 'FB'} 

110 if prefix is None: 

111 prefix = '' 

112 if suffix is None: 

113 suffix = '' 

114 if isinstance(folder, str) and not os.path.exists(folder): 

115 raise FileNotFoundError( # pragma: no cover 

116 "Folder %r does not exist." % folder) 

117 saved = {} 

118 for k, v in self.__dict__.items(): 

119 if hasattr(v, "SerializeToString"): 

120 if isinstance(folder, str): 

121 name = "%s%s%s.%s.onnx" % ( 

122 prefix, self.__class__.__name__, suffix, k) 

123 for a, b in repls.items(): 

124 name = name.replace(a, b) 

125 filename = os.path.join(folder, name) 

126 if os.path.exists(filename): 

127 warnings.warn( # pragma: no cover 

128 "Filename %r already exists." % filename) 

129 with open(filename, "wb") as f: 

130 f.write(v.SerializeToString()) 

131 saved[k] = filename 

132 else: 

133 saved[k] = v.SerializeToString() 

134 elif hasattr(v, "save_onnx_graph"): 

135 saved[k] = v.save_onnx_graph( 

136 folder, prefix=prefix, suffix="%s.%s" % (suffix, k)) 

137 return saved