Coverage for onnxcustom/training/data_loader.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Manipulate data for training.
4"""
5import numpy
6from ..utils.onnxruntime_helper import (
7 get_ort_device, numpy_to_ort_value, ort_device_to_string)
10class OrtDataLoader:
11 """
12 Draws consecutive random observations from a dataset
13 by batch. It iterates over the datasets by drawing
14 *batch_size* consecutive observations.
16 :param X: features
17 :param y: labels
18 :param sample_weight: weight or None
19 :param batch_size: batch size (consecutive observations)
20 :param device: :epkg:`C_OrtDevice` or a string such as `'cpu'`
21 :param random_iter: random iteration
23 See example :ref:`l-orttraining-nn-gpu`.
24 """
26 def __init__(self, X, y, sample_weight=None,
27 batch_size=20, device='cpu', random_iter=True):
28 if len(y.shape) == 1:
29 y = y.reshape((-1, 1))
30 if X.shape[0] != y.shape[0]:
31 raise ValueError( # pragma: no cover
32 "Shape mismatch X.shape=%r, y.shape=%r." % (X.shape, y.shape))
34 self.batch_size = batch_size
35 self.device = get_ort_device(device)
36 self.random_iter = random_iter
38 self.X_np = numpy.ascontiguousarray(X)
39 self.y_np = numpy.ascontiguousarray(y).reshape((-1, 1))
41 self.X_ort = numpy_to_ort_value(self.X_np, self.device)
42 self.y_ort = numpy_to_ort_value(self.y_np, self.device)
44 self.desc = [(self.X_np.shape, self.X_np.dtype),
45 (self.y_np.shape, self.y_np.dtype)]
47 if sample_weight is None:
48 self.w_np = None
49 self.w_ort = None
50 else:
51 if X.shape[0] != sample_weight.shape[0]:
52 raise ValueError( # pragma: no cover
53 "Shape mismatch X.shape=%r, sample_weight.shape=%r."
54 "" % (X.shape, sample_weight.shape))
55 self.w_np = numpy.ascontiguousarray(
56 sample_weight).reshape((-1, ))
57 self.w_ort = numpy_to_ort_value(self.w_np, self.device)
58 self.desc.append((self.w_np.shape, self.w_np.dtype))
60 def __getstate__(self):
61 "Removes any non pickable attribute."
62 state = {}
63 for att in ['X_np', 'y_np', 'w_np',
64 'desc', 'batch_size', 'random_iter']:
65 state[att] = getattr(self, att)
66 state['device'] = ort_device_to_string(self.device)
67 return state
69 def __setstate__(self, state):
70 "Restores any non pickable attribute."
71 for att, v in state.items():
72 setattr(self, att, v)
73 self.device = get_ort_device(self.device)
74 self.X_ort = numpy_to_ort_value(self.X_np, self.device)
75 self.y_ort = numpy_to_ort_value(self.y_np, self.device)
76 if self.w_np is None:
77 self.w_ort = None
78 else:
79 self.w_ort = numpy_to_ort_value(
80 self.w_np, self.device)
81 return self
83 def __repr__(self):
84 "usual"
85 return "%s(..., ..., batch_size=%r, device=%r)" % (
86 self.__class__.__name__, self.batch_size,
87 ort_device_to_string(self.device))
89 def __len__(self):
90 "Returns the number of observations."
91 return self.desc[0][0][0]
93 def _next_iter(self, previous):
94 if self.random_iter:
95 b = len(self) - self.batch_size
96 return numpy.random.randint(0, b)
97 if previous == -1:
98 return 0
99 i = previous + self.batch_size
100 if i + self.batch_size > len(self):
101 i = len(self) - self.batch_size
102 return i
104 def iter_numpy(self):
105 """
106 Iterates over the datasets by drawing
107 *batch_size* consecutive observations.
108 This iterator is slow as it copies the data of every
109 batch. The function yields :epkg:`C_OrtValue`.
110 """
111 if self.device.device_type() != self.device.cpu():
112 raise RuntimeError( # pragma: no cover
113 "Only CPU device is allowed if numpy arrays are requested "
114 "not %r." % ort_device_to_string(self.device))
115 N = 0
116 b = len(self) - self.batch_size
117 if self.w_np is None:
118 if b <= 0 or self.batch_size <= 0:
119 yield (self.X_np, self.y_np)
120 else:
121 i = -1
122 while N < len(self):
123 i = self._next_iter(i)
124 N += self.batch_size
125 yield (self.X_np[i:i + self.batch_size],
126 self.y_np[i:i + self.batch_size])
127 else:
128 if b <= 0 or self.batch_size <= 0:
129 yield (self.X_np, self.y_np, self.w_np)
130 else:
131 i = -1
132 while N < len(self):
133 i = self._next_iter(i)
134 N += self.batch_size
135 yield (self.X_np[i:i + self.batch_size],
136 self.y_np[i:i + self.batch_size],
137 self.w_np[i:i + self.batch_size])
139 def iter_ortvalue(self):
140 """
141 Iterates over the datasets by drawing
142 *batch_size* consecutive observations.
143 This iterator is slow as it copies the data of every
144 batch. The function yields :epkg:`C_OrtValue`.
145 """
146 N = 0
147 b = len(self) - self.batch_size
148 if self.w_ort is None:
149 if b <= 0 or self.batch_size <= 0:
150 yield (self.X_ort, self.y_ort)
151 else:
152 i = -1
153 while N < len(self):
154 i = self._next_iter(i)
155 N += self.batch_size
156 xp = self.X_np[i:i + self.batch_size]
157 yp = self.y_np[i:i + self.batch_size]
158 yield (
159 numpy_to_ort_value(xp, self.device),
160 numpy_to_ort_value(yp, self.device))
161 else:
162 if b <= 0 or self.batch_size <= 0:
163 yield (self.X_ort, self.y_ort, self.w_ort)
164 else:
165 i = -1
166 while N < len(self):
167 i = self._next_iter(i)
168 N += self.batch_size
169 xp = self.X_np[i:i + self.batch_size]
170 yp = self.y_np[i:i + self.batch_size]
171 wp = self.w_np[i:i + self.batch_size]
172 yield (
173 numpy_to_ort_value(xp, self.device),
174 numpy_to_ort_value(yp, self.device),
175 numpy_to_ort_value(wp, self.device))
177 def iter_bind(self, bind, names):
178 """
179 Iterates over the datasets by drawing
180 *batch_size* consecutive observations.
181 Modifies a bind structure.
182 """
183 if len(names) not in (3, 4):
184 raise NotImplementedError(
185 "The dataloader expects three (feature name, label name, "
186 "learning rate) or (feature name, label name, sample_weight, "
187 "learning rate), not %r." % names)
189 n_col_x = self.desc[0][0][1]
190 n_col_y = self.desc[1][0][1]
191 size_x = self.desc[0][1].itemsize
192 size_y = self.desc[1][1].itemsize
193 size_w = None if len(self.desc) <= 2 else self.desc[2][1].itemsize
195 def local_bind(bind, offset, n):
196 # This function assumes the data is contiguous.
197 shape_X = (n, n_col_x)
198 shape_y = (n, n_col_y)
200 try:
201 bind.bind_input(
202 names[0], self.device, self.desc[0][1], shape_X,
203 self.X_ort.data_ptr() + offset * n_col_x * size_x)
204 except RuntimeError as e: # pragma: no cover
205 raise RuntimeError(
206 "Unable to bind data input (X) %r, device=%r desc=%r "
207 "data_ptr=%r offset=%r n_col_x=%r size_x=%r "
208 "type(bind)=%r" % (
209 names[0], self.device, self.desc[0][1],
210 self.X_ort.data_ptr(), offset, n_col_x, size_x,
211 type(bind))) from e
212 try:
213 bind.bind_input(
214 names[1], self.device, self.desc[1][1], shape_y,
215 self.y_ort.data_ptr() + offset * n_col_y * size_y)
216 except RuntimeError as e: # pragma: no cover
217 raise RuntimeError(
218 "Unable to bind data input (y) %r, device=%r desc=%r "
219 "data_ptr=%r offset=%r n_col_y=%r size_y=%r "
220 "type(bind)=%r" % (
221 names[1], self.device, self.desc[1][1],
222 self.y_ort.data_ptr(), offset, n_col_y, size_y,
223 type(bind))) from e
225 def local_bindw(bind, offset, n):
226 # This function assumes the data is contiguous.
227 shape_w = (n, )
229 bind.bind_input(
230 names[2], self.device, self.desc[2][1], shape_w,
231 self.w_ort.data_ptr() + offset * size_w)
233 N = 0
234 b = len(self) - self.batch_size
235 if self.w_ort is None:
236 if b <= 0 or self.batch_size <= 0:
237 shape_x = self.desc[0][0]
238 local_bind(bind, 0, shape_x[0])
239 yield shape_x[0]
240 else:
241 n = self.batch_size
242 i = -1
243 while N < len(self):
244 i = self._next_iter(i)
245 N += self.batch_size
246 local_bind(bind, i, n)
247 yield n
248 else:
249 if b <= 0 or self.batch_size <= 0:
250 shape_x = self.desc[0][0]
251 local_bind(bind, 0, shape_x[0])
252 local_bindw(bind, 0, shape_x[0])
253 yield shape_x[0]
254 else:
255 n = self.batch_size
256 i = -1
257 while N < len(self):
258 i = self._next_iter(i)
259 N += self.batch_size
260 local_bind(bind, i, n)
261 local_bindw(bind, i, n)
262 yield n
264 @property
265 def data_np(self):
266 "Returns a tuple of the datasets in numpy."
267 if self.w_np is None:
268 return self.X_np, self.y_np
269 return self.X_np, self.y_np, self.w_np
271 @property
272 def data_ort(self):
273 "Returns a tuple of the datasets in onnxruntime C_OrtValue."
274 if self.w_ort is None:
275 return self.X_ort, self.y_ort
276 return self.X_ort, self.y_ort, self.w_ort