Merge pull request #2796 from Guobing-Chen/BF16_dot_coversion_apis
Add bfloat16 based dot and conversion with single/double
This commit is contained in:
commit
91c84e1c01
|
@ -5,13 +5,14 @@ QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
|
|||
CBLASOBJS_P = $(CBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
|
||||
ZBLASOBJS_P = $(ZBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
|
||||
XBLASOBJS_P = $(XBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
|
||||
SHEXTOBJS_P = $(SHEXTOBJS:.$(SUFFIX)=.$(PSUFFIX))
|
||||
|
||||
COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))
|
||||
|
||||
HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))
|
||||
|
||||
BLASOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
|
||||
BLASOBJS_P = $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
|
||||
BLASOBJS = $(SHEXTOBJS) $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
|
||||
BLASOBJS_P = $(SHEXTOBJS_P) $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
|
||||
|
||||
ifdef EXPRECISION
|
||||
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
|
||||
|
@ -30,6 +31,7 @@ $(QBLASOBJS) $(QBLASOBJS_P) : override CFLAGS += -DXDOUBLE -UCOMPLEX
|
|||
$(CBLASOBJS) $(CBLASOBJS_P) : override CFLAGS += -UDOUBLE -DCOMPLEX
|
||||
$(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX
|
||||
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
|
||||
$(SHEXTOBJS) $(SHEXTOBJS_P) : override CFLAGS += -DHALF -UDOUBLE -UCOMPLEX
|
||||
|
||||
$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
|
||||
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
|
||||
|
@ -38,6 +40,7 @@ $(QBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
|
|||
$(CBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
|
||||
$(ZBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
|
||||
$(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
|
||||
$(SHEXTOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
|
||||
|
||||
libs :: $(BLASOBJS) $(COMMONOBJS)
|
||||
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
|
||||
|
|
11
cblas.h
11
cblas.h
|
@ -382,6 +382,17 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint
|
|||
void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta,
|
||||
double *c, OPENBLAS_CONST blasint cldc);
|
||||
|
||||
/*** BFLOAT16 and INT8 extensions ***/
|
||||
/* convert float array to BFLOAT16 array by rounding */
|
||||
void cblas_shstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout);
|
||||
/* convert double array to BFLOAT16 array by rounding */
|
||||
void cblas_shdtobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST double *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout);
|
||||
/* convert BFLOAT16 array to float array */
|
||||
void cblas_sbf16tos(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, float *out, OPENBLAS_CONST blasint incout);
|
||||
/* convert BFLOAT16 array to double array */
|
||||
void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout);
|
||||
/* dot production of BFLOAT16 input arrays, and output as float */
|
||||
float cblas_shdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -126,12 +126,14 @@ if (BUILD_HALF)
|
|||
set(SHAXPYKERNEL ../arm/axpy.c)
|
||||
set(SHAXPBYKERNEL ../arm/axpby.c)
|
||||
set(SHCOPYKERNEL ../arm/copy.c)
|
||||
set(SHDOTKERNEL ../arm/dot.c)
|
||||
set(SHDOTKERNEL ../x86_64/shdot.c)
|
||||
set(SHROTKERNEL ../arm/rot.c)
|
||||
set(SHSCALKERNEL ../arm/scal.c)
|
||||
set(SHNRM2KERNEL ../arm/nrm2.c)
|
||||
set(SHSUMKERNEL ../arm/sum.c)
|
||||
set(SHSWAPKERNEL ../arm/swap.c)
|
||||
set(TOBF16KERNEL ../x86_64/tobf16.c)
|
||||
set(BF16TOKERNEL ../x86_64/bf16to.c)
|
||||
endif ()
|
||||
endmacro ()
|
||||
|
||||
|
|
3
common.h
3
common.h
|
@ -258,7 +258,8 @@ typedef unsigned long BLASULONG;
|
|||
#endif
|
||||
|
||||
#ifndef BFLOAT16
|
||||
typedef unsigned short bfloat16;
|
||||
#include <stdint.h>
|
||||
typedef uint16_t bfloat16;
|
||||
#define HALFCONVERSION 1
|
||||
#endif
|
||||
|
||||
|
|
|
@ -54,6 +54,11 @@ double BLASFUNC(dsdot) (blasint *, float *, blasint *, float *, blasint *);
|
|||
double BLASFUNC(ddot) (blasint *, double *, blasint *, double *, blasint *);
|
||||
xdouble BLASFUNC(qdot) (blasint *, xdouble *, blasint *, xdouble *, blasint *);
|
||||
|
||||
float BLASFUNC(shdot) (blasint *, bfloat16 *, blasint *, bfloat16 *, blasint *);
|
||||
void BLASFUNC(shstobf16) (blasint *, float *, blasint *, bfloat16 *, blasint *);
|
||||
void BLASFUNC(shdtobf16) (blasint *, double *, blasint *, bfloat16 *, blasint *);
|
||||
void BLASFUNC(sbf16tos) (blasint *, bfloat16 *, blasint *, float *, blasint *);
|
||||
void BLASFUNC(dbf16tod) (blasint *, bfloat16 *, blasint *, double *, blasint *);
|
||||
|
||||
#ifdef RETURN_BY_STRUCT
|
||||
typedef struct {
|
||||
|
|
|
@ -46,6 +46,12 @@ float sdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG);
|
|||
double dsdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG);
|
||||
double ddot_k(BLASLONG, double *, BLASLONG, double *, BLASLONG);
|
||||
xdouble qdot_k(BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
|
||||
float shdot_k(BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
|
||||
|
||||
void shstobf16_k(BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG);
|
||||
void shdtobf16_k(BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG);
|
||||
void sbf16tos_k (BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
|
||||
void dbf16tod_k (BLASLONG, bfloat16 *, BLASLONG, double *, BLASLONG);
|
||||
|
||||
openblas_complex_float cdotc_k (BLASLONG, float *, BLASLONG, float *, BLASLONG);
|
||||
openblas_complex_float cdotu_k (BLASLONG, float *, BLASLONG, float *, BLASLONG);
|
||||
|
|
|
@ -646,6 +646,11 @@
|
|||
|
||||
#elif defined(HALF)
|
||||
|
||||
#define D_TO_BF16_K SHDTOBF16_K
|
||||
#define D_BF16_TO_K DBF16TOD_K
|
||||
#define S_TO_BF16_K SHSTOBF16_K
|
||||
#define S_BF16_TO_K SBF16TOS_K
|
||||
|
||||
#define AMAX_K SAMAX_K
|
||||
#define AMIN_K SAMIN_K
|
||||
#define MAX_K SMAX_K
|
||||
|
@ -657,6 +662,7 @@
|
|||
#define ASUM_K SASUM_K
|
||||
#define DOTU_K SDOTU_K
|
||||
#define DOTC_K SDOTC_K
|
||||
#define BF16_DOT_K SHDOT_K
|
||||
#define AXPYU_K SAXPYU_K
|
||||
#define AXPYC_K SAXPYC_K
|
||||
#define AXPBY_K SAXPBY_K
|
||||
|
|
|
@ -51,6 +51,11 @@ typedef struct {
|
|||
int shgemm_p, shgemm_q, shgemm_r;
|
||||
int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn;
|
||||
|
||||
void (*shstobf16_k) (BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG);
|
||||
void (*shdtobf16_k) (BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG);
|
||||
void (*sbf16tos_k) (BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
|
||||
void (*dbf16tod_k) (BLASLONG, bfloat16 *, BLASLONG, double *, BLASLONG);
|
||||
|
||||
float (*shamax_k) (BLASLONG, float *, BLASLONG);
|
||||
float (*shamin_k) (BLASLONG, float *, BLASLONG);
|
||||
float (*shmax_k) (BLASLONG, float *, BLASLONG);
|
||||
|
@ -64,7 +69,7 @@ BLASLONG (*ishmin_k) (BLASLONG, float *, BLASLONG);
|
|||
float (*shasum_k) (BLASLONG, float *, BLASLONG);
|
||||
float (*shsum_k) (BLASLONG, float *, BLASLONG);
|
||||
int (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
|
||||
float (*shdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
|
||||
float (*shdot_k) (BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
|
||||
double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
|
||||
|
||||
int (*shrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float);
|
||||
|
|
12
common_sh.h
12
common_sh.h
|
@ -3,6 +3,12 @@
|
|||
|
||||
#ifndef DYNAMIC_ARCH
|
||||
|
||||
#define SHDOT_K shdot_k
|
||||
#define SHSTOBF16_K shstobf16_k
|
||||
#define SHDTOBF16_K shdtobf16_k
|
||||
#define SBF16TOS_K sbf16tos_k
|
||||
#define DBF16TOD_K dbf16tod_k
|
||||
|
||||
#define SHGEMM_ONCOPY shgemm_oncopy
|
||||
#define SHGEMM_OTCOPY shgemm_otcopy
|
||||
|
||||
|
@ -18,6 +24,12 @@
|
|||
|
||||
#else
|
||||
|
||||
#define SHDOT_K gotoblas -> shdot_k
|
||||
#define SHSTOBF16_K gotoblas -> shstobf16_k
|
||||
#define SHDTOBF16_K gotoblas -> shdtobf16_k
|
||||
#define SBF16TOS_K gotoblas -> sbf16tos_k
|
||||
#define DBF16TOD_K gotoblas -> dbf16tod_k
|
||||
|
||||
#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy
|
||||
#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy
|
||||
#define SHGEMM_INCOPY gotoblas -> shgemm_incopy
|
||||
|
|
|
@ -59,12 +59,19 @@ extern int blas_omp_linked;
|
|||
#define BLAS_PTHREAD 0x4000U
|
||||
#define BLAS_NODE 0x2000U
|
||||
|
||||
#define BLAS_PREC 0x0003U
|
||||
#define BLAS_SINGLE 0x0000U
|
||||
#define BLAS_DOUBLE 0x0001U
|
||||
#define BLAS_XDOUBLE 0x0002U
|
||||
#define BLAS_REAL 0x0000U
|
||||
#define BLAS_COMPLEX 0x0004U
|
||||
#define BLAS_PREC 0x000FU
|
||||
#define BLAS_INT8 0x0000U
|
||||
#define BLAS_BFLOAT16 0x0001U
|
||||
#define BLAS_SINGLE 0x0002U
|
||||
#define BLAS_DOUBLE 0x0003U
|
||||
#define BLAS_XDOUBLE 0x0004U
|
||||
#define BLAS_STOBF16 0x0008U
|
||||
#define BLAS_DTOBF16 0x0009U
|
||||
#define BLAS_BF16TOS 0x000AU
|
||||
#define BLAS_BF16TOD 0x000BU
|
||||
|
||||
#define BLAS_REAL 0x0000U
|
||||
#define BLAS_COMPLEX 0x1000U
|
||||
|
||||
#define BLAS_TRANSA 0x0030U /* 2bit */
|
||||
#define BLAS_TRANSA_N 0x0000U
|
||||
|
|
|
@ -142,6 +142,29 @@ static __inline void cpuid(int op, int *eax, int *ebx, int *ecx, int *edx){
|
|||
#endif
|
||||
}
|
||||
|
||||
static __inline void cpuid_count(int op, int count, int *eax, int *ebx, int *ecx, int *edx)
|
||||
{
|
||||
#ifdef C_MSVC
|
||||
int cpuInfo[4] = {-1};
|
||||
__cpuidex(cpuInfo, op, count);
|
||||
*eax = cpuInfo[0];
|
||||
*ebx = cpuInfo[1];
|
||||
*ecx = cpuInfo[2];
|
||||
*edx = cpuInfo[3];
|
||||
#else
|
||||
#if defined(__i386__) && defined(__PIC__)
|
||||
__asm__ __volatile__
|
||||
("mov %%ebx, %%edi;"
|
||||
"cpuid;"
|
||||
"xchgl %%ebx, %%edi;"
|
||||
: "=a" (*eax), "=D" (*ebx), "=c" (*ecx), "=d" (*edx) : "0" (op), "2" (count) : "cc");
|
||||
#else
|
||||
__asm__ __volatile__
|
||||
("cpuid": "=a" (*eax), "=b" (*ebx), "=c" (*ecx), "=d" (*edx) : "0" (op), "2" (count) : "cc");
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
#define WHEREAMI
|
||||
*/
|
||||
|
|
|
@ -49,9 +49,36 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha
|
|||
blas_arg_t args [MAX_CPU_NUMBER];
|
||||
|
||||
BLASLONG i, width, astride, bstride;
|
||||
int num_cpu, calc_type;
|
||||
int num_cpu, calc_type_a, calc_type_b;
|
||||
|
||||
calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2;
|
||||
switch (mode & BLAS_PREC) {
|
||||
case BLAS_INT8 :
|
||||
case BLAS_BFLOAT16:
|
||||
case BLAS_SINGLE :
|
||||
case BLAS_DOUBLE :
|
||||
case BLAS_XDOUBLE :
|
||||
calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0);
|
||||
break;
|
||||
case BLAS_STOBF16 :
|
||||
calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0);
|
||||
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
|
||||
break;
|
||||
case BLAS_DTOBF16 :
|
||||
calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0);
|
||||
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
|
||||
break;
|
||||
case BLAS_BF16TOS :
|
||||
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
|
||||
calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0);
|
||||
break;
|
||||
case BLAS_BF16TOD :
|
||||
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
|
||||
calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0);
|
||||
break;
|
||||
default:
|
||||
calc_type_a = calc_type_b = 0;
|
||||
break;
|
||||
}
|
||||
|
||||
mode |= BLAS_LEGACY;
|
||||
|
||||
|
@ -77,8 +104,8 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha
|
|||
bstride = width;
|
||||
}
|
||||
|
||||
astride <<= calc_type;
|
||||
bstride <<= calc_type;
|
||||
astride <<= calc_type_a;
|
||||
bstride <<= calc_type_b;
|
||||
|
||||
args[num_cpu].m = width;
|
||||
args[num_cpu].n = n;
|
||||
|
@ -120,9 +147,36 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL
|
|||
blas_arg_t args [MAX_CPU_NUMBER];
|
||||
|
||||
BLASLONG i, width, astride, bstride;
|
||||
int num_cpu, calc_type;
|
||||
int num_cpu, calc_type_a, calc_type_b;
|
||||
|
||||
calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2;
|
||||
switch (mode & BLAS_PREC) {
|
||||
case BLAS_INT8 :
|
||||
case BLAS_BFLOAT16:
|
||||
case BLAS_SINGLE :
|
||||
case BLAS_DOUBLE :
|
||||
case BLAS_XDOUBLE :
|
||||
calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0);
|
||||
break;
|
||||
case BLAS_STOBF16 :
|
||||
calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0);
|
||||
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
|
||||
break;
|
||||
case BLAS_DTOBF16 :
|
||||
calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0);
|
||||
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
|
||||
break;
|
||||
case BLAS_BF16TOS :
|
||||
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
|
||||
calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0);
|
||||
break;
|
||||
case BLAS_BF16TOD :
|
||||
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
|
||||
calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0);
|
||||
break;
|
||||
default:
|
||||
calc_type_a = calc_type_b = 0;
|
||||
break;
|
||||
}
|
||||
|
||||
mode |= BLAS_LEGACY;
|
||||
|
||||
|
@ -148,8 +202,8 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL
|
|||
bstride = width;
|
||||
}
|
||||
|
||||
astride <<= calc_type;
|
||||
bstride <<= calc_type;
|
||||
astride <<= calc_type_a;
|
||||
bstride <<= calc_type_b;
|
||||
|
||||
args[num_cpu].m = width;
|
||||
args[num_cpu].n = n;
|
||||
|
|
|
@ -192,7 +192,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
|
||||
if (!(mode & BLAS_COMPLEX)){
|
||||
#ifdef EXPRECISION
|
||||
if (mode & BLAS_XDOUBLE){
|
||||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
|
||||
/* REAL / Extended Double */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble,
|
||||
xdouble *, BLASLONG, xdouble *, BLASLONG,
|
||||
|
@ -205,7 +205,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> c, args -> ldc, sb);
|
||||
} else
|
||||
#endif
|
||||
if (mode & BLAS_DOUBLE){
|
||||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){
|
||||
/* REAL / Double */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
|
||||
double *, BLASLONG, double *, BLASLONG,
|
||||
|
@ -216,21 +216,58 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
} else {
|
||||
/* REAL / Single */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
|
||||
float *, BLASLONG, float *, BLASLONG,
|
||||
float *, BLASLONG, void *) = func;
|
||||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){
|
||||
/* REAL / Single */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
|
||||
float *, BLASLONG, float *, BLASLONG,
|
||||
float *, BLASLONG, void *) = func;
|
||||
|
||||
afunc(args -> m, args -> n, args -> k,
|
||||
((float *)args -> alpha)[0],
|
||||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
afunc(args -> m, args -> n, args -> k,
|
||||
((float *)args -> alpha)[0],
|
||||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
#ifdef BUILD_HALF
|
||||
} else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){
|
||||
/* REAL / BFLOAT16 */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16,
|
||||
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG,
|
||||
bfloat16 *, BLASLONG, void *) = func;
|
||||
|
||||
afunc(args -> m, args -> n, args -> k,
|
||||
((bfloat16 *)args -> alpha)[0],
|
||||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
} else if ((mode & BLAS_PREC) == BLAS_STOBF16){
|
||||
/* REAL / BLAS_STOBF16 */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
|
||||
float *, BLASLONG, bfloat16 *, BLASLONG,
|
||||
float *, BLASLONG, void *) = func;
|
||||
|
||||
afunc(args -> m, args -> n, args -> k,
|
||||
((float *)args -> alpha)[0],
|
||||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
} else if ((mode & BLAS_PREC) == BLAS_DTOBF16){
|
||||
/* REAL / BLAS_DTOBF16 */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
|
||||
double *, BLASLONG, bfloat16 *, BLASLONG,
|
||||
double *, BLASLONG, void *) = func;
|
||||
|
||||
afunc(args -> m, args -> n, args -> k,
|
||||
((double *)args -> alpha)[0],
|
||||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
#endif
|
||||
} else {
|
||||
/* REAL / Other types in future */
|
||||
}
|
||||
} else {
|
||||
#ifdef EXPRECISION
|
||||
if (mode & BLAS_XDOUBLE){
|
||||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
|
||||
/* COMPLEX / Extended Double */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble,
|
||||
xdouble *, BLASLONG, xdouble *, BLASLONG,
|
||||
|
@ -244,7 +281,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> c, args -> ldc, sb);
|
||||
} else
|
||||
#endif
|
||||
if (mode & BLAS_DOUBLE){
|
||||
if ((mode & BLAS_PREC) == BLAS_DOUBLE) {
|
||||
/* COMPLEX / Double */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double,
|
||||
double *, BLASLONG, double *, BLASLONG,
|
||||
|
@ -256,7 +293,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
} else {
|
||||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE) {
|
||||
/* COMPLEX / Single */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float,
|
||||
float *, BLASLONG, float *, BLASLONG,
|
||||
|
@ -268,7 +305,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
}
|
||||
} else {
|
||||
/* COMPLEX / Other types in future */
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -414,33 +453,37 @@ blas_queue_t *tscq;
|
|||
if (sb == NULL) {
|
||||
if (!(queue -> mode & BLAS_COMPLEX)){
|
||||
#ifdef EXPRECISION
|
||||
if (queue -> mode & BLAS_XDOUBLE){
|
||||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
|
||||
sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
} else
|
||||
#endif
|
||||
if (queue -> mode & BLAS_DOUBLE){
|
||||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE) {
|
||||
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
|
||||
} else {
|
||||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
|
||||
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
}
|
||||
} else {
|
||||
/* Other types in future */
|
||||
}
|
||||
} else {
|
||||
#ifdef EXPRECISION
|
||||
if (queue -> mode & BLAS_XDOUBLE){
|
||||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
|
||||
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
} else
|
||||
#endif
|
||||
if (queue -> mode & BLAS_DOUBLE){
|
||||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
|
||||
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
} else {
|
||||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
|
||||
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
}
|
||||
} else {
|
||||
/* Other types in future */
|
||||
}
|
||||
}
|
||||
queue->sb=sb;
|
||||
}
|
||||
|
|
|
@ -142,7 +142,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
|
||||
if (!(mode & BLAS_COMPLEX)){
|
||||
#ifdef EXPRECISION
|
||||
if (mode & BLAS_XDOUBLE){
|
||||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
|
||||
/* REAL / Extended Double */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble,
|
||||
xdouble *, BLASLONG, xdouble *, BLASLONG,
|
||||
|
@ -155,7 +155,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> c, args -> ldc, sb);
|
||||
} else
|
||||
#endif
|
||||
if (mode & BLAS_DOUBLE){
|
||||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){
|
||||
/* REAL / Double */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
|
||||
double *, BLASLONG, double *, BLASLONG,
|
||||
|
@ -166,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
} else {
|
||||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){
|
||||
/* REAL / Single */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
|
||||
float *, BLASLONG, float *, BLASLONG,
|
||||
|
@ -177,10 +177,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
#ifdef BUILD_HALF
|
||||
} else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){
|
||||
/* REAL / BFLOAT16 */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16,
|
||||
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG,
|
||||
bfloat16 *, BLASLONG, void *) = func;
|
||||
|
||||
afunc(args -> m, args -> n, args -> k,
|
||||
((bfloat16 *)args -> alpha)[0],
|
||||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
} else if ((mode & BLAS_PREC) == BLAS_STOBF16){
|
||||
/* REAL / BLAS_STOBF16 */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
|
||||
float *, BLASLONG, bfloat16 *, BLASLONG,
|
||||
float *, BLASLONG, void *) = func;
|
||||
|
||||
afunc(args -> m, args -> n, args -> k,
|
||||
((float *)args -> alpha)[0],
|
||||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
} else if ((mode & BLAS_PREC) == BLAS_DTOBF16){
|
||||
/* REAL / BLAS_DTOBF16 */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
|
||||
double *, BLASLONG, bfloat16 *, BLASLONG,
|
||||
double *, BLASLONG, void *) = func;
|
||||
|
||||
afunc(args -> m, args -> n, args -> k,
|
||||
((double *)args -> alpha)[0],
|
||||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
#endif
|
||||
} else {
|
||||
/* REAL / Other types in future */
|
||||
}
|
||||
} else {
|
||||
#ifdef EXPRECISION
|
||||
if (mode & BLAS_XDOUBLE){
|
||||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
|
||||
/* COMPLEX / Extended Double */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble,
|
||||
xdouble *, BLASLONG, xdouble *, BLASLONG,
|
||||
|
@ -194,7 +231,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> c, args -> ldc, sb);
|
||||
} else
|
||||
#endif
|
||||
if (mode & BLAS_DOUBLE){
|
||||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){
|
||||
/* COMPLEX / Double */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double,
|
||||
double *, BLASLONG, double *, BLASLONG,
|
||||
|
@ -206,7 +243,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
} else {
|
||||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){
|
||||
/* COMPLEX / Single */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float,
|
||||
float *, BLASLONG, float *, BLASLONG,
|
||||
|
@ -218,8 +255,10 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
/* COMPLEX / Other types in future */
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void exec_threads(blas_queue_t *queue, int buf_index){
|
||||
|
@ -255,32 +294,36 @@ static void exec_threads(blas_queue_t *queue, int buf_index){
|
|||
if (sb == NULL) {
|
||||
if (!(queue -> mode & BLAS_COMPLEX)){
|
||||
#ifdef EXPRECISION
|
||||
if (queue -> mode & BLAS_XDOUBLE){
|
||||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
|
||||
sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
} else
|
||||
#endif
|
||||
if (queue -> mode & BLAS_DOUBLE){
|
||||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
|
||||
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
|
||||
} else {
|
||||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE){
|
||||
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
} else {
|
||||
/* Other types in future */
|
||||
}
|
||||
} else {
|
||||
#ifdef EXPRECISION
|
||||
if (queue -> mode & BLAS_XDOUBLE){
|
||||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
|
||||
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
} else
|
||||
#endif
|
||||
if (queue -> mode & BLAS_DOUBLE){
|
||||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
|
||||
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
} else {
|
||||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
|
||||
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
} else {
|
||||
/* Other types in future */
|
||||
}
|
||||
}
|
||||
queue->sb=sb;
|
||||
|
|
|
@ -77,7 +77,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
|
||||
if (!(mode & BLAS_COMPLEX)){
|
||||
#ifdef EXPRECISION
|
||||
if (mode & BLAS_XDOUBLE){
|
||||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
|
||||
/* REAL / Extended Double */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble,
|
||||
xdouble *, BLASLONG, xdouble *, BLASLONG,
|
||||
|
@ -90,7 +90,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> c, args -> ldc, sb);
|
||||
} else
|
||||
#endif
|
||||
if (mode & BLAS_DOUBLE){
|
||||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){
|
||||
/* REAL / Double */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
|
||||
double *, BLASLONG, double *, BLASLONG,
|
||||
|
@ -101,7 +101,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
} else {
|
||||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){
|
||||
/* REAL / Single */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
|
||||
float *, BLASLONG, float *, BLASLONG,
|
||||
|
@ -112,10 +112,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
#ifdef BUILD_HALF
|
||||
} else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){
|
||||
/* REAL / BFLOAT16 */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16,
|
||||
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG,
|
||||
bfloat16 *, BLASLONG, void *) = func;
|
||||
|
||||
afunc(args -> m, args -> n, args -> k,
|
||||
((bfloat16 *)args -> alpha)[0],
|
||||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
} else if ((mode & BLAS_PREC) == BLAS_STOBF16){
|
||||
/* REAL / BLAS_STOBF16 */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
|
||||
float *, BLASLONG, bfloat16 *, BLASLONG,
|
||||
float *, BLASLONG, void *) = func;
|
||||
|
||||
afunc(args -> m, args -> n, args -> k,
|
||||
((float *)args -> alpha)[0],
|
||||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
} else if ((mode & BLAS_PREC) == BLAS_DTOBF16){
|
||||
/* REAL / BLAS_DTOBF16 */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
|
||||
double *, BLASLONG, bfloat16 *, BLASLONG,
|
||||
double *, BLASLONG, void *) = func;
|
||||
|
||||
afunc(args -> m, args -> n, args -> k,
|
||||
((double *)args -> alpha)[0],
|
||||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
#endif
|
||||
} else {
|
||||
/* REAL / Other types in future */
|
||||
}
|
||||
} else {
|
||||
#ifdef EXPRECISION
|
||||
if (mode & BLAS_XDOUBLE){
|
||||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
|
||||
/* COMPLEX / Extended Double */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble,
|
||||
xdouble *, BLASLONG, xdouble *, BLASLONG,
|
||||
|
@ -129,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> c, args -> ldc, sb);
|
||||
} else
|
||||
#endif
|
||||
if (mode & BLAS_DOUBLE){
|
||||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){
|
||||
/* COMPLEX / Double */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double,
|
||||
double *, BLASLONG, double *, BLASLONG,
|
||||
|
@ -141,7 +178,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
} else {
|
||||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE) {
|
||||
/* COMPLEX / Single */
|
||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float,
|
||||
float *, BLASLONG, float *, BLASLONG,
|
||||
|
@ -153,7 +190,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
|
|||
args -> a, args -> lda,
|
||||
args -> b, args -> ldb,
|
||||
args -> c, args -> ldc, sb);
|
||||
}
|
||||
} else {
|
||||
/* COMPLEX / Other types in future */
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -233,32 +272,36 @@ static DWORD WINAPI blas_thread_server(void *arg){
|
|||
if (sb == NULL) {
|
||||
if (!(queue -> mode & BLAS_COMPLEX)){
|
||||
#ifdef EXPRECISION
|
||||
if (queue -> mode & BLAS_XDOUBLE){
|
||||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
|
||||
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * sizeof(xdouble)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
} else
|
||||
#endif
|
||||
if (queue -> mode & BLAS_DOUBLE){
|
||||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
|
||||
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
|
||||
} else {
|
||||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
|
||||
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
} else {
|
||||
/* Other types in future */
|
||||
}
|
||||
} else {
|
||||
#ifdef EXPRECISION
|
||||
if (queue -> mode & BLAS_XDOUBLE){
|
||||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
|
||||
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
} else
|
||||
#endif
|
||||
if (queue -> mode & BLAS_DOUBLE){
|
||||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
|
||||
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
} else {
|
||||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
|
||||
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
|
||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
|
||||
} else {
|
||||
/* Other types in future */
|
||||
}
|
||||
}
|
||||
queue->sb=sb;
|
||||
|
|
|
@ -207,6 +207,19 @@ extern gotoblas_t gotoblas_SKYLAKEX;
|
|||
#else
|
||||
#define gotoblas_SKYLAKEX gotoblas_PRESCOTT
|
||||
#endif
|
||||
#ifdef DYN_COOPERLAKE
|
||||
extern gotoblas_t gotoblas_COOPERLAKE;
|
||||
#elif defined(DYN_SKYLAKEX)
|
||||
#define gotoblas_COOPERLAKE gotoblas_SKYLAKEX
|
||||
#elif defined(DYN_HASWELL)
|
||||
#define gotoblas_COOPERLAKE gotoblas_HASWELL
|
||||
#elif defined(DYN_SANDYBRIDGE)
|
||||
#define gotoblas_COOPERLAKE gotoblas_SANDYBRIDGE
|
||||
#elif defined(DYN_NEHALEM)
|
||||
#define gotoblas_COOPERLAKE gotoblas_NEHALEM
|
||||
#else
|
||||
#define gotoblas_COOPERLAKE gotoblas_PRESCOTT
|
||||
#endif
|
||||
|
||||
|
||||
#else // not DYNAMIC_LIST
|
||||
|
@ -247,14 +260,17 @@ extern gotoblas_t gotoblas_EXCAVATOR;
|
|||
#ifdef NO_AVX2
|
||||
#define gotoblas_HASWELL gotoblas_SANDYBRIDGE
|
||||
#define gotoblas_SKYLAKEX gotoblas_SANDYBRIDGE
|
||||
#define gotoblas_COOPERLAKE gotoblas_SANDYBRIDGE
|
||||
#define gotoblas_ZEN gotoblas_SANDYBRIDGE
|
||||
#else
|
||||
extern gotoblas_t gotoblas_HASWELL;
|
||||
extern gotoblas_t gotoblas_ZEN;
|
||||
#ifndef NO_AVX512
|
||||
extern gotoblas_t gotoblas_SKYLAKEX;
|
||||
extern gotoblas_t gotoblas_COOPERLAKE;
|
||||
#else
|
||||
#define gotoblas_SKYLAKEX gotoblas_HASWELL
|
||||
#define gotoblas_COOPERLAKE gotoblas_HASWELL
|
||||
#endif
|
||||
#endif
|
||||
#else
|
||||
|
@ -262,6 +278,7 @@ extern gotoblas_t gotoblas_SKYLAKEX;
|
|||
#define gotoblas_SANDYBRIDGE gotoblas_NEHALEM
|
||||
#define gotoblas_HASWELL gotoblas_NEHALEM
|
||||
#define gotoblas_SKYLAKEX gotoblas_NEHALEM
|
||||
#define gotoblas_COOPERLAKE gotoblas_NEHALEM
|
||||
#define gotoblas_BULLDOZER gotoblas_BARCELONA
|
||||
#define gotoblas_PILEDRIVER gotoblas_BARCELONA
|
||||
#define gotoblas_STEAMROLLER gotoblas_BARCELONA
|
||||
|
@ -343,6 +360,23 @@ int support_avx512(){
|
|||
#endif
|
||||
}
|
||||
|
||||
int support_avx512_bf16(){
|
||||
#if !defined(NO_AVX) && !defined(NO_AVX512)
|
||||
int eax, ebx, ecx, edx;
|
||||
int ret=0;
|
||||
|
||||
if (!support_avx512())
|
||||
return 0;
|
||||
cpuid_count(7, 1, &eax, &ebx, &ecx, &edx);
|
||||
if((eax & 32) == 32){
|
||||
ret=1; // CPUID.7.1:EAX[bit 5] indicates whether avx512_bf16 supported or not
|
||||
}
|
||||
return ret;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
extern void openblas_warning(int verbose, const char * msg);
|
||||
#define FALLBACK_VERBOSE 1
|
||||
#define NEHALEM_FALLBACK "OpenBLAS : Your OS does not support AVX instructions. OpenBLAS is using Nehalem kernels as a fallback, which may give poorer performance.\n"
|
||||
|
@ -524,7 +558,10 @@ static gotoblas_t *get_coretype(void){
|
|||
return &gotoblas_NEHALEM; //OS doesn't support AVX. Use old kernels.
|
||||
}
|
||||
}
|
||||
if (model == 5) {
|
||||
if (model == 5) {
|
||||
// Intel Cooperlake
|
||||
if(support_avx512_bf16())
|
||||
return &gotoblas_COOPERLAKE;
|
||||
// Intel Skylake X
|
||||
if (support_avx512())
|
||||
return &gotoblas_SKYLAKEX;
|
||||
|
@ -774,7 +811,8 @@ static char *corename[] = {
|
|||
"Steamroller",
|
||||
"Excavator",
|
||||
"Zen",
|
||||
"SkylakeX"
|
||||
"SkylakeX",
|
||||
"Cooperlake"
|
||||
};
|
||||
|
||||
char *gotoblas_corename(void) {
|
||||
|
@ -838,6 +876,7 @@ char *gotoblas_corename(void) {
|
|||
if (gotoblas == &gotoblas_EXCAVATOR) return corename[22];
|
||||
if (gotoblas == &gotoblas_ZEN) return corename[23];
|
||||
if (gotoblas == &gotoblas_SKYLAKEX) return corename[24];
|
||||
if (gotoblas == &gotoblas_COOPERLAKE) return corename[25];
|
||||
return corename[0];
|
||||
}
|
||||
|
||||
|
@ -868,6 +907,7 @@ static gotoblas_t *force_coretype(char *coretype){
|
|||
|
||||
switch (found)
|
||||
{
|
||||
case 25: return (&gotoblas_COOPERLAKE);
|
||||
case 24: return (&gotoblas_SKYLAKEX);
|
||||
case 23: return (&gotoblas_ZEN);
|
||||
case 22: return (&gotoblas_EXCAVATOR);
|
||||
|
|
|
@ -46,7 +46,7 @@
|
|||
ssum, dsum, scsum, dzsum
|
||||
);
|
||||
|
||||
@halfblasobjs = (shgemm);
|
||||
@halfblasobjs = (shgemm, shdot, shstobf16, shdtobf16, sbf16tos, dbf16tod);
|
||||
@cblasobjs = (
|
||||
cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv,
|
||||
cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k,
|
||||
|
@ -84,7 +84,7 @@
|
|||
cblas_xerbla
|
||||
);
|
||||
|
||||
@halfcblasobjs = (cblas_shgemm);
|
||||
@halfcblasobjs = (cblas_shgemm, cblas_shdot, cblas_shstobf16, cblas_shdtobf16, cblas_sbf16tos, cblas_dbf16tod);
|
||||
|
||||
@exblasobjs = (
|
||||
qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm,
|
||||
|
|
|
@ -47,7 +47,9 @@ SBLAS3OBJS = \
|
|||
sgeadd.$(SUFFIX)
|
||||
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
SHBLAS1OBJS = shdot.$(SUFFIX)
|
||||
SHBLAS3OBJS = shgemm.$(SUFFIX)
|
||||
SHEXTOBJS = shstobf16.$(SUFFIX) shdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX)
|
||||
endif
|
||||
|
||||
DBLAS1OBJS = \
|
||||
|
@ -281,7 +283,9 @@ CSBLAS3OBJS = \
|
|||
cblas_sgeadd.$(SUFFIX)
|
||||
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
CSHBLAS1OBJS = cblas_shdot.$(SUFFIX)
|
||||
CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX)
|
||||
CSHEXTOBJS = cblas_shstobf16.$(SUFFIX) cblas_shdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX)
|
||||
endif
|
||||
|
||||
CDBLAS1OBJS = \
|
||||
|
@ -374,6 +378,7 @@ override CFLAGS += -I.
|
|||
SBLAS1OBJS += $(CSBLAS1OBJS)
|
||||
SBLAS2OBJS += $(CSBLAS2OBJS)
|
||||
SBLAS3OBJS += $(CSBLAS3OBJS)
|
||||
SHBLAS1OBJS += $(CSHBLAS1OBJS)
|
||||
SHBLAS3OBJS += $(CSHBLAS3OBJS)
|
||||
DBLAS1OBJS += $(CDBLAS1OBJS)
|
||||
DBLAS2OBJS += $(CDBLAS2OBJS)
|
||||
|
@ -385,10 +390,11 @@ ZBLAS1OBJS += $(CZBLAS1OBJS)
|
|||
ZBLAS2OBJS += $(CZBLAS2OBJS)
|
||||
ZBLAS3OBJS += $(CZBLAS3OBJS)
|
||||
|
||||
SHEXTOBJS += $(CSHEXTOBJS)
|
||||
endif
|
||||
|
||||
SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS)
|
||||
SHBLASOBJS = $(SHBLAS3OBJS)
|
||||
SHBLASOBJS = $(SHBLAS1OBJS) $(SHBLAS3OBJS)
|
||||
DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS)
|
||||
QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS)
|
||||
CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS)
|
||||
|
@ -463,7 +469,7 @@ ZBLASOBJS += $(ZLAPACKOBJS)
|
|||
|
||||
endif
|
||||
|
||||
FUNCOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
|
||||
FUNCOBJS = $(SHEXTOBJS) $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
|
||||
|
||||
ifdef EXPRECISION
|
||||
FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS)
|
||||
|
@ -491,7 +497,7 @@ endif
|
|||
clean ::
|
||||
@rm -f functable.h
|
||||
|
||||
level1 : $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS)
|
||||
level1 : $(BEXTOBJS) $(SHBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS)
|
||||
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
|
||||
|
||||
level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS)
|
||||
|
@ -725,6 +731,19 @@ sdsdot.$(SUFFIX) sdsdot.$(PSUFFIX) : sdsdot.c
|
|||
dsdot.$(SUFFIX) dsdot.$(PSUFFIX) : dsdot.c
|
||||
$(CC) $(CFLAGS) -c $< -o $(@F)
|
||||
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
shdot.$(SUFFIX) shdot.$(PSUFFIX) : bf16dot.c
|
||||
$(CC) $(CFLAGS) -c $< -o $(@F)
|
||||
shstobf16.$(SUFFIX) shstobf16.$(PSUFFIX) : tobf16.c
|
||||
$(CC) $(CFLAGS) -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F)
|
||||
shdtobf16.$(SUFFIX) shdtobf16.$(PSUFFIX) : tobf16.c
|
||||
$(CC) $(CFLAGS) -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F)
|
||||
sbf16tos.$(SUFFIX) sbf16tos.$(PSUFFIX) : bf16to.c
|
||||
$(CC) $(CFLAGS) -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F)
|
||||
dbf16tod.$(SUFFIX) dbf16tod.$(PSUFFIX) : bf16to.c
|
||||
$(CC) $(CFLAGS) -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F)
|
||||
endif
|
||||
|
||||
sdot.$(SUFFIX) sdot.$(PSUFFIX) : dot.c
|
||||
$(CC) $(CFLAGS) -c $< -o $(@F)
|
||||
|
||||
|
@ -1463,6 +1482,19 @@ cblas_sdsdot.$(SUFFIX) cblas_sdsdot.$(PSUFFIX) : sdsdot.c
|
|||
cblas_dsdot.$(SUFFIX) cblas_dsdot.$(PSUFFIX) : dsdot.c
|
||||
$(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F)
|
||||
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
cblas_shdot.$(SUFFIX) cblas_shdot.$(PSUFFIX) : bf16dot.c
|
||||
$(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F)
|
||||
cblas_shstobf16.$(SUFFIX) cblas_shstobf16.$(PSUFFIX) : tobf16.c
|
||||
$(CC) $(CFLAGS) -DCBLAS -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F)
|
||||
cblas_shdtobf16.$(SUFFIX) cblas_shdtobf16.$(PSUFFIX) : tobf16.c
|
||||
$(CC) $(CFLAGS) -DCBLAS -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F)
|
||||
cblas_sbf16tos.$(SUFFIX) cblas_sbf16tos.$(PSUFFIX) : bf16to.c
|
||||
$(CC) $(CFLAGS) -DCBLAS -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F)
|
||||
cblas_dbf16tod.$(SUFFIX) cblas_dbf16tod.$(PSUFFIX) : bf16to.c
|
||||
$(CC) $(CFLAGS) -DCBLAS -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F)
|
||||
endif
|
||||
|
||||
cblas_sdot.$(SUFFIX) cblas_sdot.$(PSUFFIX) : dot.c
|
||||
$(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F)
|
||||
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
#include <stdio.h>
|
||||
#include "common.h"
|
||||
#ifdef FUNCTION_PROFILE
|
||||
#include "functable.h"
|
||||
#endif
|
||||
|
||||
#ifndef CBLAS
|
||||
float NAME(blasint *N, bfloat16 *x, blasint *INCX, bfloat16 *y, blasint *INCY){
|
||||
BLASLONG n = *N;
|
||||
BLASLONG incx = *INCX;
|
||||
BLASLONG incy = *INCY;
|
||||
float ret;
|
||||
PRINT_DEBUG_NAME;
|
||||
|
||||
if (n <= 0) return 0.;
|
||||
|
||||
IDEBUG_START;
|
||||
FUNCTION_PROFILE_START();
|
||||
|
||||
if (incx < 0) x -= (n - 1) * incx;
|
||||
if (incy < 0) y -= (n - 1) * incy;
|
||||
ret = BF16_DOT_K(n, x, incx, y, incy);
|
||||
|
||||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
|
||||
IDEBUG_END;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
float CNAME(blasint n, bfloat16 *x, blasint incx, bfloat16 *y, blasint incy){
|
||||
|
||||
float ret;
|
||||
PRINT_DEBUG_CNAME;
|
||||
|
||||
if (n <= 0) return 0.;
|
||||
|
||||
IDEBUG_START;
|
||||
FUNCTION_PROFILE_START();
|
||||
|
||||
if (incx < 0) x -= (n - 1) * incx;
|
||||
if (incy < 0) y -= (n - 1) * incy;
|
||||
ret = BF16_DOT_K(n, x, incx, y, incy);
|
||||
|
||||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
|
||||
IDEBUG_END;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,62 @@
|
|||
#include <stdio.h>
|
||||
#include "common.h"
|
||||
#ifdef FUNCTION_PROFILE
|
||||
#include "functable.h"
|
||||
#endif
|
||||
|
||||
#if defined(DOUBLE_PREC)
|
||||
#define FLOAT_TYPE double
|
||||
#elif defined(SINGLE_PREC)
|
||||
#define FLOAT_TYPE float
|
||||
#else
|
||||
#endif
|
||||
|
||||
#ifndef CBLAS
|
||||
void NAME(blasint *N, bfloat16 *in, blasint *INC_IN, FLOAT_TYPE *out, blasint *INC_OUT){
|
||||
BLASLONG n = *N;
|
||||
BLASLONG inc_in = *INC_IN;
|
||||
BLASLONG inc_out = *INC_OUT;
|
||||
|
||||
PRINT_DEBUG_NAME;
|
||||
|
||||
if (n <= 0) return;
|
||||
|
||||
IDEBUG_START;
|
||||
FUNCTION_PROFILE_START();
|
||||
|
||||
if (inc_in < 0) in -= (n - 1) * inc_in;
|
||||
if (inc_out < 0) out -= (n - 1) * inc_out;
|
||||
|
||||
#if defined(DOUBLE_PREC)
|
||||
D_BF16_TO_K(n, in, inc_in, out, inc_out);
|
||||
#elif defined(SINGLE_PREC)
|
||||
S_BF16_TO_K(n, in, inc_in, out, inc_out);
|
||||
#else
|
||||
#endif
|
||||
|
||||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
|
||||
IDEBUG_END;
|
||||
}
|
||||
#else
|
||||
void CNAME(blasint n, bfloat16 * in, blasint inc_in, FLOAT_TYPE * out, blasint inc_out){
|
||||
PRINT_DEBUG_CNAME;
|
||||
|
||||
if (n <= 0) return;
|
||||
|
||||
IDEBUG_START;
|
||||
FUNCTION_PROFILE_START();
|
||||
|
||||
if (inc_in < 0) in -= (n - 1) * inc_in;
|
||||
if (inc_out < 0) out -= (n - 1) * inc_out;
|
||||
|
||||
#if defined(DOUBLE_PREC)
|
||||
D_BF16_TO_K(n, in, inc_in, out, inc_out);
|
||||
#elif defined(SINGLE_PREC)
|
||||
S_BF16_TO_K(n, in, inc_in, out, inc_out);
|
||||
#else
|
||||
#endif
|
||||
|
||||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
|
||||
IDEBUG_END;
|
||||
}
|
||||
#endif
|
|
@ -0,0 +1,61 @@
|
|||
#include <stdio.h>
|
||||
#include "common.h"
|
||||
#ifdef FUNCTION_PROFILE
|
||||
#include "functable.h"
|
||||
#endif
|
||||
|
||||
#if defined(DOUBLE_PREC)
|
||||
#define FLOAT_TYPE double
|
||||
#elif defined(SINGLE_PREC)
|
||||
#define FLOAT_TYPE float
|
||||
#else
|
||||
#endif
|
||||
|
||||
#ifndef CBLAS
|
||||
void NAME(blasint *N, FLOAT_TYPE *in, blasint *INC_IN, bfloat16 *out, blasint *INC_OUT){
|
||||
BLASLONG n = *N;
|
||||
BLASLONG inc_in = *INC_IN;
|
||||
BLASLONG inc_out = *INC_OUT;
|
||||
|
||||
PRINT_DEBUG_NAME;
|
||||
|
||||
if (n <= 0) return;
|
||||
|
||||
IDEBUG_START;
|
||||
FUNCTION_PROFILE_START();
|
||||
|
||||
if (inc_in < 0) in -= (n - 1) * inc_in;
|
||||
if (inc_out < 0) out -= (n - 1) * inc_out;
|
||||
|
||||
#if defined(DOUBLE_PREC)
|
||||
D_TO_BF16_K(n, in, inc_in, out, inc_out);
|
||||
#elif defined(SINGLE_PREC)
|
||||
S_TO_BF16_K(n, in, inc_in, out, inc_out);
|
||||
#else
|
||||
#endif
|
||||
|
||||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
|
||||
IDEBUG_END;
|
||||
}
|
||||
#else
|
||||
void CNAME(blasint n, FLOAT_TYPE *in, blasint inc_in, bfloat16 *out, blasint inc_out){
|
||||
PRINT_DEBUG_CNAME;
|
||||
|
||||
if (n <= 0) return;
|
||||
|
||||
IDEBUG_START;
|
||||
FUNCTION_PROFILE_START();
|
||||
|
||||
if (inc_in < 0) in -= (n - 1) * inc_in;
|
||||
if (inc_out < 0) out -= (n - 1) * inc_out;
|
||||
|
||||
#if defined(DOUBLE_PREC)
|
||||
D_TO_BF16_K(n, in, inc_in, out, inc_out);
|
||||
#elif defined(SINGLE_PREC)
|
||||
S_TO_BF16_K(n, in, inc_in, out, inc_out);
|
||||
#endif
|
||||
|
||||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
|
||||
IDEBUG_END;
|
||||
}
|
||||
#endif
|
|
@ -262,6 +262,20 @@ ifndef XDOTKERNEL
|
|||
XDOTKERNEL = zdot.S
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
ifndef SHDOTKERNEL
|
||||
SHDOTKERNEL = ../x86_64/shdot.c
|
||||
endif
|
||||
|
||||
ifndef TOBF16KERNEL
|
||||
TOBF16KERNEL = ../x86_64/tobf16.c
|
||||
endif
|
||||
|
||||
ifndef BF16TOKERNEL
|
||||
BF16TOKERNEL = ../x86_64/bf16to.c
|
||||
endif
|
||||
endif
|
||||
|
||||
### NRM2 ###
|
||||
|
||||
ifndef SNRM2KERNEL
|
||||
|
@ -516,6 +530,15 @@ XBLASOBJS += \
|
|||
xdotc_k$(TSUFFIX).$(SUFFIX) xdotu_k$(TSUFFIX).$(SUFFIX) xnrm2_k$(TSUFFIX).$(SUFFIX) xqrot_k$(TSUFFIX).$(SUFFIX) \
|
||||
xscal_k$(TSUFFIX).$(SUFFIX) xswap_k$(TSUFFIX).$(SUFFIX) xsum_k$(TSUFFIX).$(SUFFIX)
|
||||
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
SHBLASOBJS += \
|
||||
shdot_k$(TSUFFIX).$(SUFFIX)
|
||||
SHEXTOBJS += \
|
||||
shstobf16_k$(TSUFFIX).$(SUFFIX) shdtobf16_k$(TSUFFIX).$(SUFFIX)
|
||||
SHEXTOBJS += \
|
||||
sbf16tos_k$(TSUFFIX).$(SUFFIX) dbf16tod_k$(TSUFFIX).$(SUFFIX)
|
||||
endif
|
||||
|
||||
### AMAX ###
|
||||
|
||||
|
||||
|
@ -734,6 +757,19 @@ $(KDIR)ddot_k$(TSUFFIX).$(SUFFIX) $(KDIR)ddot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNEL
|
|||
$(KDIR)qdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)qdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(QDOTKERNEL)
|
||||
$(CC) -c $(CFLAGS) -UCOMPLEX -DXDOUBLE $< -o $@
|
||||
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
$(KDIR)shdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)shdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHDOTKERNEL)
|
||||
$(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@
|
||||
$(KDIR)shstobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL)
|
||||
$(CC) -c $(CFLAGS) -UDOUBLE -DSINGLE $< -o $@
|
||||
$(KDIR)shdtobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL)
|
||||
$(CC) -c $(CFLAGS) -DDOUBLE -USINGLE $< -o $@
|
||||
$(KDIR)sbf16tos_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BF16TOKERNEL)
|
||||
$(CC) -c $(CFLAGS) -UDOUBLE -DSINGLE $< -o $@
|
||||
$(KDIR)dbf16tod_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BF16TOKERNEL)
|
||||
$(CC) -c $(CFLAGS) -DDOUBLE -USINGLE $< -o $@
|
||||
endif
|
||||
|
||||
$(KDIR)sdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)sdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SDOTKERNEL)
|
||||
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE $< -o $@
|
||||
|
||||
|
|
|
@ -62,9 +62,11 @@ gotoblas_t TABLE_NAME = {
|
|||
MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N),
|
||||
#endif
|
||||
|
||||
shstobf16_kTS, shdtobf16_kTS, sbf16tos_kTS, dbf16tod_kTS,
|
||||
|
||||
samax_kTS, samin_kTS, smax_kTS, smin_kTS,
|
||||
isamax_kTS, isamin_kTS, ismax_kTS, ismin_kTS,
|
||||
snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, sdot_kTS,
|
||||
snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, shdot_kTS,
|
||||
dsdot_kTS,
|
||||
srot_kTS, saxpy_kTS, sscal_kTS, sswap_kTS,
|
||||
sgemv_nTS, sgemv_tTS, sger_kTS,
|
||||
|
|
|
@ -146,6 +146,18 @@ ifndef XDOTKERNEL
|
|||
XDOTKERNEL = zdot.S
|
||||
endif
|
||||
|
||||
ifndef SHDOTKERNEL
|
||||
SHDOTKERNEL = shdot.c
|
||||
endif
|
||||
|
||||
ifndef TOBF16KERNEL
|
||||
TOBF16KERNEL = tobf16.c
|
||||
endif
|
||||
|
||||
ifndef BF16TOKERNEL
|
||||
BF16TOKERNEL = bf16to.c
|
||||
endif
|
||||
|
||||
ifndef ISAMAXKERNEL
|
||||
ISAMAXKERNEL = iamax_sse.S
|
||||
endif
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2014, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include <stddef.h>
|
||||
#include "common.h"
|
||||
|
||||
#if defined(DOUBLE)
|
||||
#define FLOAT_TYPE double
|
||||
#elif defined(SINGLE)
|
||||
#define FLOAT_TYPE float
|
||||
#else
|
||||
#endif
|
||||
|
||||
/* Notes for algorithm:
|
||||
* - Input denormal treated as zero
|
||||
* - Force to be QNAN
|
||||
*/
|
||||
static void bf16to_kernel_1(BLASLONG n, const bfloat16 * in, BLASLONG inc_in, FLOAT_TYPE * out, BLASLONG inc_out)
|
||||
{
|
||||
BLASLONG register index_in = 0;
|
||||
BLASLONG register index_out = 0;
|
||||
BLASLONG register index = 0;
|
||||
uint16_t * tmp = NULL;
|
||||
#if defined(DOUBLE)
|
||||
float float_out = 0.0;
|
||||
#endif
|
||||
|
||||
while(index<n) {
|
||||
#if defined(DOUBLE)
|
||||
float_out = 0.0;
|
||||
tmp = (uint16_t*)(&float_out);
|
||||
#else
|
||||
*(out+index_out) = 0;
|
||||
tmp = (uint16_t*)(out+index_out);
|
||||
#endif
|
||||
|
||||
switch((*(in+index_in)) & 0xff80u) {
|
||||
case (0x0000u): /* Type 1: Positive denormal */
|
||||
tmp[1] = 0x0000u;
|
||||
tmp[0] = 0x0000u;
|
||||
break;
|
||||
case (0x8000u): /* Type 2: Negative denormal */
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
tmp[1] = 0x8000u;
|
||||
tmp[0] = 0x0000u;
|
||||
#else
|
||||
tmp[1] = 0x0000u;
|
||||
tmp[0] = 0x8000u;
|
||||
#endif
|
||||
break;
|
||||
case (0x7f80u): /* Type 3: Positive infinity or NAN */
|
||||
case (0xff80u): /* Type 4: Negative infinity or NAN */
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
tmp[1] = *(in+index_in);
|
||||
#else
|
||||
tmp[0] = *(in+index_in);
|
||||
#endif
|
||||
/* Specific for NAN */
|
||||
if (((*(in+index_in)) & 0x007fu) != 0) {
|
||||
/* Force to be QNAN */
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
tmp[1] |= 0x0040u;
|
||||
#else
|
||||
tmp[0] |= 0x0040u;
|
||||
#endif
|
||||
}
|
||||
break;
|
||||
default: /* Type 5: Normal case */
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
tmp[1] = *(in+index_in);
|
||||
#else
|
||||
tmp[0] = *(in+index_in);
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
#if defined(DOUBLE)
|
||||
*(out+index_out) = (double)float_out;
|
||||
#endif
|
||||
index_in += inc_in;
|
||||
index_out += inc_out;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
|
||||
void CNAME(BLASLONG n, bfloat16 * in, BLASLONG inc_in, FLOAT_TYPE * out, BLASLONG inc_out)
|
||||
{
|
||||
if (n <= 0) return;
|
||||
|
||||
bf16to_kernel_1(n, in, inc_in, out, inc_out);
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2014, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
/* need a new enough GCC for avx512 support */
|
||||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
|
||||
|
||||
#define HAVE_TOBF16_ACCL_KERNEL 1
|
||||
#include "common.h"
|
||||
#include <immintrin.h>
|
||||
|
||||
static void tobf16_accl_kernel(BLASLONG n, const double * in, bfloat16 * out)
|
||||
{
|
||||
/* Get the 64-bytes unaligned header number targeting for avx512
|
||||
* processing (Assume input float array is natural aligned) */
|
||||
int align_header = ((64 - ((uintptr_t)in & (uintptr_t)0x3f)) >> 3) & 0x7;
|
||||
|
||||
if (n < align_header) {align_header = n;}
|
||||
|
||||
if (align_header != 0) {
|
||||
unsigned char align_mask8 = (((unsigned char)0xff) >> (8-align_header));
|
||||
__m512d a = _mm512_maskz_loadu_pd(*((__mmask8*) &align_mask8), &in[0]);
|
||||
_mm_mask_storeu_epi16(&out[0], *((__mmask8*) &align_mask8), (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(a)));
|
||||
}
|
||||
|
||||
if (n == align_header) {
|
||||
return;
|
||||
} else {
|
||||
n -= align_header;
|
||||
in += align_header;
|
||||
out += align_header;
|
||||
}
|
||||
|
||||
int tail_index_8 = n&(~7);
|
||||
int tail_index_32 = n&(~31);
|
||||
int tail_index_128 = n&(~127);
|
||||
unsigned char tail_mask8 = (((unsigned char) 0xff) >> (8 -(n&7)));
|
||||
|
||||
/* Processing the main chunk with 128-elements per round */
|
||||
for (int i = 0; i < tail_index_128; i += 128) {
|
||||
// Fold 1
|
||||
__m512 data1_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+ 0]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+ 8])), 1);
|
||||
__m512 data1_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+16]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+24])), 1);
|
||||
_mm512_storeu_si512(&out[i+ 0], (__m512i) _mm512_cvtne2ps_pbh(data1_512_high, data1_512_low));
|
||||
|
||||
// Fold 2
|
||||
__m512 data2_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+32]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+40])), 1);
|
||||
__m512 data2_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+48]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+56])), 1);
|
||||
_mm512_storeu_si512(&out[i+32], (__m512i) _mm512_cvtne2ps_pbh(data2_512_high, data2_512_low));
|
||||
|
||||
// Fold 3
|
||||
__m512 data3_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+64]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+72])), 1);
|
||||
__m512 data3_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+80]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+88])), 1);
|
||||
_mm512_storeu_si512(&out[i+64], (__m512i) _mm512_cvtne2ps_pbh(data3_512_high, data3_512_low));
|
||||
|
||||
// Fold 4
|
||||
__m512 data4_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+96]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+104])), 1);
|
||||
__m512 data4_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+112]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+120])), 1);
|
||||
_mm512_storeu_si512(&out[i+96], (__m512i) _mm512_cvtne2ps_pbh(data4_512_high, data4_512_low));
|
||||
}
|
||||
|
||||
/* Processing the remaining <128 chunk with 32-elements per round */
|
||||
for (int j = tail_index_128; j < tail_index_32; j += 32) {
|
||||
__m512 data1_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[j+ 0]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[j+ 8])), 1);
|
||||
__m512 data1_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[j+16]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[j+24])), 1);
|
||||
_mm512_storeu_si512(&out[j], (__m512i) _mm512_cvtne2ps_pbh(data1_512_high, data1_512_low));
|
||||
}
|
||||
|
||||
/* Processing the remaining <32 chunk with 8-elements per round */
|
||||
for (int j = tail_index_32; j < tail_index_8; j += 8) {
|
||||
_mm_storeu_si128((__m128i *)&out[j], (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(_mm512_load_pd(&in[j]))));
|
||||
}
|
||||
|
||||
/* Processing the remaining <8 chunk with masked processing */
|
||||
if ((n&7) > 0) {
|
||||
__m512d data_512 = _mm512_maskz_load_pd(*((__mmask8*) &tail_mask8), &in[tail_index_8]);
|
||||
_mm_mask_storeu_epi16(&out[tail_index_8], *((__mmask8*) &tail_mask8), (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(data_512)));
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,115 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2014, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#if defined(COOPERLAKE)
|
||||
#include "shdot_microk_cooperlake.c"
|
||||
#endif
|
||||
|
||||
static float shdot_compute(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y)
|
||||
{
|
||||
float d = 0.0;
|
||||
|
||||
#ifdef HAVE_SHDOT_ACCL_KERNEL
|
||||
if ((inc_x == 1) && (inc_y == 1)) {
|
||||
return shdot_accl_kernel(n, x, y);
|
||||
}
|
||||
#endif
|
||||
|
||||
float * x_fp32 = malloc(sizeof(float)*n);
|
||||
float * y_fp32 = malloc(sizeof(float)*n);
|
||||
|
||||
SBF16TOS_K(n, x, inc_x, x_fp32, 1);
|
||||
SBF16TOS_K(n, y, inc_y, y_fp32, 1);
|
||||
|
||||
d = SDOTU_K(n, x_fp32, 1, y_fp32, 1);
|
||||
|
||||
free(x_fp32);
|
||||
free(y_fp32);
|
||||
|
||||
return d;
|
||||
}
|
||||
|
||||
#if defined(SMP)
|
||||
static int shdot_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, bfloat16 dummy2,
|
||||
bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y,
|
||||
float *result, BLASLONG dummy3)
|
||||
{
|
||||
*(float *)result = shdot_compute(n, x, inc_x, y, inc_y);
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha,
|
||||
void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc,
|
||||
int (*function)(), int nthreads);
|
||||
#endif
|
||||
|
||||
float CNAME(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y)
|
||||
{
|
||||
float dot_result = 0.0;
|
||||
|
||||
if (n <= 0) return 0.0;
|
||||
|
||||
#if defined(SMP)
|
||||
int nthreads;
|
||||
int thread_thres = 40960;
|
||||
bfloat16 dummy_alpha;
|
||||
#endif
|
||||
|
||||
#if defined(SMP)
|
||||
if (inc_x == 0 || inc_y == 0 || n <= thread_thres)
|
||||
nthreads = 1;
|
||||
else
|
||||
nthreads = num_cpu_avail(1);
|
||||
|
||||
int best_threads = (int) (n/(float)thread_thres + 0.5);
|
||||
|
||||
if (best_threads < nthreads) {
|
||||
nthreads = best_threads;
|
||||
}
|
||||
|
||||
if (nthreads <= 1) {
|
||||
dot_result = shdot_compute(n, x, inc_x, y, inc_y);
|
||||
} else {
|
||||
char thread_result[MAX_CPU_NUMBER * sizeof(double) * 2];
|
||||
int mode = BLAS_BFLOAT16 | BLAS_REAL;
|
||||
blas_level1_thread_with_return_value(mode, n, 0, 0, &dummy_alpha,
|
||||
x, inc_x, y, inc_y, thread_result, 0,
|
||||
(void *)shdot_thread_func, nthreads);
|
||||
float * ptr = (float *)thread_result;
|
||||
for (int i = 0; i < nthreads; i++) {
|
||||
dot_result += (*ptr);
|
||||
ptr = (float *)(((char *)ptr) + sizeof(double) * 2);
|
||||
}
|
||||
}
|
||||
#else
|
||||
dot_result = shdot_compute(n, x, inc_x, y, inc_y);
|
||||
#endif
|
||||
|
||||
return dot_result;
|
||||
}
|
|
@ -0,0 +1,159 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2014, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
/* need a new enough GCC for avx512 support */
|
||||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
|
||||
|
||||
#define HAVE_SHDOT_ACCL_KERNEL 1
|
||||
#include "common.h"
|
||||
#include <immintrin.h>
|
||||
|
||||
static float shdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
|
||||
{
|
||||
__m128 accum128 = _mm_setzero_ps();
|
||||
if (n> 127) { /* n range from 128 to inf. */
|
||||
long tail_index_32 = n&(~31);
|
||||
long tail_index_128 = n&(~127);
|
||||
unsigned int tail_mask_uint = (((unsigned int)0xffffffff) >> (32-(n&31)));
|
||||
__mmask32 tail_mask = *((__mmask32*) &tail_mask_uint);
|
||||
|
||||
__m512 accum512_0 = _mm512_setzero_ps();
|
||||
__m512 accum512_1 = _mm512_setzero_ps();
|
||||
__m512 accum512_2 = _mm512_setzero_ps();
|
||||
__m512 accum512_3 = _mm512_setzero_ps();
|
||||
|
||||
/* Processing the main chunk with 128-elements per round */
|
||||
for (long i = 0; i < tail_index_128; i += 128) {
|
||||
accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) _mm512_loadu_si512(&x[i+ 0]), (__m512bh) _mm512_loadu_si512(&y[i+ 0]));
|
||||
accum512_1 = _mm512_dpbf16_ps(accum512_1, (__m512bh) _mm512_loadu_si512(&x[i+32]), (__m512bh) _mm512_loadu_si512(&y[i+32]));
|
||||
accum512_2 = _mm512_dpbf16_ps(accum512_2, (__m512bh) _mm512_loadu_si512(&x[i+64]), (__m512bh) _mm512_loadu_si512(&y[i+64]));
|
||||
accum512_3 = _mm512_dpbf16_ps(accum512_3, (__m512bh) _mm512_loadu_si512(&x[i+96]), (__m512bh) _mm512_loadu_si512(&y[i+96]));
|
||||
}
|
||||
|
||||
/* Processing the remaining <128 chunk with 32-elements per round */
|
||||
for (long j = tail_index_128; j < tail_index_32; j += 32) {
|
||||
accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) _mm512_loadu_si512(&x[j]), (__m512bh) _mm512_loadu_si512(&y[j]));
|
||||
}
|
||||
|
||||
/* Processing the remaining <32 chunk with masked 32-elements processing */
|
||||
if ((n&31) != 0) {
|
||||
accum512_2 = _mm512_dpbf16_ps(accum512_2,
|
||||
(__m512bh) _mm512_maskz_loadu_epi16(tail_mask, &x[tail_index_32]),
|
||||
(__m512bh) _mm512_maskz_loadu_epi16(tail_mask, &y[tail_index_32]));
|
||||
}
|
||||
|
||||
/* Accumulate the 4 registers into 1 register */
|
||||
accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
|
||||
accum512_2 = _mm512_add_ps(accum512_2, accum512_3);
|
||||
accum512_0 = _mm512_add_ps(accum512_0, accum512_2);
|
||||
|
||||
__m256 accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
|
||||
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
|
||||
} else if (n > 31) { /* n range from 32 to 127 */
|
||||
/* Processing <128 chunk with 32-elements per round */
|
||||
__m256 accum256 = _mm256_setzero_ps();
|
||||
__m256 accum256_1 = _mm256_setzero_ps();
|
||||
int tail_index_32 = n&(~31);
|
||||
for (int j = 0; j < tail_index_32; j += 32) {
|
||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[j+ 0]), (__m256bh) _mm256_loadu_si256(&y[j+ 0]));
|
||||
accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256(&x[j+16]), (__m256bh) _mm256_loadu_si256(&y[j+16]));
|
||||
}
|
||||
accum256 = _mm256_add_ps(accum256, accum256_1);
|
||||
|
||||
/* Processing the remaining <32 chunk with 16-elements processing */
|
||||
if ((n&16) != 0) {
|
||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[tail_index_32]), (__m256bh) _mm256_loadu_si256(&y[tail_index_32]));
|
||||
}
|
||||
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
|
||||
|
||||
/* Processing the remaining <16 chunk with 8-elements processing */
|
||||
if ((n&8) != 0) {
|
||||
int tail_index_16 = n&(~15);
|
||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16]));
|
||||
}
|
||||
|
||||
/* Processing the remaining <8 chunk with masked 8-elements processing */
|
||||
if ((n&7) != 0) {
|
||||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
|
||||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
|
||||
int tail_index_8 = n&(~7);
|
||||
accum128 = _mm_dpbf16_ps(accum128,
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]),
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8]));
|
||||
}
|
||||
} else if (n > 15) { /* n range from 16 to 31 */
|
||||
/* Processing <32 chunk with 16-elements processing */
|
||||
__m256 accum256 = _mm256_setzero_ps();
|
||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[0]), (__m256bh) _mm256_loadu_si256(&y[0]));
|
||||
accum128 += _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
|
||||
|
||||
/* Processing the remaining <16 chunk with 8-elements processing */
|
||||
if ((n&8) != 0) {
|
||||
int tail_index_16 = n&(~15);
|
||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16]));
|
||||
}
|
||||
|
||||
/* Processing the remaining <8 chunk with masked 8-elements processing */
|
||||
if ((n&7) != 0) {
|
||||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
|
||||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
|
||||
int tail_index_8 = n&(~7);
|
||||
accum128 = _mm_dpbf16_ps(accum128,
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]),
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8]));
|
||||
}
|
||||
} else if (n > 7) { /* n range from 8 to 15 */
|
||||
/* Processing <16 chunk with 8-elements processing */
|
||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[0]), (__m128bh) _mm_loadu_si128(&y[0]));
|
||||
|
||||
/* Processing the remaining <8 chunk with masked 8-elements processing */
|
||||
if ((n&7) != 0) {
|
||||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
|
||||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
|
||||
int tail_index_8 = n&(~7);
|
||||
accum128 = _mm_dpbf16_ps(accum128,
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]),
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8]));
|
||||
}
|
||||
} else { /* n range from 1 to 7 */
|
||||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
|
||||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
|
||||
accum128 = _mm_dpbf16_ps(accum128,
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[0]),
|
||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[0]));
|
||||
}
|
||||
|
||||
/* Add up the 4 elements into lowest entry */
|
||||
__m128 accum128_1 = _mm_shuffle_ps(accum128, accum128, 14);
|
||||
accum128 = _mm_add_ps(accum128, accum128_1);
|
||||
accum128_1 = _mm_shuffle_ps(accum128, accum128, 1);
|
||||
accum128 = _mm_add_ps(accum128, accum128_1);
|
||||
|
||||
return accum128[0];
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,86 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2014, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
/* need a new enough GCC for avx512 support */
|
||||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
|
||||
|
||||
#define HAVE_TOBF16_ACCL_KERNEL 1
|
||||
#include "common.h"
|
||||
#include <immintrin.h>
|
||||
|
||||
static void tobf16_accl_kernel(BLASLONG n, const float * in, bfloat16 * out)
|
||||
{
|
||||
/* Get the 64-bytes unaligned header number targeting for avx512
|
||||
* processing (Assume input float array is natural aligned) */
|
||||
int align_header = ((64 - ((uintptr_t)in & (uintptr_t)0x3f)) >> 2) & 0xf;
|
||||
|
||||
if (n < align_header) {align_header = n;}
|
||||
|
||||
if (align_header != 0) {
|
||||
uint16_t align_mask16 = (((uint16_t)0xffff) >> (16-align_header));
|
||||
__m512 a = _mm512_maskz_loadu_ps(*((__mmask16*) &align_mask16), &in[0]);
|
||||
_mm256_mask_storeu_epi16(&out[0], *((__mmask16*) &align_mask16), (__m256i) _mm512_cvtneps_pbh(a));
|
||||
}
|
||||
|
||||
if (n == align_header) {
|
||||
return;
|
||||
} else {
|
||||
n -= align_header;
|
||||
in += align_header;
|
||||
out += align_header;
|
||||
}
|
||||
|
||||
int tail_index_32 = n&(~31);
|
||||
int tail_index_128 = n&(~127);
|
||||
uint32_t tail_mask32 = (((uint32_t) 0xffffffff) >> (32-(n&31)));
|
||||
uint16_t tail_mask16 = (((uint16_t) 0xffff) >> (16-(n&15)));
|
||||
|
||||
/* Processing the main chunk with 128-elements per round */
|
||||
for (int i = 0; i < tail_index_128; i += 128) {
|
||||
_mm512_storeu_si512(&out[i+ 0], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 16]), _mm512_load_ps(&in[i+ 0])));
|
||||
_mm512_storeu_si512(&out[i+32], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 48]), _mm512_load_ps(&in[i+32])));
|
||||
_mm512_storeu_si512(&out[i+64], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 80]), _mm512_load_ps(&in[i+64])));
|
||||
_mm512_storeu_si512(&out[i+96], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+112]), _mm512_load_ps(&in[i+96])));
|
||||
}
|
||||
|
||||
/* Processing the remaining <128 chunk with 32-elements per round */
|
||||
for (int j = tail_index_128; j < tail_index_32; j += 32) {
|
||||
_mm512_storeu_si512(&out[j], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[j+ 16]), _mm512_load_ps(&in[j])));
|
||||
}
|
||||
|
||||
/* Processing the remaining <32 chunk with masked processing */
|
||||
if ((n&31) > 15) {
|
||||
__m512 b = _mm512_load_ps(&in[tail_index_32]);
|
||||
__m512 a = _mm512_maskz_load_ps(*((__mmask16*) &tail_mask16), &in[tail_index_32+16]);
|
||||
_mm512_mask_storeu_epi16(&out[tail_index_32], *((__mmask32*) &tail_mask32), (__m512i) _mm512_cvtne2ps_pbh(a, b));
|
||||
} else if ((n&31) > 0) {
|
||||
__m512 a = _mm512_maskz_load_ps(*((__mmask16*) &tail_mask16), &in[tail_index_32]);
|
||||
_mm256_mask_storeu_epi16(&out[tail_index_32], *((__mmask16*) &tail_mask16), (__m256i) _mm512_cvtneps_pbh(a));
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,170 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2014, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include <stddef.h>
|
||||
#include "common.h"
|
||||
|
||||
#if defined(DOUBLE)
|
||||
#define FLOAT_TYPE double
|
||||
#elif defined(SINGLE)
|
||||
#define FLOAT_TYPE float
|
||||
#else
|
||||
#endif
|
||||
|
||||
#if defined(COOPERLAKE)
|
||||
#if defined(DOUBLE)
|
||||
#include "dtobf16_microk_cooperlake.c"
|
||||
#elif defined(SINGLE)
|
||||
#include "stobf16_microk_cooperlake.c"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/* Notes for algorithm:
|
||||
* - Round to Nearest Even used generally
|
||||
* - QNAN for NAN case
|
||||
* - Input denormals are treated as zero
|
||||
*/
|
||||
static void tobf16_generic_kernel(BLASLONG n, const FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out)
|
||||
{
|
||||
BLASLONG register index_in = 0;
|
||||
BLASLONG register index_out = 0;
|
||||
BLASLONG register index = 0;
|
||||
float float_in = 0.0;
|
||||
uint32_t * uint32_in = (uint32_t *)(&float_in);
|
||||
uint16_t * uint16_in = (uint16_t *)(&float_in);
|
||||
|
||||
while(index<n) {
|
||||
#if defined(DOUBLE)
|
||||
float_in = (float)(*(in+index_in));
|
||||
#else
|
||||
float_in = *(in+index_in);
|
||||
#endif
|
||||
|
||||
switch((*uint32_in) & 0xff800000u) {
|
||||
case (0x00000000u): /* Type 1: Positive denormal */
|
||||
*(out+index_out) = 0x0000u;
|
||||
break;
|
||||
case (0x80000000u): /* Type 2: Negative denormal */
|
||||
*(out+index_out) = 0x8000u;
|
||||
break;
|
||||
case (0x7f800000u): /* Type 3: Positive infinity or NAN */
|
||||
case (0xff800000u): /* Type 4: Negative infinity or NAN */
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
*(out+index_out) = uint16_in[1];
|
||||
#else
|
||||
*(out+index_out) = uint16_in[0];
|
||||
#endif
|
||||
/* Specific for NAN */
|
||||
if (((*uint32_in) & 0x007fffffu) != 0) {
|
||||
/* Force to be QNAN */
|
||||
*(out+index_out) |= 0x0040u;
|
||||
}
|
||||
break;
|
||||
default: /* Type 5: Normal case */
|
||||
(*uint32_in) += ((((*uint32_in) >> 16) & 0x1u) + 0x7fffu);
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
*(out+index_out) = uint16_in[1];
|
||||
#else
|
||||
*(out+index_out) = uint16_in[0];
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
|
||||
index_in += inc_in;
|
||||
index_out += inc_out;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef HAVE_TOBF16_ACCL_KERNEL
|
||||
static void tobf16_accl_kernel(BLASLONG n, const FLOAT_TYPE * in, bfloat16 * out)
|
||||
{
|
||||
tobf16_generic_kernel(n, in, 1, out, 1);
|
||||
}
|
||||
#endif
|
||||
|
||||
static void tobf16_compute(BLASLONG n, FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out)
|
||||
{
|
||||
if ((inc_in == 1) && (inc_out == 1)) {
|
||||
tobf16_accl_kernel(n, in, out);
|
||||
} else {
|
||||
tobf16_generic_kernel(n, in, inc_in, out, inc_out);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(SMP)
|
||||
static int tobf16_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT_TYPE dummy2,
|
||||
FLOAT_TYPE *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y,
|
||||
FLOAT_TYPE *dummy3, BLASLONG dummy4)
|
||||
{
|
||||
tobf16_compute(n, x, inc_x, y, inc_y);
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha,
|
||||
void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc,
|
||||
int (*function)(), int nthreads);
|
||||
#endif
|
||||
|
||||
void CNAME(BLASLONG n, FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out)
|
||||
{
|
||||
if (n <= 0) return;
|
||||
|
||||
#if defined(SMP)
|
||||
int nthreads;
|
||||
FLOAT_TYPE dummy_alpha;
|
||||
FLOAT_TYPE dummy_c;
|
||||
#endif
|
||||
|
||||
#if defined(SMP)
|
||||
if (inc_in == 0 || inc_out == 0 || n <= 100000) {
|
||||
nthreads = 1;
|
||||
} else {
|
||||
if (n/100000 < 100) {
|
||||
nthreads = 4;
|
||||
} else {
|
||||
nthreads = 16;
|
||||
}
|
||||
}
|
||||
|
||||
if (nthreads == 1) {
|
||||
tobf16_compute(n, in, inc_in, out, inc_out);
|
||||
} else {
|
||||
#if defined(DOUBLE)
|
||||
int mode = BLAS_REAL | BLAS_DTOBF16;
|
||||
#elif defined(SINGLE)
|
||||
int mode = BLAS_REAL | BLAS_STOBF16;
|
||||
#endif
|
||||
blas_level1_thread(mode, n, 0, 0, &dummy_alpha,
|
||||
in, inc_in, out, inc_out, &dummy_c, 0,
|
||||
(void *)tobf16_thread_func, nthreads);
|
||||
}
|
||||
#else
|
||||
tobf16_compute(n, in, inc_in, out, inc_out);
|
||||
#endif
|
||||
|
||||
}
|
|
@ -35,7 +35,8 @@ typedef unsigned long BLASULONG;
|
|||
#endif
|
||||
|
||||
#ifndef BFLOAT16
|
||||
typedef unsigned short bfloat16;
|
||||
#include <stdint.h>
|
||||
typedef uint16_t bfloat16;
|
||||
#endif
|
||||
|
||||
#ifdef OPENBLAS_USE64BITINT
|
||||
|
|
Loading…
Reference in New Issue