Coverage for onnxcustom/training/_base.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Base class for @see cl BaseEstimator and @see cl BaseOnnxFunction.
4"""
5import os
6import inspect
7import warnings
10class BaseOnnxClass:
11 """
12 Bases class with common functions to handle attributes
13 in classes owning ONNX graphs.
14 """
16 @classmethod
17 def _get_param_names(cls):
18 "Extracts all parameters to serialize."
19 init = getattr(cls.__init__, "deprecated_original", cls.__init__)
20 init_signature = inspect.signature(init)
21 parameters = [
22 p for p in init_signature.parameters.values()
23 if p.name != "self" and p.kind != p.VAR_KEYWORD]
24 return [(p.name, p.default) for p in parameters]
26 def save_onnx_graph(self, folder, prefix=None, suffix=None):
27 """
28 Saves all ONNX files stored in this class.
30 :param folder: folder where to save (it must exists) or
31 ``bytes`` if the onnx graph must be returned as bytes,
32 not files
33 :param prefix: suffix to add to the name
34 :param suffix: suffix to add to the name
35 :return: list of saved files (dictionary
36 `{ attribute: filename or dictionary }`)
38 The function raises a warning if a file already exists.
39 The function uses class name, attribute names to compose
40 file names. It shortens them for frequent classes.
42 * 'Learning' -> 'L'
43 * 'OrtGradient' -> 'Grad'
44 * 'ForwardBackward' -> 'FB'
46 .. runpython::
47 :showcode:
49 import io
50 import numpy
51 import onnx
52 from sklearn.datasets import make_regression
53 from sklearn.model_selection import train_test_split
54 from sklearn.linear_model import LinearRegression
55 from skl2onnx import to_onnx
56 from mlprodict.plotting.text_plot import onnx_simple_text_plot
57 from onnxcustom.training.optimizers_partial import (
58 OrtGradientForwardBackwardOptimizer)
59 from onnxcustom.training.sgd_learning_rate import (
60 LearningRateSGDNesterov)
61 from onnxcustom.training.sgd_learning_penalty import (
62 ElasticLearningPenalty)
65 def walk_through(obj, prefix="", only_name=True):
66 for k, v in obj.items():
67 if isinstance(v, dict):
68 p = prefix + "." + k if prefix else k
69 walk_through(v, prefix=p, only_name=only_name)
70 elif only_name:
71 name = "%s.%s" % (prefix, k) if prefix else k
72 print('+', name)
73 else:
74 name = "%s.%s" % (prefix, k) if prefix else k
75 print('\\n++++++', name)
76 print()
77 bf = io.BytesIO(v)
78 onx = onnx.load(bf)
79 print(onnx_simple_text_plot(onx))
82 X, y = make_regression( # pylint: disable=W0632
83 100, n_features=3, bias=2, random_state=0)
84 X = X.astype(numpy.float32)
85 y = y.astype(numpy.float32)
86 X_train, _, y_train, __ = train_test_split(X, y)
87 reg = LinearRegression()
88 reg.fit(X_train, y_train)
89 reg.coef_ = reg.coef_.reshape((1, -1))
90 opset = 15
91 onx = to_onnx(reg, X_train, target_opset=opset,
92 black_op={'LinearRegressor'})
93 inits = ['coef', 'intercept']
95 train_session = OrtGradientForwardBackwardOptimizer(
96 onx, inits,
97 learning_rate=LearningRateSGDNesterov(
98 1e-4, nesterov=False, momentum=0.9),
99 learning_penalty=ElasticLearningPenalty(l1=1e-3, l2=1e-4),
100 warm_start=False, max_iter=100, batch_size=10)
102 onxs = train_session.save_onnx_graph(bytes)
104 print("+ all onnx graphs")
105 walk_through(onxs, only_name=True)
106 walk_through(onxs, only_name=False)
107 """
108 repls = {'Learning': 'L', 'OrtGradient': 'Grad',
109 'ForwardBackward': 'FB'}
110 if prefix is None:
111 prefix = ''
112 if suffix is None:
113 suffix = ''
114 if isinstance(folder, str) and not os.path.exists(folder):
115 raise FileNotFoundError( # pragma: no cover
116 "Folder %r does not exist." % folder)
117 saved = {}
118 for k, v in self.__dict__.items():
119 if hasattr(v, "SerializeToString"):
120 if isinstance(folder, str):
121 name = "%s%s%s.%s.onnx" % (
122 prefix, self.__class__.__name__, suffix, k)
123 for a, b in repls.items():
124 name = name.replace(a, b)
125 filename = os.path.join(folder, name)
126 if os.path.exists(filename):
127 warnings.warn( # pragma: no cover
128 "Filename %r already exists." % filename)
129 with open(filename, "wb") as f:
130 f.write(v.SerializeToString())
131 saved[k] = filename
132 else:
133 saved[k] = v.SerializeToString()
134 elif hasattr(v, "save_onnx_graph"):
135 saved[k] = v.save_onnx_graph(
136 folder, prefix=prefix, suffix="%s.%s" % (suffix, k))
137 return saved