Coverage for mlprodict/onnxrt/ops_cpu/op_fused_matmul.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

37 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from ._new_ops import OperatorSchema 

10 

11 

12class FusedMatMul(OpRun): 

13 

14 atts = {'alpha': 1., 'transA': 0, 'transB': 0} 

15 

16 def __init__(self, onnx_node, desc=None, **options): 

17 OpRun.__init__(self, onnx_node, desc=desc, 

18 expected_attributes=FusedMatMul.atts, 

19 **options) 

20 if self.transA: 

21 _meth = (FusedMatMul._fmatmul11 if self.transB 

22 else FusedMatMul._fmatmul10) 

23 else: 

24 _meth = (FusedMatMul._fmatmul01 if self.transB 

25 else FusedMatMul._fmatmul00) 

26 self._meth = lambda a, b: _meth(a, b, self.alpha) 

27 

28 def _find_custom_operator_schema(self, op_name): 

29 if op_name == "FusedMatMul": 

30 return FusedMatMulSchema() 

31 raise RuntimeError( # pragma: no cover 

32 "Unable to find a schema for operator '{}'.".format(op_name)) 

33 

34 @staticmethod 

35 def _fmatmul00(a, b, alpha): 

36 return numpy.matmul(a, b) * alpha 

37 

38 @staticmethod 

39 def _fmatmul01(a, b, alpha): 

40 return numpy.matmul(a, b.T) * alpha 

41 

42 @staticmethod 

43 def _fmatmul10(a, b, alpha): 

44 return numpy.matmul(a.T, b) * alpha 

45 

46 @staticmethod 

47 def _fmatmul11(a, b, alpha): 

48 return numpy.matmul(a.T, b.T) * alpha 

49 

50 def _run(self, a, b): # pylint: disable=W0221 

51 return (self._meth(a, b), ) 

52 

53 def _infer_shapes(self, a, b): # pylint: disable=W0221 

54 return (a, ) 

55 

56 def _infer_types(self, a, b): # pylint: disable=W0221 

57 return (a, ) 

58 

59 

60class FusedMatMulSchema(OperatorSchema): 

61 """ 

62 Defines a schema for operators added in this package 

63 such as @see cl FusedMatMul. 

64 """ 

65 

66 def __init__(self): 

67 OperatorSchema.__init__(self, 'FusedMatMul') 

68 self.attributes = FusedMatMul.atts