Coverage for mlprodict/testing/einsum/einsum_impl_classes.py: 96%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

887 statements  

1# pylint: disable=C0302 

2""" 

3@file 

4@brief Classes representing the sequence of matrix operations to 

5implement einsum computation. 

6""" 

7import numpy 

8from onnx import helper, numpy_helper 

9from ...onnx_tools.onnx2py_helper import guess_proto_dtype 

10from ...npy.xop_variable import guess_numpy_type 

11from ... import __max_supported_opset__, get_ir_version 

12from .blas_lapack import gemm_dot 

13from .einsum_impl_ext import ( 

14 numpy_extended_dot, numpy_diagonal, 

15 _numpy_extended_dot_equation, 

16 numpy_extended_dot_python, 

17 numpy_extended_dot_matrix) 

18 

19 

20def single_axes(axes): 

21 """ 

22 *axes* contains positive values, then it is the position 

23 of this axis in the original matrix, otherwise it is -1 

24 meaning this axis is an added single dimension to align 

25 all the dimensions based on the einsum equation. 

26 

27 :param axes: axes described above 

28 :return: list of integer in set `{1, 2}`, 1 for 

29 a single axis, 2 otherwise 

30 """ 

31 if axes is None: 

32 return axes 

33 return [(1 if a == -1 else 2) for a in axes] 

34 

35 

36class EinsumSubOp: 

37 """ 

38 Defines a sub operation used in Einsum decomposition. 

39 

40 :param name: name (reshape, transpose, reduce_sum, matmul, id, 

41 squeeze, diagonal, mul, batch_dot) 

42 :param inputs: inputs 

43 :param kwargs: arguments 

44 

45 Operator suffixed by `_mm` (*transpose_mm*, *reduce_sum_mm*) 

46 are equivalent to the same operator without the suffix 

47 but takes two inputs and only changes the first one. 

48 

49 Attributes `_info` summarizes the known information 

50 about dimensions. Many of them are empty because inserted. 

51 Value `1` means it was the case, `2` means it is a plain dimension. 

52 """ 

53 _allowed = {'expand_dims', 'transpose', 'reduce_sum', 'matmul', 'id', 

54 'squeeze', 'diagonal', 'mul', 'batch_dot', 

55 'transpose_mm', 'reduce_sum_mm'} 

56 

57 def __init__(self, full_dim, name, *inputs, **kwargs): 

58 self.full_dim = full_dim 

59 self.name = name 

60 self.inputs = inputs 

61 self.kwargs = kwargs 

62 self._info = {} 

63 if name not in EinsumSubOp._allowed: 

64 raise ValueError( 

65 "Unexpected name %r. It should be in %r." 

66 "" % (name, EinsumSubOp._allowed)) 

67 if len(inputs) not in (1, 2): 

68 raise RuntimeError( 

69 "Inputs must contains 1 or 2 inputs not %d." % len(inputs)) 

70 if name == 'matmul' and len(inputs) != 2: 

71 raise RuntimeError( 

72 "Inputs must contains 2 inputs not %d for operator 'matmul'." 

73 "" % len(inputs)) 

74 for i, inp in enumerate(inputs): 

75 if not isinstance(inp, (int, EinsumSubOp)): 

76 raise TypeError( 

77 "Input %d has type %r, int or EinsumSubOp is expected." 

78 "" % (i, type(inp))) 

79 self._check_() 

80 

81 def _check_(self): 

82 if self.name == 'transpose': 

83 self._check_arg_('perm', tuple) 

84 perm = self.kwargs['perm'] 

85 if len(perm) != len(set(perm)): 

86 raise RuntimeError( # pragma: no cover 

87 "perm has duplicated values %r (name=%r)." 

88 "" % (perm, self.name)) 

89 if list(perm) == list(range(len(perm))): 

90 raise ValueError( # pragma: no cover 

91 "Transpose = identity perm={}. It must be removed." 

92 "".format(perm)) 

93 elif self.name == 'matmul': 

94 self._check_arg_('axes', tuple) 

95 self._check_arg_('left', tuple) 

96 self._check_arg_('right', tuple) 

97 axes = self.kwargs['axes'] 

98 left = self.kwargs['left'] 

99 right = self.kwargs['right'] 

100 for a in axes: 

101 if a in left and a in right: 

102 raise RuntimeError( # pragma: no cover 

103 "One axis belongs to every set (axes, left, right). " 

104 "axes=%r, left=%r, right=%r." % (axes, left, right)) 

105 

106 def __repr__(self): 

107 inps = ", ".join(map(str, self.inputs)) 

108 kw = ", ".join("%s=%r" % (k, w) for k, w in self.kwargs.items()) 

109 m = "%s(%r, %s, %s)" % ( 

110 self.__class__.__name__, self.name, inps, kw) 

111 return m 

112 

113 def dot_label(self): 

114 """ 

115 Displays some informations useful to understand the operator. 

116 """ 

117 if self.name == "matmul": 

118 ndim = self.kwargs['ndim'] 

119 axes = self.kwargs['axes'] 

120 left = self.kwargs['left'] 

121 right = self.kwargs['right'] 

122 eq = _numpy_extended_dot_equation(ndim, ndim, axes, left, right) 

123 eq = eq.replace(">", "\\\\>") 

124 return "~" + eq 

125 return None 

126 

127 def _check_arg_(self, name, typ, empty=False): 

128 if name not in self.kwargs: 

129 raise RuntimeError( # pragma: no cover 

130 "Parameter %r not found for operator %r." % (name, self.name)) 

131 if empty and self.kwargs[name] is None: 

132 return 

133 if not isinstance(self.kwargs[name], typ): 

134 raise TypeError( # pragma: no cover 

135 "Unexpected type %r for parameter %r and parameter %r." 

136 "" % (type(self.kwargs[name]), name, self.name)) 

137 

138 def _check_row_(self, row, inp=False, verbose=False): 

139 """ 

140 Checks input or output is valid. 

141 """ 

142 if verbose: 

143 if inp: 

144 print('<<' if inp else '>>', self.name, row, self.kwargs) 

145 else: 

146 print('<<' if inp else '>>', self.name, row) 

147 

148 def _compute_output_row_id(self, row, row2=None, ab=False, verbose=False): 

149 if ab: 

150 raise RuntimeError("ab option not allowed.") # pragma: no cover 

151 self._check_row_(row, True, verbose=verbose) 

152 row[:] = row2[:] 

153 self._check_row_(row, verbose=verbose) 

154 

155 def _compute_output_row_transpose(self, row, row2=None, ab=False, verbose=False): 

156 if ab: 

157 self._compute_output_row_transpose(row2, verbose=verbose) 

158 return 

159 self._check_row_(row, True, verbose=verbose) 

160 self._check_arg_('perm', tuple) 

161 if len(self.kwargs['perm']) != len(row): 

162 raise RuntimeError( # pragma: no cover 

163 "Unexpected permutation %r (row=%r)." 

164 "" % (self.kwargs['perm'], row)) 

165 perm = self.kwargs['perm'] 

166 cpy = row.copy() 

167 for i, p in enumerate(perm): 

168 row[i] = cpy[p] 

169 self._check_row_(row, verbose=verbose) 

170 

171 def _compute_output_row_transpose_mm(self, row, row2=None, ab=False, verbose=False): 

172 if not ab: 

173 raise RuntimeError("ab must be True.") # pragma: no cover 

174 self._check_row_(row, True, verbose=verbose) 

175 if row2 is None: 

176 raise RuntimeError( # pragma: no cover 

177 "transpose_mm expects a second input.") 

178 self._compute_output_row_transpose(row, row2=None, verbose=verbose) 

179 

180 def _compute_output_row_expand_dims(self, row, row2=None, ab=False, verbose=False): 

181 if ab: 

182 raise RuntimeError("ab option not allowed.") # pragma: no cover 

183 self._check_row_(row, True, verbose=verbose) 

184 self._check_arg_('axes', tuple) 

185 axes = self.kwargs['axes'] 

186 for axis in axes: 

187 if not isinstance(axis, tuple): 

188 raise TypeError( # pragma: no cover 

189 "Parameter axes of expand_dims should be a tuple of " 

190 "tuple, axes=%r." % axes) 

191 if row[axis[1]] != -1: 

192 raise RuntimeError( # pragma: no cover 

193 "Dimension should be -1 in row %r axis=%r." % ( 

194 row, self.kwargs['axis'])) 

195 self._check_row_(row, verbose=verbose) 

196 

197 def _compute_output_row_reduce_sum(self, row, row2=None, ab=False, verbose=False): 

198 if ab: 

199 raise RuntimeError("ab option not allowed.") # pragma: no cover 

200 self._check_row_(row, True, verbose=verbose) 

201 self._check_arg_('axes', tuple) 

202 for a in self.kwargs['axes']: 

203 row[a] = -1 

204 self._check_row_(row, verbose=verbose) 

205 

206 def _compute_output_row_reduce_sum_mm(self, row, row2=None, ab=False, verbose=False): 

207 if not ab: 

208 raise RuntimeError("ab must be true.") # pragma: no cover 

209 self._check_row_(row2, True, verbose=verbose) 

210 if row2 is None: 

211 raise RuntimeError( # pragma: no cover 

212 "reduce_sum_mm expects a second input.") 

213 self._compute_output_row_reduce_sum(row, row2=None, verbose=verbose) 

214 

215 def _compute_output_row_squeeze(self, row, row2=None, ab=False, verbose=False): 

216 if ab: 

217 raise RuntimeError("ab option not allowed.") # pragma: no cover 

218 self._check_row_(row, True, verbose=verbose) 

219 self._check_arg_('axes', tuple) 

220 for a in self.kwargs['axes']: 

221 row[a] = -1 

222 self._check_row_(row, verbose=verbose) 

223 

224 def _compute_output_row_diagonal(self, row, row2=None, ab=False, verbose=False): 

225 if ab: 

226 raise RuntimeError("ab option not allowed.") # pragma: no cover 

227 self._check_row_(row, True, verbose=verbose) 

228 self._check_arg_('diag', list) 

229 to_remove = [] 

230 for choice, choices in self.kwargs['diag']: 

231 for ch in choices: 

232 if ch != choice: 

233 to_remove.append(ch) 

234 for i in range(len(row)): # pylint: disable=C0200 

235 if row[i] in choices: 

236 if row[i] != choice: 

237 row[i] = choice 

238 to_remove.sort() 

239 for r in to_remove: 

240 for i in range(len(row)): # pylint: disable=C0200 

241 if row[i] == r: 

242 raise RuntimeError( # pragma: no cover 

243 "Unexpected result r=%r row=%r to_remove=%r " 

244 "diag=%r." % ( 

245 r, row, to_remove, self.kwargs['diag'])) 

246 if row[i] > r: 

247 row[i] -= 1 

248 self._check_row_(row, verbose=verbose) 

249 

250 def _compute_output_row_matmul(self, row, row2=None, ab=False, verbose=False): 

251 if not ab: 

252 raise RuntimeError("ab must be True.") # pragma: no cover 

253 self._check_row_(row, True, verbose=verbose) 

254 self._check_row_(row2, True, verbose=verbose) 

255 self._check_arg_('axes', tuple) 

256 self._check_arg_('left', tuple) 

257 self._check_arg_('right', tuple) 

258 self._check_arg_('ndim', int) 

259 if row2 is None: 

260 raise RuntimeError( # pragma: no cover 

261 "matmul expects two inputs.") 

262 if verbose: 

263 ndim = self.kwargs['ndim'] 

264 axes = self.kwargs['axes'] 

265 left = self.kwargs['left'] 

266 right = self.kwargs['right'] 

267 print(" MATMUL %r @ %r axes=%r left=%r right=%r - eq=%s" % ( 

268 row, row2, axes, left, right, 

269 _numpy_extended_dot_equation(ndim, ndim, axes, left, right))) 

270 row2[:] = numpy.maximum(row, row2) 

271 for a in self.kwargs['axes']: 

272 if a not in self.kwargs['right']: 

273 row2[a] = -1 

274 self._check_row_(row2, verbose=verbose) 

275 

276 def _compute_output_row_batch_dot(self, row, row2=None, ab=False, verbose=False): 

277 if not ab: 

278 raise RuntimeError("ab must be True.") # pragma: no cover 

279 self._check_row_(row, True, verbose=verbose) 

280 self._check_row_(row2, True, verbose=verbose) 

281 self._check_arg_('batch_axes', tuple) 

282 self._check_arg_('keep_axes', tuple, empty=True) 

283 self._check_arg_('sum_axes', tuple) 

284 self._check_arg_('left', tuple) 

285 self._check_arg_('right', tuple) 

286 self._check_arg_('ndim', int) 

287 if row2 is None: 

288 raise RuntimeError( 

289 "batch_dot expects two inputs.") # pragma: no cover 

290 if verbose: 

291 batch_axes = self.kwargs['batch_axes'] 

292 keep_axes = self.kwargs['keep_axes'] 

293 sum_axes = self.kwargs['sum_axes'] 

294 left = self.kwargs['left'] 

295 right = self.kwargs['right'] 

296 ndim = self.kwargs['ndim'] 

297 print(" BATCH_DOT batch_axes=%r keep_axes=%r sum_axes=%r " 

298 "left=%r right=%r eq=%r" % ( 

299 batch_axes, keep_axes, sum_axes, left, right, 

300 _numpy_extended_dot_equation(ndim, ndim, sum_axes, left, right))) 

301 row2[:] = numpy.maximum(row, row2) 

302 for a in self.kwargs['sum_axes']: 

303 if a not in self.kwargs['right']: 

304 row2[a] = -1 

305 self._check_row_(row2, verbose=verbose) 

306 

307 def _compute_output_row_mul(self, row, row2=None, ab=False, verbose=False): 

308 if not ab: 

309 raise RuntimeError("ab must be True.") # pragma: no cover 

310 self._check_row_(row, True, verbose=verbose) 

311 self._check_row_(row2, True, verbose=verbose) 

312 if row2 is None: 

313 raise RuntimeError("mul expects two inputs.") # pragma: no cover 

314 if verbose: 

315 print( # pragma: no cover 

316 " MUL %r @ %r" % (row, row2)) 

317 row2[:] = numpy.maximum(row, row2) 

318 self._check_row_(row2, verbose=verbose) 

319 

320 def compute_output_row(self, row, row2=None, ab=False, verbose=False): 

321 """ 

322 Updates *row* based on the operator. 

323 """ 

324 method_name = "_compute_output_row_%s" % self.name 

325 meth = getattr(self, method_name, None) 

326 if meth is None: 

327 raise NotImplementedError( # pragma: no cover 

328 "compute_output_row not implemented for %r." % self.name) 

329 if verbose and ab: 

330 print(" -- called as a binary operator") 

331 self.add_info(i_row=single_axes(row), i_row2=single_axes(row2)) 

332 meth(row, row2=row2, ab=ab, verbose=verbose) 

333 self.add_info(o_row=single_axes(row), o_row2=single_axes(row2)) 

334 

335 def add_info(self, **kwargs): 

336 """ 

337 Adds information to the node. 

338 

339 :param kwargs: dictionary 

340 """ 

341 for k, v in kwargs.items(): 

342 if k in self._info: 

343 raise KeyError( # pragma: no cover 

344 "Key %r already added (operator %r)." % (k, self.name)) 

345 self._info[k] = v 

346 

347 def _check_inputs_(self, n_expected, check_dim=False): 

348 if len(self.inputs) != n_expected: 

349 raise RuntimeError( # pragma: no cover 

350 "Number of inputs must be %d not %d for operator %r." 

351 "" % (n_expected, len(self.inputs), self.name)) 

352 

353 def _check_shape_(self, m): 

354 if len(m.shape) != self.full_dim: 

355 raise RuntimeError( # pragma: no cover 

356 "Number of dimensions %r is different from expected value " 

357 "%d." % (m.shape, self.full_dim)) 

358 

359 def _get_data(self, data, key): 

360 if isinstance(key, int): 

361 if key not in data: 

362 raise RuntimeError( # pragma: no cover 

363 "Unable to find key %d in %r." % ( 

364 key, list(sorted(data)))) 

365 return data[key] 

366 if isinstance(key, EinsumSubOp): 

367 if id(key) not in data: 

368 raise RuntimeError( # pragma: no cover 

369 "Unable to find key %d in %r." % ( 

370 id(key), list(sorted(data)))) 

371 return data[id(key)] 

372 raise TypeError( # pragma: no cover 

373 "Unexpected input type %r." % type(key)) 

374 

375 def _apply_id(self, data, verbose=False, **kwargs): 

376 self._check_inputs_(1) 

377 inp = self.inputs[0] 

378 output = self._get_data(data, inp) 

379 return output 

380 

381 def _apply_diagonal(self, data, verbose=False, **kwargs): 

382 self._check_inputs_(1) 

383 inp = self.inputs[0] 

384 m = self._get_data(data, inp) 

385 if verbose: 

386 print( # pragma: no cover 

387 "- %s, shape=%r diag=%r" % ( 

388 self.name, m.shape, self.kwargs['diag'])) 

389 diag = self.kwargs['diag'] 

390 if len(diag) != 1: 

391 raise NotImplementedError( # pragma: no cover 

392 "Not implemented with more than one duplicated indice " 

393 "%r." % diag) 

394 diag0 = diag[0] 

395 output = numpy_diagonal(m, axis=diag0[0], axes=diag0[1]) 

396 return output 

397 

398 def _apply_expand_dims(self, data, verbose=False, **kwargs): 

399 self._check_inputs_(1) 

400 inp = self.inputs[0] 

401 m = self._get_data(data, inp) 

402 if verbose: 

403 print("- %s, shape=%r axes=%r" % ( 

404 self.name, m.shape, self.kwargs['axes'])) 

405 output = m 

406 for axis in reversed(self.kwargs['axes']): 

407 output = numpy.expand_dims(output, axis[0]) 

408 return output 

409 

410 def _apply_transpose(self, data, verbose=False, **kwargs): 

411 self._check_inputs_(1, True) 

412 inp = self.inputs[0] 

413 m = self._get_data(data, inp) 

414 self._check_shape_(m) 

415 if verbose: 

416 print("- %s, shape=%r perm=%r" % ( 

417 self.name, m.shape, self.kwargs['perm'])) 

418 output = numpy.transpose(m, self.kwargs['perm']) 

419 self._check_shape_(output) 

420 return output 

421 

422 def _apply_transpose_mm(self, data, verbose=False, **kwargs): 

423 self._check_inputs_(2, True) 

424 inp = self.inputs[0] 

425 m = self._get_data(data, inp) 

426 self._check_shape_(m) 

427 if verbose: 

428 print( # pragma: no cover 

429 "- %s, shape=%r perm=%r" % ( 

430 self.name, m.shape, self.kwargs['perm'])) 

431 output = numpy.transpose(m, self.kwargs['perm']) 

432 self._check_shape_(output) 

433 return output 

434 

435 def _apply_matmul(self, data, verbose=False, **kwargs): 

436 self._check_inputs_(2) 

437 inp1 = self.inputs[0] 

438 inp2 = self.inputs[1] 

439 m1 = self._get_data(data, inp1) 

440 m2 = self._get_data(data, inp2) 

441 self._check_shape_(m1) 

442 self._check_shape_(m2) 

443 axes = self.kwargs['axes'] 

444 left = self.kwargs['left'] 

445 right = self.kwargs['right'] 

446 

447 if verbose: 

448 print("- %s, shapes=%r @ %r axes=%r left=%r right=%r" % ( 

449 self.name, m1.shape, m2.shape, axes, left, right)) 

450 

451 impl = kwargs.get('matmul_impl', None) 

452 if impl == 'pyf': 

453 output = numpy_extended_dot_matrix(m1, m2, axes, left, right, 

454 verbose=verbose) 

455 elif impl == 'py': 

456 output = numpy_extended_dot_python(m1, m2, axes, left, right, 

457 verbose=verbose) 

458 elif impl is None: 

459 output = numpy_extended_dot(m1, m2, axes, left, right, 

460 verbose=verbose) 

461 else: 

462 raise ValueError( 

463 "Unknown implementation of numpy_extended_dot ({}).".format(impl)) 

464 self._check_shape_(output) 

465 return output 

466 

467 def _apply_mul(self, data, verbose=False, **kwargs): 

468 self._check_inputs_(2) 

469 inp1 = self.inputs[0] 

470 inp2 = self.inputs[1] 

471 m1 = self._get_data(data, inp1) 

472 m2 = self._get_data(data, inp2) 

473 self._check_shape_(m1) 

474 self._check_shape_(m2) 

475 

476 if verbose: 

477 print( # pragma: no cover 

478 "- %s, shapes=%r @ %r" % (self.name, m1.shape, m2.shape)) 

479 

480 output = m1 * m2 

481 self._check_shape_(output) 

482 return output 

483 

484 def _apply_batch_dot(self, data, verbose=False, **kwargs): 

485 self._check_inputs_(2) 

486 inp1 = self.inputs[0] 

487 inp2 = self.inputs[1] 

488 m1 = self._get_data(data, inp1) 

489 m2 = self._get_data(data, inp2) 

490 self._check_shape_(m1) 

491 self._check_shape_(m2) 

492 batch_axes = self.kwargs['batch_axes'] 

493 keep_axes = self.kwargs['keep_axes'] 

494 sum_axes = self.kwargs['sum_axes'] 

495 left = self.kwargs['left'] 

496 right = self.kwargs['right'] 

497 

498 if verbose: 

499 print("- %s, shapes=%r @ %r batch_axes=%r keep_axes=%r " 

500 "sum_axes=%r" % ( 

501 self.name, m1.shape, m2.shape, batch_axes, keep_axes, sum_axes)) 

502 

503 if len(m1.shape) != len(m2.shape): 

504 raise RuntimeError( # pragma: no cover 

505 "batch_dot only work with two tensors with the same number " 

506 "of dimensions not %r @ %r." % (m1.shape, m2.shape)) 

507 

508 dim0 = int(numpy.prod([m1.shape[i] for i in batch_axes])) 

509 dim0b = int(numpy.prod([m2.shape[i] for i in batch_axes])) 

510 dimb = int(-1 if keep_axes is None else numpy.prod( 

511 [m1.shape[i] for i in keep_axes])) 

512 dim1 = int(numpy.prod([m1.shape[i] for i in sum_axes])) 

513 dim2 = int(numpy.prod([m2.shape[i] for i in sum_axes])) 

514 

515 if verbose: 

516 print("- %s, reshape=%r into %r" % ( 

517 self.name, m1.shape, (dim0, dimb, dim1))) 

518 print("- %s, reshape=%r into %r" % ( 

519 self.name, m2.shape, (dim0b, dimb, dim2))) 

520 m1sh = m1.reshape((dim0, dimb, dim1)) 

521 m2sh = m2.reshape((dim0b, dimb, dim2)) 

522 

523 batch_kind = self.get_dot_kind() 

524 if batch_kind in ('11', 'N1', 'N1'): 

525 m1sh = m1sh.reshape((-1, m1sh.shape[-1])) 

526 m2sh = m2sh.reshape((-1, m2sh.shape[-1])) 

527 if verbose: 

528 print("- %s, use gemm with shape %r, %r" % ( 

529 self.name, m1sh.shape, m2sh.shape)) 

530 dot = gemm_dot(m1sh, m2sh, False, True) 

531 else: 

532 dot = m1sh @ numpy.transpose(m2sh, (0, 2, 1)) 

533 

534 # new shape 

535 new_shape = ([max(m1.shape[i], m2.shape[i]) for i in batch_axes] + 

536 [m1.shape[i] for i in left if i not in batch_axes] + 

537 [m2.shape[i] for i in right if i not in batch_axes]) 

538 while len(new_shape) < len(m1.shape): 

539 new_shape.append(1) 

540 

541 if verbose: 

542 taken = set(batch_axes) | set(sum_axes) 

543 ax = [i for i in range(len(m1.shape)) if i not in taken] 

544 print("- %s, shapes=%r @ %r -> %r" % ( 

545 self.name, m1sh.shape, m2sh.shape, dot.shape)) 

546 print("- %s, batch_axes=%r ax=%r new_shape=%r left=%r right=%r" % ( 

547 self.name, batch_axes, ax, new_shape, left, right)) 

548 

549 output = dot.reshape(tuple(new_shape)) 

550 self._check_shape_(output) 

551 return output 

552 

553 def _apply_reduce_sum(self, data, verbose=False, **kwargs): 

554 self._check_inputs_(1) 

555 inp = self.inputs[0] 

556 m = self._get_data(data, inp) 

557 self._check_shape_(m) 

558 axes = self.kwargs['axes'] 

559 if verbose: 

560 print("- %s, shape=%r axes=%r" % ( 

561 self.name, m.shape, self.kwargs['axes'])) 

562 output = numpy.sum(m, axis=axes, keepdims=True) 

563 self._check_shape_(output) 

564 return output 

565 

566 def _apply_reduce_sum_mm(self, data, verbose=False, **kwargs): 

567 self._check_inputs_(2, True) 

568 inp = self.inputs[0] 

569 m = self._get_data(data, inp) 

570 self._check_shape_(m) 

571 if verbose: 

572 print("- %s, shape=%r axes=%r" % ( 

573 self.name, m.shape, self.kwargs['axes'])) 

574 output = numpy.sum(m, self.kwargs['axes']) 

575 self._check_shape_(output) 

576 return output 

577 

578 def _apply_squeeze(self, data, verbose=False, **kwargs): 

579 self._check_inputs_(1) 

580 inp = self.inputs[0] 

581 m = self._get_data(data, inp) 

582 axes = self.kwargs['axes'] 

583 if verbose: 

584 print("- %s, shape=%r axes=%r" % ( 

585 self.name, m.shape, self.kwargs['axes'])) 

586 output = m 

587 for a in axes[::-1]: 

588 output = numpy.squeeze(output, axis=a) 

589 return output 

590 

591 def apply(self, data, verbose=False, **kwargs): 

592 """ 

593 Applies one operator on the data. 

594 

595 :param data: dictionary storing the results 

596 :param verbose: prints out intermediate results 

597 :param kwargs: additional parameters, see 

598 methods `_apply*` 

599 :return: output 

600 

601 Known additional paramaters: 

602 

603 * 'matmul_impl': if None calls :epkg:`numpy:einsum` through 

604 @see fn numpy_extended_dot (default) or 'py' to call 

605 @see fn numpy_extended_dot_python instead. 

606 """ 

607 if verbose: 

608 print() 

609 print("apply %r (%s)." % ( 

610 self.name, ", ".join(map(lambda s: str(id(s)), self.inputs)))) 

611 

612 method_name = "_apply_%s" % self.name 

613 meth = getattr(self, method_name, None) 

614 if meth is None: 

615 raise NotImplementedError( # pragma: no cover 

616 "apply not implemented for %r." % self.name) 

617 output = meth(data, verbose, **kwargs) 

618 

619 data[id(self)] = output 

620 if verbose: 

621 print("+ %s, shape=%r -- %d" % (self.name, output.shape, id(self))) 

622 return output 

623 

624 def _onnx_name(self): 

625 return 'einsum%d_%s' % (id(self), self.name[:2]) 

626 

627 def _check_onnx_opset_(self, opset, limit): 

628 if opset is not None and opset < limit: 

629 raise RuntimeError( # pragma: no cover 

630 "Opset (%r) must be >= %r for operator %r." 

631 "" % (opset, limit, self.name)) 

632 

633 def _to_onnx_id(self, names, opset, verbose=False, **kwargs): 

634 self._check_inputs_(1) 

635 inp = self.inputs[0] 

636 name = self._get_data(names, inp) 

637 yield helper.make_node('Identity', [name], [self._onnx_name()]) 

638 

639 def _to_onnx_expand_dims(self, names, opset, verbose=False, **kwargs): 

640 self._check_inputs_(1) 

641 self._check_onnx_opset_(opset, 11) 

642 inp = self.inputs[0] 

643 name = self._get_data(names, inp) 

644 axes = self.kwargs['axes'] 

645 name_axes = name + '_axes' 

646 yield numpy_helper.from_array( 

647 numpy.array([a[1] for a in axes], dtype=numpy.int64), name=name_axes) 

648 s_axes = "".join(map(str, [a[1] for a in axes])) 

649 yield helper.make_node( 

650 'Unsqueeze', [name, name_axes], [self._onnx_name()], 

651 name='Unsqueeze%s_%d' % (s_axes, id(self))) 

652 

653 def _to_onnx_squeeze(self, names, opset, verbose=False, **kwargs): 

654 self._check_inputs_(1) 

655 self._check_onnx_opset_(opset, 11) 

656 inp = self.inputs[0] 

657 name = self._get_data(names, inp) 

658 axes = self.kwargs['axes'] 

659 name_axes = name + '_axes' 

660 yield numpy_helper.from_array( 

661 numpy.array(axes, dtype=numpy.int64), name=name_axes) 

662 s_axes = "".join(map(str, axes)) 

663 yield helper.make_node( 

664 'Squeeze', [name, name_axes], [self._onnx_name()], 

665 name='Squeeze%s_%d' % (s_axes, id(self))) 

666 

667 def _to_onnx_transpose(self, names, opset, verbose=False, **kwargs): 

668 self._check_inputs_(1) 

669 inp = self.inputs[0] 

670 name = self._get_data(names, inp) 

671 perm = self.kwargs['perm'] 

672 s_perm = "".join(map(str, perm)) 

673 yield helper.make_node( 

674 'Transpose', [name], [self._onnx_name()], perm=perm, 

675 name='Transpose%s_%d' % (s_perm, id(self))) 

676 

677 def _to_onnx_reduce_sum(self, names, opset, verbose=False, **kwargs): 

678 self._check_inputs_(1) 

679 self._check_onnx_opset_(opset, 11) 

680 inp = self.inputs[0] 

681 name = self._get_data(names, inp) 

682 axes = self.kwargs['axes'] 

683 name_axes = self._onnx_name() + '_axes' 

684 yield numpy_helper.from_array( 

685 numpy.array(axes, dtype=numpy.int64), name=name_axes) 

686 s_axes = "".join(map(str, axes)) 

687 yield helper.make_node( 

688 'ReduceSum', [name, name_axes], [self._onnx_name()], keepdims=1, 

689 name='ReduceSum%s_%d' % (s_axes, id(self))) 

690 

691 def _to_onnx_mul(self, data, verbose=False, **kwargs): 

692 self._check_inputs_(2) 

693 inp1 = self.inputs[0] 

694 inp2 = self.inputs[1] 

695 m1 = self._get_data(data, inp1) 

696 m2 = self._get_data(data, inp2) 

697 yield helper.make_node('Mul', [m1, m2], [self._onnx_name()]) 

698 

699 def _to_onnx_batch_dot(self, names, opset, verbose=False, **kwargs): # pylint: disable=R0914 

700 self._check_inputs_(2) 

701 self._check_onnx_opset_(opset, 13) 

702 inp1, inp2 = self.inputs[:2] # pylint: disable=W0632 

703 name1 = self._get_data(names, inp1) 

704 name2 = self._get_data(names, inp2) 

705 

706 batch_axes = self.kwargs['batch_axes'] 

707 keep_axes = self.kwargs['keep_axes'] 

708 sum_axes = self.kwargs['sum_axes'] 

709 left = self.kwargs['left'] 

710 right = self.kwargs['right'] 

711 root = self._onnx_name() 

712 

713 def return_name_one(): 

714 name_one = root + "_1" 

715 return name_one, numpy_helper.from_array( 

716 numpy.array([1], dtype=numpy.int64), name=name_one) 

717 

718 name_one = None 

719 name_shape1 = root + "_shape1" 

720 name_shape2 = root + "_shape2" 

721 concat_left = [] 

722 concat_right = [] 

723 yield helper.make_node('Shape', [name1], [name_shape1]) 

724 yield helper.make_node('Shape', [name2], [name_shape2]) 

725 

726 if len(batch_axes) > 0: 

727 name_batch_axes = root + "_batch_axes" 

728 yield numpy_helper.from_array( 

729 numpy.array(batch_axes, dtype=numpy.int64), name=name_batch_axes) 

730 

731 if len(sum_axes) > 0: 

732 name_sum_axes = root + "_sum_axes" 

733 yield numpy_helper.from_array( 

734 numpy.array(sum_axes, dtype=numpy.int64), name=name_sum_axes) 

735 

736 # dim0 = int(numpy.prod([m1.shape[i] for i in batch_axes])) 

737 # dim0b = int(numpy.prod([m2.shape[i] for i in batch_axes])) 

738 if len(batch_axes) > 1: 

739 name_dim0 = root + "_dim0" 

740 name_dim0b = root + "_dim0b" 

741 name_dim0g = name_dim0 + 'g' 

742 name_dim0bg = name_dim0b + 'g' 

743 concat_left.append(name_dim0) 

744 concat_right.append(name_dim0b) 

745 yield helper.make_node( 

746 'Gather', [name_shape1, name_batch_axes], [name_dim0g]) 

747 yield helper.make_node( 

748 'Gather', [name_shape2, name_batch_axes], [name_dim0bg]) 

749 yield helper.make_node( 

750 'ReduceProd', [name_dim0g], [name_dim0], keepdims=1) 

751 yield helper.make_node( 

752 'ReduceProd', [name_dim0bg], [name_dim0b], keepdims=1) 

753 elif len(batch_axes) == 1: 

754 name_dim0g = root + "_dim0g" 

755 name_dim0bg = root + "_dim0bg" 

756 name_dim0 = name_dim0g 

757 name_dim0b = name_dim0bg 

758 concat_left.append(name_dim0) 

759 concat_right.append(name_dim0b) 

760 yield helper.make_node( 

761 'Gather', [name_shape1, name_batch_axes], [name_dim0g]) 

762 yield helper.make_node( 

763 'Gather', [name_shape2, name_batch_axes], [name_dim0bg]) 

764 else: 

765 if name_one is None: 

766 name_one, cst_init = return_name_one() 

767 yield cst_init 

768 name_dim0 = name_one 

769 name_dim0b = name_one 

770 concat_left.append(name_dim0) 

771 concat_right.append(name_dim0b) 

772 

773 # dimb = int(-1 if keep_axes is None else numpy.prod( 

774 # [m1.shape[i] for i in keep_axes])) 

775 if keep_axes in (-1, None) or len(keep_axes) == 0: 

776 name_dimb = root + "__1" 

777 concat_left.append(name_dimb) 

778 concat_right.append(name_dimb) 

779 yield numpy_helper.from_array( 

780 numpy.array([-1], dtype=numpy.int64), name=name_dimb) 

781 elif len(keep_axes) == 1: 

782 name_keep_axes = root + "_keep_axes" 

783 name_dimb = root + "_dimb" 

784 name_dimbg = name_dimb 

785 concat_left.append(name_dimb) 

786 concat_right.append(name_dimb) 

787 yield numpy_helper.from_array( 

788 numpy.array(keep_axes, dtype=numpy.int64), name=name_keep_axes) 

789 yield helper.make_node( 

790 'Gather', [name_shape1, name_keep_axes], [name_dimbg]) 

791 else: 

792 name_keep_axes = root + "_keep_axes" 

793 name_dimb = root + "_dimb" 

794 name_dimbg = name_dimb + 'g' 

795 concat_left.append(name_dimb) 

796 concat_right.append(name_dimb) 

797 yield numpy_helper.from_array( 

798 numpy.array(keep_axes, dtype=numpy.int64), name=name_keep_axes) 

799 yield helper.make_node( 

800 'Gather', [name_shape1, name_keep_axes], [name_dimbg]) 

801 yield helper.make_node( 

802 'ReduceProd', [name_dimbg], [name_dimb], keepdims=1) 

803 

804 # dim1 = int(numpy.prod([m1.shape[i] for i in sum_axes])) 

805 # dim2 = int(numpy.prod([m2.shape[i] for i in sum_axes])) 

806 

807 if len(sum_axes) == 0: 

808 if name_one is None: 

809 name_one, cst_init = return_name_one() 

810 yield cst_init 

811 name_dim1 = name_one 

812 name_dim2 = name_one 

813 concat_left.append(name_dim1) 

814 concat_right.append(name_dim2) 

815 elif len(sum_axes) == 1: 

816 name_dim1 = root + "_dim1" 

817 name_dim2 = root + "_dim2" 

818 name_dim1g = name_dim1 

819 name_dim2g = name_dim2 

820 concat_left.append(name_dim1) 

821 concat_right.append(name_dim2) 

822 yield helper.make_node( 

823 'Gather', [name_shape1, name_sum_axes], [name_dim1g]) 

824 yield helper.make_node( 

825 'Gather', [name_shape2, name_sum_axes], [name_dim2g]) 

826 else: 

827 name_dim1 = root + "_dim1" 

828 name_dim2 = root + "_dim2" 

829 name_dim1g = name_dim1 + 'g' 

830 name_dim2g = name_dim2 + 'g' 

831 concat_left.append(name_dim1) 

832 concat_right.append(name_dim2) 

833 yield helper.make_node( 

834 'Gather', [name_shape1, name_sum_axes], [name_dim1g]) 

835 yield helper.make_node( 

836 'Gather', [name_shape2, name_sum_axes], [name_dim2g]) 

837 yield helper.make_node( 

838 'ReduceProd', [name_dim1g], [name_dim1], keepdims=1) 

839 yield helper.make_node( 

840 'ReduceProd', [name_dim2g], [name_dim2], keepdims=1) 

841 

842 batch_kind = self.get_dot_kind() 

843 if batch_kind in ('11', 'N1', 'N1'): 

844 # *shape1, *shape2 

845 name_minus_one = root + "__01" 

846 yield numpy_helper.from_array( 

847 numpy.array([-1], dtype=numpy.int64), name=name_minus_one) 

848 name_agg_shape1_2 = root + "_resh1_%s" % batch_kind 

849 name_agg_shape2_2 = root + "_resh2_%s" % batch_kind 

850 yield helper.make_node( 

851 'Concat', [name_minus_one, name_dim1], [name_agg_shape1_2], axis=0) 

852 yield helper.make_node( 

853 'Concat', [name_minus_one, name_dim2], [name_agg_shape2_2], axis=0) 

854 

855 # m1sh = m1.reshape((-1, dim1)) 

856 # m2sh = m2.reshape((-1, dim2)) 

857 name_agg1_2 = root + "_aresh1" 

858 name_agg2_2 = root + "_aresh2" 

859 yield helper.make_node('Reshape', [name1, name_agg_shape1_2], [name_agg1_2]) 

860 yield helper.make_node('Reshape', [name2, name_agg_shape2_2], [name_agg2_2]) 

861 

862 # dot = gemm(m1sh, m2sh, False, True) 

863 name_dot = root + "_gemm" 

864 yield helper.make_node( 

865 'Gemm', [name_agg1_2, name_agg2_2], [name_dot], 

866 alpha=1., beta=0., transA=0, transB=1) 

867 else: 

868 # *shape1, *shape2 

869 name_agg_shape1 = root + "_resh1" 

870 name_agg_shape2 = root + "_resh2" 

871 yield helper.make_node( 

872 'Concat', concat_left, [name_agg_shape1], axis=0) 

873 yield helper.make_node( 

874 'Concat', concat_right, [name_agg_shape2], axis=0) 

875 

876 # m1sh = m1.reshape((dim0, dimb, dim1)) 

877 # m2sh = m2.reshape((dim0b, dimb, dim2)) 

878 name_agg1 = root + "_aresh1" 

879 name_agg2 = root + "_aresh2" 

880 yield helper.make_node('Reshape', [name1, name_agg_shape1], [name_agg1]) 

881 yield helper.make_node('Reshape', [name2, name_agg_shape2], [name_agg2]) 

882 

883 # dot = m1sh @ numpy.transpose(m2sh, (0, 2, 1)) 

884 name_agg2_tr = root + "_aresh2_tr" 

885 yield helper.make_node( 

886 'Transpose', [name_agg2], [name_agg2_tr], perm=[0, 2, 1], 

887 name="Transpose021_%s" % id(self)) 

888 

889 name_dot = root + "_dot" 

890 yield helper.make_node( 

891 'MatMul', [name_agg1, name_agg2_tr], [name_dot]) 

892 

893 # new_shape = ([max(m1.shape[i], m2.shape[i]) for i in batch_axes] + 

894 # [m1.shape[i] for i in left if i not in batch_axes] + 

895 # [m2.shape[i] for i in right if i not in batch_axes]) 

896 concat_final = [] 

897 if len(batch_axes) > 0: 

898 name_max_dim = root + "_max_dim" 

899 concat_final.append(name_max_dim) 

900 yield helper.make_node( 

901 'Max', [name_dim0g, name_dim0bg], [name_max_dim]) 

902 

903 left_set = list(sorted(set(left) - (set(batch_axes) & set(left)))) 

904 if len(left_set) > 0: 

905 name_left_dim = root + "_left_dim" 

906 name_left_set = root + "_left_set" 

907 yield numpy_helper.from_array( 

908 numpy.array(left_set, dtype=numpy.int64), name=name_left_set) 

909 yield helper.make_node( 

910 'Gather', [name_shape1, name_left_set], [name_left_dim]) 

911 concat_final.append(name_left_dim) 

912 

913 right_set = list(sorted(set(right) - (set(batch_axes) & set(right)))) 

914 if len(right_set) > 0: 

915 name_right_dim = root + "_right_dim" 

916 name_right_set = root + "_right_set" 

917 yield numpy_helper.from_array( 

918 numpy.array(right_set, dtype=numpy.int64), name=name_right_set) 

919 yield helper.make_node( 

920 'Gather', [name_shape2, name_right_set], [name_right_dim]) 

921 concat_final.append(name_right_dim) 

922 

923 name_new_shape = root + '_new_shape' 

924 diff = ( 

925 self.full_dim - 

926 (len(batch_axes) + len(left_set) + len(right_set))) 

927 if diff > 0: 

928 names_ones = root + "_ones" 

929 yield numpy_helper.from_array( 

930 numpy.array([1 for i in range(diff)], dtype=numpy.int64), 

931 name=names_ones) 

932 concat_final.append(names_ones) 

933 

934 yield helper.make_node( 

935 'Concat', concat_final, [name_new_shape], axis=0) 

936 

937 name_final = root + '_final' 

938 yield helper.make_node( 

939 'Reshape', [name_dot, name_new_shape], [name_final]) 

940 

941 def to_onnx(self, names, opset=None, verbose=False, **kwargs): 

942 """ 

943 Converts this node into ONNX. Enumerates all ONNX node 

944 which participate to the conversion. The last one 

945 is the final output. 

946 

947 :param names: dictionary where to find already converted name 

948 :param opset: opset 

949 :param verbose: prints out intermediate results 

950 :param kwargs: additional parameter for the conversion 

951 :return: output 

952 """ 

953 if opset is None: 

954 opset = __max_supported_opset__ # pragma: no cover 

955 if verbose: 

956 print() 

957 print("to_onnx %r (%s) opset=%r." % ( 

958 self.name, 

959 ", ".join(map(lambda s: str(id(s)), self.inputs)), 

960 opset)) 

961 

962 method_name = "_to_onnx_%s" % self.name 

963 meth = getattr(self, method_name, None) 

964 if meth is None: 

965 if self.name.endswith("_mm"): 

966 raise NotImplementedError( 

967 "to_onnx not implemented for %r." 

968 "You should call method simplify_mm_nodes " 

969 "to remove it." % self.name) 

970 raise NotImplementedError( 

971 "to_onnx not implemented for %r." % self.name) 

972 for node in meth(names, verbose=verbose, opset=opset, **kwargs): 

973 if hasattr(node, 'output'): 

974 names[id(self)] = node.output[0] 

975 if verbose: 

976 print("+ OP %r -- (%s - %d)" % 

977 (node.output[0], self.name, id(self))) 

978 elif verbose: 

979 # Initializer 

980 print("+ CT %r -- (%s - %d)" % 

981 (node.name, self.name, id(self))) 

982 yield node 

983 

984 def get_dot_kind(self): 

985 """ 

986 Every matrix multiplication can be either: 

987 

988 * a simple multiplication (`M`) (undetected) 

989 * a 2D matrix multiplication (`11`) 

990 * a broadcasted matrix multiplication (`N1` or `1N`) 

991 * a batch matrix multiplication (`NN`) 

992 

993 This method returns which kind it is. 

994 """ 

995 batch_axes = self.kwargs['batch_axes'] 

996 # keep_axes = self.kwargs['keep_axes'] 

997 # sum_axes = self.kwargs['sum_axes'] 

998 # left = self.kwargs['left'] 

999 # right = self.kwargs['right'] 

1000 info = self._info 

1001 row_left = info['i_row'] 

1002 row_right = info['i_row2'] 

1003 

1004 batch_left = [row_left[k] for k in batch_axes] 

1005 batch_right = [row_right[k] for k in batch_axes] 

1006 n_left = len(batch_left) > 0 and max(batch_left) == 2 

1007 n_right = len(batch_right) > 0 and max(batch_right) == 2 

1008 return "%s%s" % ('N' if n_left else '1', 'N' if n_right else '1') 

1009 

1010 

1011class GraphEinsumSubOp: 

1012 """ 

1013 Class gathering all nodes produced to explicit einsum 

1014 operators. 

1015 

1016 :param letters: list of distinct letters 

1017 :param mat: matrix, see @see fn analyse_einsum_equation 

1018 :param lengths: lengths of every input 

1019 :param duplicates: see @see fn analyse_einsum_equation 

1020 """ 

1021 

1022 def __init__(self, letters, mat, lengths, duplicates): 

1023 self._nodes = {} 

1024 self._mark = {} 

1025 self._ops = [] 

1026 self._inputs = {} 

1027 self.last_op = None 

1028 self.last_added_op = None 

1029 self.metadata = dict( 

1030 letters=letters, mat=mat, lengths=lengths, 

1031 mat0=mat.copy(), duplicates=duplicates) 

1032 

1033 def append(self, op): 

1034 """ 

1035 Adds one input or result. 

1036 

1037 :param op: integer (an input) or an instance of @see cl EinsumSubOp. 

1038 :return: op or None if op is an integer 

1039 """ 

1040 if isinstance(op, int): 

1041 if op in self._nodes: 

1042 raise RuntimeError( # pragma: no cover 

1043 "Key %d already added." % op) 

1044 self._nodes[op] = op 

1045 self.last_added_op = op 

1046 self._inputs[op] = op 

1047 return None 

1048 if isinstance(op, EinsumSubOp): 

1049 if op in self._nodes: 

1050 raise RuntimeError( # pragma: no cover 

1051 "Key %d already added, op=%r." % (id(op), op)) 

1052 self._nodes[id(op)] = op 

1053 self._ops.append(op) 

1054 self.last_added_op = op 

1055 return op 

1056 raise TypeError( # pragma: no cover 

1057 "Unexpected type %r." % type(op)) 

1058 

1059 def mark_last_node(self): 

1060 """ 

1061 Marks the last node as the final output. 

1062 """ 

1063 if self.last_added_op is None: 

1064 raise RuntimeError("last_added_op is None.") # pragma: no cover 

1065 self.mark(-1, self.last_added_op) 

1066 

1067 def mark(self, i, op): 

1068 """ 

1069 Marks one input or result as an intermediate result 

1070 after a full einsum step. 

1071 

1072 :param op: integer (an input) or an instance of @see cl EinsumSubOp. 

1073 """ 

1074 if not isinstance(i, int): 

1075 raise TypeError( # pragma: no cover 

1076 "i must an integer not %r." % type(i)) 

1077 if i != -1 and i not in self._inputs: 

1078 raise RuntimeError( # pragma: no cover 

1079 "Input %d was not registered in %r." % (i, self._inputs)) 

1080 if isinstance(op, EinsumSubOp): 

1081 if id(op) not in self._nodes: 

1082 raise RuntimeError( # pragma: no cover 

1083 "Key %d not found, op=%r." % (id(op), op)) 

1084 self._mark[i] = op 

1085 self._mark[id(op)] = i 

1086 self.last_op = op 

1087 else: 

1088 raise TypeError( # pragma: no cover 

1089 "Unexpected type %r." % type(i)) 

1090 

1091 def __iter__(self): 

1092 "Iterates on nodes." 

1093 for op in self._ops: 

1094 yield op 

1095 

1096 def to_dot(self, **kwargs): 

1097 """ 

1098 Produces a graph in :epkg:`dot`. 

1099 

1100 :param kwargs: additional graph option 

1101 :return: string 

1102 """ 

1103 options = { 

1104 'orientation': 'portrait', 

1105 'ranksep': '0.25', 

1106 'nodesep': '0.05', 

1107 'width': '0.5', 

1108 'height': '0.1', 

1109 'size': '5', 

1110 'node': '[shape=record]', 

1111 } 

1112 options.update(kwargs) 

1113 

1114 def d2s(d): 

1115 it = [] 

1116 for k, v in sorted(d.items()): 

1117 it.append("%s=%s" % (k, v)) 

1118 return " ".join(it) 

1119 

1120 def d2sd(d): 

1121 it = [] 

1122 for k, v in sorted(d.items()): 

1123 if len(v) > 1: 

1124 it.append("%s=%s" % (k, ",".join(map(str, v)))) 

1125 return " ".join(it) 

1126 

1127 rows = ["digraph{"] 

1128 for k, v in options.items(): 

1129 if isinstance(v, str) and "[" in v: 

1130 rows.append("{} {};".format(k, v)) 

1131 else: 

1132 rows.append("{}={};".format(k, v)) 

1133 for k, v in self._nodes.items(): 

1134 if isinstance(v, int): 

1135 let = [(r, self.metadata['letters'][i]) 

1136 for i, r in enumerate(self.metadata['mat0'][v]) 

1137 if r != -1] 

1138 dup = self.metadata['duplicates'][v] 

1139 if dup is None: 

1140 dup = "" 

1141 else: 

1142 dup = " - %s" % d2sd(dup) 

1143 let.sort() 

1144 letters = "".join(_[1] for _ in let) 

1145 lab = "input %d\\\\n%s\\\\n%s%s" % ( 

1146 v, letters, str(self.metadata['mat0'][v]), dup) 

1147 sk = v 

1148 extended_lab = "" 

1149 else: 

1150 lab = "%s\\\\n%s" % (v.name, d2s(v.kwargs)) 

1151 sk = id(v) 

1152 extended_lab = v.dot_label() 

1153 if extended_lab: 

1154 extended_lab = "\\\\n" + extended_lab 

1155 

1156 if sk in self._mark and isinstance(self._mark[sk], int): 

1157 la = self._mark[sk] 

1158 lab = lab.replace("\\\\n", " - I%d\\\\n" % la) 

1159 s = ('%d [label="%s%s" style=filled ' 

1160 'fillcolor=red];' % (k, lab, extended_lab)) 

1161 else: 

1162 s = '%d [label="%s%s"];' % (k, lab, extended_lab) 

1163 rows.append(s) 

1164 if not hasattr(v, 'inputs'): 

1165 continue 

1166 for i in v.inputs: 

1167 vid = i if isinstance(i, int) else id(i) 

1168 s = "%d -> %d;" % (vid, k) 

1169 rows.append(s) 

1170 rows.append("}") 

1171 return "\n".join(rows) 

1172 

1173 def apply_sequence(self, *inputs, verbose=False, **kwargs): 

1174 """ 

1175 Applies a sequence of operations on a list of inputs. 

1176 

1177 :param inputs: inputs: 

1178 :param verbose: prints out intermediate results 

1179 :param kwargs: additional parameters, 

1180 see :meth:`apply 

1181 <mlprodict.testing.einsum.einsum_impl_classes.EinsumSubOp.apply>`. 

1182 :return: output 

1183 """ 

1184 if verbose: 

1185 print('######### apply_sequence') 

1186 data = {i: inp for i, inp in enumerate(inputs)} 

1187 last = None 

1188 for op in self: 

1189 last = op.apply(data, verbose=verbose, **kwargs) 

1190 if last is None: 

1191 raise RuntimeError( # pragma: no cover 

1192 "Sequence of operations is empty.") 

1193 return last 

1194 

1195 def clean_unused_nodes(self, verbose=False): 

1196 """ 

1197 Cleans nodes with unused outputs. 

1198 

1199 :param verbose: display intermediate information 

1200 """ 

1201 

1202 def iteration(it): 

1203 # Walks through all nodes. 

1204 is_used = {} 

1205 for node in self._ops: 

1206 if not isinstance(node, EinsumSubOp): 

1207 continue 

1208 if id(node) not in is_used: 

1209 is_used[id(node)] = [] 

1210 for inp in node.inputs: 

1211 if not isinstance(inp, EinsumSubOp): 

1212 continue 

1213 idn = id(inp) 

1214 if idn not in is_used: 

1215 is_used[idn] = [] 

1216 is_used[idn].append(id(node)) 

1217 

1218 # Remove unused nodes. 

1219 removed = [] 

1220 for k, v in is_used.items(): 

1221 if len(v) == 0: 

1222 removed.append(k) 

1223 removed = set(removed) 

1224 i_rem = [] 

1225 for i, op in enumerate(self._ops): 

1226 if not isinstance(op, EinsumSubOp): 

1227 continue 

1228 if id(op) in removed and id(op) not in self._mark: 

1229 i_rem.append((i, id(op))) 

1230 for i, idn in reversed(i_rem): 

1231 if verbose: 

1232 print("[GraphEinsumSubOp.clean_nodes] remove node " 

1233 "i=%d: %d - id=%d" % (it, i, idn)) 

1234 del self._ops[i] 

1235 del self._nodes[idn] 

1236 return len(i_rem) > 0 

1237 

1238 it = 1 

1239 while iteration(it): 

1240 it += 1 

1241 

1242 self.last_op = None 

1243 self.last_added_op = None 

1244 

1245 def simplify_mm_nodes(self, verbose=False): 

1246 """ 

1247 Node name suffixed by `mm` are an artifact to keep 

1248 the graph consistent while building it. They can 

1249 now be replaced by the equivalent node without suffix `mm`. 

1250 

1251 :param verbose: display intermediate information 

1252 """ 

1253 for op in self: 

1254 if not isinstance(op, EinsumSubOp): 

1255 continue 

1256 if op.name.endswith('_mm'): 

1257 if verbose: 

1258 print("[GraphEinsumSubOp.simplify_mm_nodes] node %r" 

1259 " - id=%d" % (op.name, id(op))) 

1260 if len(op.inputs) != 2: 

1261 raise RuntimeError( # pragma: no cover 

1262 "Expecting 2 inputs for node %r not %r id=%r." % ( 

1263 op.name, len(op.inputs), id(op))) 

1264 op.name = op.name[:-3] 

1265 op.inputs = op.inputs[:1] 

1266 

1267 def _get_forward_nodes(self): 

1268 """ 

1269 Returns the forward nodes. 

1270 """ 

1271 forward = {} 

1272 for op in self: 

1273 if isinstance(op, int): 

1274 continue 

1275 for inp in op.inputs: 

1276 key = inp if isinstance(inp, int) else id(inp) 

1277 if key in forward: 

1278 forward[key].append(op) 

1279 else: 

1280 forward[key] = [op] 

1281 return forward 

1282 

1283 def _pprint_forward(self): 

1284 rows = [] 

1285 for op in self: 

1286 line = "%r <- %s(%s)" % ( 

1287 id(op), op.name, 

1288 ", ".join(map(str, [id(_) for _ in op.inputs]))) 

1289 rows.append(line) 

1290 return "\n".join(rows) 

1291 

1292 def _replace_node_sequence(self, added, deleted): 

1293 """ 

1294 Removes a sequence of nodes. The method does not check 

1295 that the graph remains consistent. 

1296 """ 

1297 forward = self._get_forward_nodes() 

1298 key = id(deleted[-1]) 

1299 if key not in forward: 

1300 raise RuntimeError( # pragma: no cover 

1301 "Key {} missing in all forward nodes (other keys {}), " 

1302 "all keys:\n{}".format( 

1303 key, [id(_) for _ in deleted], 

1304 self._pprint_forward())) 

1305 

1306 # deletion 

1307 mark_input = None 

1308 for d in deleted: 

1309 del self._nodes[id(d)] 

1310 if id(d) in self._mark: 

1311 del self._mark[id(d)] 

1312 dels = [] 

1313 for k, v in self._mark.items(): 

1314 if id(v) == id(d): 

1315 mark_input = k 

1316 dels.append(k) 

1317 if len(dels) != 1: 

1318 raise RuntimeError( # pragma: no cover 

1319 "Input %d has more than one marked operator " 

1320 "(%r)." % (id(d), dels)) 

1321 del self._mark[dels[0]] 

1322 

1323 dels = set(id(o) for o in deleted) 

1324 rem = [] 

1325 for i, op in enumerate(self._ops): 

1326 if id(op) in dels: 

1327 rem.append(i) 

1328 if len(rem) != len(deleted): 

1329 raise RuntimeError( # pragma: no cover 

1330 "Mismatched length %r, %r, len=%r." % ( 

1331 rem, dels, len(deleted))) 

1332 for i in reversed(rem): 

1333 del self._ops[i] 

1334 self.last_add_op = None 

1335 

1336 # insertion 

1337 if added is not None: 

1338 self._ops.insert(rem[0], added) 

1339 self._nodes[id(added)] = added 

1340 for op in forward[key]: 

1341 new_inputs = list(op.inputs) 

1342 for i in range(len(op.inputs)): # pylint: disable=C0200 

1343 if id(op.inputs[i]) == key: 

1344 new_inputs[i] = added 

1345 op.inputs = tuple(new_inputs) 

1346 if mark_input is not None: 

1347 self.mark(mark_input, added) 

1348 else: 

1349 inps = deleted[0].inputs 

1350 if len(inps) != 1: 

1351 raise RuntimeError( # pragma: no cover 

1352 "More than one input. Call another method.") 

1353 inp = inps[0] 

1354 for op in forward[key]: 

1355 new_inputs = list(op.inputs) 

1356 for i in range(len(op.inputs)): # pylint: disable=C0200 

1357 if id(op.inputs[i]) == key: 

1358 new_inputs[i] = inp 

1359 op.inputs = tuple(new_inputs) 

1360 if mark_input is not None: 

1361 self.mark(mark_input, inp) 

1362 

1363 def remove_duplicate_transpose(self, verbose=False): 

1364 """ 

1365 Removes consecutive transpose by merging them. 

1366 

1367 :param verbose: display intermediate information 

1368 """ 

1369 modif = 1 

1370 while modif > 0: 

1371 modif = 0 

1372 candidates = [] 

1373 forward = self._get_forward_nodes() 

1374 for op in self: 

1375 if op.name == "transpose": 

1376 inp = op.inputs[0] 

1377 if (isinstance(inp, EinsumSubOp) and 

1378 inp.name == 'transpose' and 

1379 len(forward[id(inp)]) == 1): 

1380 candidates.append(op) 

1381 

1382 if len(candidates) > 0: 

1383 modif = 1 

1384 # Not efficient to take the first one and to 

1385 # start again but the graph should not be too big. 

1386 cand = candidates[0] 

1387 op2 = cand 

1388 op1 = cand.inputs[0] 

1389 perm1 = op1.kwargs['perm'] 

1390 perm2 = op2.kwargs['perm'] 

1391 if len(perm1) != len(perm2): 

1392 raise RuntimeError( # pragma: no cover 

1393 "Transposition should have the same length " 

1394 "%r, %r." % (perm1, perm2)) 

1395 perm = list(perm1) 

1396 for i in range(len(perm)): # pylint: disable=C0200 

1397 perm[i] = perm1[perm2[i]] 

1398 if list(range(len(perm))) == perm: 

1399 # identity, everything needs to be removed 

1400 new_op = None 

1401 else: 

1402 new_op = op2.__class__( 

1403 op2.full_dim, op2.name, op1.inputs[0], 

1404 perm=tuple(perm)) 

1405 self._replace_node_sequence(new_op, [op1, op2]) 

1406 if verbose: 

1407 print( # pragma: no cover 

1408 "[GraphEinsumSubOp.remove_duplicate_transpose] remove nodes %r" 

1409 " - id=%d,%d + %d perm1=%r perm2=%r -> perm=%r" % ( 

1410 op2.name, id(op1), id(op2), 

1411 id(new_op) if new_op is not None else -1, 

1412 perm1, perm2, perm)) 

1413 

1414 def to_onnx(self, output, *inputs, dtype=None, verbose=False, 

1415 opset=None, **kwargs): 

1416 """ 

1417 Converts the graph into ONNX. 

1418 

1419 :param output: output name 

1420 :param inputs: input names 

1421 :param dtype: type used for all operators 

1422 :param opset: desired opset, None for the last one 

1423 :param verbose: display intermediate operators 

1424 :param kwargs: additional parameter to use when building 

1425 the ONNX graph, list of supported parameters: 

1426 *name*, *ir_version*, *producer_name*, 

1427 *producer_version*, *initializer* 

1428 :return: ONNX graph 

1429 

1430 Not all graphs can be converted into ONNX. Only graphs produced 

1431 with `strategy='numpy'` can be converted otherwise the following 

1432 error shows up: 

1433 

1434 :: 

1435 

1436 NotImplementedError: to_onnx not implemented for 'matmul'. 

1437 """ 

1438 from ...onnx_tools.optim import onnx_remove_node_unused 

1439 

1440 # inputs 

1441 if opset is None: 

1442 opset = __max_supported_opset__ 

1443 if verbose: 

1444 print("[GraphEinsumSubOp.to_onnx] %r -> %s opset=%r " 

1445 "dtype=%r" % (inputs, output, opset, dtype)) 

1446 onx_inputs = [] 

1447 proto = guess_proto_dtype( 

1448 numpy.float32 if dtype is None else dtype) 

1449 lengths = self.metadata['lengths'] 

1450 names = {} 

1451 for inp, le in zip(inputs, lengths): 

1452 if isinstance(inp, tuple): 

1453 name, typ = inp 

1454 if le != len(typ.shape): 

1455 raise ValueError( # pragma: no cover 

1456 "Irreconcialable shapes for input %r: " 

1457 "%r != len(%r)." % (name, le, typ.shape)) 

1458 proto = guess_proto_dtype(guess_numpy_type(typ)) 

1459 onx_inputs.append( 

1460 helper.make_tensor_value_info(name, proto, typ.shape)) 

1461 names[len(names)] = name 

1462 else: 

1463 onx_inputs.append( 

1464 helper.make_tensor_value_info( 

1465 inp, proto, [None for i in range(le)])) 

1466 names[len(names)] = inp 

1467 

1468 # output 

1469 onx_output = helper.make_tensor_value_info( 

1470 output, proto, [None for i in range(lengths[-1])]) 

1471 

1472 # nodes 

1473 nodes = [] 

1474 inits = [] 

1475 if "initializer" in kwargs: 

1476 inits.extend(kwargs['initializer']) 

1477 for op in self: 

1478 for onx_node in op.to_onnx(names, verbose=verbose, opset=opset): 

1479 if hasattr(onx_node, 'output'): 

1480 nodes.append(onx_node) 

1481 else: 

1482 inits.append(onx_node) 

1483 

1484 # last node 

1485 last_node = nodes[-1] 

1486 nodes.append(helper.make_node( 

1487 'Identity', [last_node.output[0]], [output])) 

1488 

1489 # Builds the graph 

1490 model = helper.make_model( 

1491 opset_imports=[helper.make_operatorsetid('', opset)], 

1492 ir_version=kwargs.get('ir_version', get_ir_version(opset)), 

1493 producer_name=kwargs.get('producer_name', 'mlprodict'), 

1494 producer_version=kwargs.get('producer_version', "0.0.dev"), 

1495 graph=helper.make_graph( 

1496 name=kwargs.get('name', 'einsum'), 

1497 inputs=onx_inputs, outputs=[onx_output], 

1498 initializer=inits, nodes=nodes)) 

1499 

1500 return onnx_remove_node_unused(model)