Coverage for mlprodict/onnxrt/ops_cpu/op_fused_matmul.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from ._new_ops import OperatorSchema
12class FusedMatMul(OpRun):
14 atts = {'alpha': 1., 'transA': 0, 'transB': 0}
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRun.__init__(self, onnx_node, desc=desc,
18 expected_attributes=FusedMatMul.atts,
19 **options)
20 if self.transA:
21 _meth = (FusedMatMul._fmatmul11 if self.transB
22 else FusedMatMul._fmatmul10)
23 else:
24 _meth = (FusedMatMul._fmatmul01 if self.transB
25 else FusedMatMul._fmatmul00)
26 self._meth = lambda a, b: _meth(a, b, self.alpha)
28 def _find_custom_operator_schema(self, op_name):
29 if op_name == "FusedMatMul":
30 return FusedMatMulSchema()
31 raise RuntimeError( # pragma: no cover
32 "Unable to find a schema for operator '{}'.".format(op_name))
34 @staticmethod
35 def _fmatmul00(a, b, alpha):
36 return numpy.matmul(a, b) * alpha
38 @staticmethod
39 def _fmatmul01(a, b, alpha):
40 return numpy.matmul(a, b.T) * alpha
42 @staticmethod
43 def _fmatmul10(a, b, alpha):
44 return numpy.matmul(a.T, b) * alpha
46 @staticmethod
47 def _fmatmul11(a, b, alpha):
48 return numpy.matmul(a.T, b.T) * alpha
50 def _run(self, a, b): # pylint: disable=W0221
51 return (self._meth(a, b), )
53 def _infer_shapes(self, a, b): # pylint: disable=W0221
54 return (a, )
56 def _infer_types(self, a, b): # pylint: disable=W0221
57 return (a, )
60class FusedMatMulSchema(OperatorSchema):
61 """
62 Defines a schema for operators added in this package
63 such as @see cl FusedMatMul.
64 """
66 def __init__(self):
67 OperatorSchema.__init__(self, 'FusedMatMul')
68 self.attributes = FusedMatMul.atts