Merge pull request #3353 from guowangy/bf16-small-matrix-cooperlake

Enable existing SBGEMM kernel for Cooperlake by small-matrix path
This commit is contained in:
Martin Kroeker 2021-08-30 20:39:51 +02:00 committed by GitHub
commit 806221440b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 302 additions and 41 deletions

View File

@ -400,6 +400,8 @@ void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPE
float cblas_sbdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy);
void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum CBLAS_TRANSPOSE trans, OPENBLAS_CONST blasint m, OPENBLAS_CONST blasint n, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *a, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST float beta, float *y, OPENBLAS_CONST blasint incy);
void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);
#ifdef __cplusplus
}
#endif /* __cplusplus */

View File

@ -516,6 +516,13 @@ int qgemm_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, xdouble *, xd
#endif
#ifdef SMALL_MATRIX_OPT
int sbgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float beta);
int sbgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc);
int sbgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc);
int sbgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc);
int sbgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc);
int sgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float beta);
int sgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc);
@ -530,6 +537,11 @@ int dgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLO
int dgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc);
int dgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc);
int sbgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc);
int sbgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc);
int sbgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc);
int sbgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc);
int sgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc);
int sgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc);
int sgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc);

View File

@ -942,17 +942,17 @@
#define GEADD_K SGEADD_K
#define GEMM_SMALL_MATRIX_PERMIT SGEMM_SMALL_MATRIX_PERMIT
#define GEMM_SMALL_MATRIX_PERMIT SBGEMM_SMALL_MATRIX_PERMIT
#define GEMM_SMALL_KERNEL_NN SGEMM_SMALL_KERNEL_NN
#define GEMM_SMALL_KERNEL_NT SGEMM_SMALL_KERNEL_NT
#define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN
#define GEMM_SMALL_KERNEL_TT SGEMM_SMALL_KERNEL_TT
#define GEMM_SMALL_KERNEL_NN SBGEMM_SMALL_KERNEL_NN
#define GEMM_SMALL_KERNEL_NT SBGEMM_SMALL_KERNEL_NT
#define GEMM_SMALL_KERNEL_TN SBGEMM_SMALL_KERNEL_TN
#define GEMM_SMALL_KERNEL_TT SBGEMM_SMALL_KERNEL_TT
#define GEMM_SMALL_KERNEL_B0_NN SGEMM_SMALL_KERNEL_B0_NN
#define GEMM_SMALL_KERNEL_B0_NT SGEMM_SMALL_KERNEL_B0_NT
#define GEMM_SMALL_KERNEL_B0_TN SGEMM_SMALL_KERNEL_B0_TN
#define GEMM_SMALL_KERNEL_B0_TT SGEMM_SMALL_KERNEL_B0_TT
#define GEMM_SMALL_KERNEL_B0_NN SBGEMM_SMALL_KERNEL_B0_NN
#define GEMM_SMALL_KERNEL_B0_NT SBGEMM_SMALL_KERNEL_B0_NT
#define GEMM_SMALL_KERNEL_B0_TN SBGEMM_SMALL_KERNEL_B0_TN
#define GEMM_SMALL_KERNEL_B0_TT SBGEMM_SMALL_KERNEL_B0_TT
#endif

View File

@ -145,6 +145,19 @@ BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);
int (*sbneg_tcopy) (BLASLONG, BLASLONG, float *, BLASLONG, float *);
int (*sblaswp_ncopy) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG, blasint *, float *);
#ifdef SMALL_MATRIX_OPT
int (*sbgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float beta);
int (*sbgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc);
int (*sbgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc);
int (*sbgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc);
int (*sbgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc);
int (*sbgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc);
int (*sbgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc);
int (*sbgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc);
int (*sbgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc);
#endif
#endif
#if defined(BUILD_SINGLE) || defined(BUILD_COMPLEX)

View File

@ -24,6 +24,7 @@
#define SBGEMM_BETA sbgemm_beta
#define SBGEMM_KERNEL sbgemm_kernel
#define SBGEMM_SMALL_MATRIX_PERMIT sbgemm_small_matrix_permit
#else
#define SBDOT_K gotoblas -> sbdot_k
@ -41,8 +42,19 @@
#define SBGEMM_BETA gotoblas -> sbgemm_beta
#define SBGEMM_KERNEL gotoblas -> sbgemm_kernel
#define SBGEMM_SMALL_MATRIX_PERMIT gotoblas -> sbgemm_small_matrix_permit
#endif
#define SBGEMM_SMALL_KERNEL_NN FUNC_OFFSET(sbgemm_small_kernel_nn)
#define SBGEMM_SMALL_KERNEL_NT FUNC_OFFSET(sbgemm_small_kernel_nt)
#define SBGEMM_SMALL_KERNEL_TN FUNC_OFFSET(sbgemm_small_kernel_tn)
#define SBGEMM_SMALL_KERNEL_TT FUNC_OFFSET(sbgemm_small_kernel_tt)
#define SBGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(sbgemm_small_kernel_b0_nn)
#define SBGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(sbgemm_small_kernel_b0_nt)
#define SBGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(sbgemm_small_kernel_b0_tn)
#define SBGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(sbgemm_small_kernel_b0_tt)
#define SBGEMM_NN sbgemm_nn
#define SBGEMM_CN sbgemm_tn
#define SBGEMM_TN sbgemm_tn

View File

@ -105,7 +105,7 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, B
#endif
};
#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE) && !defined(BFLOAT16)
#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE)
#define USE_SMALL_MATRIX_OPT 1
#else
#define USE_SMALL_MATRIX_OPT 0
@ -131,8 +131,8 @@ static size_t gemm_small_kernel_b0[] = {
GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, 0, 0,
};
#define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx))
#define GEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel, (idx))
#define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx))
#define GEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel, (idx))
#else
static size_t zgemm_small_kernel[] = {
@ -273,8 +273,8 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
blasint m, blasint n, blasint k,
#ifndef COMPLEX
FLOAT alpha,
FLOAT *a, blasint lda,
FLOAT *b, blasint ldb,
IFLOAT *a, blasint lda,
IFLOAT *b, blasint ldb,
FLOAT beta,
FLOAT *c, blasint ldc) {
#else

View File

@ -450,6 +450,15 @@ endif
###### BLAS small matrix optimization #####
ifeq ($(SMALL_MATRIX_OPT), 1)
ifeq ($(BUILD_BFLOAT16),1)
SBBLASOBJS += \
sbgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \
sbgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \
sbgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \
sbgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \
sbgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX)
endif
SBLASOBJS += \
sgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \
sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \
@ -4424,6 +4433,72 @@ $(KDIR)sgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL
$(KDIR)sgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_TT)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DB0 $< -o $@
ifeq ($(BUILD_BFLOAT16), 1)
ifndef SBGEMM_SMALL_M_PERMIT
SBGEMM_SMALL_M_PERMIT = ../generic/gemm_small_matrix_permit.c
endif
ifndef SBGEMM_SMALL_K_NN
SBGEMM_SMALL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c
endif
ifndef SBGEMM_SMALL_K_NT
SBGEMM_SMALL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c
endif
ifndef SBGEMM_SMALL_K_TN
SBGEMM_SMALL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c
endif
ifndef SBGEMM_SMALL_K_TT
SBGEMM_SMALL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c
endif
$(KDIR)sbgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_M_PERMIT)
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)sbgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_NN)
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)sbgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_NT)
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)sbgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_TN)
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)sbgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_TT)
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
ifndef SBGEMM_SMALL_K_B0_NN
SBGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_nn.c
endif
ifndef SBGEMM_SMALL_K_B0_NT
SBGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_nt.c
endif
ifndef SBGEMM_SMALL_K_B0_TN
SBGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_tn.c
endif
ifndef SBGEMM_SMALL_K_B0_TT
SBGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_tt.c
endif
$(KDIR)sbgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_B0_NN)
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@
$(KDIR)sbgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_B0_NT)
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@
$(KDIR)sbgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_B0_TN)
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@
$(KDIR)sbgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_B0_TT)
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@
endif
ifndef CGEMM_SMALL_M_PERMIT
CGEMM_SMALL_M_PERMIT = ../generic/zgemm_small_matrix_permit.c
endif

View File

@ -28,9 +28,9 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "common.h"
#ifdef B0
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb,FLOAT * C, BLASLONG ldc)
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc)
#else
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc)
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc)
#endif
{
//naive implemtation

View File

@ -28,9 +28,9 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "common.h"
#ifdef B0
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc)
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc)
#else
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc)
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc)
#endif
{
//naive implemtation

View File

@ -28,9 +28,9 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "common.h"
#ifdef B0
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb,FLOAT * C, BLASLONG ldc)
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc)
#else
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc)
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc)
#endif
{
//naive implemtation

View File

@ -28,9 +28,9 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "common.h"
#ifdef B0
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc)
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc)
#else
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc)
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc)
#endif
{
//naive implemtation

View File

@ -112,6 +112,11 @@ gotoblas_t TABLE_NAME = {
#else
NULL,NULL,
#endif
#ifdef SMALL_MATRIX_OPT
sbgemm_small_matrix_permitTS,
sbgemm_small_kernel_nnTS, sbgemm_small_kernel_ntTS, sbgemm_small_kernel_tnTS, sbgemm_small_kernel_ttTS,
sbgemm_small_kernel_b0_nnTS, sbgemm_small_kernel_b0_ntTS, sbgemm_small_kernel_b0_tnTS, sbgemm_small_kernel_b0_ttTS,
#endif
#endif
#if ( BUILD_SINGLE==1) || (BUILD_DOUBLE==1) || (BUILD_COMPLEX==1) || (BUILD_COMPLEX16==1)

View File

@ -1 +1,11 @@
include $(KERNELDIR)/KERNEL.SKYLAKEX
SBGEMM_SMALL_M_PERMIT = sbgemm_small_kernel_permit_cooperlake.c
SBGEMM_SMALL_K_NN = sbgemm_small_kernel_nn_cooperlake.c
SBGEMM_SMALL_K_B0_NN = sbgemm_small_kernel_nn_cooperlake.c
SBGEMM_SMALL_K_NT = sbgemm_small_kernel_nt_cooperlake.c
SBGEMM_SMALL_K_B0_NT = sbgemm_small_kernel_nt_cooperlake.c
SBGEMM_SMALL_K_TN = sbgemm_small_kernel_tn_cooperlake.c
SBGEMM_SMALL_K_B0_TN = sbgemm_small_kernel_tn_cooperlake.c
SBGEMM_SMALL_K_TT = sbgemm_small_kernel_tt_cooperlake.c
SBGEMM_SMALL_K_B0_TT = sbgemm_small_kernel_tt_cooperlake.c

View File

@ -1,6 +1,5 @@
//#include "sbgemm.h"
#include <immintrin.h>
// Walk around those intrinsics that missed by compiler
#define MM256_LOADU_EPI16(addr) \
_mm256_maskz_loadu_epi16(~0, (addr))
@ -1747,7 +1746,7 @@ void COL_MAJOR_OTCOPY_KERNEL_Kx8m(BLASLONG k, BLASLONG n, bfloat16 * B, BLASLONG
}
// Scale matrix C when beta is not ZERO or ONE
void sbgemm_scal_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc)
void sbgemm_scal_operation(BLASLONG M, BLASLONG N, float beta, float *C, BLASLONG ldc)
{
float * C_addr0 = C;
float * C_addr1 = C + ldc;
@ -1759,12 +1758,6 @@ void sbgemm_scal_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST
__m512 array_512_0, array_512_1, array_512_2, array_512_3;
__m512 BETAVECTOR = _mm512_set1_ps(beta);
if (Order == CblasRowMajor) {
blasint tmp = M;
M = N;
N = tmp;
}
BLASLONG tag_n_Nx = N & (~3);
BLASLONG tag_n_Mx = M & (~15);
unsigned short tail_mask = (((unsigned short)0xffff) >> (16-M+tag_n_Mx));
@ -1828,7 +1821,7 @@ void sbgemm_scal_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST
}
// Zero C matrix when Beta is 0
void sbgemm_zero_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, float *C, OPENBLAS_CONST blasint ldc)
void sbgemm_zero_operation(BLASLONG M, BLASLONG N, float *C, BLASLONG ldc)
{
float * C_addr0 = C;
float * C_addr1 = C + ldc;
@ -1839,12 +1832,6 @@ void sbgemm_zero_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST
__m512 ZEROVECTOR = _mm512_setzero_ps();
if (Order == CblasRowMajor) {
blasint tmp = M;
M = N;
N = tmp;
}
BLASLONG tag_n_Nx = N & (~3);
BLASLONG tag_n_Mx = M & (~15);
unsigned short tail_mask = (((unsigned short)0xffff) >> (16-M+tag_n_Mx));

View File

@ -1,8 +1,6 @@
#include "sbgemm.h"
#include "bf16_common_macros.h"
#include <immintrin.h>
/* These macros are needed and should be placed at the right place
#define BF16_BLOCK_STEP_N 8
#define BF16_BLOCK_THRES_K 1024
#define BF16_BLOCK_THRES_M 32
@ -14,7 +12,6 @@
#define ONE 1.e0f
#define ZERO 0.e0f
*/
#undef STORE16_COMPLETE_RESULT
#undef STORE16_MASK_COMPLETE_RESULT
@ -1798,6 +1795,7 @@ void sbgemm_blocking_kernel_tt_one(blasint M, blasint N, blasint K, float alpha,
}
/* ----------------------------------------- End of TT kernels --------------------------------------- */
/*
#ifndef ONE_ALPHA // ALPHA is not ONE
void sbgemm_internal_kernel_alpha(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, float *C, OPENBLAS_CONST blasint ldc)
@ -1836,3 +1834,4 @@ void sbgemm_internal_kernel_one(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_
}
}
}
*/

View File

@ -0,0 +1,2 @@
#define TRANS_NN
#include "sbgemm_small_kernel_template_cooperlake.c"

View File

@ -0,0 +1,2 @@
#define TRANS_NT
#include "sbgemm_small_kernel_template_cooperlake.c"

View File

@ -0,0 +1,42 @@
/***************************************************************************
Copyright (c) 2021, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#include "common.h"
#include "sbgemm_block_microk_cooperlake.c"
// Define micro kernels for ALPHA not ONE scenarios
#undef ONE_ALPHA
#include "sbgemm_microk_cooperlake_template.c"
// Define micro kernels for ALPHA as ONE scenarios
#define ONE_ALPHA 1
#include "sbgemm_microk_cooperlake_template.c"
int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta)
{
return 1;
}

View File

@ -0,0 +1,96 @@
/***************************************************************************
Copyright (c) 2021, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#include "common.h"
#include <memory.h>
extern void sbgemm_scal_operation(BLASLONG M, BLASLONG N, float beta, float *C, BLASLONG ldc);
extern void sbgemm_zero_operation(BLASLONG M, BLASLONG N, float *C, BLASLONG ldc);
extern void sbgemm_blocking_kernel_nn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B);
extern void sbgemm_blocking_kernel_nn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B);
extern void sbgemm_blocking_kernel_nt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B);
extern void sbgemm_blocking_kernel_nt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B);
extern void sbgemm_blocking_kernel_tn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B);
extern void sbgemm_blocking_kernel_tn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B);
extern void sbgemm_blocking_kernel_tt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B);
extern void sbgemm_blocking_kernel_tt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B);
#if defined(TRANS_NN)
#define SBGEMM_BLOCKING_KERNEL_ONE sbgemm_blocking_kernel_nn_one
#define SBGEMM_BLOCKING_KERNEL_ALPHA sbgemm_blocking_kernel_nn_alpha
#elif defined(TRANS_NT)
#define SBGEMM_BLOCKING_KERNEL_ONE sbgemm_blocking_kernel_nt_one
#define SBGEMM_BLOCKING_KERNEL_ALPHA sbgemm_blocking_kernel_nt_alpha
#elif defined(TRANS_TN)
#define SBGEMM_BLOCKING_KERNEL_ONE sbgemm_blocking_kernel_tn_one
#define SBGEMM_BLOCKING_KERNEL_ALPHA sbgemm_blocking_kernel_tn_alpha
#elif defined(TRANS_TT)
#define SBGEMM_BLOCKING_KERNEL_ONE sbgemm_blocking_kernel_tt_one
#define SBGEMM_BLOCKING_KERNEL_ALPHA sbgemm_blocking_kernel_tt_alpha
#endif
#define BF16_BLOCK_THRES_K 1024
// If we want to adjust this to be bigger, need to change COL_MAJOR_INCOPY_KERNEL_Kx32 kernel to be bigger also
#define BF16_BLOCK_THRES_M 32
#define BF16_BLOCK_THRES_N 1024
#define MALLOC_ALIGN64(ptr, size, raw_ptr) \
raw_ptr = malloc((size) + 63); \
ptr = (bfloat16 *)(((uintptr_t) raw_ptr + 63) & ~(uintptr_t)63)
#if defined(B0)
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc)
#else
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc)
#endif
{
bfloat16 * block_A;
bfloat16 * block_B;
void* raw_ptrA;
void* raw_ptrB;
MALLOC_ALIGN64(block_A, sizeof(bfloat16) * BF16_BLOCK_THRES_K * BF16_BLOCK_THRES_M, raw_ptrA);
MALLOC_ALIGN64(block_B, sizeof(bfloat16) * BF16_BLOCK_THRES_N * BF16_BLOCK_THRES_K, raw_ptrB);
#if defined(B0)
sbgemm_zero_operation(M, N, C, ldc);
#else
sbgemm_scal_operation(M, N, beta, C, ldc);
#endif
if (alpha == ONE) {
SBGEMM_BLOCKING_KERNEL_ONE(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
} else {
SBGEMM_BLOCKING_KERNEL_ALPHA(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
}
free(raw_ptrA);
free(raw_ptrB);
return 0;
}

View File

@ -0,0 +1,2 @@
#define TRANS_TN
#include "sbgemm_small_kernel_template_cooperlake.c"

View File

@ -0,0 +1,2 @@
#define TRANS_TT
#include "sbgemm_small_kernel_template_cooperlake.c"