Coverage for mlprodict/onnxrt/ops_cpu/op_gemm.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

41 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class Gemm(OpRun): 

12 

13 atts = {'alpha': 1., 'beta': 1., 'transA': 0, 'transB': 0} 

14 

15 def __init__(self, onnx_node, desc=None, **options): 

16 OpRun.__init__(self, onnx_node, desc=desc, 

17 expected_attributes=Gemm.atts, 

18 **options) 

19 if self.transA: 

20 _meth = (Gemm._gemm11 if self.transB 

21 else Gemm._gemm10) 

22 else: 

23 _meth = (Gemm._gemm01 if self.transB 

24 else Gemm._gemm00) 

25 self._meth = lambda a, b, c: _meth(a, b, c, self.alpha, self.beta) 

26 

27 @staticmethod 

28 def _gemm00(a, b, c, alpha, beta): 

29 o = numpy.dot(a, b) * alpha 

30 if c is not None and beta != 0: 

31 o += c * beta 

32 return o 

33 

34 @staticmethod 

35 def _gemm01(a, b, c, alpha, beta): 

36 o = numpy.dot(a, b.T) * alpha 

37 if c is not None and beta != 0: 

38 o += c * beta 

39 return o 

40 

41 @staticmethod 

42 def _gemm10(a, b, c, alpha, beta): 

43 o = numpy.dot(a.T, b) * alpha 

44 if c is not None and beta != 0: 

45 o += c * beta 

46 return o 

47 

48 @staticmethod 

49 def _gemm11(a, b, c, alpha, beta): 

50 o = numpy.dot(a.T, b.T) * alpha 

51 if c is not None and beta != 0: 

52 o += c * beta 

53 return o 

54 

55 def _run(self, a, b, c=None): # pylint: disable=W0221 

56 return (self._meth(a, b, c), ) 

57 

58 def _infer_shapes(self, a, b, c=None): # pylint: disable=W0221 

59 return (a, ) 

60 

61 def _infer_types(self, a, b, c=None): # pylint: disable=W0221 

62 return (a, )