Coverage for onnxcustom/training/_base_estimator.py: 98%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

63 statements  

1""" 

2@file 

3@brief Optimizer with :epkg:`onnxruntime-training`. 

4""" 

5import inspect 

6from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

7 OrtDevice as C_OrtDevice) 

8from ..utils.onnxruntime_helper import ( 

9 get_ort_device, ort_device_to_string) 

10from ..utils.onnx_helper import replace_initializers_into_onnx 

11from ._base import BaseOnnxClass 

12from ._base_onnx_function import BaseLearningOnnx 

13from .sgd_learning_rate import BaseLearningRate 

14 

15 

16class BaseEstimator(BaseOnnxClass): 

17 """ 

18 Base class for optimizers. 

19 Implements common methods such `__repr__`. 

20 

21 :param model_onnx: onnx graph to train 

22 :param learning_rate: learning rate class, 

23 see module :mod:`onnxcustom.training.sgd_learning_rate` 

24 :param device: device as :epkg:`C_OrtDevice` or a string 

25 representing this device 

26 """ 

27 

28 def __init__(self, model_onnx, learning_rate, device): 

29 self.model_onnx = model_onnx 

30 self.learning_rate = BaseLearningRate.select(learning_rate) 

31 self.device = get_ort_device(device) 

32 

33 @classmethod 

34 def _get_param_names(cls): 

35 "Extracts all parameters to serialize." 

36 init = getattr(cls.__init__, "deprecated_original", cls.__init__) 

37 init_signature = inspect.signature(init) 

38 parameters = [ 

39 p for p in init_signature.parameters.values() 

40 if p.name != "self" and p.kind != p.VAR_KEYWORD] 

41 return [(p.name, p.default) for p in parameters] 

42 

43 def get_params(self, deep=False): 

44 """ 

45 Returns the list of parameters. 

46 Parameter *deep* is unused. 

47 """ 

48 ps = set(p[0] for p in self._get_param_names()) 

49 res = {att: getattr(self, att) 

50 for att in dir(self) 

51 if not att.endswith('_') and att in ps} 

52 if 'device' in res and not isinstance(res['device'], str): 

53 res['device'] = ort_device_to_string(res['device']) 

54 return res 

55 

56 def set_params(self, **params): 

57 """ 

58 Returns the list of parameters. 

59 Parameter *deep* is unused. 

60 """ 

61 for k, v in params.items(): 

62 if k == 'device' and isinstance(v, str): 

63 v = get_ort_device(v) 

64 setattr(self, k, v) 

65 self.build_onnx_function() # pylint: disable=E1101 

66 return self 

67 

68 def __repr__(self): 

69 "Usual." 

70 param = self._get_param_names() 

71 ps = [] 

72 for k, v in param: 

73 if k not in self.__dict__: 

74 continue # pragma: no cover 

75 ov = getattr(self, k) 

76 if isinstance(ov, BaseLearningOnnx): 

77 ps.append("%s=%s" % (k, repr(ov))) 

78 elif isinstance(ov, C_OrtDevice): 

79 ps.append("%s=%r" % (k, ort_device_to_string(ov))) 

80 elif v is not inspect._empty or ov != v: 

81 ro = repr(ov) 

82 if len(ro) > 50 or "\n" in ro: 

83 ro = ro[:10].replace("\n", " ") + "..." 

84 ps.append("%s=%r" % (k, ro)) 

85 else: 

86 ps.append("%s=%r" % (k, ov)) 

87 return "%s(%s)" % (self.__class__.__name__, ", ".join(ps)) 

88 

89 def __getstate__(self): 

90 "Removes any non pickable attribute." 

91 atts = [k for k in self.__dict__ if not k.endswith('_')] 

92 if (hasattr(self, 'trained_coef_') and 

93 not hasattr(self.__class__, 'trained_coef_')): 

94 atts.append('trained_coef_') 

95 state = {att: getattr(self, att) for att in atts} 

96 state['device'] = ort_device_to_string(state['device']) 

97 return state 

98 

99 def __setstate__(self, state): 

100 "Restores any non pickable attribute." 

101 for att, v in state.items(): 

102 setattr(self, att, v) 

103 self.device = get_ort_device(self.device) 

104 return self 

105 

106 def get_trained_onnx(self): 

107 """ 

108 Returns the trained onnx graph, the initial graph 

109 modified by replacing the initializers with the trained 

110 weights. 

111 

112 :return: onnx graph 

113 """ 

114 raise NotImplementedError( # pragma: no cover 

115 "The method needs to be overloaded.") 

116 

117 def _get_trained_onnx(self, state, model=None): 

118 """ 

119 Returns the trained onnx graph, the initial graph 

120 modified by replacing the initializers with the trained 

121 weights. 

122 

123 :param state: trained weights 

124 :param model: replace the weights in another graph 

125 than the training graph 

126 :return: onnx graph 

127 """ 

128 return replace_initializers_into_onnx(model or self.model_onnx, state)