Merge pull request #2559 from RajalakshmiSR/shgemm

Add half precision gemm for bfloat16 in OpenBLAS
This commit is contained in:
Martin Kroeker 2020-04-19 22:09:55 +02:00 committed by GitHub
commit 8a6d26458b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 1162 additions and 73 deletions

View File

@ -89,6 +89,7 @@ endif ()
# set which float types we want to build for # set which float types we want to build for
if (NOT DEFINED BUILD_SINGLE AND NOT DEFINED BUILD_DOUBLE AND NOT DEFINED BUILD_COMPLEX AND NOT DEFINED BUILD_COMPLEX16) if (NOT DEFINED BUILD_SINGLE AND NOT DEFINED BUILD_DOUBLE AND NOT DEFINED BUILD_COMPLEX AND NOT DEFINED BUILD_COMPLEX16)
# if none are defined, build for all # if none are defined, build for all
set(BUILD_HALF true)
set(BUILD_SINGLE true) set(BUILD_SINGLE true)
set(BUILD_DOUBLE true) set(BUILD_DOUBLE true)
set(BUILD_COMPLEX true) set(BUILD_COMPLEX true)
@ -120,6 +121,11 @@ if (BUILD_COMPLEX16)
list(APPEND FLOAT_TYPES "ZCOMPLEX") # defines COMPLEX and DOUBLE list(APPEND FLOAT_TYPES "ZCOMPLEX") # defines COMPLEX and DOUBLE
endif () endif ()
if (BUILD_SINGLE OR BUILD_HALF)
message(STATUS "Building Half Precision")
list(APPEND FLOAT_TYPES "HALF") # defines nothing
endif ()
if (NOT DEFINED CORE OR "${CORE}" STREQUAL "UNKNOWN") if (NOT DEFINED CORE OR "${CORE}" STREQUAL "UNKNOWN")
message(FATAL_ERROR "Detecting CPU failed. Please set TARGET explicitly, e.g. make TARGET=your_cpu_target. Please read README for details.") message(FATAL_ERROR "Detecting CPU failed. Please set TARGET explicitly, e.g. make TARGET=your_cpu_target. Please read README for details.")
endif () endif ()

View File

@ -1390,6 +1390,8 @@ export FUNCTION_PROFILE
export TARGET_CORE export TARGET_CORE
export NO_AVX512 export NO_AVX512
export SHGEMM_UNROLL_M
export SHGEMM_UNROLL_N
export SGEMM_UNROLL_M export SGEMM_UNROLL_M
export SGEMM_UNROLL_N export SGEMM_UNROLL_N
export DGEMM_UNROLL_M export DGEMM_UNROLL_M

View File

@ -1,3 +1,4 @@
SHBLASOBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
@ -9,8 +10,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))
HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))
BLASOBJS = $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) BLASOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
BLASOBJS_P = $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) BLASOBJS_P = $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
ifdef EXPRECISION ifdef EXPRECISION
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
@ -22,6 +23,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P) BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P)
endif endif
$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHALF -UDOUBLE -UCOMPLEX
$(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX $(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX
$(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX $(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX
$(QBLASOBJS) $(QBLASOBJS_P) : override CFLAGS += -DXDOUBLE -UCOMPLEX $(QBLASOBJS) $(QBLASOBJS_P) : override CFLAGS += -DXDOUBLE -UCOMPLEX
@ -29,6 +31,7 @@ $(CBLASOBJS) $(CBLASOBJS_P) : override CFLAGS += -UDOUBLE -DCOMPLEX
$(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX $(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(QBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(QBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)

View File

@ -113,11 +113,29 @@ macro(SetDefaultL1)
set(ZSUMKERNEL zsum.S) set(ZSUMKERNEL zsum.S)
set(QSUMKERNEL sum.S) set(QSUMKERNEL sum.S)
set(XSUMKERNEL zsum.S) set(XSUMKERNEL zsum.S)
set(SHAMINKERNEL ../arm/amin.c)
set(SHAMAXKERNEL ../arm/amax.c)
set(SHMAXKERNEL ../arm/max.c)
set(SHMINKERNEL ../arm/min.c)
set(ISHAMAXKERNEL ../arm/iamax.c)
set(ISHAMINKERNEL ../arm/iamin.c)
set(ISHMAXKERNEL ../arm/imax.c)
set(ISHMINKERNEL ../arm/imin.c)
set(SHASUMKERNEL ../arm/asum.c)
set(SHAXPYKERNEL ../arm/axpy.c)
set(SHAXPBYKERNEL ../arm/axpby.c)
set(SHCOPYKERNEL ../arm/copy.c)
set(SHDOTKERNEL ../arm/dot.c)
set(SHROTKERNEL ../arm/rot.c)
set(SHSCALKERNEL ../arm/scal.c)
set(SHNRM2KERNEL ../arm/nrm2.c)
set(SHSUMKERNEL ../arm/sum.c)
set(SHSWAPKERNEL ../arm/swap.c)
endmacro () endmacro ()
macro(SetDefaultL2) macro(SetDefaultL2)
set(SGEMVNKERNEL gemv_n.S) set(SGEMVNKERNEL ../arm/gemv_n.c)
set(SGEMVTKERNEL gemv_t.S) set(SGEMVTKERNEL ../arm/gemv_t.c)
set(DGEMVNKERNEL gemv_n.S) set(DGEMVNKERNEL gemv_n.S)
set(DGEMVTKERNEL gemv_t.S) set(DGEMVTKERNEL gemv_t.S)
set(CGEMVNKERNEL zgemv_n.S) set(CGEMVNKERNEL zgemv_n.S)
@ -161,6 +179,10 @@ macro(SetDefaultL2)
set(XHEMV_L_KERNEL ../generic/zhemv_k.c) set(XHEMV_L_KERNEL ../generic/zhemv_k.c)
set(XHEMV_V_KERNEL ../generic/zhemv_k.c) set(XHEMV_V_KERNEL ../generic/zhemv_k.c)
set(XHEMV_M_KERNEL ../generic/zhemv_k.c) set(XHEMV_M_KERNEL ../generic/zhemv_k.c)
set(SHGEMVNKERNEL ../arm/gemv_n.c)
set(SHGEMVTKERNEL ../arm/gemv_t.c)
set(SHGERKERNEL ../generic/ger.c)
endmacro () endmacro ()
macro(SetDefaultL3) macro(SetDefaultL3)
@ -168,4 +190,17 @@ macro(SetDefaultL3)
set(DGEADD_KERNEL ../generic/geadd.c) set(DGEADD_KERNEL ../generic/geadd.c)
set(CGEADD_KERNEL ../generic/zgeadd.c) set(CGEADD_KERNEL ../generic/zgeadd.c)
set(ZGEADD_KERNEL ../generic/zgeadd.c) set(ZGEADD_KERNEL ../generic/zgeadd.c)
set(SHGEADD_KERNEL ../generic/geadd.c)
set(SHGEMMKERNEL ../generic/gemmkernel_2x2.c)
set(SHGEMM_BETA ../generic/gemm_beta.c)
set(SHGEMMINCOPY ../generic/gemm_ncopy_2.c)
set(SHGEMMITCOPY ../generic/gemm_tcopy_2.c)
set(SHGEMMONCOPY ../generic/gemm_ncopy_2.c)
set(SHGEMMOTCOPY ../generic/gemm_tcopy_2.c)
set(SHGEMMINCOPYOBJ shgemm_incopy.o)
set(SHGEMMITCOPYOBJ shgemm_itcopy.o)
set(SHGEMMONCOPYOBJ shgemm_oncopy.o)
set(SHGEMMOTCOPYOBJ shgemm_otcopy.o)
endmacro () endmacro ()

View File

@ -16,6 +16,8 @@
# HAVE_SSE2 # HAVE_SSE2
# HAVE_SSE3 # HAVE_SSE3
# MAKE # MAKE
# SHGEMM_UNROLL_M
# SHGEMM_UNROLL_N
# SGEMM_UNROLL_M # SGEMM_UNROLL_M
# SGEMM_UNROLL_N # SGEMM_UNROLL_N
# DGEMM_UNROLL_M # DGEMM_UNROLL_M
@ -437,6 +439,8 @@ if (DEFINED CORE AND CMAKE_CROSSCOMPILING AND NOT (${HOST_OS} STREQUAL "WINDOWSS
set(ZGEMM_UNROLL_N 2) set(ZGEMM_UNROLL_N 2)
set(SYMV_P 8) set(SYMV_P 8)
endif() endif()
set(SHGEMM_UNROLL_M 8)
set(SHGEMM_UNROLL_N 4)
# Or should this actually be NUM_CORES? # Or should this actually be NUM_CORES?
if (${NUM_THREADS} GREATER 0) if (${NUM_THREADS} GREATER 0)

View File

@ -530,6 +530,8 @@ endif ()
#export FUNCTION_PROFILE #export FUNCTION_PROFILE
#export TARGET_CORE #export TARGET_CORE
# #
#export SHGEMM_UNROLL_M
#export SHGEMM_UNROLL_N
#export SGEMM_UNROLL_M #export SGEMM_UNROLL_M
#export SGEMM_UNROLL_N #export SGEMM_UNROLL_N
#export DGEMM_UNROLL_M #export DGEMM_UNROLL_M

View File

@ -163,6 +163,7 @@ function(GenerateNamedObjects sources_in)
if (complex_only) if (complex_only)
list(REMOVE_ITEM float_list "SINGLE") list(REMOVE_ITEM float_list "SINGLE")
list(REMOVE_ITEM float_list "DOUBLE") list(REMOVE_ITEM float_list "DOUBLE")
list(REMOVE_ITEM float_list "HALF")
elseif (real_only) elseif (real_only)
list(REMOVE_ITEM float_list "COMPLEX") list(REMOVE_ITEM float_list "COMPLEX")
list(REMOVE_ITEM float_list "ZCOMPLEX") list(REMOVE_ITEM float_list "ZCOMPLEX")
@ -176,6 +177,9 @@ function(GenerateNamedObjects sources_in)
if (NOT no_float_type) if (NOT no_float_type)
string(SUBSTRING ${float_type} 0 1 float_char) string(SUBSTRING ${float_type} 0 1 float_char)
string(TOLOWER ${float_char} float_char) string(TOLOWER ${float_char} float_char)
if (${float_type} STREQUAL "HALF")
set (float_char "sh")
endif ()
endif () endif ()
if (NOT name_in) if (NOT name_in)
@ -210,6 +214,9 @@ function(GenerateNamedObjects sources_in)
if (${float_type} STREQUAL "DOUBLE" OR ${float_type} STREQUAL "ZCOMPLEX") if (${float_type} STREQUAL "DOUBLE" OR ${float_type} STREQUAL "ZCOMPLEX")
list(APPEND obj_defines "DOUBLE") list(APPEND obj_defines "DOUBLE")
endif () endif ()
if (${float_type} STREQUAL "HALF")
list(APPEND obj_defines "HALF")
endif ()
if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX") if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX")
list(APPEND obj_defines "COMPLEX") list(APPEND obj_defines "COMPLEX")
if (mangle_complex_sources) if (mangle_complex_sources)

View File

@ -257,6 +257,11 @@ typedef long BLASLONG;
typedef unsigned long BLASULONG; typedef unsigned long BLASULONG;
#endif #endif
#ifndef BFLOAT16
typedef unsigned short bfloat16;
#define HALFCONVERSION 1
#endif
#ifdef USE64BITINT #ifdef USE64BITINT
typedef BLASLONG blasint; typedef BLASLONG blasint;
#if defined(OS_WINDOWS) && defined(__64BIT__) #if defined(OS_WINDOWS) && defined(__64BIT__)
@ -297,6 +302,13 @@ typedef int blasint;
#define SIZE 8 #define SIZE 8
#define BASE_SHIFT 3 #define BASE_SHIFT 3
#define ZBASE_SHIFT 4 #define ZBASE_SHIFT 4
#elif defined(HALF)
#define IFLOAT bfloat16
#define XFLOAT IFLOAT
#define FLOAT float
#define SIZE 2
#define BASE_SHIFT 1
#define ZBASE_SHIFT 2
#else #else
#define FLOAT float #define FLOAT float
#define SIZE 4 #define SIZE 4
@ -308,6 +320,10 @@ typedef int blasint;
#define XFLOAT FLOAT #define XFLOAT FLOAT
#endif #endif
#ifndef IFLOAT
#define IFLOAT FLOAT
#endif
#ifndef COMPLEX #ifndef COMPLEX
#define COMPSIZE 1 #define COMPSIZE 1
#else #else

View File

@ -469,6 +469,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint
/* Level 3 routines */ /* Level 3 routines */
void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *, void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
float *, blasint *, float *, blasint *, float *, float *, blasint *); float *, blasint *, float *, blasint *, float *, float *, blasint *);
void BLASFUNC(dgemm)(char *, char *, blasint *, blasint *, blasint *, double *, void BLASFUNC(dgemm)(char *, char *, blasint *, blasint *, blasint *, double *,

View File

@ -55,6 +55,8 @@ extern void sgemm_kernel_direct(BLASLONG M, BLASLONG N, BLASLONG K,
extern int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); extern int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
int dgemm_beta(BLASLONG, BLASLONG, BLASLONG, double, int dgemm_beta(BLASLONG, BLASLONG, BLASLONG, double,
@ -76,6 +78,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *,
xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG); xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
#endif #endif
int shgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int shgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int shgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int shgemm_otcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int sgemm_incopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b); int sgemm_incopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b);
int sgemm_itcopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b); int sgemm_itcopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b);
int sgemm_oncopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b); int sgemm_oncopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b);
@ -499,6 +505,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl
int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG); int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG);
@ -527,6 +534,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float
int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG); int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG);
int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG); int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG);
int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); int sgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
int sgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); int sgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
int sgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); int sgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
@ -619,6 +631,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON
int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG); int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG);
#endif #endif
int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); int sgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
int sgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); int sgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
int sgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG); int sgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);

View File

@ -39,6 +39,7 @@
#ifndef COMMON_MACRO #ifndef COMMON_MACRO
#define COMMON_MACRO #define COMMON_MACRO
#include "common_sh.h"
#include "common_s.h" #include "common_s.h"
#include "common_d.h" #include "common_d.h"
#include "common_q.h" #include "common_q.h"
@ -642,6 +643,288 @@
#define IMATCOPY_K_RT DIMATCOPY_K_RT #define IMATCOPY_K_RT DIMATCOPY_K_RT
#define GEADD_K DGEADD_K #define GEADD_K DGEADD_K
#elif defined(HALF)
#define AMAX_K SAMAX_K
#define AMIN_K SAMIN_K
#define MAX_K SMAX_K
#define MIN_K SMIN_K
#define IAMAX_K ISAMAX_K
#define IAMIN_K ISAMIN_K
#define IMAX_K ISMAX_K
#define IMIN_K ISMIN_K
#define ASUM_K SASUM_K
#define DOTU_K SDOTU_K
#define DOTC_K SDOTC_K
#define AXPYU_K SAXPYU_K
#define AXPYC_K SAXPYC_K
#define AXPBY_K SAXPBY_K
#define SCAL_K SSCAL_K
#define GEMV_N SGEMV_N
#define GEMV_T SGEMV_T
#define SYMV_U SSYMV_U
#define SYMV_L SSYMV_L
#define GERU_K SGERU_K
#define GERC_K SGERC_K
#define GERV_K SGERV_K
#define GERD_K SGERD_K
#define SUM_K SSUM_K
#define SWAP_K SSWAP_K
#define ROT_K SROT_K
#define COPY_K SCOPY_K
#define NRM2_K SNRM2_K
#define SYMV_THREAD_U SSYMV_THREAD_U
#define SYMV_THREAD_L SSYMV_THREAD_L
#define GEMM_BETA SHGEMM_BETA
#define GEMM_KERNEL_N SHGEMM_KERNEL
#define GEMM_KERNEL_L SHGEMM_KERNEL
#define GEMM_KERNEL_R SHGEMM_KERNEL
#define GEMM_KERNEL_B SHGEMM_KERNEL
#define GEMM_NN SHGEMM_NN
#define GEMM_CN SHGEMM_TN
#define GEMM_TN SHGEMM_TN
#define GEMM_NC SHGEMM_NT
#define GEMM_NT SHGEMM_NT
#define GEMM_CC SHGEMM_TT
#define GEMM_CT SHGEMM_TT
#define GEMM_TC SHGEMM_TT
#define GEMM_TT SHGEMM_TT
#define GEMM_NR SHGEMM_NN
#define GEMM_TR SHGEMM_TN
#define GEMM_CR SHGEMM_TN
#define GEMM_RN SHGEMM_NN
#define GEMM_RT SHGEMM_NT
#define GEMM_RC SHGEMM_NT
#define GEMM_RR SHGEMM_NN
#define GEMM_ONCOPY SHGEMM_ONCOPY
#define GEMM_OTCOPY SHGEMM_OTCOPY
#define GEMM_INCOPY SHGEMM_INCOPY
#define GEMM_ITCOPY SHGEMM_ITCOPY
#define SYMM_THREAD_LU SSYMM_THREAD_LU
#define SYMM_THREAD_LL SSYMM_THREAD_LL
#define SYMM_THREAD_RU SSYMM_THREAD_RU
#define SYMM_THREAD_RL SSYMM_THREAD_RL
#define SYMM_LU SSYMM_LU
#define SYMM_LL SSYMM_LL
#define SYMM_RU SSYMM_RU
#define SYMM_RL SSYMM_RL
#define HEMM_THREAD_LU SHEMM_THREAD_LU
#define HEMM_THREAD_LL SHEMM_THREAD_LL
#define HEMM_THREAD_RU SHEMM_THREAD_RU
#define HEMM_THREAD_RL SHEMM_THREAD_RL
#define GEMM_THREAD_NN SHGEMM_THREAD_NN
#define GEMM_THREAD_CN SHGEMM_THREAD_TN
#define GEMM_THREAD_TN SHGEMM_THREAD_TN
#define GEMM_THREAD_NC SHGEMM_THREAD_NT
#define GEMM_THREAD_NT SHGEMM_THREAD_NT
#define GEMM_THREAD_CC SHGEMM_THREAD_TT
#define GEMM_THREAD_CT SHGEMM_THREAD_TT
#define GEMM_THREAD_TC SHGEMM_THREAD_TT
#define GEMM_THREAD_TT SHGEMM_THREAD_TT
#define GEMM_THREAD_NR SHGEMM_THREAD_NN
#define GEMM_THREAD_TR SHGEMM_THREAD_TN
#define GEMM_THREAD_CR SHGEMM_THREAD_TN
#define GEMM_THREAD_RN SHGEMM_THREAD_NN
#define GEMM_THREAD_RT SHGEMM_THREAD_NT
#define GEMM_THREAD_RC SHGEMM_THREAD_NT
#define GEMM_THREAD_RR SHGEMM_THREAD_NN
#ifdef UNIT
#define TRMM_OUNCOPY STRMM_OUNUCOPY
#define TRMM_OUTCOPY STRMM_OUTUCOPY
#define TRMM_OLNCOPY STRMM_OLNUCOPY
#define TRMM_OLTCOPY STRMM_OLTUCOPY
#define TRSM_OUNCOPY STRSM_OUNUCOPY
#define TRSM_OUTCOPY STRSM_OUTUCOPY
#define TRSM_OLNCOPY STRSM_OLNUCOPY
#define TRSM_OLTCOPY STRSM_OLTUCOPY
#define TRMM_IUNCOPY STRMM_IUNUCOPY
#define TRMM_IUTCOPY STRMM_IUTUCOPY
#define TRMM_ILNCOPY STRMM_ILNUCOPY
#define TRMM_ILTCOPY STRMM_ILTUCOPY
#define TRSM_IUNCOPY STRSM_IUNUCOPY
#define TRSM_IUTCOPY STRSM_IUTUCOPY
#define TRSM_ILNCOPY STRSM_ILNUCOPY
#define TRSM_ILTCOPY STRSM_ILTUCOPY
#else
#define TRMM_OUNCOPY STRMM_OUNNCOPY
#define TRMM_OUTCOPY STRMM_OUTNCOPY
#define TRMM_OLNCOPY STRMM_OLNNCOPY
#define TRMM_OLTCOPY STRMM_OLTNCOPY
#define TRSM_OUNCOPY STRSM_OUNNCOPY
#define TRSM_OUTCOPY STRSM_OUTNCOPY
#define TRSM_OLNCOPY STRSM_OLNNCOPY
#define TRSM_OLTCOPY STRSM_OLTNCOPY
#define TRMM_IUNCOPY STRMM_IUNNCOPY
#define TRMM_IUTCOPY STRMM_IUTNCOPY
#define TRMM_ILNCOPY STRMM_ILNNCOPY
#define TRMM_ILTCOPY STRMM_ILTNCOPY
#define TRSM_IUNCOPY STRSM_IUNNCOPY
#define TRSM_IUTCOPY STRSM_IUTNCOPY
#define TRSM_ILNCOPY STRSM_ILNNCOPY
#define TRSM_ILTCOPY STRSM_ILTNCOPY
#define TRMM_KERNEL_LN STRMM_KERNEL_LN
#define TRMM_KERNEL_LT STRMM_KERNEL_LT
#define TRMM_KERNEL_LR STRMM_KERNEL_LN
#define TRMM_KERNEL_LC STRMM_KERNEL_LT
#define TRMM_KERNEL_RN STRMM_KERNEL_RN
#define TRMM_KERNEL_RT STRMM_KERNEL_RT
#define TRMM_KERNEL_RR STRMM_KERNEL_RN
#define TRMM_KERNEL_RC STRMM_KERNEL_RT
#define TRSM_KERNEL_LN STRSM_KERNEL_LN
#define TRSM_KERNEL_LT STRSM_KERNEL_LT
#define TRSM_KERNEL_LR STRSM_KERNEL_LN
#define TRSM_KERNEL_LC STRSM_KERNEL_LT
#define TRSM_KERNEL_RN STRSM_KERNEL_RN
#define TRSM_KERNEL_RT STRSM_KERNEL_RT
#define TRSM_KERNEL_RR STRSM_KERNEL_RN
#define TRSM_KERNEL_RC STRSM_KERNEL_RT
#define SYMM_IUTCOPY SSYMM_IUTCOPY
#define SYMM_ILTCOPY SSYMM_ILTCOPY
#define SYMM_OUTCOPY SSYMM_OUTCOPY
#define SYMM_OLTCOPY SSYMM_OLTCOPY
#define TRMM_LNUU STRMM_LNUU
#define TRMM_LNUN STRMM_LNUN
#define TRMM_LNLU STRMM_LNLU
#define TRMM_LNLN STRMM_LNLN
#define TRMM_LTUU STRMM_LTUU
#define TRMM_LTUN STRMM_LTUN
#define TRMM_LTLU STRMM_LTLU
#define TRMM_LTLN STRMM_LTLN
#define TRMM_LRUU STRMM_LNUU
#define TRMM_LRUN STRMM_LNUN
#define TRMM_LRLU STRMM_LNLU
#define TRMM_LRLN STRMM_LNLN
#define TRMM_LCUU STRMM_LTUU
#define TRMM_LCUN STRMM_LTUN
#define TRMM_LCLU STRMM_LTLU
#define TRMM_LCLN STRMM_LTLN
#define TRMM_RNUU STRMM_RNUU
#define TRMM_RNUN STRMM_RNUN
#define TRMM_RNLU STRMM_RNLU
#define TRMM_RNLN STRMM_RNLN
#define TRMM_RTUU STRMM_RTUU
#define TRMM_RTUN STRMM_RTUN
#define TRMM_RTLU STRMM_RTLU
#define TRMM_RTLN STRMM_RTLN
#define TRMM_RRUU STRMM_RNUU
#define TRMM_RRUN STRMM_RNUN
#define TRMM_RRLU STRMM_RNLU
#define TRMM_RRLN STRMM_RNLN
#define TRMM_RCUU STRMM_RTUU
#define TRMM_RCUN STRMM_RTUN
#define TRMM_RCLU STRMM_RTLU
#define TRMM_RCLN STRMM_RTLN
#define TRSM_LNUU STRSM_LNUU
#define TRSM_LNUN STRSM_LNUN
#define TRSM_LNLU STRSM_LNLU
#define TRSM_LNLN STRSM_LNLN
#define TRSM_LTUU STRSM_LTUU
#define TRSM_LTUN STRSM_LTUN
#define TRSM_LTLU STRSM_LTLU
#define TRSM_LTLN STRSM_LTLN
#define TRSM_LRUU STRSM_LNUU
#define TRSM_LRUN STRSM_LNUN
#define TRSM_LRLU STRSM_LNLU
#define TRSM_LRLN STRSM_LNLN
#define TRSM_LCUU STRSM_LTUU
#define TRSM_LCUN STRSM_LTUN
#define TRSM_LCLU STRSM_LTLU
#define TRSM_LCLN STRSM_LTLN
#define TRSM_RNUU STRSM_RNUU
#define TRSM_RNUN STRSM_RNUN
#define TRSM_RNLU STRSM_RNLU
#define TRSM_RNLN STRSM_RNLN
#define TRSM_RTUU STRSM_RTUU
#define TRSM_RTUN STRSM_RTUN
#define TRSM_RTLU STRSM_RTLU
#define TRSM_RTLN STRSM_RTLN
#define TRSM_RRUU STRSM_RNUU
#define TRSM_RRUN STRSM_RNUN
#define TRSM_RRLU STRSM_RNLU
#define TRSM_RRLN STRSM_RNLN
#define TRSM_RCUU STRSM_RTUU
#define TRSM_RCUN STRSM_RTUN
#define TRSM_RCLU STRSM_RTLU
#define TRSM_RCLN STRSM_RTLN
#define SYRK_UN SSYRK_UN
#define SYRK_UT SSYRK_UT
#define SYRK_LN SSYRK_LN
#define SYRK_LT SSYRK_LT
#define SYRK_UR SSYRK_UN
#define SYRK_UC SSYRK_UT
#define SYRK_LR SSYRK_LN
#define SYRK_LC SSYRK_LT
#define SYRK_KERNEL_U SSYRK_KERNEL_U
#define SYRK_KERNEL_L SSYRK_KERNEL_L
#define HERK_UN SSYRK_UN
#define HERK_LN SSYRK_LN
#define HERK_UC SSYRK_UT
#define HERK_LC SSYRK_LT
#define HER2K_UN SSYR2K_UN
#define HER2K_LN SSYR2K_LN
#define HER2K_UC SSYR2K_UT
#define HER2K_LC SSYR2K_LT
#define SYR2K_UN SSYR2K_UN
#define SYR2K_UT SSYR2K_UT
#define SYR2K_LN SSYR2K_LN
#define SYR2K_LT SSYR2K_LT
#define SYR2K_UR SSYR2K_UN
#define SYR2K_UC SSYR2K_UT
#define SYR2K_LR SSYR2K_LN
#define SYR2K_LC SSYR2K_LT
#define SYR2K_KERNEL_U SSYR2K_KERNEL_U
#define SYR2K_KERNEL_L SSYR2K_KERNEL_L
#define SYRK_THREAD_UN SSYRK_THREAD_UN
#define SYRK_THREAD_UT SSYRK_THREAD_UT
#define SYRK_THREAD_LN SSYRK_THREAD_LN
#define SYRK_THREAD_LT SSYRK_THREAD_LT
#define SYRK_THREAD_UR SSYRK_THREAD_UR
#define SYRK_THREAD_UC SSYRK_THREAD_UC
#define SYRK_THREAD_LR SSYRK_THREAD_LN
#define SYRK_THREAD_LC SSYRK_THREAD_LT
#define HERK_THREAD_UN SSYRK_THREAD_UN
#define HERK_THREAD_UT SSYRK_THREAD_UT
#define HERK_THREAD_LN SSYRK_THREAD_LN
#define HERK_THREAD_LT SSYRK_THREAD_LT
#define HERK_THREAD_UR SSYRK_THREAD_UR
#define HERK_THREAD_UC SSYRK_THREAD_UC
#define HERK_THREAD_LR SSYRK_THREAD_LN
#define HERK_THREAD_LC SSYRK_THREAD_LT
#define OMATCOPY_K_CN SOMATCOPY_K_CN
#define OMATCOPY_K_RN SOMATCOPY_K_RN
#define OMATCOPY_K_CT SOMATCOPY_K_CT
#define OMATCOPY_K_RT SOMATCOPY_K_RT
#define IMATCOPY_K_CN SIMATCOPY_K_CN
#define IMATCOPY_K_RN SIMATCOPY_K_RN
#define IMATCOPY_K_CT SIMATCOPY_K_CT
#define IMATCOPY_K_RT SIMATCOPY_K_RT
#define GEADD_K SGEADD_K
#endif
#else #else
#define AMAX_K SAMAX_K #define AMAX_K SAMAX_K
@ -673,14 +956,14 @@
#define GEMV_S SGEMV_S #define GEMV_S SGEMV_S
#define GEMV_D SGEMV_D #define GEMV_D SGEMV_D
#define SYMV_U SSYMV_U
#define SYMV_L SSYMV_L
#define GERU_K SGERU_K #define GERU_K SGERU_K
#define GERC_K SGERC_K #define GERC_K SGERC_K
#define GERV_K SGERV_K #define GERV_K SGERV_K
#define GERD_K SGERD_K #define GERD_K SGERD_K
#define SYMV_U SSYMV_U
#define SYMV_L SSYMV_L
#define SYMV_THREAD_U SSYMV_THREAD_U #define SYMV_THREAD_U SSYMV_THREAD_U
#define SYMV_THREAD_L SSYMV_THREAD_L #define SYMV_THREAD_L SSYMV_THREAD_L
@ -2202,6 +2485,9 @@
#if defined(ARCH_X86) || defined(ARCH_X86_64) || defined(ARCH_IA64) || defined(ARCH_MIPS64) || defined(ARCH_ARM64) #if defined(ARCH_X86) || defined(ARCH_X86_64) || defined(ARCH_IA64) || defined(ARCH_MIPS64) || defined(ARCH_ARM64)
extern BLASLONG gemm_offset_a; extern BLASLONG gemm_offset_a;
extern BLASLONG gemm_offset_b; extern BLASLONG gemm_offset_b;
extern BLASLONG shgemm_p;
extern BLASLONG shgemm_q;
extern BLASLONG shgemm_r;
extern BLASLONG sgemm_p; extern BLASLONG sgemm_p;
extern BLASLONG sgemm_q; extern BLASLONG sgemm_q;
extern BLASLONG sgemm_r; extern BLASLONG sgemm_r;

View File

@ -47,6 +47,100 @@ typedef struct {
int dtb_entries; int dtb_entries;
int offsetA, offsetB, align; int offsetA, offsetB, align;
#if 1
int shgemm_p, shgemm_q, shgemm_r;
int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn;
float (*shamax_k) (BLASLONG, float *, BLASLONG);
float (*shamin_k) (BLASLONG, float *, BLASLONG);
float (*shmax_k) (BLASLONG, float *, BLASLONG);
float (*shmin_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*ishamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*ishamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*ishmax_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*ishmin_k) (BLASLONG, float *, BLASLONG);
float (*shnrm2_k) (BLASLONG, float *, BLASLONG);
float (*shasum_k) (BLASLONG, float *, BLASLONG);
float (*shsum_k) (BLASLONG, float *, BLASLONG);
int (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
float (*shdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
int (*shrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float);
int (*shaxpy_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
int (*shscal_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
int (*shswap_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
int (*shgemv_n) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*shgemv_t) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*shger_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*shsymv_L) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*shsymv_U) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
int (*shgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
int (*shgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
int (*shgemm_oncopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
int (*shgemm_otcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
int (*shtrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrsm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrsm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrsm_iunucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_iunncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_iutucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_iutncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_ilnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_ilnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_iltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_iltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_ounucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_ounncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_outucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_outncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_olnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_olnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_oltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrmm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrmm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrmm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrmm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrmm_iunucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_iunncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_iutucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_iutncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_ilnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_ilnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_iltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_iltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_ounucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_ounncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_outucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_outncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_olnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_olnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_oltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shsymm_iutcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shsymm_iltcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shsymm_outcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shsymm_oltcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shneg_tcopy) (BLASLONG, BLASLONG, float *, BLASLONG, float *);
int (*shlaswp_ncopy) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG, blasint *, float *);
#endif
int sgemm_p, sgemm_q, sgemm_r; int sgemm_p, sgemm_q, sgemm_r;
int sgemm_unroll_m, sgemm_unroll_n, sgemm_unroll_mn; int sgemm_unroll_m, sgemm_unroll_n, sgemm_unroll_mn;
@ -84,6 +178,7 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
int (*sgemm_incopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_incopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *);
int (*sgemm_itcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_itcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *);
int (*sgemm_oncopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_oncopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *);
@ -907,6 +1002,13 @@ extern gotoblas_t *gotoblas;
#define HAVE_EX_L2 gotoblas -> exclusive_cache #define HAVE_EX_L2 gotoblas -> exclusive_cache
#define SHGEMM_P gotoblas -> shgemm_p
#define SHGEMM_Q gotoblas -> shgemm_q
#define SHGEMM_R gotoblas -> shgemm_r
#define SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m
#define SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n
#define SHGEMM_UNROLL_MN gotoblas -> shgemm_unroll_mn
#define SGEMM_P gotoblas -> sgemm_p #define SGEMM_P gotoblas -> sgemm_p
#define SGEMM_Q gotoblas -> sgemm_q #define SGEMM_Q gotoblas -> sgemm_q
#define SGEMM_R gotoblas -> sgemm_r #define SGEMM_R gotoblas -> sgemm_r
@ -984,6 +1086,17 @@ extern gotoblas_t *gotoblas;
#define HAVE_EX_L2 0 #define HAVE_EX_L2 0
#endif #endif
#define SHGEMM_P SHGEMM_DEFAULT_P
#define SHGEMM_Q SHGEMM_DEFAULT_Q
#define SHGEMM_R SHGEMM_DEFAULT_R
#define SHGEMM_UNROLL_M SHGEMM_DEFAULT_UNROLL_M
#define SHGEMM_UNROLL_N SHGEMM_DEFAULT_UNROLL_N
#ifdef SHGEMM_DEFAULT_UNROLL_MN
#define SHGEMM_UNROLL_MN SHGEMM_DEFAULT_UNROLL_MN
#else
#define SHGEMM_UNROLL_MN MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N))
#endif
#define SGEMM_P SGEMM_DEFAULT_P #define SGEMM_P SGEMM_DEFAULT_P
#define SGEMM_Q SGEMM_DEFAULT_Q #define SGEMM_Q SGEMM_DEFAULT_Q
#define SGEMM_R SGEMM_DEFAULT_R #define SGEMM_R SGEMM_DEFAULT_R
@ -1119,6 +1232,18 @@ extern gotoblas_t *gotoblas;
#define GEMM_DEFAULT_R DGEMM_DEFAULT_R #define GEMM_DEFAULT_R DGEMM_DEFAULT_R
#define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M #define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M
#define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N #define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N
#elif defined(HALF)
#define GEMM_P SHGEMM_P
#define GEMM_Q SHGEMM_Q
#define GEMM_R SHGEMM_R
#define GEMM_UNROLL_M SHGEMM_UNROLL_M
#define GEMM_UNROLL_N SHGEMM_UNROLL_N
#define GEMM_UNROLL_MN SHGEMM_UNROLL_MN
#define GEMM_DEFAULT_P SHGEMM_DEFAULT_P
#define GEMM_DEFAULT_Q SHGEMM_DEFAULT_Q
#define GEMM_DEFAULT_R SHGEMM_DEFAULT_R
#define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M
#define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N
#else #else
#define GEMM_P SGEMM_P #define GEMM_P SGEMM_P
#define GEMM_Q SGEMM_Q #define GEMM_Q SGEMM_Q
@ -1204,6 +1329,10 @@ extern gotoblas_t *gotoblas;
#define GEMM_THREAD gemm_thread_n #define GEMM_THREAD gemm_thread_n
#endif #endif
#ifndef SHGEMM_DEFAULT_R
#define SHGEMM_DEFAULT_R (((BUFFER_SIZE - ((SHGEMM_DEFAULT_P * SHGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SHGEMM_DEFAULT_Q * 4) - 15) & ~15)
#endif
#ifndef SGEMM_DEFAULT_R #ifndef SGEMM_DEFAULT_R
#define SGEMM_DEFAULT_R (((BUFFER_SIZE - ((SGEMM_DEFAULT_P * SGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SGEMM_DEFAULT_Q * 4) - 15) & ~15) #define SGEMM_DEFAULT_R (((BUFFER_SIZE - ((SGEMM_DEFAULT_P * SGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SGEMM_DEFAULT_Q * 4) - 15) & ~15)
#endif #endif

65
common_sh.h Normal file
View File

@ -0,0 +1,65 @@
#ifndef COMMON_SH_H
#define COMMON_SH_H
#ifndef DYNAMIC_ARCH
#define SHGEMM_ONCOPY shgemm_oncopy
#define SHGEMM_OTCOPY shgemm_otcopy
#if SHGEMM_DEFAULT_UNROLL_M == SHGEMM_DEFAULT_UNROLL_N
#define SHGEMM_INCOPY shgemm_oncopy
#define SHGEMM_ITCOPY shgemm_otcopy
#else
#define SHGEMM_INCOPY shgemm_incopy
#define SHGEMM_ITCOPY shgemm_itcopy
#endif
#define SHGEMM_BETA shgemm_beta
#define SHGEMM_KERNEL shgemm_kernel
#else
#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy
#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy
#define SHGEMM_INCOPY gotoblas -> shgemm_incopy
#define SHGEMM_ITCOPY gotoblas -> shgemm_itcopy
#define SHGEMM_BETA gotoblas -> shgemm_beta
#define SHGEMM_KERNEL gotoblas -> shgemm_kernel
#endif
#define SHGEMM_NN shgemm_nn
#define SHGEMM_CN shgemm_tn
#define SHGEMM_TN shgemm_tn
#define SHGEMM_NC shgemm_nt
#define SHGEMM_NT shgemm_nt
#define SHGEMM_CC shgemm_tt
#define SHGEMM_CT shgemm_tt
#define SHGEMM_TC shgemm_tt
#define SHGEMM_TT shgemm_tt
#define SHGEMM_NR shgemm_nn
#define SHGEMM_TR shgemm_tn
#define SHGEMM_CR shgemm_tn
#define SHGEMM_RN shgemm_nn
#define SHGEMM_RT shgemm_nt
#define SHGEMM_RC shgemm_nt
#define SHGEMM_RR shgemm_nn
#define SHGEMM_THREAD_NN shgemm_thread_nn
#define SHGEMM_THREAD_CN shgemm_thread_tn
#define SHGEMM_THREAD_TN shgemm_thread_tn
#define SHGEMM_THREAD_NC shgemm_thread_nt
#define SHGEMM_THREAD_NT shgemm_thread_nt
#define SHGEMM_THREAD_CC shgemm_thread_tt
#define SHGEMM_THREAD_CT shgemm_thread_tt
#define SHGEMM_THREAD_TC shgemm_thread_tt
#define SHGEMM_THREAD_TT shgemm_thread_tt
#define SHGEMM_THREAD_NR shgemm_thread_nn
#define SHGEMM_THREAD_TR shgemm_thread_tn
#define SHGEMM_THREAD_CR shgemm_thread_tn
#define SHGEMM_THREAD_RN shgemm_thread_nn
#define SHGEMM_THREAD_RT shgemm_thread_nt
#define SHGEMM_THREAD_RC shgemm_thread_nt
#define SHGEMM_THREAD_RR shgemm_thread_nn
#endif

View File

@ -12,6 +12,9 @@ FILE(WRITE ${CMAKE_CURRENT_BINARY_DIR}/test_cblas_helper.sh
foreach(float_type ${FLOAT_TYPES}) foreach(float_type ${FLOAT_TYPES})
string(SUBSTRING ${float_type} 0 1 float_char_upper) string(SUBSTRING ${float_type} 0 1 float_char_upper)
string(TOLOWER ${float_char_upper} float_char) string(TOLOWER ${float_char_upper} float_char)
if (${float_char} STREQUAL "h")
continue()
endif()
#level1 #level1
add_executable(x${float_char}cblat1 add_executable(x${float_char}cblat1
c_${float_char}blat1.f c_${float_char}blat1.f

View File

@ -19,6 +19,7 @@ ifeq ($(ARCH), MIPS)
USE_GEMM3M = 1 USE_GEMM3M = 1
endif endif
SHBLASOBJS += shgemm_nn.$(SUFFIX) shgemm_nt.$(SUFFIX) shgemm_tn.$(SUFFIX) shgemm_tt.$(SUFFIX)
SBLASOBJS += \ SBLASOBJS += \
sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \ sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \
strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \ strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \
@ -204,6 +205,7 @@ COMMONOBJS += syrk_thread.$(SUFFIX)
ifndef USE_SIMPLE_THREADED_LEVEL3 ifndef USE_SIMPLE_THREADED_LEVEL3
SHBLASOBJS += shgemm_thread_nn.$(SUFFIX) shgemm_thread_nt.$(SUFFIX) shgemm_thread_tn.$(SUFFIX) shgemm_thread_tt.$(SUFFIX)
SBLASOBJS += sgemm_thread_nn.$(SUFFIX) sgemm_thread_nt.$(SUFFIX) sgemm_thread_tn.$(SUFFIX) sgemm_thread_tt.$(SUFFIX) SBLASOBJS += sgemm_thread_nn.$(SUFFIX) sgemm_thread_nt.$(SUFFIX) sgemm_thread_tn.$(SUFFIX) sgemm_thread_tt.$(SUFFIX)
DBLASOBJS += dgemm_thread_nn.$(SUFFIX) dgemm_thread_nt.$(SUFFIX) dgemm_thread_tn.$(SUFFIX) dgemm_thread_tt.$(SUFFIX) DBLASOBJS += dgemm_thread_nn.$(SUFFIX) dgemm_thread_nt.$(SUFFIX) dgemm_thread_tn.$(SUFFIX) dgemm_thread_tt.$(SUFFIX)
QBLASOBJS += qgemm_thread_nn.$(SUFFIX) qgemm_thread_nt.$(SUFFIX) qgemm_thread_tn.$(SUFFIX) qgemm_thread_tt.$(SUFFIX) QBLASOBJS += qgemm_thread_nn.$(SUFFIX) qgemm_thread_nt.$(SUFFIX) qgemm_thread_tn.$(SUFFIX) qgemm_thread_tt.$(SUFFIX)
@ -283,6 +285,18 @@ endif
all :: all ::
shgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
shgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
shgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
shgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
sgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h sgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
@ -478,6 +492,17 @@ gemm_thread_variable.$(SUFFIX) : gemm_thread_variable.c ../../common.h
beta_thread.$(SUFFIX) : beta_thread.c ../../common.h beta_thread.$(SUFFIX) : beta_thread.c ../../common.h
$(CC) -c $(CFLAGS) $< -o $(@F) $(CC) -c $(CFLAGS) $< -o $(@F)
shgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
shgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
shgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
shgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
sgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h sgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
@ -2652,6 +2677,18 @@ xtrsm_RCLU.$(SUFFIX) : trsm_R.c
xtrsm_RCLN.$(SUFFIX) : trsm_R.c xtrsm_RCLN.$(SUFFIX) : trsm_R.c
$(CC) -c $(CFLAGS) -DCOMPLEX -DXDOUBLE -DTRANSA -UUPPER -UUNIT -DCONJ $< -o $(@F) $(CC) -c $(CFLAGS) -DCOMPLEX -DXDOUBLE -DTRANSA -UUPPER -UUNIT -DCONJ $< -o $(@F)
shgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
shgemm_nt.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
shgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
shgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
sgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h sgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) $(CC) $(PFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
@ -2848,6 +2885,18 @@ beta_thread.$(PSUFFIX) : beta_thread.c ../../common.h
$(CC) -c $(PFLAGS) $< -o $(@F) $(CC) -c $(PFLAGS) $< -o $(@F)
shgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
shgemm_thread_nt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
shgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
shgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
sgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h sgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)

View File

@ -62,18 +62,18 @@
#ifndef ICOPY_OPERATION #ifndef ICOPY_OPERATION
#if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \ #if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \
defined(RN) || defined(RT) || defined(RC) || defined(RR) defined(RN) || defined(RT) || defined(RC) || defined(RR)
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); #define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
#else #else
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); #define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
#endif #endif
#endif #endif
#ifndef OCOPY_OPERATION #ifndef OCOPY_OPERATION
#if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \ #if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \
defined(NR) || defined(TR) || defined(CR) || defined(RR) defined(NR) || defined(TR) || defined(CR) || defined(RR)
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); #define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
#else #else
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); #define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
#endif #endif
#endif #endif
@ -173,7 +173,8 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){ XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){
BLASLONG k, lda, ldb, ldc; BLASLONG k, lda, ldb, ldc;
FLOAT *alpha, *beta; FLOAT *alpha, *beta;
FLOAT *a, *b, *c; IFLOAT *a, *b;
FLOAT *c;
BLASLONG m_from, m_to, n_from, n_to; BLASLONG m_from, m_to, n_from, n_to;
BLASLONG ls, is, js; BLASLONG ls, is, js;
@ -198,8 +199,8 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
k = K; k = K;
a = (FLOAT *)A; a = (IFLOAT *)A;
b = (FLOAT *)B; b = (IFLOAT *)B;
c = (FLOAT *)C; c = (FLOAT *)C;
lda = LDA; lda = LDA;

View File

@ -117,18 +117,18 @@ typedef struct {
#ifndef ICOPY_OPERATION #ifndef ICOPY_OPERATION
#if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \ #if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \
defined(RN) || defined(RT) || defined(RC) || defined(RR) defined(RN) || defined(RT) || defined(RC) || defined(RR)
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); #define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
#else #else
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); #define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
#endif #endif
#endif #endif
#ifndef OCOPY_OPERATION #ifndef OCOPY_OPERATION
#if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \ #if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \
defined(NR) || defined(TR) || defined(CR) || defined(RR) defined(NR) || defined(TR) || defined(CR) || defined(RR)
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER); #define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
#else #else
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER); #define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
#endif #endif
#endif #endif
@ -219,15 +219,16 @@ typedef struct {
#define STOP_RPCC(COUNTER) #define STOP_RPCC(COUNTER)
#endif #endif
static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){
FLOAT *buffer[DIVIDE_RATE]; IFLOAT *buffer[DIVIDE_RATE];
BLASLONG k, lda, ldb, ldc; BLASLONG k, lda, ldb, ldc;
BLASLONG m_from, m_to, n_from, n_to; BLASLONG m_from, m_to, n_from, n_to;
FLOAT *alpha, *beta; FLOAT *alpha, *beta;
FLOAT *a, *b, *c; IFLOAT *a, *b;
FLOAT *c;
job_t *job = (job_t *)args -> common; job_t *job = (job_t *)args -> common;
BLASLONG nthreads_m; BLASLONG nthreads_m;
@ -255,8 +256,8 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
k = K; k = K;
a = (FLOAT *)A; a = (IFLOAT *)A;
b = (FLOAT *)B; b = (IFLOAT *)B;
c = (FLOAT *)C; c = (FLOAT *)C;
lda = LDA; lda = LDA;
@ -425,7 +426,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
/* Apply kernel with local region of A and part of other region of B */ /* Apply kernel with local region of A and part of other region of B */
START_RPCC(); START_RPCC();
KERNEL_OPERATION(min_i, MIN(range_n[current + 1] - js, div_n), min_l, alpha, KERNEL_OPERATION(min_i, MIN(range_n[current + 1] - js, div_n), min_l, alpha,
sa, (FLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside], sa, (IFLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside],
c, ldc, m_from, js); c, ldc, m_from, js);
STOP_RPCC(kernel); STOP_RPCC(kernel);
@ -469,7 +470,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
/* Apply kernel with local region of A and part of region of B */ /* Apply kernel with local region of A and part of region of B */
START_RPCC(); START_RPCC();
KERNEL_OPERATION(min_i, MIN(range_n[current + 1] - js, div_n), min_l, alpha, KERNEL_OPERATION(min_i, MIN(range_n[current + 1] - js, div_n), min_l, alpha,
sa, (FLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside], sa, (IFLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside],
c, ldc, is, js); c, ldc, is, js);
STOP_RPCC(kernel); STOP_RPCC(kernel);
@ -532,7 +533,7 @@ static int round_up(int remainder, int width, int multiple)
static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
*range_n, FLOAT *sa, FLOAT *sb, *range_n, IFLOAT *sa, IFLOAT *sb,
BLASLONG nthreads_m, BLASLONG nthreads_n) { BLASLONG nthreads_m, BLASLONG nthreads_n) {
#ifndef USE_OPENMP #ifndef USE_OPENMP
@ -728,7 +729,7 @@ EnterCriticalSection((PCRITICAL_SECTION)&level3_lock);
return 0; return 0;
} }
int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){
BLASLONG m = args -> m; BLASLONG m = args -> m;
BLASLONG n = args -> n; BLASLONG n = args -> n;

View File

@ -62,6 +62,11 @@ BLASLONG gemm_offset_b = DEFAULT_GEMM_OFFSET_B;
BLASLONG gemm_offset_b = GEMM_OFFSET_B; BLASLONG gemm_offset_b = GEMM_OFFSET_B;
#endif #endif
#if SHGEMM_P == shgemm_p
BLASLONG shgemm_p = DEFAULT_GEMM_P;
#else
BLASLONG shgemm_p = SHGEMM_P;
#endif
#if SGEMM_P == sgemm_p #if SGEMM_P == sgemm_p
BLASLONG sgemm_p = DEFAULT_GEMM_P; BLASLONG sgemm_p = DEFAULT_GEMM_P;
#else #else
@ -83,6 +88,11 @@ BLASLONG zgemm_p = DEFAULT_GEMM_P;
BLASLONG zgemm_p = ZGEMM_P; BLASLONG zgemm_p = ZGEMM_P;
#endif #endif
#if SHGEMM_Q == shgemm_q
BLASLONG shgemm_q = DEFAULT_GEMM_Q;
#else
BLASLONG shgemm_q = SHGEMM_Q;
#endif
#if SGEMM_Q == sgemm_q #if SGEMM_Q == sgemm_q
BLASLONG sgemm_q = DEFAULT_GEMM_Q; BLASLONG sgemm_q = DEFAULT_GEMM_Q;
#else #else
@ -104,6 +114,11 @@ BLASLONG zgemm_q = DEFAULT_GEMM_Q;
BLASLONG zgemm_q = ZGEMM_Q; BLASLONG zgemm_q = ZGEMM_Q;
#endif #endif
#if SHGEMM_R == shgemm_r
BLASLONG shgemm_r = DEFAULT_GEMM_R;
#else
BLASLONG shgemm_r = SHGEMM_R;
#endif
#if SGEMM_R == sgemm_r #if SGEMM_R == sgemm_r
BLASLONG sgemm_r = DEFAULT_GEMM_R; BLASLONG sgemm_r = DEFAULT_GEMM_R;
#else #else
@ -597,6 +612,7 @@ void blas_set_parameter(void){
size = BITMASK(cpuid3, 16, 0xff); size = BITMASK(cpuid3, 16, 0xff);
shgemm_p = 192 * (size + 1);
sgemm_p = 192 * (size + 1); sgemm_p = 192 * (size + 1);
dgemm_p = 96 * (size + 1); dgemm_p = 96 * (size + 1);
cgemm_p = 96 * (size + 1); cgemm_p = 96 * (size + 1);
@ -610,6 +626,7 @@ void blas_set_parameter(void){
xgemm_p = 16 * (size + 1); xgemm_p = 16 * (size + 1);
#endif #endif
shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q * 4)) - 15) & ~15;
sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15; sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15;
dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15; dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15;
cgemm_r = (((BUFFER_SIZE - ((CGEMM_P * CGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (CGEMM_Q * 8)) - 15) & ~15; cgemm_r = (((BUFFER_SIZE - ((CGEMM_P * CGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (CGEMM_Q * 8)) - 15) & ~15;

View File

@ -30,7 +30,7 @@
icamax,icamin,idamax,idamin,idmax,idmin,isamax,isamin,ismax,ismin, icamax,icamin,idamax,idamin,idmax,idmin,isamax,isamin,ismax,ismin,
izamax,izamin,lsame,samax,samin,sasum,saxpy,scabs1,scamax, izamax,izamin,lsame,samax,samin,sasum,saxpy,scabs1,scamax,
scamin,scasum,scnrm2,scopy,sdot,sdsdot,sgbmv,sgemm,sgemv,sger, scamin,scasum,scnrm2,scopy,sdot,sdsdot,sgbmv,sgemm,sgemv,sger,
smax,smin,snrm2, shgemm, smax,smin,snrm2,
srot,srotg,srotm,srotmg,ssbmv,sscal,sspmv,sspr2,sspr,sswap, srot,srotg,srotm,srotmg,ssbmv,sscal,sspmv,sspr2,sspr,sswap,
ssymm,ssymv,ssyr2,ssyr2k,ssyr,ssyrk,stbmv,stbsv,stpmv,stpsv, ssymm,ssymv,ssyr2,ssyr2k,ssyr,ssyrk,stbmv,stbsv,stpmv,stpsv,
strmm,strmv,strsm,strsv,zaxpy,zcopy,zdotc,zdotu,zdrot, strmm,strmv,strsm,strsv,zaxpy,zcopy,zdotc,zdotu,zdrot,
@ -67,7 +67,7 @@
cblas_isamax, cblas_izamax, cblas_isamax, cblas_izamax,
cblas_sasum, cblas_saxpy, cblas_sasum, cblas_saxpy,
cblas_scasum, cblas_scnrm2, cblas_scopy, cblas_sdot, cblas_sdsdot, cblas_sgbmv, cblas_sgemm, cblas_scasum, cblas_scnrm2, cblas_scopy, cblas_sdot, cblas_sdsdot, cblas_sgbmv, cblas_sgemm,
cblas_sgemv, cblas_sger, cblas_snrm2, cblas_srot, cblas_srotg, cblas_sgemv, cblas_sger, cblas_shgemm, cblas_snrm2, cblas_srot, cblas_srotg,
cblas_srotm, cblas_srotmg, cblas_ssbmv, cblas_sscal, cblas_sspmv, cblas_sspr2, cblas_sspr, cblas_srotm, cblas_srotmg, cblas_ssbmv, cblas_sscal, cblas_sspmv, cblas_sspr2, cblas_sspr,
cblas_sswap, cblas_ssymm, cblas_ssymv, cblas_ssyr2, cblas_ssyr2k, cblas_ssyr, cblas_ssyrk, cblas_sswap, cblas_ssymm, cblas_ssymv, cblas_ssyr2, cblas_ssyr2k, cblas_ssyr, cblas_ssyrk,
cblas_stbmv, cblas_stbsv, cblas_stpmv, cblas_stpsv, cblas_strmm, cblas_strmv, cblas_strsm, cblas_stbmv, cblas_stbsv, cblas_stpmv, cblas_stpsv, cblas_strmm, cblas_strmv, cblas_strsm,

View File

@ -9,6 +9,8 @@
int main(int argc, char **argv) { int main(int argc, char **argv) {
if ( (argc <= 1) || ((argc >= 2) && (*argv[1] == '0'))) { if ( (argc <= 1) || ((argc >= 2) && (*argv[1] == '0'))) {
printf("SHGEMM_UNROLL_M=%d\n", SHGEMM_DEFAULT_UNROLL_M);
printf("SHGEMM_UNROLL_N=%d\n", SHGEMM_DEFAULT_UNROLL_N);
printf("SGEMM_UNROLL_M=%d\n", SGEMM_DEFAULT_UNROLL_M); printf("SGEMM_UNROLL_M=%d\n", SGEMM_DEFAULT_UNROLL_M);
printf("SGEMM_UNROLL_N=%d\n", SGEMM_DEFAULT_UNROLL_N); printf("SGEMM_UNROLL_N=%d\n", SGEMM_DEFAULT_UNROLL_N);
printf("DGEMM_UNROLL_M=%d\n", DGEMM_DEFAULT_UNROLL_M); printf("DGEMM_UNROLL_M=%d\n", DGEMM_DEFAULT_UNROLL_M);

View File

@ -46,6 +46,7 @@ SBLAS3OBJS = \
somatcopy.$(SUFFIX) simatcopy.$(SUFFIX)\ somatcopy.$(SUFFIX) simatcopy.$(SUFFIX)\
sgeadd.$(SUFFIX) sgeadd.$(SUFFIX)
SHBLAS3OBJS = shgemm.$(SUFFIX)
DBLAS1OBJS = \ DBLAS1OBJS = \
daxpy.$(SUFFIX) dswap.$(SUFFIX) \ daxpy.$(SUFFIX) dswap.$(SUFFIX) \
@ -277,6 +278,8 @@ CSBLAS3OBJS = \
cblas_ssyrk.$(SUFFIX) cblas_ssyr2k.$(SUFFIX) cblas_somatcopy.$(SUFFIX) cblas_simatcopy.$(SUFFIX)\ cblas_ssyrk.$(SUFFIX) cblas_ssyr2k.$(SUFFIX) cblas_somatcopy.$(SUFFIX) cblas_simatcopy.$(SUFFIX)\
cblas_sgeadd.$(SUFFIX) cblas_sgeadd.$(SUFFIX)
CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX)
CDBLAS1OBJS = \ CDBLAS1OBJS = \
cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \ cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \
cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \ cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \
@ -367,6 +370,7 @@ override CFLAGS += -I.
SBLAS1OBJS += $(CSBLAS1OBJS) SBLAS1OBJS += $(CSBLAS1OBJS)
SBLAS2OBJS += $(CSBLAS2OBJS) SBLAS2OBJS += $(CSBLAS2OBJS)
SBLAS3OBJS += $(CSBLAS3OBJS) SBLAS3OBJS += $(CSBLAS3OBJS)
SHBLAS3OBJS += $(CSHBLAS3OBJS)
DBLAS1OBJS += $(CDBLAS1OBJS) DBLAS1OBJS += $(CDBLAS1OBJS)
DBLAS2OBJS += $(CDBLAS2OBJS) DBLAS2OBJS += $(CDBLAS2OBJS)
DBLAS3OBJS += $(CDBLAS3OBJS) DBLAS3OBJS += $(CDBLAS3OBJS)
@ -380,6 +384,7 @@ ZBLAS3OBJS += $(CZBLAS3OBJS)
endif endif
SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS)
SHBLASOBJS = $(SHBLAS3OBJS)
DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS)
QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS) QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS)
CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS) CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS)
@ -454,7 +459,7 @@ ZBLASOBJS += $(ZLAPACKOBJS)
endif endif
FUNCOBJS = $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) FUNCOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
ifdef EXPRECISION ifdef EXPRECISION
FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS) FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS)
@ -488,10 +493,10 @@ level1 : $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $
level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
level3 : $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) level3 : $(SHBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
$(CSBLASOBJS) $(CSBLASOBJS_P) $(CDBLASOBJS) $(CDBLASOBJS_P) $(CQBLASOBJS) $(CQBLASOBJS_P) \ $(CSHBLASOBJS) $(CSHBLASOBJS_P) $(CSBLASOBJS) $(CSBLASOBJS_P) $(CDBLASOBJS) $(CDBLASOBJS_P) $(CQBLASOBJS) $(CQBLASOBJS_P) \
$(CCBLASOBJS) $(CCBLASOBJS_P) $(CZBLASOBJS) $(CZBLASOBJS_P) $(CXBLASOBJS) $(CXBLASOBJS_P) : override CFLAGS += -DCBLAS $(CCBLASOBJS) $(CCBLASOBJS_P) $(CZBLASOBJS) $(CZBLASOBJS_P) $(CXBLASOBJS) $(CXBLASOBJS_P) : override CFLAGS += -DCBLAS
srot.$(SUFFIX) srot.$(PSUFFIX) : rot.c srot.$(SUFFIX) srot.$(PSUFFIX) : rot.c
@ -1209,6 +1214,9 @@ zhpr2.$(SUFFIX) zhpr2.$(PSUFFIX) : zhpr2.c
xhpr2.$(SUFFIX) xhpr2.$(PSUFFIX) : zhpr2.c xhpr2.$(SUFFIX) xhpr2.$(PSUFFIX) : zhpr2.c
$(CC) -c $(CFLAGS) $< -o $(@F) $(CC) -c $(CFLAGS) $< -o $(@F)
shgemm.$(SUFFIX) shgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -c $(CFLAGS) $< -o $(@F)
sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -c $(CFLAGS) $< -o $(@F) $(CC) -c $(CFLAGS) $< -o $(@F)
@ -1770,6 +1778,9 @@ cblas_zhemv.$(SUFFIX) cblas_zhemv.$(PSUFFIX) : zhemv.c
cblas_sgemm.$(SUFFIX) cblas_sgemm.$(PSUFFIX) : gemm.c ../param.h cblas_sgemm.$(SUFFIX) cblas_sgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
cblas_shgemm.$(SUFFIX) cblas_shgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)

View File

@ -77,7 +77,7 @@
#define GEMM_MULTITHREAD_THRESHOLD 4 #define GEMM_MULTITHREAD_THRESHOLD 4
#endif #endif
static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLASLONG) = { static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = {
#ifndef GEMM3M #ifndef GEMM3M
GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN, GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT, GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
@ -108,8 +108,8 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLA
void NAME(char *TRANSA, char *TRANSB, void NAME(char *TRANSA, char *TRANSB,
blasint *M, blasint *N, blasint *K, blasint *M, blasint *N, blasint *K,
FLOAT *alpha, FLOAT *alpha,
FLOAT *a, blasint *ldA, IFLOAT *a, blasint *ldA,
FLOAT *b, blasint *ldB, IFLOAT *b, blasint *ldB,
FLOAT *beta, FLOAT *beta,
FLOAT *c, blasint *ldC){ FLOAT *c, blasint *ldC){
@ -119,8 +119,8 @@ void NAME(char *TRANSA, char *TRANSB,
blasint info; blasint info;
char transA, transB; char transA, transB;
FLOAT *buffer; IFLOAT *buffer;
FLOAT *sa, *sb; IFLOAT *sa, *sb;
#ifdef SMP #ifdef SMP
double MNK; double MNK;

View File

@ -41,6 +41,9 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
foreach (float_type ${FLOAT_TYPES}) foreach (float_type ${FLOAT_TYPES})
# a bit of metaprogramming here to pull out the appropriate KERNEL var # a bit of metaprogramming here to pull out the appropriate KERNEL var
string(SUBSTRING ${float_type} 0 1 float_char) string(SUBSTRING ${float_type} 0 1 float_char)
if (${float_type} STREQUAL "HALF")
set (float_char "SH")
endif ()
GenerateNamedObjects("${KERNELDIR}/${${float_char}AMAXKERNEL}" "USE_ABS" "amax_k" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}AMAXKERNEL}" "USE_ABS" "amax_k" false "" "" false ${float_type})
GenerateNamedObjects("${KERNELDIR}/${${float_char}AMINKERNEL}" "USE_ABS;USE_MIN" "amin_k" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}AMINKERNEL}" "USE_ABS;USE_MIN" "amin_k" false "" "" false ${float_type})
if (DEFINED ${float_char}MAXKERNEL) if (DEFINED ${float_char}MAXKERNEL)
@ -93,6 +96,9 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
GenerateNamedObjects("generic/ger.c" "" "ger_k" false "" "" "" 3) GenerateNamedObjects("generic/ger.c" "" "ger_k" false "" "" "" 3)
foreach (float_type ${FLOAT_TYPES}) foreach (float_type ${FLOAT_TYPES})
string(SUBSTRING ${float_type} 0 1 float_char) string(SUBSTRING ${float_type} 0 1 float_char)
if (${float_type} STREQUAL "HALF")
set (float_char "SH")
endif ()
if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX") if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX")
GenerateNamedObjects("${KERNELDIR}/${${float_char}GERUKERNEL}" "" "geru_k" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}GERUKERNEL}" "" "geru_k" false "" "" false ${float_type})
GenerateNamedObjects("${KERNELDIR}/${${float_char}GERCKERNEL}" "CONJ" "gerc_k" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}GERCKERNEL}" "CONJ" "gerc_k" false "" "" false ${float_type})
@ -128,13 +134,19 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
set(USE_TRMM true) set(USE_TRMM true)
endif () endif ()
foreach (float_type SINGLE DOUBLE) foreach (float_type SINGLE DOUBLE HALF)
string(SUBSTRING ${float_type} 0 1 float_char) string(SUBSTRING ${float_type} 0 1 float_char)
if (${float_type} STREQUAL "HALF")
set (float_char "SH")
endif ()
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type})
endforeach() endforeach()
foreach (float_type ${FLOAT_TYPES}) foreach (float_type ${FLOAT_TYPES})
string(SUBSTRING ${float_type} 0 1 float_char) string(SUBSTRING ${float_type} 0 1 float_char)
if (${float_type} STREQUAL "HALF")
set (float_char "SH")
endif ()
if (${float_char}GEMMINCOPY) if (${float_char}GEMMINCOPY)
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMINCOPY}" "${float_type}" "${${float_char}GEMMINCOPYOBJ}" false "" "" true ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMINCOPY}" "${float_type}" "${${float_char}GEMMINCOPYOBJ}" false "" "" true ${float_type})
endif () endif ()
@ -470,9 +482,13 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEADD_KERNEL}" "" "geadd_k" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEADD_KERNEL}" "" "geadd_k" false "" "" false ${float_type})
endforeach () endforeach ()
# Makefile.LA # Makefile.LA
if(NOT NO_LAPACK) if(NOT NO_LAPACK)
foreach (float_type ${FLOAT_TYPES}) foreach (float_type ${FLOAT_TYPES})
if (${float_type} STREQUAL "HALF")
set (float_char "SH")
endif ()
if (NOT DEFINED ${float_char}NEG_TCOPY) if (NOT DEFINED ${float_char}NEG_TCOPY)
if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C" OR ${float_char} STREQUAL "X") if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C" OR ${float_char} STREQUAL "X")
set(${float_char}NEG_TCOPY ../generic/zneg_tcopy.c) set(${float_char}NEG_TCOPY ../generic/zneg_tcopy.c)
@ -516,6 +532,9 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
foreach (float_type ${FLOAT_TYPES}) foreach (float_type ${FLOAT_TYPES})
# a bit of metaprogramming here to pull out the appropriate KERNEL var # a bit of metaprogramming here to pull out the appropriate KERNEL var
string(SUBSTRING ${float_type} 0 1 float_char) string(SUBSTRING ${float_type} 0 1 float_char)
if (${float_type} STREQUAL "HALF")
set (float_char "SH")
endif ()
GenerateNamedObjects("generic/neg_tcopy_${${float_char}GEMM_UNROLL_M}.c" "" "neg_tcopy" false "" ${TSUFFIX} false ${float_type}) GenerateNamedObjects("generic/neg_tcopy_${${float_char}GEMM_UNROLL_M}.c" "" "neg_tcopy" false "" ${TSUFFIX} false ${float_type})
GenerateNamedObjects("generic/laswp_ncopy_${${float_char}GEMM_UNROLL_N}.c" "" "laswp_ncopy" false "" ${TSUFFIX} false ${float_type}) GenerateNamedObjects("generic/laswp_ncopy_${${float_char}GEMM_UNROLL_N}.c" "" "laswp_ncopy" false "" ${TSUFFIX} false ${float_type})
endforeach () endforeach ()

View File

@ -59,6 +59,23 @@ ifeq ($(CORE), Z14)
USE_TRMM = 1 USE_TRMM = 1
endif endif
#ifndef SHGEMMKERNEL
SHGEMM_BETA = ../generic/gemm_beta.c
SHGEMMKERNEL = ../generic/gemmkernel_2x2.c
SHGEMMINCOPY = ../generic/gemm_ncopy_2.c
SHGEMMITCOPY = ../generic/gemm_tcopy_2.c
SHGEMMONCOPY = ../generic/gemm_ncopy_2.c
SHGEMMOTCOPY = ../generic/gemm_tcopy_2.c
SHGEMMINCOPYOBJ = shgemm_incopy$(TSUFFIX).$(SUFFIX)
SHGEMMITCOPYOBJ = shgemm_itcopy$(TSUFFIX).$(SUFFIX)
SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX)
SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX)
#endif
SHKERNELOBJS += \
shgemm_kernel$(TSUFFIX).$(SUFFIX) \
$(SHGEMMINCOPYOBJ) $(SHGEMMITCOPYOBJ) \
$(SHGEMMONCOPYOBJ) $(SHGEMMOTCOPYOBJ)
SKERNELOBJS += \ SKERNELOBJS += \
sgemm_kernel$(TSUFFIX).$(SUFFIX) \ sgemm_kernel$(TSUFFIX).$(SUFFIX) \
@ -93,6 +110,7 @@ XKERNELOBJS += \
$(XGEMMINCOPYOBJ) $(XGEMMITCOPYOBJ) \ $(XGEMMINCOPYOBJ) $(XGEMMITCOPYOBJ) \
$(XGEMMONCOPYOBJ) $(XGEMMOTCOPYOBJ) $(XGEMMONCOPYOBJ) $(XGEMMOTCOPYOBJ)
SHBLASOBJS += $(SHKERNELOBJS)
SBLASOBJS += $(SKERNELOBJS) SBLASOBJS += $(SKERNELOBJS)
DBLASOBJS += $(DKERNELOBJS) DBLASOBJS += $(DKERNELOBJS)
QBLASOBJS += $(QKERNELOBJS) QBLASOBJS += $(QKERNELOBJS)
@ -100,6 +118,7 @@ CBLASOBJS += $(CKERNELOBJS)
ZBLASOBJS += $(ZKERNELOBJS) ZBLASOBJS += $(ZKERNELOBJS)
XBLASOBJS += $(XKERNELOBJS) XBLASOBJS += $(XKERNELOBJS)
SHBLASOBJS += shgemm_beta$(TSUFFIX).$(SUFFIX)
SBLASOBJS += \ SBLASOBJS += \
sgemm_beta$(TSUFFIX).$(SUFFIX) \ sgemm_beta$(TSUFFIX).$(SUFFIX) \
strmm_kernel_LN$(TSUFFIX).$(SUFFIX) strmm_kernel_LT$(TSUFFIX).$(SUFFIX) \ strmm_kernel_LN$(TSUFFIX).$(SUFFIX) strmm_kernel_LT$(TSUFFIX).$(SUFFIX) \
@ -390,6 +409,10 @@ ZBLASOBJS += \
zgeadd_k$(TSUFFIX).$(SUFFIX) zgeadd_k$(TSUFFIX).$(SUFFIX)
SHGEMMINCOPYOBJ_P = $(SHGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SHGEMMITCOPYOBJ_P = $(SHGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SHGEMMONCOPYOBJ_P = $(SHGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SHGEMMOTCOPYOBJ_P = $(SHGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SGEMMINCOPYOBJ_P = $(SGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SGEMMINCOPYOBJ_P = $(SGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SGEMMITCOPYOBJ_P = $(SGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SGEMMITCOPYOBJ_P = $(SGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SGEMMONCOPYOBJ_P = $(SGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SGEMMONCOPYOBJ_P = $(SGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
@ -415,6 +438,9 @@ XGEMMITCOPYOBJ_P = $(XGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
XGEMMONCOPYOBJ_P = $(XGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) XGEMMONCOPYOBJ_P = $(XGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
XGEMMOTCOPYOBJ_P = $(XGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) XGEMMOTCOPYOBJ_P = $(XGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
$(KDIR)shgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA)
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)sgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_BETA) $(KDIR)sgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
@ -433,6 +459,36 @@ $(KDIR)zgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_BETA)
$(KDIR)xgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMM_BETA) $(KDIR)xgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMM_BETA)
$(CC) $(CFLAGS) -c -DXDOUBLE -DCOMPLEX $< -o $@ $(CC) $(CFLAGS) -c -DXDOUBLE -DCOMPLEX $< -o $@
$(KDIR)$(SHGEMMONCOPYOBJ) : $(KERNELDIR)/$(SHGEMMONCOPY)
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)$(SHGEMMOTCOPYOBJ) : $(KERNELDIR)/$(SHGEMMOTCOPY)
ifeq ($(OS), AIX)
$(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemmotcopy.s
m4 shgemmotcopy.s > shgemmotcopy_nomacros.s
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemmotcopy_nomacros.s -o $@
rm shgemmotcopy.s shgemmotcopy_nomacros.s
else
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
endif
ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N))
$(KDIR)$(SHGEMMINCOPYOBJ) : $(KERNELDIR)/$(SHGEMMINCOPY)
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)$(SHGEMMITCOPYOBJ) : $(KERNELDIR)/$(SHGEMMITCOPY)
ifeq ($(OS), AIX)
$(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemmitcopy.s
m4 shgemmitcopy.s > shgemmitcopy_nomacros.s
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemmitcopy_nomacros.s -o $@
rm shgemmitcopy.s shgemmitcopy_nomacros.s
else
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
endif
endif
$(KDIR)$(SGEMMONCOPYOBJ) : $(KERNELDIR)/$(SGEMMONCOPY) $(KDIR)$(SGEMMONCOPYOBJ) : $(KERNELDIR)/$(SGEMMONCOPY)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
@ -590,6 +646,16 @@ else
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
endif endif
$(KDIR)shgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)
ifeq ($(OS), AIX)
$(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemm_kernel$(TSUFFIX).s
m4 shgemm_kernel$(TSUFFIX).s > shgemm_kernel$(TSUFFIX)_nomacros.s
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemm_kernel$(TSUFFIX)_nomacros.s -o $@
rm shgemm_kernel$(TSUFFIX).s shgemm_kernel$(TSUFFIX)_nomacros.s
else
$(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
endif
$(KDIR)dgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMMKERNEL) $(DGEMMDEPEND) $(KDIR)dgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMMKERNEL) $(DGEMMDEPEND)
ifeq ($(OS), AIX) ifeq ($(OS), AIX)
$(CC) $(CFLAGS) -E -DDOUBLE -UCOMPLEX $< -o dgemm_kernel$(TSUFFIX).s $(CC) $(CFLAGS) -E -DDOUBLE -UCOMPLEX $< -o dgemm_kernel$(TSUFFIX).s
@ -2206,6 +2272,9 @@ $(KDIR)xtrsm_oltncopy$(TSUFFIX).$(SUFFIX) : generic/ztrsm_ltcopy_$(XGEMM_UNROLL_
$(KDIR)sgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMM_BETA) $(KDIR)sgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMM_BETA)
$(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)shgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA)
$(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)dgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(DGEMM_BETA) $(KDIR)dgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(DGEMM_BETA)
$(CC) $(PFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ $(CC) $(PFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@
@ -2221,6 +2290,20 @@ $(KDIR)zgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(ZGEMM_BETA)
$(KDIR)xgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(XGEMM_BETA) $(KDIR)xgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(XGEMM_BETA)
$(CC) $(PFLAGS) -c -DXDOUBLE -DCOMPLEX $< -o $@ $(CC) $(PFLAGS) -c -DXDOUBLE -DCOMPLEX $< -o $@
$(SHGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMONCOPY)
$(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
$(SHGEMMOTCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMOTCOPY)
$(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N))
$(SHGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMINCOPY)
$(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
$(SHGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMITCOPY)
$(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
endif
$(SGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SGEMMONCOPY) $(SGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SGEMMONCOPY)
$(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
@ -2325,6 +2408,9 @@ endif
endif endif
$(KDIR)shgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)
$(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)sgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMMKERNEL) $(SGEMMDEPEND) $(KDIR)sgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMMKERNEL) $(SGEMMDEPEND)
$(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@

View File

@ -47,6 +47,100 @@ typedef struct {
int dtb_entries; int dtb_entries;
int offsetA, offsetB, align; int offsetA, offsetB, align;
#if 1
int shgemm_p, shgemm_q, shgemm_r;
int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn;
float (*shamax_k) (BLASLONG, float *, BLASLONG);
float (*shamin_k) (BLASLONG, float *, BLASLONG);
float (*shmax_k) (BLASLONG, float *, BLASLONG);
float (*shmin_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*ishamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*ishamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*ishmax_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*ishmin_k) (BLASLONG, float *, BLASLONG);
float (*shnrm2_k) (BLASLONG, float *, BLASLONG);
float (*shasum_k) (BLASLONG, float *, BLASLONG);
float (*shsum_k) (BLASLONG, float *, BLASLONG);
int (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
float (*shdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
int (*shrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float);
int (*shaxpy_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
int (*shscal_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
int (*shswap_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
int (*shgemv_n) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*shgemv_t) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*shger_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*shsymv_L) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*shsymv_U) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
int (*shgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
int (*shgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
int (*shgemm_oncopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
int (*shgemm_otcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
int (*shtrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrsm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrsm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrsm_iunucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_iunncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_iutucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_iutncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_ilnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_ilnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_iltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_iltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_ounucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_ounncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_outucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_outncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_olnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_olnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_oltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrsm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
int (*shtrmm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrmm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrmm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrmm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*shtrmm_iunucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_iunncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_iutucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_iutncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_ilnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_ilnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_iltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_iltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_ounucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_ounncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_outucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_outncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_olnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_olnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_oltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shtrmm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shsymm_iutcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shsymm_iltcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shsymm_outcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shsymm_oltcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *);
int (*shneg_tcopy) (BLASLONG, BLASLONG, float *, BLASLONG, float *);
int (*shlaswp_ncopy) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG, blasint *, float *);
#endif
int sgemm_p, sgemm_q, sgemm_r; int sgemm_p, sgemm_q, sgemm_r;
int sgemm_unroll_m, sgemm_unroll_n, sgemm_unroll_mn; int sgemm_unroll_m, sgemm_unroll_n, sgemm_unroll_mn;
@ -84,6 +178,7 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
int (*sgemm_incopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_incopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *);
int (*sgemm_itcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_itcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *);
int (*sgemm_oncopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_oncopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *);
@ -907,6 +1002,13 @@ extern gotoblas_t *gotoblas;
#define HAVE_EX_L2 gotoblas -> exclusive_cache #define HAVE_EX_L2 gotoblas -> exclusive_cache
#define SHGEMM_P gotoblas -> shgemm_p
#define SHGEMM_Q gotoblas -> shgemm_q
#define SHGEMM_R gotoblas -> shgemm_r
#define SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m
#define SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n
#define SHGEMM_UNROLL_MN gotoblas -> shgemm_unroll_mn
#define SGEMM_P gotoblas -> sgemm_p #define SGEMM_P gotoblas -> sgemm_p
#define SGEMM_Q gotoblas -> sgemm_q #define SGEMM_Q gotoblas -> sgemm_q
#define SGEMM_R gotoblas -> sgemm_r #define SGEMM_R gotoblas -> sgemm_r
@ -984,6 +1086,17 @@ extern gotoblas_t *gotoblas;
#define HAVE_EX_L2 0 #define HAVE_EX_L2 0
#endif #endif
#define SHGEMM_P SHGEMM_DEFAULT_P
#define SHGEMM_Q SHGEMM_DEFAULT_Q
#define SHGEMM_R SHGEMM_DEFAULT_R
#define SHGEMM_UNROLL_M SHGEMM_DEFAULT_UNROLL_M
#define SHGEMM_UNROLL_N SHGEMM_DEFAULT_UNROLL_N
#ifdef SHGEMM_DEFAULT_UNROLL_MN
#define SHGEMM_UNROLL_MN SHGEMM_DEFAULT_UNROLL_MN
#else
#define SHGEMM_UNROLL_MN MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N))
#endif
#define SGEMM_P SGEMM_DEFAULT_P #define SGEMM_P SGEMM_DEFAULT_P
#define SGEMM_Q SGEMM_DEFAULT_Q #define SGEMM_Q SGEMM_DEFAULT_Q
#define SGEMM_R SGEMM_DEFAULT_R #define SGEMM_R SGEMM_DEFAULT_R
@ -1119,6 +1232,18 @@ extern gotoblas_t *gotoblas;
#define GEMM_DEFAULT_R DGEMM_DEFAULT_R #define GEMM_DEFAULT_R DGEMM_DEFAULT_R
#define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M #define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M
#define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N #define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N
#elif defined(HALF)
#define GEMM_P SHGEMM_P
#define GEMM_Q SHGEMM_Q
#define GEMM_R SHGEMM_R
#define GEMM_UNROLL_M SHGEMM_UNROLL_M
#define GEMM_UNROLL_N SHGEMM_UNROLL_N
#define GEMM_UNROLL_MN SHGEMM_UNROLL_MN
#define GEMM_DEFAULT_P SHGEMM_DEFAULT_P
#define GEMM_DEFAULT_Q SHGEMM_DEFAULT_Q
#define GEMM_DEFAULT_R SHGEMM_DEFAULT_R
#define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M
#define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N
#else #else
#define GEMM_P SGEMM_P #define GEMM_P SGEMM_P
#define GEMM_Q SGEMM_Q #define GEMM_Q SGEMM_Q
@ -1204,6 +1329,10 @@ extern gotoblas_t *gotoblas;
#define GEMM_THREAD gemm_thread_n #define GEMM_THREAD gemm_thread_n
#endif #endif
#ifndef SHGEMM_DEFAULT_R
#define SHGEMM_DEFAULT_R (((BUFFER_SIZE - ((SHGEMM_DEFAULT_P * SHGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SHGEMM_DEFAULT_Q * 4) - 15) & ~15UL)
#endif
#ifndef SGEMM_DEFAULT_R #ifndef SGEMM_DEFAULT_R
#define SGEMM_DEFAULT_R (((BUFFER_SIZE - ((SGEMM_DEFAULT_P * SGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SGEMM_DEFAULT_Q * 4) - 15) & ~15UL) #define SGEMM_DEFAULT_R (((BUFFER_SIZE - ((SGEMM_DEFAULT_P * SGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SGEMM_DEFAULT_Q * 4) - 15) & ~15UL)
#endif #endif

View File

@ -39,7 +39,7 @@
#include "common.h" #include "common.h"
int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta,
FLOAT *dummy2, BLASLONG dummy3, FLOAT *dummy4, BLASLONG dummy5, IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5,
FLOAT *c, BLASLONG ldc){ FLOAT *c, BLASLONG ldc){

View File

@ -39,10 +39,10 @@
#include <stdio.h> #include <stdio.h>
#include "common.h" #include "common.h"
int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){
BLASLONG i, j; BLASLONG i, j;
FLOAT *a_offset, *a_offset1, *a_offset2; IFLOAT *a_offset, *a_offset1, *a_offset2;
FLOAT *b_offset; IFLOAT *b_offset;
a_offset = a; a_offset = a;
b_offset = b; b_offset = b;

View File

@ -39,11 +39,11 @@
#include <stdio.h> #include <stdio.h>
#include "common.h" #include "common.h"
int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){
BLASLONG i, j; BLASLONG i, j;
FLOAT *a_offset, *a_offset1, *a_offset2; IFLOAT *a_offset, *a_offset1, *a_offset2;
FLOAT *b_offset, *b_offset1, *b_offset2; IFLOAT *b_offset, *b_offset1, *b_offset2;
a_offset = a; a_offset = a;
b_offset = b; b_offset = b;

View File

@ -1,13 +1,32 @@
#include "common.h" #include "common.h"
int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FLOAT* C,BLASLONG ldc #if defined(HALF) && defined(HALFCONVERSION)
static float
bfloat16tof32 (bfloat16 f16)
{
float result = 0;
unsigned short* q = (unsigned short*)(&result);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
q[0] = f16;
#else
q[1] = f16;
#endif
return result;
}
#define BF16TOF32(x) (bfloat16tof32(x))
#else
#define BF16TOF32(x) x
#endif
int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc
#ifdef TRMMKERNEL #ifdef TRMMKERNEL
,BLASLONG offset ,BLASLONG offset
#endif #endif
) )
{ {
BLASLONG i,j,k; BLASLONG i,j,k;
FLOAT *C0,*C1,*ptrba,*ptrbb; FLOAT *C0,*C1;
FLOAT res0,res1,res2,res3,load0,load1,load2,load3,load4,load5,load6,load7; IFLOAT *ptrba,*ptrbb;
FLOAT res0,res1,res2,res3;
IFLOAT load0,load1,load2,load3,load4,load5,load6,load7;
for (j=0; j<bn/2; j+=1) for (j=0; j<bn/2; j+=1)
{ {
C0 = C; C0 = C;
@ -24,36 +43,36 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
{ {
load0 = ptrba[2*0+0]; load0 = ptrba[2*0+0];
load1 = ptrbb[2*0+0]; load1 = ptrbb[2*0+0];
res0 = res0+load0*load1; res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
load2 = ptrba[2*0+1]; load2 = ptrba[2*0+1];
res1 = res1+load2*load1; res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
load3 = ptrbb[2*0+1]; load3 = ptrbb[2*0+1];
res2 = res2+load0*load3; res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
res3 = res3+load2*load3; res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
load4 = ptrba[2*1+0]; load4 = ptrba[2*1+0];
load5 = ptrbb[2*1+0]; load5 = ptrbb[2*1+0];
res0 = res0+load4*load5; res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
load6 = ptrba[2*1+1]; load6 = ptrba[2*1+1];
res1 = res1+load6*load5; res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
load7 = ptrbb[2*1+1]; load7 = ptrbb[2*1+1];
res2 = res2+load4*load7; res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
res3 = res3+load6*load7; res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
load0 = ptrba[2*2+0]; load0 = ptrba[2*2+0];
load1 = ptrbb[2*2+0]; load1 = ptrbb[2*2+0];
res0 = res0+load0*load1; res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
load2 = ptrba[2*2+1]; load2 = ptrba[2*2+1];
res1 = res1+load2*load1; res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
load3 = ptrbb[2*2+1]; load3 = ptrbb[2*2+1];
res2 = res2+load0*load3; res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
res3 = res3+load2*load3; res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
load4 = ptrba[2*3+0]; load4 = ptrba[2*3+0];
load5 = ptrbb[2*3+0]; load5 = ptrbb[2*3+0];
res0 = res0+load4*load5; res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
load6 = ptrba[2*3+1]; load6 = ptrba[2*3+1];
res1 = res1+load6*load5; res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
load7 = ptrbb[2*3+1]; load7 = ptrbb[2*3+1];
res2 = res2+load4*load7; res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
res3 = res3+load6*load7; res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
ptrba = ptrba+8; ptrba = ptrba+8;
ptrbb = ptrbb+8; ptrbb = ptrbb+8;
} }
@ -61,12 +80,12 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
{ {
load0 = ptrba[2*0+0]; load0 = ptrba[2*0+0];
load1 = ptrbb[2*0+0]; load1 = ptrbb[2*0+0];
res0 = res0+load0*load1; res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
load2 = ptrba[2*0+1]; load2 = ptrba[2*0+1];
res1 = res1+load2*load1; res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
load3 = ptrbb[2*0+1]; load3 = ptrbb[2*0+1];
res2 = res2+load0*load3; res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
res3 = res3+load2*load3; res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
ptrba = ptrba+2; ptrba = ptrba+2;
ptrbb = ptrbb+2; ptrbb = ptrbb+2;
} }
@ -90,9 +109,9 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
{ {
load0 = ptrba[0+0]; load0 = ptrba[0+0];
load1 = ptrbb[2*0+0]; load1 = ptrbb[2*0+0];
res0 = res0+load0*load1; res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
load2 = ptrbb[2*0+1]; load2 = ptrbb[2*0+1];
res1 = res1+load0*load2; res1 = res1+BF16TOF32(load0)*BF16TOF32(load2);
ptrba = ptrba+1; ptrba = ptrba+1;
ptrbb = ptrbb+2; ptrbb = ptrbb+2;
} }
@ -121,9 +140,9 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
{ {
load0 = ptrba[2*0+0]; load0 = ptrba[2*0+0];
load1 = ptrbb[0+0]; load1 = ptrbb[0+0];
res0 = res0+load0*load1; res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
load2 = ptrba[2*0+1]; load2 = ptrba[2*0+1];
res1 = res1+load2*load1; res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
ptrba = ptrba+2; ptrba = ptrba+2;
ptrbb = ptrbb+1; ptrbb = ptrbb+1;
} }
@ -141,7 +160,7 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
{ {
load0 = ptrba[0+0]; load0 = ptrba[0+0];
load1 = ptrbb[0+0]; load1 = ptrbb[0+0];
res0 = res0+load0*load1; res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
ptrba = ptrba+1; ptrba = ptrba+1;
ptrbb = ptrbb+1; ptrbb = ptrbb+1;
} }

View File

@ -53,6 +53,64 @@ gotoblas_t TABLE_NAME = {
GEMM_DEFAULT_OFFSET_A, GEMM_DEFAULT_OFFSET_B, GEMM_DEFAULT_ALIGN, GEMM_DEFAULT_OFFSET_A, GEMM_DEFAULT_OFFSET_B, GEMM_DEFAULT_ALIGN,
0, 0, 0,
SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N,
#ifdef SHGEMM_DEFAULT_UNROLL_MN
SHGEMM_DEFAULT_UNROLL_MN,
#else
MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N),
#endif
samax_kTS, samin_kTS, smax_kTS, smin_kTS,
isamax_kTS, isamin_kTS, ismax_kTS, ismin_kTS,
snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, sdot_kTS,
dsdot_kTS,
srot_kTS, saxpy_kTS, sscal_kTS, sswap_kTS,
sgemv_nTS, sgemv_tTS, sger_kTS,
ssymv_LTS, ssymv_UTS,
shgemm_kernelTS, shgemm_betaTS,
#if SHGEMM_DEFAULT_UNROLL_M != SHGEMM_DEFAULT_UNROLL_N
shgemm_incopyTS, shgemm_itcopyTS,
#else
shgemm_oncopyTS, shgemm_otcopyTS,
#endif
shgemm_oncopyTS, shgemm_otcopyTS,
strsm_kernel_LNTS, strsm_kernel_LTTS, strsm_kernel_RNTS, strsm_kernel_RTTS,
#if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N
strsm_iunucopyTS, strsm_iunncopyTS, strsm_iutucopyTS, strsm_iutncopyTS,
strsm_ilnucopyTS, strsm_ilnncopyTS, strsm_iltucopyTS, strsm_iltncopyTS,
#else
strsm_ounucopyTS, strsm_ounncopyTS, strsm_outucopyTS, strsm_outncopyTS,
strsm_olnucopyTS, strsm_olnncopyTS, strsm_oltucopyTS, strsm_oltncopyTS,
#endif
strsm_ounucopyTS, strsm_ounncopyTS, strsm_outucopyTS, strsm_outncopyTS,
strsm_olnucopyTS, strsm_olnncopyTS, strsm_oltucopyTS, strsm_oltncopyTS,
strmm_kernel_RNTS, strmm_kernel_RTTS, strmm_kernel_LNTS, strmm_kernel_LTTS,
#if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N
strmm_iunucopyTS, strmm_iunncopyTS, strmm_iutucopyTS, strmm_iutncopyTS,
strmm_ilnucopyTS, strmm_ilnncopyTS, strmm_iltucopyTS, strmm_iltncopyTS,
#else
strmm_ounucopyTS, strmm_ounncopyTS, strmm_outucopyTS, strmm_outncopyTS,
strmm_olnucopyTS, strmm_olnncopyTS, strmm_oltucopyTS, strmm_oltncopyTS,
#endif
strmm_ounucopyTS, strmm_ounncopyTS, strmm_outucopyTS, strmm_outncopyTS,
strmm_olnucopyTS, strmm_olnncopyTS, strmm_oltucopyTS, strmm_oltncopyTS,
#if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N
ssymm_iutcopyTS, ssymm_iltcopyTS,
#else
ssymm_outcopyTS, ssymm_oltcopyTS,
#endif
ssymm_outcopyTS, ssymm_oltcopyTS,
#ifndef NO_LAPACK
sneg_tcopyTS, slaswp_ncopyTS,
#else
NULL,NULL,
#endif
0, 0, 0, 0, 0, 0,
SGEMM_DEFAULT_UNROLL_M, SGEMM_DEFAULT_UNROLL_N, SGEMM_DEFAULT_UNROLL_M, SGEMM_DEFAULT_UNROLL_N,
#ifdef SGEMM_DEFAULT_UNROLL_MN #ifdef SGEMM_DEFAULT_UNROLL_MN
@ -648,16 +706,19 @@ gotoblas_t TABLE_NAME = {
#if defined(ARCH_ARM64) #if defined(ARCH_ARM64)
static void init_parameter(void) { static void init_parameter(void) {
TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P; TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P; TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P; TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P;
TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P; TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P;
TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q; TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q; TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q; TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
TABLE_NAME.zgemm_q = ZGEMM_DEFAULT_Q; TABLE_NAME.zgemm_q = ZGEMM_DEFAULT_Q;
TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R; TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R; TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R; TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R;
@ -721,17 +782,20 @@ static void init_parameter(void) {
#if defined(ARCH_POWER) #if defined(ARCH_POWER)
static void init_parameter(void) { static void init_parameter(void) {
TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P; TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P; TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P; TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P;
TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P; TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P;
TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R; TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R; TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R; TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R;
TABLE_NAME.zgemm_r = ZGEMM_DEFAULT_R; TABLE_NAME.zgemm_r = ZGEMM_DEFAULT_R;
TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q; TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q; TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q; TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
@ -741,17 +805,20 @@ static void init_parameter(void) {
#if defined(ARCH_ZARCH) #if defined(ARCH_ZARCH)
static void init_parameter(void) { static void init_parameter(void) {
TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P; TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P; TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P; TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P;
TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P; TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P;
TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R; TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R; TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R; TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R;
TABLE_NAME.zgemm_r = ZGEMM_DEFAULT_R; TABLE_NAME.zgemm_r = ZGEMM_DEFAULT_R;
TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q; TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q; TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q; TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
@ -891,6 +958,9 @@ static void init_parameter(void) {
(void) l2; /* dirty trick to suppress unused variable warning for targets */ (void) l2; /* dirty trick to suppress unused variable warning for targets */
/* where the GEMM unrolling parameters do not depend on l2 */ /* where the GEMM unrolling parameters do not depend on l2 */
TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q; TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q; TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q; TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;

View File

@ -2,6 +2,7 @@
include_directories(${PROJECT_SOURCE_DIR}) include_directories(${PROJECT_SOURCE_DIR})
include_directories(${PROJECT_BINARY_DIR}) include_directories(${PROJECT_BINARY_DIR})
list (REMOVE_ITEM FLOAT_TYPES "HALF")
set(LAPACK_SOURCES set(LAPACK_SOURCES
potrf/potrf_U_single.c potrf/potrf_U_single.c
@ -45,6 +46,9 @@ GenerateNamedObjects("laswp/generic/laswp_k_4.c" "" "laswp_plus" false "" "" fa
GenerateNamedObjects("laswp/generic/laswp_k_4.c" "MINUS" "laswp_minus" false "" "" false 3) GenerateNamedObjects("laswp/generic/laswp_k_4.c" "MINUS" "laswp_minus" false "" "" false 3)
foreach (float_type ${FLOAT_TYPES}) foreach (float_type ${FLOAT_TYPES})
if (${float_type} STREQUAL "HALF")
continue()
endif()
GenerateNamedObjects("getrf/getrf_single.c" "UNIT" "getrf_single" false "" "" false ${float_type}) GenerateNamedObjects("getrf/getrf_single.c" "UNIT" "getrf_single" false "" "" false ${float_type})
endforeach () endforeach ()

View File

@ -380,6 +380,9 @@ static int thread_driver(blas_arg_t *args, FLOAT *sa, FLOAT *sb){
#elif defined(DOUBLE) #elif defined(DOUBLE)
mode = BLAS_DOUBLE | BLAS_REAL; mode = BLAS_DOUBLE | BLAS_REAL;
mask = MAX(DGEMM_UNROLL_M, DGEMM_UNROLL_N) - 1; mask = MAX(DGEMM_UNROLL_M, DGEMM_UNROLL_N) - 1;
#elif defined(HALF)
mode = BLAS_HALF | BLAS_REAL;
mask = MAX(SHGEMM_UNROLL_M, SHGEMM_UNROLL_N) - 1;
#else #else
mode = BLAS_SINGLE | BLAS_REAL; mode = BLAS_SINGLE | BLAS_REAL;
mask = MAX(SGEMM_UNROLL_M, SGEMM_UNROLL_N) - 1; mask = MAX(SGEMM_UNROLL_M, SGEMM_UNROLL_N) - 1;

View File

@ -72,6 +72,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef PARAM_H #ifndef PARAM_H
#define PARAM_H #define PARAM_H
#define SHGEMM_DEFAULT_UNROLL_N 4
#define SHGEMM_DEFAULT_UNROLL_M 8
#define SHGEMM_DEFAULT_UNROLL_MN 32
#define SHGEMM_DEFAULT_P 256
#define SHGEMM_DEFAULT_R 256
#define SHGEMM_DEFAULT_Q 256
#ifdef OPTERON #ifdef OPTERON
#define SNUMOPT 4 #define SNUMOPT 4

View File

@ -0,0 +1,95 @@
/***************************************************************************
Copyright (c) 2020, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#include <stdio.h>
#include <stdint.h>
#include "common.h"
#define SGEMM BLASFUNC(sgemm)
#define SHGEMM BLASFUNC(shgemm)
typedef union
{
unsigned short v;
struct
{
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
unsigned short s:1;
unsigned short e:8;
unsigned short m:7;
#else
unsigned short m:7;
unsigned short e:8;
unsigned short s:1;
#endif
} bits;
} bfloat16_bits;
int
main (int argc, char *argv[])
{
int m, n, k;
int i, j, l;
int ret = 0;
int loop = 20;
char transA = 'N', transB = 'N';
float alpha = 1.0, beta = 0.0;
char transa = 'N';
char transb = 'N';
for (int x = 0; x <= loop; x++)
{
m = k = n = x;
float A[m * k];
float B[k * n];
float C[m * n];
bfloat16_bits AA[m * k], BB[k * n];
float CC[m * n];
for (int j = 0; j < m; j++)
{
for (int i = 0; i < m; i++)
{
A[j * k + i] = j * 9.0;
B[j * k + i] = i * 2.0;
C[j * k + i] = 0;
AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16;
BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16;
CC[j * k + i] = 0;
}
}
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
&m, B, &k, &beta, C, &m);
SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA,
&m, BB, &k, &beta, CC, &m);
for (i = 0; i < n; i++)
for (j = 0; j < m; j++)
for (l = 0; l < k; l++)
if (CC[i * m + j] != C[i * m + j])
ret++;
}
fprintf (stderr, "Return code: %d\n", ret);
return ret;
}