RFC : Add half precision gemm for bfloat16 in OpenBLAS
This patch adds support for bfloat16 data type matrix multiplication kernel. For architectures that don't support bfloat16, it is defined as unsigned short (2 bytes). Default unroll sizes can be changed as per architecture as done for SGEMM and for now 8 and 4 are used for M and N. Size of ncopy/tcopy can be changed as per architecture requirement and for now, size 2 is used. Added shgemm in kernel/power/KERNEL.POWER9 and tested in powerpc64le and powerpc64. For reference, added a small test compare_sgemm_shgemm.c to compare sgemm and shgemm output. This patch does not cover OpenBLAS test, benchmark and lapack tests for shgemm. Complex type implementation can be discussed and added once this is approved.
This commit is contained in:
@@ -19,6 +19,7 @@ ifeq ($(ARCH), MIPS)
|
||||
USE_GEMM3M = 1
|
||||
endif
|
||||
|
||||
SHBLASOBJS += shgemm_nn.$(SUFFIX) shgemm_nt.$(SUFFIX) shgemm_tn.$(SUFFIX) shgemm_tt.$(SUFFIX)
|
||||
SBLASOBJS += \
|
||||
sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \
|
||||
strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \
|
||||
@@ -204,6 +205,7 @@ COMMONOBJS += syrk_thread.$(SUFFIX)
|
||||
|
||||
ifndef USE_SIMPLE_THREADED_LEVEL3
|
||||
|
||||
SHBLASOBJS += shgemm_thread_nn.$(SUFFIX) shgemm_thread_nt.$(SUFFIX) shgemm_thread_tn.$(SUFFIX) shgemm_thread_tt.$(SUFFIX)
|
||||
SBLASOBJS += sgemm_thread_nn.$(SUFFIX) sgemm_thread_nt.$(SUFFIX) sgemm_thread_tn.$(SUFFIX) sgemm_thread_tt.$(SUFFIX)
|
||||
DBLASOBJS += dgemm_thread_nn.$(SUFFIX) dgemm_thread_nt.$(SUFFIX) dgemm_thread_tn.$(SUFFIX) dgemm_thread_tt.$(SUFFIX)
|
||||
QBLASOBJS += qgemm_thread_nn.$(SUFFIX) qgemm_thread_nt.$(SUFFIX) qgemm_thread_tn.$(SUFFIX) qgemm_thread_tt.$(SUFFIX)
|
||||
@@ -283,6 +285,18 @@ endif
|
||||
|
||||
all ::
|
||||
|
||||
shgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
|
||||
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
|
||||
|
||||
shgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h
|
||||
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
|
||||
|
||||
shgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h
|
||||
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
|
||||
|
||||
shgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h
|
||||
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
|
||||
|
||||
sgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
|
||||
$(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
|
||||
|
||||
@@ -478,6 +492,17 @@ gemm_thread_variable.$(SUFFIX) : gemm_thread_variable.c ../../common.h
|
||||
beta_thread.$(SUFFIX) : beta_thread.c ../../common.h
|
||||
$(CC) -c $(CFLAGS) $< -o $(@F)
|
||||
|
||||
shgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
|
||||
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
|
||||
|
||||
shgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
|
||||
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
|
||||
|
||||
shgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
|
||||
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
|
||||
|
||||
shgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
|
||||
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
|
||||
|
||||
sgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
|
||||
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
|
||||
@@ -2652,6 +2677,18 @@ xtrsm_RCLU.$(SUFFIX) : trsm_R.c
|
||||
xtrsm_RCLN.$(SUFFIX) : trsm_R.c
|
||||
$(CC) -c $(CFLAGS) -DCOMPLEX -DXDOUBLE -DTRANSA -UUPPER -UUNIT -DCONJ $< -o $(@F)
|
||||
|
||||
shgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
|
||||
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
|
||||
|
||||
shgemm_nt.$(PSUFFIX) : gemm.c level3.c ../../param.h
|
||||
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
|
||||
|
||||
shgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h
|
||||
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
|
||||
|
||||
shgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h
|
||||
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
|
||||
|
||||
sgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
|
||||
$(CC) $(PFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
|
||||
|
||||
@@ -2848,6 +2885,18 @@ beta_thread.$(PSUFFIX) : beta_thread.c ../../common.h
|
||||
$(CC) -c $(PFLAGS) $< -o $(@F)
|
||||
|
||||
|
||||
shgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
|
||||
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
|
||||
|
||||
shgemm_thread_nt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
|
||||
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
|
||||
|
||||
shgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
|
||||
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
|
||||
|
||||
shgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
|
||||
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
|
||||
|
||||
sgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
|
||||
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
|
||||
|
||||
|
||||
@@ -62,18 +62,18 @@
|
||||
#ifndef ICOPY_OPERATION
|
||||
#if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \
|
||||
defined(RN) || defined(RT) || defined(RC) || defined(RR)
|
||||
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#else
|
||||
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef OCOPY_OPERATION
|
||||
#if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \
|
||||
defined(NR) || defined(TR) || defined(CR) || defined(RR)
|
||||
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#else
|
||||
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -173,7 +173,8 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
|
||||
XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){
|
||||
BLASLONG k, lda, ldb, ldc;
|
||||
FLOAT *alpha, *beta;
|
||||
FLOAT *a, *b, *c;
|
||||
IFLOAT *a, *b;
|
||||
FLOAT *c;
|
||||
BLASLONG m_from, m_to, n_from, n_to;
|
||||
|
||||
BLASLONG ls, is, js;
|
||||
@@ -198,8 +199,8 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
|
||||
|
||||
k = K;
|
||||
|
||||
a = (FLOAT *)A;
|
||||
b = (FLOAT *)B;
|
||||
a = (IFLOAT *)A;
|
||||
b = (IFLOAT *)B;
|
||||
c = (FLOAT *)C;
|
||||
|
||||
lda = LDA;
|
||||
|
||||
@@ -117,18 +117,18 @@ typedef struct {
|
||||
#ifndef ICOPY_OPERATION
|
||||
#if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \
|
||||
defined(RN) || defined(RT) || defined(RC) || defined(RR)
|
||||
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#else
|
||||
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef OCOPY_OPERATION
|
||||
#if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \
|
||||
defined(NR) || defined(TR) || defined(CR) || defined(RR)
|
||||
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#else
|
||||
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -219,15 +219,16 @@ typedef struct {
|
||||
#define STOP_RPCC(COUNTER)
|
||||
#endif
|
||||
|
||||
static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){
|
||||
static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){
|
||||
|
||||
FLOAT *buffer[DIVIDE_RATE];
|
||||
IFLOAT *buffer[DIVIDE_RATE];
|
||||
|
||||
BLASLONG k, lda, ldb, ldc;
|
||||
BLASLONG m_from, m_to, n_from, n_to;
|
||||
|
||||
FLOAT *alpha, *beta;
|
||||
FLOAT *a, *b, *c;
|
||||
IFLOAT *a, *b;
|
||||
FLOAT *c;
|
||||
job_t *job = (job_t *)args -> common;
|
||||
|
||||
BLASLONG nthreads_m;
|
||||
@@ -255,8 +256,8 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
|
||||
|
||||
k = K;
|
||||
|
||||
a = (FLOAT *)A;
|
||||
b = (FLOAT *)B;
|
||||
a = (IFLOAT *)A;
|
||||
b = (IFLOAT *)B;
|
||||
c = (FLOAT *)C;
|
||||
|
||||
lda = LDA;
|
||||
@@ -425,7 +426,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
|
||||
/* Apply kernel with local region of A and part of other region of B */
|
||||
START_RPCC();
|
||||
KERNEL_OPERATION(min_i, MIN(range_n[current + 1] - js, div_n), min_l, alpha,
|
||||
sa, (FLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside],
|
||||
sa, (IFLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside],
|
||||
c, ldc, m_from, js);
|
||||
STOP_RPCC(kernel);
|
||||
|
||||
@@ -469,7 +470,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
|
||||
/* Apply kernel with local region of A and part of region of B */
|
||||
START_RPCC();
|
||||
KERNEL_OPERATION(min_i, MIN(range_n[current + 1] - js, div_n), min_l, alpha,
|
||||
sa, (FLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside],
|
||||
sa, (IFLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside],
|
||||
c, ldc, is, js);
|
||||
STOP_RPCC(kernel);
|
||||
|
||||
@@ -532,7 +533,7 @@ static int round_up(int remainder, int width, int multiple)
|
||||
|
||||
|
||||
static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
|
||||
*range_n, FLOAT *sa, FLOAT *sb,
|
||||
*range_n, IFLOAT *sa, IFLOAT *sb,
|
||||
BLASLONG nthreads_m, BLASLONG nthreads_n) {
|
||||
|
||||
#ifndef USE_OPENMP
|
||||
@@ -728,7 +729,7 @@ EnterCriticalSection((PCRITICAL_SECTION)&level3_lock);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){
|
||||
int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){
|
||||
|
||||
BLASLONG m = args -> m;
|
||||
BLASLONG n = args -> n;
|
||||
|
||||
Reference in New Issue
Block a user