Coverage for mlprodict/onnxrt/ops_cpu/op_gemm.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11class Gemm(OpRun):
13 atts = {'alpha': 1., 'beta': 1., 'transA': 0, 'transB': 0}
15 def __init__(self, onnx_node, desc=None, **options):
16 OpRun.__init__(self, onnx_node, desc=desc,
17 expected_attributes=Gemm.atts,
18 **options)
19 if self.transA:
20 _meth = (Gemm._gemm11 if self.transB
21 else Gemm._gemm10)
22 else:
23 _meth = (Gemm._gemm01 if self.transB
24 else Gemm._gemm00)
25 self._meth = lambda a, b, c: _meth(a, b, c, self.alpha, self.beta)
27 @staticmethod
28 def _gemm00(a, b, c, alpha, beta):
29 o = numpy.dot(a, b) * alpha
30 if c is not None and beta != 0:
31 o += c * beta
32 return o
34 @staticmethod
35 def _gemm01(a, b, c, alpha, beta):
36 o = numpy.dot(a, b.T) * alpha
37 if c is not None and beta != 0:
38 o += c * beta
39 return o
41 @staticmethod
42 def _gemm10(a, b, c, alpha, beta):
43 o = numpy.dot(a.T, b) * alpha
44 if c is not None and beta != 0:
45 o += c * beta
46 return o
48 @staticmethod
49 def _gemm11(a, b, c, alpha, beta):
50 o = numpy.dot(a.T, b.T) * alpha
51 if c is not None and beta != 0:
52 o += c * beta
53 return o
55 def _run(self, a, b, c=None): # pylint: disable=W0221
56 return (self._meth(a, b, c), )
58 def _infer_shapes(self, a, b, c=None): # pylint: disable=W0221
59 return (a, )
61 def _infer_types(self, a, b, c=None): # pylint: disable=W0221
62 return (a, )