Coverage for mlprodict/testing/einsum/einsum_impl.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Main functions decomposing einsum computation into
4more simple functions.
5"""
6import numpy
7from .einsum_impl_classes import EinsumSubOp, GraphEinsumSubOp
10def analyse_einsum_equation(equation):
11 """
12 Analyses an einsum equation.
14 :param equation: :epkg:`numpy:einsum` equation
15 :return: three results, list of letters,
16 a matrix (see below), lengths of each components,
17 duplicates
19 The returned a matrix is defined as follows:
21 .. math::
23 m_{ij}=\\left\\{\\begin{array}{ll}-1 &
24 \\text{if letter j is involved in input i} \\\\
25 p & \\text{p is position of letter j in equation i}
26 \\end{array}\\right.
27 """
28 spl = equation.strip(' ,').split("->")
29 if len(spl) != 2 or len(spl[1]) == 0 or len(spl[0]) == 0:
30 raise NotImplementedError(
31 "The function only implements the case when there are "
32 "two sides in the equation: %r." % equation)
33 inputs = list(map(lambda s: s.strip(), spl[0].split(',')))
34 output = spl[1]
35 all_letters = set(inputs[0])
37 # Set of letters
38 for inp in inputs[1:]:
39 all_letters |= set(inp)
40 letters = list(sorted(all_letters))
41 for c in letters:
42 if not(('a' <= c <= 'z') or ('A' <= c <= 'Z')):
43 raise ValueError(
44 "Equation %r must only contain lower or upper letters "
45 "but %r is not." % (equation, c))
47 rev = {c: i for i, c in enumerate(letters)}
48 for c in output:
49 if c not in letters:
50 raise ValueError(
51 "Output contains one unexpected letter %r in "
52 "equation %r." % (c, equation))
53 mat = numpy.full((len(inputs) + 1, len(letters)), -1, dtype=numpy.int8)
54 for i, inp in enumerate(inputs):
55 for k, c in enumerate(inp):
56 mat[i, rev[c]] = k
57 for k, c in enumerate(output):
58 mat[len(inputs), rev[c]] = k
59 lengths = [len(inp) for inp in inputs]
60 lengths.append(len(output))
62 # Look for duplicates
63 duplicates = []
64 for inp in inputs + [output]:
65 if len(inp) == len(set(inp)):
66 duplicates.append(None)
67 continue
68 # There is some duplicates.
69 counts = {}
70 for i, c in enumerate(inp):
71 if c in counts:
72 counts[c].append(i)
73 else:
74 counts[c] = [i]
75 duplicates.append(counts)
77 return "".join(letters), mat, lengths, duplicates
80def decompose_einsum_equation(equation, *shapes, strategy="simple",
81 clean=False, verbose=False):
82 """
83 Decomposes an equation used in :epkg:`numpy:einsum` knowing
84 the input shapes. It returns a sequence of operations
85 to do to compute the results.
87 :param equation: a string
88 :param shapes: sequence of input shapes
89 :param strategy: there are different way to decompose the equation,
90 this parameters defines the way to do it (see below)
91 :param clean: clean the unnecessary node in the graph
92 :param verbose: verbosity
93 :return: instance of @see cl GraphEinsumSubOp
95 About *strategy*:
97 * `'simple'`: align all dimensions in the alphabetical order,
98 some generic matrix multiplication remains implemented with
99 :epkg:`numpy:einsum` but only with two matrices aligned on
100 the same dimension (see @see fn numpy_extended_dot)
101 * `'numpy'`: same as `simple` but the decomposition does not use
102 :epkg:`numpy:einsum` anymore but only multiplication or
103 matrix multiplication merged into a single operator called
104 *batch_dot* (see @see fn numpy_extended_dot_matrix)
106 Available operations: *expand_dims*, *transpose*, *matmul*, *reduce_sum*,
107 *id*, *squeeze*, *diagonal*. It analyses an equation and produces a graph
108 where node are instance of class @see cl EinsumSubOp.
110 .. runpython::
111 :showcode:
113 from mlprodict.testing.einsum import decompose_einsum_equation
114 seq = decompose_einsum_equation("bac,cd,def->ebc")
115 for op in seq:
116 print(op)
118 It can be better displayed as the following.
120 .. gdot::
121 :script: DOT-SECTION
122 :process:
124 from mlprodict.testing.einsum import decompose_einsum_equation
125 seq = decompose_einsum_equation(
126 "bac,cd,def->ebc", (2, 2, 2), (2, 2), (2, 2, 2))
127 print("DOT-SECTION", seq.to_dot())
129 See notebook :ref:`einsumdecompositionrst`.
130 """
131 if len(shapes) > 0:
132 for sh in shapes:
133 if not isinstance(sh, tuple):
134 raise TypeError(
135 "All shapes must be tuples for %r is not." % sh)
136 if strategy in ("simple", "numpy"):
137 op_matmul = {'simple': 'matmul',
138 'numpy': 'batch_dot'}
139 graph = _decompose_einsum_equation_simple(
140 equation, *shapes, verbose=verbose, op_matmul=op_matmul[strategy])
141 else:
142 raise ValueError("Unknown strategy %r." % strategy)
144 # Last step: clean unused nodes.
145 if clean:
146 last_node = graph.last_added_op
147 graph.append(EinsumSubOp(last_node.full_dim, 'id', last_node))
148 graph.mark_last_node()
149 graph.simplify_mm_nodes(verbose=verbose)
150 graph.remove_duplicate_transpose(verbose=verbose)
151 graph.clean_unused_nodes(verbose=verbose)
152 else:
153 graph.mark_last_node()
154 return graph
157def apply_einsum_sequence(seq, *inputs, verbose=False, **kwargs):
158 """
159 Applies a sequence of operations on a list of inputs.
160 The sequence of operations is produced by function
161 @see fn decompose_einsum_equation.
163 :param seq: sequence of operations
164 :param inputs: inputs
165 :param kwargs: additional parameters,
166 see :meth:`apply_sequence
167 <mlprodict.testing.einsum.einsum_impl_classes.
168 GraphEinsumSubOp.apply_sequence>`.
169 :return: output
171 .. runpython::
172 :showcode:
174 import numpy
175 from mlprodict.testing.einsum import (
176 decompose_einsum_equation, apply_einsum_sequence)
178 m1 = numpy.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10
179 m2 = numpy.arange(4).reshape((2, 2)) + 100
180 m3 = numpy.arange(8).reshape((2, 2, 2)) + 1000
182 seq = decompose_einsum_equation("bac,cd,def->ebc")
183 res = apply_einsum_sequence(seq, m1, m2, m3)
184 print(res)
186 See notebook :ref:`einsumdecompositionrst`.
187 """
188 return seq.apply_sequence(*inputs, verbose=verbose, **kwargs)
191def is_transpose_identity(perm):
192 """
193 Tells if the permutation *perm* does nothing (itentity).
195 :param perm: permutation
196 :return: boolean
197 """
198 return list(perm) == list(range(len(perm)))
201def _basic_verification(lengths, shapes, equation):
202 if len(lengths) - 1 != len(shapes):
203 raise ValueError(
204 "Equation %r has %d inputs but %d shapes are given."
205 "" % (equation, len(lengths), len(shapes)))
206 for i, (le, sh) in enumerate(zip(lengths, shapes)):
207 if le != len(sh):
208 raise ValueError(
209 "Inputs %d has %d dimensions but shapes %r has %d "
210 " in equation %r." % (i, le, sh, len(sh), equation))
213def _apply_transpose_reshape(op, row):
214 """
215 Put all dimensions in the same order.
217 :param op: integer (for one input) or an operator
218 :param row: letter involved in this input (as a vector of binaries)
219 :return: last created operator
220 """
221 axes = []
222 p = 0
223 perm = []
224 for i, r in enumerate(row):
225 if r == -1:
226 axes.append((p, i))
227 else:
228 p += 1
229 perm.append((r, i))
230 op = EinsumSubOp(len(row), 'expand_dims', op, axes=tuple(axes))
231 yield op
232 perm.sort()
233 p = 0
234 new_perm = numpy.arange(len(row))
235 for i, r in enumerate(row):
236 if r == -1:
237 continue
238 new_perm[perm[p][1]] = i
239 p += 1
240 if not is_transpose_identity(new_perm):
241 op = EinsumSubOp(len(row), 'transpose', op, perm=tuple(new_perm))
242 yield op
245def _apply_squeeze_transpose(op, row_last, row_output):
246 """
247 Puts output dimension in the expected order.
248 """
249 perm = []
250 sq = []
251 for i, d in enumerate(row_output):
252 if d == -1:
253 sq.append(i)
254 else:
255 perm.append((d, i))
256 perm.sort()
257 new_perm = numpy.arange(len(row_last))
258 p = 0
259 for i, d in enumerate(row_output):
260 if d == -1:
261 continue
262 new_perm[i] = perm[p][1]
263 p += 1
264 perm = [p[1] for p in perm]
265 if not is_transpose_identity(new_perm):
266 op = EinsumSubOp(len(row_last), 'transpose', op,
267 perm=tuple(new_perm))
268 yield op
269 if len(sq) > 0:
270 op = EinsumSubOp(len(row_last), 'squeeze', op, axes=tuple(sq))
271 yield op
274def _apply_einsum_matmul(fd, op1, op2, axes, left, right, ndim,
275 op_matmul, row1, row2, verbose=False):
276 """
277 Decomposes the generic matrix multiplication into numpy operations
278 depending on the operator to use for matrix multiplication
279 *op_matmul* (see @see fn decompose_einsum_equation).
280 """
281 allowed = {'matmul', 'batch_dot', 'dot'}
282 if op_matmul not in allowed:
283 raise ValueError( # pragma: no cover
284 "Unknown operator op_matmul=%r not in %r." % (op_matmul, allowed))
285 if op_matmul == 'matmul':
286 if verbose: # pragma: no cover
287 print(" -- MATMUL -> matmul axes=%r left=%r right=%r"
288 "" % (axes, left, right))
289 yield EinsumSubOp(fd, 'matmul', op1, op2,
290 axes=axes, left=left, right=right, ndim=ndim)
292 elif len(axes) == 0 and len(set(left) & set(right)) == 0:
293 if verbose: # pragma: no cover
294 print(" -- MATMUL -> mul axes=%r left=%r right=%r"
295 "" % (axes, left, right))
296 yield EinsumSubOp(fd, 'mul', op1, op2)
298 elif (len(set(axes) & set(left)) == 0 and
299 len(set(axes) & set(right)) == 0):
301 # No intersection between axes and right: matrix multiplication
302 if verbose: # pragma: no cover
303 print(" -- MATMUL -> batch_dot axes=%r left=%r right=%r"
304 "" % (axes, left, right))
306 all_axes = set(left) | set(right) | set(axes)
307 common_axes = list(set(left) & set(right))
308 for i in range(ndim):
309 if i not in all_axes:
310 common_axes.append(i)
311 common_axes.sort()
313 # ReduceSum*
314 has_dim = set(i for i in range(len(row1)) if row1[i] >= 0)
315 right_no_left = (set(right) & has_dim) - \
316 (set(right) & (set(left) | set(axes)))
317 if right_no_left:
318 if verbose: # pragma: no cover
319 print(' -- MATMUL reduce1 has_dim=%r axes=%r' %
320 (has_dim, right_no_left))
321 op1 = EinsumSubOp(fd, 'reduce_sum_mm', op1, op2,
322 axes=tuple(sorted(right_no_left)))
323 yield op1
325 has_dim = set(i for i in range(len(row2)) if row2[i] >= 0)
326 left_no_right = (set(left) & has_dim) - \
327 (set(left) & (set(right) | set(axes)))
328 if left_no_right:
329 if verbose: # pragma: no cover
330 print(' -- MATMUL reduce2 has_dim=%r axes=%r' %
331 (has_dim, left_no_right))
332 op2 = EinsumSubOp(fd, 'reduce_sum', op2,
333 axes=tuple(sorted(left_no_right)))
334 yield op2
336 # Transpose
337 i_axes = [(-1 if i in common_axes
338 else (1 if i in axes else 0), i)
339 for i in range(ndim)]
340 i_axes.sort()
341 perm = [_[1] for _ in i_axes]
342 perm_left = [i for i in range(len(perm)) if perm[i] in left]
343 perm_right = [i for i in range(len(perm)) if perm[i] in right]
344 if not is_transpose_identity(perm):
345 op1 = EinsumSubOp(fd, 'transpose_mm', op1, op2, perm=tuple(perm))
346 yield op1
347 op2 = EinsumSubOp(fd, 'transpose', op2, perm=tuple(perm))
348 yield op2
350 # Reshape
351 all_axes = list(range(0, ndim))
352 new_axes = all_axes[-len(axes):] if len(axes) > 0 else []
353 new_common_axes = all_axes[:len(common_axes)]
354 not_in_both = []
355 for i in range(0, ndim):
356 if i not in left and i not in right and i not in common_axes:
357 not_in_both.append(i)
359 op = EinsumSubOp(fd, 'batch_dot', op1, op2,
360 batch_axes=tuple(new_common_axes),
361 keep_axes=None, sum_axes=tuple(new_axes),
362 left=tuple(perm_left), right=tuple(perm_right),
363 ndim=ndim)
364 yield op
366 # Transpose again
367 ordered_axes = (common_axes +
368 list(i for i in left if i not in right) +
369 list(i for i in right if i not in left) +
370 not_in_both)
371 rev_perm = [(a, i) for i, a in enumerate(ordered_axes)]
372 rev_perm.sort()
373 rev_perm = [p[1] for p in rev_perm]
375 if not is_transpose_identity(rev_perm):
376 op_unused = EinsumSubOp(fd, 'transpose_mm', op1,
377 op, perm=tuple(rev_perm))
378 yield op_unused
379 op = EinsumSubOp(fd, 'transpose', op, perm=tuple(rev_perm))
380 yield op
381 else:
382 raise NotImplementedError( # pragma: no cover
383 "axes and right or left have axes in common, "
384 "axes=%r left=%r right=%r ndim=%r." % (
385 axes, left, right, ndim))
388def _decompose_einsum_equation_simple(equation, *shapes, verbose=False,
389 op_matmul='matmul'):
390 """
391 Applies strategy `simple`, `numpy`
392 defined in by function @see fn decompose_einsum_equation.
394 :param op_matmul: which operator to use for matrix multiplication,
395 a single operator *matmul*, or *batch_dot* with *transposes*,
396 *reduce_sum*, or just *dot*
397 """
398 letters, mat, lengths, duplicates = analyse_einsum_equation(equation)
399 if len(letters) != mat.shape[1]:
400 raise RuntimeError( # pragma: no cover
401 "Unexpected number of letters %r, shape=%r." % (
402 letters, mat.shape))
403 if len(shapes) == 0:
404 shapes = [(2, ) * le for le in lengths[:-1]]
405 _basic_verification(lengths, shapes, equation)
407 # last_row, current_row (row = shape)
408 rows = numpy.full((2, mat.shape[1]), -1)
409 graph = GraphEinsumSubOp(letters, mat, lengths, duplicates)
410 fd = mat.shape[1]
411 if verbose:
412 print("EQUATION=%r" % equation)
413 print("LETTERS=%r" % letters, "LENGTHS=%r" % lengths)
414 print("DUPLICATES=%r" % duplicates)
416 for i, sh in enumerate(shapes):
417 if verbose:
418 print()
419 print("######### ROW %d shape=%r row=%r" % (i, sh, rows[1, :]))
420 graph.append(i)
422 # Input matrix aligned to the same dimensions.
423 op = EinsumSubOp(fd, 'id', i)
424 op.compute_output_row(rows[1, :], mat[i, :], verbose=verbose)
425 marked = graph.append(op)
427 duplicate = duplicates[i]
428 if duplicate is not None:
429 # Diagonal
430 diag = []
431 for _, v in duplicate.items():
432 if len(v) == 1:
433 continue
434 diag.append((v[0], tuple(v)))
435 op = EinsumSubOp(fd, 'diagonal', op, diag=diag)
436 op.compute_output_row(rows[1, :], mat[i, :], verbose=verbose)
437 tr_row = rows[1, :]
438 marked = graph.append(op)
439 else:
440 diag = None
441 tr_row = mat[i]
443 for op in _apply_transpose_reshape(op, tr_row):
444 op.compute_output_row(rows[1, :], verbose=verbose)
445 marked = graph.append(op)
447 # Reduction? (a dimension not used later)
448 red = []
449 for d in range(0, mat.shape[1]):
450 if (mat[i + 1:, d].max() == -1 and rows[1, d] != -1 and
451 rows[0, d] == -1):
452 red.append(d)
453 if len(red) > 0:
454 if verbose:
455 print(" -- REDUCE1 row=%d axes=%r" % (i, red))
456 print(mat)
457 print(' -')
458 print(rows)
459 op = EinsumSubOp(fd, 'reduce_sum',
460 graph.last_added_op, axes=tuple(red))
461 op.compute_output_row(rows[1, :], verbose=verbose)
462 marked = graph.append(op)
464 if graph.last_op is not None:
465 # Matrix multiplication?
466 common_dims = []
467 left = []
468 right = []
469 for d in range(0, mat.shape[1]):
470 if rows[:, d].min() >= 0:
471 if mat[i + 1:, d].max() >= 0:
472 left.append(d)
473 right.append(d)
474 else:
475 common_dims.append(d)
476 else:
477 if rows[0, d] >= 0:
478 left.append(d)
479 if rows[1, d] >= 0:
480 right.append(d)
481 if verbose:
482 print(" -- MATMUL common_dims=%r" % common_dims)
483 print(rows)
484 for iop in _apply_einsum_matmul(
485 fd, graph.last_op, op, axes=tuple(common_dims),
486 left=tuple(left), right=tuple(right),
487 ndim=rows.shape[1], op_matmul=op_matmul,
488 row1=rows[0, :], row2=rows[1, :], verbose=verbose):
489 op = iop
490 op.compute_output_row(rows[0, :], rows[1, :],
491 ab=True, verbose=verbose)
492 marked = graph.append(op)
494 # End
495 graph.mark(i, marked)
496 rows[0, :] = rows[1, :]
498 # Final output
499 if verbose:
500 print()
501 print("######### FIN row=%r" % rows[1, :])
503 if mat[len(shapes), :].max() >= 0:
504 rows[1, :] = mat[len(shapes), :]
505 red = []
506 for d in range(0, mat.shape[1]):
507 if rows[0, d] > 0 and rows[1, d] == -1:
508 red.append(d)
509 elif rows[0, d] == -1 and rows[1, d] >= 0:
510 raise RuntimeError( # pragma: no cover
511 "Issue in equation %r, variable %d, last_result is %r, "
512 "output is %r." % (equation, d, rows[0, :], rows[1, :]))
513 if len(red) > 0:
514 if verbose: # pragma: no cover
515 print("-- REDUCE2 axes=%r" % red)
516 print(mat)
517 op = EinsumSubOp(fd, 'reduce_sum', op, axes=tuple(red))
518 graph.append(op)
519 op.compute_output_row(rows[1, :], verbose=verbose)
521 # Removes empty axes.
522 for op in _apply_squeeze_transpose(op, rows[1, :], mat[len(shapes), :]):
523 op.compute_output_row(rows[1, :], verbose=verbose)
524 graph.append(op)
525 return graph