Coverage for mlprodict/testing/einsum/blas_lapack.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Direct calls to libraries :epkg:`BLAS` and :epkg:`LAPACK`.
4"""
5import numpy
6from scipy.linalg.blas import sgemm, dgemm # pylint: disable=E0611
7from .direct_blas_lapack import ( # pylint: disable=E0401,E0611
8 dgemm_dot, sgemm_dot)
11def pygemm(transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc):
12 """
13 Pure python implementatin of GEMM.
14 """
15 if len(A.shape) != 1:
16 raise ValueError( # pragma: no cover
17 "A must be a vector.")
18 if len(B.shape) != 1:
19 raise ValueError( # pragma: no cover
20 "B must be a vector.")
21 if len(C.shape) != 1:
22 raise ValueError(
23 "C must be a vector.")
24 if A.shape[0] != M * K:
25 raise ValueError(
26 "Dimension mismatch for A.shape=%r M=%r N=%r K=%r." % (
27 A.shape, M, N, K))
28 if B.shape[0] != N * K:
29 raise ValueError(
30 "Dimension mismatch for B.shape=%r M=%r N=%r K=%r." % (
31 B.shape, M, N, K))
32 if C.shape[0] != N * M:
33 raise ValueError( # pragma: no cover
34 "Dimension mismatch for C.shape=%r M=%r N=%r K=%r." % (
35 C.shape, M, N, K))
37 if transA:
38 a_i_stride = lda
39 a_k_stride = 1
40 else:
41 a_i_stride = 1
42 a_k_stride = lda
44 if transB:
45 b_j_stride = 1
46 b_k_stride = ldb
47 else:
48 b_j_stride = ldb
49 b_k_stride = 1
51 c_i_stride = 1
52 c_j_stride = ldc
54 n_loop = 0
55 for j in range(N):
56 for i in range(M):
57 total = 0
58 for k in range(K):
59 n_loop += 1
60 a_index = i * a_i_stride + k * a_k_stride
61 if a_index >= A.shape[0]:
62 raise IndexError(
63 "A: i=%d a_index=%d >= %d "
64 "(a_i_stride=%d a_k_stride=%d)" % (
65 i, a_index, A.shape[0], a_i_stride, a_k_stride))
66 a_val = A[a_index]
68 b_index = j * b_j_stride + k * b_k_stride
69 if b_index >= B.shape[0]:
70 raise IndexError(
71 "B: j=%d b_index=%d >= %d "
72 "(a_i_stride=%d a_k_stride=%d)" % (
73 j, b_index, B.shape[0], b_j_stride, b_k_stride))
74 b_val = B[b_index]
76 mult = a_val * b_val
77 total += mult
79 c_index = i * c_i_stride + j * c_j_stride
80 if c_index >= C.shape[0]:
81 raise IndexError("C: %d >= %d" % (c_index, C.shape[0]))
82 C[c_index] = alpha * total + beta * C[c_index]
84 if n_loop != M * N * K:
85 raise RuntimeError( # pragma: no cover
86 "Unexpected number of loops: %d != %d = (%d * %d * %d) "
87 "lda=%d ldb=%d ldc=%d" % (
88 n_loop, M * N * K, M, N, K, lda, ldb, ldc))
91def gemm_dot(A, B, transA=False, transB=False):
92 """
93 Implements dot product with gemm when possible.
95 :param A: first matrix
96 :param B: second matrix
97 :param transA: is first matrix transposed?
98 :param transB: is second matrix transposed?
99 """
100 if A.dtype != B.dtype:
101 raise TypeError( # pragma: no cover
102 "Matrices A and B must have the same dtype not "
103 "%r, %r." % (A.dtype, B.dtype))
104 if len(A.shape) != 2:
105 raise ValueError( # pragma: no cover
106 "Matrix A does not have 2 dimensions but %d." % len(A.shape))
107 if len(B.shape) != 2:
108 raise ValueError( # pragma: no cover
109 "Matrix B does not have 2 dimensions but %d." % len(B.shape))
111 def _make_contiguous_(A, B):
112 if not A.flags['C_CONTIGUOUS']:
113 A = numpy.ascontiguousarray(A)
114 if not B.flags['C_CONTIGUOUS']:
115 B = numpy.ascontiguousarray(B)
116 return A, B
118 all_dims = A.shape + B.shape
119 square = min(all_dims) == max(all_dims)
121 if transA:
122 if transB:
123 if A.dtype == numpy.float32:
124 if square:
125 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype)
126 A, B = _make_contiguous_(A, B)
127 sgemm_dot(B, A, True, True, C)
128 return C
129 else:
130 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype)
131 return sgemm(1, A, B, 0, C, 1, 1, 1)
132 if A.dtype == numpy.float64:
133 if square:
134 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype)
135 A, B = _make_contiguous_(A, B)
136 dgemm_dot(B, A, True, True, C)
137 return C
138 else:
139 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype)
140 return dgemm(1, A, B, 0, C, 1, 1, 1)
141 return A.T @ B.T
142 else:
143 if A.dtype == numpy.float32:
144 if square:
145 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype)
146 A, B = _make_contiguous_(A, B)
147 sgemm_dot(B, A, False, True, C)
148 return C
149 else:
150 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype)
151 return sgemm(1, A, B, 0, C, 1, 0, 1)
152 if A.dtype == numpy.float64:
153 if square:
154 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype)
155 A, B = _make_contiguous_(A, B)
156 dgemm_dot(B, A, False, True, C)
157 return C
158 else:
159 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype)
160 return dgemm(1, A, B, 0, C, 1, 0, 1)
161 return A.T @ B
162 else:
163 if transB:
164 if A.dtype == numpy.float32:
165 if square:
166 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype)
167 A, B = _make_contiguous_(A, B)
168 sgemm_dot(B, A, True, False, C)
169 return C
170 else:
171 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype)
172 return sgemm(1, A, B, 0, C, 0, 1, 1)
173 if A.dtype == numpy.float64:
174 if square:
175 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype)
176 A, B = _make_contiguous_(A, B)
177 dgemm_dot(B, A, True, False, C)
178 return C
179 else:
180 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype)
181 return dgemm(1, A, B, 0, C, 0, 1, 1)
182 return A @ B.T
183 else:
184 if A.dtype == numpy.float32:
185 if square:
186 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype)
187 A, B = _make_contiguous_(A, B)
188 sgemm_dot(B, A, False, False, C)
189 return C
190 else:
191 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype)
192 return sgemm(1, A, B, 0, C, 0, 0)
193 if A.dtype == numpy.float64:
194 if square:
195 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype)
196 A, B = _make_contiguous_(A, B)
197 dgemm_dot(B, A, False, False, C)
198 return C
199 else:
200 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype)
201 return dgemm(1, A, B, 0, C, 0, 0, 1)
202 return A @ B