RFC : Add half precision gemm for bfloat16 in OpenBLAS
This patch adds support for bfloat16 data type matrix multiplication kernel. For architectures that don't support bfloat16, it is defined as unsigned short (2 bytes). Default unroll sizes can be changed as per architecture as done for SGEMM and for now 8 and 4 are used for M and N. Size of ncopy/tcopy can be changed as per architecture requirement and for now, size 2 is used. Added shgemm in kernel/power/KERNEL.POWER9 and tested in powerpc64le and powerpc64. For reference, added a small test compare_sgemm_shgemm.c to compare sgemm and shgemm output. This patch does not cover OpenBLAS test, benchmark and lapack tests for shgemm. Complex type implementation can be discussed and added once this is approved.
This commit is contained in:
@@ -77,7 +77,7 @@
|
||||
#define GEMM_MULTITHREAD_THRESHOLD 4
|
||||
#endif
|
||||
|
||||
static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLASLONG) = {
|
||||
static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = {
|
||||
#ifndef GEMM3M
|
||||
GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
|
||||
GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
|
||||
@@ -108,8 +108,8 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLA
|
||||
void NAME(char *TRANSA, char *TRANSB,
|
||||
blasint *M, blasint *N, blasint *K,
|
||||
FLOAT *alpha,
|
||||
FLOAT *a, blasint *ldA,
|
||||
FLOAT *b, blasint *ldB,
|
||||
IFLOAT *a, blasint *ldA,
|
||||
IFLOAT *b, blasint *ldB,
|
||||
FLOAT *beta,
|
||||
FLOAT *c, blasint *ldC){
|
||||
|
||||
@@ -119,8 +119,8 @@ void NAME(char *TRANSA, char *TRANSB,
|
||||
blasint info;
|
||||
|
||||
char transA, transB;
|
||||
FLOAT *buffer;
|
||||
FLOAT *sa, *sb;
|
||||
IFLOAT *buffer;
|
||||
IFLOAT *sa, *sb;
|
||||
|
||||
#ifdef SMP
|
||||
double MNK;
|
||||
|
||||
Reference in New Issue
Block a user