Coverage for onnxcustom/training/_base_estimator.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Optimizer with :epkg:`onnxruntime-training`.
4"""
5import inspect
6from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
7 OrtDevice as C_OrtDevice)
8from ..utils.onnxruntime_helper import (
9 get_ort_device, ort_device_to_string)
10from ..utils.onnx_helper import replace_initializers_into_onnx
11from ._base import BaseOnnxClass
12from ._base_onnx_function import BaseLearningOnnx
13from .sgd_learning_rate import BaseLearningRate
16class BaseEstimator(BaseOnnxClass):
17 """
18 Base class for optimizers.
19 Implements common methods such `__repr__`.
21 :param model_onnx: onnx graph to train
22 :param learning_rate: learning rate class,
23 see module :mod:`onnxcustom.training.sgd_learning_rate`
24 :param device: device as :epkg:`C_OrtDevice` or a string
25 representing this device
26 """
28 def __init__(self, model_onnx, learning_rate, device):
29 self.model_onnx = model_onnx
30 self.learning_rate = BaseLearningRate.select(learning_rate)
31 self.device = get_ort_device(device)
33 @classmethod
34 def _get_param_names(cls):
35 "Extracts all parameters to serialize."
36 init = getattr(cls.__init__, "deprecated_original", cls.__init__)
37 init_signature = inspect.signature(init)
38 parameters = [
39 p for p in init_signature.parameters.values()
40 if p.name != "self" and p.kind != p.VAR_KEYWORD]
41 return [(p.name, p.default) for p in parameters]
43 def get_params(self, deep=False):
44 """
45 Returns the list of parameters.
46 Parameter *deep* is unused.
47 """
48 ps = set(p[0] for p in self._get_param_names())
49 res = {att: getattr(self, att)
50 for att in dir(self)
51 if not att.endswith('_') and att in ps}
52 if 'device' in res and not isinstance(res['device'], str):
53 res['device'] = ort_device_to_string(res['device'])
54 return res
56 def set_params(self, **params):
57 """
58 Returns the list of parameters.
59 Parameter *deep* is unused.
60 """
61 for k, v in params.items():
62 if k == 'device' and isinstance(v, str):
63 v = get_ort_device(v)
64 setattr(self, k, v)
65 self.build_onnx_function() # pylint: disable=E1101
66 return self
68 def __repr__(self):
69 "Usual."
70 param = self._get_param_names()
71 ps = []
72 for k, v in param:
73 if k not in self.__dict__:
74 continue # pragma: no cover
75 ov = getattr(self, k)
76 if isinstance(ov, BaseLearningOnnx):
77 ps.append("%s=%s" % (k, repr(ov)))
78 elif isinstance(ov, C_OrtDevice):
79 ps.append("%s=%r" % (k, ort_device_to_string(ov)))
80 elif v is not inspect._empty or ov != v:
81 ro = repr(ov)
82 if len(ro) > 50 or "\n" in ro:
83 ro = ro[:10].replace("\n", " ") + "..."
84 ps.append("%s=%r" % (k, ro))
85 else:
86 ps.append("%s=%r" % (k, ov))
87 return "%s(%s)" % (self.__class__.__name__, ", ".join(ps))
89 def __getstate__(self):
90 "Removes any non pickable attribute."
91 atts = [k for k in self.__dict__ if not k.endswith('_')]
92 if (hasattr(self, 'trained_coef_') and
93 not hasattr(self.__class__, 'trained_coef_')):
94 atts.append('trained_coef_')
95 state = {att: getattr(self, att) for att in atts}
96 state['device'] = ort_device_to_string(state['device'])
97 return state
99 def __setstate__(self, state):
100 "Restores any non pickable attribute."
101 for att, v in state.items():
102 setattr(self, att, v)
103 self.device = get_ort_device(self.device)
104 return self
106 def get_trained_onnx(self):
107 """
108 Returns the trained onnx graph, the initial graph
109 modified by replacing the initializers with the trained
110 weights.
112 :return: onnx graph
113 """
114 raise NotImplementedError( # pragma: no cover
115 "The method needs to be overloaded.")
117 def _get_trained_onnx(self, state, model=None):
118 """
119 Returns the trained onnx graph, the initial graph
120 modified by replacing the initializers with the trained
121 weights.
123 :param state: trained weights
124 :param model: replace the weights in another graph
125 than the training graph
126 :return: onnx graph
127 """
128 return replace_initializers_into_onnx(model or self.model_onnx, state)