Coverage for mlprodict/testing/einsum/einsum_impl_ext.py: 96%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Functions implemented einsum computation for two
4matrices having the same dimensions.
5"""
6import numpy
9def numpy_diagonal(m, axis, axes):
10 """
11 Extracts diagonal coefficients from an array.
13 :param m: input array
14 :param axis: kept axis among the diagonal ones
15 :param axes: diagonal axes (axis must be one of them)
16 :return: output
18 .. runpython::
19 :showcode:
21 import numpy
22 from mlprodict.testing.einsum import numpy_diagonal
24 mat = numpy.arange(8).reshape((2, 2, 2))
25 print(mat)
26 diag = numpy_diagonal(mat, 1, [1, 2])
27 print(diag)
28 """
29 if axis not in axes:
30 raise RuntimeError(
31 "axis %r must be in axes %r." % (axis, axes))
32 shape = []
33 new_shape = []
34 for i, s in enumerate(m.shape):
35 if i in axes:
36 if i == axis:
37 shape.append(s)
38 new_shape.append(s)
39 else:
40 shape.append(1)
41 else:
42 shape.append(s)
43 new_shape.append(s)
45 # Extracts coefficients.
46 output = numpy.empty(tuple(shape), dtype=m.dtype)
47 index_in = [slice(s) for s in m.shape]
48 index_out = [slice(s) for s in m.shape]
49 for i in range(0, shape[axis]):
50 for a in axes:
51 index_in[a] = i
52 index_out[a] = i if a == axis else 0
53 output[tuple(index_out)] = m[tuple(index_in)]
55 # Removes axis.
56 return output.reshape(tuple(new_shape))
59def _numpy_extended_dot_equation(m1_dim, m2_dim, axes, left, right):
60 """
61 Returns the equation equivalent to an extended version
62 of an aligned matrix multiplication
63 (see @see fn numpy_extended_dot).
65 :param m1: number of dimensions of the first matrix
66 :param m2: number of dimensions of the second matrix
67 :param axes: summation axes
68 :param axes: summation axes
69 :param left: left axes
70 :param right: right axes
71 :return: equation
73 .. runpython::
74 :showcode:
76 import numpy
77 from mlprodict.testing.einsum.einsum_impl_ext import (
78 numpy_extended_dot_python, _numpy_extended_dot_equation)
80 a = numpy.arange(6).reshape((3, 2, 1))
81 b = numpy.arange(12).reshape((3, 1, 4))
83 print(numpy_extended_dot_python(
84 a, b, axes=(0, ), left=(1,), right=(2,)))
86 # Equivalent einsum equation
87 print('equation', _numpy_extended_dot_equation(
88 len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,)))
90 # Same einsum computation written in a different way.
91 print(numpy.einsum('kix,kxj->xij', a, b))
92 """
93 if m1_dim != m2_dim:
94 raise RuntimeError(
95 "Matrices m1 and m2 must have the same number of dimensions, "
96 "m1=%r, m2=%r." % (m1_dim, m2_dim))
97 total = set(axes) | set(left) | set(right)
98 if len(total) > m1_dim:
99 raise ValueError(
100 "Whole set of involved axes should be inferior to the number "
101 "of dimensions: %r = {%r} | {%r} | {%r} has more than %d elements"
102 "." % (total, axes, left, right, m1_dim))
104 def _check_(axs, n):
105 for a in axs:
106 if a < 0 or a >= n:
107 raise ValueError(
108 "One axis %d (in %r) is negative or above the maximum "
109 "dimension %d." % (a, axs, n))
110 _check_(axes, m1_dim)
111 _check_(left, m1_dim)
112 _check_(right, m1_dim)
114 l1 = [chr(i + 97) for i in range(m1_dim)]
115 l2 = [chr(i + 97) for i in range(m1_dim)]
116 l3 = [chr(i + 97) for i in range(m1_dim)]
117 for a in left:
118 l1[a] = l1[a].upper()
119 l3[a] = l3[a].upper()
120 for a in right:
121 l2[a] = l2[a].upper()
122 l3[a] = l3[a].upper()
123 for a in axes:
124 l1[a] = l1[a].lower()
125 l2[a] = l2[a].lower()
126 if a not in right:
127 l3[a] = None
128 else:
129 l3[a] = l3[a].lower()
130 eq = "%s,%s->%s" % ("".join(l1), "".join(l2),
131 "".join(s for s in l3 if s))
132 return eq
135def _common_check_numpy_extended_dot(m1, m2, axes, left, right):
136 """
137 Common verifications for all implementations of
138 @see fn numpy_extended_dot.
139 """
140 if m1.dtype != m2.dtype:
141 raise TypeError(
142 "Both matrices should share the same dtype %r != %r."
143 "" % (m1.dtype, m2.dtype))
144 m1_dim = len(m1.shape)
145 m2_dim = len(m2.shape)
146 if m1_dim != m2_dim:
147 raise RuntimeError( # pragma: no cover
148 "Matrices m1 and m2 must have the same number of dimensions, "
149 "m1=%r, m2=%r." % (m1_dim, m2_dim))
150 total = set(axes) | set(left) | set(right)
151 if len(total) > m1_dim:
152 raise ValueError(
153 "Whole set of involved axes should be inferior to the number "
154 "of dimensions: %r = {%r} | {%r} | {%r} has more than %d elements"
155 "." % (total, axes, left, right, m1_dim))
158def numpy_extended_dot(m1, m2, axes, left, right, verbose=False):
159 """
160 Extended version of a matrix multiplication (:epkg:`numpy:dot`)
161 with two matrices *m1*, *m2* of the same dimensions.
162 Loops over *left* axes for *m1* and *right* axes for *m2*,
163 summation is done over *axes*.
164 Other axes must be empty.
165 This multiplication combines matrix multiplication (dot)
166 and broadcasted multiplication term by term.
168 :param m1: first matrix
169 :param m2: second matrix
170 :param axes: summation axes
171 :param left: left axes
172 :param right: right axes
173 :param verbose: display intermediate information
174 :return: output
176 The dot product is equivalent to:
178 .. runpython::
179 :showcode:
181 import numpy
182 from mlprodict.testing.einsum import numpy_extended_dot
184 m1 = numpy.arange(4).reshape((2, 2))
185 m2 = m1 + 10
186 print("dot product")
187 print(m1 @ m2)
189 dm1 = m1.reshape((2, 2, 1))
190 dm2 = m2.reshape((1, 2, 2))
191 dot = numpy_extended_dot(dm1, dm2, axes=[1], left=[0], right=[2],
192 verbose=True)
193 print("extended dot product")
194 print(dot)
196 Empty axes should be squeezed to get identical results.
197 Dot product when the second matrix is transposed.
199 .. runpython::
200 :showcode:
202 import numpy
203 from mlprodict.testing.einsum import numpy_extended_dot
205 m1 = numpy.arange(4).reshape((2, 2))
206 m2 = m1 + 10
207 print("dot product")
208 print(m1 @ m2.T)
210 dm1 = m1.reshape((2, 1, 2))
211 dm2 = m2.reshape((1, 2, 2))
212 dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1],
213 verbose=True)
214 print("extended dot product")
215 print(dot)
217 An example when right axes include the summation axis.
219 .. runpython::
220 :showcode:
222 import numpy
223 from mlprodict.testing.einsum import numpy_extended_dot
225 m1 = numpy.arange(4).reshape((2, 2))
226 m2 = m1 + 10
227 dm1 = m1.reshape((2, 2, 1))
228 dm2 = m2.reshape((1, 2, 2))
229 dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1, 2],
230 verbose=True)
231 print(dot)
233 Example in higher dimension:
235 .. runpython::
236 :showcode:
238 import numpy
239 from mlprodict.testing.einsum import numpy_extended_dot
241 m1 = numpy.arange(8).reshape((2, 2, 2))
242 m2 = m1 + 10
244 dot = numpy_extended_dot(m1, m2, [1], [0], [2], verbose=True)
245 print(dot)
247 The current implementation still uses :epkg:`numpy:einsum`
248 but this should be replaced.
249 """
250 _common_check_numpy_extended_dot(m1, m2, axes, left, right)
251 eq = _numpy_extended_dot_equation(
252 len(m1.shape), len(m2.shape), axes, left, right)
253 if verbose:
254 print(" [numpy_extended_dot] %s: %r @ %r" % (eq, m1.shape, m2.shape))
255 output = numpy.einsum(eq, m1, m2)
256 new_shape = list(output.shape)
257 for a in axes:
258 if a not in right:
259 new_shape.insert(a, 1)
260 if verbose:
261 print(" [numpy_extended_dot] %r reshaped into %r " % (
262 output.shape, new_shape))
263 return output.reshape(tuple(new_shape))
266def numpy_extended_dot_ouput_shape(m1, m2, axes, left, right):
267 """
268 Computes the output shape of results produced by function
269 :func:`numpy_extended_dot
270 <mlprodict.testing.einsum_impl_ext.numpy_extended_dot>` or
271 :func:`numpy_extended_dot_python
272 <mlprodict.testing.einsum_impl_ext.numpy_extended_dot_python>`.
273 """
274 _common_check_numpy_extended_dot(m1, m2, axes, left, right)
275 m1_dim = len(m1.shape)
277 new_shape = numpy.full(m1_dim, 1, dtype=numpy.int64)
278 for i in left:
279 new_shape[i] = m1.shape[i]
280 for i in right:
281 if (i in left and m1.shape[i] != m2.shape[i] and
282 m1.shape[i] != 1 and m2.shape[i] != 1):
283 raise RuntimeError( # pragma: no cover
284 "Matrices should have the same dimension for dimension %d, "
285 "shapes=%r @ %r." % (i, m1.shape, m2.shape))
286 new_shape[i] = m2.shape[i]
287 return new_shape
290def _numpy_extended_dot_python_l1l2l3(m1_dim, axes, left, right):
291 l1 = [chr(i + 97) for i in range(m1_dim)]
292 l2 = [chr(i + 97) for i in range(m1_dim)]
293 l3 = [chr(i + 97) for i in range(m1_dim)]
294 for a in left:
295 l1[a] = l1[a].upper()
296 l3[a] = l3[a].upper()
297 for a in right:
298 l2[a] = l2[a].upper()
299 l3[a] = l3[a].upper()
300 for a in axes:
301 l1[a] = l1[a].lower()
302 l2[a] = l2[a].lower()
303 if a not in right:
304 l3[a] = "-"
305 else:
306 l3[a] = l3[a].lower()
307 return l1, l2, l3
310def _numpy_extended_dot_python_intermediate(m1_shape, m2_shape, l1, l2, l3):
311 names = list(sorted(set(l1 + l2)))
312 kind = numpy.zeros(len(names), dtype=numpy.int64)
313 cols = {}
315 for i, n in enumerate(names):
316 if n in l1:
317 kind[i] += 1
318 cols[n] = l1.index(n)
319 if n in l2:
320 kind[i] += 2
321 cols[n] = l2.index(n)
322 if n in l3:
323 kind[i] += 4
325 pos = numpy.zeros(len(names), dtype=numpy.int64)
326 for j in range(0, pos.shape[0]):
327 pos[j] = cols[names[j]]
328 common = [(kind[i] & 3) == 3 for i in range(len(kind))]
329 broadcast = [common[i] and m1_shape[pos[i]] != m2_shape[pos[i]]
330 for i in range(len(common))]
332 return names, kind, cols, common, broadcast, pos
335def _numpy_extended_dot_python_update_broadcast(
336 m1, m2, axes, left, right, l1, l2, l3, names, broadcast, cols,
337 kind, common, verbose=False):
339 def dispb(c):
340 return "".join("o" if b else "." for b in c)
342 if verbose:
343 print( # pragma: no cover
344 "[GENERICDOT] before broadcast %s,%s->%s or %s" % (
345 "".join(l1), "".join(l2), "".join(l3),
346 _numpy_extended_dot_equation(
347 len(m1.shape), len(m1.shape), axes, left, right)))
348 print( # pragma: no cover
349 "[GENERICDOT] names=%s kind=%r common=%s broadcast=%s" % (
350 "".join(names), kind.tolist(),
351 dispb(common), dispb(broadcast)))
353 for i in range(len(broadcast)): # pylint: disable=C0200
354 if broadcast[i] and not (kind[i] & 3) == 3:
355 raise RuntimeError( # pragma: no cover
356 "Broadcast should only happen on common axes, "
357 "axes=%r left=%r right=%r shape1=%r shape2=%r."
358 "" % (axes, left, right, m1.shape, m2.shape))
359 if not broadcast[i]:
360 continue
361 # We split letters.
362 p = cols[names[i]]
363 dim = (m1.shape[p], m2.shape[p])
364 let = [l1[p], l2[p], l3[p]]
365 inp = 1 if dim[0] == 1 else 0
366 if verbose:
367 print( # pragma: no cover
368 "[GENERICDOT] name=%s dim=%r let=%r inp=%r p=%r" % (
369 names[i], dim, let, inp, p))
370 print( # pragma: no cover
371 " B0 l1=%r, l2=%r l3=%r" % (l1, l2, l3))
372 if (kind[i] & 4) > 0:
373 # Summation axis is part of the output.
374 if let[inp].lower() == let[inp]:
375 let[inp] = let[inp].upper()
376 else:
377 let[inp] = let[inp].lower()
378 l3[p] = let[inp]
379 if inp == 1:
380 l2[p] = let[inp]
381 else:
382 l1[p] = let[inp]
383 if verbose:
384 print( # pragma: no cover
385 " B1 l1=%r, l2=%r l3=%r" % (l1, l2, l3))
386 else:
387 # Summation axis is not part of the output.
388 if let[inp].lower() == let[inp]:
389 let[inp] = let[inp].upper()
390 else:
391 let[inp] = let[inp].lower()
392 if inp == 1:
393 l2[p] = let[inp]
394 else:
395 l1[p] = let[inp]
396 if verbose:
397 print(" B2 l1=%r, l2=%r l3=%r" % (l1, l2, l3))
399 return l1, l2, l3
402def numpy_extended_dot_python(m1, m2, axes, left, right, verbose=False):
403 """
404 Implementation of @see fn numpy_extended_dot in pure python.
405 This implementation is not efficient but shows how to
406 implement this operation without :epkg:`numpy:einsum`.
408 .. runpython::
409 :showcode:
411 import numpy
412 from mlprodict.testing.einsum import numpy_extended_dot_python
413 from mlprodict.testing.einsum.einsum_impl_ext import (
414 _numpy_extended_dot_equation)
416 a = numpy.arange(6).reshape((3, 2, 1))
417 b = numpy.arange(12).reshape((3, 1, 4))
419 print(numpy_extended_dot_python(
420 a, b, axes=(0, ), left=(1,), right=(2,)))
422 # Equivalent einsum equation
423 print('equation', _numpy_extended_dot_equation(
424 len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,)))
426 # Same einsum computation written in a different way.
427 print(numpy.einsum('kix,kxj->xij', a, b))
428 """
429 def dispb(c):
430 return "".join("o" if b else "." for b in c)
432 new_shape = numpy_extended_dot_ouput_shape(m1, m2, axes, left, right)
433 m1_dim = len(m1.shape)
435 # output result
436 res = numpy.full(tuple(new_shape), 0, dtype=m1.dtype)
438 # indices
439 l1, l2, l3 = _numpy_extended_dot_python_l1l2l3(m1_dim, axes, left, right)
440 names, kind, cols, common, broadcast, pos = (
441 _numpy_extended_dot_python_intermediate(
442 m1.shape, m2.shape, l1, l2, l3))
444 if any(broadcast):
445 l1, l2, l3 = _numpy_extended_dot_python_update_broadcast(
446 m1, m2, axes, left, right, l1, l2, l3, names, broadcast, cols,
447 kind, common, verbose=verbose)
449 names, kind, cols, common, broadcast, pos = (
450 _numpy_extended_dot_python_intermediate(
451 m1.shape, m2.shape, l1, l2, l3))
453 indices = numpy.array([0 for n in names], dtype=numpy.int64)
454 pl1 = numpy.array([names.index(c) for c in l1], dtype=numpy.int64)
455 pl2 = numpy.array([names.index(c) for c in l2], dtype=numpy.int64)
456 limits = numpy.array(
457 [m1.shape[pos[n]] if (kind[n] & 1) == 1 else m2.shape[pos[n]]
458 for n in range(len(names))], dtype=numpy.int64)
459 plo = numpy.array(
460 [-1 if c not in names else names.index(c) for c in l3],
461 dtype=numpy.int64)
463 if verbose:
464 print("[GENERICDOT] %s,%s->%s or %s" % (
465 "".join(l1), "".join(l2), "".join(l3),
466 _numpy_extended_dot_equation(
467 len(m1.shape), len(m1.shape), axes, left, right)))
468 print("[GENERICDOT] shape1=%r shape2=%r shape=%r" % (
469 m1.shape, m2.shape, res.shape))
470 print("[GENERICDOT] axes=%r left=%r right=%r" % (axes, left, right))
471 print("[GENERICDOT] pl1=%r pl2=%r plo=%r" % (pl1, pl2, plo))
472 print("[GENERICDOT] names=%s kind=%r common=%s broadcast=%s" % (
473 "".join(names), kind.tolist(),
474 dispb(common), dispb(broadcast)))
475 print("[GENERICDOT] pos=%r" % pos.tolist())
476 print("[GENERICDOT] cols=%r" % cols)
477 print("[GENERICDOT] limits=%r" % limits)
479 while indices[0] < limits[0]:
481 # The function spends most of its time is these three lines.
482 t1 = tuple(indices[n] for n in pl1)
483 t2 = tuple(indices[n] for n in pl2)
484 to = tuple(0 if n == -1 else indices[n] for n in plo)
486 c = m1[t1] * m2[t2]
488 if verbose:
489 print(" %r x %r -> %r v=%r I=%r" % (t1, t2, to, c, indices))
491 res[to] += c
493 last = len(indices) - 1
494 indices[last] += 1
495 for i in range(last, 0, -1):
496 if indices[i] < limits[i]:
497 break
498 indices[i] = 0
499 if i > 0:
500 indices[i - 1] += 1
502 return res
505def numpy_extended_dot_matrix(m1, m2, axes, left, right, verbose=False):
506 """
507 Implementation of @see fn numpy_extended_dot using dot product,
508 multiplication, transpose and reduction
509 but not a custom python implementation like
510 @see fn numpy_extended_dot_python.
512 .. runpython::
513 :showcode:
515 import numpy
516 from mlprodict.testing.einsum import numpy_extended_dot_matrix
517 from mlprodict.testing.einsum.einsum_impl_ext import (
518 _numpy_extended_dot_equation)
520 a = numpy.arange(6).reshape((3, 2, 1))
521 b = numpy.arange(12).reshape((3, 1, 4))
523 print(numpy_extended_dot_matrix(
524 a, b, axes=(0, ), left=(1,), right=(2,)))
526 # Equivalent einsum equation
527 print('equation', _numpy_extended_dot_equation(
528 len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,)))
530 # Same einsum computation written in a different way.
531 print(numpy.einsum('kix,kxj->xij', a, b))
532 """
533 _common_check_numpy_extended_dot(m1, m2, axes, left, right)
535 if verbose:
536 print( # pragma: no cover
537 "[GENERICDOT] shape1=%r shape2=%r axes=%r "
538 "left=%r right=%r -- %s" % (
539 m1.shape, m2.shape, axes, left, right,
540 _numpy_extended_dot_equation(
541 len(m1.shape), len(m1.shape), axes, left, right)))
543 if len(axes) == 0 and len(set(left) & set(right)) == 0:
544 # Simple multiplication
545 res = m1 * m2
546 if verbose:
547 print( # pragma: no cover
548 "[GENERICDOT] Mul %r @ %r -> %r" % (
549 m1.shape, m2.shape, res.shape))
550 return res
552 if (len(set(axes) & set(left)) == 0 and
553 len(set(axes) & set(right)) == 0):
555 # No intersection between axes and right: matrix multiplication
556 # ReduceSum
557 right_no_left = set(right) - (set(right) & (set(left) | set(axes)))
558 if right_no_left:
559 red1 = m1.sum(axis=tuple(sorted(right_no_left)), keepdims=True)
560 if verbose:
561 print("[GENERICDOT] reducesumL=%r, %r -> %r" % (
562 right_no_left, m1.shape, red1.shape))
563 else:
564 red1 = m1
566 left_no_right = set(left) - (set(left) & (set(right) | set(axes)))
567 if left_no_right:
568 red2 = m2.sum(axis=tuple(sorted(left_no_right)), keepdims=True)
569 if verbose:
570 print("[GENERICDOT] reducesumR=%r, %r -> %r" % (
571 left_no_right, m2.shape, red2.shape))
572 else:
573 red2 = m2
575 # Transpose
576 common_axes = sorted(set(left) & set(right))
577 i_axes = [(-1 if i in common_axes
578 else (1 if i in axes else 0), i)
579 for i in range(len(m1.shape))]
580 i_axes.sort()
581 perm = [_[1] for _ in i_axes]
582 trm1 = numpy.transpose(red1, axes=perm)
583 trm2 = numpy.transpose(red2, axes=perm)
584 if verbose:
585 print("[GENERICDOT] transposeL=%r, %r -> %r" % (
586 perm, red1.shape, trm1.shape))
587 print("[GENERICDOT] transposeR=%r, %r -> %r" % (
588 perm, red2.shape, trm2.shape))
589 final_shape = numpy_extended_dot_ouput_shape(
590 m1, m2, axes, left, right)
591 perm_left = [i for i in range(len(perm)) if perm[i] in left]
592 perm_right = [i for i in range(len(perm)) if perm[i] in right]
593 perm_common_axes = [i for i in range(len(perm))
594 if perm[i] in common_axes]
596 if verbose:
597 print("[GENERICDOT] MatMul %r @ %r -> %r -- %s" % (
598 m1.shape, m2.shape, final_shape,
599 _numpy_extended_dot_equation(
600 len(m1.shape), len(m1.shape), axes, left, right)))
601 print("[GENERICDOT] axes=%r left=%r right=%r" %
602 (axes, left, right))
603 print("[GENERICDOT] perm=%r perm_left=%r "
604 "perm_right=%r perm_common_axes=%r" % (
605 perm, perm_left, perm_right, perm_common_axes))
607 # Reshape
608 dim0 = int(numpy.prod([trm1.shape[i] for i in perm_common_axes]))
609 dim0b = int(numpy.prod([trm2.shape[i] for i in perm_common_axes]))
610 if len(axes) > 0:
611 all_axes = list(range(0, len(m1.shape)))
612 new_axes = all_axes[-len(axes):]
613 else:
614 new_axes = []
615 dim1 = int(numpy.prod([trm1.shape[i] for i in new_axes]))
616 dim2 = int(numpy.prod([trm2.shape[i] for i in new_axes]))
617 if dim1 != dim2:
618 raise RuntimeError( # pragma: no cover
619 "Summation axis do not have the same length %d != %d, "
620 "trshape1=%r trshape2=%r "
621 "p_axes=%r p_left=%r p_right=%r p_common=%r"
622 "." % (dim1, dim2, trm1.shape, trm2.shape,
623 new_axes, perm_left, perm_right, perm_common_axes))
624 else:
625 shm1 = trm1.reshape((dim0, -1, dim1))
626 shm2 = trm2.reshape((dim0b, -1, dim2))
628 if verbose:
629 print("[GENERICDOT] Reshape %r @ %r -> %r @ %r" % (
630 (dim0, -1, dim1), (dim0, -1, dim2),
631 shm1.shape, shm2.shape))
632 print("[GENERICDOT] matmul")
634 # Multiplication (this should be done in a different way.
635 res = shm1 @ numpy.transpose(shm2, axes=(0, 2, 1))
637 if verbose:
638 print("[GENERICDOT] Shape after multiplication %s" % (res.shape, ))
640 # Transpose again
641 not_in_both = []
642 for i in range(0, len(m1.shape)):
643 if i not in left and i not in right:
644 not_in_both.append(i)
645 ordered_axes = (common_axes +
646 list(i for i in left if i not in right) +
647 list(i for i in right if i not in left) +
648 not_in_both)
650 perm_not_in_both = [i for i in range(len(perm))
651 if perm[i] in not_in_both]
652 current_shape = ([max(trm1.shape[i], trm2.shape[i])
653 for i in sorted(perm_common_axes)] +
654 [trm1.shape[i] for i in sorted(perm_left)
655 if i not in perm_common_axes] +
656 [trm2.shape[i] for i in sorted(perm_right)
657 if i not in perm_common_axes] +
658 [1 for i in perm_not_in_both])
660 if verbose:
661 print("[GENERICDOT] current_shape=%r final_shape=%r "
662 "last_shape=%r" % (current_shape, final_shape, res.shape))
664 if len(current_shape) != len(final_shape):
665 raise RuntimeError( # pragma: no cover
666 "Shapes mismatch %r > %r, "
667 "shape1=%r shape2=%r axes=%r left=%r right=%r." % (
668 current_shape, final_shape,
669 m1.shape, m2.shape, axes, left, right))
671 res = res.reshape(current_shape)
673 perm = [(a, i) for i, a in enumerate(ordered_axes)]
674 perm.sort()
675 perm = [p[1] for p in perm]
677 if verbose:
678 print("[GENERICDOT] ordered_axes=%r perm=%r" % (
679 ordered_axes, perm))
681 return numpy.transpose(res, axes=perm)
683 else:
684 # Multiplication and Matrix multiplication at the same time.
685 l_axes = set(left) & set(axes)
686 r_axes = set(right) & set(axes)
687 if r_axes and not l_axes:
688 new_axes = list(a for a in axes if a not in right)
689 new_left = list(sorted(set(left) | r_axes))
690 if verbose: # pragma: no cover
691 eq1 = _numpy_extended_dot_equation(
692 len(m1.shape), len(m1.shape), axes, left, right)
693 eq2 = _numpy_extended_dot_equation(
694 len(m1.shape), len(m1.shape), new_axes, new_left, right)
695 print("[GENERICDOT] replace left %r by %r axes %r by %r, "
696 "eq %r by %r" % (
697 left, new_left, axes, new_axes, eq1, eq2))
698 return numpy_extended_dot_matrix(m1, m2, new_axes, new_left, right,
699 verbose=verbose)
700 raise RuntimeError( # pragma: no cover
701 "shape1=%r shape2=%r axes=%r left=%r right=%r eq=%s." % (
702 m1.shape, m2.shape, axes, left, right,
703 _numpy_extended_dot_equation(
704 len(m1.shape), len(m1.shape), axes, left, right)))