From deaeb6c5b89f64bbe9d5ba0126690ae5d57ae0ce Mon Sep 17 00:00:00 2001 From: "Chen, Guobing" Date: Thu, 27 Aug 2020 06:42:28 +0800 Subject: [PATCH 01/13] Add bfloat16 based dot and conversion with single/double 1. Added bfloat16 based dot as new API: shdot 2. Implemented generic kernel and cooperlake-specific (AVX512-BF16) kernel for shdot 3. Added 4 conversion APIs for bfloat16 data type <=> single/double: shstobf16 shdtobf16 sbf16tos dbf16tod shstobf16 -- convert single float array to bfloat16 array shdtobf16 -- convert double float array to bfloat16 array sbf16tos -- convert bfloat16 array to single float array dbf16tod -- convert bfloat16 array to double float array 4. Implemented generic kernels for all 4 conversion APIs, and cooperlake-specific kernel for shstobf16 and shdtobf16 5. Update level1 thread facilitate functions and macros to support multi-threading for these new APIs 6. Fix Cooperlake platform detection/specify issue when under dynamic-arch building 7. Change the typedef of bfloat16 from unsigned short to more strict uint16_t Signed-off-by: Chen, Guobing --- Makefile.tail | 7 +- cblas.h | 11 ++ cmake/kernel.cmake | 4 +- common.h | 3 +- common_interface.h | 5 + common_level1.h | 6 + common_macro.h | 6 + common_param.h | 7 +- common_sh.h | 12 ++ common_thread.h | 19 ++- common_x86_64.h | 23 +++ driver/others/blas_l1_thread.c | 70 ++++++++- driver/others/blas_server.c | 91 +++++++++--- driver/others/blas_server_omp.c | 71 +++++++-- driver/others/blas_server_win32.c | 69 +++++++-- driver/others/dynamic.c | 44 +++++- exports/gensymbol | 4 +- interface/Makefile | 38 ++++- interface/bf16dot.c | 52 +++++++ interface/bf16to.c | 62 ++++++++ interface/tobf16.c | 61 ++++++++ kernel/Makefile.L1 | 36 +++++ kernel/setparam-ref.c | 4 +- kernel/x86_64/KERNEL | 12 ++ kernel/x86_64/bf16to.c | 114 +++++++++++++++ kernel/x86_64/dtobf16_microk_cooperlake.c | 104 +++++++++++++ kernel/x86_64/shdot.c | 115 +++++++++++++++ kernel/x86_64/shdot_microk_cooperlake.c | 159 ++++++++++++++++++++ kernel/x86_64/stobf16_microk_cooperlake.c | 86 +++++++++++ kernel/x86_64/tobf16.c | 170 ++++++++++++++++++++++ openblas_config_template.h | 3 +- 31 files changed, 1389 insertions(+), 79 deletions(-) create mode 100644 interface/bf16dot.c create mode 100644 interface/bf16to.c create mode 100644 interface/tobf16.c create mode 100644 kernel/x86_64/bf16to.c create mode 100644 kernel/x86_64/dtobf16_microk_cooperlake.c create mode 100644 kernel/x86_64/shdot.c create mode 100644 kernel/x86_64/shdot_microk_cooperlake.c create mode 100644 kernel/x86_64/stobf16_microk_cooperlake.c create mode 100644 kernel/x86_64/tobf16.c diff --git a/Makefile.tail b/Makefile.tail index 39902982b..cfc4a36fc 100644 --- a/Makefile.tail +++ b/Makefile.tail @@ -5,13 +5,14 @@ QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) CBLASOBJS_P = $(CBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) ZBLASOBJS_P = $(ZBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) XBLASOBJS_P = $(XBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) +SHEXTOBJS_P = $(SHEXTOBJS:.$(SUFFIX)=.$(PSUFFIX)) COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX)) HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) -BLASOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) -BLASOBJS_P = $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) +BLASOBJS = $(SHEXTOBJS) $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) +BLASOBJS_P = $(SHEXTOBJS_P) $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) ifdef EXPRECISION BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) @@ -30,6 +31,7 @@ $(QBLASOBJS) $(QBLASOBJS_P) : override CFLAGS += -DXDOUBLE -UCOMPLEX $(CBLASOBJS) $(CBLASOBJS_P) : override CFLAGS += -UDOUBLE -DCOMPLEX $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX $(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX +$(SHEXTOBJS) $(SHEXTOBJS_P) : override CFLAGS += -DHALF -UDOUBLE -UCOMPLEX $(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) @@ -38,6 +40,7 @@ $(QBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(CBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(ZBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) +$(SHEXTOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) libs :: $(BLASOBJS) $(COMMONOBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ diff --git a/cblas.h b/cblas.h index 4bc5588d8..21f3958f2 100644 --- a/cblas.h +++ b/cblas.h @@ -382,6 +382,17 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta, double *c, OPENBLAS_CONST blasint cldc); +/*** BFLOAT16 and INT8 extensions ***/ +/* convert float array to BFLOAT16 array by rounding */ +void cblas_shstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout); +/* convert double array to BFLOAT16 array by rounding */ +void cblas_shdtobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST double *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout); +/* convert BFLOAT16 array to float array */ +void cblas_sbf16tos(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, float *out, OPENBLAS_CONST blasint incout); +/* convert BFLOAT16 array to double array */ +void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout); +/* dot production of BFLOAT16 input arrays, and output as float */ +float cblas_shdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy); #ifdef __cplusplus } diff --git a/cmake/kernel.cmake b/cmake/kernel.cmake index 4b505a102..79eeaae6f 100644 --- a/cmake/kernel.cmake +++ b/cmake/kernel.cmake @@ -126,12 +126,14 @@ if (BUILD_HALF) set(SHAXPYKERNEL ../arm/axpy.c) set(SHAXPBYKERNEL ../arm/axpby.c) set(SHCOPYKERNEL ../arm/copy.c) - set(SHDOTKERNEL ../arm/dot.c) + set(SHDOTKERNEL ../x86_64/shdot.c) set(SHROTKERNEL ../arm/rot.c) set(SHSCALKERNEL ../arm/scal.c) set(SHNRM2KERNEL ../arm/nrm2.c) set(SHSUMKERNEL ../arm/sum.c) set(SHSWAPKERNEL ../arm/swap.c) + set(TOBF16KERNEL ../x86_64/tobf16.c) + set(BF16TOKERNEL ../x86_64/bf16to.c) endif () endmacro () diff --git a/common.h b/common.h index d6637abe4..adc162536 100644 --- a/common.h +++ b/common.h @@ -258,7 +258,8 @@ typedef unsigned long BLASULONG; #endif #ifndef BFLOAT16 -typedef unsigned short bfloat16; +#include +typedef uint16_t bfloat16; #define HALFCONVERSION 1 #endif diff --git a/common_interface.h b/common_interface.h index 78f5be6b0..35a957aa1 100644 --- a/common_interface.h +++ b/common_interface.h @@ -54,6 +54,11 @@ double BLASFUNC(dsdot) (blasint *, float *, blasint *, float *, blasint *); double BLASFUNC(ddot) (blasint *, double *, blasint *, double *, blasint *); xdouble BLASFUNC(qdot) (blasint *, xdouble *, blasint *, xdouble *, blasint *); +float BLASFUNC(shdot) (blasint *, bfloat16 *, blasint *, bfloat16 *, blasint *); +void BLASFUNC(shstobf16) (blasint *, float *, blasint *, bfloat16 *, blasint *); +void BLASFUNC(shdtobf16) (blasint *, double *, blasint *, bfloat16 *, blasint *); +void BLASFUNC(sbf16tos) (blasint *, bfloat16 *, blasint *, float *, blasint *); +void BLASFUNC(dbf16tod) (blasint *, bfloat16 *, blasint *, double *, blasint *); #ifdef RETURN_BY_STRUCT typedef struct { diff --git a/common_level1.h b/common_level1.h index 74cafb6db..88aa275a5 100644 --- a/common_level1.h +++ b/common_level1.h @@ -46,6 +46,12 @@ float sdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG); double dsdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG); double ddot_k(BLASLONG, double *, BLASLONG, double *, BLASLONG); xdouble qdot_k(BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG); +float shdot_k(BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); + +void shstobf16_k(BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG); +void shdtobf16_k(BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG); +void sbf16tos_k (BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); +void dbf16tod_k (BLASLONG, bfloat16 *, BLASLONG, double *, BLASLONG); openblas_complex_float cdotc_k (BLASLONG, float *, BLASLONG, float *, BLASLONG); openblas_complex_float cdotu_k (BLASLONG, float *, BLASLONG, float *, BLASLONG); diff --git a/common_macro.h b/common_macro.h index 8fe1f156f..3d6bcd9e8 100644 --- a/common_macro.h +++ b/common_macro.h @@ -646,6 +646,11 @@ #elif defined(HALF) +#define D_TO_BF16_K SHDTOBF16_K +#define D_BF16_TO_K DBF16TOD_K +#define S_TO_BF16_K SHSTOBF16_K +#define S_BF16_TO_K SBF16TOS_K + #define AMAX_K SAMAX_K #define AMIN_K SAMIN_K #define MAX_K SMAX_K @@ -657,6 +662,7 @@ #define ASUM_K SASUM_K #define DOTU_K SDOTU_K #define DOTC_K SDOTC_K +#define BF16_DOT_K SHDOT_K #define AXPYU_K SAXPYU_K #define AXPYC_K SAXPYC_K #define AXPBY_K SAXPBY_K diff --git a/common_param.h b/common_param.h index 0437482dc..a52de98ab 100644 --- a/common_param.h +++ b/common_param.h @@ -51,6 +51,11 @@ typedef struct { int shgemm_p, shgemm_q, shgemm_r; int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; + void (*shstobf16_k) (BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG); + void (*shdtobf16_k) (BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG); + void (*sbf16tos_k) (BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); + void (*dbf16tod_k) (BLASLONG, bfloat16 *, BLASLONG, double *, BLASLONG); + float (*shamax_k) (BLASLONG, float *, BLASLONG); float (*shamin_k) (BLASLONG, float *, BLASLONG); float (*shmax_k) (BLASLONG, float *, BLASLONG); @@ -64,7 +69,7 @@ BLASLONG (*ishmin_k) (BLASLONG, float *, BLASLONG); float (*shasum_k) (BLASLONG, float *, BLASLONG); float (*shsum_k) (BLASLONG, float *, BLASLONG); int (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); - float (*shdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); + float (*shdot_k) (BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); int (*shrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float); diff --git a/common_sh.h b/common_sh.h index 7a0045762..5dc99b3bd 100644 --- a/common_sh.h +++ b/common_sh.h @@ -3,6 +3,12 @@ #ifndef DYNAMIC_ARCH +#define SHDOT_K shdot_k +#define SHSTOBF16_K shstobf16_k +#define SHDTOBF16_K shdtobf16_k +#define SBF16TOS_K sbf16tos_k +#define DBF16TOD_K dbf16tod_k + #define SHGEMM_ONCOPY shgemm_oncopy #define SHGEMM_OTCOPY shgemm_otcopy @@ -18,6 +24,12 @@ #else +#define SHDOT_K gotoblas -> shdot_k +#define SHSTOBF16_K gotoblas -> shstobf16_k +#define SHDTOBF16_K gotoblas -> shdtobf16_k +#define SBF16TOS_K gotoblas -> sbf16tos_k +#define DBF16TOD_K gotoblas -> dbf16tod_k + #define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy #define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy #define SHGEMM_INCOPY gotoblas -> shgemm_incopy diff --git a/common_thread.h b/common_thread.h index ec0c65b22..a18df0d78 100644 --- a/common_thread.h +++ b/common_thread.h @@ -59,12 +59,19 @@ extern int blas_omp_linked; #define BLAS_PTHREAD 0x4000U #define BLAS_NODE 0x2000U -#define BLAS_PREC 0x0003U -#define BLAS_SINGLE 0x0000U -#define BLAS_DOUBLE 0x0001U -#define BLAS_XDOUBLE 0x0002U -#define BLAS_REAL 0x0000U -#define BLAS_COMPLEX 0x0004U +#define BLAS_PREC 0x000FU +#define BLAS_INT8 0x0000U +#define BLAS_BFLOAT16 0x0001U +#define BLAS_SINGLE 0x0002U +#define BLAS_DOUBLE 0x0003U +#define BLAS_XDOUBLE 0x0004U +#define BLAS_STOBF16 0x0008U +#define BLAS_DTOBF16 0x0009U +#define BLAS_BF16TOS 0x000AU +#define BLAS_BF16TOD 0x000BU + +#define BLAS_REAL 0x0000U +#define BLAS_COMPLEX 0x1000U #define BLAS_TRANSA 0x0030U /* 2bit */ #define BLAS_TRANSA_N 0x0000U diff --git a/common_x86_64.h b/common_x86_64.h index bee7e8cdb..b813336c6 100644 --- a/common_x86_64.h +++ b/common_x86_64.h @@ -142,6 +142,29 @@ static __inline void cpuid(int op, int *eax, int *ebx, int *ecx, int *edx){ #endif } +static __inline void cpuid_count(int op, int count, int *eax, int *ebx, int *ecx, int *edx) +{ +#ifdef C_MSVC + int cpuInfo[4] = {-1}; + __cpuidex(cpuInfo, op, count); + *eax = cpuInfo[0]; + *ebx = cpuInfo[1]; + *ecx = cpuInfo[2]; + *edx = cpuInfo[3]; +#else +#if defined(__i386__) && defined(__PIC__) + __asm__ __volatile__ + ("mov %%ebx, %%edi;" + "cpuid;" + "xchgl %%ebx, %%edi;" + : "=a" (*eax), "=D" (*ebx), "=c" (*ecx), "=d" (*edx) : "0" (op), "2" (count) : "cc"); +#else + __asm__ __volatile__ + ("cpuid": "=a" (*eax), "=b" (*ebx), "=c" (*ecx), "=d" (*edx) : "0" (op), "2" (count) : "cc"); +#endif +#endif +} + /* #define WHEREAMI */ diff --git a/driver/others/blas_l1_thread.c b/driver/others/blas_l1_thread.c index e405c7465..04acbcc5f 100644 --- a/driver/others/blas_l1_thread.c +++ b/driver/others/blas_l1_thread.c @@ -49,9 +49,36 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha blas_arg_t args [MAX_CPU_NUMBER]; BLASLONG i, width, astride, bstride; - int num_cpu, calc_type; + int num_cpu, calc_type_a, calc_type_b; - calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2; + switch (mode & BLAS_PREC) { + case BLAS_INT8 : + case BLAS_BFLOAT16: + case BLAS_SINGLE : + case BLAS_DOUBLE : + case BLAS_XDOUBLE : + calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0); + break; + case BLAS_STOBF16 : + calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0); + calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0); + break; + case BLAS_DTOBF16 : + calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0); + calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0); + break; + case BLAS_BF16TOS : + calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0); + calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0); + break; + case BLAS_BF16TOD : + calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0); + calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0); + break; + default: + calc_type_a = calc_type_b = 0; + break; + } mode |= BLAS_LEGACY; @@ -77,8 +104,8 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha bstride = width; } - astride <<= calc_type; - bstride <<= calc_type; + astride <<= calc_type_a; + bstride <<= calc_type_b; args[num_cpu].m = width; args[num_cpu].n = n; @@ -120,9 +147,36 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL blas_arg_t args [MAX_CPU_NUMBER]; BLASLONG i, width, astride, bstride; - int num_cpu, calc_type; + int num_cpu, calc_type_a, calc_type_b; - calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2; + switch (mode & BLAS_PREC) { + case BLAS_INT8 : + case BLAS_BFLOAT16: + case BLAS_SINGLE : + case BLAS_DOUBLE : + case BLAS_XDOUBLE : + calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0); + break; + case BLAS_STOBF16 : + calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0); + calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0); + break; + case BLAS_DTOBF16 : + calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0); + calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0); + break; + case BLAS_BF16TOS : + calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0); + calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0); + break; + case BLAS_BF16TOD : + calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0); + calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0); + break; + default: + calc_type_a = calc_type_b = 0; + break; + } mode |= BLAS_LEGACY; @@ -148,8 +202,8 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL bstride = width; } - astride <<= calc_type; - bstride <<= calc_type; + astride <<= calc_type_a; + bstride <<= calc_type_b; args[num_cpu].m = width; args[num_cpu].n = n; diff --git a/driver/others/blas_server.c b/driver/others/blas_server.c index 756e51b5d..8d3dda3bf 100644 --- a/driver/others/blas_server.c +++ b/driver/others/blas_server.c @@ -192,7 +192,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ if (!(mode & BLAS_COMPLEX)){ #ifdef EXPRECISION - if (mode & BLAS_XDOUBLE){ + if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ /* REAL / Extended Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, @@ -205,7 +205,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> c, args -> ldc, sb); } else #endif - if (mode & BLAS_DOUBLE){ + if ((mode & BLAS_PREC) == BLAS_DOUBLE){ /* REAL / Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double *, BLASLONG, double *, BLASLONG, @@ -216,21 +216,58 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); - } else { - /* REAL / Single */ - void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, - float *, BLASLONG, float *, BLASLONG, - float *, BLASLONG, void *) = func; + } else if ((mode & BLAS_PREC) == BLAS_SINGLE){ + /* REAL / Single */ + void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, + float *, BLASLONG, float *, BLASLONG, + float *, BLASLONG, void *) = func; - afunc(args -> m, args -> n, args -> k, - ((float *)args -> alpha)[0], - args -> a, args -> lda, - args -> b, args -> ldb, - args -> c, args -> ldc, sb); + afunc(args -> m, args -> n, args -> k, + ((float *)args -> alpha)[0], + args -> a, args -> lda, + args -> b, args -> ldb, + args -> c, args -> ldc, sb); +#ifdef BUILD_HALF + } else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){ + /* REAL / BFLOAT16 */ + void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16, + bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, + bfloat16 *, BLASLONG, void *) = func; + + afunc(args -> m, args -> n, args -> k, + ((bfloat16 *)args -> alpha)[0], + args -> a, args -> lda, + args -> b, args -> ldb, + args -> c, args -> ldc, sb); + } else if ((mode & BLAS_PREC) == BLAS_STOBF16){ + /* REAL / BLAS_STOBF16 */ + void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, + float *, BLASLONG, bfloat16 *, BLASLONG, + float *, BLASLONG, void *) = func; + + afunc(args -> m, args -> n, args -> k, + ((float *)args -> alpha)[0], + args -> a, args -> lda, + args -> b, args -> ldb, + args -> c, args -> ldc, sb); + } else if ((mode & BLAS_PREC) == BLAS_DTOBF16){ + /* REAL / BLAS_DTOBF16 */ + void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, + double *, BLASLONG, bfloat16 *, BLASLONG, + double *, BLASLONG, void *) = func; + + afunc(args -> m, args -> n, args -> k, + ((double *)args -> alpha)[0], + args -> a, args -> lda, + args -> b, args -> ldb, + args -> c, args -> ldc, sb); +#endif + } else { + /* REAL / Other types in future */ } } else { #ifdef EXPRECISION - if (mode & BLAS_XDOUBLE){ + if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ /* COMPLEX / Extended Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, @@ -244,7 +281,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> c, args -> ldc, sb); } else #endif - if (mode & BLAS_DOUBLE){ + if ((mode & BLAS_PREC) == BLAS_DOUBLE) { /* COMPLEX / Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double, double *, BLASLONG, double *, BLASLONG, @@ -256,7 +293,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); - } else { + } else if ((mode & BLAS_PREC) == BLAS_SINGLE) { /* COMPLEX / Single */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float, float *, BLASLONG, float *, BLASLONG, @@ -268,7 +305,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); - } + } else { + /* COMPLEX / Other types in future */ + } } } @@ -414,33 +453,37 @@ blas_queue_t *tscq; if (sb == NULL) { if (!(queue -> mode & BLAS_COMPLEX)){ #ifdef EXPRECISION - if (queue -> mode & BLAS_XDOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); } else #endif - if (queue -> mode & BLAS_DOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE) { sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); - } else { + } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); - } + } else { + /* Other types in future */ + } } else { #ifdef EXPRECISION - if (queue -> mode & BLAS_XDOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); } else #endif - if (queue -> mode & BLAS_DOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); - } else { + } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); - } + } else { + /* Other types in future */ + } } queue->sb=sb; } diff --git a/driver/others/blas_server_omp.c b/driver/others/blas_server_omp.c index d9969b599..d126955e4 100644 --- a/driver/others/blas_server_omp.c +++ b/driver/others/blas_server_omp.c @@ -142,7 +142,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ if (!(mode & BLAS_COMPLEX)){ #ifdef EXPRECISION - if (mode & BLAS_XDOUBLE){ + if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ /* REAL / Extended Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, @@ -155,7 +155,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> c, args -> ldc, sb); } else #endif - if (mode & BLAS_DOUBLE){ + if ((mode & BLAS_PREC) == BLAS_DOUBLE){ /* REAL / Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double *, BLASLONG, double *, BLASLONG, @@ -166,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); - } else { + } else if ((mode & BLAS_PREC) == BLAS_SINGLE){ /* REAL / Single */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, @@ -177,10 +177,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); +#ifdef BUILD_HALF + } else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){ + /* REAL / BFLOAT16 */ + void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16, + bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, + bfloat16 *, BLASLONG, void *) = func; + + afunc(args -> m, args -> n, args -> k, + ((bfloat16 *)args -> alpha)[0], + args -> a, args -> lda, + args -> b, args -> ldb, + args -> c, args -> ldc, sb); + } else if ((mode & BLAS_PREC) == BLAS_STOBF16){ + /* REAL / BLAS_STOBF16 */ + void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, + float *, BLASLONG, bfloat16 *, BLASLONG, + float *, BLASLONG, void *) = func; + + afunc(args -> m, args -> n, args -> k, + ((float *)args -> alpha)[0], + args -> a, args -> lda, + args -> b, args -> ldb, + args -> c, args -> ldc, sb); + } else if ((mode & BLAS_PREC) == BLAS_DTOBF16){ + /* REAL / BLAS_DTOBF16 */ + void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, + double *, BLASLONG, bfloat16 *, BLASLONG, + double *, BLASLONG, void *) = func; + + afunc(args -> m, args -> n, args -> k, + ((double *)args -> alpha)[0], + args -> a, args -> lda, + args -> b, args -> ldb, + args -> c, args -> ldc, sb); +#endif + } else { + /* REAL / Other types in future */ } } else { #ifdef EXPRECISION - if (mode & BLAS_XDOUBLE){ + if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ /* COMPLEX / Extended Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, @@ -194,7 +231,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> c, args -> ldc, sb); } else #endif - if (mode & BLAS_DOUBLE){ + if ((mode & BLAS_PREC) == BLAS_DOUBLE){ /* COMPLEX / Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double, double *, BLASLONG, double *, BLASLONG, @@ -206,7 +243,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); - } else { + } else if ((mode & BLAS_PREC) == BLAS_SINGLE){ /* COMPLEX / Single */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float, float *, BLASLONG, float *, BLASLONG, @@ -218,8 +255,10 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); - } - } + } else { + /* COMPLEX / Other types in future */ + } + } } static void exec_threads(blas_queue_t *queue, int buf_index){ @@ -255,32 +294,36 @@ static void exec_threads(blas_queue_t *queue, int buf_index){ if (sb == NULL) { if (!(queue -> mode & BLAS_COMPLEX)){ #ifdef EXPRECISION - if (queue -> mode & BLAS_XDOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); } else #endif - if (queue -> mode & BLAS_DOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); - } else { + } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE){ sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); + } else { + /* Other types in future */ } } else { #ifdef EXPRECISION - if (queue -> mode & BLAS_XDOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); } else #endif - if (queue -> mode & BLAS_DOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); - } else { + } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); + } else { + /* Other types in future */ } } queue->sb=sb; diff --git a/driver/others/blas_server_win32.c b/driver/others/blas_server_win32.c index 5ecc4428b..d2cc91757 100644 --- a/driver/others/blas_server_win32.c +++ b/driver/others/blas_server_win32.c @@ -77,7 +77,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ if (!(mode & BLAS_COMPLEX)){ #ifdef EXPRECISION - if (mode & BLAS_XDOUBLE){ + if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ /* REAL / Extended Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, @@ -90,7 +90,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> c, args -> ldc, sb); } else #endif - if (mode & BLAS_DOUBLE){ + if ((mode & BLAS_PREC) == BLAS_DOUBLE){ /* REAL / Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double *, BLASLONG, double *, BLASLONG, @@ -101,7 +101,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); - } else { + } else if ((mode & BLAS_PREC) == BLAS_SINGLE){ /* REAL / Single */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, @@ -112,10 +112,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); +#ifdef BUILD_HALF + } else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){ + /* REAL / BFLOAT16 */ + void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16, + bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, + bfloat16 *, BLASLONG, void *) = func; + + afunc(args -> m, args -> n, args -> k, + ((bfloat16 *)args -> alpha)[0], + args -> a, args -> lda, + args -> b, args -> ldb, + args -> c, args -> ldc, sb); + } else if ((mode & BLAS_PREC) == BLAS_STOBF16){ + /* REAL / BLAS_STOBF16 */ + void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, + float *, BLASLONG, bfloat16 *, BLASLONG, + float *, BLASLONG, void *) = func; + + afunc(args -> m, args -> n, args -> k, + ((float *)args -> alpha)[0], + args -> a, args -> lda, + args -> b, args -> ldb, + args -> c, args -> ldc, sb); + } else if ((mode & BLAS_PREC) == BLAS_DTOBF16){ + /* REAL / BLAS_DTOBF16 */ + void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, + double *, BLASLONG, bfloat16 *, BLASLONG, + double *, BLASLONG, void *) = func; + + afunc(args -> m, args -> n, args -> k, + ((double *)args -> alpha)[0], + args -> a, args -> lda, + args -> b, args -> ldb, + args -> c, args -> ldc, sb); +#endif + } else { + /* REAL / Other types in future */ } } else { #ifdef EXPRECISION - if (mode & BLAS_XDOUBLE){ + if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ /* COMPLEX / Extended Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, @@ -129,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> c, args -> ldc, sb); } else #endif - if (mode & BLAS_DOUBLE){ + if ((mode & BLAS_PREC) == BLAS_DOUBLE){ /* COMPLEX / Double */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double, double *, BLASLONG, double *, BLASLONG, @@ -141,7 +178,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); - } else { + } else if ((mode & BLAS_PREC) == BLAS_SINGLE) { /* COMPLEX / Single */ void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float, float *, BLASLONG, float *, BLASLONG, @@ -153,7 +190,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ args -> a, args -> lda, args -> b, args -> ldb, args -> c, args -> ldc, sb); - } + } else { + /* COMPLEX / Other types in future */ + } } } @@ -233,32 +272,36 @@ static DWORD WINAPI blas_thread_server(void *arg){ if (sb == NULL) { if (!(queue -> mode & BLAS_COMPLEX)){ #ifdef EXPRECISION - if (queue -> mode & BLAS_XDOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * sizeof(xdouble) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); } else #endif - if (queue -> mode & BLAS_DOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); - } else { + } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); + } else { + /* Other types in future */ } } else { #ifdef EXPRECISION - if (queue -> mode & BLAS_XDOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); } else #endif - if (queue -> mode & BLAS_DOUBLE){ + if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); - } else { + } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float) + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); + } else { + /* Other types in future */ } } queue->sb=sb; diff --git a/driver/others/dynamic.c b/driver/others/dynamic.c index 5d71b1b2c..21d2c7948 100644 --- a/driver/others/dynamic.c +++ b/driver/others/dynamic.c @@ -207,6 +207,19 @@ extern gotoblas_t gotoblas_SKYLAKEX; #else #define gotoblas_SKYLAKEX gotoblas_PRESCOTT #endif +#ifdef DYN_COOPERLAKE +extern gotoblas_t gotoblas_COOPERLAKE; +#elif defined(DYN_SKYLAKEX) +#define gotoblas_COOPERLAKE gotoblas_SKYLAKEX +#elif defined(DYN_HASWELL) +#define gotoblas_COOPERLAKE gotoblas_HASWELL +#elif defined(DYN_SANDYBRIDGE) +#define gotoblas_COOPERLAKE gotoblas_SANDYBRIDGE +#elif defined(DYN_NEHALEM) +#define gotoblas_COOPERLAKE gotoblas_NEHALEM +#else +#define gotoblas_COOPERLAKE gotoblas_PRESCOTT +#endif #else // not DYNAMIC_LIST @@ -247,14 +260,17 @@ extern gotoblas_t gotoblas_EXCAVATOR; #ifdef NO_AVX2 #define gotoblas_HASWELL gotoblas_SANDYBRIDGE #define gotoblas_SKYLAKEX gotoblas_SANDYBRIDGE +#define gotoblas_COOPERLAKE gotoblas_SANDYBRIDGE #define gotoblas_ZEN gotoblas_SANDYBRIDGE #else extern gotoblas_t gotoblas_HASWELL; extern gotoblas_t gotoblas_ZEN; #ifndef NO_AVX512 extern gotoblas_t gotoblas_SKYLAKEX; +extern gotoblas_t gotoblas_COOPERLAKE; #else #define gotoblas_SKYLAKEX gotoblas_HASWELL +#define gotoblas_COOPERLAKE gotoblas_HASWELL #endif #endif #else @@ -262,6 +278,7 @@ extern gotoblas_t gotoblas_SKYLAKEX; #define gotoblas_SANDYBRIDGE gotoblas_NEHALEM #define gotoblas_HASWELL gotoblas_NEHALEM #define gotoblas_SKYLAKEX gotoblas_NEHALEM +#define gotoblas_COOPERLAKE gotoblas_NEHALEM #define gotoblas_BULLDOZER gotoblas_BARCELONA #define gotoblas_PILEDRIVER gotoblas_BARCELONA #define gotoblas_STEAMROLLER gotoblas_BARCELONA @@ -343,6 +360,23 @@ int support_avx512(){ #endif } +int support_avx512_bf16(){ +#if !defined(NO_AVX) && !defined(NO_AVX512) + int eax, ebx, ecx, edx; + int ret=0; + + if (!support_avx512()) + return 0; + cpuid_count(7, 1, &eax, &ebx, &ecx, &edx); + if((eax & 32) == 32){ + ret=1; // CPUID.7.1:EAX[bit 5] indicates whether avx512_bf16 supported or not + } + return ret; +#else + return 0; +#endif +} + extern void openblas_warning(int verbose, const char * msg); #define FALLBACK_VERBOSE 1 #define NEHALEM_FALLBACK "OpenBLAS : Your OS does not support AVX instructions. OpenBLAS is using Nehalem kernels as a fallback, which may give poorer performance.\n" @@ -524,7 +558,10 @@ static gotoblas_t *get_coretype(void){ return &gotoblas_NEHALEM; //OS doesn't support AVX. Use old kernels. } } - if (model == 5) { + if (model == 5) { + // Intel Cooperlake + if(support_avx512_bf16()) + return &gotoblas_COOPERLAKE; // Intel Skylake X if (support_avx512()) return &gotoblas_SKYLAKEX; @@ -774,7 +811,8 @@ static char *corename[] = { "Steamroller", "Excavator", "Zen", - "SkylakeX" + "SkylakeX", + "Cooperlake" }; char *gotoblas_corename(void) { @@ -838,6 +876,7 @@ char *gotoblas_corename(void) { if (gotoblas == &gotoblas_EXCAVATOR) return corename[22]; if (gotoblas == &gotoblas_ZEN) return corename[23]; if (gotoblas == &gotoblas_SKYLAKEX) return corename[24]; + if (gotoblas == &gotoblas_COOPERLAKE) return corename[25]; return corename[0]; } @@ -868,6 +907,7 @@ static gotoblas_t *force_coretype(char *coretype){ switch (found) { + case 25: return (&gotoblas_COOPERLAKE); case 24: return (&gotoblas_SKYLAKEX); case 23: return (&gotoblas_ZEN); case 22: return (&gotoblas_EXCAVATOR); diff --git a/exports/gensymbol b/exports/gensymbol index 73b4be248..ce4d9bb64 100644 --- a/exports/gensymbol +++ b/exports/gensymbol @@ -46,7 +46,7 @@ ssum, dsum, scsum, dzsum ); -@halfblasobjs = (shgemm); +@halfblasobjs = (shgemm, shdot, shstobf16, shdtobf16, sbf16tos, dbf16tod); @cblasobjs = ( cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv, cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k, @@ -84,7 +84,7 @@ cblas_xerbla ); -@halfcblasobjs = (cblas_shgemm); +@halfcblasobjs = (cblas_shgemm, cblas_shdot, cblas_shstobf16, cblas_shdtobf16, cblas_sbf16tos, cblas_dbf16tod); @exblasobjs = ( qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm, diff --git a/interface/Makefile b/interface/Makefile index 2dbd60073..fde6227bc 100644 --- a/interface/Makefile +++ b/interface/Makefile @@ -47,7 +47,9 @@ SBLAS3OBJS = \ sgeadd.$(SUFFIX) ifeq ($(BUILD_HALF),1) +SHBLAS1OBJS = shdot.$(SUFFIX) SHBLAS3OBJS = shgemm.$(SUFFIX) +SHEXTOBJS = shstobf16.$(SUFFIX) shdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) endif DBLAS1OBJS = \ @@ -281,7 +283,9 @@ CSBLAS3OBJS = \ cblas_sgeadd.$(SUFFIX) ifeq ($(BUILD_HALF),1) +CSHBLAS1OBJS = cblas_shdot.$(SUFFIX) CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX) +CSHEXTOBJS = cblas_shstobf16.$(SUFFIX) cblas_shdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) endif CDBLAS1OBJS = \ @@ -374,6 +378,7 @@ override CFLAGS += -I. SBLAS1OBJS += $(CSBLAS1OBJS) SBLAS2OBJS += $(CSBLAS2OBJS) SBLAS3OBJS += $(CSBLAS3OBJS) +SHBLAS1OBJS += $(CSHBLAS1OBJS) SHBLAS3OBJS += $(CSHBLAS3OBJS) DBLAS1OBJS += $(CDBLAS1OBJS) DBLAS2OBJS += $(CDBLAS2OBJS) @@ -385,10 +390,11 @@ ZBLAS1OBJS += $(CZBLAS1OBJS) ZBLAS2OBJS += $(CZBLAS2OBJS) ZBLAS3OBJS += $(CZBLAS3OBJS) +SHEXTOBJS += $(CSHEXTOBJS) endif SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) -SHBLASOBJS = $(SHBLAS3OBJS) +SHBLASOBJS = $(SHBLAS1OBJS) $(SHBLAS3OBJS) DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS) CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS) @@ -463,7 +469,7 @@ ZBLASOBJS += $(ZLAPACKOBJS) endif -FUNCOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) +FUNCOBJS = $(SHEXTOBJS) $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) ifdef EXPRECISION FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS) @@ -491,7 +497,7 @@ endif clean :: @rm -f functable.h -level1 : $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS) +level1 : $(BEXTOBJS) $(SHBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) @@ -725,6 +731,19 @@ sdsdot.$(SUFFIX) sdsdot.$(PSUFFIX) : sdsdot.c dsdot.$(SUFFIX) dsdot.$(PSUFFIX) : dsdot.c $(CC) $(CFLAGS) -c $< -o $(@F) +ifeq ($(BUILD_HALF),1) +shdot.$(SUFFIX) shdot.$(PSUFFIX) : bf16dot.c + $(CC) $(CFLAGS) -c $< -o $(@F) +shstobf16.$(SUFFIX) shstobf16.$(PSUFFIX) : tobf16.c + $(CC) $(CFLAGS) -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F) +shdtobf16.$(SUFFIX) shdtobf16.$(PSUFFIX) : tobf16.c + $(CC) $(CFLAGS) -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F) +sbf16tos.$(SUFFIX) sbf16tos.$(PSUFFIX) : bf16to.c + $(CC) $(CFLAGS) -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F) +dbf16tod.$(SUFFIX) dbf16tod.$(PSUFFIX) : bf16to.c + $(CC) $(CFLAGS) -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F) +endif + sdot.$(SUFFIX) sdot.$(PSUFFIX) : dot.c $(CC) $(CFLAGS) -c $< -o $(@F) @@ -1463,6 +1482,19 @@ cblas_sdsdot.$(SUFFIX) cblas_sdsdot.$(PSUFFIX) : sdsdot.c cblas_dsdot.$(SUFFIX) cblas_dsdot.$(PSUFFIX) : dsdot.c $(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) +ifeq ($(BUILD_HALF),1) +cblas_shdot.$(SUFFIX) cblas_shdot.$(PSUFFIX) : bf16dot.c + $(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) +cblas_shstobf16.$(SUFFIX) cblas_shstobf16.$(PSUFFIX) : tobf16.c + $(CC) $(CFLAGS) -DCBLAS -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F) +cblas_shdtobf16.$(SUFFIX) cblas_shdtobf16.$(PSUFFIX) : tobf16.c + $(CC) $(CFLAGS) -DCBLAS -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F) +cblas_sbf16tos.$(SUFFIX) cblas_sbf16tos.$(PSUFFIX) : bf16to.c + $(CC) $(CFLAGS) -DCBLAS -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F) +cblas_dbf16tod.$(SUFFIX) cblas_dbf16tod.$(PSUFFIX) : bf16to.c + $(CC) $(CFLAGS) -DCBLAS -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F) +endif + cblas_sdot.$(SUFFIX) cblas_sdot.$(PSUFFIX) : dot.c $(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) diff --git a/interface/bf16dot.c b/interface/bf16dot.c new file mode 100644 index 000000000..33717e374 --- /dev/null +++ b/interface/bf16dot.c @@ -0,0 +1,52 @@ +#include +#include "common.h" +#ifdef FUNCTION_PROFILE +#include "functable.h" +#endif + +#ifndef CBLAS +float NAME(blasint *N, bfloat16 *x, blasint *INCX, bfloat16 *y, blasint *INCY){ + BLASLONG n = *N; + BLASLONG incx = *INCX; + BLASLONG incy = *INCY; + float ret; + PRINT_DEBUG_NAME; + + if (n <= 0) return 0.; + + IDEBUG_START; + FUNCTION_PROFILE_START(); + + if (incx < 0) x -= (n - 1) * incx; + if (incy < 0) y -= (n - 1) * incy; + ret = BF16_DOT_K(n, x, incx, y, incy); + + FUNCTION_PROFILE_END(1, 2 * n, 2 * n); + IDEBUG_END; + + return ret; + } + +#else + +float CNAME(blasint n, bfloat16 *x, blasint incx, bfloat16 *y, blasint incy){ + + float ret; + PRINT_DEBUG_CNAME; + + if (n <= 0) return 0.; + + IDEBUG_START; + FUNCTION_PROFILE_START(); + + if (incx < 0) x -= (n - 1) * incx; + if (incy < 0) y -= (n - 1) * incy; + ret = BF16_DOT_K(n, x, incx, y, incy); + + FUNCTION_PROFILE_END(1, 2 * n, 2 * n); + IDEBUG_END; + + return ret; +} + +#endif diff --git a/interface/bf16to.c b/interface/bf16to.c new file mode 100644 index 000000000..036c0b142 --- /dev/null +++ b/interface/bf16to.c @@ -0,0 +1,62 @@ +#include +#include "common.h" +#ifdef FUNCTION_PROFILE +#include "functable.h" +#endif + +#if defined(DOUBLE_PREC) +#define FLOAT_TYPE double +#elif defined(SINGLE_PREC) +#define FLOAT_TYPE float +#else +#endif + +#ifndef CBLAS +void NAME(blasint *N, bfloat16 *in, blasint *INC_IN, FLOAT_TYPE *out, blasint *INC_OUT){ + BLASLONG n = *N; + BLASLONG inc_in = *INC_IN; + BLASLONG inc_out = *INC_OUT; + + PRINT_DEBUG_NAME; + + if (n <= 0) return; + + IDEBUG_START; + FUNCTION_PROFILE_START(); + + if (inc_in < 0) in -= (n - 1) * inc_in; + if (inc_out < 0) out -= (n - 1) * inc_out; + +#if defined(DOUBLE_PREC) + D_BF16_TO_K(n, in, inc_in, out, inc_out); +#elif defined(SINGLE_PREC) + S_BF16_TO_K(n, in, inc_in, out, inc_out); +#else +#endif + + FUNCTION_PROFILE_END(1, 2 * n, 2 * n); + IDEBUG_END; +} +#else +void CNAME(blasint n, bfloat16 * in, blasint inc_in, FLOAT_TYPE * out, blasint inc_out){ + PRINT_DEBUG_CNAME; + + if (n <= 0) return; + + IDEBUG_START; + FUNCTION_PROFILE_START(); + + if (inc_in < 0) in -= (n - 1) * inc_in; + if (inc_out < 0) out -= (n - 1) * inc_out; + +#if defined(DOUBLE_PREC) + D_BF16_TO_K(n, in, inc_in, out, inc_out); +#elif defined(SINGLE_PREC) + S_BF16_TO_K(n, in, inc_in, out, inc_out); +#else +#endif + + FUNCTION_PROFILE_END(1, 2 * n, 2 * n); + IDEBUG_END; +} +#endif diff --git a/interface/tobf16.c b/interface/tobf16.c new file mode 100644 index 000000000..787d9d689 --- /dev/null +++ b/interface/tobf16.c @@ -0,0 +1,61 @@ +#include +#include "common.h" +#ifdef FUNCTION_PROFILE +#include "functable.h" +#endif + +#if defined(DOUBLE_PREC) +#define FLOAT_TYPE double +#elif defined(SINGLE_PREC) +#define FLOAT_TYPE float +#else +#endif + +#ifndef CBLAS +void NAME(blasint *N, FLOAT_TYPE *in, blasint *INC_IN, bfloat16 *out, blasint *INC_OUT){ + BLASLONG n = *N; + BLASLONG inc_in = *INC_IN; + BLASLONG inc_out = *INC_OUT; + + PRINT_DEBUG_NAME; + + if (n <= 0) return; + + IDEBUG_START; + FUNCTION_PROFILE_START(); + + if (inc_in < 0) in -= (n - 1) * inc_in; + if (inc_out < 0) out -= (n - 1) * inc_out; + +#if defined(DOUBLE_PREC) + D_TO_BF16_K(n, in, inc_in, out, inc_out); +#elif defined(SINGLE_PREC) + S_TO_BF16_K(n, in, inc_in, out, inc_out); +#else +#endif + + FUNCTION_PROFILE_END(1, 2 * n, 2 * n); + IDEBUG_END; +} +#else +void CNAME(blasint n, FLOAT_TYPE *in, blasint inc_in, bfloat16 *out, blasint inc_out){ + PRINT_DEBUG_CNAME; + + if (n <= 0) return; + + IDEBUG_START; + FUNCTION_PROFILE_START(); + + if (inc_in < 0) in -= (n - 1) * inc_in; + if (inc_out < 0) out -= (n - 1) * inc_out; + +#if defined(DOUBLE_PREC) + D_TO_BF16_K(n, in, inc_in, out, inc_out); +#elif defined(SINGLE_PREC) + S_TO_BF16_K(n, in, inc_in, out, inc_out); +#endif + + FUNCTION_PROFILE_END(1, 2 * n, 2 * n); + IDEBUG_END; +} +#endif diff --git a/kernel/Makefile.L1 b/kernel/Makefile.L1 index 970703230..c6576ee07 100644 --- a/kernel/Makefile.L1 +++ b/kernel/Makefile.L1 @@ -262,6 +262,20 @@ ifndef XDOTKERNEL XDOTKERNEL = zdot.S endif +ifeq ($(BUILD_HALF),1) +ifndef SHDOTKERNEL +SHDOTKERNEL = ../x86_64/shdot.c +endif + +ifndef TOBF16KERNEL +TOBF16KERNEL = ../x86_64/tobf16.c +endif + +ifndef BF16TOKERNEL +BF16TOKERNEL = ../x86_64/bf16to.c +endif +endif + ### NRM2 ### ifndef SNRM2KERNEL @@ -516,6 +530,15 @@ XBLASOBJS += \ xdotc_k$(TSUFFIX).$(SUFFIX) xdotu_k$(TSUFFIX).$(SUFFIX) xnrm2_k$(TSUFFIX).$(SUFFIX) xqrot_k$(TSUFFIX).$(SUFFIX) \ xscal_k$(TSUFFIX).$(SUFFIX) xswap_k$(TSUFFIX).$(SUFFIX) xsum_k$(TSUFFIX).$(SUFFIX) +ifeq ($(BUILD_HALF),1) +SHBLASOBJS += \ + shdot_k$(TSUFFIX).$(SUFFIX) +SHEXTOBJS += \ + shstobf16_k$(TSUFFIX).$(SUFFIX) shdtobf16_k$(TSUFFIX).$(SUFFIX) +SHEXTOBJS += \ + sbf16tos_k$(TSUFFIX).$(SUFFIX) dbf16tod_k$(TSUFFIX).$(SUFFIX) +endif + ### AMAX ### @@ -734,6 +757,19 @@ $(KDIR)ddot_k$(TSUFFIX).$(SUFFIX) $(KDIR)ddot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNEL $(KDIR)qdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)qdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(QDOTKERNEL) $(CC) -c $(CFLAGS) -UCOMPLEX -DXDOUBLE $< -o $@ +ifeq ($(BUILD_HALF),1) +$(KDIR)shdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)shdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHDOTKERNEL) + $(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@ +$(KDIR)shstobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL) + $(CC) -c $(CFLAGS) -UDOUBLE -DSINGLE $< -o $@ +$(KDIR)shdtobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL) + $(CC) -c $(CFLAGS) -DDOUBLE -USINGLE $< -o $@ +$(KDIR)sbf16tos_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BF16TOKERNEL) + $(CC) -c $(CFLAGS) -UDOUBLE -DSINGLE $< -o $@ +$(KDIR)dbf16tod_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BF16TOKERNEL) + $(CC) -c $(CFLAGS) -DDOUBLE -USINGLE $< -o $@ +endif + $(KDIR)sdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)sdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SDOTKERNEL) $(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE $< -o $@ diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index 582a1dc01..c43520310 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -62,9 +62,11 @@ gotoblas_t TABLE_NAME = { MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N), #endif + shstobf16_kTS, shdtobf16_kTS, sbf16tos_kTS, dbf16tod_kTS, + samax_kTS, samin_kTS, smax_kTS, smin_kTS, isamax_kTS, isamin_kTS, ismax_kTS, ismin_kTS, - snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, sdot_kTS, + snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, shdot_kTS, dsdot_kTS, srot_kTS, saxpy_kTS, sscal_kTS, sswap_kTS, sgemv_nTS, sgemv_tTS, sger_kTS, diff --git a/kernel/x86_64/KERNEL b/kernel/x86_64/KERNEL index 4874711bb..4a2e13bed 100644 --- a/kernel/x86_64/KERNEL +++ b/kernel/x86_64/KERNEL @@ -146,6 +146,18 @@ ifndef XDOTKERNEL XDOTKERNEL = zdot.S endif +ifndef SHDOTKERNEL +SHDOTKERNEL = shdot.c +endif + +ifndef TOBF16KERNEL +TOBF16KERNEL = tobf16.c +endif + +ifndef BF16TOKERNEL +BF16TOKERNEL = bf16to.c +endif + ifndef ISAMAXKERNEL ISAMAXKERNEL = iamax_sse.S endif diff --git a/kernel/x86_64/bf16to.c b/kernel/x86_64/bf16to.c new file mode 100644 index 000000000..fc6b5a529 --- /dev/null +++ b/kernel/x86_64/bf16to.c @@ -0,0 +1,114 @@ +/*************************************************************************** +Copyright (c) 2014, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" + +#if defined(DOUBLE) +#define FLOAT_TYPE double +#elif defined(SINGLE) +#define FLOAT_TYPE float +#else +#endif + +/* Notes for algorithm: + * - Input denormal treated as zero + * - Force to be QNAN + */ +static void bf16to_kernel_1(BLASLONG n, const bfloat16 * in, BLASLONG inc_in, FLOAT_TYPE * out, BLASLONG inc_out) +{ + BLASLONG register index_in = 0; + BLASLONG register index_out = 0; + BLASLONG register index = 0; + uint16_t * tmp = NULL; +#if defined(DOUBLE) + float float_out = 0.0; +#endif + + while(index= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9)) + +#define HAVE_TOBF16_ACCL_KERNEL 1 +#include "common.h" +#include + +static void tobf16_accl_kernel(BLASLONG n, const double * in, bfloat16 * out) +{ + /* Get the 64-bytes unaligned header number targeting for avx512 + * processing (Assume input float array is natural aligned) */ + int align_header = ((64 - ((uintptr_t)in & (uintptr_t)0x3f)) >> 3) & 0x7; + + if (n < align_header) {align_header = n;} + + if (align_header != 0) { + unsigned char align_mask8 = (((unsigned char)0xff) >> (8-align_header)); + __m512d a = _mm512_maskz_loadu_pd(*((__mmask8*) &align_mask8), &in[0]); + _mm_mask_storeu_epi16(&out[0], *((__mmask8*) &align_mask8), (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(a))); + } + + if (n == align_header) { + return; + } else { + n -= align_header; + in += align_header; + out += align_header; + } + + int tail_index_8 = n&(~7); + int tail_index_32 = n&(~31); + int tail_index_128 = n&(~127); + unsigned char tail_mask8 = (((unsigned char) 0xff) >> (8 -(n&7))); + + /* Processing the main chunk with 128-elements per round */ + for (int i = 0; i < tail_index_128; i += 128) { + // Fold 1 + __m512 data1_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+ 0]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+ 8])), 1); + __m512 data1_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+16]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+24])), 1); + _mm512_storeu_si512(&out[i+ 0], (__m512i) _mm512_cvtne2ps_pbh(data1_512_high, data1_512_low)); + + // Fold 2 + __m512 data2_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+32]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+40])), 1); + __m512 data2_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+48]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+56])), 1); + _mm512_storeu_si512(&out[i+32], (__m512i) _mm512_cvtne2ps_pbh(data2_512_high, data2_512_low)); + + // Fold 3 + __m512 data3_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+64]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+72])), 1); + __m512 data3_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+80]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+88])), 1); + _mm512_storeu_si512(&out[i+64], (__m512i) _mm512_cvtne2ps_pbh(data3_512_high, data3_512_low)); + + // Fold 4 + __m512 data4_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+96]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+104])), 1); + __m512 data4_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+112]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+120])), 1); + _mm512_storeu_si512(&out[i+96], (__m512i) _mm512_cvtne2ps_pbh(data4_512_high, data4_512_low)); + } + + /* Processing the remaining <128 chunk with 32-elements per round */ + for (int j = tail_index_128; j < tail_index_32; j += 32) { + __m512 data1_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[j+ 0]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[j+ 8])), 1); + __m512 data1_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[j+16]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[j+24])), 1); + _mm512_storeu_si512(&out[j], (__m512i) _mm512_cvtne2ps_pbh(data1_512_high, data1_512_low)); + } + + /* Processing the remaining <32 chunk with 8-elements per round */ + for (int j = tail_index_32; j < tail_index_8; j += 8) { + _mm_storeu_si128((__m128i *)&out[j], (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(_mm512_load_pd(&in[j])))); + } + + /* Processing the remaining <8 chunk with masked processing */ + if ((n&7) > 0) { + __m512d data_512 = _mm512_maskz_load_pd(*((__mmask8*) &tail_mask8), &in[tail_index_8]); + _mm_mask_storeu_epi16(&out[tail_index_8], *((__mmask8*) &tail_mask8), (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(data_512))); + } +} + +#endif diff --git a/kernel/x86_64/shdot.c b/kernel/x86_64/shdot.c new file mode 100644 index 000000000..5073fda2a --- /dev/null +++ b/kernel/x86_64/shdot.c @@ -0,0 +1,115 @@ +/*************************************************************************** +Copyright (c) 2014, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#if defined(COOPERLAKE) +#include "shdot_microk_cooperlake.c" +#endif + +static float shdot_compute(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y) +{ + float d = 0.0; + +#ifdef HAVE_SHDOT_ACCL_KERNEL + if ((inc_x == 1) && (inc_y == 1)) { + return shdot_accl_kernel(n, x, y); + } +#endif + + float * x_fp32 = malloc(sizeof(float)*n); + float * y_fp32 = malloc(sizeof(float)*n); + + SBF16TOS_K(n, x, inc_x, x_fp32, 1); + SBF16TOS_K(n, y, inc_y, y_fp32, 1); + + d = SDOTU_K(n, x_fp32, 1, y_fp32, 1); + + free(x_fp32); + free(y_fp32); + + return d; +} + +#if defined(SMP) +static int shdot_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, bfloat16 dummy2, + bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y, + float *result, BLASLONG dummy3) +{ + *(float *)result = shdot_compute(n, x, inc_x, y, inc_y); + return 0; +} + +extern int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha, + void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc, + int (*function)(), int nthreads); +#endif + +float CNAME(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y) +{ + float dot_result = 0.0; + + if (n <= 0) return 0.0; + +#if defined(SMP) + int nthreads; + int thread_thres = 40960; + bfloat16 dummy_alpha; +#endif + +#if defined(SMP) + if (inc_x == 0 || inc_y == 0 || n <= thread_thres) + nthreads = 1; + else + nthreads = num_cpu_avail(1); + + int best_threads = (int) (n/(float)thread_thres + 0.5); + + if (best_threads < nthreads) { + nthreads = best_threads; + } + + if (nthreads <= 1) { + dot_result = shdot_compute(n, x, inc_x, y, inc_y); + } else { + char thread_result[MAX_CPU_NUMBER * sizeof(double) * 2]; + int mode = BLAS_BFLOAT16 | BLAS_REAL; + blas_level1_thread_with_return_value(mode, n, 0, 0, &dummy_alpha, + x, inc_x, y, inc_y, thread_result, 0, + (void *)shdot_thread_func, nthreads); + float * ptr = (float *)thread_result; + for (int i = 0; i < nthreads; i++) { + dot_result += (*ptr); + ptr = (float *)(((char *)ptr) + sizeof(double) * 2); + } + } +#else + dot_result = shdot_compute(n, x, inc_x, y, inc_y); +#endif + + return dot_result; +} diff --git a/kernel/x86_64/shdot_microk_cooperlake.c b/kernel/x86_64/shdot_microk_cooperlake.c new file mode 100644 index 000000000..e645296f1 --- /dev/null +++ b/kernel/x86_64/shdot_microk_cooperlake.c @@ -0,0 +1,159 @@ +/*************************************************************************** +Copyright (c) 2014, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +/* need a new enough GCC for avx512 support */ +#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9)) + +#define HAVE_SHDOT_ACCL_KERNEL 1 +#include "common.h" +#include + +static float shdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y) +{ + __m128 accum128 = _mm_setzero_ps(); + if (n> 127) { /* n range from 128 to inf. */ + long tail_index_32 = n&(~31); + long tail_index_128 = n&(~127); + unsigned int tail_mask_uint = (((unsigned int)0xffffffff) >> (32-(n&31))); + __mmask32 tail_mask = *((__mmask32*) &tail_mask_uint); + + __m512 accum512_0 = _mm512_setzero_ps(); + __m512 accum512_1 = _mm512_setzero_ps(); + __m512 accum512_2 = _mm512_setzero_ps(); + __m512 accum512_3 = _mm512_setzero_ps(); + + /* Processing the main chunk with 128-elements per round */ + for (long i = 0; i < tail_index_128; i += 128) { + accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) _mm512_loadu_si512(&x[i+ 0]), (__m512bh) _mm512_loadu_si512(&y[i+ 0])); + accum512_1 = _mm512_dpbf16_ps(accum512_1, (__m512bh) _mm512_loadu_si512(&x[i+32]), (__m512bh) _mm512_loadu_si512(&y[i+32])); + accum512_2 = _mm512_dpbf16_ps(accum512_2, (__m512bh) _mm512_loadu_si512(&x[i+64]), (__m512bh) _mm512_loadu_si512(&y[i+64])); + accum512_3 = _mm512_dpbf16_ps(accum512_3, (__m512bh) _mm512_loadu_si512(&x[i+96]), (__m512bh) _mm512_loadu_si512(&y[i+96])); + } + + /* Processing the remaining <128 chunk with 32-elements per round */ + for (long j = tail_index_128; j < tail_index_32; j += 32) { + accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) _mm512_loadu_si512(&x[j]), (__m512bh) _mm512_loadu_si512(&y[j])); + } + + /* Processing the remaining <32 chunk with masked 32-elements processing */ + if ((n&31) != 0) { + accum512_2 = _mm512_dpbf16_ps(accum512_2, + (__m512bh) _mm512_maskz_loadu_epi16(tail_mask, &x[tail_index_32]), + (__m512bh) _mm512_maskz_loadu_epi16(tail_mask, &y[tail_index_32])); + } + + /* Accumulate the 4 registers into 1 register */ + accum512_0 = _mm512_add_ps(accum512_0, accum512_1); + accum512_2 = _mm512_add_ps(accum512_2, accum512_3); + accum512_0 = _mm512_add_ps(accum512_0, accum512_2); + + __m256 accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1)); + accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1)); + } else if (n > 31) { /* n range from 32 to 127 */ + /* Processing <128 chunk with 32-elements per round */ + __m256 accum256 = _mm256_setzero_ps(); + __m256 accum256_1 = _mm256_setzero_ps(); + int tail_index_32 = n&(~31); + for (int j = 0; j < tail_index_32; j += 32) { + accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[j+ 0]), (__m256bh) _mm256_loadu_si256(&y[j+ 0])); + accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256(&x[j+16]), (__m256bh) _mm256_loadu_si256(&y[j+16])); + } + accum256 = _mm256_add_ps(accum256, accum256_1); + + /* Processing the remaining <32 chunk with 16-elements processing */ + if ((n&16) != 0) { + accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[tail_index_32]), (__m256bh) _mm256_loadu_si256(&y[tail_index_32])); + } + accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1)); + + /* Processing the remaining <16 chunk with 8-elements processing */ + if ((n&8) != 0) { + int tail_index_16 = n&(~15); + accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16])); + } + + /* Processing the remaining <8 chunk with masked 8-elements processing */ + if ((n&7) != 0) { + unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7))); + __mmask8 tail_mask = *((__mmask8*) &tail_mask_uint); + int tail_index_8 = n&(~7); + accum128 = _mm_dpbf16_ps(accum128, + (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]), + (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8])); + } + } else if (n > 15) { /* n range from 16 to 31 */ + /* Processing <32 chunk with 16-elements processing */ + __m256 accum256 = _mm256_setzero_ps(); + accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[0]), (__m256bh) _mm256_loadu_si256(&y[0])); + accum128 += _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1)); + + /* Processing the remaining <16 chunk with 8-elements processing */ + if ((n&8) != 0) { + int tail_index_16 = n&(~15); + accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16])); + } + + /* Processing the remaining <8 chunk with masked 8-elements processing */ + if ((n&7) != 0) { + unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7))); + __mmask8 tail_mask = *((__mmask8*) &tail_mask_uint); + int tail_index_8 = n&(~7); + accum128 = _mm_dpbf16_ps(accum128, + (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]), + (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8])); + } + } else if (n > 7) { /* n range from 8 to 15 */ + /* Processing <16 chunk with 8-elements processing */ + accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[0]), (__m128bh) _mm_loadu_si128(&y[0])); + + /* Processing the remaining <8 chunk with masked 8-elements processing */ + if ((n&7) != 0) { + unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7))); + __mmask8 tail_mask = *((__mmask8*) &tail_mask_uint); + int tail_index_8 = n&(~7); + accum128 = _mm_dpbf16_ps(accum128, + (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]), + (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8])); + } + } else { /* n range from 1 to 7 */ + unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7))); + __mmask8 tail_mask = *((__mmask8*) &tail_mask_uint); + accum128 = _mm_dpbf16_ps(accum128, + (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[0]), + (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[0])); + } + + /* Add up the 4 elements into lowest entry */ + __m128 accum128_1 = _mm_shuffle_ps(accum128, accum128, 14); + accum128 = _mm_add_ps(accum128, accum128_1); + accum128_1 = _mm_shuffle_ps(accum128, accum128, 1); + accum128 = _mm_add_ps(accum128, accum128_1); + + return accum128[0]; +} + +#endif diff --git a/kernel/x86_64/stobf16_microk_cooperlake.c b/kernel/x86_64/stobf16_microk_cooperlake.c new file mode 100644 index 000000000..2756a6934 --- /dev/null +++ b/kernel/x86_64/stobf16_microk_cooperlake.c @@ -0,0 +1,86 @@ +/*************************************************************************** +Copyright (c) 2014, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +/* need a new enough GCC for avx512 support */ +#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9)) + +#define HAVE_TOBF16_ACCL_KERNEL 1 +#include "common.h" +#include + +static void tobf16_accl_kernel(BLASLONG n, const float * in, bfloat16 * out) +{ + /* Get the 64-bytes unaligned header number targeting for avx512 + * processing (Assume input float array is natural aligned) */ + int align_header = ((64 - ((uintptr_t)in & (uintptr_t)0x3f)) >> 2) & 0xf; + + if (n < align_header) {align_header = n;} + + if (align_header != 0) { + uint16_t align_mask16 = (((uint16_t)0xffff) >> (16-align_header)); + __m512 a = _mm512_maskz_loadu_ps(*((__mmask16*) &align_mask16), &in[0]); + _mm256_mask_storeu_epi16(&out[0], *((__mmask16*) &align_mask16), (__m256i) _mm512_cvtneps_pbh(a)); + } + + if (n == align_header) { + return; + } else { + n -= align_header; + in += align_header; + out += align_header; + } + + int tail_index_32 = n&(~31); + int tail_index_128 = n&(~127); + uint32_t tail_mask32 = (((uint32_t) 0xffffffff) >> (32-(n&31))); + uint16_t tail_mask16 = (((uint16_t) 0xffff) >> (16-(n&15))); + + /* Processing the main chunk with 128-elements per round */ + for (int i = 0; i < tail_index_128; i += 128) { + _mm512_storeu_si512(&out[i+ 0], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 16]), _mm512_load_ps(&in[i+ 0]))); + _mm512_storeu_si512(&out[i+32], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 48]), _mm512_load_ps(&in[i+32]))); + _mm512_storeu_si512(&out[i+64], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 80]), _mm512_load_ps(&in[i+64]))); + _mm512_storeu_si512(&out[i+96], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+112]), _mm512_load_ps(&in[i+96]))); + } + + /* Processing the remaining <128 chunk with 32-elements per round */ + for (int j = tail_index_128; j < tail_index_32; j += 32) { + _mm512_storeu_si512(&out[j], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[j+ 16]), _mm512_load_ps(&in[j]))); + } + + /* Processing the remaining <32 chunk with masked processing */ + if ((n&31) > 15) { + __m512 b = _mm512_load_ps(&in[tail_index_32]); + __m512 a = _mm512_maskz_load_ps(*((__mmask16*) &tail_mask16), &in[tail_index_32+16]); + _mm512_mask_storeu_epi16(&out[tail_index_32], *((__mmask32*) &tail_mask32), (__m512i) _mm512_cvtne2ps_pbh(a, b)); + } else if ((n&31) > 0) { + __m512 a = _mm512_maskz_load_ps(*((__mmask16*) &tail_mask16), &in[tail_index_32]); + _mm256_mask_storeu_epi16(&out[tail_index_32], *((__mmask16*) &tail_mask16), (__m256i) _mm512_cvtneps_pbh(a)); + } +} + +#endif diff --git a/kernel/x86_64/tobf16.c b/kernel/x86_64/tobf16.c new file mode 100644 index 000000000..3d1796621 --- /dev/null +++ b/kernel/x86_64/tobf16.c @@ -0,0 +1,170 @@ +/*************************************************************************** +Copyright (c) 2014, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" + +#if defined(DOUBLE) +#define FLOAT_TYPE double +#elif defined(SINGLE) +#define FLOAT_TYPE float +#else +#endif + +#if defined(COOPERLAKE) +#if defined(DOUBLE) +#include "dtobf16_microk_cooperlake.c" +#elif defined(SINGLE) +#include "stobf16_microk_cooperlake.c" +#endif +#endif + +/* Notes for algorithm: + * - Round to Nearest Even used generally + * - QNAN for NAN case + * - Input denormals are treated as zero + */ +static void tobf16_generic_kernel(BLASLONG n, const FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out) +{ + BLASLONG register index_in = 0; + BLASLONG register index_out = 0; + BLASLONG register index = 0; + float float_in = 0.0; + uint32_t * uint32_in = (uint32_t *)(&float_in); + uint16_t * uint16_in = (uint16_t *)(&float_in); + + while(index> 16) & 0x1u) + 0x7fffu); +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + *(out+index_out) = uint16_in[1]; +#else + *(out+index_out) = uint16_in[0]; +#endif + break; + } + + index_in += inc_in; + index_out += inc_out; + index++; + } +} + +#ifndef HAVE_TOBF16_ACCL_KERNEL +static void tobf16_accl_kernel(BLASLONG n, const FLOAT_TYPE * in, bfloat16 * out) +{ + tobf16_generic_kernel(n, in, 1, out, 1); +} +#endif + +static void tobf16_compute(BLASLONG n, FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out) +{ + if ((inc_in == 1) && (inc_out == 1)) { + tobf16_accl_kernel(n, in, out); + } else { + tobf16_generic_kernel(n, in, inc_in, out, inc_out); + } +} + +#if defined(SMP) +static int tobf16_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT_TYPE dummy2, + FLOAT_TYPE *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y, + FLOAT_TYPE *dummy3, BLASLONG dummy4) +{ + tobf16_compute(n, x, inc_x, y, inc_y); + return 0; +} + +extern int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha, + void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc, + int (*function)(), int nthreads); +#endif + +void CNAME(BLASLONG n, FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out) +{ + if (n <= 0) return; + +#if defined(SMP) + int nthreads; + FLOAT_TYPE dummy_alpha; + FLOAT_TYPE dummy_c; +#endif + +#if defined(SMP) + if (inc_in == 0 || inc_out == 0 || n <= 100000) { + nthreads = 1; + } else { + if (n/100000 < 100) { + nthreads = 4; + } else { + nthreads = 16; + } + } + + if (nthreads == 1) { + tobf16_compute(n, in, inc_in, out, inc_out); + } else { +#if defined(DOUBLE) + int mode = BLAS_REAL | BLAS_DTOBF16; +#elif defined(SINGLE) + int mode = BLAS_REAL | BLAS_STOBF16; +#endif + blas_level1_thread(mode, n, 0, 0, &dummy_alpha, + in, inc_in, out, inc_out, &dummy_c, 0, + (void *)tobf16_thread_func, nthreads); + } +#else + tobf16_compute(n, in, inc_in, out, inc_out); +#endif + +} diff --git a/openblas_config_template.h b/openblas_config_template.h index 9955e5c73..858b8c5cb 100644 --- a/openblas_config_template.h +++ b/openblas_config_template.h @@ -35,7 +35,8 @@ typedef unsigned long BLASULONG; #endif #ifndef BFLOAT16 -typedef unsigned short bfloat16; +#include +typedef uint16_t bfloat16; #endif #ifdef OPENBLAS_USE64BITINT From 746ad3bd190493a7219bc02547a050772d4a4e01 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sun, 13 Sep 2020 18:40:59 +0200 Subject: [PATCH 02/13] Fix vendor match for GCC gfortran --- f_check | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/f_check b/f_check index dd4d3475c..f894aa9ac 100644 --- a/f_check +++ b/f_check @@ -69,7 +69,7 @@ if ($compiler eq "") { $bu = "_"; } - if ($data =~ /GNU/) { + if ($data =~ /GNU/ || $data =~ /GCC/ ) { $data =~ /(\d+)\.(\d+).(\d+)/; $major = $1; From 26792d2096ce0736a53bef6b8bf4ff0206ac3efa Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sun, 13 Sep 2020 21:47:55 +0200 Subject: [PATCH 03/13] Copy BUILD_* directives to the compiler options to allow ifdef in tests --- cmake/system.cmake | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/cmake/system.cmake b/cmake/system.cmake index c0f3c6ed2..aa342c3d2 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -393,6 +393,18 @@ set(REVISION "-r${OpenBLAS_VERSION}") set(MAJOR_VERSION ${OpenBLAS_MAJOR_VERSION}) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${CCOMMON_OPT}") +if (BUILD_SINGLE) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_SINGLE") +endif() +if (BUILD_DOUBLE) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_DOUBLE") +endif() +if (BUILD_COMPLEX) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_COMPLEX") +endif() +if (BUILD_COMPLEX16) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_COMPLEX16") +endif() if(NOT MSVC) set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} ${CCOMMON_OPT}") endif() From 74e358bcd514cff2e9b32c13571c09176b56a3d8 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sun, 13 Sep 2020 21:49:01 +0200 Subject: [PATCH 04/13] Remove spurious complex16 tests --- ctest/c_dblas1.c | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/ctest/c_dblas1.c b/ctest/c_dblas1.c index e49ae6007..8e13afcaa 100644 --- a/ctest/c_dblas1.c +++ b/ctest/c_dblas1.c @@ -74,16 +74,6 @@ void F77_dswap( const int *N, double *X, const int *incX, return; } -double F77_dzasum(const int *N, void *X, const int *incX) -{ - return cblas_dzasum(*N, X, *incX); -} - -double F77_dznrm2(const int *N, OPENBLAS_CONST void *X, const int *incX) -{ - return cblas_dznrm2(*N, X, *incX); -} - int F77_idamax(const int *N, OPENBLAS_CONST double *X, const int *incX) { if (*N < 1 || *incX < 1) return(0); From 593ce9e23786796a483f44436e4aca57d042f05d Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sun, 13 Sep 2020 21:50:12 +0200 Subject: [PATCH 05/13] Make building individual tests depend on BUILD_SINGLE etc defines --- test/CMakeLists.txt | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index adeee3452..f1f773cba 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -3,11 +3,18 @@ include_directories(${PROJECT_BINARY_DIR}) enable_language(Fortran) -set(OpenBLAS_Tests - sblat1 sblat2 sblat3 - dblat1 dblat2 dblat3 - cblat1 cblat2 cblat3 - zblat1 zblat2 zblat3) +if (BUILD_SINGLE) + list( APPEND OpenBLAS_Tests sblat1 sblat2 sblat3) +endif() +if (BUILD_DOUBLE) + list (APPEND OpenBLAS_Tests dblat1 dblat2 dblat3) +endif() +if (BUILD_COMPLEX) + list (APPEND OpenBLAS_Tests cblat1 cblat2 cblat3) +endif() +if (BUILD_COMPLEX16) + list (APPEND OpenBLAS_Tests zblat1 zblat2 zblat3) +endif() foreach(test_bin ${OpenBLAS_Tests}) add_executable(${test_bin} ${test_bin}.f) From ce8939863626d3a194890e87edc9b7280f73b660 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sun, 13 Sep 2020 21:52:18 +0200 Subject: [PATCH 06/13] Make tests for individual variable types conditional on the respective BUILD_ option --- utest/test_amax.c | 6 +++++- utest/test_axpy.c | 9 +++++++++ utest/test_dotu.c | 3 +++ utest/test_ismin.c | 2 ++ utest/test_min.c | 13 +++++++++++-- utest/test_potrs.c | 39 ++++++++++++++++++++++++++++++--------- utest/test_rot.c | 9 +++++++++ utest/test_swap.c | 9 +++++++++ 8 files changed, 78 insertions(+), 12 deletions(-) diff --git a/utest/test_amax.c b/utest/test_amax.c index 831804027..a9e5a1c85 100644 --- a/utest/test_amax.c +++ b/utest/test_amax.c @@ -33,6 +33,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "openblas_utest.h" +#ifdef BUILD_SINGLE CTEST(amax, samax){ blasint N=3, inc=1; float te_max=0.0, tr_max=0.0; @@ -43,7 +44,8 @@ CTEST(amax, samax){ ASSERT_DBL_NEAR_TOL((double)(tr_max), (double)(te_max), SINGLE_EPS); } - +#endif +#ifdef BUILD_DOUBLE CTEST(amax, damax){ blasint N=3, inc=1; double te_max=0.0, tr_max=0.0; @@ -54,3 +56,5 @@ CTEST(amax, damax){ ASSERT_DBL_NEAR_TOL((double)(tr_max), (double)(te_max), DOUBLE_EPS); } +#endif + diff --git a/utest/test_axpy.c b/utest/test_axpy.c index 603043073..5fd7c1b04 100644 --- a/utest/test_axpy.c +++ b/utest/test_axpy.c @@ -33,6 +33,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "openblas_utest.h" +#ifdef BUILD_DOUBLE CTEST(axpy,daxpy_inc_0) { blasint i; @@ -52,7 +53,9 @@ CTEST(axpy,daxpy_inc_0) ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS); } } +#endif +#ifdef BUILD_COMPLEX16 CTEST(axpy,zaxpy_inc_0) { blasint i; @@ -71,7 +74,9 @@ CTEST(axpy,zaxpy_inc_0) ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS); } } +#endif +#ifdef BUILD_SINGLE CTEST(axpy,saxpy_inc_0) { blasint i; @@ -90,7 +95,9 @@ CTEST(axpy,saxpy_inc_0) ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS); } } +#endif +#ifdef BUILD_COMPLEX CTEST(axpy,caxpy_inc_0) { blasint i; @@ -109,3 +116,5 @@ CTEST(axpy,caxpy_inc_0) ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS); } } +#endif + diff --git a/utest/test_dotu.c b/utest/test_dotu.c index 918541848..542286403 100644 --- a/utest/test_dotu.c +++ b/utest/test_dotu.c @@ -33,6 +33,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "openblas_utest.h" +#ifdef BUILD_COMPLEX16 CTEST( zdotu,zdotu_n_1) { blasint N=1,incX=1,incY=1; @@ -80,3 +81,5 @@ CTEST(zdotu, zdotu_offset_1) #endif } +#endif + diff --git a/utest/test_ismin.c b/utest/test_ismin.c index f23d6b545..af597807f 100644 --- a/utest/test_ismin.c +++ b/utest/test_ismin.c @@ -36,6 +36,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define ELEMENTS 50 #define INCREMENT 2 +#ifdef BUILD_SINGLE CTEST(ismin, positive_step_2){ blasint i; blasint N = ELEMENTS, inc = INCREMENT; @@ -87,3 +88,4 @@ CTEST(ismax, negative_step_2){ blasint index = BLASFUNC(ismax)(&N, x, &inc); ASSERT_EQUAL(9, index); } +#endif diff --git a/utest/test_min.c b/utest/test_min.c index fd31b5982..a627674ae 100644 --- a/utest/test_min.c +++ b/utest/test_min.c @@ -32,7 +32,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. **********************************************************************************/ #include "openblas_utest.h" - +#ifdef BUILD_SINGLE CTEST(min, smin_negative){ blasint N=3, inc=1; float te_min=0.0, tr_min=0.0; @@ -43,7 +43,9 @@ CTEST(min, smin_negative){ ASSERT_DBL_NEAR_TOL((double)(tr_min), (double)(te_min), SINGLE_EPS); } +#endif +#ifdef BUILD_DOUBLE CTEST(min, dmin_positive){ blasint N=3, inc=1; double te_min=0.0, tr_min=0.0; @@ -54,7 +56,9 @@ CTEST(min, dmin_positive){ ASSERT_DBL_NEAR_TOL((double)(tr_min), (double)(te_min), DOUBLE_EPS); } +#endif +#ifdef BUILD_SINGLE CTEST(min, smin_zero){ blasint N=3, inc=1; float te_min=0.0, tr_min=0.0; @@ -76,7 +80,9 @@ CTEST(max, smax_negative){ ASSERT_DBL_NEAR_TOL((double)(tr_max), (double)(te_max), SINGLE_EPS); } +#endif +#ifdef BUILD_DOUBLE CTEST(max, dmax_positive){ blasint N=3, inc=1; double te_max=0.0, tr_max=0.0; @@ -87,7 +93,8 @@ CTEST(max, dmax_positive){ ASSERT_DBL_NEAR_TOL((double)(tr_max), (double)(te_max), DOUBLE_EPS); } - +#endif +#ifdef BUILD_SINGLE CTEST(max, smax_zero){ blasint N=3, inc=1; float te_max=0.0, tr_max=0.0; @@ -98,3 +105,5 @@ CTEST(max, smax_zero){ ASSERT_DBL_NEAR_TOL((double)(tr_max), (double)(te_max), SINGLE_EPS); } +#endif + diff --git a/utest/test_potrs.c b/utest/test_potrs.c index 7afeb4c9d..05ce3037b 100644 --- a/utest/test_potrs.c +++ b/utest/test_potrs.c @@ -39,10 +39,10 @@ void BLASFUNC(zpotrs_(char*, BLASINT*, BLASINT*, complex double*, BLASINT*, complex double*, BLASINT*, BLASINT*); */ - //https://github.com/xianyi/OpenBLAS/issues/695 CTEST(potrf, bug_695){ +#ifdef BUILD_COMPLEX openblas_complex_float A1[100] = { openblas_make_complex_float(5.8525753, +0.0), @@ -153,7 +153,9 @@ CTEST(potrf, bug_695){ blasint info[1]; BLASFUNC(cpotrf)(&up, &n, (float*)(A1), &n, info); //printf("%g+%g*I\n", creal(A1[91]), cimag(A1[91])); +#endif +#ifdef BUILD_COMPLEX16 openblas_complex_double A2[100] = { openblas_make_complex_double(3.0607147216796875, +0.0), @@ -283,7 +285,8 @@ CTEST(potrf, bug_695){ char lo = 'L'; blasint nrhs = 2; BLASFUNC(zpotrs)(&lo, &n, &nrhs, (double*)(A2), &n, (double*)(B), &n, info); - +#endif +#ifdef BUILD_COMPLEX // note that this is exactly equal to A1 openblas_complex_float A3[100] = { @@ -393,9 +396,9 @@ CTEST(potrf, bug_695){ if(isnan(CREAL(A3[91])) || isnan(CIMAG(A3[91]))) { CTEST_ERR("%s:%d got NaN", __FILE__, __LINE__); } +#endif } - // Check potrf factorizes a small problem correctly CTEST(potrf, smoketest_trivial){ float A1s[4] = {2, 0.3, 0.3, 3}; @@ -439,31 +442,43 @@ CTEST(potrf, smoketest_trivial){ uplo = 'U'; } +#ifdef BUILD_SINGLE BLASFUNC(scopy)(&nv, A1s, &inc, As, &inc); +#endif +#ifdef BUILD_DOUBLE BLASFUNC(dcopy)(&nv, A1d, &inc, Ad, &inc); +#endif +#ifdef BUILD_COMPLEX BLASFUNC(ccopy)(&nv, (float *)A1c, &inc, (float *)Ac, &inc); +#endif +#ifdef BUILD_COMPLEX16 BLASFUNC(zcopy)(&nv, (double *)A1z, &inc, (double *)Az, &inc); +#endif +#ifdef BUILD_SINGLE BLASFUNC(spotrf)(&uplo, &n, As, &n, &info); if (info != 0) { CTEST_ERR("%s:%d info != 0", __FILE__, __LINE__); } - +#endif +#ifdef BUILD_DOUBLE BLASFUNC(dpotrf)(&uplo, &n, Ad, &n, &info); if (info != 0) { CTEST_ERR("%s:%d info != 0", __FILE__, __LINE__); } - +#endif +#ifdef BUILD_COMPLEX BLASFUNC(cpotrf)(&uplo, &n, (float *)Ac, &n, &info); if (info != 0) { CTEST_ERR("%s:%d info != 0", __FILE__, __LINE__); } - +#endif +#ifdef BUILD_COMPLEX16 BLASFUNC(zpotrf)(&uplo, &n, (double *)Az, &n, &info); if (info != 0) { CTEST_ERR("%s:%d info != 0", __FILE__, __LINE__); } - +#endif /* Fill the other triangle */ if (uplo == 'L') { for (i = 0; i < n; ++i) { @@ -495,14 +510,20 @@ CTEST(potrf, smoketest_trivial){ trans1 = 'C'; trans2 = 'N'; } - +#ifdef BUILD_SINGLE BLASFUNC(sgemm)(&trans1, &trans2, &n, &n, &n, &ones, As, &n, As, &n, &zeros, Bs, &n); +#endif +#ifdef BUILD_DOUBLE BLASFUNC(dgemm)(&trans1, &trans2, &n, &n, &n, &oned, Ad, &n, Ad, &n, &zerod, Bd, &n); +#endif +#ifdef BUILD_COMPLEX BLASFUNC(cgemm)(&trans1, &trans2, &n, &n, &n, (float *)&onec, (float *)Ac, &n, (float *)Ac, &n, (float *)&zeroc, (float *)Bc, &n); +#endif +#ifdef BUILD_COMPLEX16 BLASFUNC(zgemm)(&trans1, &trans2, &n, &n, &n, (double *)&onez, (double *)Az, &n, (double *)Az, &n, (double *)&zeroz, (double *)Bz, &n); - +#endif /* Check result is close to original */ for (i = 0; i < n; ++i) { for (j = 0; j < n; ++j) { diff --git a/utest/test_rot.c b/utest/test_rot.c index cf72ad22d..0e74ecbb3 100644 --- a/utest/test_rot.c +++ b/utest/test_rot.c @@ -33,6 +33,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "openblas_utest.h" +#ifdef BUILD_DOUBLE CTEST(rot,drot_inc_0) { blasint i=0; @@ -52,7 +53,9 @@ CTEST(rot,drot_inc_0) ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS); } } +#endif +#ifdef BUILD_COMPLEX16 CTEST(rot,zdrot_inc_0) { blasint i=0; @@ -72,7 +75,9 @@ CTEST(rot,zdrot_inc_0) ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS); } } +#endif +#ifdef BUILD_SINGLE CTEST(rot,srot_inc_0) { blasint i=0; @@ -91,7 +96,9 @@ CTEST(rot,srot_inc_0) ASSERT_DBL_NEAR_TOL(y2[i], y1[i], SINGLE_EPS); } } +#endif +#ifdef BUILD_COMPLEX CTEST(rot, csrot_inc_0) { blasint i=0; @@ -110,3 +117,5 @@ CTEST(rot, csrot_inc_0) ASSERT_DBL_NEAR_TOL(y2[i], y1[i], SINGLE_EPS); } } +#endif + diff --git a/utest/test_swap.c b/utest/test_swap.c index 259c83a5c..6d8ae8056 100644 --- a/utest/test_swap.c +++ b/utest/test_swap.c @@ -33,6 +33,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "openblas_utest.h" +#ifdef BUILD_DOUBLE CTEST(swap,dswap_inc_0) { blasint i=0; @@ -50,7 +51,9 @@ CTEST(swap,dswap_inc_0) ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS); } } +#endif +#ifdef BUILD_COMPLEX16 CTEST(swap,zswap_inc_0) { blasint i=0; @@ -68,7 +71,9 @@ CTEST(swap,zswap_inc_0) ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS); } } +#endif +#ifdef BUILD_SINGLE CTEST(swap,sswap_inc_0) { blasint i=0; @@ -86,7 +91,9 @@ CTEST(swap,sswap_inc_0) ASSERT_DBL_NEAR_TOL(y2[i], y1[i], SINGLE_EPS); } } +#endif +#ifdef BUILD_COMPLEX CTEST(swap,cswap_inc_0) { blasint i=0; @@ -104,3 +111,5 @@ CTEST(swap,cswap_inc_0) ASSERT_DBL_NEAR_TOL(y2[i], y1[i], SINGLE_EPS); } } +#endif + From ec2948f14784c3559b11f9aed07646396c3527cf Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sun, 13 Sep 2020 22:17:46 +0200 Subject: [PATCH 07/13] Make tests conditional on BUILD_DOUBLE --- utest/test_kernel_regress.c | 2 ++ utest/test_rotmg.c | 2 ++ 2 files changed, 4 insertions(+) diff --git a/utest/test_kernel_regress.c b/utest/test_kernel_regress.c index 93a30b30c..5b131bb2c 100644 --- a/utest/test_kernel_regress.c +++ b/utest/test_kernel_regress.c @@ -22,6 +22,7 @@ double m[DATASIZE*DATASIZE]; CTEST(kernel_regress,skx_avx) { +#ifdef BUILD_DOUBLE double norm; int i, j, info; srand(0); @@ -47,4 +48,5 @@ CTEST(kernel_regress,skx_avx) norm = cblas_dnrm2(DATASIZE*DATASIZE, X, 1); ASSERT_DBL_NEAR_TOL(0.0, norm, 1e-10); +#endif } diff --git a/utest/test_rotmg.c b/utest/test_rotmg.c index e5ec78983..ad435f6b0 100644 --- a/utest/test_rotmg.c +++ b/utest/test_rotmg.c @@ -33,6 +33,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "openblas_utest.h" +#ifdef BUILD_DOUBLE CTEST (drotmg,rotmg) { double te_d1, tr_d1; @@ -204,3 +205,4 @@ CTEST(drotmg, drotmg_D1_big_D2_big_flag_zero) ASSERT_DBL_NEAR_TOL(tr_param[i], te_param[i], DOUBLE_EPS); } } +#endif From de139337b8bcb1c76cd157afd4d5fd035a76efdf Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sun, 13 Sep 2020 22:20:41 +0200 Subject: [PATCH 08/13] Remove spurious tests for complex ASUM and NRM2 --- ctest/c_sblas1.c | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/ctest/c_sblas1.c b/ctest/c_sblas1.c index 1a433b287..a562014a4 100644 --- a/ctest/c_sblas1.c +++ b/ctest/c_sblas1.c @@ -21,16 +21,6 @@ void F77_saxpy(blasint *N, const float *alpha, OPENBLAS_CONST float *X, return; } -float F77_scasum(blasint *N, float *X, blasint *incX) -{ - return cblas_scasum(*N, X, *incX); -} - -float F77_scnrm2(blasint *N, OPENBLAS_CONST float *X, blasint *incX) -{ - return cblas_scnrm2(*N, X, *incX); -} - void F77_scopy(blasint *N, OPENBLAS_CONST float *X, blasint *incX, float *Y, blasint *incY) { From 4d250d0cdf9f0d234aa9c3eeff246bbe1b9edd3b Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sun, 13 Sep 2020 23:29:01 +0200 Subject: [PATCH 09/13] Rearrange ifdefs --- utest/test_potrs.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utest/test_potrs.c b/utest/test_potrs.c index 05ce3037b..2681615f4 100644 --- a/utest/test_potrs.c +++ b/utest/test_potrs.c @@ -42,7 +42,6 @@ void BLASFUNC(zpotrs_(char*, BLASINT*, BLASINT*, complex double*, //https://github.com/xianyi/OpenBLAS/issues/695 CTEST(potrf, bug_695){ -#ifdef BUILD_COMPLEX openblas_complex_float A1[100] = { openblas_make_complex_float(5.8525753, +0.0), @@ -151,11 +150,11 @@ CTEST(potrf, bug_695){ blasint n=10; blasint info[1]; +#ifdef BUILD_COMPLEX BLASFUNC(cpotrf)(&up, &n, (float*)(A1), &n, info); //printf("%g+%g*I\n", creal(A1[91]), cimag(A1[91])); #endif -#ifdef BUILD_COMPLEX16 openblas_complex_double A2[100] = { openblas_make_complex_double(3.0607147216796875, +0.0), @@ -284,9 +283,9 @@ CTEST(potrf, bug_695){ }; char lo = 'L'; blasint nrhs = 2; +#ifdef BUILD_COMPLEX16 BLASFUNC(zpotrs)(&lo, &n, &nrhs, (double*)(A2), &n, (double*)(B), &n, info); #endif -#ifdef BUILD_COMPLEX // note that this is exactly equal to A1 openblas_complex_float A3[100] = { @@ -391,6 +390,7 @@ CTEST(potrf, bug_695){ openblas_make_complex_float(-0.9617417, -1.2486815), openblas_make_complex_float(3.4629636, +0.0) }; +#ifdef BUILD_COMPLEX BLASFUNC(cpotrf)(&up, &n, (float*)(A3), &n, info); // printf("%g+%g*I\n", creal(A3[91]), cimag(A3[91])); if(isnan(CREAL(A3[91])) || isnan(CIMAG(A3[91]))) { From 9e11c2d62f23ef2483d206aaf3952e0bd09d30cb Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sun, 13 Sep 2020 23:55:11 +0200 Subject: [PATCH 10/13] Add BUILD_SINGLE etc --- Makefile.rule | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Makefile.rule b/Makefile.rule index 2c12177ee..40bd1a854 100644 --- a/Makefile.rule +++ b/Makefile.rule @@ -277,5 +277,10 @@ COMMON_PROF = -pg # If you want to enable the experimental BFLOAT16 support # BUILD_HALF = 1 # +# the below is not yet configurable, use cmake if you need to build only select types +BUILD_SINGLE = 1 +BUILD_DOUBLE = 1 +BUILD_COMPLEX = 1 +BUILD_COMPLEX16 = 1 # End of user configuration # From ba644378dce720f6bb946aa2b585c9e71f257e1f Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Mon, 14 Sep 2020 00:03:33 +0200 Subject: [PATCH 11/13] Copy BUILD_ options available to the compiler flags --- Makefile.system | 53 +++++++++++++++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/Makefile.system b/Makefile.system index 1b832ba41..0ccf9eaed 100644 --- a/Makefile.system +++ b/Makefile.system @@ -295,6 +295,7 @@ endif ifeq ($(C_COMPILER), GCC) GCCVERSIONGTEQ4 := $(shell expr `$(CC) -dumpversion | cut -f1 -d.` \>= 4) GCCVERSIONGT4 := $(shell expr `$(CC) -dumpversion | cut -f1 -d.` \> 4) +GCCVERSIONEQ5 := $(shell expr `$(CC) -dumpversion | cut -f1 -d.` = 5) GCCVERSIONGT5 := $(shell expr `$(CC) -dumpversion | cut -f1 -d.` \> 5) GCCVERSIONGTEQ7 := $(shell expr `$(CC) -dumpversion | cut -f1 -d.` \>= 7) GCCVERSIONGTEQ9 := $(shell expr `$(CC) -dumpversion | cut -f1 -d.` \>= 9) @@ -593,35 +594,33 @@ endif ifeq ($(ARCH), zarch) DYNAMIC_CORE = ZARCH_GENERIC -# if the compiler accepts -march=arch11 or -march=z13 and can compile a file -# with z13-specific inline assembly, then we can include support for Z13. -# note: -march=z13 is equivalent to -march=arch11 yet some compiler releases -# only support one or the other. -# note: LLVM version 6.x supported -march=z13 yet could not handle vector -# registers in inline assembly, so the check for supporting the -march flag is -# not enough. -ZARCH_TEST_COMPILE=-c $(TOPDIR)/kernel/zarch/damin_z13.c -I$(TOPDIR) -o /dev/null > /dev/null 2> /dev/null -ZARCH_CC_SUPPORTS_ARCH11=$(shell $(CC) -march=arch11 $(ZARCH_TEST_COMPILE) && echo 1) -ZARCH_CC_SUPPORTS_Z13=$(shell $(CC) -march=z13 $(ZARCH_TEST_COMPILE) && echo 1) +# Z13 is supported since gcc-5.2, gcc-6, and in RHEL 7.3 and newer +ifeq ($(GCCVERSIONGT5), 1) + ZARCH_SUPPORT_Z13 := 1 +else ifeq ($(GCCVERSIONEQ5), 1) +ifeq ($(GCCMINORVERSIONGTEQ2), 1) + ZARCH_SUPPORT_Z13 := 1 +endif +endif -ifeq ($(or $(ZARCH_CC_SUPPORTS_ARCH11), $(ZARCH_CC_SUPPORTS_Z13)), 1) +ifeq ($(wildcard /etc/redhat-release), /etc/redhat-release) +ifeq ($(shell source /etc/os-release ; expr $$VERSION_ID \>= "7.3"), 1) + ZARCH_SUPPORT_Z13 := 1 +endif +endif + +ifeq ($(ZARCH_SUPPORT_Z13), 1) DYNAMIC_CORE += Z13 -CCOMMON_OPT += -DDYN_Z13 else -$(info OpenBLAS: Not building Z13 kernels because the compiler $(CC) does not support it) +$(info OpenBLAS: Not building Z13 kernels because gcc is older than 5.2 or 6.x) endif -# as above for z13, check for -march=arch12 and z14 support in the compiler. -ZARCH_CC_SUPPORTS_ARCH12=$(shell $(CC) -march=arch12 $(ZARCH_TEST_COMPILE) && echo 1) -ZARCH_CC_SUPPORTS_Z14=$(shell $(CC) -march=z14 $(ZARCH_TEST_COMPILE) && echo 1) -ifeq ($(or $(ZARCH_CC_SUPPORTS_ARCH12), $(ZARCH_CC_SUPPORTS_Z14)), 1) +ifeq ($(GCCVERSIONGTEQ7), 1) DYNAMIC_CORE += Z14 -CCOMMON_OPT += -DDYN_Z14 else -$(info OpenBLAS: Not building Z14 kernels because the compiler $(CC) does not support it) +$(info OpenBLAS: Not building Z14 kernels because gcc is older than 7.x) +endif endif - -endif # ARCH zarch ifeq ($(ARCH), power) DYNAMIC_CORE = POWER6 @@ -1223,6 +1222,18 @@ endif ifeq ($(BUILD_HALF), 1) CCOMMON_OPT += -DBUILD_HALF endif +ifeq ($(BUILD_SINGLE), 1) +CCOMMON_OPT += -DBUILD_SINGLE +endif +ifeq ($(BUILD_DOUBLE), 1) +CCOMMON_OPT += -DBUILD_DOUBLE +endif +ifeq ($(BUILD_COMPLEX), 1) +CCOMMON_OPT += -DBUILD_COMPLEX +endif +ifeq ($(BUILD_COMPLEX16), 1) +CCOMMON_OPT += -DBUILD_COMPLEX16 +endif CCOMMON_OPT += -DVERSION=\"$(VERSION)\" From 274d6e015b56a9f0ccad928232ed3bd88a063754 Mon Sep 17 00:00:00 2001 From: fossum Date: Mon, 14 Sep 2020 13:10:48 -0500 Subject: [PATCH 12/13] Fixing a performance bug in trsm_[LR].c. --- driver/level3/trsm_L.c | 4 ++-- driver/level3/trsm_R.c | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/driver/level3/trsm_L.c b/driver/level3/trsm_L.c index d8130ee7e..d842efa93 100644 --- a/driver/level3/trsm_L.c +++ b/driver/level3/trsm_L.c @@ -131,7 +131,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; @@ -197,7 +197,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; diff --git a/driver/level3/trsm_R.c b/driver/level3/trsm_R.c index f6a57f93f..f76a8f7f3 100644 --- a/driver/level3/trsm_R.c +++ b/driver/level3/trsm_R.c @@ -126,7 +126,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; @@ -182,7 +182,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = 0; jjs < min_j - min_l - ls + js; jjs += min_jj){ min_jj = min_j - min_l - ls + js - jjs; - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; @@ -243,7 +243,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; @@ -304,7 +304,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = 0; jjs < min_j - js + ls; jjs += min_jj){ min_jj = min_j - js + ls - jjs; - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; From dfeca46098ff7b3cc47aa053195fe1c82bce87e9 Mon Sep 17 00:00:00 2001 From: fossum Date: Tue, 15 Sep 2020 08:59:50 -0500 Subject: [PATCH 13/13] Adding performance patch for trmm, just like #2836 --- driver/level3/trmm_L.c | 8 ++++---- driver/level3/trmm_R.c | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/driver/level3/trmm_L.c b/driver/level3/trmm_L.c index 1027c0c73..ae8435d03 100644 --- a/driver/level3/trmm_L.c +++ b/driver/level3/trmm_L.c @@ -139,7 +139,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO /* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */ if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N; #else - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #endif @@ -209,7 +209,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO /* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */ if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N; #else - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #endif @@ -304,7 +304,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO /* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */ if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N; #else - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #endif @@ -374,7 +374,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO /* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */ if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N; #else - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #endif diff --git a/driver/level3/trmm_R.c b/driver/level3/trmm_R.c index e8df7fb21..3be43edde 100644 --- a/driver/level3/trmm_R.c +++ b/driver/level3/trmm_R.c @@ -126,7 +126,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO /* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */ if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N; #else - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #endif @@ -150,7 +150,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO /* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */ if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N; #else - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #endif @@ -207,7 +207,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO /* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */ if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N; #else - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #endif @@ -262,7 +262,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO /* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */ if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N; #else - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #endif @@ -287,7 +287,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO /* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */ if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N; #else - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #endif @@ -348,7 +348,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO /* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */ if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N; #else - if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #endif