Coverage for mlprodict/testing/einsum/einsum_impl_classes.py: 96%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=C0302
2"""
3@file
4@brief Classes representing the sequence of matrix operations to
5implement einsum computation.
6"""
7import numpy
8from onnx import helper, numpy_helper
9from ...onnx_tools.onnx2py_helper import guess_proto_dtype
10from ...npy.xop_variable import guess_numpy_type
11from ... import __max_supported_opset__, get_ir_version
12from .blas_lapack import gemm_dot
13from .einsum_impl_ext import (
14 numpy_extended_dot, numpy_diagonal,
15 _numpy_extended_dot_equation,
16 numpy_extended_dot_python,
17 numpy_extended_dot_matrix)
20def single_axes(axes):
21 """
22 *axes* contains positive values, then it is the position
23 of this axis in the original matrix, otherwise it is -1
24 meaning this axis is an added single dimension to align
25 all the dimensions based on the einsum equation.
27 :param axes: axes described above
28 :return: list of integer in set `{1, 2}`, 1 for
29 a single axis, 2 otherwise
30 """
31 if axes is None:
32 return axes
33 return [(1 if a == -1 else 2) for a in axes]
36class EinsumSubOp:
37 """
38 Defines a sub operation used in Einsum decomposition.
40 :param name: name (reshape, transpose, reduce_sum, matmul, id,
41 squeeze, diagonal, mul, batch_dot)
42 :param inputs: inputs
43 :param kwargs: arguments
45 Operator suffixed by `_mm` (*transpose_mm*, *reduce_sum_mm*)
46 are equivalent to the same operator without the suffix
47 but takes two inputs and only changes the first one.
49 Attributes `_info` summarizes the known information
50 about dimensions. Many of them are empty because inserted.
51 Value `1` means it was the case, `2` means it is a plain dimension.
52 """
53 _allowed = {'expand_dims', 'transpose', 'reduce_sum', 'matmul', 'id',
54 'squeeze', 'diagonal', 'mul', 'batch_dot',
55 'transpose_mm', 'reduce_sum_mm'}
57 def __init__(self, full_dim, name, *inputs, **kwargs):
58 self.full_dim = full_dim
59 self.name = name
60 self.inputs = inputs
61 self.kwargs = kwargs
62 self._info = {}
63 if name not in EinsumSubOp._allowed:
64 raise ValueError(
65 "Unexpected name %r. It should be in %r."
66 "" % (name, EinsumSubOp._allowed))
67 if len(inputs) not in (1, 2):
68 raise RuntimeError(
69 "Inputs must contains 1 or 2 inputs not %d." % len(inputs))
70 if name == 'matmul' and len(inputs) != 2:
71 raise RuntimeError(
72 "Inputs must contains 2 inputs not %d for operator 'matmul'."
73 "" % len(inputs))
74 for i, inp in enumerate(inputs):
75 if not isinstance(inp, (int, EinsumSubOp)):
76 raise TypeError(
77 "Input %d has type %r, int or EinsumSubOp is expected."
78 "" % (i, type(inp)))
79 self._check_()
81 def _check_(self):
82 if self.name == 'transpose':
83 self._check_arg_('perm', tuple)
84 perm = self.kwargs['perm']
85 if len(perm) != len(set(perm)):
86 raise RuntimeError( # pragma: no cover
87 "perm has duplicated values %r (name=%r)."
88 "" % (perm, self.name))
89 if list(perm) == list(range(len(perm))):
90 raise ValueError( # pragma: no cover
91 "Transpose = identity perm={}. It must be removed."
92 "".format(perm))
93 elif self.name == 'matmul':
94 self._check_arg_('axes', tuple)
95 self._check_arg_('left', tuple)
96 self._check_arg_('right', tuple)
97 axes = self.kwargs['axes']
98 left = self.kwargs['left']
99 right = self.kwargs['right']
100 for a in axes:
101 if a in left and a in right:
102 raise RuntimeError( # pragma: no cover
103 "One axis belongs to every set (axes, left, right). "
104 "axes=%r, left=%r, right=%r." % (axes, left, right))
106 def __repr__(self):
107 inps = ", ".join(map(str, self.inputs))
108 kw = ", ".join("%s=%r" % (k, w) for k, w in self.kwargs.items())
109 m = "%s(%r, %s, %s)" % (
110 self.__class__.__name__, self.name, inps, kw)
111 return m
113 def dot_label(self):
114 """
115 Displays some informations useful to understand the operator.
116 """
117 if self.name == "matmul":
118 ndim = self.kwargs['ndim']
119 axes = self.kwargs['axes']
120 left = self.kwargs['left']
121 right = self.kwargs['right']
122 eq = _numpy_extended_dot_equation(ndim, ndim, axes, left, right)
123 eq = eq.replace(">", "\\\\>")
124 return "~" + eq
125 return None
127 def _check_arg_(self, name, typ, empty=False):
128 if name not in self.kwargs:
129 raise RuntimeError( # pragma: no cover
130 "Parameter %r not found for operator %r." % (name, self.name))
131 if empty and self.kwargs[name] is None:
132 return
133 if not isinstance(self.kwargs[name], typ):
134 raise TypeError( # pragma: no cover
135 "Unexpected type %r for parameter %r and parameter %r."
136 "" % (type(self.kwargs[name]), name, self.name))
138 def _check_row_(self, row, inp=False, verbose=False):
139 """
140 Checks input or output is valid.
141 """
142 if verbose:
143 if inp:
144 print('<<' if inp else '>>', self.name, row, self.kwargs)
145 else:
146 print('<<' if inp else '>>', self.name, row)
148 def _compute_output_row_id(self, row, row2=None, ab=False, verbose=False):
149 if ab:
150 raise RuntimeError("ab option not allowed.") # pragma: no cover
151 self._check_row_(row, True, verbose=verbose)
152 row[:] = row2[:]
153 self._check_row_(row, verbose=verbose)
155 def _compute_output_row_transpose(self, row, row2=None, ab=False, verbose=False):
156 if ab:
157 self._compute_output_row_transpose(row2, verbose=verbose)
158 return
159 self._check_row_(row, True, verbose=verbose)
160 self._check_arg_('perm', tuple)
161 if len(self.kwargs['perm']) != len(row):
162 raise RuntimeError( # pragma: no cover
163 "Unexpected permutation %r (row=%r)."
164 "" % (self.kwargs['perm'], row))
165 perm = self.kwargs['perm']
166 cpy = row.copy()
167 for i, p in enumerate(perm):
168 row[i] = cpy[p]
169 self._check_row_(row, verbose=verbose)
171 def _compute_output_row_transpose_mm(self, row, row2=None, ab=False, verbose=False):
172 if not ab:
173 raise RuntimeError("ab must be True.") # pragma: no cover
174 self._check_row_(row, True, verbose=verbose)
175 if row2 is None:
176 raise RuntimeError( # pragma: no cover
177 "transpose_mm expects a second input.")
178 self._compute_output_row_transpose(row, row2=None, verbose=verbose)
180 def _compute_output_row_expand_dims(self, row, row2=None, ab=False, verbose=False):
181 if ab:
182 raise RuntimeError("ab option not allowed.") # pragma: no cover
183 self._check_row_(row, True, verbose=verbose)
184 self._check_arg_('axes', tuple)
185 axes = self.kwargs['axes']
186 for axis in axes:
187 if not isinstance(axis, tuple):
188 raise TypeError( # pragma: no cover
189 "Parameter axes of expand_dims should be a tuple of "
190 "tuple, axes=%r." % axes)
191 if row[axis[1]] != -1:
192 raise RuntimeError( # pragma: no cover
193 "Dimension should be -1 in row %r axis=%r." % (
194 row, self.kwargs['axis']))
195 self._check_row_(row, verbose=verbose)
197 def _compute_output_row_reduce_sum(self, row, row2=None, ab=False, verbose=False):
198 if ab:
199 raise RuntimeError("ab option not allowed.") # pragma: no cover
200 self._check_row_(row, True, verbose=verbose)
201 self._check_arg_('axes', tuple)
202 for a in self.kwargs['axes']:
203 row[a] = -1
204 self._check_row_(row, verbose=verbose)
206 def _compute_output_row_reduce_sum_mm(self, row, row2=None, ab=False, verbose=False):
207 if not ab:
208 raise RuntimeError("ab must be true.") # pragma: no cover
209 self._check_row_(row2, True, verbose=verbose)
210 if row2 is None:
211 raise RuntimeError( # pragma: no cover
212 "reduce_sum_mm expects a second input.")
213 self._compute_output_row_reduce_sum(row, row2=None, verbose=verbose)
215 def _compute_output_row_squeeze(self, row, row2=None, ab=False, verbose=False):
216 if ab:
217 raise RuntimeError("ab option not allowed.") # pragma: no cover
218 self._check_row_(row, True, verbose=verbose)
219 self._check_arg_('axes', tuple)
220 for a in self.kwargs['axes']:
221 row[a] = -1
222 self._check_row_(row, verbose=verbose)
224 def _compute_output_row_diagonal(self, row, row2=None, ab=False, verbose=False):
225 if ab:
226 raise RuntimeError("ab option not allowed.") # pragma: no cover
227 self._check_row_(row, True, verbose=verbose)
228 self._check_arg_('diag', list)
229 to_remove = []
230 for choice, choices in self.kwargs['diag']:
231 for ch in choices:
232 if ch != choice:
233 to_remove.append(ch)
234 for i in range(len(row)): # pylint: disable=C0200
235 if row[i] in choices:
236 if row[i] != choice:
237 row[i] = choice
238 to_remove.sort()
239 for r in to_remove:
240 for i in range(len(row)): # pylint: disable=C0200
241 if row[i] == r:
242 raise RuntimeError( # pragma: no cover
243 "Unexpected result r=%r row=%r to_remove=%r "
244 "diag=%r." % (
245 r, row, to_remove, self.kwargs['diag']))
246 if row[i] > r:
247 row[i] -= 1
248 self._check_row_(row, verbose=verbose)
250 def _compute_output_row_matmul(self, row, row2=None, ab=False, verbose=False):
251 if not ab:
252 raise RuntimeError("ab must be True.") # pragma: no cover
253 self._check_row_(row, True, verbose=verbose)
254 self._check_row_(row2, True, verbose=verbose)
255 self._check_arg_('axes', tuple)
256 self._check_arg_('left', tuple)
257 self._check_arg_('right', tuple)
258 self._check_arg_('ndim', int)
259 if row2 is None:
260 raise RuntimeError( # pragma: no cover
261 "matmul expects two inputs.")
262 if verbose:
263 ndim = self.kwargs['ndim']
264 axes = self.kwargs['axes']
265 left = self.kwargs['left']
266 right = self.kwargs['right']
267 print(" MATMUL %r @ %r axes=%r left=%r right=%r - eq=%s" % (
268 row, row2, axes, left, right,
269 _numpy_extended_dot_equation(ndim, ndim, axes, left, right)))
270 row2[:] = numpy.maximum(row, row2)
271 for a in self.kwargs['axes']:
272 if a not in self.kwargs['right']:
273 row2[a] = -1
274 self._check_row_(row2, verbose=verbose)
276 def _compute_output_row_batch_dot(self, row, row2=None, ab=False, verbose=False):
277 if not ab:
278 raise RuntimeError("ab must be True.") # pragma: no cover
279 self._check_row_(row, True, verbose=verbose)
280 self._check_row_(row2, True, verbose=verbose)
281 self._check_arg_('batch_axes', tuple)
282 self._check_arg_('keep_axes', tuple, empty=True)
283 self._check_arg_('sum_axes', tuple)
284 self._check_arg_('left', tuple)
285 self._check_arg_('right', tuple)
286 self._check_arg_('ndim', int)
287 if row2 is None:
288 raise RuntimeError(
289 "batch_dot expects two inputs.") # pragma: no cover
290 if verbose:
291 batch_axes = self.kwargs['batch_axes']
292 keep_axes = self.kwargs['keep_axes']
293 sum_axes = self.kwargs['sum_axes']
294 left = self.kwargs['left']
295 right = self.kwargs['right']
296 ndim = self.kwargs['ndim']
297 print(" BATCH_DOT batch_axes=%r keep_axes=%r sum_axes=%r "
298 "left=%r right=%r eq=%r" % (
299 batch_axes, keep_axes, sum_axes, left, right,
300 _numpy_extended_dot_equation(ndim, ndim, sum_axes, left, right)))
301 row2[:] = numpy.maximum(row, row2)
302 for a in self.kwargs['sum_axes']:
303 if a not in self.kwargs['right']:
304 row2[a] = -1
305 self._check_row_(row2, verbose=verbose)
307 def _compute_output_row_mul(self, row, row2=None, ab=False, verbose=False):
308 if not ab:
309 raise RuntimeError("ab must be True.") # pragma: no cover
310 self._check_row_(row, True, verbose=verbose)
311 self._check_row_(row2, True, verbose=verbose)
312 if row2 is None:
313 raise RuntimeError("mul expects two inputs.") # pragma: no cover
314 if verbose:
315 print( # pragma: no cover
316 " MUL %r @ %r" % (row, row2))
317 row2[:] = numpy.maximum(row, row2)
318 self._check_row_(row2, verbose=verbose)
320 def compute_output_row(self, row, row2=None, ab=False, verbose=False):
321 """
322 Updates *row* based on the operator.
323 """
324 method_name = "_compute_output_row_%s" % self.name
325 meth = getattr(self, method_name, None)
326 if meth is None:
327 raise NotImplementedError( # pragma: no cover
328 "compute_output_row not implemented for %r." % self.name)
329 if verbose and ab:
330 print(" -- called as a binary operator")
331 self.add_info(i_row=single_axes(row), i_row2=single_axes(row2))
332 meth(row, row2=row2, ab=ab, verbose=verbose)
333 self.add_info(o_row=single_axes(row), o_row2=single_axes(row2))
335 def add_info(self, **kwargs):
336 """
337 Adds information to the node.
339 :param kwargs: dictionary
340 """
341 for k, v in kwargs.items():
342 if k in self._info:
343 raise KeyError( # pragma: no cover
344 "Key %r already added (operator %r)." % (k, self.name))
345 self._info[k] = v
347 def _check_inputs_(self, n_expected, check_dim=False):
348 if len(self.inputs) != n_expected:
349 raise RuntimeError( # pragma: no cover
350 "Number of inputs must be %d not %d for operator %r."
351 "" % (n_expected, len(self.inputs), self.name))
353 def _check_shape_(self, m):
354 if len(m.shape) != self.full_dim:
355 raise RuntimeError( # pragma: no cover
356 "Number of dimensions %r is different from expected value "
357 "%d." % (m.shape, self.full_dim))
359 def _get_data(self, data, key):
360 if isinstance(key, int):
361 if key not in data:
362 raise RuntimeError( # pragma: no cover
363 "Unable to find key %d in %r." % (
364 key, list(sorted(data))))
365 return data[key]
366 if isinstance(key, EinsumSubOp):
367 if id(key) not in data:
368 raise RuntimeError( # pragma: no cover
369 "Unable to find key %d in %r." % (
370 id(key), list(sorted(data))))
371 return data[id(key)]
372 raise TypeError( # pragma: no cover
373 "Unexpected input type %r." % type(key))
375 def _apply_id(self, data, verbose=False, **kwargs):
376 self._check_inputs_(1)
377 inp = self.inputs[0]
378 output = self._get_data(data, inp)
379 return output
381 def _apply_diagonal(self, data, verbose=False, **kwargs):
382 self._check_inputs_(1)
383 inp = self.inputs[0]
384 m = self._get_data(data, inp)
385 if verbose:
386 print( # pragma: no cover
387 "- %s, shape=%r diag=%r" % (
388 self.name, m.shape, self.kwargs['diag']))
389 diag = self.kwargs['diag']
390 if len(diag) != 1:
391 raise NotImplementedError( # pragma: no cover
392 "Not implemented with more than one duplicated indice "
393 "%r." % diag)
394 diag0 = diag[0]
395 output = numpy_diagonal(m, axis=diag0[0], axes=diag0[1])
396 return output
398 def _apply_expand_dims(self, data, verbose=False, **kwargs):
399 self._check_inputs_(1)
400 inp = self.inputs[0]
401 m = self._get_data(data, inp)
402 if verbose:
403 print("- %s, shape=%r axes=%r" % (
404 self.name, m.shape, self.kwargs['axes']))
405 output = m
406 for axis in reversed(self.kwargs['axes']):
407 output = numpy.expand_dims(output, axis[0])
408 return output
410 def _apply_transpose(self, data, verbose=False, **kwargs):
411 self._check_inputs_(1, True)
412 inp = self.inputs[0]
413 m = self._get_data(data, inp)
414 self._check_shape_(m)
415 if verbose:
416 print("- %s, shape=%r perm=%r" % (
417 self.name, m.shape, self.kwargs['perm']))
418 output = numpy.transpose(m, self.kwargs['perm'])
419 self._check_shape_(output)
420 return output
422 def _apply_transpose_mm(self, data, verbose=False, **kwargs):
423 self._check_inputs_(2, True)
424 inp = self.inputs[0]
425 m = self._get_data(data, inp)
426 self._check_shape_(m)
427 if verbose:
428 print( # pragma: no cover
429 "- %s, shape=%r perm=%r" % (
430 self.name, m.shape, self.kwargs['perm']))
431 output = numpy.transpose(m, self.kwargs['perm'])
432 self._check_shape_(output)
433 return output
435 def _apply_matmul(self, data, verbose=False, **kwargs):
436 self._check_inputs_(2)
437 inp1 = self.inputs[0]
438 inp2 = self.inputs[1]
439 m1 = self._get_data(data, inp1)
440 m2 = self._get_data(data, inp2)
441 self._check_shape_(m1)
442 self._check_shape_(m2)
443 axes = self.kwargs['axes']
444 left = self.kwargs['left']
445 right = self.kwargs['right']
447 if verbose:
448 print("- %s, shapes=%r @ %r axes=%r left=%r right=%r" % (
449 self.name, m1.shape, m2.shape, axes, left, right))
451 impl = kwargs.get('matmul_impl', None)
452 if impl == 'pyf':
453 output = numpy_extended_dot_matrix(m1, m2, axes, left, right,
454 verbose=verbose)
455 elif impl == 'py':
456 output = numpy_extended_dot_python(m1, m2, axes, left, right,
457 verbose=verbose)
458 elif impl is None:
459 output = numpy_extended_dot(m1, m2, axes, left, right,
460 verbose=verbose)
461 else:
462 raise ValueError(
463 "Unknown implementation of numpy_extended_dot ({}).".format(impl))
464 self._check_shape_(output)
465 return output
467 def _apply_mul(self, data, verbose=False, **kwargs):
468 self._check_inputs_(2)
469 inp1 = self.inputs[0]
470 inp2 = self.inputs[1]
471 m1 = self._get_data(data, inp1)
472 m2 = self._get_data(data, inp2)
473 self._check_shape_(m1)
474 self._check_shape_(m2)
476 if verbose:
477 print( # pragma: no cover
478 "- %s, shapes=%r @ %r" % (self.name, m1.shape, m2.shape))
480 output = m1 * m2
481 self._check_shape_(output)
482 return output
484 def _apply_batch_dot(self, data, verbose=False, **kwargs):
485 self._check_inputs_(2)
486 inp1 = self.inputs[0]
487 inp2 = self.inputs[1]
488 m1 = self._get_data(data, inp1)
489 m2 = self._get_data(data, inp2)
490 self._check_shape_(m1)
491 self._check_shape_(m2)
492 batch_axes = self.kwargs['batch_axes']
493 keep_axes = self.kwargs['keep_axes']
494 sum_axes = self.kwargs['sum_axes']
495 left = self.kwargs['left']
496 right = self.kwargs['right']
498 if verbose:
499 print("- %s, shapes=%r @ %r batch_axes=%r keep_axes=%r "
500 "sum_axes=%r" % (
501 self.name, m1.shape, m2.shape, batch_axes, keep_axes, sum_axes))
503 if len(m1.shape) != len(m2.shape):
504 raise RuntimeError( # pragma: no cover
505 "batch_dot only work with two tensors with the same number "
506 "of dimensions not %r @ %r." % (m1.shape, m2.shape))
508 dim0 = int(numpy.prod([m1.shape[i] for i in batch_axes]))
509 dim0b = int(numpy.prod([m2.shape[i] for i in batch_axes]))
510 dimb = int(-1 if keep_axes is None else numpy.prod(
511 [m1.shape[i] for i in keep_axes]))
512 dim1 = int(numpy.prod([m1.shape[i] for i in sum_axes]))
513 dim2 = int(numpy.prod([m2.shape[i] for i in sum_axes]))
515 if verbose:
516 print("- %s, reshape=%r into %r" % (
517 self.name, m1.shape, (dim0, dimb, dim1)))
518 print("- %s, reshape=%r into %r" % (
519 self.name, m2.shape, (dim0b, dimb, dim2)))
520 m1sh = m1.reshape((dim0, dimb, dim1))
521 m2sh = m2.reshape((dim0b, dimb, dim2))
523 batch_kind = self.get_dot_kind()
524 if batch_kind in ('11', 'N1', 'N1'):
525 m1sh = m1sh.reshape((-1, m1sh.shape[-1]))
526 m2sh = m2sh.reshape((-1, m2sh.shape[-1]))
527 if verbose:
528 print("- %s, use gemm with shape %r, %r" % (
529 self.name, m1sh.shape, m2sh.shape))
530 dot = gemm_dot(m1sh, m2sh, False, True)
531 else:
532 dot = m1sh @ numpy.transpose(m2sh, (0, 2, 1))
534 # new shape
535 new_shape = ([max(m1.shape[i], m2.shape[i]) for i in batch_axes] +
536 [m1.shape[i] for i in left if i not in batch_axes] +
537 [m2.shape[i] for i in right if i not in batch_axes])
538 while len(new_shape) < len(m1.shape):
539 new_shape.append(1)
541 if verbose:
542 taken = set(batch_axes) | set(sum_axes)
543 ax = [i for i in range(len(m1.shape)) if i not in taken]
544 print("- %s, shapes=%r @ %r -> %r" % (
545 self.name, m1sh.shape, m2sh.shape, dot.shape))
546 print("- %s, batch_axes=%r ax=%r new_shape=%r left=%r right=%r" % (
547 self.name, batch_axes, ax, new_shape, left, right))
549 output = dot.reshape(tuple(new_shape))
550 self._check_shape_(output)
551 return output
553 def _apply_reduce_sum(self, data, verbose=False, **kwargs):
554 self._check_inputs_(1)
555 inp = self.inputs[0]
556 m = self._get_data(data, inp)
557 self._check_shape_(m)
558 axes = self.kwargs['axes']
559 if verbose:
560 print("- %s, shape=%r axes=%r" % (
561 self.name, m.shape, self.kwargs['axes']))
562 output = numpy.sum(m, axis=axes, keepdims=True)
563 self._check_shape_(output)
564 return output
566 def _apply_reduce_sum_mm(self, data, verbose=False, **kwargs):
567 self._check_inputs_(2, True)
568 inp = self.inputs[0]
569 m = self._get_data(data, inp)
570 self._check_shape_(m)
571 if verbose:
572 print("- %s, shape=%r axes=%r" % (
573 self.name, m.shape, self.kwargs['axes']))
574 output = numpy.sum(m, self.kwargs['axes'])
575 self._check_shape_(output)
576 return output
578 def _apply_squeeze(self, data, verbose=False, **kwargs):
579 self._check_inputs_(1)
580 inp = self.inputs[0]
581 m = self._get_data(data, inp)
582 axes = self.kwargs['axes']
583 if verbose:
584 print("- %s, shape=%r axes=%r" % (
585 self.name, m.shape, self.kwargs['axes']))
586 output = m
587 for a in axes[::-1]:
588 output = numpy.squeeze(output, axis=a)
589 return output
591 def apply(self, data, verbose=False, **kwargs):
592 """
593 Applies one operator on the data.
595 :param data: dictionary storing the results
596 :param verbose: prints out intermediate results
597 :param kwargs: additional parameters, see
598 methods `_apply*`
599 :return: output
601 Known additional paramaters:
603 * 'matmul_impl': if None calls :epkg:`numpy:einsum` through
604 @see fn numpy_extended_dot (default) or 'py' to call
605 @see fn numpy_extended_dot_python instead.
606 """
607 if verbose:
608 print()
609 print("apply %r (%s)." % (
610 self.name, ", ".join(map(lambda s: str(id(s)), self.inputs))))
612 method_name = "_apply_%s" % self.name
613 meth = getattr(self, method_name, None)
614 if meth is None:
615 raise NotImplementedError( # pragma: no cover
616 "apply not implemented for %r." % self.name)
617 output = meth(data, verbose, **kwargs)
619 data[id(self)] = output
620 if verbose:
621 print("+ %s, shape=%r -- %d" % (self.name, output.shape, id(self)))
622 return output
624 def _onnx_name(self):
625 return 'einsum%d_%s' % (id(self), self.name[:2])
627 def _check_onnx_opset_(self, opset, limit):
628 if opset is not None and opset < limit:
629 raise RuntimeError( # pragma: no cover
630 "Opset (%r) must be >= %r for operator %r."
631 "" % (opset, limit, self.name))
633 def _to_onnx_id(self, names, opset, verbose=False, **kwargs):
634 self._check_inputs_(1)
635 inp = self.inputs[0]
636 name = self._get_data(names, inp)
637 yield helper.make_node('Identity', [name], [self._onnx_name()])
639 def _to_onnx_expand_dims(self, names, opset, verbose=False, **kwargs):
640 self._check_inputs_(1)
641 self._check_onnx_opset_(opset, 11)
642 inp = self.inputs[0]
643 name = self._get_data(names, inp)
644 axes = self.kwargs['axes']
645 name_axes = name + '_axes'
646 yield numpy_helper.from_array(
647 numpy.array([a[1] for a in axes], dtype=numpy.int64), name=name_axes)
648 s_axes = "".join(map(str, [a[1] for a in axes]))
649 yield helper.make_node(
650 'Unsqueeze', [name, name_axes], [self._onnx_name()],
651 name='Unsqueeze%s_%d' % (s_axes, id(self)))
653 def _to_onnx_squeeze(self, names, opset, verbose=False, **kwargs):
654 self._check_inputs_(1)
655 self._check_onnx_opset_(opset, 11)
656 inp = self.inputs[0]
657 name = self._get_data(names, inp)
658 axes = self.kwargs['axes']
659 name_axes = name + '_axes'
660 yield numpy_helper.from_array(
661 numpy.array(axes, dtype=numpy.int64), name=name_axes)
662 s_axes = "".join(map(str, axes))
663 yield helper.make_node(
664 'Squeeze', [name, name_axes], [self._onnx_name()],
665 name='Squeeze%s_%d' % (s_axes, id(self)))
667 def _to_onnx_transpose(self, names, opset, verbose=False, **kwargs):
668 self._check_inputs_(1)
669 inp = self.inputs[0]
670 name = self._get_data(names, inp)
671 perm = self.kwargs['perm']
672 s_perm = "".join(map(str, perm))
673 yield helper.make_node(
674 'Transpose', [name], [self._onnx_name()], perm=perm,
675 name='Transpose%s_%d' % (s_perm, id(self)))
677 def _to_onnx_reduce_sum(self, names, opset, verbose=False, **kwargs):
678 self._check_inputs_(1)
679 self._check_onnx_opset_(opset, 11)
680 inp = self.inputs[0]
681 name = self._get_data(names, inp)
682 axes = self.kwargs['axes']
683 name_axes = self._onnx_name() + '_axes'
684 yield numpy_helper.from_array(
685 numpy.array(axes, dtype=numpy.int64), name=name_axes)
686 s_axes = "".join(map(str, axes))
687 yield helper.make_node(
688 'ReduceSum', [name, name_axes], [self._onnx_name()], keepdims=1,
689 name='ReduceSum%s_%d' % (s_axes, id(self)))
691 def _to_onnx_mul(self, data, verbose=False, **kwargs):
692 self._check_inputs_(2)
693 inp1 = self.inputs[0]
694 inp2 = self.inputs[1]
695 m1 = self._get_data(data, inp1)
696 m2 = self._get_data(data, inp2)
697 yield helper.make_node('Mul', [m1, m2], [self._onnx_name()])
699 def _to_onnx_batch_dot(self, names, opset, verbose=False, **kwargs): # pylint: disable=R0914
700 self._check_inputs_(2)
701 self._check_onnx_opset_(opset, 13)
702 inp1, inp2 = self.inputs[:2] # pylint: disable=W0632
703 name1 = self._get_data(names, inp1)
704 name2 = self._get_data(names, inp2)
706 batch_axes = self.kwargs['batch_axes']
707 keep_axes = self.kwargs['keep_axes']
708 sum_axes = self.kwargs['sum_axes']
709 left = self.kwargs['left']
710 right = self.kwargs['right']
711 root = self._onnx_name()
713 def return_name_one():
714 name_one = root + "_1"
715 return name_one, numpy_helper.from_array(
716 numpy.array([1], dtype=numpy.int64), name=name_one)
718 name_one = None
719 name_shape1 = root + "_shape1"
720 name_shape2 = root + "_shape2"
721 concat_left = []
722 concat_right = []
723 yield helper.make_node('Shape', [name1], [name_shape1])
724 yield helper.make_node('Shape', [name2], [name_shape2])
726 if len(batch_axes) > 0:
727 name_batch_axes = root + "_batch_axes"
728 yield numpy_helper.from_array(
729 numpy.array(batch_axes, dtype=numpy.int64), name=name_batch_axes)
731 if len(sum_axes) > 0:
732 name_sum_axes = root + "_sum_axes"
733 yield numpy_helper.from_array(
734 numpy.array(sum_axes, dtype=numpy.int64), name=name_sum_axes)
736 # dim0 = int(numpy.prod([m1.shape[i] for i in batch_axes]))
737 # dim0b = int(numpy.prod([m2.shape[i] for i in batch_axes]))
738 if len(batch_axes) > 1:
739 name_dim0 = root + "_dim0"
740 name_dim0b = root + "_dim0b"
741 name_dim0g = name_dim0 + 'g'
742 name_dim0bg = name_dim0b + 'g'
743 concat_left.append(name_dim0)
744 concat_right.append(name_dim0b)
745 yield helper.make_node(
746 'Gather', [name_shape1, name_batch_axes], [name_dim0g])
747 yield helper.make_node(
748 'Gather', [name_shape2, name_batch_axes], [name_dim0bg])
749 yield helper.make_node(
750 'ReduceProd', [name_dim0g], [name_dim0], keepdims=1)
751 yield helper.make_node(
752 'ReduceProd', [name_dim0bg], [name_dim0b], keepdims=1)
753 elif len(batch_axes) == 1:
754 name_dim0g = root + "_dim0g"
755 name_dim0bg = root + "_dim0bg"
756 name_dim0 = name_dim0g
757 name_dim0b = name_dim0bg
758 concat_left.append(name_dim0)
759 concat_right.append(name_dim0b)
760 yield helper.make_node(
761 'Gather', [name_shape1, name_batch_axes], [name_dim0g])
762 yield helper.make_node(
763 'Gather', [name_shape2, name_batch_axes], [name_dim0bg])
764 else:
765 if name_one is None:
766 name_one, cst_init = return_name_one()
767 yield cst_init
768 name_dim0 = name_one
769 name_dim0b = name_one
770 concat_left.append(name_dim0)
771 concat_right.append(name_dim0b)
773 # dimb = int(-1 if keep_axes is None else numpy.prod(
774 # [m1.shape[i] for i in keep_axes]))
775 if keep_axes in (-1, None) or len(keep_axes) == 0:
776 name_dimb = root + "__1"
777 concat_left.append(name_dimb)
778 concat_right.append(name_dimb)
779 yield numpy_helper.from_array(
780 numpy.array([-1], dtype=numpy.int64), name=name_dimb)
781 elif len(keep_axes) == 1:
782 name_keep_axes = root + "_keep_axes"
783 name_dimb = root + "_dimb"
784 name_dimbg = name_dimb
785 concat_left.append(name_dimb)
786 concat_right.append(name_dimb)
787 yield numpy_helper.from_array(
788 numpy.array(keep_axes, dtype=numpy.int64), name=name_keep_axes)
789 yield helper.make_node(
790 'Gather', [name_shape1, name_keep_axes], [name_dimbg])
791 else:
792 name_keep_axes = root + "_keep_axes"
793 name_dimb = root + "_dimb"
794 name_dimbg = name_dimb + 'g'
795 concat_left.append(name_dimb)
796 concat_right.append(name_dimb)
797 yield numpy_helper.from_array(
798 numpy.array(keep_axes, dtype=numpy.int64), name=name_keep_axes)
799 yield helper.make_node(
800 'Gather', [name_shape1, name_keep_axes], [name_dimbg])
801 yield helper.make_node(
802 'ReduceProd', [name_dimbg], [name_dimb], keepdims=1)
804 # dim1 = int(numpy.prod([m1.shape[i] for i in sum_axes]))
805 # dim2 = int(numpy.prod([m2.shape[i] for i in sum_axes]))
807 if len(sum_axes) == 0:
808 if name_one is None:
809 name_one, cst_init = return_name_one()
810 yield cst_init
811 name_dim1 = name_one
812 name_dim2 = name_one
813 concat_left.append(name_dim1)
814 concat_right.append(name_dim2)
815 elif len(sum_axes) == 1:
816 name_dim1 = root + "_dim1"
817 name_dim2 = root + "_dim2"
818 name_dim1g = name_dim1
819 name_dim2g = name_dim2
820 concat_left.append(name_dim1)
821 concat_right.append(name_dim2)
822 yield helper.make_node(
823 'Gather', [name_shape1, name_sum_axes], [name_dim1g])
824 yield helper.make_node(
825 'Gather', [name_shape2, name_sum_axes], [name_dim2g])
826 else:
827 name_dim1 = root + "_dim1"
828 name_dim2 = root + "_dim2"
829 name_dim1g = name_dim1 + 'g'
830 name_dim2g = name_dim2 + 'g'
831 concat_left.append(name_dim1)
832 concat_right.append(name_dim2)
833 yield helper.make_node(
834 'Gather', [name_shape1, name_sum_axes], [name_dim1g])
835 yield helper.make_node(
836 'Gather', [name_shape2, name_sum_axes], [name_dim2g])
837 yield helper.make_node(
838 'ReduceProd', [name_dim1g], [name_dim1], keepdims=1)
839 yield helper.make_node(
840 'ReduceProd', [name_dim2g], [name_dim2], keepdims=1)
842 batch_kind = self.get_dot_kind()
843 if batch_kind in ('11', 'N1', 'N1'):
844 # *shape1, *shape2
845 name_minus_one = root + "__01"
846 yield numpy_helper.from_array(
847 numpy.array([-1], dtype=numpy.int64), name=name_minus_one)
848 name_agg_shape1_2 = root + "_resh1_%s" % batch_kind
849 name_agg_shape2_2 = root + "_resh2_%s" % batch_kind
850 yield helper.make_node(
851 'Concat', [name_minus_one, name_dim1], [name_agg_shape1_2], axis=0)
852 yield helper.make_node(
853 'Concat', [name_minus_one, name_dim2], [name_agg_shape2_2], axis=0)
855 # m1sh = m1.reshape((-1, dim1))
856 # m2sh = m2.reshape((-1, dim2))
857 name_agg1_2 = root + "_aresh1"
858 name_agg2_2 = root + "_aresh2"
859 yield helper.make_node('Reshape', [name1, name_agg_shape1_2], [name_agg1_2])
860 yield helper.make_node('Reshape', [name2, name_agg_shape2_2], [name_agg2_2])
862 # dot = gemm(m1sh, m2sh, False, True)
863 name_dot = root + "_gemm"
864 yield helper.make_node(
865 'Gemm', [name_agg1_2, name_agg2_2], [name_dot],
866 alpha=1., beta=0., transA=0, transB=1)
867 else:
868 # *shape1, *shape2
869 name_agg_shape1 = root + "_resh1"
870 name_agg_shape2 = root + "_resh2"
871 yield helper.make_node(
872 'Concat', concat_left, [name_agg_shape1], axis=0)
873 yield helper.make_node(
874 'Concat', concat_right, [name_agg_shape2], axis=0)
876 # m1sh = m1.reshape((dim0, dimb, dim1))
877 # m2sh = m2.reshape((dim0b, dimb, dim2))
878 name_agg1 = root + "_aresh1"
879 name_agg2 = root + "_aresh2"
880 yield helper.make_node('Reshape', [name1, name_agg_shape1], [name_agg1])
881 yield helper.make_node('Reshape', [name2, name_agg_shape2], [name_agg2])
883 # dot = m1sh @ numpy.transpose(m2sh, (0, 2, 1))
884 name_agg2_tr = root + "_aresh2_tr"
885 yield helper.make_node(
886 'Transpose', [name_agg2], [name_agg2_tr], perm=[0, 2, 1],
887 name="Transpose021_%s" % id(self))
889 name_dot = root + "_dot"
890 yield helper.make_node(
891 'MatMul', [name_agg1, name_agg2_tr], [name_dot])
893 # new_shape = ([max(m1.shape[i], m2.shape[i]) for i in batch_axes] +
894 # [m1.shape[i] for i in left if i not in batch_axes] +
895 # [m2.shape[i] for i in right if i not in batch_axes])
896 concat_final = []
897 if len(batch_axes) > 0:
898 name_max_dim = root + "_max_dim"
899 concat_final.append(name_max_dim)
900 yield helper.make_node(
901 'Max', [name_dim0g, name_dim0bg], [name_max_dim])
903 left_set = list(sorted(set(left) - (set(batch_axes) & set(left))))
904 if len(left_set) > 0:
905 name_left_dim = root + "_left_dim"
906 name_left_set = root + "_left_set"
907 yield numpy_helper.from_array(
908 numpy.array(left_set, dtype=numpy.int64), name=name_left_set)
909 yield helper.make_node(
910 'Gather', [name_shape1, name_left_set], [name_left_dim])
911 concat_final.append(name_left_dim)
913 right_set = list(sorted(set(right) - (set(batch_axes) & set(right))))
914 if len(right_set) > 0:
915 name_right_dim = root + "_right_dim"
916 name_right_set = root + "_right_set"
917 yield numpy_helper.from_array(
918 numpy.array(right_set, dtype=numpy.int64), name=name_right_set)
919 yield helper.make_node(
920 'Gather', [name_shape2, name_right_set], [name_right_dim])
921 concat_final.append(name_right_dim)
923 name_new_shape = root + '_new_shape'
924 diff = (
925 self.full_dim -
926 (len(batch_axes) + len(left_set) + len(right_set)))
927 if diff > 0:
928 names_ones = root + "_ones"
929 yield numpy_helper.from_array(
930 numpy.array([1 for i in range(diff)], dtype=numpy.int64),
931 name=names_ones)
932 concat_final.append(names_ones)
934 yield helper.make_node(
935 'Concat', concat_final, [name_new_shape], axis=0)
937 name_final = root + '_final'
938 yield helper.make_node(
939 'Reshape', [name_dot, name_new_shape], [name_final])
941 def to_onnx(self, names, opset=None, verbose=False, **kwargs):
942 """
943 Converts this node into ONNX. Enumerates all ONNX node
944 which participate to the conversion. The last one
945 is the final output.
947 :param names: dictionary where to find already converted name
948 :param opset: opset
949 :param verbose: prints out intermediate results
950 :param kwargs: additional parameter for the conversion
951 :return: output
952 """
953 if opset is None:
954 opset = __max_supported_opset__ # pragma: no cover
955 if verbose:
956 print()
957 print("to_onnx %r (%s) opset=%r." % (
958 self.name,
959 ", ".join(map(lambda s: str(id(s)), self.inputs)),
960 opset))
962 method_name = "_to_onnx_%s" % self.name
963 meth = getattr(self, method_name, None)
964 if meth is None:
965 if self.name.endswith("_mm"):
966 raise NotImplementedError(
967 "to_onnx not implemented for %r."
968 "You should call method simplify_mm_nodes "
969 "to remove it." % self.name)
970 raise NotImplementedError(
971 "to_onnx not implemented for %r." % self.name)
972 for node in meth(names, verbose=verbose, opset=opset, **kwargs):
973 if hasattr(node, 'output'):
974 names[id(self)] = node.output[0]
975 if verbose:
976 print("+ OP %r -- (%s - %d)" %
977 (node.output[0], self.name, id(self)))
978 elif verbose:
979 # Initializer
980 print("+ CT %r -- (%s - %d)" %
981 (node.name, self.name, id(self)))
982 yield node
984 def get_dot_kind(self):
985 """
986 Every matrix multiplication can be either:
988 * a simple multiplication (`M`) (undetected)
989 * a 2D matrix multiplication (`11`)
990 * a broadcasted matrix multiplication (`N1` or `1N`)
991 * a batch matrix multiplication (`NN`)
993 This method returns which kind it is.
994 """
995 batch_axes = self.kwargs['batch_axes']
996 # keep_axes = self.kwargs['keep_axes']
997 # sum_axes = self.kwargs['sum_axes']
998 # left = self.kwargs['left']
999 # right = self.kwargs['right']
1000 info = self._info
1001 row_left = info['i_row']
1002 row_right = info['i_row2']
1004 batch_left = [row_left[k] for k in batch_axes]
1005 batch_right = [row_right[k] for k in batch_axes]
1006 n_left = len(batch_left) > 0 and max(batch_left) == 2
1007 n_right = len(batch_right) > 0 and max(batch_right) == 2
1008 return "%s%s" % ('N' if n_left else '1', 'N' if n_right else '1')
1011class GraphEinsumSubOp:
1012 """
1013 Class gathering all nodes produced to explicit einsum
1014 operators.
1016 :param letters: list of distinct letters
1017 :param mat: matrix, see @see fn analyse_einsum_equation
1018 :param lengths: lengths of every input
1019 :param duplicates: see @see fn analyse_einsum_equation
1020 """
1022 def __init__(self, letters, mat, lengths, duplicates):
1023 self._nodes = {}
1024 self._mark = {}
1025 self._ops = []
1026 self._inputs = {}
1027 self.last_op = None
1028 self.last_added_op = None
1029 self.metadata = dict(
1030 letters=letters, mat=mat, lengths=lengths,
1031 mat0=mat.copy(), duplicates=duplicates)
1033 def append(self, op):
1034 """
1035 Adds one input or result.
1037 :param op: integer (an input) or an instance of @see cl EinsumSubOp.
1038 :return: op or None if op is an integer
1039 """
1040 if isinstance(op, int):
1041 if op in self._nodes:
1042 raise RuntimeError( # pragma: no cover
1043 "Key %d already added." % op)
1044 self._nodes[op] = op
1045 self.last_added_op = op
1046 self._inputs[op] = op
1047 return None
1048 if isinstance(op, EinsumSubOp):
1049 if op in self._nodes:
1050 raise RuntimeError( # pragma: no cover
1051 "Key %d already added, op=%r." % (id(op), op))
1052 self._nodes[id(op)] = op
1053 self._ops.append(op)
1054 self.last_added_op = op
1055 return op
1056 raise TypeError( # pragma: no cover
1057 "Unexpected type %r." % type(op))
1059 def mark_last_node(self):
1060 """
1061 Marks the last node as the final output.
1062 """
1063 if self.last_added_op is None:
1064 raise RuntimeError("last_added_op is None.") # pragma: no cover
1065 self.mark(-1, self.last_added_op)
1067 def mark(self, i, op):
1068 """
1069 Marks one input or result as an intermediate result
1070 after a full einsum step.
1072 :param op: integer (an input) or an instance of @see cl EinsumSubOp.
1073 """
1074 if not isinstance(i, int):
1075 raise TypeError( # pragma: no cover
1076 "i must an integer not %r." % type(i))
1077 if i != -1 and i not in self._inputs:
1078 raise RuntimeError( # pragma: no cover
1079 "Input %d was not registered in %r." % (i, self._inputs))
1080 if isinstance(op, EinsumSubOp):
1081 if id(op) not in self._nodes:
1082 raise RuntimeError( # pragma: no cover
1083 "Key %d not found, op=%r." % (id(op), op))
1084 self._mark[i] = op
1085 self._mark[id(op)] = i
1086 self.last_op = op
1087 else:
1088 raise TypeError( # pragma: no cover
1089 "Unexpected type %r." % type(i))
1091 def __iter__(self):
1092 "Iterates on nodes."
1093 for op in self._ops:
1094 yield op
1096 def to_dot(self, **kwargs):
1097 """
1098 Produces a graph in :epkg:`dot`.
1100 :param kwargs: additional graph option
1101 :return: string
1102 """
1103 options = {
1104 'orientation': 'portrait',
1105 'ranksep': '0.25',
1106 'nodesep': '0.05',
1107 'width': '0.5',
1108 'height': '0.1',
1109 'size': '5',
1110 'node': '[shape=record]',
1111 }
1112 options.update(kwargs)
1114 def d2s(d):
1115 it = []
1116 for k, v in sorted(d.items()):
1117 it.append("%s=%s" % (k, v))
1118 return " ".join(it)
1120 def d2sd(d):
1121 it = []
1122 for k, v in sorted(d.items()):
1123 if len(v) > 1:
1124 it.append("%s=%s" % (k, ",".join(map(str, v))))
1125 return " ".join(it)
1127 rows = ["digraph{"]
1128 for k, v in options.items():
1129 if isinstance(v, str) and "[" in v:
1130 rows.append("{} {};".format(k, v))
1131 else:
1132 rows.append("{}={};".format(k, v))
1133 for k, v in self._nodes.items():
1134 if isinstance(v, int):
1135 let = [(r, self.metadata['letters'][i])
1136 for i, r in enumerate(self.metadata['mat0'][v])
1137 if r != -1]
1138 dup = self.metadata['duplicates'][v]
1139 if dup is None:
1140 dup = ""
1141 else:
1142 dup = " - %s" % d2sd(dup)
1143 let.sort()
1144 letters = "".join(_[1] for _ in let)
1145 lab = "input %d\\\\n%s\\\\n%s%s" % (
1146 v, letters, str(self.metadata['mat0'][v]), dup)
1147 sk = v
1148 extended_lab = ""
1149 else:
1150 lab = "%s\\\\n%s" % (v.name, d2s(v.kwargs))
1151 sk = id(v)
1152 extended_lab = v.dot_label()
1153 if extended_lab:
1154 extended_lab = "\\\\n" + extended_lab
1156 if sk in self._mark and isinstance(self._mark[sk], int):
1157 la = self._mark[sk]
1158 lab = lab.replace("\\\\n", " - I%d\\\\n" % la)
1159 s = ('%d [label="%s%s" style=filled '
1160 'fillcolor=red];' % (k, lab, extended_lab))
1161 else:
1162 s = '%d [label="%s%s"];' % (k, lab, extended_lab)
1163 rows.append(s)
1164 if not hasattr(v, 'inputs'):
1165 continue
1166 for i in v.inputs:
1167 vid = i if isinstance(i, int) else id(i)
1168 s = "%d -> %d;" % (vid, k)
1169 rows.append(s)
1170 rows.append("}")
1171 return "\n".join(rows)
1173 def apply_sequence(self, *inputs, verbose=False, **kwargs):
1174 """
1175 Applies a sequence of operations on a list of inputs.
1177 :param inputs: inputs:
1178 :param verbose: prints out intermediate results
1179 :param kwargs: additional parameters,
1180 see :meth:`apply
1181 <mlprodict.testing.einsum.einsum_impl_classes.EinsumSubOp.apply>`.
1182 :return: output
1183 """
1184 if verbose:
1185 print('######### apply_sequence')
1186 data = {i: inp for i, inp in enumerate(inputs)}
1187 last = None
1188 for op in self:
1189 last = op.apply(data, verbose=verbose, **kwargs)
1190 if last is None:
1191 raise RuntimeError( # pragma: no cover
1192 "Sequence of operations is empty.")
1193 return last
1195 def clean_unused_nodes(self, verbose=False):
1196 """
1197 Cleans nodes with unused outputs.
1199 :param verbose: display intermediate information
1200 """
1202 def iteration(it):
1203 # Walks through all nodes.
1204 is_used = {}
1205 for node in self._ops:
1206 if not isinstance(node, EinsumSubOp):
1207 continue
1208 if id(node) not in is_used:
1209 is_used[id(node)] = []
1210 for inp in node.inputs:
1211 if not isinstance(inp, EinsumSubOp):
1212 continue
1213 idn = id(inp)
1214 if idn not in is_used:
1215 is_used[idn] = []
1216 is_used[idn].append(id(node))
1218 # Remove unused nodes.
1219 removed = []
1220 for k, v in is_used.items():
1221 if len(v) == 0:
1222 removed.append(k)
1223 removed = set(removed)
1224 i_rem = []
1225 for i, op in enumerate(self._ops):
1226 if not isinstance(op, EinsumSubOp):
1227 continue
1228 if id(op) in removed and id(op) not in self._mark:
1229 i_rem.append((i, id(op)))
1230 for i, idn in reversed(i_rem):
1231 if verbose:
1232 print("[GraphEinsumSubOp.clean_nodes] remove node "
1233 "i=%d: %d - id=%d" % (it, i, idn))
1234 del self._ops[i]
1235 del self._nodes[idn]
1236 return len(i_rem) > 0
1238 it = 1
1239 while iteration(it):
1240 it += 1
1242 self.last_op = None
1243 self.last_added_op = None
1245 def simplify_mm_nodes(self, verbose=False):
1246 """
1247 Node name suffixed by `mm` are an artifact to keep
1248 the graph consistent while building it. They can
1249 now be replaced by the equivalent node without suffix `mm`.
1251 :param verbose: display intermediate information
1252 """
1253 for op in self:
1254 if not isinstance(op, EinsumSubOp):
1255 continue
1256 if op.name.endswith('_mm'):
1257 if verbose:
1258 print("[GraphEinsumSubOp.simplify_mm_nodes] node %r"
1259 " - id=%d" % (op.name, id(op)))
1260 if len(op.inputs) != 2:
1261 raise RuntimeError( # pragma: no cover
1262 "Expecting 2 inputs for node %r not %r id=%r." % (
1263 op.name, len(op.inputs), id(op)))
1264 op.name = op.name[:-3]
1265 op.inputs = op.inputs[:1]
1267 def _get_forward_nodes(self):
1268 """
1269 Returns the forward nodes.
1270 """
1271 forward = {}
1272 for op in self:
1273 if isinstance(op, int):
1274 continue
1275 for inp in op.inputs:
1276 key = inp if isinstance(inp, int) else id(inp)
1277 if key in forward:
1278 forward[key].append(op)
1279 else:
1280 forward[key] = [op]
1281 return forward
1283 def _pprint_forward(self):
1284 rows = []
1285 for op in self:
1286 line = "%r <- %s(%s)" % (
1287 id(op), op.name,
1288 ", ".join(map(str, [id(_) for _ in op.inputs])))
1289 rows.append(line)
1290 return "\n".join(rows)
1292 def _replace_node_sequence(self, added, deleted):
1293 """
1294 Removes a sequence of nodes. The method does not check
1295 that the graph remains consistent.
1296 """
1297 forward = self._get_forward_nodes()
1298 key = id(deleted[-1])
1299 if key not in forward:
1300 raise RuntimeError( # pragma: no cover
1301 "Key {} missing in all forward nodes (other keys {}), "
1302 "all keys:\n{}".format(
1303 key, [id(_) for _ in deleted],
1304 self._pprint_forward()))
1306 # deletion
1307 mark_input = None
1308 for d in deleted:
1309 del self._nodes[id(d)]
1310 if id(d) in self._mark:
1311 del self._mark[id(d)]
1312 dels = []
1313 for k, v in self._mark.items():
1314 if id(v) == id(d):
1315 mark_input = k
1316 dels.append(k)
1317 if len(dels) != 1:
1318 raise RuntimeError( # pragma: no cover
1319 "Input %d has more than one marked operator "
1320 "(%r)." % (id(d), dels))
1321 del self._mark[dels[0]]
1323 dels = set(id(o) for o in deleted)
1324 rem = []
1325 for i, op in enumerate(self._ops):
1326 if id(op) in dels:
1327 rem.append(i)
1328 if len(rem) != len(deleted):
1329 raise RuntimeError( # pragma: no cover
1330 "Mismatched length %r, %r, len=%r." % (
1331 rem, dels, len(deleted)))
1332 for i in reversed(rem):
1333 del self._ops[i]
1334 self.last_add_op = None
1336 # insertion
1337 if added is not None:
1338 self._ops.insert(rem[0], added)
1339 self._nodes[id(added)] = added
1340 for op in forward[key]:
1341 new_inputs = list(op.inputs)
1342 for i in range(len(op.inputs)): # pylint: disable=C0200
1343 if id(op.inputs[i]) == key:
1344 new_inputs[i] = added
1345 op.inputs = tuple(new_inputs)
1346 if mark_input is not None:
1347 self.mark(mark_input, added)
1348 else:
1349 inps = deleted[0].inputs
1350 if len(inps) != 1:
1351 raise RuntimeError( # pragma: no cover
1352 "More than one input. Call another method.")
1353 inp = inps[0]
1354 for op in forward[key]:
1355 new_inputs = list(op.inputs)
1356 for i in range(len(op.inputs)): # pylint: disable=C0200
1357 if id(op.inputs[i]) == key:
1358 new_inputs[i] = inp
1359 op.inputs = tuple(new_inputs)
1360 if mark_input is not None:
1361 self.mark(mark_input, inp)
1363 def remove_duplicate_transpose(self, verbose=False):
1364 """
1365 Removes consecutive transpose by merging them.
1367 :param verbose: display intermediate information
1368 """
1369 modif = 1
1370 while modif > 0:
1371 modif = 0
1372 candidates = []
1373 forward = self._get_forward_nodes()
1374 for op in self:
1375 if op.name == "transpose":
1376 inp = op.inputs[0]
1377 if (isinstance(inp, EinsumSubOp) and
1378 inp.name == 'transpose' and
1379 len(forward[id(inp)]) == 1):
1380 candidates.append(op)
1382 if len(candidates) > 0:
1383 modif = 1
1384 # Not efficient to take the first one and to
1385 # start again but the graph should not be too big.
1386 cand = candidates[0]
1387 op2 = cand
1388 op1 = cand.inputs[0]
1389 perm1 = op1.kwargs['perm']
1390 perm2 = op2.kwargs['perm']
1391 if len(perm1) != len(perm2):
1392 raise RuntimeError( # pragma: no cover
1393 "Transposition should have the same length "
1394 "%r, %r." % (perm1, perm2))
1395 perm = list(perm1)
1396 for i in range(len(perm)): # pylint: disable=C0200
1397 perm[i] = perm1[perm2[i]]
1398 if list(range(len(perm))) == perm:
1399 # identity, everything needs to be removed
1400 new_op = None
1401 else:
1402 new_op = op2.__class__(
1403 op2.full_dim, op2.name, op1.inputs[0],
1404 perm=tuple(perm))
1405 self._replace_node_sequence(new_op, [op1, op2])
1406 if verbose:
1407 print( # pragma: no cover
1408 "[GraphEinsumSubOp.remove_duplicate_transpose] remove nodes %r"
1409 " - id=%d,%d + %d perm1=%r perm2=%r -> perm=%r" % (
1410 op2.name, id(op1), id(op2),
1411 id(new_op) if new_op is not None else -1,
1412 perm1, perm2, perm))
1414 def to_onnx(self, output, *inputs, dtype=None, verbose=False,
1415 opset=None, **kwargs):
1416 """
1417 Converts the graph into ONNX.
1419 :param output: output name
1420 :param inputs: input names
1421 :param dtype: type used for all operators
1422 :param opset: desired opset, None for the last one
1423 :param verbose: display intermediate operators
1424 :param kwargs: additional parameter to use when building
1425 the ONNX graph, list of supported parameters:
1426 *name*, *ir_version*, *producer_name*,
1427 *producer_version*, *initializer*
1428 :return: ONNX graph
1430 Not all graphs can be converted into ONNX. Only graphs produced
1431 with `strategy='numpy'` can be converted otherwise the following
1432 error shows up:
1434 ::
1436 NotImplementedError: to_onnx not implemented for 'matmul'.
1437 """
1438 from ...onnx_tools.optim import onnx_remove_node_unused
1440 # inputs
1441 if opset is None:
1442 opset = __max_supported_opset__
1443 if verbose:
1444 print("[GraphEinsumSubOp.to_onnx] %r -> %s opset=%r "
1445 "dtype=%r" % (inputs, output, opset, dtype))
1446 onx_inputs = []
1447 proto = guess_proto_dtype(
1448 numpy.float32 if dtype is None else dtype)
1449 lengths = self.metadata['lengths']
1450 names = {}
1451 for inp, le in zip(inputs, lengths):
1452 if isinstance(inp, tuple):
1453 name, typ = inp
1454 if le != len(typ.shape):
1455 raise ValueError( # pragma: no cover
1456 "Irreconcialable shapes for input %r: "
1457 "%r != len(%r)." % (name, le, typ.shape))
1458 proto = guess_proto_dtype(guess_numpy_type(typ))
1459 onx_inputs.append(
1460 helper.make_tensor_value_info(name, proto, typ.shape))
1461 names[len(names)] = name
1462 else:
1463 onx_inputs.append(
1464 helper.make_tensor_value_info(
1465 inp, proto, [None for i in range(le)]))
1466 names[len(names)] = inp
1468 # output
1469 onx_output = helper.make_tensor_value_info(
1470 output, proto, [None for i in range(lengths[-1])])
1472 # nodes
1473 nodes = []
1474 inits = []
1475 if "initializer" in kwargs:
1476 inits.extend(kwargs['initializer'])
1477 for op in self:
1478 for onx_node in op.to_onnx(names, verbose=verbose, opset=opset):
1479 if hasattr(onx_node, 'output'):
1480 nodes.append(onx_node)
1481 else:
1482 inits.append(onx_node)
1484 # last node
1485 last_node = nodes[-1]
1486 nodes.append(helper.make_node(
1487 'Identity', [last_node.output[0]], [output]))
1489 # Builds the graph
1490 model = helper.make_model(
1491 opset_imports=[helper.make_operatorsetid('', opset)],
1492 ir_version=kwargs.get('ir_version', get_ir_version(opset)),
1493 producer_name=kwargs.get('producer_name', 'mlprodict'),
1494 producer_version=kwargs.get('producer_version', "0.0.dev"),
1495 graph=helper.make_graph(
1496 name=kwargs.get('name', 'einsum'),
1497 inputs=onx_inputs, outputs=[onx_output],
1498 initializer=inits, nodes=nodes))
1500 return onnx_remove_node_unused(model)