From 7eb55504b1727eebcb0f451fa5b148dbea303b69 Mon Sep 17 00:00:00 2001 From: Rajalakshmi Srinivasaraghavan Date: Tue, 14 Apr 2020 14:55:08 -0500 Subject: [PATCH 01/15] RFC : Add half precision gemm for bfloat16 in OpenBLAS This patch adds support for bfloat16 data type matrix multiplication kernel. For architectures that don't support bfloat16, it is defined as unsigned short (2 bytes). Default unroll sizes can be changed as per architecture as done for SGEMM and for now 8 and 4 are used for M and N. Size of ncopy/tcopy can be changed as per architecture requirement and for now, size 2 is used. Added shgemm in kernel/power/KERNEL.POWER9 and tested in powerpc64le and powerpc64. For reference, added a small test compare_sgemm_shgemm.c to compare sgemm and shgemm output. This patch does not cover OpenBLAS test, benchmark and lapack tests for shgemm. Complex type implementation can be discussed and added once this is approved. --- Makefile.system | 2 + Makefile.tail | 7 ++- cmake/prebuild.cmake | 4 ++ cmake/system.cmake | 2 + common.h | 15 ++++++ common_interface.h | 5 ++ common_level3.h | 20 +++++++ common_macro.h | 51 ++++++++++++++++++ common_param.h | 44 +++++++++++++++ common_sh.h | 65 ++++++++++++++++++++++ driver/level3/Makefile | 49 +++++++++++++++++ driver/level3/level3.c | 15 +++--- driver/level3/level3_thread.c | 27 +++++----- driver/others/parameter.c | 17 ++++++ getarch_2nd.c | 2 + interface/Makefile | 17 ++++-- interface/gemm.c | 10 ++-- kernel/Makefile.L3 | 73 +++++++++++++++++++++++++ kernel/generic/gemm_beta.c | 2 +- kernel/generic/gemm_ncopy_2.c | 6 +-- kernel/generic/gemm_tcopy_2.c | 6 +-- kernel/generic/gemmkernel_2x2.c | 75 ++++++++++++++++---------- kernel/power/KERNEL.POWER9 | 11 ++++ kernel/setparam-ref.c | 30 +++++++++++ lapack/getrf/potrf_parallel.c | 3 ++ param.h | 6 +++ test/compare_sgemm_shgemm.c | 95 +++++++++++++++++++++++++++++++++ 27 files changed, 594 insertions(+), 65 deletions(-) create mode 100644 common_sh.h create mode 100644 test/compare_sgemm_shgemm.c diff --git a/Makefile.system b/Makefile.system index 2998c0e6a..0e176987c 100644 --- a/Makefile.system +++ b/Makefile.system @@ -1390,6 +1390,8 @@ export FUNCTION_PROFILE export TARGET_CORE export NO_AVX512 +export SHGEMM_UNROLL_M +export SHGEMM_UNROLL_N export SGEMM_UNROLL_M export SGEMM_UNROLL_N export DGEMM_UNROLL_M diff --git a/Makefile.tail b/Makefile.tail index 2adede1a5..39902982b 100644 --- a/Makefile.tail +++ b/Makefile.tail @@ -1,3 +1,4 @@ +SHBLASOBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) @@ -9,8 +10,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX)) HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) -BLASOBJS = $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) -BLASOBJS_P = $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) +BLASOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) +BLASOBJS_P = $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) ifdef EXPRECISION BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) @@ -22,6 +23,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P) endif +$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHALF -UDOUBLE -UCOMPLEX $(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX $(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX $(QBLASOBJS) $(QBLASOBJS_P) : override CFLAGS += -DXDOUBLE -UCOMPLEX @@ -29,6 +31,7 @@ $(CBLASOBJS) $(CBLASOBJS_P) : override CFLAGS += -UDOUBLE -DCOMPLEX $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX $(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX +$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(QBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) diff --git a/cmake/prebuild.cmake b/cmake/prebuild.cmake index 44e1473d1..e0696093b 100644 --- a/cmake/prebuild.cmake +++ b/cmake/prebuild.cmake @@ -16,6 +16,8 @@ # HAVE_SSE2 # HAVE_SSE3 # MAKE +# SHGEMM_UNROLL_M +# SHGEMM_UNROLL_N # SGEMM_UNROLL_M # SGEMM_UNROLL_N # DGEMM_UNROLL_M @@ -437,6 +439,8 @@ if (DEFINED CORE AND CMAKE_CROSSCOMPILING AND NOT (${HOST_OS} STREQUAL "WINDOWSS set(ZGEMM_UNROLL_N 2) set(SYMV_P 8) endif() + set(SHGEMM_UNROLL_M 8) + set(SHGEMM_UNROLL_N 4) # Or should this actually be NUM_CORES? if (${NUM_THREADS} GREATER 0) diff --git a/cmake/system.cmake b/cmake/system.cmake index ce980a7b9..65e5aa508 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -530,6 +530,8 @@ endif () #export FUNCTION_PROFILE #export TARGET_CORE # +#export SHGEMM_UNROLL_M +#export SHGEMM_UNROLL_N #export SGEMM_UNROLL_M #export SGEMM_UNROLL_N #export DGEMM_UNROLL_M diff --git a/common.h b/common.h index 762968e6f..1d8bf07e5 100644 --- a/common.h +++ b/common.h @@ -297,6 +297,17 @@ typedef int blasint; #define SIZE 8 #define BASE_SHIFT 3 #define ZBASE_SHIFT 4 +#elif defined(HALF) +#ifndef BFLOAT16 +typedef unsigned short bfloat16; +#define HALFCONVERSION 1 +#endif +#define IFLOAT bfloat16 +#define XFLOAT IFLOAT +#define FLOAT float +#define SIZE 2 +#define BASE_SHIFT 1 +#define ZBASE_SHIFT 2 #else #define FLOAT float #define SIZE 4 @@ -308,6 +319,10 @@ typedef int blasint; #define XFLOAT FLOAT #endif +#ifndef IFLOAT +#define IFLOAT FLOAT +#endif + #ifndef COMPLEX #define COMPSIZE 1 #else diff --git a/common_interface.h b/common_interface.h index c350ac8ec..081043af1 100644 --- a/common_interface.h +++ b/common_interface.h @@ -37,6 +37,9 @@ /*********************************************************************/ #ifndef ASSEMBLER +#ifndef BFLOAT16 +typedef unsigned short bfloat16; +#endif #ifdef __cplusplus extern "C" { @@ -469,6 +472,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint /* Level 3 routines */ +void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *, + bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *); void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *, float *, blasint *, float *, blasint *, float *, float *, blasint *); void BLASFUNC(dgemm)(char *, char *, blasint *, blasint *, blasint *, double *, diff --git a/common_level3.h b/common_level3.h index 6fa902be8..8194ba6ce 100644 --- a/common_level3.h +++ b/common_level3.h @@ -37,6 +37,9 @@ /*********************************************************************/ #ifndef ASSEMBLER +#ifndef BFLOAT16 +typedef unsigned short bfloat16; +#endif #ifdef __CUDACC__ __global__ void cuda_sgemm_kernel(int, int, int, float *, float *, float *); @@ -55,6 +58,8 @@ extern void sgemm_kernel_direct(BLASLONG M, BLASLONG N, BLASLONG K, extern int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); +int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, + bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); int dgemm_beta(BLASLONG, BLASLONG, BLASLONG, double, @@ -76,6 +81,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG); #endif +int shgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); +int shgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); +int shgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); +int shgemm_otcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); int sgemm_incopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b); int sgemm_itcopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b); int sgemm_oncopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b); @@ -499,6 +508,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); +int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG); @@ -527,6 +537,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG); int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG); +int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); + int sgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); int sgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); int sgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); @@ -619,6 +634,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG); #endif +int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); + int sgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); int sgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); int sgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); diff --git a/common_macro.h b/common_macro.h index 13bb85794..b438c83ba 100644 --- a/common_macro.h +++ b/common_macro.h @@ -39,6 +39,7 @@ #ifndef COMMON_MACRO #define COMMON_MACRO +#include "common_sh.h" #include "common_s.h" #include "common_d.h" #include "common_q.h" @@ -642,6 +643,53 @@ #define IMATCOPY_K_RT DIMATCOPY_K_RT #define GEADD_K DGEADD_K + +#elif defined(HALF) + +#define GEMM_BETA SHGEMM_BETA +#define GEMM_KERNEL_N SHGEMM_KERNEL +#define GEMM_KERNEL_L SHGEMM_KERNEL +#define GEMM_KERNEL_R SHGEMM_KERNEL +#define GEMM_KERNEL_B SHGEMM_KERNEL + +#define GEMM_NN SHGEMM_NN +#define GEMM_CN SHGEMM_TN +#define GEMM_TN SHGEMM_TN +#define GEMM_NC SHGEMM_NT +#define GEMM_NT SHGEMM_NT +#define GEMM_CC SHGEMM_TT +#define GEMM_CT SHGEMM_TT +#define GEMM_TC SHGEMM_TT +#define GEMM_TT SHGEMM_TT +#define GEMM_NR SHGEMM_NN +#define GEMM_TR SHGEMM_TN +#define GEMM_CR SHGEMM_TN +#define GEMM_RN SHGEMM_NN +#define GEMM_RT SHGEMM_NT +#define GEMM_RC SHGEMM_NT +#define GEMM_RR SHGEMM_NN +#define GEMM_ONCOPY SHGEMM_ONCOPY +#define GEMM_OTCOPY SHGEMM_OTCOPY +#define GEMM_INCOPY SHGEMM_INCOPY +#define GEMM_ITCOPY SHGEMM_ITCOPY + +#define GEMM_THREAD_NN SHGEMM_THREAD_NN +#define GEMM_THREAD_CN SHGEMM_THREAD_TN +#define GEMM_THREAD_TN SHGEMM_THREAD_TN +#define GEMM_THREAD_NC SHGEMM_THREAD_NT +#define GEMM_THREAD_NT SHGEMM_THREAD_NT +#define GEMM_THREAD_CC SHGEMM_THREAD_TT +#define GEMM_THREAD_CT SHGEMM_THREAD_TT +#define GEMM_THREAD_TC SHGEMM_THREAD_TT +#define GEMM_THREAD_TT SHGEMM_THREAD_TT +#define GEMM_THREAD_NR SHGEMM_THREAD_NN +#define GEMM_THREAD_TR SHGEMM_THREAD_TN +#define GEMM_THREAD_CR SHGEMM_THREAD_TN +#define GEMM_THREAD_RN SHGEMM_THREAD_NN +#define GEMM_THREAD_RT SHGEMM_THREAD_NT +#define GEMM_THREAD_RC SHGEMM_THREAD_NT +#define GEMM_THREAD_RR SHGEMM_THREAD_NN + #else #define AMAX_K SAMAX_K @@ -2202,6 +2250,9 @@ #if defined(ARCH_X86) || defined(ARCH_X86_64) || defined(ARCH_IA64) || defined(ARCH_MIPS64) || defined(ARCH_ARM64) extern BLASLONG gemm_offset_a; extern BLASLONG gemm_offset_b; +extern BLASLONG shgemm_p; +extern BLASLONG shgemm_q; +extern BLASLONG shgemm_r; extern BLASLONG sgemm_p; extern BLASLONG sgemm_q; extern BLASLONG sgemm_r; diff --git a/common_param.h b/common_param.h index 574d5e176..f1cac38d1 100644 --- a/common_param.h +++ b/common_param.h @@ -84,6 +84,16 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + int shgemm_p, shgemm_q, shgemm_r; + int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; + int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); + int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); + + int (*shgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_oncopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_otcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*sgemm_incopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_itcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_oncopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); @@ -907,6 +917,13 @@ extern gotoblas_t *gotoblas; #define HAVE_EX_L2 gotoblas -> exclusive_cache +#define SHGEMM_P gotoblas -> shgemm_p +#define SHGEMM_Q gotoblas -> shgemm_q +#define SHGEMM_R gotoblas -> shgemm_r +#define SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m +#define SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n +#define SHGEMM_UNROLL_MN gotoblas -> shgemm_unroll_mn + #define SGEMM_P gotoblas -> sgemm_p #define SGEMM_Q gotoblas -> sgemm_q #define SGEMM_R gotoblas -> sgemm_r @@ -984,6 +1001,17 @@ extern gotoblas_t *gotoblas; #define HAVE_EX_L2 0 #endif +#define SHGEMM_P SHGEMM_DEFAULT_P +#define SHGEMM_Q SHGEMM_DEFAULT_Q +#define SHGEMM_R SHGEMM_DEFAULT_R +#define SHGEMM_UNROLL_M SHGEMM_DEFAULT_UNROLL_M +#define SHGEMM_UNROLL_N SHGEMM_DEFAULT_UNROLL_N +#ifdef SHGEMM_DEFAULT_UNROLL_MN +#define SHGEMM_UNROLL_MN SHGEMM_DEFAULT_UNROLL_MN +#else +#define SHGEMM_UNROLL_MN MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N)) +#endif + #define SGEMM_P SGEMM_DEFAULT_P #define SGEMM_Q SGEMM_DEFAULT_Q #define SGEMM_R SGEMM_DEFAULT_R @@ -1119,6 +1147,18 @@ extern gotoblas_t *gotoblas; #define GEMM_DEFAULT_R DGEMM_DEFAULT_R #define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M #define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N +#elif defined(HALF) +#define GEMM_P SHGEMM_P +#define GEMM_Q SHGEMM_Q +#define GEMM_R SHGEMM_R +#define GEMM_UNROLL_M SHGEMM_UNROLL_M +#define GEMM_UNROLL_N SHGEMM_UNROLL_N +#define GEMM_UNROLL_MN SHGEMM_UNROLL_MN +#define GEMM_DEFAULT_P SHGEMM_DEFAULT_P +#define GEMM_DEFAULT_Q SHGEMM_DEFAULT_Q +#define GEMM_DEFAULT_R SHGEMM_DEFAULT_R +#define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M +#define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N #else #define GEMM_P SGEMM_P #define GEMM_Q SGEMM_Q @@ -1204,6 +1244,10 @@ extern gotoblas_t *gotoblas; #define GEMM_THREAD gemm_thread_n #endif +#ifndef SHGEMM_DEFAULT_R +#define SHGEMM_DEFAULT_R (((BUFFER_SIZE - ((SHGEMM_DEFAULT_P * SHGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SHGEMM_DEFAULT_Q * 4) - 15) & ~15) +#endif + #ifndef SGEMM_DEFAULT_R #define SGEMM_DEFAULT_R (((BUFFER_SIZE - ((SGEMM_DEFAULT_P * SGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SGEMM_DEFAULT_Q * 4) - 15) & ~15) #endif diff --git a/common_sh.h b/common_sh.h new file mode 100644 index 000000000..8859694f1 --- /dev/null +++ b/common_sh.h @@ -0,0 +1,65 @@ +#ifndef COMMON_H_H +#define COMMON_H_H + +#ifndef DYNAMIC_ARCH + +#define SHGEMM_ONCOPY shgemm_oncopy +#define SHGEMM_OTCOPY shgemm_otcopy + +#if SHGEMM_DEFAULT_UNROLL_M == SHGEMM_DEFAULT_UNROLL_N +#define SHGEMM_INCOPY shgemm_oncopy +#define SHGEMM_ITCOPY shgemm_otcopy +#else +#define SHGEMM_INCOPY shgemm_incopy +#define SHGEMM_ITCOPY shgemm_itcopy +#endif +#define SHGEMM_BETA shgemm_beta +#define SHGEMM_KERNEL shgemm_kernel + +#else + +#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy +#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy +#define SHGEMM_INCOPY gotoblas -> shgemm_incopy +#define SHGEMM_ITCOPY gotoblas -> shgemm_itcopy +#define SHGEMM_BETA gotoblas -> shgemm_beta +#define SHGEMM_KERNEL gotoblas -> shgemm_kernel + +#endif + +#define SHGEMM_NN shgemm_nn +#define SHGEMM_CN shgemm_tn +#define SHGEMM_TN shgemm_tn +#define SHGEMM_NC shgemm_nt +#define SHGEMM_NT shgemm_nt +#define SHGEMM_CC shgemm_tt +#define SHGEMM_CT shgemm_tt +#define SHGEMM_TC shgemm_tt +#define SHGEMM_TT shgemm_tt +#define SHGEMM_NR shgemm_nn +#define SHGEMM_TR shgemm_tn +#define SHGEMM_CR shgemm_tn +#define SHGEMM_RN shgemm_nn +#define SHGEMM_RT shgemm_nt +#define SHGEMM_RC shgemm_nt +#define SHGEMM_RR shgemm_nn + +#define SHGEMM_THREAD_NN shgemm_thread_nn +#define SHGEMM_THREAD_CN shgemm_thread_tn +#define SHGEMM_THREAD_TN shgemm_thread_tn +#define SHGEMM_THREAD_NC shgemm_thread_nt +#define SHGEMM_THREAD_NT shgemm_thread_nt +#define SHGEMM_THREAD_CC shgemm_thread_tt +#define SHGEMM_THREAD_CT shgemm_thread_tt +#define SHGEMM_THREAD_TC shgemm_thread_tt +#define SHGEMM_THREAD_TT shgemm_thread_tt +#define SHGEMM_THREAD_NR shgemm_thread_nn +#define SHGEMM_THREAD_TR shgemm_thread_tn +#define SHGEMM_THREAD_CR shgemm_thread_tn +#define SHGEMM_THREAD_RN shgemm_thread_nn +#define SHGEMM_THREAD_RT shgemm_thread_nt +#define SHGEMM_THREAD_RC shgemm_thread_nt +#define SHGEMM_THREAD_RR shgemm_thread_nn + +#endif + diff --git a/driver/level3/Makefile b/driver/level3/Makefile index e320092e3..881b4ee35 100644 --- a/driver/level3/Makefile +++ b/driver/level3/Makefile @@ -19,6 +19,7 @@ ifeq ($(ARCH), MIPS) USE_GEMM3M = 1 endif +SHBLASOBJS += shgemm_nn.$(SUFFIX) shgemm_nt.$(SUFFIX) shgemm_tn.$(SUFFIX) shgemm_tt.$(SUFFIX) SBLASOBJS += \ sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \ strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \ @@ -204,6 +205,7 @@ COMMONOBJS += syrk_thread.$(SUFFIX) ifndef USE_SIMPLE_THREADED_LEVEL3 +SHBLASOBJS += shgemm_thread_nn.$(SUFFIX) shgemm_thread_nt.$(SUFFIX) shgemm_thread_tn.$(SUFFIX) shgemm_thread_tt.$(SUFFIX) SBLASOBJS += sgemm_thread_nn.$(SUFFIX) sgemm_thread_nt.$(SUFFIX) sgemm_thread_tn.$(SUFFIX) sgemm_thread_tt.$(SUFFIX) DBLASOBJS += dgemm_thread_nn.$(SUFFIX) dgemm_thread_nt.$(SUFFIX) dgemm_thread_tn.$(SUFFIX) dgemm_thread_tt.$(SUFFIX) QBLASOBJS += qgemm_thread_nn.$(SUFFIX) qgemm_thread_nt.$(SUFFIX) qgemm_thread_tn.$(SUFFIX) qgemm_thread_tt.$(SUFFIX) @@ -283,6 +285,18 @@ endif all :: +shgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +shgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +shgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +shgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) @@ -478,6 +492,17 @@ gemm_thread_variable.$(SUFFIX) : gemm_thread_variable.c ../../common.h beta_thread.$(SUFFIX) : beta_thread.c ../../common.h $(CC) -c $(CFLAGS) $< -o $(@F) +shgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +shgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +shgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +shgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) sgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) @@ -2652,6 +2677,18 @@ xtrsm_RCLU.$(SUFFIX) : trsm_R.c xtrsm_RCLN.$(SUFFIX) : trsm_R.c $(CC) -c $(CFLAGS) -DCOMPLEX -DXDOUBLE -DTRANSA -UUPPER -UUNIT -DCONJ $< -o $(@F) +shgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +shgemm_nt.$(PSUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +shgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +shgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h $(CC) $(PFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) @@ -2848,6 +2885,18 @@ beta_thread.$(PSUFFIX) : beta_thread.c ../../common.h $(CC) -c $(PFLAGS) $< -o $(@F) +shgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +shgemm_thread_nt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +shgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +shgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) diff --git a/driver/level3/level3.c b/driver/level3/level3.c index 9aa67286f..c6bbb9ca9 100644 --- a/driver/level3/level3.c +++ b/driver/level3/level3.c @@ -62,18 +62,18 @@ #ifndef ICOPY_OPERATION #if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \ defined(RN) || defined(RT) || defined(RC) || defined(RR) -#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); #else -#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); #endif #endif #ifndef OCOPY_OPERATION #if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \ defined(NR) || defined(TR) || defined(CR) || defined(RR) -#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); #else -#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); #endif #endif @@ -173,7 +173,8 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){ BLASLONG k, lda, ldb, ldc; FLOAT *alpha, *beta; - FLOAT *a, *b, *c; + IFLOAT *a, *b; + FLOAT *c; BLASLONG m_from, m_to, n_from, n_to; BLASLONG ls, is, js; @@ -198,8 +199,8 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, k = K; - a = (FLOAT *)A; - b = (FLOAT *)B; + a = (IFLOAT *)A; + b = (IFLOAT *)B; c = (FLOAT *)C; lda = LDA; diff --git a/driver/level3/level3_thread.c b/driver/level3/level3_thread.c index ca0085e71..5a8d497d2 100644 --- a/driver/level3/level3_thread.c +++ b/driver/level3/level3_thread.c @@ -117,18 +117,18 @@ typedef struct { #ifndef ICOPY_OPERATION #if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \ defined(RN) || defined(RT) || defined(RC) || defined(RR) -#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); #else -#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); #endif #endif #ifndef OCOPY_OPERATION #if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \ defined(NR) || defined(TR) || defined(CR) || defined(RR) -#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); #else -#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); +#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); #endif #endif @@ -219,15 +219,16 @@ typedef struct { #define STOP_RPCC(COUNTER) #endif -static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){ +static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){ - FLOAT *buffer[DIVIDE_RATE]; + IFLOAT *buffer[DIVIDE_RATE]; BLASLONG k, lda, ldb, ldc; BLASLONG m_from, m_to, n_from, n_to; FLOAT *alpha, *beta; - FLOAT *a, *b, *c; + IFLOAT *a, *b; + FLOAT *c; job_t *job = (job_t *)args -> common; BLASLONG nthreads_m; @@ -255,8 +256,8 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, k = K; - a = (FLOAT *)A; - b = (FLOAT *)B; + a = (IFLOAT *)A; + b = (IFLOAT *)B; c = (FLOAT *)C; lda = LDA; @@ -425,7 +426,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, /* Apply kernel with local region of A and part of other region of B */ START_RPCC(); KERNEL_OPERATION(min_i, MIN(range_n[current + 1] - js, div_n), min_l, alpha, - sa, (FLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside], + sa, (IFLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside], c, ldc, m_from, js); STOP_RPCC(kernel); @@ -469,7 +470,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, /* Apply kernel with local region of A and part of region of B */ START_RPCC(); KERNEL_OPERATION(min_i, MIN(range_n[current + 1] - js, div_n), min_l, alpha, - sa, (FLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside], + sa, (IFLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside], c, ldc, is, js); STOP_RPCC(kernel); @@ -532,7 +533,7 @@ static int round_up(int remainder, int width, int multiple) static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG - *range_n, FLOAT *sa, FLOAT *sb, + *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG nthreads_m, BLASLONG nthreads_n) { #ifndef USE_OPENMP @@ -728,7 +729,7 @@ EnterCriticalSection((PCRITICAL_SECTION)&level3_lock); return 0; } -int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){ +int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){ BLASLONG m = args -> m; BLASLONG n = args -> n; diff --git a/driver/others/parameter.c b/driver/others/parameter.c index 8bf7da78b..b1f3befae 100644 --- a/driver/others/parameter.c +++ b/driver/others/parameter.c @@ -62,6 +62,11 @@ BLASLONG gemm_offset_b = DEFAULT_GEMM_OFFSET_B; BLASLONG gemm_offset_b = GEMM_OFFSET_B; #endif +#if SHGEMM_P == shgemm_p +BLASLONG shgemm_p = DEFAULT_GEMM_P; +#else +BLASLONG shgemm_p = SHGEMM_P; +#endif #if SGEMM_P == sgemm_p BLASLONG sgemm_p = DEFAULT_GEMM_P; #else @@ -83,6 +88,11 @@ BLASLONG zgemm_p = DEFAULT_GEMM_P; BLASLONG zgemm_p = ZGEMM_P; #endif +#if SHGEMM_Q == shgemm_q +BLASLONG shgemm_q = DEFAULT_GEMM_Q; +#else +BLASLONG shgemm_q = SHGEMM_Q; +#endif #if SGEMM_Q == sgemm_q BLASLONG sgemm_q = DEFAULT_GEMM_Q; #else @@ -104,6 +114,11 @@ BLASLONG zgemm_q = DEFAULT_GEMM_Q; BLASLONG zgemm_q = ZGEMM_Q; #endif +#if SHGEMM_R == shgemm_r +BLASLONG shgemm_r = DEFAULT_GEMM_R; +#else +BLASLONG shgemm_r = SHGEMM_R; +#endif #if SGEMM_R == sgemm_r BLASLONG sgemm_r = DEFAULT_GEMM_R; #else @@ -597,6 +612,7 @@ void blas_set_parameter(void){ size = BITMASK(cpuid3, 16, 0xff); + shgemm_p = 192 * (size + 1); sgemm_p = 192 * (size + 1); dgemm_p = 96 * (size + 1); cgemm_p = 96 * (size + 1); @@ -610,6 +626,7 @@ void blas_set_parameter(void){ xgemm_p = 16 * (size + 1); #endif + shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q * 4)) - 15) & ~15; sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15; dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15; cgemm_r = (((BUFFER_SIZE - ((CGEMM_P * CGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (CGEMM_Q * 8)) - 15) & ~15; diff --git a/getarch_2nd.c b/getarch_2nd.c index cf9c578cb..a1d0ccac8 100644 --- a/getarch_2nd.c +++ b/getarch_2nd.c @@ -9,6 +9,8 @@ int main(int argc, char **argv) { if ( (argc <= 1) || ((argc >= 2) && (*argv[1] == '0'))) { + printf("SHGEMM_UNROLL_M=%d\n", SHGEMM_DEFAULT_UNROLL_M); + printf("SHGEMM_UNROLL_N=%d\n", SHGEMM_DEFAULT_UNROLL_N); printf("SGEMM_UNROLL_M=%d\n", SGEMM_DEFAULT_UNROLL_M); printf("SGEMM_UNROLL_N=%d\n", SGEMM_DEFAULT_UNROLL_N); printf("DGEMM_UNROLL_M=%d\n", DGEMM_DEFAULT_UNROLL_M); diff --git a/interface/Makefile b/interface/Makefile index 3f0dcca28..741f6bac0 100644 --- a/interface/Makefile +++ b/interface/Makefile @@ -46,6 +46,7 @@ SBLAS3OBJS = \ somatcopy.$(SUFFIX) simatcopy.$(SUFFIX)\ sgeadd.$(SUFFIX) +SHBLAS3OBJS = shgemm.$(SUFFIX) DBLAS1OBJS = \ daxpy.$(SUFFIX) dswap.$(SUFFIX) \ @@ -277,6 +278,8 @@ CSBLAS3OBJS = \ cblas_ssyrk.$(SUFFIX) cblas_ssyr2k.$(SUFFIX) cblas_somatcopy.$(SUFFIX) cblas_simatcopy.$(SUFFIX)\ cblas_sgeadd.$(SUFFIX) +CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX) + CDBLAS1OBJS = \ cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \ cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \ @@ -367,6 +370,7 @@ override CFLAGS += -I. SBLAS1OBJS += $(CSBLAS1OBJS) SBLAS2OBJS += $(CSBLAS2OBJS) SBLAS3OBJS += $(CSBLAS3OBJS) +SHBLAS3OBJS += $(CSHBLAS3OBJS) DBLAS1OBJS += $(CDBLAS1OBJS) DBLAS2OBJS += $(CDBLAS2OBJS) DBLAS3OBJS += $(CDBLAS3OBJS) @@ -380,6 +384,7 @@ ZBLAS3OBJS += $(CZBLAS3OBJS) endif SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) +SHBLASOBJS = $(SHBLAS3OBJS) DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS) CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS) @@ -454,7 +459,7 @@ ZBLASOBJS += $(ZLAPACKOBJS) endif -FUNCOBJS = $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) +FUNCOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) ifdef EXPRECISION FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS) @@ -488,10 +493,10 @@ level1 : $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $ level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ -level3 : $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) +level3 : $(SHBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ -$(CSBLASOBJS) $(CSBLASOBJS_P) $(CDBLASOBJS) $(CDBLASOBJS_P) $(CQBLASOBJS) $(CQBLASOBJS_P) \ +$(CSHBLASOBJS) $(CSHBLASOBJS_P) $(CSBLASOBJS) $(CSBLASOBJS_P) $(CDBLASOBJS) $(CDBLASOBJS_P) $(CQBLASOBJS) $(CQBLASOBJS_P) \ $(CCBLASOBJS) $(CCBLASOBJS_P) $(CZBLASOBJS) $(CZBLASOBJS_P) $(CXBLASOBJS) $(CXBLASOBJS_P) : override CFLAGS += -DCBLAS srot.$(SUFFIX) srot.$(PSUFFIX) : rot.c @@ -1209,6 +1214,9 @@ zhpr2.$(SUFFIX) zhpr2.$(PSUFFIX) : zhpr2.c xhpr2.$(SUFFIX) xhpr2.$(PSUFFIX) : zhpr2.c $(CC) -c $(CFLAGS) $< -o $(@F) +shgemm.$(SUFFIX) shgemm.$(PSUFFIX) : gemm.c ../param.h + $(CC) -c $(CFLAGS) $< -o $(@F) + sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -c $(CFLAGS) $< -o $(@F) @@ -1770,6 +1778,9 @@ cblas_zhemv.$(SUFFIX) cblas_zhemv.$(PSUFFIX) : zhemv.c cblas_sgemm.$(SUFFIX) cblas_sgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) +cblas_shgemm.$(SUFFIX) cblas_shgemm.$(PSUFFIX) : gemm.c ../param.h + $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) + cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) diff --git a/interface/gemm.c b/interface/gemm.c index 0b18d9a8c..99388e7d9 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -77,7 +77,7 @@ #define GEMM_MULTITHREAD_THRESHOLD 4 #endif -static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLASLONG) = { +static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = { #ifndef GEMM3M GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN, GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT, @@ -108,8 +108,8 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLA void NAME(char *TRANSA, char *TRANSB, blasint *M, blasint *N, blasint *K, FLOAT *alpha, - FLOAT *a, blasint *ldA, - FLOAT *b, blasint *ldB, + IFLOAT *a, blasint *ldA, + IFLOAT *b, blasint *ldB, FLOAT *beta, FLOAT *c, blasint *ldC){ @@ -119,8 +119,8 @@ void NAME(char *TRANSA, char *TRANSB, blasint info; char transA, transB; - FLOAT *buffer; - FLOAT *sa, *sb; + IFLOAT *buffer; + IFLOAT *sa, *sb; #ifdef SMP double MNK; diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 6d96abb2e..aee610efb 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -59,6 +59,10 @@ ifeq ($(CORE), Z14) USE_TRMM = 1 endif +SHKERNELOBJS += \ + shgemm_kernel$(TSUFFIX).$(SUFFIX) \ + $(SHGEMMINCOPYOBJ) $(SHGEMMITCOPYOBJ) \ + $(SHGEMMONCOPYOBJ) $(SHGEMMOTCOPYOBJ) SKERNELOBJS += \ sgemm_kernel$(TSUFFIX).$(SUFFIX) \ @@ -93,6 +97,7 @@ XKERNELOBJS += \ $(XGEMMINCOPYOBJ) $(XGEMMITCOPYOBJ) \ $(XGEMMONCOPYOBJ) $(XGEMMOTCOPYOBJ) +SHBLASOBJS += $(SHKERNELOBJS) SBLASOBJS += $(SKERNELOBJS) DBLASOBJS += $(DKERNELOBJS) QBLASOBJS += $(QKERNELOBJS) @@ -100,6 +105,7 @@ CBLASOBJS += $(CKERNELOBJS) ZBLASOBJS += $(ZKERNELOBJS) XBLASOBJS += $(XKERNELOBJS) +SHBLASOBJS += shgemm_beta$(TSUFFIX).$(SUFFIX) SBLASOBJS += \ sgemm_beta$(TSUFFIX).$(SUFFIX) \ strmm_kernel_LN$(TSUFFIX).$(SUFFIX) strmm_kernel_LT$(TSUFFIX).$(SUFFIX) \ @@ -390,6 +396,10 @@ ZBLASOBJS += \ zgeadd_k$(TSUFFIX).$(SUFFIX) +SHGEMMINCOPYOBJ_P = $(SHGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +SHGEMMITCOPYOBJ_P = $(SHGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +SHGEMMONCOPYOBJ_P = $(SHGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +SHGEMMOTCOPYOBJ_P = $(SHGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SGEMMINCOPYOBJ_P = $(SGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SGEMMITCOPYOBJ_P = $(SGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SGEMMONCOPYOBJ_P = $(SGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) @@ -415,6 +425,9 @@ XGEMMITCOPYOBJ_P = $(XGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) XGEMMONCOPYOBJ_P = $(XGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) XGEMMOTCOPYOBJ_P = $(XGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +$(KDIR)shgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA) + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + $(KDIR)sgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_BETA) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ @@ -433,6 +446,36 @@ $(KDIR)zgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_BETA) $(KDIR)xgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMM_BETA) $(CC) $(CFLAGS) -c -DXDOUBLE -DCOMPLEX $< -o $@ +$(KDIR)$(SHGEMMONCOPYOBJ) : $(KERNELDIR)/$(SHGEMMONCOPY) + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)$(SHGEMMOTCOPYOBJ) : $(KERNELDIR)/$(SHGEMMOTCOPY) +ifeq ($(OS), AIX) + $(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemmotcopy.s + m4 shgemmotcopy.s > shgemmotcopy_nomacros.s + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemmotcopy_nomacros.s -o $@ + rm shgemmotcopy.s shgemmotcopy_nomacros.s +else + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ +endif + +ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N)) + +$(KDIR)$(SHGEMMINCOPYOBJ) : $(KERNELDIR)/$(SHGEMMINCOPY) + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)$(SHGEMMITCOPYOBJ) : $(KERNELDIR)/$(SHGEMMITCOPY) +ifeq ($(OS), AIX) + $(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemmitcopy.s + m4 shgemmitcopy.s > shgemmitcopy_nomacros.s + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemmitcopy_nomacros.s -o $@ + rm shgemmitcopy.s shgemmitcopy_nomacros.s +else + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ +endif + +endif + $(KDIR)$(SGEMMONCOPYOBJ) : $(KERNELDIR)/$(SGEMMONCOPY) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ @@ -590,6 +633,16 @@ else $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ endif +$(KDIR)shgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND) +ifeq ($(OS), AIX) + $(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemm_kernel$(TSUFFIX).s + m4 shgemm_kernel$(TSUFFIX).s > shgemm_kernel$(TSUFFIX)_nomacros.s + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemm_kernel$(TSUFFIX)_nomacros.s -o $@ + rm shgemm_kernel$(TSUFFIX).s shgemm_kernel$(TSUFFIX)_nomacros.s +else + $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ +endif + $(KDIR)dgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMMKERNEL) $(DGEMMDEPEND) ifeq ($(OS), AIX) $(CC) $(CFLAGS) -E -DDOUBLE -UCOMPLEX $< -o dgemm_kernel$(TSUFFIX).s @@ -2206,6 +2259,9 @@ $(KDIR)xtrsm_oltncopy$(TSUFFIX).$(SUFFIX) : generic/ztrsm_ltcopy_$(XGEMM_UNROLL_ $(KDIR)sgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMM_BETA) $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ +$(KDIR)shgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA) + $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + $(KDIR)dgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(DGEMM_BETA) $(CC) $(PFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ @@ -2221,6 +2277,20 @@ $(KDIR)zgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(ZGEMM_BETA) $(KDIR)xgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(XGEMM_BETA) $(CC) $(PFLAGS) -c -DXDOUBLE -DCOMPLEX $< -o $@ +$(SHGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMONCOPY) + $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + +$(SHGEMMOTCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMOTCOPY) + $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + +ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N)) +$(SHGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMINCOPY) + $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + +$(SHGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMITCOPY) + $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + +endif $(SGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SGEMMONCOPY) $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ @@ -2325,6 +2395,9 @@ endif endif +$(KDIR)shgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND) + $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@ + $(KDIR)sgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMMKERNEL) $(SGEMMDEPEND) $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ diff --git a/kernel/generic/gemm_beta.c b/kernel/generic/gemm_beta.c index fa9d7680d..ccb772cc7 100644 --- a/kernel/generic/gemm_beta.c +++ b/kernel/generic/gemm_beta.c @@ -39,7 +39,7 @@ #include "common.h" int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, - FLOAT *dummy2, BLASLONG dummy3, FLOAT *dummy4, BLASLONG dummy5, + IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, FLOAT *c, BLASLONG ldc){ diff --git a/kernel/generic/gemm_ncopy_2.c b/kernel/generic/gemm_ncopy_2.c index b728c713f..415860f81 100644 --- a/kernel/generic/gemm_ncopy_2.c +++ b/kernel/generic/gemm_ncopy_2.c @@ -39,10 +39,10 @@ #include #include "common.h" -int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){ +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){ BLASLONG i, j; - FLOAT *a_offset, *a_offset1, *a_offset2; - FLOAT *b_offset; + IFLOAT *a_offset, *a_offset1, *a_offset2; + IFLOAT *b_offset; a_offset = a; b_offset = b; diff --git a/kernel/generic/gemm_tcopy_2.c b/kernel/generic/gemm_tcopy_2.c index 5695b13c2..b4aa4de57 100644 --- a/kernel/generic/gemm_tcopy_2.c +++ b/kernel/generic/gemm_tcopy_2.c @@ -39,11 +39,11 @@ #include #include "common.h" -int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){ +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){ BLASLONG i, j; - FLOAT *a_offset, *a_offset1, *a_offset2; - FLOAT *b_offset, *b_offset1, *b_offset2; + IFLOAT *a_offset, *a_offset1, *a_offset2; + IFLOAT *b_offset, *b_offset1, *b_offset2; a_offset = a; b_offset = b; diff --git a/kernel/generic/gemmkernel_2x2.c b/kernel/generic/gemmkernel_2x2.c index 01f1c67b5..26a88db6d 100644 --- a/kernel/generic/gemmkernel_2x2.c +++ b/kernel/generic/gemmkernel_2x2.c @@ -1,13 +1,32 @@ #include "common.h" -int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FLOAT* C,BLASLONG ldc +#if defined(HALF) && defined(HALFCONVERSION) +float +bfloat16tof32 (bfloat16 f16) +{ + float result = 0; + unsigned short* q = (unsigned short*)(&result); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + q[0] = f16; +#else + q[1] = f16; +#endif + return result; +} +#define BF16TOF32(x) (bfloat16tof32(x)) +#else +#define BF16TOF32(x) x +#endif +int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc #ifdef TRMMKERNEL ,BLASLONG offset #endif ) { BLASLONG i,j,k; - FLOAT *C0,*C1,*ptrba,*ptrbb; - FLOAT res0,res1,res2,res3,load0,load1,load2,load3,load4,load5,load6,load7; + FLOAT *C0,*C1; + IFLOAT *ptrba,*ptrbb; + FLOAT res0,res1,res2,res3; + IFLOAT load0,load1,load2,load3,load4,load5,load6,load7; for (j=0; j +#include +#include "common.h" +#define SGEMM BLASFUNC(sgemm) +#define SHGEMM BLASFUNC(shgemm) +typedef union +{ + unsigned short v; + struct + { +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + unsigned short s:1; + unsigned short e:8; + unsigned short m:7; +#else + unsigned short m:7; + unsigned short e:8; + unsigned short s:1; +#endif + } bits; +} bfloat16_bits; + +int +main (int argc, char *argv[]) +{ + int m, n, k; + int i, j, l; + int ret = 0; + int loop = 20; + char transA = 'N', transB = 'N'; + float alpha = 1.0, beta = 0.0; + char transa = 'N'; + char transb = 'N'; + + for (int x = 0; x <= loop; x++) + { + m = k = n = x; + float A[m * k]; + float B[k * n]; + float C[m * n]; + bfloat16_bits AA[m * k], BB[k * n]; + float CC[m * n]; + + for (int j = 0; j < m; j++) + { + for (int i = 0; i < m; i++) + { + A[j * k + i] = j * 9.0; + B[j * k + i] = i * 2.0; + C[j * k + i] = 0; + AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16; + BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16; + CC[j * k + i] = 0; + } + } + SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, + &m, B, &k, &beta, C, &m); + SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA, + &m, BB, &k, &beta, CC, &m); + + for (i = 0; i < n; i++) + for (j = 0; j < m; j++) + for (l = 0; l < k; l++) + if (CC[i * m + j] != C[i * m + j]) + ret++; + } + fprintf (stderr, "Return code: %d\n", ret); + return ret; +} From ff010f496e255de706067ff54b57e38b69f33c0d Mon Sep 17 00:00:00 2001 From: Rajalakshmi Srinivasaraghavan Date: Tue, 14 Apr 2020 20:38:53 -0500 Subject: [PATCH 02/15] Build shgemm for all architecture --- kernel/Makefile.L3 | 13 +++++++++++++ kernel/power/KERNEL.POWER9 | 11 ----------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index aee610efb..baf0c1c8a 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -59,6 +59,19 @@ ifeq ($(CORE), Z14) USE_TRMM = 1 endif +#ifndef SHGEMMKERNEL +SHGEMM_BETA = ../generic/gemm_beta.c +SHGEMMKERNEL = ../generic/gemmkernel_2x2.c +SHGEMMINCOPY = ../generic/gemm_ncopy_2.c +SHGEMMITCOPY = ../generic/gemm_tcopy_2.c +SHGEMMONCOPY = ../generic/gemm_ncopy_2.c +SHGEMMOTCOPY = ../generic/gemm_tcopy_2.c +SHGEMMINCOPYOBJ = shgemm_incopy$(TSUFFIX).$(SUFFIX) +SHGEMMITCOPYOBJ = shgemm_itcopy$(TSUFFIX).$(SUFFIX) +SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX) +SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX) +#endif + SHKERNELOBJS += \ shgemm_kernel$(TSUFFIX).$(SUFFIX) \ $(SHGEMMINCOPYOBJ) $(SHGEMMITCOPYOBJ) \ diff --git a/kernel/power/KERNEL.POWER9 b/kernel/power/KERNEL.POWER9 index dedb015e8..aabb5d976 100644 --- a/kernel/power/KERNEL.POWER9 +++ b/kernel/power/KERNEL.POWER9 @@ -12,17 +12,6 @@ DTRMMKERNEL = dgemm_kernel_power9.S CTRMMKERNEL = cgemm_kernel_power9.S ZTRMMKERNEL = zgemm_kernel_power9.S -SHGEMM_BETA = ../generic/gemm_beta.c -SHGEMMKERNEL = ../generic/gemmkernel_2x2.c -SHGEMMINCOPY = ../generic/gemm_ncopy_2.c -SHGEMMITCOPY = ../generic/gemm_tcopy_2.c -SHGEMMONCOPY = ../generic/gemm_ncopy_2.c -SHGEMMOTCOPY = ../generic/gemm_tcopy_2.c -SHGEMMINCOPYOBJ = shgemm_incopy$(TSUFFIX).$(SUFFIX) -SHGEMMITCOPYOBJ = shgemm_itcopy$(TSUFFIX).$(SUFFIX) -SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX) -SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX) - SGEMMKERNEL = sgemm_kernel_power9.S SGEMMINCOPY = ../generic/gemm_ncopy_16.c SGEMMITCOPY = sgemm_tcopy_16_power8.S From ac6a22ae7801888df527eb426647b0b55e79f60c Mon Sep 17 00:00:00 2001 From: Rajalakshmi Srinivasaraghavan Date: Tue, 14 Apr 2020 22:58:39 -0500 Subject: [PATCH 03/15] Update header --- common_param.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/common_param.h b/common_param.h index f1cac38d1..6276f7f51 100644 --- a/common_param.h +++ b/common_param.h @@ -41,6 +41,9 @@ #ifndef ASSEMBLER +#ifndef BFLOAT16 +typedef unsigned short bfloat16; +#endif #ifdef DYNAMIC_ARCH typedef struct { From a87793e03c4a073b533ceadafa54cf6c01a66f18 Mon Sep 17 00:00:00 2001 From: Rajalakshmi Srinivasaraghavan Date: Wed, 15 Apr 2020 09:09:50 -0500 Subject: [PATCH 04/15] Fix DYNAMIC_ARCH compilation errors --- common_param.h | 106 +++++++++++++++++++++++++++++--- kernel/generic/gemmkernel_2x2.c | 2 +- kernel/setparam-ref.c | 46 +++++++++++++- 3 files changed, 142 insertions(+), 12 deletions(-) diff --git a/common_param.h b/common_param.h index 6276f7f51..446d42452 100644 --- a/common_param.h +++ b/common_param.h @@ -41,15 +41,110 @@ #ifndef ASSEMBLER +#ifdef DYNAMIC_ARCH + #ifndef BFLOAT16 typedef unsigned short bfloat16; #endif -#ifdef DYNAMIC_ARCH typedef struct { int dtb_entries; int offsetA, offsetB, align; +#if 1 + int shgemm_p, shgemm_q, shgemm_r; + int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; + + float (*shamax_k) (BLASLONG, float *, BLASLONG); + float (*shamin_k) (BLASLONG, float *, BLASLONG); + float (*shmax_k) (BLASLONG, float *, BLASLONG); + float (*shmin_k) (BLASLONG, float *, BLASLONG); +BLASLONG (*ishamax_k)(BLASLONG, float *, BLASLONG); +BLASLONG (*ishamin_k)(BLASLONG, float *, BLASLONG); +BLASLONG (*ishmax_k) (BLASLONG, float *, BLASLONG); +BLASLONG (*ishmin_k) (BLASLONG, float *, BLASLONG); + + float (*shnrm2_k) (BLASLONG, float *, BLASLONG); + float (*shasum_k) (BLASLONG, float *, BLASLONG); + float (*shsum_k) (BLASLONG, float *, BLASLONG); + int (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); + float (*shdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); + double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); + + int (*shrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float); + + int (*shaxpy_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + int (*shscal_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + int (*shswap_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + + int (*shgemv_n) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + int (*shgemv_t) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + int (*shger_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + + int (*shsymv_L) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + int (*shsymv_U) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + + int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); + int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); + + int (*shgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_oncopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_otcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + + int (*shtrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrsm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrsm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + + int (*shtrsm_iunucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iunncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iutucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iutncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ilnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ilnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ounucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ounncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_outucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_outncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_olnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_olnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_oltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + + int (*shtrmm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrmm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrmm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrmm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + + int (*shtrmm_iunucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iunncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iutucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iutncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ilnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ilnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ounucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ounncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_outucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_outncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_olnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_olnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_oltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + + int (*shsymm_iutcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shsymm_iltcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shsymm_outcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shsymm_oltcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + + int (*shneg_tcopy) (BLASLONG, BLASLONG, float *, BLASLONG, float *); + int (*shlaswp_ncopy) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG, blasint *, float *); + +#endif int sgemm_p, sgemm_q, sgemm_r; int sgemm_unroll_m, sgemm_unroll_n, sgemm_unroll_mn; @@ -87,15 +182,6 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); - int shgemm_p, shgemm_q, shgemm_r; - int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; - int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); - int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); - - int (*shgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); - int (*shgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); - int (*shgemm_oncopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); - int (*shgemm_otcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); int (*sgemm_incopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_itcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); diff --git a/kernel/generic/gemmkernel_2x2.c b/kernel/generic/gemmkernel_2x2.c index 26a88db6d..cc7bb8e48 100644 --- a/kernel/generic/gemmkernel_2x2.c +++ b/kernel/generic/gemmkernel_2x2.c @@ -1,6 +1,6 @@ #include "common.h" #if defined(HALF) && defined(HALFCONVERSION) -float +static float bfloat16tof32 (bfloat16 f16) { float result = 0; diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index 12d038901..79cd151f6 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -60,6 +60,15 @@ gotoblas_t TABLE_NAME = { #else MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N), #endif + + samax_kTS, samin_kTS, smax_kTS, smin_kTS, + isamax_kTS, isamin_kTS, ismax_kTS, ismin_kTS, + snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, sdot_kTS, + dsdot_kTS, + srot_kTS, saxpy_kTS, sscal_kTS, sswap_kTS, + sgemv_nTS, sgemv_tTS, sger_kTS, + ssymv_LTS, ssymv_UTS, + shgemm_kernelTS, shgemm_betaTS, #if SHGEMM_DEFAULT_UNROLL_M != SHGEMM_DEFAULT_UNROLL_N shgemm_incopyTS, shgemm_itcopyTS, @@ -67,7 +76,42 @@ gotoblas_t TABLE_NAME = { shgemm_oncopyTS, shgemm_otcopyTS, #endif shgemm_oncopyTS, shgemm_otcopyTS, - sgemm_kernelTS, sgemm_betaTS, + + strsm_kernel_LNTS, strsm_kernel_LTTS, strsm_kernel_RNTS, strsm_kernel_RTTS, +#if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N + strsm_iunucopyTS, strsm_iunncopyTS, strsm_iutucopyTS, strsm_iutncopyTS, + strsm_ilnucopyTS, strsm_ilnncopyTS, strsm_iltucopyTS, strsm_iltncopyTS, +#else + strsm_ounucopyTS, strsm_ounncopyTS, strsm_outucopyTS, strsm_outncopyTS, + strsm_olnucopyTS, strsm_olnncopyTS, strsm_oltucopyTS, strsm_oltncopyTS, +#endif + strsm_ounucopyTS, strsm_ounncopyTS, strsm_outucopyTS, strsm_outncopyTS, + strsm_olnucopyTS, strsm_olnncopyTS, strsm_oltucopyTS, strsm_oltncopyTS, + strmm_kernel_RNTS, strmm_kernel_RTTS, strmm_kernel_LNTS, strmm_kernel_LTTS, +#if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N + strmm_iunucopyTS, strmm_iunncopyTS, strmm_iutucopyTS, strmm_iutncopyTS, + strmm_ilnucopyTS, strmm_ilnncopyTS, strmm_iltucopyTS, strmm_iltncopyTS, +#else + strmm_ounucopyTS, strmm_ounncopyTS, strmm_outucopyTS, strmm_outncopyTS, + strmm_olnucopyTS, strmm_olnncopyTS, strmm_oltucopyTS, strmm_oltncopyTS, +#endif + strmm_ounucopyTS, strmm_ounncopyTS, strmm_outucopyTS, strmm_outncopyTS, + strmm_olnucopyTS, strmm_olnncopyTS, strmm_oltucopyTS, strmm_oltncopyTS, +#if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N + ssymm_iutcopyTS, ssymm_iltcopyTS, +#else + ssymm_outcopyTS, ssymm_oltcopyTS, +#endif + ssymm_outcopyTS, ssymm_oltcopyTS, + +#ifndef NO_LAPACK + sneg_tcopyTS, slaswp_ncopyTS, +#else + NULL,NULL, +#endif + + + 0, 0, 0, SGEMM_DEFAULT_UNROLL_M, SGEMM_DEFAULT_UNROLL_N, #ifdef SGEMM_DEFAULT_UNROLL_MN SGEMM_DEFAULT_UNROLL_MN, From 67cc4b9e16d2e8c017731d2b9eabb5c6b45a9ad5 Mon Sep 17 00:00:00 2001 From: Rajalakshmi Srinivasaraghavan Date: Wed, 15 Apr 2020 19:15:23 -0500 Subject: [PATCH 05/15] Fix warnings in clang and export symbol --- common.h | 9 +-- common_interface.h | 3 - common_level3.h | 3 - common_param.h | 4 -- common_sh.h | 4 +- exports/gensymbol | 4 +- kernel/common_param.h | 129 ++++++++++++++++++++++++++++++++++++++++++ kernel/setparam-ref.c | 8 +-- 8 files changed, 140 insertions(+), 24 deletions(-) diff --git a/common.h b/common.h index 1d8bf07e5..e2c8cdee5 100644 --- a/common.h +++ b/common.h @@ -257,6 +257,11 @@ typedef long BLASLONG; typedef unsigned long BLASULONG; #endif +#ifndef BFLOAT16 +typedef unsigned short bfloat16; +#define HALFCONVERSION 1 +#endif + #ifdef USE64BITINT typedef BLASLONG blasint; #if defined(OS_WINDOWS) && defined(__64BIT__) @@ -298,10 +303,6 @@ typedef int blasint; #define BASE_SHIFT 3 #define ZBASE_SHIFT 4 #elif defined(HALF) -#ifndef BFLOAT16 -typedef unsigned short bfloat16; -#define HALFCONVERSION 1 -#endif #define IFLOAT bfloat16 #define XFLOAT IFLOAT #define FLOAT float diff --git a/common_interface.h b/common_interface.h index 081043af1..78f5be6b0 100644 --- a/common_interface.h +++ b/common_interface.h @@ -37,9 +37,6 @@ /*********************************************************************/ #ifndef ASSEMBLER -#ifndef BFLOAT16 -typedef unsigned short bfloat16; -#endif #ifdef __cplusplus extern "C" { diff --git a/common_level3.h b/common_level3.h index 8194ba6ce..4e44a5e73 100644 --- a/common_level3.h +++ b/common_level3.h @@ -37,9 +37,6 @@ /*********************************************************************/ #ifndef ASSEMBLER -#ifndef BFLOAT16 -typedef unsigned short bfloat16; -#endif #ifdef __CUDACC__ __global__ void cuda_sgemm_kernel(int, int, int, float *, float *, float *); diff --git a/common_param.h b/common_param.h index 446d42452..19a34fa3d 100644 --- a/common_param.h +++ b/common_param.h @@ -43,10 +43,6 @@ #ifdef DYNAMIC_ARCH -#ifndef BFLOAT16 -typedef unsigned short bfloat16; -#endif - typedef struct { int dtb_entries; int offsetA, offsetB, align; diff --git a/common_sh.h b/common_sh.h index 8859694f1..7a0045762 100644 --- a/common_sh.h +++ b/common_sh.h @@ -1,5 +1,5 @@ -#ifndef COMMON_H_H -#define COMMON_H_H +#ifndef COMMON_SH_H +#define COMMON_SH_H #ifndef DYNAMIC_ARCH diff --git a/exports/gensymbol b/exports/gensymbol index d2894e6c8..235446f14 100644 --- a/exports/gensymbol +++ b/exports/gensymbol @@ -30,7 +30,7 @@ icamax,icamin,idamax,idamin,idmax,idmin,isamax,isamin,ismax,ismin, izamax,izamin,lsame,samax,samin,sasum,saxpy,scabs1,scamax, scamin,scasum,scnrm2,scopy,sdot,sdsdot,sgbmv,sgemm,sgemv,sger, - smax,smin,snrm2, + shgemm, smax,smin,snrm2, srot,srotg,srotm,srotmg,ssbmv,sscal,sspmv,sspr2,sspr,sswap, ssymm,ssymv,ssyr2,ssyr2k,ssyr,ssyrk,stbmv,stbsv,stpmv,stpsv, strmm,strmv,strsm,strsv,zaxpy,zcopy,zdotc,zdotu,zdrot, @@ -67,7 +67,7 @@ cblas_isamax, cblas_izamax, cblas_sasum, cblas_saxpy, cblas_scasum, cblas_scnrm2, cblas_scopy, cblas_sdot, cblas_sdsdot, cblas_sgbmv, cblas_sgemm, - cblas_sgemv, cblas_sger, cblas_snrm2, cblas_srot, cblas_srotg, + cblas_sgemv, cblas_sger, cblas_shgemm, cblas_snrm2, cblas_srot, cblas_srotg, cblas_srotm, cblas_srotmg, cblas_ssbmv, cblas_sscal, cblas_sspmv, cblas_sspr2, cblas_sspr, cblas_sswap, cblas_ssymm, cblas_ssymv, cblas_ssyr2, cblas_ssyr2k, cblas_ssyr, cblas_ssyrk, cblas_stbmv, cblas_stbsv, cblas_stpmv, cblas_stpsv, cblas_strmm, cblas_strmv, cblas_strsm, diff --git a/kernel/common_param.h b/kernel/common_param.h index eab14b0a6..29bb65e5c 100644 --- a/kernel/common_param.h +++ b/kernel/common_param.h @@ -47,6 +47,100 @@ typedef struct { int dtb_entries; int offsetA, offsetB, align; +#if 1 + int shgemm_p, shgemm_q, shgemm_r; + int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; + + float (*shamax_k) (BLASLONG, float *, BLASLONG); + float (*shamin_k) (BLASLONG, float *, BLASLONG); + float (*shmax_k) (BLASLONG, float *, BLASLONG); + float (*shmin_k) (BLASLONG, float *, BLASLONG); +BLASLONG (*ishamax_k)(BLASLONG, float *, BLASLONG); +BLASLONG (*ishamin_k)(BLASLONG, float *, BLASLONG); +BLASLONG (*ishmax_k) (BLASLONG, float *, BLASLONG); +BLASLONG (*ishmin_k) (BLASLONG, float *, BLASLONG); + + float (*shnrm2_k) (BLASLONG, float *, BLASLONG); + float (*shasum_k) (BLASLONG, float *, BLASLONG); + float (*shsum_k) (BLASLONG, float *, BLASLONG); + int (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); + float (*shdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); + double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); + + int (*shrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float); + + int (*shaxpy_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + int (*shscal_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + int (*shswap_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + + int (*shgemv_n) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + int (*shgemv_t) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + int (*shger_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + + int (*shsymv_L) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + int (*shsymv_U) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + + int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); + int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); + + int (*shgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_oncopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_otcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + + int (*shtrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrsm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrsm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + + int (*shtrsm_iunucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iunncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iutucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iutncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ilnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ilnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ounucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ounncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_outucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_outncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_olnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_olnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_oltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + + int (*shtrmm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrmm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrmm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrmm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + + int (*shtrmm_iunucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iunncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iutucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iutncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ilnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ilnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ounucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ounncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_outucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_outncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_olnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_olnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_oltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + + int (*shsymm_iutcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shsymm_iltcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shsymm_outcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shsymm_oltcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + + int (*shneg_tcopy) (BLASLONG, BLASLONG, float *, BLASLONG, float *); + int (*shlaswp_ncopy) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG, blasint *, float *); + +#endif int sgemm_p, sgemm_q, sgemm_r; int sgemm_unroll_m, sgemm_unroll_n, sgemm_unroll_mn; @@ -84,6 +178,7 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + int (*sgemm_incopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_itcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_oncopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); @@ -907,6 +1002,13 @@ extern gotoblas_t *gotoblas; #define HAVE_EX_L2 gotoblas -> exclusive_cache +#define SHGEMM_P gotoblas -> shgemm_p +#define SHGEMM_Q gotoblas -> shgemm_q +#define SHGEMM_R gotoblas -> shgemm_r +#define SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m +#define SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n +#define SHGEMM_UNROLL_MN gotoblas -> shgemm_unroll_mn + #define SGEMM_P gotoblas -> sgemm_p #define SGEMM_Q gotoblas -> sgemm_q #define SGEMM_R gotoblas -> sgemm_r @@ -984,6 +1086,17 @@ extern gotoblas_t *gotoblas; #define HAVE_EX_L2 0 #endif +#define SHGEMM_P SHGEMM_DEFAULT_P +#define SHGEMM_Q SHGEMM_DEFAULT_Q +#define SHGEMM_R SHGEMM_DEFAULT_R +#define SHGEMM_UNROLL_M SHGEMM_DEFAULT_UNROLL_M +#define SHGEMM_UNROLL_N SHGEMM_DEFAULT_UNROLL_N +#ifdef SHGEMM_DEFAULT_UNROLL_MN +#define SHGEMM_UNROLL_MN SHGEMM_DEFAULT_UNROLL_MN +#else +#define SHGEMM_UNROLL_MN MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N)) +#endif + #define SGEMM_P SGEMM_DEFAULT_P #define SGEMM_Q SGEMM_DEFAULT_Q #define SGEMM_R SGEMM_DEFAULT_R @@ -1119,6 +1232,18 @@ extern gotoblas_t *gotoblas; #define GEMM_DEFAULT_R DGEMM_DEFAULT_R #define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M #define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N +#elif defined(HALF) +#define GEMM_P SHGEMM_P +#define GEMM_Q SHGEMM_Q +#define GEMM_R SHGEMM_R +#define GEMM_UNROLL_M SHGEMM_UNROLL_M +#define GEMM_UNROLL_N SHGEMM_UNROLL_N +#define GEMM_UNROLL_MN SHGEMM_UNROLL_MN +#define GEMM_DEFAULT_P SHGEMM_DEFAULT_P +#define GEMM_DEFAULT_Q SHGEMM_DEFAULT_Q +#define GEMM_DEFAULT_R SHGEMM_DEFAULT_R +#define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M +#define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N #else #define GEMM_P SGEMM_P #define GEMM_Q SGEMM_Q @@ -1204,6 +1329,10 @@ extern gotoblas_t *gotoblas; #define GEMM_THREAD gemm_thread_n #endif +#ifndef SHGEMM_DEFAULT_R +#define SHGEMM_DEFAULT_R (((BUFFER_SIZE - ((SHGEMM_DEFAULT_P * SHGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SHGEMM_DEFAULT_Q * 4) - 15) & ~15UL) +#endif + #ifndef SGEMM_DEFAULT_R #define SGEMM_DEFAULT_R (((BUFFER_SIZE - ((SGEMM_DEFAULT_P * SGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SGEMM_DEFAULT_Q * 4) - 15) & ~15UL) #endif diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index 79cd151f6..b7cf0f112 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -958,6 +958,8 @@ static void init_parameter(void) { (void) l2; /* dirty trick to suppress unused variable warning for targets */ /* where the GEMM unrolling parameters do not depend on l2 */ + TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P; + TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R; TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q; TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q; TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q; @@ -1329,7 +1331,6 @@ static void init_parameter(void) { - TABLE_NAME.shgemm_p = ((TABLE_NAME.shgemm_p + SHGEMM_DEFAULT_UNROLL_M - 1)/SHGEMM_DEFAULT_UNROLL_M) * SHGEMM_DEFAULT_UNROLL_M; TABLE_NAME.sgemm_p = ((TABLE_NAME.sgemm_p + SGEMM_DEFAULT_UNROLL_M - 1)/SGEMM_DEFAULT_UNROLL_M) * SGEMM_DEFAULT_UNROLL_M; TABLE_NAME.dgemm_p = ((TABLE_NAME.dgemm_p + DGEMM_DEFAULT_UNROLL_M - 1)/DGEMM_DEFAULT_UNROLL_M) * DGEMM_DEFAULT_UNROLL_M; TABLE_NAME.cgemm_p = ((TABLE_NAME.cgemm_p + CGEMM_DEFAULT_UNROLL_M - 1)/CGEMM_DEFAULT_UNROLL_M) * CGEMM_DEFAULT_UNROLL_M; @@ -1357,11 +1358,6 @@ static void init_parameter(void) { fprintf(stderr, "L2 = %8d DGEMM_P .. %d\n", l2, TABLE_NAME.dgemm_p); #endif - TABLE_NAME.shgemm_r = (((BUFFER_SIZE - - ((TABLE_NAME.shgemm_p * TABLE_NAME.shgemm_q * 4 + TABLE_NAME.offsetA - + TABLE_NAME.align) & ~TABLE_NAME.align) - ) / (TABLE_NAME.shgemm_q * 4) - 15) & ~15); - TABLE_NAME.sgemm_r = (((BUFFER_SIZE - ((TABLE_NAME.sgemm_p * TABLE_NAME.sgemm_q * 4 + TABLE_NAME.offsetA + TABLE_NAME.align) & ~TABLE_NAME.align) From 22bb50fb8115909ab8ba4a977913cd6adc1b3290 Mon Sep 17 00:00:00 2001 From: Rajalakshmi Srinivasaraghavan Date: Fri, 17 Apr 2020 13:35:17 -0500 Subject: [PATCH 06/15] cmake fixes --- CMakeLists.txt | 6 ++ cmake/kernel.cmake | 39 +++++++- cmake/utils.cmake | 7 ++ common_macro.h | 213 +++++++++++++++++++++++++++++++++++++++++- ctest/CMakeLists.txt | 3 + kernel/CMakeLists.txt | 21 ++++- lapack/CMakeLists.txt | 4 + 7 files changed, 287 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 951271717..20cf741c4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,6 +89,7 @@ endif () # set which float types we want to build for if (NOT DEFINED BUILD_SINGLE AND NOT DEFINED BUILD_DOUBLE AND NOT DEFINED BUILD_COMPLEX AND NOT DEFINED BUILD_COMPLEX16) # if none are defined, build for all + set(BUILD_HALF true) set(BUILD_SINGLE true) set(BUILD_DOUBLE true) set(BUILD_COMPLEX true) @@ -120,6 +121,11 @@ if (BUILD_COMPLEX16) list(APPEND FLOAT_TYPES "ZCOMPLEX") # defines COMPLEX and DOUBLE endif () +if (BUILD_SINGLE OR BUILD_HALF) + message(STATUS "Building Half Precision") + list(APPEND FLOAT_TYPES "HALF") # defines nothing +endif () + if (NOT DEFINED CORE OR "${CORE}" STREQUAL "UNKNOWN") message(FATAL_ERROR "Detecting CPU failed. Please set TARGET explicitly, e.g. make TARGET=your_cpu_target. Please read README for details.") endif () diff --git a/cmake/kernel.cmake b/cmake/kernel.cmake index 9b238f004..7b64a03fc 100644 --- a/cmake/kernel.cmake +++ b/cmake/kernel.cmake @@ -113,11 +113,29 @@ macro(SetDefaultL1) set(ZSUMKERNEL zsum.S) set(QSUMKERNEL sum.S) set(XSUMKERNEL zsum.S) + set(SHAMINKERNEL ../arm/amin.c) + set(SHAMAXKERNEL amax.S) + set(SHMAXKERNEL ../arm/max.c) + set(SHMINKERNEL ../arm/min.c) + set(ISHAMAXKERNEL iamax.S) + set(ISHAMINKERNEL ../arm/iamin.c) + set(ISHMAXKERNEL ../arm/imax.c) + set(ISHMINKERNEL ../arm/imin.c) + set(SHASUMKERNEL asum.S) + set(SHAXPYKERNEL axpy.S) + set(SHAXPBYKERNEL ../arm/axpby.c) + set(SHCOPYKERNEL copy.S) + set(SHDOTKERNEL dot.S) + set(SHROTKERNEL rot.S) + set(SHSCALKERNEL scal.S) + set(SHNRM2KERNEL nrm2.S) + set(SHSUMKERNEL sum.S) + set(SHSWAPKERNEL swap.S) endmacro () macro(SetDefaultL2) - set(SGEMVNKERNEL gemv_n.S) - set(SGEMVTKERNEL gemv_t.S) + set(SGEMVNKERNEL ../arm/gemv_n.c) + set(SGEMVTKERNEL ../arm/gemv_t.c) set(DGEMVNKERNEL gemv_n.S) set(DGEMVTKERNEL gemv_t.S) set(CGEMVNKERNEL zgemv_n.S) @@ -161,6 +179,10 @@ macro(SetDefaultL2) set(XHEMV_L_KERNEL ../generic/zhemv_k.c) set(XHEMV_V_KERNEL ../generic/zhemv_k.c) set(XHEMV_M_KERNEL ../generic/zhemv_k.c) + set(SHGEMVNKERNEL ../arm/gemv_n.c) + set(SHGEMVTKERNEL ../arm/gemv_t.c) + set(SHGERKERNEL ../generic/ger.c) + endmacro () macro(SetDefaultL3) @@ -168,4 +190,17 @@ macro(SetDefaultL3) set(DGEADD_KERNEL ../generic/geadd.c) set(CGEADD_KERNEL ../generic/zgeadd.c) set(ZGEADD_KERNEL ../generic/zgeadd.c) + set(SHGEADD_KERNEL ../generic/geadd.c) + set(SHGEMMKERNEL ../generic/gemmkernel_2x2.c) + set(SHGEMM_BETA ../generic/gemm_beta.c) + set(SHGEMMINCOPY ../generic/gemm_ncopy_2.c) + set(SHGEMMITCOPY ../generic/gemm_tcopy_2.c) + set(SHGEMMONCOPY ../generic/gemm_ncopy_2.c) + set(SHGEMMOTCOPY ../generic/gemm_tcopy_2.c) + set(SHGEMMINCOPYOBJ shgemm_incopy.o) + set(SHGEMMITCOPYOBJ shgemm_itcopy.o) + set(SHGEMMONCOPYOBJ shgemm_oncopy.o) + set(SHGEMMOTCOPYOBJ shgemm_otcopy.o) + + endmacro () diff --git a/cmake/utils.cmake b/cmake/utils.cmake index fd93f8a70..831ddffe6 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -163,6 +163,7 @@ function(GenerateNamedObjects sources_in) if (complex_only) list(REMOVE_ITEM float_list "SINGLE") list(REMOVE_ITEM float_list "DOUBLE") + list(REMOVE_ITEM float_list "HALF") elseif (real_only) list(REMOVE_ITEM float_list "COMPLEX") list(REMOVE_ITEM float_list "ZCOMPLEX") @@ -176,6 +177,9 @@ function(GenerateNamedObjects sources_in) if (NOT no_float_type) string(SUBSTRING ${float_type} 0 1 float_char) string(TOLOWER ${float_char} float_char) + if (${float_type} STREQUAL "HALF") + set (float_char "sh") + endif () endif () if (NOT name_in) @@ -210,6 +214,9 @@ function(GenerateNamedObjects sources_in) if (${float_type} STREQUAL "DOUBLE" OR ${float_type} STREQUAL "ZCOMPLEX") list(APPEND obj_defines "DOUBLE") endif () + if (${float_type} STREQUAL "HALF") + list(APPEND obj_defines "HALF") + endif () if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX") list(APPEND obj_defines "COMPLEX") if (mangle_complex_sources) diff --git a/common_macro.h b/common_macro.h index b438c83ba..2166e62a2 100644 --- a/common_macro.h +++ b/common_macro.h @@ -646,6 +646,19 @@ #elif defined(HALF) +#define AXPYU_K SAXPYU_K +#define AXPYC_K SAXPYC_K +#define SCAL_K SSCAL_K +#define GEMV_N SGEMV_N +#define GEMV_T SGEMV_T +#define SYMV_U SSYMV_U +#define SYMV_L SSYMV_L +#define GERU_K SGERU_K +#define GERC_K SGERC_K +#define GERV_K SGERV_K +#define GERD_K SGERD_K +#define SYMV_THREAD_U SSYMV_THREAD_U +#define SYMV_THREAD_L SSYMV_THREAD_L #define GEMM_BETA SHGEMM_BETA #define GEMM_KERNEL_N SHGEMM_KERNEL #define GEMM_KERNEL_L SHGEMM_KERNEL @@ -672,6 +685,20 @@ #define GEMM_OTCOPY SHGEMM_OTCOPY #define GEMM_INCOPY SHGEMM_INCOPY #define GEMM_ITCOPY SHGEMM_ITCOPY +#define SYMM_THREAD_LU SSYMM_THREAD_LU +#define SYMM_THREAD_LL SSYMM_THREAD_LL +#define SYMM_THREAD_RU SSYMM_THREAD_RU +#define SYMM_THREAD_RL SSYMM_THREAD_RL +#define SYMM_LU SSYMM_LU +#define SYMM_LL SSYMM_LL +#define SYMM_RU SSYMM_RU +#define SYMM_RL SSYMM_RL + + +#define HEMM_THREAD_LU SHEMM_THREAD_LU +#define HEMM_THREAD_LL SHEMM_THREAD_LL +#define HEMM_THREAD_RU SHEMM_THREAD_RU +#define HEMM_THREAD_RL SHEMM_THREAD_RL #define GEMM_THREAD_NN SHGEMM_THREAD_NN #define GEMM_THREAD_CN SHGEMM_THREAD_TN @@ -690,6 +717,186 @@ #define GEMM_THREAD_RC SHGEMM_THREAD_NT #define GEMM_THREAD_RR SHGEMM_THREAD_NN +#ifdef UNIT + +#define TRMM_OUNCOPY STRMM_OUNUCOPY +#define TRMM_OUTCOPY STRMM_OUTUCOPY +#define TRMM_OLNCOPY STRMM_OLNUCOPY +#define TRMM_OLTCOPY STRMM_OLTUCOPY +#define TRSM_OUNCOPY STRSM_OUNUCOPY +#define TRSM_OUTCOPY STRSM_OUTUCOPY +#define TRSM_OLNCOPY STRSM_OLNUCOPY +#define TRSM_OLTCOPY STRSM_OLTUCOPY + +#define TRMM_IUNCOPY STRMM_IUNUCOPY +#define TRMM_IUTCOPY STRMM_IUTUCOPY +#define TRMM_ILNCOPY STRMM_ILNUCOPY +#define TRMM_ILTCOPY STRMM_ILTUCOPY +#define TRSM_IUNCOPY STRSM_IUNUCOPY +#define TRSM_IUTCOPY STRSM_IUTUCOPY +#define TRSM_ILNCOPY STRSM_ILNUCOPY +#define TRSM_ILTCOPY STRSM_ILTUCOPY + +#else + +#define TRMM_OUNCOPY STRMM_OUNNCOPY +#define TRMM_OUTCOPY STRMM_OUTNCOPY +#define TRMM_OLNCOPY STRMM_OLNNCOPY +#define TRMM_OLTCOPY STRMM_OLTNCOPY +#define TRSM_OUNCOPY STRSM_OUNNCOPY +#define TRSM_OUTCOPY STRSM_OUTNCOPY +#define TRSM_OLNCOPY STRSM_OLNNCOPY +#define TRSM_OLTCOPY STRSM_OLTNCOPY + +#define TRMM_IUNCOPY STRMM_IUNNCOPY +#define TRMM_IUTCOPY STRMM_IUTNCOPY +#define TRMM_ILNCOPY STRMM_ILNNCOPY +#define TRMM_ILTCOPY STRMM_ILTNCOPY +#define TRSM_IUNCOPY STRSM_IUNNCOPY +#define TRSM_IUTCOPY STRSM_IUTNCOPY +#define TRSM_ILNCOPY STRSM_ILNNCOPY +#define TRSM_ILTCOPY STRSM_ILTNCOPY + +#define TRMM_KERNEL_LN STRMM_KERNEL_LN +#define TRMM_KERNEL_LT STRMM_KERNEL_LT +#define TRMM_KERNEL_LR STRMM_KERNEL_LN +#define TRMM_KERNEL_LC STRMM_KERNEL_LT +#define TRMM_KERNEL_RN STRMM_KERNEL_RN +#define TRMM_KERNEL_RT STRMM_KERNEL_RT +#define TRMM_KERNEL_RR STRMM_KERNEL_RN +#define TRMM_KERNEL_RC STRMM_KERNEL_RT + +#define TRSM_KERNEL_LN STRSM_KERNEL_LN +#define TRSM_KERNEL_LT STRSM_KERNEL_LT +#define TRSM_KERNEL_LR STRSM_KERNEL_LN +#define TRSM_KERNEL_LC STRSM_KERNEL_LT +#define TRSM_KERNEL_RN STRSM_KERNEL_RN +#define TRSM_KERNEL_RT STRSM_KERNEL_RT +#define TRSM_KERNEL_RR STRSM_KERNEL_RN +#define TRSM_KERNEL_RC STRSM_KERNEL_RT + +#define SYMM_IUTCOPY SSYMM_IUTCOPY +#define SYMM_ILTCOPY SSYMM_ILTCOPY +#define SYMM_OUTCOPY SSYMM_OUTCOPY +#define SYMM_OLTCOPY SSYMM_OLTCOPY +#define TRMM_LNUU STRMM_LNUU +#define TRMM_LNUN STRMM_LNUN +#define TRMM_LNLU STRMM_LNLU +#define TRMM_LNLN STRMM_LNLN +#define TRMM_LTUU STRMM_LTUU +#define TRMM_LTUN STRMM_LTUN +#define TRMM_LTLU STRMM_LTLU +#define TRMM_LTLN STRMM_LTLN +#define TRMM_LRUU STRMM_LNUU +#define TRMM_LRUN STRMM_LNUN +#define TRMM_LRLU STRMM_LNLU +#define TRMM_LRLN STRMM_LNLN +#define TRMM_LCUU STRMM_LTUU +#define TRMM_LCUN STRMM_LTUN +#define TRMM_LCLU STRMM_LTLU +#define TRMM_LCLN STRMM_LTLN +#define TRMM_RNUU STRMM_RNUU +#define TRMM_RNUN STRMM_RNUN +#define TRMM_RNLU STRMM_RNLU +#define TRMM_RNLN STRMM_RNLN +#define TRMM_RTUU STRMM_RTUU +#define TRMM_RTUN STRMM_RTUN +#define TRMM_RTLU STRMM_RTLU +#define TRMM_RTLN STRMM_RTLN +#define TRMM_RRUU STRMM_RNUU +#define TRMM_RRUN STRMM_RNUN +#define TRMM_RRLU STRMM_RNLU +#define TRMM_RRLN STRMM_RNLN +#define TRMM_RCUU STRMM_RTUU +#define TRMM_RCUN STRMM_RTUN +#define TRMM_RCLU STRMM_RTLU +#define TRMM_RCLN STRMM_RTLN + +#define TRSM_LNUU STRSM_LNUU +#define TRSM_LNUN STRSM_LNUN +#define TRSM_LNLU STRSM_LNLU +#define TRSM_LNLN STRSM_LNLN +#define TRSM_LTUU STRSM_LTUU +#define TRSM_LTUN STRSM_LTUN +#define TRSM_LTLU STRSM_LTLU +#define TRSM_LTLN STRSM_LTLN +#define TRSM_LRUU STRSM_LNUU +#define TRSM_LRUN STRSM_LNUN +#define TRSM_LRLU STRSM_LNLU +#define TRSM_LRLN STRSM_LNLN +#define TRSM_LCUU STRSM_LTUU +#define TRSM_LCUN STRSM_LTUN +#define TRSM_LCLU STRSM_LTLU +#define TRSM_LCLN STRSM_LTLN +#define TRSM_RNUU STRSM_RNUU +#define TRSM_RNUN STRSM_RNUN +#define TRSM_RNLU STRSM_RNLU +#define TRSM_RNLN STRSM_RNLN +#define TRSM_RTUU STRSM_RTUU +#define TRSM_RTUN STRSM_RTUN +#define TRSM_RTLU STRSM_RTLU +#define TRSM_RTLN STRSM_RTLN +#define TRSM_RRUU STRSM_RNUU +#define TRSM_RRUN STRSM_RNUN +#define TRSM_RRLU STRSM_RNLU +#define TRSM_RRLN STRSM_RNLN +#define TRSM_RCUU STRSM_RTUU +#define TRSM_RCUN STRSM_RTUN +#define TRSM_RCLU STRSM_RTLU +#define TRSM_RCLN STRSM_RTLN +#define SYRK_UN SSYRK_UN +#define SYRK_UT SSYRK_UT +#define SYRK_LN SSYRK_LN +#define SYRK_LT SSYRK_LT +#define SYRK_UR SSYRK_UN +#define SYRK_UC SSYRK_UT +#define SYRK_LR SSYRK_LN +#define SYRK_LC SSYRK_LT + +#define SYRK_KERNEL_U SSYRK_KERNEL_U +#define SYRK_KERNEL_L SSYRK_KERNEL_L + +#define HERK_UN SSYRK_UN +#define HERK_LN SSYRK_LN +#define HERK_UC SSYRK_UT +#define HERK_LC SSYRK_LT + +#define HER2K_UN SSYR2K_UN +#define HER2K_LN SSYR2K_LN +#define HER2K_UC SSYR2K_UT +#define HER2K_LC SSYR2K_LT + +#define SYR2K_UN SSYR2K_UN +#define SYR2K_UT SSYR2K_UT +#define SYR2K_LN SSYR2K_LN +#define SYR2K_LT SSYR2K_LT +#define SYR2K_UR SSYR2K_UN +#define SYR2K_UC SSYR2K_UT +#define SYR2K_LR SSYR2K_LN +#define SYR2K_LC SSYR2K_LT + +#define SYR2K_KERNEL_U SSYR2K_KERNEL_U +#define SYR2K_KERNEL_L SSYR2K_KERNEL_L +#define SYRK_THREAD_UN SSYRK_THREAD_UN +#define SYRK_THREAD_UT SSYRK_THREAD_UT +#define SYRK_THREAD_LN SSYRK_THREAD_LN +#define SYRK_THREAD_LT SSYRK_THREAD_LT +#define SYRK_THREAD_UR SSYRK_THREAD_UR +#define SYRK_THREAD_UC SSYRK_THREAD_UC +#define SYRK_THREAD_LR SSYRK_THREAD_LN +#define SYRK_THREAD_LC SSYRK_THREAD_LT + +#define HERK_THREAD_UN SSYRK_THREAD_UN +#define HERK_THREAD_UT SSYRK_THREAD_UT +#define HERK_THREAD_LN SSYRK_THREAD_LN +#define HERK_THREAD_LT SSYRK_THREAD_LT +#define HERK_THREAD_UR SSYRK_THREAD_UR +#define HERK_THREAD_UC SSYRK_THREAD_UC +#define HERK_THREAD_LR SSYRK_THREAD_LN +#define HERK_THREAD_LC SSYRK_THREAD_LT + +#endif + #else #define AMAX_K SAMAX_K @@ -721,14 +928,14 @@ #define GEMV_S SGEMV_S #define GEMV_D SGEMV_D + +#define SYMV_U SSYMV_U +#define SYMV_L SSYMV_L #define GERU_K SGERU_K #define GERC_K SGERC_K #define GERV_K SGERV_K #define GERD_K SGERD_K -#define SYMV_U SSYMV_U -#define SYMV_L SSYMV_L - #define SYMV_THREAD_U SSYMV_THREAD_U #define SYMV_THREAD_L SSYMV_THREAD_L diff --git a/ctest/CMakeLists.txt b/ctest/CMakeLists.txt index 14c9d1944..8d301c239 100644 --- a/ctest/CMakeLists.txt +++ b/ctest/CMakeLists.txt @@ -12,6 +12,9 @@ FILE(WRITE ${CMAKE_CURRENT_BINARY_DIR}/test_cblas_helper.sh foreach(float_type ${FLOAT_TYPES}) string(SUBSTRING ${float_type} 0 1 float_char_upper) string(TOLOWER ${float_char_upper} float_char) + if (${float_char} STREQUAL "h") + continue() + endif() #level1 add_executable(x${float_char}cblat1 c_${float_char}blat1.f diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index 35e0fff25..4113a1647 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -41,6 +41,9 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) foreach (float_type ${FLOAT_TYPES}) # a bit of metaprogramming here to pull out the appropriate KERNEL var string(SUBSTRING ${float_type} 0 1 float_char) + if (${float_type} STREQUAL "HALF") + set (float_char "SH") + endif () GenerateNamedObjects("${KERNELDIR}/${${float_char}AMAXKERNEL}" "USE_ABS" "amax_k" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}AMINKERNEL}" "USE_ABS;USE_MIN" "amin_k" false "" "" false ${float_type}) if (DEFINED ${float_char}MAXKERNEL) @@ -93,6 +96,9 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("generic/ger.c" "" "ger_k" false "" "" "" 3) foreach (float_type ${FLOAT_TYPES}) string(SUBSTRING ${float_type} 0 1 float_char) + if (${float_type} STREQUAL "HALF") + set (float_char "SH") + endif () if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX") GenerateNamedObjects("${KERNELDIR}/${${float_char}GERUKERNEL}" "" "geru_k" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}GERCKERNEL}" "CONJ" "gerc_k" false "" "" false ${float_type}) @@ -128,13 +134,19 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) set(USE_TRMM true) endif () - foreach (float_type SINGLE DOUBLE) + foreach (float_type SINGLE DOUBLE HALF) string(SUBSTRING ${float_type} 0 1 float_char) + if (${float_type} STREQUAL "HALF") + set (float_char "SH") + endif () GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type}) endforeach() foreach (float_type ${FLOAT_TYPES}) string(SUBSTRING ${float_type} 0 1 float_char) + if (${float_type} STREQUAL "HALF") + set (float_char "SH") + endif () if (${float_char}GEMMINCOPY) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMINCOPY}" "${float_type}" "${${float_char}GEMMINCOPYOBJ}" false "" "" true ${float_type}) endif () @@ -470,9 +482,13 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEADD_KERNEL}" "" "geadd_k" false "" "" false ${float_type}) endforeach () + # Makefile.LA if(NOT NO_LAPACK) foreach (float_type ${FLOAT_TYPES}) + if (${float_type} STREQUAL "HALF") + set (float_char "SH") + endif () if (NOT DEFINED ${float_char}NEG_TCOPY) if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C" OR ${float_char} STREQUAL "X") set(${float_char}NEG_TCOPY ../generic/zneg_tcopy.c) @@ -516,6 +532,9 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) foreach (float_type ${FLOAT_TYPES}) # a bit of metaprogramming here to pull out the appropriate KERNEL var string(SUBSTRING ${float_type} 0 1 float_char) + if (${float_type} STREQUAL "HALF") + set (float_char "SH") + endif () GenerateNamedObjects("generic/neg_tcopy_${${float_char}GEMM_UNROLL_M}.c" "" "neg_tcopy" false "" ${TSUFFIX} false ${float_type}) GenerateNamedObjects("generic/laswp_ncopy_${${float_char}GEMM_UNROLL_N}.c" "" "laswp_ncopy" false "" ${TSUFFIX} false ${float_type}) endforeach () diff --git a/lapack/CMakeLists.txt b/lapack/CMakeLists.txt index e21a9aabb..778e6f8fa 100644 --- a/lapack/CMakeLists.txt +++ b/lapack/CMakeLists.txt @@ -2,6 +2,7 @@ include_directories(${PROJECT_SOURCE_DIR}) include_directories(${PROJECT_BINARY_DIR}) +list (REMOVE_ITEM FLOAT_TYPES "HALF") set(LAPACK_SOURCES potrf/potrf_U_single.c @@ -45,6 +46,9 @@ GenerateNamedObjects("laswp/generic/laswp_k_4.c" "" "laswp_plus" false "" "" fa GenerateNamedObjects("laswp/generic/laswp_k_4.c" "MINUS" "laswp_minus" false "" "" false 3) foreach (float_type ${FLOAT_TYPES}) +if (${float_type} STREQUAL "HALF") + continue() +endif() GenerateNamedObjects("getrf/getrf_single.c" "UNIT" "getrf_single" false "" "" false ${float_type}) endforeach () From 9f6d6f6cb69ba871a887ecc9751fbc2d529e1b98 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Fri, 17 Apr 2020 22:27:58 +0200 Subject: [PATCH 07/15] use saxpy.c instead of axpy.S for SHAXPY --- cmake/kernel.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/kernel.cmake b/cmake/kernel.cmake index 7b64a03fc..c8244d833 100644 --- a/cmake/kernel.cmake +++ b/cmake/kernel.cmake @@ -122,7 +122,7 @@ macro(SetDefaultL1) set(ISHMAXKERNEL ../arm/imax.c) set(ISHMINKERNEL ../arm/imin.c) set(SHASUMKERNEL asum.S) - set(SHAXPYKERNEL axpy.S) + set(SHAXPYKERNEL saxpy.c) set(SHAXPBYKERNEL ../arm/axpby.c) set(SHCOPYKERNEL copy.S) set(SHDOTKERNEL dot.S) From f361de30a363d9f262daa9272525468c3b884e27 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sat, 18 Apr 2020 11:07:16 +0200 Subject: [PATCH 08/15] Use generic axpy.c for SHAXPY as x86 lacks saxpy.c --- cmake/kernel.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/kernel.cmake b/cmake/kernel.cmake index c8244d833..38096ad18 100644 --- a/cmake/kernel.cmake +++ b/cmake/kernel.cmake @@ -122,7 +122,7 @@ macro(SetDefaultL1) set(ISHMAXKERNEL ../arm/imax.c) set(ISHMINKERNEL ../arm/imin.c) set(SHASUMKERNEL asum.S) - set(SHAXPYKERNEL saxpy.c) + set(SHAXPYKERNEL ../arm/axpy.c) set(SHAXPBYKERNEL ../arm/axpby.c) set(SHCOPYKERNEL copy.S) set(SHDOTKERNEL dot.S) From e7afe8a969af29e2f25e3d3349c03c9c912b669e Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sat, 18 Apr 2020 11:10:15 +0200 Subject: [PATCH 09/15] Define AXPBY_K fallback for float16 --- common_macro.h | 1 + 1 file changed, 1 insertion(+) diff --git a/common_macro.h b/common_macro.h index 2166e62a2..95e5b1061 100644 --- a/common_macro.h +++ b/common_macro.h @@ -648,6 +648,7 @@ #define AXPYU_K SAXPYU_K #define AXPYC_K SAXPYC_K +#define AXPBY_K SAXPBY_K #define SCAL_K SSCAL_K #define GEMV_N SGEMV_N #define GEMV_T SGEMV_T From 0a19bd813cad97a5adc8577d1b103afadfbd911c Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sat, 18 Apr 2020 12:52:51 +0200 Subject: [PATCH 10/15] Use generic codes for shamax and shcopy --- cmake/kernel.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/kernel.cmake b/cmake/kernel.cmake index 38096ad18..27d1ad630 100644 --- a/cmake/kernel.cmake +++ b/cmake/kernel.cmake @@ -114,7 +114,7 @@ macro(SetDefaultL1) set(QSUMKERNEL sum.S) set(XSUMKERNEL zsum.S) set(SHAMINKERNEL ../arm/amin.c) - set(SHAMAXKERNEL amax.S) + set(SHAMAXKERNEL ../arm/amax.c) set(SHMAXKERNEL ../arm/max.c) set(SHMINKERNEL ../arm/min.c) set(ISHAMAXKERNEL iamax.S) @@ -124,7 +124,7 @@ macro(SetDefaultL1) set(SHASUMKERNEL asum.S) set(SHAXPYKERNEL ../arm/axpy.c) set(SHAXPBYKERNEL ../arm/axpby.c) - set(SHCOPYKERNEL copy.S) + set(SHCOPYKERNEL ../arm/copy.c) set(SHDOTKERNEL dot.S) set(SHROTKERNEL rot.S) set(SHSCALKERNEL scal.S) From a83a59b0381e719011685cda3081e20aa59eaaee Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sat, 18 Apr 2020 15:53:51 +0200 Subject: [PATCH 11/15] Use generic kernels for ishama,shasum,shdot,shrot --- cmake/kernel.cmake | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cmake/kernel.cmake b/cmake/kernel.cmake index 27d1ad630..f50244e7d 100644 --- a/cmake/kernel.cmake +++ b/cmake/kernel.cmake @@ -117,16 +117,16 @@ macro(SetDefaultL1) set(SHAMAXKERNEL ../arm/amax.c) set(SHMAXKERNEL ../arm/max.c) set(SHMINKERNEL ../arm/min.c) - set(ISHAMAXKERNEL iamax.S) + set(ISHAMAXKERNEL ../arm/iamax.c) set(ISHAMINKERNEL ../arm/iamin.c) set(ISHMAXKERNEL ../arm/imax.c) set(ISHMINKERNEL ../arm/imin.c) - set(SHASUMKERNEL asum.S) + set(SHASUMKERNEL ../arm/asum.c) set(SHAXPYKERNEL ../arm/axpy.c) set(SHAXPBYKERNEL ../arm/axpby.c) set(SHCOPYKERNEL ../arm/copy.c) - set(SHDOTKERNEL dot.S) - set(SHROTKERNEL rot.S) + set(SHDOTKERNEL ../arm/dot.c) + set(SHROTKERNEL ../arm/rot.c) set(SHSCALKERNEL scal.S) set(SHNRM2KERNEL nrm2.S) set(SHSUMKERNEL sum.S) From c7d668c2481303e2fab76d86e9b47fe40b361c22 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sat, 18 Apr 2020 16:04:38 +0200 Subject: [PATCH 12/15] Update common_macro.h --- common_macro.h | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/common_macro.h b/common_macro.h index 95e5b1061..9eff94e8e 100644 --- a/common_macro.h +++ b/common_macro.h @@ -646,6 +646,17 @@ #elif defined(HALF) +#define AMAX_K SAMAX_K +#define AMIN_K SAMIN_K +#define MAX_K SMAX_K +#define MIN_K SMIN_K +#define IAMAX_K ISAMAX_K +#define IAMIN_K ISAMIN_K +#define IMAX_K ISMAX_K +#define IMIN_K ISMIN_K +#define ASUM_K SASUM_K +#define DOTU_K SDOTU_K +#define DOTC_K SDOTC_K #define AXPYU_K SAXPYU_K #define AXPYC_K SAXPYC_K #define AXPBY_K SAXPBY_K @@ -658,6 +669,10 @@ #define GERC_K SGERC_K #define GERV_K SGERV_K #define GERD_K SGERD_K +#define SUM_K SSUM_K +#define SWAP_K SSWAP_K +#define ROT_K SROT_K +#define COPY_K SCOPY_K #define SYMV_THREAD_U SSYMV_THREAD_U #define SYMV_THREAD_L SSYMV_THREAD_L #define GEMM_BETA SHGEMM_BETA From 7dbb59b256d47507fa8a11c03b98857b957e42d1 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sat, 18 Apr 2020 21:34:14 +0200 Subject: [PATCH 13/15] Update common_macro.h --- common_macro.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/common_macro.h b/common_macro.h index 9eff94e8e..8fe1f156f 100644 --- a/common_macro.h +++ b/common_macro.h @@ -673,6 +673,7 @@ #define SWAP_K SSWAP_K #define ROT_K SROT_K #define COPY_K SCOPY_K +#define NRM2_K SNRM2_K #define SYMV_THREAD_U SSYMV_THREAD_U #define SYMV_THREAD_L SSYMV_THREAD_L #define GEMM_BETA SHGEMM_BETA @@ -911,6 +912,17 @@ #define HERK_THREAD_LR SSYRK_THREAD_LN #define HERK_THREAD_LC SSYRK_THREAD_LT +#define OMATCOPY_K_CN SOMATCOPY_K_CN +#define OMATCOPY_K_RN SOMATCOPY_K_RN +#define OMATCOPY_K_CT SOMATCOPY_K_CT +#define OMATCOPY_K_RT SOMATCOPY_K_RT +#define IMATCOPY_K_CN SIMATCOPY_K_CN +#define IMATCOPY_K_RN SIMATCOPY_K_RN +#define IMATCOPY_K_CT SIMATCOPY_K_CT +#define IMATCOPY_K_RT SIMATCOPY_K_RT + +#define GEADD_K SGEADD_K + #endif #else From d0737b014288c2808ab679c0a609a37a5f5be286 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sat, 18 Apr 2020 21:36:28 +0200 Subject: [PATCH 14/15] Update kernel.cmake --- cmake/kernel.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/kernel.cmake b/cmake/kernel.cmake index f50244e7d..19e760c56 100644 --- a/cmake/kernel.cmake +++ b/cmake/kernel.cmake @@ -128,7 +128,7 @@ macro(SetDefaultL1) set(SHDOTKERNEL ../arm/dot.c) set(SHROTKERNEL ../arm/rot.c) set(SHSCALKERNEL scal.S) - set(SHNRM2KERNEL nrm2.S) + set(SHNRM2KERNEL ../arm/nrm2.c) set(SHSUMKERNEL sum.S) set(SHSWAPKERNEL swap.S) endmacro () From 4f70512b978c39237d6e7e17bfeaa336b69f957d Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sun, 19 Apr 2020 08:10:26 +0200 Subject: [PATCH 15/15] Update kernel.cmake --- cmake/kernel.cmake | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cmake/kernel.cmake b/cmake/kernel.cmake index 19e760c56..1c1fed571 100644 --- a/cmake/kernel.cmake +++ b/cmake/kernel.cmake @@ -127,10 +127,10 @@ macro(SetDefaultL1) set(SHCOPYKERNEL ../arm/copy.c) set(SHDOTKERNEL ../arm/dot.c) set(SHROTKERNEL ../arm/rot.c) - set(SHSCALKERNEL scal.S) + set(SHSCALKERNEL ../arm/scal.c) set(SHNRM2KERNEL ../arm/nrm2.c) - set(SHSUMKERNEL sum.S) - set(SHSWAPKERNEL swap.S) + set(SHSUMKERNEL ../arm/sum.c) + set(SHSWAPKERNEL ../arm/swap.c) endmacro () macro(SetDefaultL2)