Coverage for onnxcustom/training/_base_onnx_function.py: 95%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

119 statements  

1# pylint: disable=W0105 

2""" 

3@file 

4@brief Helper for :epkg:`onnxruntime-training`. 

5""" 

6import inspect 

7from io import BytesIO 

8import numpy 

9import onnx 

10from onnxruntime import SessionOptions, InferenceSession, RunOptions 

11from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

12 OrtValue as C_OrtValue) 

13from ..utils.onnxruntime_helper import ort_device_to_string 

14from .excs import ProviderError 

15from ._base import BaseOnnxClass 

16 

17 

18class BaseLearningOnnx(BaseOnnxClass): 

19 """ 

20 Class handling ONNX function to manipulate OrtValue. 

21 Base class for @see cl BaseLearningRate and 

22 @see cl BaseLearningLoss. 

23 """ 

24 

25 def __init__(self): 

26 self.cache_in_ = {} 

27 self.cache_out_ = {} 

28 

29 def __getstate__(self): 

30 """ 

31 Overwrites getstate to get rid of InferenceSession. 

32 """ 

33 atts = [k for k in self.__dict__ if not k.endswith('_')] 

34 state = {k: getattr(self, k) for k in atts} 

35 if hasattr(self, 'ro_'): 

36 state['ro_'] = True 

37 onx = [k for k in self.__dict__ if k.endswith('_onnx_')] 

38 for o in onx: 

39 state[o] = getattr(self, o).SerializeToString() 

40 onx = [k for k in self.__dict__ if k.endswith('_sess_')] 

41 bind = [k for k in self.__dict__ if k.endswith('_bind_')] 

42 for k in bind: 

43 state[k] = True 

44 binds = [k for k in self.__dict__ if k.endswith('_binds_')] 

45 for k in binds: 

46 state[k] = len(getattr(self, k)) 

47 for o in onx: 

48 state[o] = getattr(self, o).get_providers() 

49 return state 

50 

51 def __setstate__(self, state): 

52 """ 

53 Overwrites getstate to get rid of InferenceSession. 

54 """ 

55 for k, v in state.items(): 

56 if k == 'ro_': 

57 self.ro_ = RunOptions() 

58 elif not k.endswith('_onnx_') and not k.endswith('_sess_'): 

59 setattr(self, k, v) 

60 

61 so = SessionOptions() 

62 so.log_severity_level = 4 

63 for k, v in state.items(): 

64 if k.endswith('_onnx_'): 

65 setattr(self, k, onnx.load(BytesIO(v))) 

66 k2 = k.replace("onnx", "sess") 

67 prov = state[k2] 

68 setattr(self, k2, InferenceSession( 

69 getattr(self, k).SerializeToString(), so, 

70 providers=prov)) 

71 for k, v in state.items(): 

72 if k.endswith('_bind_'): 

73 k2 = k[:-5] 

74 setattr(self, k, getattr(self, k2).io_binding()._iobinding) 

75 elif k.endswith('_binds_'): 

76 k2 = k[:-6] 

77 n = v 

78 setattr(self, k, [ 

79 getattr(self, k2).io_binding()._iobinding 

80 for i in range(n)]) 

81 self.cache_in_ = {} 

82 self.cache_out_ = {} 

83 return self 

84 

85 def __repr_extended__(self): 

86 return '' 

87 

88 def __repr__(self): 

89 """ 

90 Usual 

91 """ 

92 param = self._get_param_names() 

93 ps = [] 

94 for k, v in param: 

95 if k not in self.__dict__: 

96 continue # pragma: no cover 

97 ov = getattr(self, k) 

98 if v is not inspect._empty or ov != v: 

99 ro = repr(ov) 

100 ps.append("%s=%s" % (k, ro)) 

101 return "%s(%s)%s" % ( 

102 self.__class__.__name__, ", ".join(ps), self.__repr_extended__()) 

103 

104 def build_onnx_function(self, opset, device, *args): 

105 """ 

106 This class updates the weights. 

107 It assumes it can do operator on *OrtValue*. 

108 This can be done through ONNX graph. 

109 This function creates :epkg:`InferenceSession` 

110 which do that. 

111 

112 :param opset: opset to use 

113 :param device: :epkg:`C_OrtDevice` 

114 :param args: additional arguments 

115 """ 

116 raise NotImplementedError( 

117 "This method must be overwritten.") 

118 

119 @staticmethod 

120 def _cache_in_clear(cache, name, bind): 

121 key = id(bind) 

122 if key in cache: 

123 if name in cache[key]: 

124 if cache[key][name] == 0: 

125 return True 

126 cache[key][name] = 0 

127 return False 

128 return True 

129 

130 def clear_binding_inputs(self, name, bind, cache=False): 

131 """ 

132 Clears binding and empty cache. 

133 """ 

134 if cache and self._cache_in_clear(self.cache_in_, name, bind): 

135 return 

136 bind.clear_binding_inputs() 

137 

138 @staticmethod 

139 def _bio_cache(cache, name, bind, c_ortvalue, ptr2): 

140 key = id(bind) 

141 if key in cache: 

142 if name in cache[key]: 

143 ptr = cache[key][name] 

144 if ptr == ptr2: 

145 return True 

146 cache[key][name] = ptr2 

147 else: 

148 cache[key] = {name: ptr2} 

149 return False 

150 

151 @staticmethod 

152 def _bio_do_bind_in(name, bind, c_ortvalue): 

153 bind.bind_ortvalue_input(name, c_ortvalue) 

154 

155 @staticmethod 

156 def _bio_ptr(c): 

157 return c.data_ptr() 

158 

159 def _bind_input_ortvalue(self, name, bind, c_ortvalue, device, 

160 cache=False): 

161 """ 

162 Binds :epkg:`C_OrtValue` to the structure used by 

163 :epkg:`InferenceSession` to run inference. 

164 

165 :param name: str 

166 :param bind: python structure 

167 :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`), 

168 it can be also a numpy array 

169 :param device: device 

170 :param cache: avoids binding again if the data pointer did not change, 

171 only works when c_ortvalue is of :epkg:`C_OrtValue`, the cache is 

172 equivalent to a dictionary 

173 `{ id(bind), name: c_ort_value.data_ptr() }`. 

174 """ 

175 if isinstance(c_ortvalue, C_OrtValue): 

176 if cache and self._bio_cache( 

177 self.cache_in_, name, bind, c_ortvalue, 

178 self._bio_ptr(c_ortvalue)): 

179 return 

180 self._bio_do_bind_in(name, bind, c_ortvalue) 

181 elif isinstance(c_ortvalue, numpy.ndarray): 

182 if self.device_type() != device.cpu(): # pylint: disable=E1101 

183 raise ProviderError( # pragma: no cover 

184 "device=%s is not CPU." % ort_device_to_string( 

185 device)) 

186 if cache and self._bio_cache( 

187 self.cache_in_, name, bind, c_ortvalue, 

188 c_ortvalue.__array_interface__['data'][0]): 

189 return 

190 bind.bind_input( 

191 name, device, c_ortvalue.dtype, c_ortvalue.shape, 

192 c_ortvalue.__array_interface__['data'][0]) 

193 else: 

194 raise TypeError( # pragma: no cover 

195 "Unable to bind type %r for name %r." % ( 

196 type(c_ortvalue), name)) 

197 

198 @staticmethod 

199 def _bio_do_bind_out(name, bind, c_ortvalue): 

200 bind.bind_ortvalue_output(name, c_ortvalue) 

201 

202 def _bind_output_ortvalue(self, name, bind, c_ortvalue, cache=False): 

203 """ 

204 Binds :epkg:`C_OrtValue` to the structure used by 

205 :epkg:`InferenceSession` to run inference. 

206 

207 :param name: str 

208 :param bind: python structure 

209 :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`) 

210 :param cache: avoids binding again if the data pointer did not change, 

211 only works when c_ortvalue is of :epkg:`C_OrtValue`, the cache is 

212 equivalent to a dictionary 

213 `{ id(bind), name: c_ort_value.data_ptr() }`. 

214 

215 This method can be used for inplace computation. 

216 """ 

217 if isinstance(c_ortvalue, C_OrtValue): 

218 if cache and self._bio_cache( 

219 self.cache_out_, name, bind, c_ortvalue, 

220 self._bio_ptr(c_ortvalue)): 

221 return 

222 self._bio_do_bind_out(name, bind, c_ortvalue) 

223 else: 

224 raise TypeError( # pragma: no cover 

225 "Unable to bind type %r for name %r." % ( 

226 type(c_ortvalue), name))