Coverage for mlprodict/testing/experimental.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Experimental implementation.
4"""
5from collections import OrderedDict
6import numpy
9def custom_pad(arr, paddings, constant=0, verbose=False):
10 """
11 Implements function
12 `pad <https://numpy.org/doc/stable/reference/
13 generated/numpy.pad.html>`_ in python,
14 only the constant version.
16 :param arr: array
17 :param paddings: paddings
18 :param constant: constant
19 :return: padded array
20 """
21 if paddings.shape[0] != len(arr.shape):
22 raise ValueError( # pragma: no cover
23 "Input shape {} and paddings {} are inconsistent.".format(
24 arr.shape, paddings))
25 if min(paddings.ravel()) < 0:
26 raise NotImplementedError("Negative paddings is not implemented yet.")
27 if not arr.flags['C_CONTIGUOUS']:
28 arr = numpy.ascontiguousarray(arr) # pragma: no cover
30 new_shape = tuple(
31 a + s for a, s in zip(arr.shape, numpy.sum(paddings, axis=1, keepdims=0)))
33 cumulative_copy = [1]
34 for a in reversed(new_shape):
35 cumulative_copy.insert(0, a * cumulative_copy[0])
36 cumulative_input = [1]
37 for a in reversed(arr.shape):
38 cumulative_input.insert(0, a * cumulative_input[0])
40 input_arr = arr.ravel()
41 if verbose:
42 res = numpy.zeros(cumulative_copy[0], dtype=arr.dtype) - 1
43 else:
44 res = numpy.empty(cumulative_copy[0], dtype=arr.dtype)
46 # preparation
47 first_index = sum(
48 p * c for p, c in zip(paddings[:, 0], cumulative_copy[1:]))
49 dh_input = arr.shape[-1]
50 dh_copy = new_shape[-1]
52 # constance
53 no_constant = 1 if constant == 0 else 0
54 res[first_index:cumulative_copy[0]:dh_copy] = no_constant
56 # padding
57 for i, sh in enumerate(new_shape):
58 upper_number = cumulative_copy[0] // cumulative_copy[i]
59 contiguous = cumulative_copy[i + 1]
60 big_index = 0
61 p_left = paddings[i, 0] * contiguous
62 p_right = paddings[i, 1] * contiguous
63 dp = sh * contiguous - p_right
64 for _ in range(upper_number):
65 if p_left > 0:
66 res[big_index:big_index + p_left] = constant
67 if p_right > 0:
68 index = big_index + dp
69 res[index:index + p_right] = constant
70 big_index += cumulative_copy[i]
72 # copy
73 index_input = 0
74 index_copy = first_index
75 while index_copy < cumulative_copy[0]:
76 if res[index_copy] == no_constant:
77 res[index_copy:index_copy + dh_input] = \
78 input_arr[index_input:index_input + dh_input]
79 index_input += dh_input
80 index_copy += dh_copy
82 # final
83 return res.reshape(new_shape)
86def custom_einsum(equation, x, y, verbose=False):
87 """
88 Experimental implementation of operator Einsum
89 when it does a matrix multiplication.
90 Case: ``bsnh,btnh->bnts`` with shapes
91 `(1,512,12,64)` and `(1,512,12,64)`.
93 :param equation: equation
94 :param x: first matrix
95 :param y: second matrix
96 :param verbose: display internal information
97 :return: result of *einsum*
99 This implementation does not any transpose,
100 it does a direct computation of the final result.
101 It does not implementation diagonal summation (square product).
102 """
103 def _check_eq(eq, sh):
104 if len(eq) != len(sh):
105 raise ValueError(
106 "Unable to map equation %r to shape %r." % (eq, sh))
108 def _split(eq, sh):
109 dx = OrderedDict((e, (v, i)) for i, (e, v) in enumerate(zip(eq, sh)))
110 return dx
112 def _interpret(dx, dy, eqr):
113 c_uni = []
114 c_trp = []
115 c_sum = []
116 for r in eqr:
117 if r in dx:
118 if r in dy:
119 if dx[r][0] != dy[r][0]:
120 raise ValueError(
121 "Dimension mismatch for letter "
122 "%r dx=%r dy=%r." % (r, dx, dy))
123 c_trp.append(r)
124 else:
125 c_uni.append((r, None))
126 elif r in dy:
127 c_uni.append((None, r))
128 else:
129 raise ValueError( # pragma: no cover
130 "Unexpected letter %r in result %r." % (r, eqr))
131 for c in dx:
132 if c not in eqr:
133 if c not in dy:
134 raise ValueError( # pragma: no cover
135 "Unable to guess what to do with column %r (left side)" % c)
136 if dx[c][0] != dy[c][0]:
137 raise ValueError( # pragma: no cover
138 "Dimension mismatch for letter "
139 "%r dx=%r dy=%r." % (c, dx, dy))
140 c_sum.append(c)
141 for c in dy:
142 if c not in eqr and c not in dx:
143 raise ValueError( # pragma: no cover
144 "Unable to guess what to do with column %r (right side)" % c)
145 shape = OrderedDict()
146 for i, r in enumerate(eqr):
147 if r in c_trp:
148 shape[r] = (dx[r][0], i)
149 else:
150 for a, b in c_uni:
151 if a == r:
152 shape[r] = (dx[r][0], i)
153 break
154 if b == r:
155 shape[r] = (dy[r][0], i)
156 break
157 if len(shape) != len(eqr):
158 raise RuntimeError( # pragma: no cover
159 "Unable to compute the output shape "
160 "dx=%r dy=%r eqr=%r got shape=%r." % (dx, dy, eqr, shape))
161 return shape, c_trp, c_uni, c_sum
163 def _inc(d):
164 t = 1
165 drev = list(reversed(d.items()))
166 res = []
167 for c, (sh, p) in drev:
168 res.append((c, (t, p)))
169 t *= sh
170 return OrderedDict(reversed(res))
172 def prod(seq):
173 p = 1
174 for s in seq:
175 p *= s
176 return p
178 def get_index(cd, shape, index, col_sum):
179 ind = 0
180 for c, i in zip(shape, index):
181 if c in cd:
182 inc = cd[c][0]
183 ind += inc * i
184 return ind, cd[col_sum][0]
186 def get_incs(cd, shape):
187 incs = []
188 for c in shape:
189 inc = cd[c][0] if c in cd else 0
190 incs.append(inc)
191 return incs
193 if x.dtype != y.dtype:
194 raise RuntimeError("x and y must have the same dtype.")
195 eqx = equation.split(',')[0]
196 eqy = equation.split(',')[-1].split('->')[0]
197 eqr = equation.split('->')[-1]
198 _check_eq(eqx, x.shape)
199 _check_eq(eqy, y.shape)
200 dx = _split(eqx, x.shape)
201 dy = _split(eqy, y.shape)
202 shape, __, _, c_sum = _interpret(dx, dy, eqr)
203 cdx = _inc(dx)
204 cdy = _inc(dy)
205 xrav = x.ravel()
206 yrav = y.ravel()
207 full_size = prod(v[0] for v in shape.values())
208 zrav = numpy.empty((full_size, ), dtype=x.dtype)
210 # loop
211 if len(c_sum) != 1:
212 raise NotImplementedError(
213 "More than one summation indices %r in equation %r." % (
214 c_sum, equation))
215 zeros = numpy.zeros((1, ), dtype=x.dtype)
216 shape_dims = [v[0] for v in shape.values()]
217 index = [0 for s in shape]
218 len_index = len(index)
219 loop_size = dx[c_sum[0]][0]
221 i_left_loop, inc_left = get_index(cdx, shape, index, c_sum[0])
222 i_right_loop, inc_right = get_index(cdy, shape, index, c_sum[0])
223 left_incs = get_incs(cdx, shape)
224 right_incs = get_incs(cdy, shape)
226 if verbose:
227 def MakeString(*args):
228 return "".join(map(str, args))
230 print(MakeString("equation=", equation))
231 print(MakeString("c_sum=", c_sum))
232 print(MakeString("full_size=", full_size))
233 print(MakeString("loop_size=", loop_size))
234 print(MakeString("i_left_loop=", i_left_loop))
235 print(MakeString("i_right_loop=", i_right_loop))
236 print(MakeString("inc_left=", inc_left))
237 print(MakeString("inc_right=", inc_right))
238 print(MakeString("left_incs=", left_incs))
239 print(MakeString("right_incs=", right_incs))
240 print(MakeString("shape=", shape))
241 print(MakeString("cdx=", cdx))
242 print(MakeString("cdy=", cdy))
244 for i in range(0, full_size):
246 i_left = i_left_loop
247 i_right = i_right_loop
249 # summation
250 add = zeros[0]
251 for _ in range(loop_size):
252 add += xrav[i_left] * yrav[i_right]
253 i_left += inc_left
254 i_right += inc_right
255 zrav[i] = add
257 if verbose:
258 print(MakeString(
259 " -- index=", index, " ii=", i,
260 " i_left_loop=", i_left_loop, " i_right_loop=", i_right_loop,
261 " add=", add))
263 # increment
264 pos = len_index - 1
265 index[pos] += 1
266 i_left_loop += left_incs[pos]
267 i_right_loop += right_incs[pos]
268 while pos > 0 and index[pos] >= shape_dims[pos]:
269 i_left_loop -= left_incs[pos] * index[pos]
270 i_right_loop -= right_incs[pos] * index[pos]
271 index[pos] = 0
272 pos -= 1
273 index[pos] += 1
274 i_left_loop += left_incs[pos]
275 i_right_loop += right_incs[pos]
277 new_shape = tuple(v[0] for v in shape.values())
278 return zrav.reshape(new_shape)