Coverage for onnxcustom/training/_base_onnx_function.py: 95%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=W0105
2"""
3@file
4@brief Helper for :epkg:`onnxruntime-training`.
5"""
6import inspect
7from io import BytesIO
8import numpy
9import onnx
10from onnxruntime import SessionOptions, InferenceSession, RunOptions
11from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
12 OrtValue as C_OrtValue)
13from ..utils.onnxruntime_helper import ort_device_to_string
14from .excs import ProviderError
15from ._base import BaseOnnxClass
18class BaseLearningOnnx(BaseOnnxClass):
19 """
20 Class handling ONNX function to manipulate OrtValue.
21 Base class for @see cl BaseLearningRate and
22 @see cl BaseLearningLoss.
23 """
25 def __init__(self):
26 self.cache_in_ = {}
27 self.cache_out_ = {}
29 def __getstate__(self):
30 """
31 Overwrites getstate to get rid of InferenceSession.
32 """
33 atts = [k for k in self.__dict__ if not k.endswith('_')]
34 state = {k: getattr(self, k) for k in atts}
35 if hasattr(self, 'ro_'):
36 state['ro_'] = True
37 onx = [k for k in self.__dict__ if k.endswith('_onnx_')]
38 for o in onx:
39 state[o] = getattr(self, o).SerializeToString()
40 onx = [k for k in self.__dict__ if k.endswith('_sess_')]
41 bind = [k for k in self.__dict__ if k.endswith('_bind_')]
42 for k in bind:
43 state[k] = True
44 binds = [k for k in self.__dict__ if k.endswith('_binds_')]
45 for k in binds:
46 state[k] = len(getattr(self, k))
47 for o in onx:
48 state[o] = getattr(self, o).get_providers()
49 return state
51 def __setstate__(self, state):
52 """
53 Overwrites getstate to get rid of InferenceSession.
54 """
55 for k, v in state.items():
56 if k == 'ro_':
57 self.ro_ = RunOptions()
58 elif not k.endswith('_onnx_') and not k.endswith('_sess_'):
59 setattr(self, k, v)
61 so = SessionOptions()
62 so.log_severity_level = 4
63 for k, v in state.items():
64 if k.endswith('_onnx_'):
65 setattr(self, k, onnx.load(BytesIO(v)))
66 k2 = k.replace("onnx", "sess")
67 prov = state[k2]
68 setattr(self, k2, InferenceSession(
69 getattr(self, k).SerializeToString(), so,
70 providers=prov))
71 for k, v in state.items():
72 if k.endswith('_bind_'):
73 k2 = k[:-5]
74 setattr(self, k, getattr(self, k2).io_binding()._iobinding)
75 elif k.endswith('_binds_'):
76 k2 = k[:-6]
77 n = v
78 setattr(self, k, [
79 getattr(self, k2).io_binding()._iobinding
80 for i in range(n)])
81 self.cache_in_ = {}
82 self.cache_out_ = {}
83 return self
85 def __repr_extended__(self):
86 return ''
88 def __repr__(self):
89 """
90 Usual
91 """
92 param = self._get_param_names()
93 ps = []
94 for k, v in param:
95 if k not in self.__dict__:
96 continue # pragma: no cover
97 ov = getattr(self, k)
98 if v is not inspect._empty or ov != v:
99 ro = repr(ov)
100 ps.append("%s=%s" % (k, ro))
101 return "%s(%s)%s" % (
102 self.__class__.__name__, ", ".join(ps), self.__repr_extended__())
104 def build_onnx_function(self, opset, device, *args):
105 """
106 This class updates the weights.
107 It assumes it can do operator on *OrtValue*.
108 This can be done through ONNX graph.
109 This function creates :epkg:`InferenceSession`
110 which do that.
112 :param opset: opset to use
113 :param device: :epkg:`C_OrtDevice`
114 :param args: additional arguments
115 """
116 raise NotImplementedError(
117 "This method must be overwritten.")
119 @staticmethod
120 def _cache_in_clear(cache, name, bind):
121 key = id(bind)
122 if key in cache:
123 if name in cache[key]:
124 if cache[key][name] == 0:
125 return True
126 cache[key][name] = 0
127 return False
128 return True
130 def clear_binding_inputs(self, name, bind, cache=False):
131 """
132 Clears binding and empty cache.
133 """
134 if cache and self._cache_in_clear(self.cache_in_, name, bind):
135 return
136 bind.clear_binding_inputs()
138 @staticmethod
139 def _bio_cache(cache, name, bind, c_ortvalue, ptr2):
140 key = id(bind)
141 if key in cache:
142 if name in cache[key]:
143 ptr = cache[key][name]
144 if ptr == ptr2:
145 return True
146 cache[key][name] = ptr2
147 else:
148 cache[key] = {name: ptr2}
149 return False
151 @staticmethod
152 def _bio_do_bind_in(name, bind, c_ortvalue):
153 bind.bind_ortvalue_input(name, c_ortvalue)
155 @staticmethod
156 def _bio_ptr(c):
157 return c.data_ptr()
159 def _bind_input_ortvalue(self, name, bind, c_ortvalue, device,
160 cache=False):
161 """
162 Binds :epkg:`C_OrtValue` to the structure used by
163 :epkg:`InferenceSession` to run inference.
165 :param name: str
166 :param bind: python structure
167 :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`),
168 it can be also a numpy array
169 :param device: device
170 :param cache: avoids binding again if the data pointer did not change,
171 only works when c_ortvalue is of :epkg:`C_OrtValue`, the cache is
172 equivalent to a dictionary
173 `{ id(bind), name: c_ort_value.data_ptr() }`.
174 """
175 if isinstance(c_ortvalue, C_OrtValue):
176 if cache and self._bio_cache(
177 self.cache_in_, name, bind, c_ortvalue,
178 self._bio_ptr(c_ortvalue)):
179 return
180 self._bio_do_bind_in(name, bind, c_ortvalue)
181 elif isinstance(c_ortvalue, numpy.ndarray):
182 if self.device_type() != device.cpu(): # pylint: disable=E1101
183 raise ProviderError( # pragma: no cover
184 "device=%s is not CPU." % ort_device_to_string(
185 device))
186 if cache and self._bio_cache(
187 self.cache_in_, name, bind, c_ortvalue,
188 c_ortvalue.__array_interface__['data'][0]):
189 return
190 bind.bind_input(
191 name, device, c_ortvalue.dtype, c_ortvalue.shape,
192 c_ortvalue.__array_interface__['data'][0])
193 else:
194 raise TypeError( # pragma: no cover
195 "Unable to bind type %r for name %r." % (
196 type(c_ortvalue), name))
198 @staticmethod
199 def _bio_do_bind_out(name, bind, c_ortvalue):
200 bind.bind_ortvalue_output(name, c_ortvalue)
202 def _bind_output_ortvalue(self, name, bind, c_ortvalue, cache=False):
203 """
204 Binds :epkg:`C_OrtValue` to the structure used by
205 :epkg:`InferenceSession` to run inference.
207 :param name: str
208 :param bind: python structure
209 :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`)
210 :param cache: avoids binding again if the data pointer did not change,
211 only works when c_ortvalue is of :epkg:`C_OrtValue`, the cache is
212 equivalent to a dictionary
213 `{ id(bind), name: c_ort_value.data_ptr() }`.
215 This method can be used for inplace computation.
216 """
217 if isinstance(c_ortvalue, C_OrtValue):
218 if cache and self._bio_cache(
219 self.cache_out_, name, bind, c_ortvalue,
220 self._bio_ptr(c_ortvalue)):
221 return
222 self._bio_do_bind_out(name, bind, c_ortvalue)
223 else:
224 raise TypeError( # pragma: no cover
225 "Unable to bind type %r for name %r." % (
226 type(c_ortvalue), name))