Coverage for mlprodict/testing/einsum/blas_lapack.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

129 statements  

1""" 

2@file 

3@brief Direct calls to libraries :epkg:`BLAS` and :epkg:`LAPACK`. 

4""" 

5import numpy 

6from scipy.linalg.blas import sgemm, dgemm # pylint: disable=E0611 

7from .direct_blas_lapack import ( # pylint: disable=E0401,E0611 

8 dgemm_dot, sgemm_dot) 

9 

10 

11def pygemm(transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc): 

12 """ 

13 Pure python implementatin of GEMM. 

14 """ 

15 if len(A.shape) != 1: 

16 raise ValueError( # pragma: no cover 

17 "A must be a vector.") 

18 if len(B.shape) != 1: 

19 raise ValueError( # pragma: no cover 

20 "B must be a vector.") 

21 if len(C.shape) != 1: 

22 raise ValueError( 

23 "C must be a vector.") 

24 if A.shape[0] != M * K: 

25 raise ValueError( 

26 "Dimension mismatch for A.shape=%r M=%r N=%r K=%r." % ( 

27 A.shape, M, N, K)) 

28 if B.shape[0] != N * K: 

29 raise ValueError( 

30 "Dimension mismatch for B.shape=%r M=%r N=%r K=%r." % ( 

31 B.shape, M, N, K)) 

32 if C.shape[0] != N * M: 

33 raise ValueError( # pragma: no cover 

34 "Dimension mismatch for C.shape=%r M=%r N=%r K=%r." % ( 

35 C.shape, M, N, K)) 

36 

37 if transA: 

38 a_i_stride = lda 

39 a_k_stride = 1 

40 else: 

41 a_i_stride = 1 

42 a_k_stride = lda 

43 

44 if transB: 

45 b_j_stride = 1 

46 b_k_stride = ldb 

47 else: 

48 b_j_stride = ldb 

49 b_k_stride = 1 

50 

51 c_i_stride = 1 

52 c_j_stride = ldc 

53 

54 n_loop = 0 

55 for j in range(N): 

56 for i in range(M): 

57 total = 0 

58 for k in range(K): 

59 n_loop += 1 

60 a_index = i * a_i_stride + k * a_k_stride 

61 if a_index >= A.shape[0]: 

62 raise IndexError( 

63 "A: i=%d a_index=%d >= %d " 

64 "(a_i_stride=%d a_k_stride=%d)" % ( 

65 i, a_index, A.shape[0], a_i_stride, a_k_stride)) 

66 a_val = A[a_index] 

67 

68 b_index = j * b_j_stride + k * b_k_stride 

69 if b_index >= B.shape[0]: 

70 raise IndexError( 

71 "B: j=%d b_index=%d >= %d " 

72 "(a_i_stride=%d a_k_stride=%d)" % ( 

73 j, b_index, B.shape[0], b_j_stride, b_k_stride)) 

74 b_val = B[b_index] 

75 

76 mult = a_val * b_val 

77 total += mult 

78 

79 c_index = i * c_i_stride + j * c_j_stride 

80 if c_index >= C.shape[0]: 

81 raise IndexError("C: %d >= %d" % (c_index, C.shape[0])) 

82 C[c_index] = alpha * total + beta * C[c_index] 

83 

84 if n_loop != M * N * K: 

85 raise RuntimeError( # pragma: no cover 

86 "Unexpected number of loops: %d != %d = (%d * %d * %d) " 

87 "lda=%d ldb=%d ldc=%d" % ( 

88 n_loop, M * N * K, M, N, K, lda, ldb, ldc)) 

89 

90 

91def gemm_dot(A, B, transA=False, transB=False): 

92 """ 

93 Implements dot product with gemm when possible. 

94 

95 :param A: first matrix 

96 :param B: second matrix 

97 :param transA: is first matrix transposed? 

98 :param transB: is second matrix transposed? 

99 """ 

100 if A.dtype != B.dtype: 

101 raise TypeError( # pragma: no cover 

102 "Matrices A and B must have the same dtype not " 

103 "%r, %r." % (A.dtype, B.dtype)) 

104 if len(A.shape) != 2: 

105 raise ValueError( # pragma: no cover 

106 "Matrix A does not have 2 dimensions but %d." % len(A.shape)) 

107 if len(B.shape) != 2: 

108 raise ValueError( # pragma: no cover 

109 "Matrix B does not have 2 dimensions but %d." % len(B.shape)) 

110 

111 def _make_contiguous_(A, B): 

112 if not A.flags['C_CONTIGUOUS']: 

113 A = numpy.ascontiguousarray(A) 

114 if not B.flags['C_CONTIGUOUS']: 

115 B = numpy.ascontiguousarray(B) 

116 return A, B 

117 

118 all_dims = A.shape + B.shape 

119 square = min(all_dims) == max(all_dims) 

120 

121 if transA: 

122 if transB: 

123 if A.dtype == numpy.float32: 

124 if square: 

125 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype) 

126 A, B = _make_contiguous_(A, B) 

127 sgemm_dot(B, A, True, True, C) 

128 return C 

129 else: 

130 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype) 

131 return sgemm(1, A, B, 0, C, 1, 1, 1) 

132 if A.dtype == numpy.float64: 

133 if square: 

134 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype) 

135 A, B = _make_contiguous_(A, B) 

136 dgemm_dot(B, A, True, True, C) 

137 return C 

138 else: 

139 C = numpy.zeros((A.shape[1], B.shape[0]), dtype=A.dtype) 

140 return dgemm(1, A, B, 0, C, 1, 1, 1) 

141 return A.T @ B.T 

142 else: 

143 if A.dtype == numpy.float32: 

144 if square: 

145 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype) 

146 A, B = _make_contiguous_(A, B) 

147 sgemm_dot(B, A, False, True, C) 

148 return C 

149 else: 

150 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype) 

151 return sgemm(1, A, B, 0, C, 1, 0, 1) 

152 if A.dtype == numpy.float64: 

153 if square: 

154 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype) 

155 A, B = _make_contiguous_(A, B) 

156 dgemm_dot(B, A, False, True, C) 

157 return C 

158 else: 

159 C = numpy.zeros((A.shape[1], B.shape[1]), dtype=A.dtype) 

160 return dgemm(1, A, B, 0, C, 1, 0, 1) 

161 return A.T @ B 

162 else: 

163 if transB: 

164 if A.dtype == numpy.float32: 

165 if square: 

166 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype) 

167 A, B = _make_contiguous_(A, B) 

168 sgemm_dot(B, A, True, False, C) 

169 return C 

170 else: 

171 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype) 

172 return sgemm(1, A, B, 0, C, 0, 1, 1) 

173 if A.dtype == numpy.float64: 

174 if square: 

175 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype) 

176 A, B = _make_contiguous_(A, B) 

177 dgemm_dot(B, A, True, False, C) 

178 return C 

179 else: 

180 C = numpy.zeros((A.shape[0], B.shape[0]), dtype=A.dtype) 

181 return dgemm(1, A, B, 0, C, 0, 1, 1) 

182 return A @ B.T 

183 else: 

184 if A.dtype == numpy.float32: 

185 if square: 

186 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype) 

187 A, B = _make_contiguous_(A, B) 

188 sgemm_dot(B, A, False, False, C) 

189 return C 

190 else: 

191 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype) 

192 return sgemm(1, A, B, 0, C, 0, 0) 

193 if A.dtype == numpy.float64: 

194 if square: 

195 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype) 

196 A, B = _make_contiguous_(A, B) 

197 dgemm_dot(B, A, False, False, C) 

198 return C 

199 else: 

200 C = numpy.zeros((A.shape[0], B.shape[1]), dtype=A.dtype) 

201 return dgemm(1, A, B, 0, C, 0, 0, 1) 

202 return A @ B