diff --git a/cblas.h b/cblas.h index f0220eb99..a5ad25ad7 100644 --- a/cblas.h +++ b/cblas.h @@ -400,6 +400,8 @@ void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPE float cblas_sbdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy); void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum CBLAS_TRANSPOSE trans, OPENBLAS_CONST blasint m, OPENBLAS_CONST blasint n, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *a, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST float beta, float *y, OPENBLAS_CONST blasint incy); +void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, + OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc); #ifdef __cplusplus } #endif /* __cplusplus */ diff --git a/common_level3.h b/common_level3.h index 187402a9a..5080ada10 100644 --- a/common_level3.h +++ b/common_level3.h @@ -516,6 +516,13 @@ int qgemm_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, xdouble *, xd #endif #ifdef SMALL_MATRIX_OPT +int sbgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float beta); + +int sbgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); +int sbgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); +int sbgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); +int sbgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + int sgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float beta); int sgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); @@ -530,6 +537,11 @@ int dgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLO int dgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); int dgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); +int sbgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sbgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sbgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sbgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); + int sgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); int sgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); int sgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); diff --git a/common_macro.h b/common_macro.h index aeb9a205b..cf2a3fd88 100644 --- a/common_macro.h +++ b/common_macro.h @@ -942,17 +942,17 @@ #define GEADD_K SGEADD_K -#define GEMM_SMALL_MATRIX_PERMIT SGEMM_SMALL_MATRIX_PERMIT +#define GEMM_SMALL_MATRIX_PERMIT SBGEMM_SMALL_MATRIX_PERMIT -#define GEMM_SMALL_KERNEL_NN SGEMM_SMALL_KERNEL_NN -#define GEMM_SMALL_KERNEL_NT SGEMM_SMALL_KERNEL_NT -#define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN -#define GEMM_SMALL_KERNEL_TT SGEMM_SMALL_KERNEL_TT +#define GEMM_SMALL_KERNEL_NN SBGEMM_SMALL_KERNEL_NN +#define GEMM_SMALL_KERNEL_NT SBGEMM_SMALL_KERNEL_NT +#define GEMM_SMALL_KERNEL_TN SBGEMM_SMALL_KERNEL_TN +#define GEMM_SMALL_KERNEL_TT SBGEMM_SMALL_KERNEL_TT -#define GEMM_SMALL_KERNEL_B0_NN SGEMM_SMALL_KERNEL_B0_NN -#define GEMM_SMALL_KERNEL_B0_NT SGEMM_SMALL_KERNEL_B0_NT -#define GEMM_SMALL_KERNEL_B0_TN SGEMM_SMALL_KERNEL_B0_TN -#define GEMM_SMALL_KERNEL_B0_TT SGEMM_SMALL_KERNEL_B0_TT +#define GEMM_SMALL_KERNEL_B0_NN SBGEMM_SMALL_KERNEL_B0_NN +#define GEMM_SMALL_KERNEL_B0_NT SBGEMM_SMALL_KERNEL_B0_NT +#define GEMM_SMALL_KERNEL_B0_TN SBGEMM_SMALL_KERNEL_B0_TN +#define GEMM_SMALL_KERNEL_B0_TT SBGEMM_SMALL_KERNEL_B0_TT #endif diff --git a/common_param.h b/common_param.h index 7e8bea4fe..31fba9059 100644 --- a/common_param.h +++ b/common_param.h @@ -145,6 +145,19 @@ BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG); int (*sbneg_tcopy) (BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sblaswp_ncopy) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG, blasint *, float *); +#ifdef SMALL_MATRIX_OPT + int (*sbgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float beta); + + int (*sbgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + int (*sbgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + int (*sbgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + int (*sbgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + + int (*sbgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*sbgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*sbgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*sbgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); +#endif #endif #if defined(BUILD_SINGLE) || defined(BUILD_COMPLEX) diff --git a/common_sb.h b/common_sb.h index 9976e812e..d21e7a563 100644 --- a/common_sb.h +++ b/common_sb.h @@ -24,6 +24,7 @@ #define SBGEMM_BETA sbgemm_beta #define SBGEMM_KERNEL sbgemm_kernel +#define SBGEMM_SMALL_MATRIX_PERMIT sbgemm_small_matrix_permit #else #define SBDOT_K gotoblas -> sbdot_k @@ -41,8 +42,19 @@ #define SBGEMM_BETA gotoblas -> sbgemm_beta #define SBGEMM_KERNEL gotoblas -> sbgemm_kernel +#define SBGEMM_SMALL_MATRIX_PERMIT gotoblas -> sbgemm_small_matrix_permit #endif +#define SBGEMM_SMALL_KERNEL_NN FUNC_OFFSET(sbgemm_small_kernel_nn) +#define SBGEMM_SMALL_KERNEL_NT FUNC_OFFSET(sbgemm_small_kernel_nt) +#define SBGEMM_SMALL_KERNEL_TN FUNC_OFFSET(sbgemm_small_kernel_tn) +#define SBGEMM_SMALL_KERNEL_TT FUNC_OFFSET(sbgemm_small_kernel_tt) + +#define SBGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(sbgemm_small_kernel_b0_nn) +#define SBGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(sbgemm_small_kernel_b0_nt) +#define SBGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(sbgemm_small_kernel_b0_tn) +#define SBGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(sbgemm_small_kernel_b0_tt) + #define SBGEMM_NN sbgemm_nn #define SBGEMM_CN sbgemm_tn #define SBGEMM_TN sbgemm_tn diff --git a/interface/gemm.c b/interface/gemm.c index 3497d8651..6dcc54041 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -105,7 +105,7 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, B #endif }; -#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE) && !defined(BFLOAT16) +#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE) #define USE_SMALL_MATRIX_OPT 1 #else #define USE_SMALL_MATRIX_OPT 0 @@ -131,8 +131,8 @@ static size_t gemm_small_kernel_b0[] = { GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, 0, 0, }; -#define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx)) -#define GEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel, (idx)) +#define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx)) +#define GEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel, (idx)) #else static size_t zgemm_small_kernel[] = { @@ -273,8 +273,8 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS blasint m, blasint n, blasint k, #ifndef COMPLEX FLOAT alpha, - FLOAT *a, blasint lda, - FLOAT *b, blasint ldb, + IFLOAT *a, blasint lda, + IFLOAT *b, blasint ldb, FLOAT beta, FLOAT *c, blasint ldc) { #else diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index ef11e391c..49b7c78fb 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -450,6 +450,15 @@ endif ###### BLAS small matrix optimization ##### ifeq ($(SMALL_MATRIX_OPT), 1) +ifeq ($(BUILD_BFLOAT16),1) +SBBLASOBJS += \ + sbgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ + sbgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ + sbgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ + sbgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ + sbgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) +endif + SBLASOBJS += \ sgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ @@ -4424,6 +4433,72 @@ $(KDIR)sgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL $(KDIR)sgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_TT) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DB0 $< -o $@ + +ifeq ($(BUILD_BFLOAT16), 1) +ifndef SBGEMM_SMALL_M_PERMIT +SBGEMM_SMALL_M_PERMIT = ../generic/gemm_small_matrix_permit.c +endif + +ifndef SBGEMM_SMALL_K_NN +SBGEMM_SMALL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c +endif + +ifndef SBGEMM_SMALL_K_NT +SBGEMM_SMALL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c +endif + +ifndef SBGEMM_SMALL_K_TN +SBGEMM_SMALL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c +endif + +ifndef SBGEMM_SMALL_K_TT +SBGEMM_SMALL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c +endif + +$(KDIR)sbgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_M_PERMIT) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sbgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sbgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sbgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sbgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +ifndef SBGEMM_SMALL_K_B0_NN +SBGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_nn.c +endif + +ifndef SBGEMM_SMALL_K_B0_NT +SBGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_nt.c +endif + +ifndef SBGEMM_SMALL_K_B0_TN +SBGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_tn.c +endif + +ifndef SBGEMM_SMALL_K_B0_TT +SBGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_tt.c +endif + +$(KDIR)sbgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)sbgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)sbgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)sbgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@ +endif + ifndef CGEMM_SMALL_M_PERMIT CGEMM_SMALL_M_PERMIT = ../generic/zgemm_small_matrix_permit.c endif diff --git a/kernel/generic/gemm_small_matrix_kernel_nn.c b/kernel/generic/gemm_small_matrix_kernel_nn.c index 71700a1fa..543e7e047 100644 --- a/kernel/generic/gemm_small_matrix_kernel_nn.c +++ b/kernel/generic/gemm_small_matrix_kernel_nn.c @@ -28,9 +28,9 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" #ifdef B0 -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb,FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) #else -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) #endif { //naive implemtation diff --git a/kernel/generic/gemm_small_matrix_kernel_nt.c b/kernel/generic/gemm_small_matrix_kernel_nt.c index b287b3837..d4a7aec6a 100644 --- a/kernel/generic/gemm_small_matrix_kernel_nt.c +++ b/kernel/generic/gemm_small_matrix_kernel_nt.c @@ -28,9 +28,9 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" #ifdef B0 -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) #else -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) #endif { //naive implemtation diff --git a/kernel/generic/gemm_small_matrix_kernel_tn.c b/kernel/generic/gemm_small_matrix_kernel_tn.c index c41ea7211..2747337f2 100644 --- a/kernel/generic/gemm_small_matrix_kernel_tn.c +++ b/kernel/generic/gemm_small_matrix_kernel_tn.c @@ -28,9 +28,9 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" #ifdef B0 -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb,FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) #else -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) #endif { //naive implemtation diff --git a/kernel/generic/gemm_small_matrix_kernel_tt.c b/kernel/generic/gemm_small_matrix_kernel_tt.c index 734510c67..eec926bc7 100644 --- a/kernel/generic/gemm_small_matrix_kernel_tt.c +++ b/kernel/generic/gemm_small_matrix_kernel_tt.c @@ -28,9 +28,9 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" #ifdef B0 -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) #else -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) #endif { //naive implemtation diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index f303d0dc6..19b7b5f0b 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -112,6 +112,11 @@ gotoblas_t TABLE_NAME = { #else NULL,NULL, #endif +#ifdef SMALL_MATRIX_OPT + sbgemm_small_matrix_permitTS, + sbgemm_small_kernel_nnTS, sbgemm_small_kernel_ntTS, sbgemm_small_kernel_tnTS, sbgemm_small_kernel_ttTS, + sbgemm_small_kernel_b0_nnTS, sbgemm_small_kernel_b0_ntTS, sbgemm_small_kernel_b0_tnTS, sbgemm_small_kernel_b0_ttTS, +#endif #endif #if ( BUILD_SINGLE==1) || (BUILD_DOUBLE==1) || (BUILD_COMPLEX==1) || (BUILD_COMPLEX16==1) diff --git a/kernel/x86_64/KERNEL.COOPERLAKE b/kernel/x86_64/KERNEL.COOPERLAKE index 0b2f3c0ed..6272dd73d 100644 --- a/kernel/x86_64/KERNEL.COOPERLAKE +++ b/kernel/x86_64/KERNEL.COOPERLAKE @@ -1 +1,11 @@ include $(KERNELDIR)/KERNEL.SKYLAKEX + +SBGEMM_SMALL_M_PERMIT = sbgemm_small_kernel_permit_cooperlake.c +SBGEMM_SMALL_K_NN = sbgemm_small_kernel_nn_cooperlake.c +SBGEMM_SMALL_K_B0_NN = sbgemm_small_kernel_nn_cooperlake.c +SBGEMM_SMALL_K_NT = sbgemm_small_kernel_nt_cooperlake.c +SBGEMM_SMALL_K_B0_NT = sbgemm_small_kernel_nt_cooperlake.c +SBGEMM_SMALL_K_TN = sbgemm_small_kernel_tn_cooperlake.c +SBGEMM_SMALL_K_B0_TN = sbgemm_small_kernel_tn_cooperlake.c +SBGEMM_SMALL_K_TT = sbgemm_small_kernel_tt_cooperlake.c +SBGEMM_SMALL_K_B0_TT = sbgemm_small_kernel_tt_cooperlake.c diff --git a/kernel/x86_64/sbgemm_block_microk_cooperlake.c b/kernel/x86_64/sbgemm_block_microk_cooperlake.c index 147c5ebdd..2c27221ac 100644 --- a/kernel/x86_64/sbgemm_block_microk_cooperlake.c +++ b/kernel/x86_64/sbgemm_block_microk_cooperlake.c @@ -1,6 +1,5 @@ -//#include "sbgemm.h" - #include + // Walk around those intrinsics that missed by compiler #define MM256_LOADU_EPI16(addr) \ _mm256_maskz_loadu_epi16(~0, (addr)) @@ -1747,7 +1746,7 @@ void COL_MAJOR_OTCOPY_KERNEL_Kx8m(BLASLONG k, BLASLONG n, bfloat16 * B, BLASLONG } // Scale matrix C when beta is not ZERO or ONE -void sbgemm_scal_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc) +void sbgemm_scal_operation(BLASLONG M, BLASLONG N, float beta, float *C, BLASLONG ldc) { float * C_addr0 = C; float * C_addr1 = C + ldc; @@ -1759,12 +1758,6 @@ void sbgemm_scal_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST __m512 array_512_0, array_512_1, array_512_2, array_512_3; __m512 BETAVECTOR = _mm512_set1_ps(beta); - if (Order == CblasRowMajor) { - blasint tmp = M; - M = N; - N = tmp; - } - BLASLONG tag_n_Nx = N & (~3); BLASLONG tag_n_Mx = M & (~15); unsigned short tail_mask = (((unsigned short)0xffff) >> (16-M+tag_n_Mx)); @@ -1828,7 +1821,7 @@ void sbgemm_scal_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST } // Zero C matrix when Beta is 0 -void sbgemm_zero_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, float *C, OPENBLAS_CONST blasint ldc) +void sbgemm_zero_operation(BLASLONG M, BLASLONG N, float *C, BLASLONG ldc) { float * C_addr0 = C; float * C_addr1 = C + ldc; @@ -1839,12 +1832,6 @@ void sbgemm_zero_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST __m512 ZEROVECTOR = _mm512_setzero_ps(); - if (Order == CblasRowMajor) { - blasint tmp = M; - M = N; - N = tmp; - } - BLASLONG tag_n_Nx = N & (~3); BLASLONG tag_n_Mx = M & (~15); unsigned short tail_mask = (((unsigned short)0xffff) >> (16-M+tag_n_Mx)); diff --git a/kernel/x86_64/sbgemm_microk_cooperlake_template.c b/kernel/x86_64/sbgemm_microk_cooperlake_template.c index c71595813..b8ed9838e 100644 --- a/kernel/x86_64/sbgemm_microk_cooperlake_template.c +++ b/kernel/x86_64/sbgemm_microk_cooperlake_template.c @@ -1,8 +1,6 @@ -#include "sbgemm.h" #include "bf16_common_macros.h" #include -/* These macros are needed and should be placed at the right place #define BF16_BLOCK_STEP_N 8 #define BF16_BLOCK_THRES_K 1024 #define BF16_BLOCK_THRES_M 32 @@ -14,7 +12,6 @@ #define ONE 1.e0f #define ZERO 0.e0f -*/ #undef STORE16_COMPLETE_RESULT #undef STORE16_MASK_COMPLETE_RESULT @@ -1798,6 +1795,7 @@ void sbgemm_blocking_kernel_tt_one(blasint M, blasint N, blasint K, float alpha, } /* ----------------------------------------- End of TT kernels --------------------------------------- */ +/* #ifndef ONE_ALPHA // ALPHA is not ONE void sbgemm_internal_kernel_alpha(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, float *C, OPENBLAS_CONST blasint ldc) @@ -1836,3 +1834,4 @@ void sbgemm_internal_kernel_one(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_ } } } +*/ diff --git a/kernel/x86_64/sbgemm_small_kernel_nn_cooperlake.c b/kernel/x86_64/sbgemm_small_kernel_nn_cooperlake.c new file mode 100644 index 000000000..ec40a5054 --- /dev/null +++ b/kernel/x86_64/sbgemm_small_kernel_nn_cooperlake.c @@ -0,0 +1,2 @@ +#define TRANS_NN +#include "sbgemm_small_kernel_template_cooperlake.c" diff --git a/kernel/x86_64/sbgemm_small_kernel_nt_cooperlake.c b/kernel/x86_64/sbgemm_small_kernel_nt_cooperlake.c new file mode 100644 index 000000000..1cdfd2936 --- /dev/null +++ b/kernel/x86_64/sbgemm_small_kernel_nt_cooperlake.c @@ -0,0 +1,2 @@ +#define TRANS_NT +#include "sbgemm_small_kernel_template_cooperlake.c" diff --git a/kernel/x86_64/sbgemm_small_kernel_permit_cooperlake.c b/kernel/x86_64/sbgemm_small_kernel_permit_cooperlake.c new file mode 100644 index 000000000..823aafbdd --- /dev/null +++ b/kernel/x86_64/sbgemm_small_kernel_permit_cooperlake.c @@ -0,0 +1,42 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#include "sbgemm_block_microk_cooperlake.c" +// Define micro kernels for ALPHA not ONE scenarios +#undef ONE_ALPHA +#include "sbgemm_microk_cooperlake_template.c" + +// Define micro kernels for ALPHA as ONE scenarios +#define ONE_ALPHA 1 +#include "sbgemm_microk_cooperlake_template.c" + +int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta) +{ + return 1; +} diff --git a/kernel/x86_64/sbgemm_small_kernel_template_cooperlake.c b/kernel/x86_64/sbgemm_small_kernel_template_cooperlake.c new file mode 100644 index 000000000..1ab7a34ab --- /dev/null +++ b/kernel/x86_64/sbgemm_small_kernel_template_cooperlake.c @@ -0,0 +1,96 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" +#include + +extern void sbgemm_scal_operation(BLASLONG M, BLASLONG N, float beta, float *C, BLASLONG ldc); +extern void sbgemm_zero_operation(BLASLONG M, BLASLONG N, float *C, BLASLONG ldc); + +extern void sbgemm_blocking_kernel_nn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); +extern void sbgemm_blocking_kernel_nn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); +extern void sbgemm_blocking_kernel_nt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); +extern void sbgemm_blocking_kernel_nt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); +extern void sbgemm_blocking_kernel_tn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); +extern void sbgemm_blocking_kernel_tn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); +extern void sbgemm_blocking_kernel_tt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); +extern void sbgemm_blocking_kernel_tt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); + +#if defined(TRANS_NN) +#define SBGEMM_BLOCKING_KERNEL_ONE sbgemm_blocking_kernel_nn_one +#define SBGEMM_BLOCKING_KERNEL_ALPHA sbgemm_blocking_kernel_nn_alpha +#elif defined(TRANS_NT) +#define SBGEMM_BLOCKING_KERNEL_ONE sbgemm_blocking_kernel_nt_one +#define SBGEMM_BLOCKING_KERNEL_ALPHA sbgemm_blocking_kernel_nt_alpha +#elif defined(TRANS_TN) +#define SBGEMM_BLOCKING_KERNEL_ONE sbgemm_blocking_kernel_tn_one +#define SBGEMM_BLOCKING_KERNEL_ALPHA sbgemm_blocking_kernel_tn_alpha +#elif defined(TRANS_TT) +#define SBGEMM_BLOCKING_KERNEL_ONE sbgemm_blocking_kernel_tt_one +#define SBGEMM_BLOCKING_KERNEL_ALPHA sbgemm_blocking_kernel_tt_alpha +#endif + +#define BF16_BLOCK_THRES_K 1024 +// If we want to adjust this to be bigger, need to change COL_MAJOR_INCOPY_KERNEL_Kx32 kernel to be bigger also +#define BF16_BLOCK_THRES_M 32 +#define BF16_BLOCK_THRES_N 1024 + +#define MALLOC_ALIGN64(ptr, size, raw_ptr) \ + raw_ptr = malloc((size) + 63); \ + ptr = (bfloat16 *)(((uintptr_t) raw_ptr + 63) & ~(uintptr_t)63) + + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + bfloat16 * block_A; + bfloat16 * block_B; + void* raw_ptrA; + void* raw_ptrB; + + MALLOC_ALIGN64(block_A, sizeof(bfloat16) * BF16_BLOCK_THRES_K * BF16_BLOCK_THRES_M, raw_ptrA); + MALLOC_ALIGN64(block_B, sizeof(bfloat16) * BF16_BLOCK_THRES_N * BF16_BLOCK_THRES_K, raw_ptrB); + +#if defined(B0) + sbgemm_zero_operation(M, N, C, ldc); +#else + sbgemm_scal_operation(M, N, beta, C, ldc); +#endif + + if (alpha == ONE) { + SBGEMM_BLOCKING_KERNEL_ONE(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B); + } else { + SBGEMM_BLOCKING_KERNEL_ALPHA(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B); + } + + free(raw_ptrA); + free(raw_ptrB); + return 0; +} diff --git a/kernel/x86_64/sbgemm_small_kernel_tn_cooperlake.c b/kernel/x86_64/sbgemm_small_kernel_tn_cooperlake.c new file mode 100644 index 000000000..f1a0d0d0c --- /dev/null +++ b/kernel/x86_64/sbgemm_small_kernel_tn_cooperlake.c @@ -0,0 +1,2 @@ +#define TRANS_TN +#include "sbgemm_small_kernel_template_cooperlake.c" diff --git a/kernel/x86_64/sbgemm_small_kernel_tt_cooperlake.c b/kernel/x86_64/sbgemm_small_kernel_tt_cooperlake.c new file mode 100644 index 000000000..8a2a597bc --- /dev/null +++ b/kernel/x86_64/sbgemm_small_kernel_tt_cooperlake.c @@ -0,0 +1,2 @@ +#define TRANS_TT +#include "sbgemm_small_kernel_template_cooperlake.c"