diff --git a/Makefile.system b/Makefile.system index 2998c0e6a..0e176987c 100644 --- a/Makefile.system +++ b/Makefile.system @@ -1390,6 +1390,8 @@ export FUNCTION_PROFILE export TARGET_CORE export NO_AVX512 +export SHGEMM_UNROLL_M +export SHGEMM_UNROLL_N export SGEMM_UNROLL_M export SGEMM_UNROLL_N export DGEMM_UNROLL_M diff --git a/Makefile.tail b/Makefile.tail index 2adede1a5..39902982b 100644 --- a/Makefile.tail +++ b/Makefile.tail @@ -1,3 +1,4 @@ +SHBLASOBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) @@ -9,8 +10,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX)) HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) -BLASOBJS = $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) -BLASOBJS_P = $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) +BLASOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) +BLASOBJS_P = $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) ifdef EXPRECISION BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) @@ -22,6 +23,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P) endif +$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHALF -UDOUBLE -UCOMPLEX $(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX $(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX $(QBLASOBJS) $(QBLASOBJS_P) : override CFLAGS += -DXDOUBLE -UCOMPLEX @@ -29,6 +31,7 @@ $(CBLASOBJS) $(CBLASOBJS_P) : override CFLAGS += -UDOUBLE -DCOMPLEX $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX $(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX +$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(QBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) diff --git a/cmake/prebuild.cmake b/cmake/prebuild.cmake index 44e1473d1..e0696093b 100644 --- a/cmake/prebuild.cmake +++ b/cmake/prebuild.cmake @@ -16,6 +16,8 @@ # HAVE_SSE2 # HAVE_SSE3 # MAKE +# SHGEMM_UNROLL_M +# SHGEMM_UNROLL_N # SGEMM_UNROLL_M # SGEMM_UNROLL_N # DGEMM_UNROLL_M @@ -437,6 +439,8 @@ if (DEFINED CORE AND CMAKE_CROSSCOMPILING AND NOT (${HOST_OS} STREQUAL "WINDOWSS set(ZGEMM_UNROLL_N 2) set(SYMV_P 8) endif() + set(SHGEMM_UNROLL_M 8) + set(SHGEMM_UNROLL_N 4) # Or should this actually be NUM_CORES? if (${NUM_THREADS} GREATER 0) diff --git a/cmake/system.cmake b/cmake/system.cmake index ce980a7b9..65e5aa508 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -530,6 +530,8 @@ endif () #export FUNCTION_PROFILE #export TARGET_CORE # +#export SHGEMM_UNROLL_M +#export SHGEMM_UNROLL_N #export SGEMM_UNROLL_M #export SGEMM_UNROLL_N #export DGEMM_UNROLL_M diff --git a/common.h b/common.h index 762968e6f..1d8bf07e5 100644 --- a/common.h +++ b/common.h @@ -297,6 +297,17 @@ typedef int blasint; #define SIZE 8 #define BASE_SHIFT 3 #define ZBASE_SHIFT 4 +#elif defined(HALF) +#ifndef BFLOAT16 +typedef unsigned short bfloat16; +#define HALFCONVERSION 1 +#endif +#define IFLOAT bfloat16 +#define XFLOAT IFLOAT +#define FLOAT float +#define SIZE 2 +#define BASE_SHIFT 1 +#define ZBASE_SHIFT 2 #else #define FLOAT float #define SIZE 4 @@ -308,6 +319,10 @@ typedef int blasint; #define XFLOAT FLOAT #endif +#ifndef IFLOAT +#define IFLOAT FLOAT +#endif + #ifndef COMPLEX #define COMPSIZE 1 #else diff --git a/common_interface.h b/common_interface.h index c350ac8ec..081043af1 100644 --- a/common_interface.h +++ b/common_interface.h @@ -37,6 +37,9 @@ /*********************************************************************/ #ifndef ASSEMBLER +#ifndef BFLOAT16 +typedef unsigned short bfloat16; +#endif #ifdef __cplusplus extern "C" { @@ -469,6 +472,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint /* Level 3 routines */ +void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *, + bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *); void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *, float *, blasint *, float *, blasint *, float *, float *, blasint *); void BLASFUNC(dgemm)(char *, char *, blasint *, blasint *, blasint *, double *, diff --git a/common_level3.h b/common_level3.h index 6fa902be8..8194ba6ce 100644 --- a/common_level3.h +++ b/common_level3.h @@ -37,6 +37,9 @@ /*********************************************************************/ #ifndef ASSEMBLER +#ifndef BFLOAT16 +typedef unsigned short bfloat16; +#endif #ifdef __CUDACC__ __global__ void cuda_sgemm_kernel(int, int, int, float *, float *, float *); @@ -55,6 +58,8 @@ extern void sgemm_kernel_direct(BLASLONG M, BLASLONG N, BLASLONG K, extern int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); +int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, + bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); int dgemm_beta(BLASLONG, BLASLONG, BLASLONG, double, @@ -76,6 +81,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG); #endif +int shgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); +int shgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); +int shgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); +int shgemm_otcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); int sgemm_incopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b); int sgemm_itcopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b); int sgemm_oncopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b); @@ -499,6 +508,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); +int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG); @@ -527,6 +537,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG); int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG); +int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); + int sgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); int sgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); int sgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); @@ -619,6 +634,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG); #endif +int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); + int sgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); int sgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); int sgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); diff --git a/common_macro.h b/common_macro.h index 13bb85794..b438c83ba 100644 --- a/common_macro.h +++ b/common_macro.h @@ -39,6 +39,7 @@ #ifndef COMMON_MACRO #define COMMON_MACRO +#include "common_sh.h" #include "common_s.h" #include "common_d.h" #include "common_q.h" @@ -642,6 +643,53 @@ #define IMATCOPY_K_RT DIMATCOPY_K_RT #define GEADD_K DGEADD_K + +#elif defined(HALF) + +#define GEMM_BETA SHGEMM_BETA +#define GEMM_KERNEL_N SHGEMM_KERNEL +#define GEMM_KERNEL_L SHGEMM_KERNEL +#define GEMM_KERNEL_R SHGEMM_KERNEL +#define GEMM_KERNEL_B SHGEMM_KERNEL + +#define GEMM_NN SHGEMM_NN +#define GEMM_CN SHGEMM_TN +#define GEMM_TN SHGEMM_TN +#define GEMM_NC SHGEMM_NT +#define GEMM_NT SHGEMM_NT +#define GEMM_CC SHGEMM_TT +#define GEMM_CT SHGEMM_TT +#define GEMM_TC SHGEMM_TT +#define GEMM_TT SHGEMM_TT +#define GEMM_NR SHGEMM_NN +#define GEMM_TR SHGEMM_TN +#define GEMM_CR SHGEMM_TN +#define GEMM_RN SHGEMM_NN +#define GEMM_RT SHGEMM_NT +#define GEMM_RC SHGEMM_NT +#define GEMM_RR SHGEMM_NN +#define GEMM_ONCOPY SHGEMM_ONCOPY +#define GEMM_OTCOPY SHGEMM_OTCOPY +#define GEMM_INCOPY SHGEMM_INCOPY +#define GEMM_ITCOPY SHGEMM_ITCOPY + +#define GEMM_THREAD_NN SHGEMM_THREAD_NN +#define GEMM_THREAD_CN SHGEMM_THREAD_TN +#define GEMM_THREAD_TN SHGEMM_THREAD_TN +#define GEMM_THREAD_NC SHGEMM_THREAD_NT +#define GEMM_THREAD_NT SHGEMM_THREAD_NT +#define GEMM_THREAD_CC SHGEMM_THREAD_TT +#define GEMM_THREAD_CT SHGEMM_THREAD_TT +#define GEMM_THREAD_TC SHGEMM_THREAD_TT +#define GEMM_THREAD_TT SHGEMM_THREAD_TT +#define GEMM_THREAD_NR SHGEMM_THREAD_NN +#define GEMM_THREAD_TR SHGEMM_THREAD_TN +#define GEMM_THREAD_CR SHGEMM_THREAD_TN +#define GEMM_THREAD_RN SHGEMM_THREAD_NN +#define GEMM_THREAD_RT SHGEMM_THREAD_NT +#define GEMM_THREAD_RC SHGEMM_THREAD_NT +#define GEMM_THREAD_RR SHGEMM_THREAD_NN + #else #define AMAX_K SAMAX_K @@ -2202,6 +2250,9 @@ #if defined(ARCH_X86) || defined(ARCH_X86_64) || defined(ARCH_IA64) || defined(ARCH_MIPS64) || defined(ARCH_ARM64) extern BLASLONG gemm_offset_a; extern BLASLONG gemm_offset_b; +extern BLASLONG shgemm_p; +extern BLASLONG shgemm_q; +extern BLASLONG shgemm_r; extern BLASLONG sgemm_p; extern BLASLONG sgemm_q; extern BLASLONG sgemm_r; diff --git a/common_param.h b/common_param.h index 574d5e176..f1cac38d1 100644 --- a/common_param.h +++ b/common_param.h @@ -84,6 +84,16 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + int shgemm_p, shgemm_q, shgemm_r; + int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; + int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); + int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); + + int (*shgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_oncopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_otcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*sgemm_incopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_itcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_oncopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); @@ -907,6 +917,13 @@ extern gotoblas_t *gotoblas; #define HAVE_EX_L2 gotoblas -> exclusive_cache +#define SHGEMM_P gotoblas -> shgemm_p +#define SHGEMM_Q gotoblas -> shgemm_q +#define SHGEMM_R gotoblas -> shgemm_r +#define SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m +#define SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n +#define SHGEMM_UNROLL_MN gotoblas -> shgemm_unroll_mn + #define SGEMM_P gotoblas -> sgemm_p #define SGEMM_Q gotoblas -> sgemm_q #define SGEMM_R gotoblas -> sgemm_r @@ -984,6 +1001,17 @@ extern gotoblas_t *gotoblas; #define HAVE_EX_L2 0 #endif +#define SHGEMM_P SHGEMM_DEFAULT_P +#define SHGEMM_Q SHGEMM_DEFAULT_Q +#define SHGEMM_R SHGEMM_DEFAULT_R +#define SHGEMM_UNROLL_M SHGEMM_DEFAULT_UNROLL_M +#define SHGEMM_UNROLL_N SHGEMM_DEFAULT_UNROLL_N +#ifdef SHGEMM_DEFAULT_UNROLL_MN +#define SHGEMM_UNROLL_MN SHGEMM_DEFAULT_UNROLL_MN +#else +#define SHGEMM_UNROLL_MN MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N)) +#endif + #define SGEMM_P SGEMM_DEFAULT_P #define SGEMM_Q SGEMM_DEFAULT_Q #define SGEMM_R SGEMM_DEFAULT_R @@ -1119,6 +1147,18 @@ extern gotoblas_t *gotoblas; #define GEMM_DEFAULT_R DGEMM_DEFAULT_R #define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M #define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N +#elif defined(HALF) +#define GEMM_P SHGEMM_P +#define GEMM_Q SHGEMM_Q +#define GEMM_R SHGEMM_R +#define GEMM_UNROLL_M SHGEMM_UNROLL_M +#define GEMM_UNROLL_N SHGEMM_UNROLL_N +#define GEMM_UNROLL_MN SHGEMM_UNROLL_MN +#define GEMM_DEFAULT_P SHGEMM_DEFAULT_P +#define GEMM_DEFAULT_Q SHGEMM_DEFAULT_Q +#define GEMM_DEFAULT_R SHGEMM_DEFAULT_R +#define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M +#define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N #else #define GEMM_P SGEMM_P #define GEMM_Q SGEMM_Q @@ -1204,6 +1244,10 @@ extern gotoblas_t *gotoblas; #define GEMM_THREAD gemm_thread_n #endif +#ifndef SHGEMM_DEFAULT_R +#define SHGEMM_DEFAULT_R (((BUFFER_SIZE - ((SHGEMM_DEFAULT_P * SHGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SHGEMM_DEFAULT_Q * 4) - 15) & ~15) +#endif + #ifndef SGEMM_DEFAULT_R #define SGEMM_DEFAULT_R (((BUFFER_SIZE - ((SGEMM_DEFAULT_P * SGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SGEMM_DEFAULT_Q * 4) - 15) & ~15) #endif diff --git a/common_sh.h b/common_sh.h new file mode 100644 index 000000000..8859694f1 --- /dev/null +++ b/common_sh.h @@ -0,0 +1,65 @@ +#ifndef COMMON_H_H +#define COMMON_H_H + +#ifndef DYNAMIC_ARCH + +#define SHGEMM_ONCOPY shgemm_oncopy +#define SHGEMM_OTCOPY shgemm_otcopy + +#if SHGEMM_DEFAULT_UNROLL_M == SHGEMM_DEFAULT_UNROLL_N +#define SHGEMM_INCOPY shgemm_oncopy +#define SHGEMM_ITCOPY shgemm_otcopy +#else +#define SHGEMM_INCOPY shgemm_incopy +#define SHGEMM_ITCOPY shgemm_itcopy +#endif +#define SHGEMM_BETA shgemm_beta +#define SHGEMM_KERNEL shgemm_kernel + +#else + +#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy +#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy +#define SHGEMM_INCOPY gotoblas -> shgemm_incopy +#define SHGEMM_ITCOPY gotoblas -> shgemm_itcopy +#define SHGEMM_BETA gotoblas -> shgemm_beta +#define SHGEMM_KERNEL gotoblas -> shgemm_kernel + +#endif + +#define SHGEMM_NN shgemm_nn +#define SHGEMM_CN shgemm_tn +#define SHGEMM_TN shgemm_tn +#define SHGEMM_NC shgemm_nt +#define SHGEMM_NT shgemm_nt +#define SHGEMM_CC shgemm_tt +#define SHGEMM_CT shgemm_tt +#define SHGEMM_TC shgemm_tt +#define SHGEMM_TT shgemm_tt +#define SHGEMM_NR shgemm_nn +#define SHGEMM_TR shgemm_tn +#define SHGEMM_CR shgemm_tn +#define SHGEMM_RN shgemm_nn +#define SHGEMM_RT shgemm_nt +#define SHGEMM_RC shgemm_nt +#define SHGEMM_RR shgemm_nn + +#define SHGEMM_THREAD_NN shgemm_thread_nn +#define SHGEMM_THREAD_CN shgemm_thread_tn +#define SHGEMM_THREAD_TN shgemm_thread_tn +#define SHGEMM_THREAD_NC shgemm_thread_nt +#define SHGEMM_THREAD_NT shgemm_thread_nt +#define SHGEMM_THREAD_CC shgemm_thread_tt +#define SHGEMM_THREAD_CT shgemm_thread_tt +#define SHGEMM_THREAD_TC shgemm_thread_tt +#define SHGEMM_THREAD_TT shgemm_thread_tt +#define SHGEMM_THREAD_NR shgemm_thread_nn +#define SHGEMM_THREAD_TR shgemm_thread_tn +#define SHGEMM_THREAD_CR shgemm_thread_tn +#define SHGEMM_THREAD_RN shgemm_thread_nn +#define SHGEMM_THREAD_RT shgemm_thread_nt +#define SHGEMM_THREAD_RC shgemm_thread_nt +#define SHGEMM_THREAD_RR shgemm_thread_nn + +#endif + diff --git a/driver/level3/Makefile b/driver/level3/Makefile index e320092e3..881b4ee35 100644 --- a/driver/level3/Makefile +++ b/driver/level3/Makefile @@ -19,6 +19,7 @@ ifeq ($(ARCH), MIPS) USE_GEMM3M = 1 endif +SHBLASOBJS += shgemm_nn.$(SUFFIX) shgemm_nt.$(SUFFIX) shgemm_tn.$(SUFFIX) shgemm_tt.$(SUFFIX) SBLASOBJS += \ sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \ strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \ @@ -204,6 +205,7 @@ COMMONOBJS += syrk_thread.$(SUFFIX) ifndef USE_SIMPLE_THREADED_LEVEL3 +SHBLASOBJS += shgemm_thread_nn.$(SUFFIX) shgemm_thread_nt.$(SUFFIX) shgemm_thread_tn.$(SUFFIX) shgemm_thread_tt.$(SUFFIX) SBLASOBJS += sgemm_thread_nn.$(SUFFIX) sgemm_thread_nt.$(SUFFIX) sgemm_thread_tn.$(SUFFIX) sgemm_thread_tt.$(SUFFIX) DBLASOBJS += dgemm_thread_nn.$(SUFFIX) dgemm_thread_nt.$(SUFFIX) dgemm_thread_tn.$(SUFFIX) dgemm_thread_tt.$(SUFFIX) QBLASOBJS += qgemm_thread_nn.$(SUFFIX) qgemm_thread_nt.$(SUFFIX) qgemm_thread_tn.$(SUFFIX) qgemm_thread_tt.$(SUFFIX) @@ -283,6 +285,18 @@ endif all :: +shgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +shgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +shgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +shgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) @@ -478,6 +492,17 @@ gemm_thread_variable.$(SUFFIX) : gemm_thread_variable.c ../../common.h beta_thread.$(SUFFIX) : beta_thread.c ../../common.h $(CC) -c $(CFLAGS) $< -o $(@F) +shgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +shgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +shgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +shgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) sgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) @@ -2652,6 +2677,18 @@ xtrsm_RCLU.$(SUFFIX) : trsm_R.c xtrsm_RCLN.$(SUFFIX) : trsm_R.c $(CC) -c $(CFLAGS) -DCOMPLEX -DXDOUBLE -DTRANSA -UUPPER -UUNIT -DCONJ $< -o $(@F) +shgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +shgemm_nt.$(PSUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +shgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +shgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h $(CC) $(PFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) @@ -2848,6 +2885,18 @@ beta_thread.$(PSUFFIX) : beta_thread.c ../../common.h $(CC) -c $(PFLAGS) $< -o $(@F) +shgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +shgemm_thread_nt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +shgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +shgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) diff --git a/driver/level3/level3.c b/driver/level3/level3.c index 9aa67286f..c6bbb9ca9 100644 --- a/driver/level3/level3.c +++ b/driver/level3/level3.c @@ -62,18 +62,18 @@ #ifndef ICOPY_OPERATION #if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \ defined(RN) || defined(RT) || defined(RC) || defined(RR) -#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); #else -#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); #endif #endif #ifndef OCOPY_OPERATION #if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \ defined(NR) || defined(TR) || defined(CR) || defined(RR) -#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); #else -#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); #endif #endif @@ -173,7 +173,8 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){ BLASLONG k, lda, ldb, ldc; FLOAT *alpha, *beta; - FLOAT *a, *b, *c; + IFLOAT *a, *b; + FLOAT *c; BLASLONG m_from, m_to, n_from, n_to; BLASLONG ls, is, js; @@ -198,8 +199,8 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, k = K; - a = (FLOAT *)A; - b = (FLOAT *)B; + a = (IFLOAT *)A; + b = (IFLOAT *)B; c = (FLOAT *)C; lda = LDA; diff --git a/driver/level3/level3_thread.c b/driver/level3/level3_thread.c index ca0085e71..5a8d497d2 100644 --- a/driver/level3/level3_thread.c +++ b/driver/level3/level3_thread.c @@ -117,18 +117,18 @@ typedef struct { #ifndef ICOPY_OPERATION #if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \ defined(RN) || defined(RT) || defined(RC) || defined(RR) -#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); #else -#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); #endif #endif #ifndef OCOPY_OPERATION #if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \ defined(NR) || defined(TR) || defined(CR) || defined(RR) -#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); #else -#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); #endif #endif @@ -219,15 +219,16 @@ typedef struct { #define STOP_RPCC(COUNTER) #endif -static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){ +static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){ - FLOAT *buffer[DIVIDE_RATE]; + IFLOAT *buffer[DIVIDE_RATE]; BLASLONG k, lda, ldb, ldc; BLASLONG m_from, m_to, n_from, n_to; FLOAT *alpha, *beta; - FLOAT *a, *b, *c; + IFLOAT *a, *b; + FLOAT *c; job_t *job = (job_t *)args -> common; BLASLONG nthreads_m; @@ -255,8 +256,8 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, k = K; - a = (FLOAT *)A; - b = (FLOAT *)B; + a = (IFLOAT *)A; + b = (IFLOAT *)B; c = (FLOAT *)C; lda = LDA; @@ -425,7 +426,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, /* Apply kernel with local region of A and part of other region of B */ START_RPCC(); KERNEL_OPERATION(min_i, MIN(range_n[current + 1] - js, div_n), min_l, alpha, - sa, (FLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside], + sa, (IFLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside], c, ldc, m_from, js); STOP_RPCC(kernel); @@ -469,7 +470,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, /* Apply kernel with local region of A and part of region of B */ START_RPCC(); KERNEL_OPERATION(min_i, MIN(range_n[current + 1] - js, div_n), min_l, alpha, - sa, (FLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside], + sa, (IFLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside], c, ldc, is, js); STOP_RPCC(kernel); @@ -532,7 +533,7 @@ static int round_up(int remainder, int width, int multiple) static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG - *range_n, FLOAT *sa, FLOAT *sb, + *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG nthreads_m, BLASLONG nthreads_n) { #ifndef USE_OPENMP @@ -728,7 +729,7 @@ EnterCriticalSection((PCRITICAL_SECTION)&level3_lock); return 0; } -int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){ +int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){ BLASLONG m = args -> m; BLASLONG n = args -> n; diff --git a/driver/others/parameter.c b/driver/others/parameter.c index 8bf7da78b..b1f3befae 100644 --- a/driver/others/parameter.c +++ b/driver/others/parameter.c @@ -62,6 +62,11 @@ BLASLONG gemm_offset_b = DEFAULT_GEMM_OFFSET_B; BLASLONG gemm_offset_b = GEMM_OFFSET_B; #endif +#if SHGEMM_P == shgemm_p +BLASLONG shgemm_p = DEFAULT_GEMM_P; +#else +BLASLONG shgemm_p = SHGEMM_P; +#endif #if SGEMM_P == sgemm_p BLASLONG sgemm_p = DEFAULT_GEMM_P; #else @@ -83,6 +88,11 @@ BLASLONG zgemm_p = DEFAULT_GEMM_P; BLASLONG zgemm_p = ZGEMM_P; #endif +#if SHGEMM_Q == shgemm_q +BLASLONG shgemm_q = DEFAULT_GEMM_Q; +#else +BLASLONG shgemm_q = SHGEMM_Q; +#endif #if SGEMM_Q == sgemm_q BLASLONG sgemm_q = DEFAULT_GEMM_Q; #else @@ -104,6 +114,11 @@ BLASLONG zgemm_q = DEFAULT_GEMM_Q; BLASLONG zgemm_q = ZGEMM_Q; #endif +#if SHGEMM_R == shgemm_r +BLASLONG shgemm_r = DEFAULT_GEMM_R; +#else +BLASLONG shgemm_r = SHGEMM_R; +#endif #if SGEMM_R == sgemm_r BLASLONG sgemm_r = DEFAULT_GEMM_R; #else @@ -597,6 +612,7 @@ void blas_set_parameter(void){ size = BITMASK(cpuid3, 16, 0xff); + shgemm_p = 192 * (size + 1); sgemm_p = 192 * (size + 1); dgemm_p = 96 * (size + 1); cgemm_p = 96 * (size + 1); @@ -610,6 +626,7 @@ void blas_set_parameter(void){ xgemm_p = 16 * (size + 1); #endif + shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q * 4)) - 15) & ~15; sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15; dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15; cgemm_r = (((BUFFER_SIZE - ((CGEMM_P * CGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (CGEMM_Q * 8)) - 15) & ~15; diff --git a/getarch_2nd.c b/getarch_2nd.c index cf9c578cb..a1d0ccac8 100644 --- a/getarch_2nd.c +++ b/getarch_2nd.c @@ -9,6 +9,8 @@ int main(int argc, char **argv) { if ( (argc <= 1) || ((argc >= 2) && (*argv[1] == '0'))) { + printf("SHGEMM_UNROLL_M=%d\n", SHGEMM_DEFAULT_UNROLL_M); + printf("SHGEMM_UNROLL_N=%d\n", SHGEMM_DEFAULT_UNROLL_N); printf("SGEMM_UNROLL_M=%d\n", SGEMM_DEFAULT_UNROLL_M); printf("SGEMM_UNROLL_N=%d\n", SGEMM_DEFAULT_UNROLL_N); printf("DGEMM_UNROLL_M=%d\n", DGEMM_DEFAULT_UNROLL_M); diff --git a/interface/Makefile b/interface/Makefile index 3f0dcca28..741f6bac0 100644 --- a/interface/Makefile +++ b/interface/Makefile @@ -46,6 +46,7 @@ SBLAS3OBJS = \ somatcopy.$(SUFFIX) simatcopy.$(SUFFIX)\ sgeadd.$(SUFFIX) +SHBLAS3OBJS = shgemm.$(SUFFIX) DBLAS1OBJS = \ daxpy.$(SUFFIX) dswap.$(SUFFIX) \ @@ -277,6 +278,8 @@ CSBLAS3OBJS = \ cblas_ssyrk.$(SUFFIX) cblas_ssyr2k.$(SUFFIX) cblas_somatcopy.$(SUFFIX) cblas_simatcopy.$(SUFFIX)\ cblas_sgeadd.$(SUFFIX) +CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX) + CDBLAS1OBJS = \ cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \ cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \ @@ -367,6 +370,7 @@ override CFLAGS += -I. SBLAS1OBJS += $(CSBLAS1OBJS) SBLAS2OBJS += $(CSBLAS2OBJS) SBLAS3OBJS += $(CSBLAS3OBJS) +SHBLAS3OBJS += $(CSHBLAS3OBJS) DBLAS1OBJS += $(CDBLAS1OBJS) DBLAS2OBJS += $(CDBLAS2OBJS) DBLAS3OBJS += $(CDBLAS3OBJS) @@ -380,6 +384,7 @@ ZBLAS3OBJS += $(CZBLAS3OBJS) endif SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) +SHBLASOBJS = $(SHBLAS3OBJS) DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS) CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS) @@ -454,7 +459,7 @@ ZBLASOBJS += $(ZLAPACKOBJS) endif -FUNCOBJS = $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) +FUNCOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) ifdef EXPRECISION FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS) @@ -488,10 +493,10 @@ level1 : $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $ level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ -level3 : $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) +level3 : $(SHBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ -$(CSBLASOBJS) $(CSBLASOBJS_P) $(CDBLASOBJS) $(CDBLASOBJS_P) $(CQBLASOBJS) $(CQBLASOBJS_P) \ +$(CSHBLASOBJS) $(CSHBLASOBJS_P) $(CSBLASOBJS) $(CSBLASOBJS_P) $(CDBLASOBJS) $(CDBLASOBJS_P) $(CQBLASOBJS) $(CQBLASOBJS_P) \ $(CCBLASOBJS) $(CCBLASOBJS_P) $(CZBLASOBJS) $(CZBLASOBJS_P) $(CXBLASOBJS) $(CXBLASOBJS_P) : override CFLAGS += -DCBLAS srot.$(SUFFIX) srot.$(PSUFFIX) : rot.c @@ -1209,6 +1214,9 @@ zhpr2.$(SUFFIX) zhpr2.$(PSUFFIX) : zhpr2.c xhpr2.$(SUFFIX) xhpr2.$(PSUFFIX) : zhpr2.c $(CC) -c $(CFLAGS) $< -o $(@F) +shgemm.$(SUFFIX) shgemm.$(PSUFFIX) : gemm.c ../param.h + $(CC) -c $(CFLAGS) $< -o $(@F) + sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -c $(CFLAGS) $< -o $(@F) @@ -1770,6 +1778,9 @@ cblas_zhemv.$(SUFFIX) cblas_zhemv.$(PSUFFIX) : zhemv.c cblas_sgemm.$(SUFFIX) cblas_sgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) +cblas_shgemm.$(SUFFIX) cblas_shgemm.$(PSUFFIX) : gemm.c ../param.h + $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) + cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) diff --git a/interface/gemm.c b/interface/gemm.c index 0b18d9a8c..99388e7d9 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -77,7 +77,7 @@ #define GEMM_MULTITHREAD_THRESHOLD 4 #endif -static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLASLONG) = { +static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = { #ifndef GEMM3M GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN, GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT, @@ -108,8 +108,8 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLA void NAME(char *TRANSA, char *TRANSB, blasint *M, blasint *N, blasint *K, FLOAT *alpha, - FLOAT *a, blasint *ldA, - FLOAT *b, blasint *ldB, + IFLOAT *a, blasint *ldA, + IFLOAT *b, blasint *ldB, FLOAT *beta, FLOAT *c, blasint *ldC){ @@ -119,8 +119,8 @@ void NAME(char *TRANSA, char *TRANSB, blasint info; char transA, transB; - FLOAT *buffer; - FLOAT *sa, *sb; + IFLOAT *buffer; + IFLOAT *sa, *sb; #ifdef SMP double MNK; diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 6d96abb2e..aee610efb 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -59,6 +59,10 @@ ifeq ($(CORE), Z14) USE_TRMM = 1 endif +SHKERNELOBJS += \ + shgemm_kernel$(TSUFFIX).$(SUFFIX) \ + $(SHGEMMINCOPYOBJ) $(SHGEMMITCOPYOBJ) \ + $(SHGEMMONCOPYOBJ) $(SHGEMMOTCOPYOBJ) SKERNELOBJS += \ sgemm_kernel$(TSUFFIX).$(SUFFIX) \ @@ -93,6 +97,7 @@ XKERNELOBJS += \ $(XGEMMINCOPYOBJ) $(XGEMMITCOPYOBJ) \ $(XGEMMONCOPYOBJ) $(XGEMMOTCOPYOBJ) +SHBLASOBJS += $(SHKERNELOBJS) SBLASOBJS += $(SKERNELOBJS) DBLASOBJS += $(DKERNELOBJS) QBLASOBJS += $(QKERNELOBJS) @@ -100,6 +105,7 @@ CBLASOBJS += $(CKERNELOBJS) ZBLASOBJS += $(ZKERNELOBJS) XBLASOBJS += $(XKERNELOBJS) +SHBLASOBJS += shgemm_beta$(TSUFFIX).$(SUFFIX) SBLASOBJS += \ sgemm_beta$(TSUFFIX).$(SUFFIX) \ strmm_kernel_LN$(TSUFFIX).$(SUFFIX) strmm_kernel_LT$(TSUFFIX).$(SUFFIX) \ @@ -390,6 +396,10 @@ ZBLASOBJS += \ zgeadd_k$(TSUFFIX).$(SUFFIX) +SHGEMMINCOPYOBJ_P = $(SHGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +SHGEMMITCOPYOBJ_P = $(SHGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +SHGEMMONCOPYOBJ_P = $(SHGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +SHGEMMOTCOPYOBJ_P = $(SHGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SGEMMINCOPYOBJ_P = $(SGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SGEMMITCOPYOBJ_P = $(SGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SGEMMONCOPYOBJ_P = $(SGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) @@ -415,6 +425,9 @@ XGEMMITCOPYOBJ_P = $(XGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) XGEMMONCOPYOBJ_P = $(XGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) XGEMMOTCOPYOBJ_P = $(XGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +$(KDIR)shgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA) + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + $(KDIR)sgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_BETA) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ @@ -433,6 +446,36 @@ $(KDIR)zgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_BETA) $(KDIR)xgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMM_BETA) $(CC) $(CFLAGS) -c -DXDOUBLE -DCOMPLEX $< -o $@ +$(KDIR)$(SHGEMMONCOPYOBJ) : $(KERNELDIR)/$(SHGEMMONCOPY) + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)$(SHGEMMOTCOPYOBJ) : $(KERNELDIR)/$(SHGEMMOTCOPY) +ifeq ($(OS), AIX) + $(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemmotcopy.s + m4 shgemmotcopy.s > shgemmotcopy_nomacros.s + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemmotcopy_nomacros.s -o $@ + rm shgemmotcopy.s shgemmotcopy_nomacros.s +else + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ +endif + +ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N)) + +$(KDIR)$(SHGEMMINCOPYOBJ) : $(KERNELDIR)/$(SHGEMMINCOPY) + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)$(SHGEMMITCOPYOBJ) : $(KERNELDIR)/$(SHGEMMITCOPY) +ifeq ($(OS), AIX) + $(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemmitcopy.s + m4 shgemmitcopy.s > shgemmitcopy_nomacros.s + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemmitcopy_nomacros.s -o $@ + rm shgemmitcopy.s shgemmitcopy_nomacros.s +else + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ +endif + +endif + $(KDIR)$(SGEMMONCOPYOBJ) : $(KERNELDIR)/$(SGEMMONCOPY) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ @@ -590,6 +633,16 @@ else $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ endif +$(KDIR)shgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND) +ifeq ($(OS), AIX) + $(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemm_kernel$(TSUFFIX).s + m4 shgemm_kernel$(TSUFFIX).s > shgemm_kernel$(TSUFFIX)_nomacros.s + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemm_kernel$(TSUFFIX)_nomacros.s -o $@ + rm shgemm_kernel$(TSUFFIX).s shgemm_kernel$(TSUFFIX)_nomacros.s +else + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ +endif + $(KDIR)dgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMMKERNEL) $(DGEMMDEPEND) ifeq ($(OS), AIX) $(CC) $(CFLAGS) -E -DDOUBLE -UCOMPLEX $< -o dgemm_kernel$(TSUFFIX).s @@ -2206,6 +2259,9 @@ $(KDIR)xtrsm_oltncopy$(TSUFFIX).$(SUFFIX) : generic/ztrsm_ltcopy_$(XGEMM_UNROLL_ $(KDIR)sgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMM_BETA) $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ +$(KDIR)shgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA) + $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + $(KDIR)dgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(DGEMM_BETA) $(CC) $(PFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ @@ -2221,6 +2277,20 @@ $(KDIR)zgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(ZGEMM_BETA) $(KDIR)xgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(XGEMM_BETA) $(CC) $(PFLAGS) -c -DXDOUBLE -DCOMPLEX $< -o $@ +$(SHGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMONCOPY) + $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + +$(SHGEMMOTCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMOTCOPY) + $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + +ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N)) +$(SHGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMINCOPY) + $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + +$(SHGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMITCOPY) + $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + +endif $(SGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SGEMMONCOPY) $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ @@ -2325,6 +2395,9 @@ endif endif +$(KDIR)shgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND) + $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + $(KDIR)sgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMMKERNEL) $(SGEMMDEPEND) $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ diff --git a/kernel/generic/gemm_beta.c b/kernel/generic/gemm_beta.c index fa9d7680d..ccb772cc7 100644 --- a/kernel/generic/gemm_beta.c +++ b/kernel/generic/gemm_beta.c @@ -39,7 +39,7 @@ #include "common.h" int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, - FLOAT *dummy2, BLASLONG dummy3, FLOAT *dummy4, BLASLONG dummy5, + IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, FLOAT *c, BLASLONG ldc){ diff --git a/kernel/generic/gemm_ncopy_2.c b/kernel/generic/gemm_ncopy_2.c index b728c713f..415860f81 100644 --- a/kernel/generic/gemm_ncopy_2.c +++ b/kernel/generic/gemm_ncopy_2.c @@ -39,10 +39,10 @@ #include #include "common.h" -int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){ +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){ BLASLONG i, j; - FLOAT *a_offset, *a_offset1, *a_offset2; - FLOAT *b_offset; + IFLOAT *a_offset, *a_offset1, *a_offset2; + IFLOAT *b_offset; a_offset = a; b_offset = b; diff --git a/kernel/generic/gemm_tcopy_2.c b/kernel/generic/gemm_tcopy_2.c index 5695b13c2..b4aa4de57 100644 --- a/kernel/generic/gemm_tcopy_2.c +++ b/kernel/generic/gemm_tcopy_2.c @@ -39,11 +39,11 @@ #include #include "common.h" -int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){ +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){ BLASLONG i, j; - FLOAT *a_offset, *a_offset1, *a_offset2; - FLOAT *b_offset, *b_offset1, *b_offset2; + IFLOAT *a_offset, *a_offset1, *a_offset2; + IFLOAT *b_offset, *b_offset1, *b_offset2; a_offset = a; b_offset = b; diff --git a/kernel/generic/gemmkernel_2x2.c b/kernel/generic/gemmkernel_2x2.c index 01f1c67b5..26a88db6d 100644 --- a/kernel/generic/gemmkernel_2x2.c +++ b/kernel/generic/gemmkernel_2x2.c @@ -1,13 +1,32 @@ #include "common.h" -int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FLOAT* C,BLASLONG ldc +#if defined(HALF) && defined(HALFCONVERSION) +float +bfloat16tof32 (bfloat16 f16) +{ + float result = 0; + unsigned short* q = (unsigned short*)(&result); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + q[0] = f16; +#else + q[1] = f16; +#endif + return result; +} +#define BF16TOF32(x) (bfloat16tof32(x)) +#else +#define BF16TOF32(x) x +#endif +int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc #ifdef TRMMKERNEL ,BLASLONG offset #endif ) { BLASLONG i,j,k; - FLOAT *C0,*C1,*ptrba,*ptrbb; - FLOAT res0,res1,res2,res3,load0,load1,load2,load3,load4,load5,load6,load7; + FLOAT *C0,*C1; + IFLOAT *ptrba,*ptrbb; + FLOAT res0,res1,res2,res3; + IFLOAT load0,load1,load2,load3,load4,load5,load6,load7; for (j=0; j +#include +#include "common.h" +#define SGEMM BLASFUNC(sgemm) +#define SHGEMM BLASFUNC(shgemm) +typedef union +{ + unsigned short v; + struct + { +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + unsigned short s:1; + unsigned short e:8; + unsigned short m:7; +#else + unsigned short m:7; + unsigned short e:8; + unsigned short s:1; +#endif + } bits; +} bfloat16_bits; + +int +main (int argc, char *argv[]) +{ + int m, n, k; + int i, j, l; + int ret = 0; + int loop = 20; + char transA = 'N', transB = 'N'; + float alpha = 1.0, beta = 0.0; + char transa = 'N'; + char transb = 'N'; + + for (int x = 0; x <= loop; x++) + { + m = k = n = x; + float A[m * k]; + float B[k * n]; + float C[m * n]; + bfloat16_bits AA[m * k], BB[k * n]; + float CC[m * n]; + + for (int j = 0; j < m; j++) + { + for (int i = 0; i < m; i++) + { + A[j * k + i] = j * 9.0; + B[j * k + i] = i * 2.0; + C[j * k + i] = 0; + AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16; + BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16; + CC[j * k + i] = 0; + } + } + SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, + &m, B, &k, &beta, C, &m); + SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA, + &m, BB, &k, &beta, CC, &m); + + for (i = 0; i < n; i++) + for (j = 0; j < m; j++) + for (l = 0; l < k; l++) + if (CC[i * m + j] != C[i * m + j]) + ret++; + } + fprintf (stderr, "Return code: %d\n", ret); + return ret; +}