Coverage for mlprodict/testing/einsum/einsum_impl_ext.py: 96%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

271 statements  

1""" 

2@file 

3@brief Functions implemented einsum computation for two 

4matrices having the same dimensions. 

5""" 

6import numpy 

7 

8 

9def numpy_diagonal(m, axis, axes): 

10 """ 

11 Extracts diagonal coefficients from an array. 

12 

13 :param m: input array 

14 :param axis: kept axis among the diagonal ones 

15 :param axes: diagonal axes (axis must be one of them) 

16 :return: output 

17 

18 .. runpython:: 

19 :showcode: 

20 

21 import numpy 

22 from mlprodict.testing.einsum import numpy_diagonal 

23 

24 mat = numpy.arange(8).reshape((2, 2, 2)) 

25 print(mat) 

26 diag = numpy_diagonal(mat, 1, [1, 2]) 

27 print(diag) 

28 """ 

29 if axis not in axes: 

30 raise RuntimeError( 

31 "axis %r must be in axes %r." % (axis, axes)) 

32 shape = [] 

33 new_shape = [] 

34 for i, s in enumerate(m.shape): 

35 if i in axes: 

36 if i == axis: 

37 shape.append(s) 

38 new_shape.append(s) 

39 else: 

40 shape.append(1) 

41 else: 

42 shape.append(s) 

43 new_shape.append(s) 

44 

45 # Extracts coefficients. 

46 output = numpy.empty(tuple(shape), dtype=m.dtype) 

47 index_in = [slice(s) for s in m.shape] 

48 index_out = [slice(s) for s in m.shape] 

49 for i in range(0, shape[axis]): 

50 for a in axes: 

51 index_in[a] = i 

52 index_out[a] = i if a == axis else 0 

53 output[tuple(index_out)] = m[tuple(index_in)] 

54 

55 # Removes axis. 

56 return output.reshape(tuple(new_shape)) 

57 

58 

59def _numpy_extended_dot_equation(m1_dim, m2_dim, axes, left, right): 

60 """ 

61 Returns the equation equivalent to an extended version 

62 of an aligned matrix multiplication 

63 (see @see fn numpy_extended_dot). 

64 

65 :param m1: number of dimensions of the first matrix 

66 :param m2: number of dimensions of the second matrix 

67 :param axes: summation axes 

68 :param axes: summation axes 

69 :param left: left axes 

70 :param right: right axes 

71 :return: equation 

72 

73 .. runpython:: 

74 :showcode: 

75 

76 import numpy 

77 from mlprodict.testing.einsum.einsum_impl_ext import ( 

78 numpy_extended_dot_python, _numpy_extended_dot_equation) 

79 

80 a = numpy.arange(6).reshape((3, 2, 1)) 

81 b = numpy.arange(12).reshape((3, 1, 4)) 

82 

83 print(numpy_extended_dot_python( 

84 a, b, axes=(0, ), left=(1,), right=(2,))) 

85 

86 # Equivalent einsum equation 

87 print('equation', _numpy_extended_dot_equation( 

88 len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,))) 

89 

90 # Same einsum computation written in a different way. 

91 print(numpy.einsum('kix,kxj->xij', a, b)) 

92 """ 

93 if m1_dim != m2_dim: 

94 raise RuntimeError( 

95 "Matrices m1 and m2 must have the same number of dimensions, " 

96 "m1=%r, m2=%r." % (m1_dim, m2_dim)) 

97 total = set(axes) | set(left) | set(right) 

98 if len(total) > m1_dim: 

99 raise ValueError( 

100 "Whole set of involved axes should be inferior to the number " 

101 "of dimensions: %r = {%r} | {%r} | {%r} has more than %d elements" 

102 "." % (total, axes, left, right, m1_dim)) 

103 

104 def _check_(axs, n): 

105 for a in axs: 

106 if a < 0 or a >= n: 

107 raise ValueError( 

108 "One axis %d (in %r) is negative or above the maximum " 

109 "dimension %d." % (a, axs, n)) 

110 _check_(axes, m1_dim) 

111 _check_(left, m1_dim) 

112 _check_(right, m1_dim) 

113 

114 l1 = [chr(i + 97) for i in range(m1_dim)] 

115 l2 = [chr(i + 97) for i in range(m1_dim)] 

116 l3 = [chr(i + 97) for i in range(m1_dim)] 

117 for a in left: 

118 l1[a] = l1[a].upper() 

119 l3[a] = l3[a].upper() 

120 for a in right: 

121 l2[a] = l2[a].upper() 

122 l3[a] = l3[a].upper() 

123 for a in axes: 

124 l1[a] = l1[a].lower() 

125 l2[a] = l2[a].lower() 

126 if a not in right: 

127 l3[a] = None 

128 else: 

129 l3[a] = l3[a].lower() 

130 eq = "%s,%s->%s" % ("".join(l1), "".join(l2), 

131 "".join(s for s in l3 if s)) 

132 return eq 

133 

134 

135def _common_check_numpy_extended_dot(m1, m2, axes, left, right): 

136 """ 

137 Common verifications for all implementations of 

138 @see fn numpy_extended_dot. 

139 """ 

140 if m1.dtype != m2.dtype: 

141 raise TypeError( 

142 "Both matrices should share the same dtype %r != %r." 

143 "" % (m1.dtype, m2.dtype)) 

144 m1_dim = len(m1.shape) 

145 m2_dim = len(m2.shape) 

146 if m1_dim != m2_dim: 

147 raise RuntimeError( # pragma: no cover 

148 "Matrices m1 and m2 must have the same number of dimensions, " 

149 "m1=%r, m2=%r." % (m1_dim, m2_dim)) 

150 total = set(axes) | set(left) | set(right) 

151 if len(total) > m1_dim: 

152 raise ValueError( 

153 "Whole set of involved axes should be inferior to the number " 

154 "of dimensions: %r = {%r} | {%r} | {%r} has more than %d elements" 

155 "." % (total, axes, left, right, m1_dim)) 

156 

157 

158def numpy_extended_dot(m1, m2, axes, left, right, verbose=False): 

159 """ 

160 Extended version of a matrix multiplication (:epkg:`numpy:dot`) 

161 with two matrices *m1*, *m2* of the same dimensions. 

162 Loops over *left* axes for *m1* and *right* axes for *m2*, 

163 summation is done over *axes*. 

164 Other axes must be empty. 

165 This multiplication combines matrix multiplication (dot) 

166 and broadcasted multiplication term by term. 

167 

168 :param m1: first matrix 

169 :param m2: second matrix 

170 :param axes: summation axes 

171 :param left: left axes 

172 :param right: right axes 

173 :param verbose: display intermediate information 

174 :return: output 

175 

176 The dot product is equivalent to: 

177 

178 .. runpython:: 

179 :showcode: 

180 

181 import numpy 

182 from mlprodict.testing.einsum import numpy_extended_dot 

183 

184 m1 = numpy.arange(4).reshape((2, 2)) 

185 m2 = m1 + 10 

186 print("dot product") 

187 print(m1 @ m2) 

188 

189 dm1 = m1.reshape((2, 2, 1)) 

190 dm2 = m2.reshape((1, 2, 2)) 

191 dot = numpy_extended_dot(dm1, dm2, axes=[1], left=[0], right=[2], 

192 verbose=True) 

193 print("extended dot product") 

194 print(dot) 

195 

196 Empty axes should be squeezed to get identical results. 

197 Dot product when the second matrix is transposed. 

198 

199 .. runpython:: 

200 :showcode: 

201 

202 import numpy 

203 from mlprodict.testing.einsum import numpy_extended_dot 

204 

205 m1 = numpy.arange(4).reshape((2, 2)) 

206 m2 = m1 + 10 

207 print("dot product") 

208 print(m1 @ m2.T) 

209 

210 dm1 = m1.reshape((2, 1, 2)) 

211 dm2 = m2.reshape((1, 2, 2)) 

212 dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1], 

213 verbose=True) 

214 print("extended dot product") 

215 print(dot) 

216 

217 An example when right axes include the summation axis. 

218 

219 .. runpython:: 

220 :showcode: 

221 

222 import numpy 

223 from mlprodict.testing.einsum import numpy_extended_dot 

224 

225 m1 = numpy.arange(4).reshape((2, 2)) 

226 m2 = m1 + 10 

227 dm1 = m1.reshape((2, 2, 1)) 

228 dm2 = m2.reshape((1, 2, 2)) 

229 dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1, 2], 

230 verbose=True) 

231 print(dot) 

232 

233 Example in higher dimension: 

234 

235 .. runpython:: 

236 :showcode: 

237 

238 import numpy 

239 from mlprodict.testing.einsum import numpy_extended_dot 

240 

241 m1 = numpy.arange(8).reshape((2, 2, 2)) 

242 m2 = m1 + 10 

243 

244 dot = numpy_extended_dot(m1, m2, [1], [0], [2], verbose=True) 

245 print(dot) 

246 

247 The current implementation still uses :epkg:`numpy:einsum` 

248 but this should be replaced. 

249 """ 

250 _common_check_numpy_extended_dot(m1, m2, axes, left, right) 

251 eq = _numpy_extended_dot_equation( 

252 len(m1.shape), len(m2.shape), axes, left, right) 

253 if verbose: 

254 print(" [numpy_extended_dot] %s: %r @ %r" % (eq, m1.shape, m2.shape)) 

255 output = numpy.einsum(eq, m1, m2) 

256 new_shape = list(output.shape) 

257 for a in axes: 

258 if a not in right: 

259 new_shape.insert(a, 1) 

260 if verbose: 

261 print(" [numpy_extended_dot] %r reshaped into %r " % ( 

262 output.shape, new_shape)) 

263 return output.reshape(tuple(new_shape)) 

264 

265 

266def numpy_extended_dot_ouput_shape(m1, m2, axes, left, right): 

267 """ 

268 Computes the output shape of results produced by function 

269 :func:`numpy_extended_dot 

270 <mlprodict.testing.einsum_impl_ext.numpy_extended_dot>` or 

271 :func:`numpy_extended_dot_python 

272 <mlprodict.testing.einsum_impl_ext.numpy_extended_dot_python>`. 

273 """ 

274 _common_check_numpy_extended_dot(m1, m2, axes, left, right) 

275 m1_dim = len(m1.shape) 

276 

277 new_shape = numpy.full(m1_dim, 1, dtype=numpy.int64) 

278 for i in left: 

279 new_shape[i] = m1.shape[i] 

280 for i in right: 

281 if (i in left and m1.shape[i] != m2.shape[i] and 

282 m1.shape[i] != 1 and m2.shape[i] != 1): 

283 raise RuntimeError( # pragma: no cover 

284 "Matrices should have the same dimension for dimension %d, " 

285 "shapes=%r @ %r." % (i, m1.shape, m2.shape)) 

286 new_shape[i] = m2.shape[i] 

287 return new_shape 

288 

289 

290def _numpy_extended_dot_python_l1l2l3(m1_dim, axes, left, right): 

291 l1 = [chr(i + 97) for i in range(m1_dim)] 

292 l2 = [chr(i + 97) for i in range(m1_dim)] 

293 l3 = [chr(i + 97) for i in range(m1_dim)] 

294 for a in left: 

295 l1[a] = l1[a].upper() 

296 l3[a] = l3[a].upper() 

297 for a in right: 

298 l2[a] = l2[a].upper() 

299 l3[a] = l3[a].upper() 

300 for a in axes: 

301 l1[a] = l1[a].lower() 

302 l2[a] = l2[a].lower() 

303 if a not in right: 

304 l3[a] = "-" 

305 else: 

306 l3[a] = l3[a].lower() 

307 return l1, l2, l3 

308 

309 

310def _numpy_extended_dot_python_intermediate(m1_shape, m2_shape, l1, l2, l3): 

311 names = list(sorted(set(l1 + l2))) 

312 kind = numpy.zeros(len(names), dtype=numpy.int64) 

313 cols = {} 

314 

315 for i, n in enumerate(names): 

316 if n in l1: 

317 kind[i] += 1 

318 cols[n] = l1.index(n) 

319 if n in l2: 

320 kind[i] += 2 

321 cols[n] = l2.index(n) 

322 if n in l3: 

323 kind[i] += 4 

324 

325 pos = numpy.zeros(len(names), dtype=numpy.int64) 

326 for j in range(0, pos.shape[0]): 

327 pos[j] = cols[names[j]] 

328 common = [(kind[i] & 3) == 3 for i in range(len(kind))] 

329 broadcast = [common[i] and m1_shape[pos[i]] != m2_shape[pos[i]] 

330 for i in range(len(common))] 

331 

332 return names, kind, cols, common, broadcast, pos 

333 

334 

335def _numpy_extended_dot_python_update_broadcast( 

336 m1, m2, axes, left, right, l1, l2, l3, names, broadcast, cols, 

337 kind, common, verbose=False): 

338 

339 def dispb(c): 

340 return "".join("o" if b else "." for b in c) 

341 

342 if verbose: 

343 print( # pragma: no cover 

344 "[GENERICDOT] before broadcast %s,%s->%s or %s" % ( 

345 "".join(l1), "".join(l2), "".join(l3), 

346 _numpy_extended_dot_equation( 

347 len(m1.shape), len(m1.shape), axes, left, right))) 

348 print( # pragma: no cover 

349 "[GENERICDOT] names=%s kind=%r common=%s broadcast=%s" % ( 

350 "".join(names), kind.tolist(), 

351 dispb(common), dispb(broadcast))) 

352 

353 for i in range(len(broadcast)): # pylint: disable=C0200 

354 if broadcast[i] and not (kind[i] & 3) == 3: 

355 raise RuntimeError( # pragma: no cover 

356 "Broadcast should only happen on common axes, " 

357 "axes=%r left=%r right=%r shape1=%r shape2=%r." 

358 "" % (axes, left, right, m1.shape, m2.shape)) 

359 if not broadcast[i]: 

360 continue 

361 # We split letters. 

362 p = cols[names[i]] 

363 dim = (m1.shape[p], m2.shape[p]) 

364 let = [l1[p], l2[p], l3[p]] 

365 inp = 1 if dim[0] == 1 else 0 

366 if verbose: 

367 print( # pragma: no cover 

368 "[GENERICDOT] name=%s dim=%r let=%r inp=%r p=%r" % ( 

369 names[i], dim, let, inp, p)) 

370 print( # pragma: no cover 

371 " B0 l1=%r, l2=%r l3=%r" % (l1, l2, l3)) 

372 if (kind[i] & 4) > 0: 

373 # Summation axis is part of the output. 

374 if let[inp].lower() == let[inp]: 

375 let[inp] = let[inp].upper() 

376 else: 

377 let[inp] = let[inp].lower() 

378 l3[p] = let[inp] 

379 if inp == 1: 

380 l2[p] = let[inp] 

381 else: 

382 l1[p] = let[inp] 

383 if verbose: 

384 print( # pragma: no cover 

385 " B1 l1=%r, l2=%r l3=%r" % (l1, l2, l3)) 

386 else: 

387 # Summation axis is not part of the output. 

388 if let[inp].lower() == let[inp]: 

389 let[inp] = let[inp].upper() 

390 else: 

391 let[inp] = let[inp].lower() 

392 if inp == 1: 

393 l2[p] = let[inp] 

394 else: 

395 l1[p] = let[inp] 

396 if verbose: 

397 print(" B2 l1=%r, l2=%r l3=%r" % (l1, l2, l3)) 

398 

399 return l1, l2, l3 

400 

401 

402def numpy_extended_dot_python(m1, m2, axes, left, right, verbose=False): 

403 """ 

404 Implementation of @see fn numpy_extended_dot in pure python. 

405 This implementation is not efficient but shows how to 

406 implement this operation without :epkg:`numpy:einsum`. 

407 

408 .. runpython:: 

409 :showcode: 

410 

411 import numpy 

412 from mlprodict.testing.einsum import numpy_extended_dot_python 

413 from mlprodict.testing.einsum.einsum_impl_ext import ( 

414 _numpy_extended_dot_equation) 

415 

416 a = numpy.arange(6).reshape((3, 2, 1)) 

417 b = numpy.arange(12).reshape((3, 1, 4)) 

418 

419 print(numpy_extended_dot_python( 

420 a, b, axes=(0, ), left=(1,), right=(2,))) 

421 

422 # Equivalent einsum equation 

423 print('equation', _numpy_extended_dot_equation( 

424 len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,))) 

425 

426 # Same einsum computation written in a different way. 

427 print(numpy.einsum('kix,kxj->xij', a, b)) 

428 """ 

429 def dispb(c): 

430 return "".join("o" if b else "." for b in c) 

431 

432 new_shape = numpy_extended_dot_ouput_shape(m1, m2, axes, left, right) 

433 m1_dim = len(m1.shape) 

434 

435 # output result 

436 res = numpy.full(tuple(new_shape), 0, dtype=m1.dtype) 

437 

438 # indices 

439 l1, l2, l3 = _numpy_extended_dot_python_l1l2l3(m1_dim, axes, left, right) 

440 names, kind, cols, common, broadcast, pos = ( 

441 _numpy_extended_dot_python_intermediate( 

442 m1.shape, m2.shape, l1, l2, l3)) 

443 

444 if any(broadcast): 

445 l1, l2, l3 = _numpy_extended_dot_python_update_broadcast( 

446 m1, m2, axes, left, right, l1, l2, l3, names, broadcast, cols, 

447 kind, common, verbose=verbose) 

448 

449 names, kind, cols, common, broadcast, pos = ( 

450 _numpy_extended_dot_python_intermediate( 

451 m1.shape, m2.shape, l1, l2, l3)) 

452 

453 indices = numpy.array([0 for n in names], dtype=numpy.int64) 

454 pl1 = numpy.array([names.index(c) for c in l1], dtype=numpy.int64) 

455 pl2 = numpy.array([names.index(c) for c in l2], dtype=numpy.int64) 

456 limits = numpy.array( 

457 [m1.shape[pos[n]] if (kind[n] & 1) == 1 else m2.shape[pos[n]] 

458 for n in range(len(names))], dtype=numpy.int64) 

459 plo = numpy.array( 

460 [-1 if c not in names else names.index(c) for c in l3], 

461 dtype=numpy.int64) 

462 

463 if verbose: 

464 print("[GENERICDOT] %s,%s->%s or %s" % ( 

465 "".join(l1), "".join(l2), "".join(l3), 

466 _numpy_extended_dot_equation( 

467 len(m1.shape), len(m1.shape), axes, left, right))) 

468 print("[GENERICDOT] shape1=%r shape2=%r shape=%r" % ( 

469 m1.shape, m2.shape, res.shape)) 

470 print("[GENERICDOT] axes=%r left=%r right=%r" % (axes, left, right)) 

471 print("[GENERICDOT] pl1=%r pl2=%r plo=%r" % (pl1, pl2, plo)) 

472 print("[GENERICDOT] names=%s kind=%r common=%s broadcast=%s" % ( 

473 "".join(names), kind.tolist(), 

474 dispb(common), dispb(broadcast))) 

475 print("[GENERICDOT] pos=%r" % pos.tolist()) 

476 print("[GENERICDOT] cols=%r" % cols) 

477 print("[GENERICDOT] limits=%r" % limits) 

478 

479 while indices[0] < limits[0]: 

480 

481 # The function spends most of its time is these three lines. 

482 t1 = tuple(indices[n] for n in pl1) 

483 t2 = tuple(indices[n] for n in pl2) 

484 to = tuple(0 if n == -1 else indices[n] for n in plo) 

485 

486 c = m1[t1] * m2[t2] 

487 

488 if verbose: 

489 print(" %r x %r -> %r v=%r I=%r" % (t1, t2, to, c, indices)) 

490 

491 res[to] += c 

492 

493 last = len(indices) - 1 

494 indices[last] += 1 

495 for i in range(last, 0, -1): 

496 if indices[i] < limits[i]: 

497 break 

498 indices[i] = 0 

499 if i > 0: 

500 indices[i - 1] += 1 

501 

502 return res 

503 

504 

505def numpy_extended_dot_matrix(m1, m2, axes, left, right, verbose=False): 

506 """ 

507 Implementation of @see fn numpy_extended_dot using dot product, 

508 multiplication, transpose and reduction 

509 but not a custom python implementation like 

510 @see fn numpy_extended_dot_python. 

511 

512 .. runpython:: 

513 :showcode: 

514 

515 import numpy 

516 from mlprodict.testing.einsum import numpy_extended_dot_matrix 

517 from mlprodict.testing.einsum.einsum_impl_ext import ( 

518 _numpy_extended_dot_equation) 

519 

520 a = numpy.arange(6).reshape((3, 2, 1)) 

521 b = numpy.arange(12).reshape((3, 1, 4)) 

522 

523 print(numpy_extended_dot_matrix( 

524 a, b, axes=(0, ), left=(1,), right=(2,))) 

525 

526 # Equivalent einsum equation 

527 print('equation', _numpy_extended_dot_equation( 

528 len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,))) 

529 

530 # Same einsum computation written in a different way. 

531 print(numpy.einsum('kix,kxj->xij', a, b)) 

532 """ 

533 _common_check_numpy_extended_dot(m1, m2, axes, left, right) 

534 

535 if verbose: 

536 print( # pragma: no cover 

537 "[GENERICDOT] shape1=%r shape2=%r axes=%r " 

538 "left=%r right=%r -- %s" % ( 

539 m1.shape, m2.shape, axes, left, right, 

540 _numpy_extended_dot_equation( 

541 len(m1.shape), len(m1.shape), axes, left, right))) 

542 

543 if len(axes) == 0 and len(set(left) & set(right)) == 0: 

544 # Simple multiplication 

545 res = m1 * m2 

546 if verbose: 

547 print( # pragma: no cover 

548 "[GENERICDOT] Mul %r @ %r -> %r" % ( 

549 m1.shape, m2.shape, res.shape)) 

550 return res 

551 

552 if (len(set(axes) & set(left)) == 0 and 

553 len(set(axes) & set(right)) == 0): 

554 

555 # No intersection between axes and right: matrix multiplication 

556 # ReduceSum 

557 right_no_left = set(right) - (set(right) & (set(left) | set(axes))) 

558 if right_no_left: 

559 red1 = m1.sum(axis=tuple(sorted(right_no_left)), keepdims=True) 

560 if verbose: 

561 print("[GENERICDOT] reducesumL=%r, %r -> %r" % ( 

562 right_no_left, m1.shape, red1.shape)) 

563 else: 

564 red1 = m1 

565 

566 left_no_right = set(left) - (set(left) & (set(right) | set(axes))) 

567 if left_no_right: 

568 red2 = m2.sum(axis=tuple(sorted(left_no_right)), keepdims=True) 

569 if verbose: 

570 print("[GENERICDOT] reducesumR=%r, %r -> %r" % ( 

571 left_no_right, m2.shape, red2.shape)) 

572 else: 

573 red2 = m2 

574 

575 # Transpose 

576 common_axes = sorted(set(left) & set(right)) 

577 i_axes = [(-1 if i in common_axes 

578 else (1 if i in axes else 0), i) 

579 for i in range(len(m1.shape))] 

580 i_axes.sort() 

581 perm = [_[1] for _ in i_axes] 

582 trm1 = numpy.transpose(red1, axes=perm) 

583 trm2 = numpy.transpose(red2, axes=perm) 

584 if verbose: 

585 print("[GENERICDOT] transposeL=%r, %r -> %r" % ( 

586 perm, red1.shape, trm1.shape)) 

587 print("[GENERICDOT] transposeR=%r, %r -> %r" % ( 

588 perm, red2.shape, trm2.shape)) 

589 final_shape = numpy_extended_dot_ouput_shape( 

590 m1, m2, axes, left, right) 

591 perm_left = [i for i in range(len(perm)) if perm[i] in left] 

592 perm_right = [i for i in range(len(perm)) if perm[i] in right] 

593 perm_common_axes = [i for i in range(len(perm)) 

594 if perm[i] in common_axes] 

595 

596 if verbose: 

597 print("[GENERICDOT] MatMul %r @ %r -> %r -- %s" % ( 

598 m1.shape, m2.shape, final_shape, 

599 _numpy_extended_dot_equation( 

600 len(m1.shape), len(m1.shape), axes, left, right))) 

601 print("[GENERICDOT] axes=%r left=%r right=%r" % 

602 (axes, left, right)) 

603 print("[GENERICDOT] perm=%r perm_left=%r " 

604 "perm_right=%r perm_common_axes=%r" % ( 

605 perm, perm_left, perm_right, perm_common_axes)) 

606 

607 # Reshape 

608 dim0 = int(numpy.prod([trm1.shape[i] for i in perm_common_axes])) 

609 dim0b = int(numpy.prod([trm2.shape[i] for i in perm_common_axes])) 

610 if len(axes) > 0: 

611 all_axes = list(range(0, len(m1.shape))) 

612 new_axes = all_axes[-len(axes):] 

613 else: 

614 new_axes = [] 

615 dim1 = int(numpy.prod([trm1.shape[i] for i in new_axes])) 

616 dim2 = int(numpy.prod([trm2.shape[i] for i in new_axes])) 

617 if dim1 != dim2: 

618 raise RuntimeError( # pragma: no cover 

619 "Summation axis do not have the same length %d != %d, " 

620 "trshape1=%r trshape2=%r " 

621 "p_axes=%r p_left=%r p_right=%r p_common=%r" 

622 "." % (dim1, dim2, trm1.shape, trm2.shape, 

623 new_axes, perm_left, perm_right, perm_common_axes)) 

624 else: 

625 shm1 = trm1.reshape((dim0, -1, dim1)) 

626 shm2 = trm2.reshape((dim0b, -1, dim2)) 

627 

628 if verbose: 

629 print("[GENERICDOT] Reshape %r @ %r -> %r @ %r" % ( 

630 (dim0, -1, dim1), (dim0, -1, dim2), 

631 shm1.shape, shm2.shape)) 

632 print("[GENERICDOT] matmul") 

633 

634 # Multiplication (this should be done in a different way. 

635 res = shm1 @ numpy.transpose(shm2, axes=(0, 2, 1)) 

636 

637 if verbose: 

638 print("[GENERICDOT] Shape after multiplication %s" % (res.shape, )) 

639 

640 # Transpose again 

641 not_in_both = [] 

642 for i in range(0, len(m1.shape)): 

643 if i not in left and i not in right: 

644 not_in_both.append(i) 

645 ordered_axes = (common_axes + 

646 list(i for i in left if i not in right) + 

647 list(i for i in right if i not in left) + 

648 not_in_both) 

649 

650 perm_not_in_both = [i for i in range(len(perm)) 

651 if perm[i] in not_in_both] 

652 current_shape = ([max(trm1.shape[i], trm2.shape[i]) 

653 for i in sorted(perm_common_axes)] + 

654 [trm1.shape[i] for i in sorted(perm_left) 

655 if i not in perm_common_axes] + 

656 [trm2.shape[i] for i in sorted(perm_right) 

657 if i not in perm_common_axes] + 

658 [1 for i in perm_not_in_both]) 

659 

660 if verbose: 

661 print("[GENERICDOT] current_shape=%r final_shape=%r " 

662 "last_shape=%r" % (current_shape, final_shape, res.shape)) 

663 

664 if len(current_shape) != len(final_shape): 

665 raise RuntimeError( # pragma: no cover 

666 "Shapes mismatch %r > %r, " 

667 "shape1=%r shape2=%r axes=%r left=%r right=%r." % ( 

668 current_shape, final_shape, 

669 m1.shape, m2.shape, axes, left, right)) 

670 

671 res = res.reshape(current_shape) 

672 

673 perm = [(a, i) for i, a in enumerate(ordered_axes)] 

674 perm.sort() 

675 perm = [p[1] for p in perm] 

676 

677 if verbose: 

678 print("[GENERICDOT] ordered_axes=%r perm=%r" % ( 

679 ordered_axes, perm)) 

680 

681 return numpy.transpose(res, axes=perm) 

682 

683 else: 

684 # Multiplication and Matrix multiplication at the same time. 

685 l_axes = set(left) & set(axes) 

686 r_axes = set(right) & set(axes) 

687 if r_axes and not l_axes: 

688 new_axes = list(a for a in axes if a not in right) 

689 new_left = list(sorted(set(left) | r_axes)) 

690 if verbose: # pragma: no cover 

691 eq1 = _numpy_extended_dot_equation( 

692 len(m1.shape), len(m1.shape), axes, left, right) 

693 eq2 = _numpy_extended_dot_equation( 

694 len(m1.shape), len(m1.shape), new_axes, new_left, right) 

695 print("[GENERICDOT] replace left %r by %r axes %r by %r, " 

696 "eq %r by %r" % ( 

697 left, new_left, axes, new_axes, eq1, eq2)) 

698 return numpy_extended_dot_matrix(m1, m2, new_axes, new_left, right, 

699 verbose=verbose) 

700 raise RuntimeError( # pragma: no cover 

701 "shape1=%r shape2=%r axes=%r left=%r right=%r eq=%s." % ( 

702 m1.shape, m2.shape, axes, left, right, 

703 _numpy_extended_dot_equation( 

704 len(m1.shape), len(m1.shape), axes, left, right)))