Coverage for mlprodict/onnxrt/ops_cpu/op_einsum.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from ..shape_object import ShapeObject
12class Einsum(OpRun):
14 atts = {'equation': ''}
15 python_inputs = ['*inputs']
17 def __init__(self, onnx_node, desc=None, **options):
18 OpRun.__init__(self, onnx_node, desc=desc,
19 expected_attributes=Einsum.atts,
20 **options)
21 if not isinstance(self.equation, (str, bytes)):
22 raise TypeError( # pragma: no cover
23 "equation must be string but is %r." % type(self.equation))
24 self.equation = self.equation.strip()
25 if len(self.equation) == 0:
26 raise TypeError("equation is empty.") # pragma: no cover
28 def _run(self, *args): # pylint: disable=W0221
29 try:
30 return (numpy.einsum(self.equation, *args, optimize=True), )
31 except TypeError:
32 return (numpy.einsum(self.equation, *args), )
34 def _infer_shapes(self, *args): # pylint: disable=W0221
35 try:
36 return (ShapeObject.einsum_shape(self.equation, *args), )
37 except RuntimeError: # pragma: no cover
38 return (ShapeObject(None, dtype=args[0].dtype), )
40 def _infer_types(self, *args): # pylint: disable=W0221
41 return (args[0], )
43 def _infer_sizes(self, *args): # pylint: disable=W0221
44 res = self.run(*args)
45 maxi = max(a.size for a in args)
46 return (dict(temp=maxi * 3 * args[0].dtype.itemsize), ) + res
48 def to_python(self, inputs):
49 return ("import numpy",
50 "return numpy.einsum(equation, *inputs)")