diff --git a/common.h b/common.h index 1d8bf07e5..e2c8cdee5 100644 --- a/common.h +++ b/common.h @@ -257,6 +257,11 @@ typedef long BLASLONG; typedef unsigned long BLASULONG; #endif +#ifndef BFLOAT16 +typedef unsigned short bfloat16; +#define HALFCONVERSION 1 +#endif + #ifdef USE64BITINT typedef BLASLONG blasint; #if defined(OS_WINDOWS) && defined(__64BIT__) @@ -298,10 +303,6 @@ typedef int blasint; #define BASE_SHIFT 3 #define ZBASE_SHIFT 4 #elif defined(HALF) -#ifndef BFLOAT16 -typedef unsigned short bfloat16; -#define HALFCONVERSION 1 -#endif #define IFLOAT bfloat16 #define XFLOAT IFLOAT #define FLOAT float diff --git a/common_interface.h b/common_interface.h index 081043af1..78f5be6b0 100644 --- a/common_interface.h +++ b/common_interface.h @@ -37,9 +37,6 @@ /*********************************************************************/ #ifndef ASSEMBLER -#ifndef BFLOAT16 -typedef unsigned short bfloat16; -#endif #ifdef __cplusplus extern "C" { diff --git a/common_level3.h b/common_level3.h index 8194ba6ce..4e44a5e73 100644 --- a/common_level3.h +++ b/common_level3.h @@ -37,9 +37,6 @@ /*********************************************************************/ #ifndef ASSEMBLER -#ifndef BFLOAT16 -typedef unsigned short bfloat16; -#endif #ifdef __CUDACC__ __global__ void cuda_sgemm_kernel(int, int, int, float *, float *, float *); diff --git a/common_param.h b/common_param.h index 446d42452..19a34fa3d 100644 --- a/common_param.h +++ b/common_param.h @@ -43,10 +43,6 @@ #ifdef DYNAMIC_ARCH -#ifndef BFLOAT16 -typedef unsigned short bfloat16; -#endif - typedef struct { int dtb_entries; int offsetA, offsetB, align; diff --git a/common_sh.h b/common_sh.h index 8859694f1..7a0045762 100644 --- a/common_sh.h +++ b/common_sh.h @@ -1,5 +1,5 @@ -#ifndef COMMON_H_H -#define COMMON_H_H +#ifndef COMMON_SH_H +#define COMMON_SH_H #ifndef DYNAMIC_ARCH diff --git a/exports/gensymbol b/exports/gensymbol index d2894e6c8..235446f14 100644 --- a/exports/gensymbol +++ b/exports/gensymbol @@ -30,7 +30,7 @@ icamax,icamin,idamax,idamin,idmax,idmin,isamax,isamin,ismax,ismin, izamax,izamin,lsame,samax,samin,sasum,saxpy,scabs1,scamax, scamin,scasum,scnrm2,scopy,sdot,sdsdot,sgbmv,sgemm,sgemv,sger, - smax,smin,snrm2, + shgemm, smax,smin,snrm2, srot,srotg,srotm,srotmg,ssbmv,sscal,sspmv,sspr2,sspr,sswap, ssymm,ssymv,ssyr2,ssyr2k,ssyr,ssyrk,stbmv,stbsv,stpmv,stpsv, strmm,strmv,strsm,strsv,zaxpy,zcopy,zdotc,zdotu,zdrot, @@ -67,7 +67,7 @@ cblas_isamax, cblas_izamax, cblas_sasum, cblas_saxpy, cblas_scasum, cblas_scnrm2, cblas_scopy, cblas_sdot, cblas_sdsdot, cblas_sgbmv, cblas_sgemm, - cblas_sgemv, cblas_sger, cblas_snrm2, cblas_srot, cblas_srotg, + cblas_sgemv, cblas_sger, cblas_shgemm, cblas_snrm2, cblas_srot, cblas_srotg, cblas_srotm, cblas_srotmg, cblas_ssbmv, cblas_sscal, cblas_sspmv, cblas_sspr2, cblas_sspr, cblas_sswap, cblas_ssymm, cblas_ssymv, cblas_ssyr2, cblas_ssyr2k, cblas_ssyr, cblas_ssyrk, cblas_stbmv, cblas_stbsv, cblas_stpmv, cblas_stpsv, cblas_strmm, cblas_strmv, cblas_strsm, diff --git a/kernel/common_param.h b/kernel/common_param.h index eab14b0a6..29bb65e5c 100644 --- a/kernel/common_param.h +++ b/kernel/common_param.h @@ -47,6 +47,100 @@ typedef struct { int dtb_entries; int offsetA, offsetB, align; +#if 1 + int shgemm_p, shgemm_q, shgemm_r; + int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; + + float (*shamax_k) (BLASLONG, float *, BLASLONG); + float (*shamin_k) (BLASLONG, float *, BLASLONG); + float (*shmax_k) (BLASLONG, float *, BLASLONG); + float (*shmin_k) (BLASLONG, float *, BLASLONG); +BLASLONG (*ishamax_k)(BLASLONG, float *, BLASLONG); +BLASLONG (*ishamin_k)(BLASLONG, float *, BLASLONG); +BLASLONG (*ishmax_k) (BLASLONG, float *, BLASLONG); +BLASLONG (*ishmin_k) (BLASLONG, float *, BLASLONG); + + float (*shnrm2_k) (BLASLONG, float *, BLASLONG); + float (*shasum_k) (BLASLONG, float *, BLASLONG); + float (*shsum_k) (BLASLONG, float *, BLASLONG); + int (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); + float (*shdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); + double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); + + int (*shrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float); + + int (*shaxpy_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + int (*shscal_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + int (*shswap_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + + int (*shgemv_n) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + int (*shgemv_t) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + int (*shger_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + + int (*shsymv_L) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + int (*shsymv_U) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + + int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); + int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); + + int (*shgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_oncopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_otcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + + int (*shtrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrsm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrsm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + + int (*shtrsm_iunucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iunncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iutucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iutncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ilnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ilnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ounucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ounncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_outucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_outncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_olnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_olnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_oltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + + int (*shtrmm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrmm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrmm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrmm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + + int (*shtrmm_iunucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iunncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iutucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iutncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ilnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ilnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ounucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ounncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_outucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_outncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_olnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_olnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_oltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + + int (*shsymm_iutcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shsymm_iltcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shsymm_outcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shsymm_oltcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + + int (*shneg_tcopy) (BLASLONG, BLASLONG, float *, BLASLONG, float *); + int (*shlaswp_ncopy) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG, blasint *, float *); + +#endif int sgemm_p, sgemm_q, sgemm_r; int sgemm_unroll_m, sgemm_unroll_n, sgemm_unroll_mn; @@ -84,6 +178,7 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + int (*sgemm_incopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_itcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_oncopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); @@ -907,6 +1002,13 @@ extern gotoblas_t *gotoblas; #define HAVE_EX_L2 gotoblas -> exclusive_cache +#define SHGEMM_P gotoblas -> shgemm_p +#define SHGEMM_Q gotoblas -> shgemm_q +#define SHGEMM_R gotoblas -> shgemm_r +#define SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m +#define SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n +#define SHGEMM_UNROLL_MN gotoblas -> shgemm_unroll_mn + #define SGEMM_P gotoblas -> sgemm_p #define SGEMM_Q gotoblas -> sgemm_q #define SGEMM_R gotoblas -> sgemm_r @@ -984,6 +1086,17 @@ extern gotoblas_t *gotoblas; #define HAVE_EX_L2 0 #endif +#define SHGEMM_P SHGEMM_DEFAULT_P +#define SHGEMM_Q SHGEMM_DEFAULT_Q +#define SHGEMM_R SHGEMM_DEFAULT_R +#define SHGEMM_UNROLL_M SHGEMM_DEFAULT_UNROLL_M +#define SHGEMM_UNROLL_N SHGEMM_DEFAULT_UNROLL_N +#ifdef SHGEMM_DEFAULT_UNROLL_MN +#define SHGEMM_UNROLL_MN SHGEMM_DEFAULT_UNROLL_MN +#else +#define SHGEMM_UNROLL_MN MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N)) +#endif + #define SGEMM_P SGEMM_DEFAULT_P #define SGEMM_Q SGEMM_DEFAULT_Q #define SGEMM_R SGEMM_DEFAULT_R @@ -1119,6 +1232,18 @@ extern gotoblas_t *gotoblas; #define GEMM_DEFAULT_R DGEMM_DEFAULT_R #define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M #define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N +#elif defined(HALF) +#define GEMM_P SHGEMM_P +#define GEMM_Q SHGEMM_Q +#define GEMM_R SHGEMM_R +#define GEMM_UNROLL_M SHGEMM_UNROLL_M +#define GEMM_UNROLL_N SHGEMM_UNROLL_N +#define GEMM_UNROLL_MN SHGEMM_UNROLL_MN +#define GEMM_DEFAULT_P SHGEMM_DEFAULT_P +#define GEMM_DEFAULT_Q SHGEMM_DEFAULT_Q +#define GEMM_DEFAULT_R SHGEMM_DEFAULT_R +#define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M +#define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N #else #define GEMM_P SGEMM_P #define GEMM_Q SGEMM_Q @@ -1204,6 +1329,10 @@ extern gotoblas_t *gotoblas; #define GEMM_THREAD gemm_thread_n #endif +#ifndef SHGEMM_DEFAULT_R +#define SHGEMM_DEFAULT_R (((BUFFER_SIZE - ((SHGEMM_DEFAULT_P * SHGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SHGEMM_DEFAULT_Q * 4) - 15) & ~15UL) +#endif + #ifndef SGEMM_DEFAULT_R #define SGEMM_DEFAULT_R (((BUFFER_SIZE - ((SGEMM_DEFAULT_P * SGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SGEMM_DEFAULT_Q * 4) - 15) & ~15UL) #endif diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index 79cd151f6..b7cf0f112 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -958,6 +958,8 @@ static void init_parameter(void) { (void) l2; /* dirty trick to suppress unused variable warning for targets */ /* where the GEMM unrolling parameters do not depend on l2 */ + TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P; + TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R; TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q; TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q; TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q; @@ -1329,7 +1331,6 @@ static void init_parameter(void) { - TABLE_NAME.shgemm_p = ((TABLE_NAME.shgemm_p + SHGEMM_DEFAULT_UNROLL_M - 1)/SHGEMM_DEFAULT_UNROLL_M) * SHGEMM_DEFAULT_UNROLL_M; TABLE_NAME.sgemm_p = ((TABLE_NAME.sgemm_p + SGEMM_DEFAULT_UNROLL_M - 1)/SGEMM_DEFAULT_UNROLL_M) * SGEMM_DEFAULT_UNROLL_M; TABLE_NAME.dgemm_p = ((TABLE_NAME.dgemm_p + DGEMM_DEFAULT_UNROLL_M - 1)/DGEMM_DEFAULT_UNROLL_M) * DGEMM_DEFAULT_UNROLL_M; TABLE_NAME.cgemm_p = ((TABLE_NAME.cgemm_p + CGEMM_DEFAULT_UNROLL_M - 1)/CGEMM_DEFAULT_UNROLL_M) * CGEMM_DEFAULT_UNROLL_M; @@ -1357,11 +1358,6 @@ static void init_parameter(void) { fprintf(stderr, "L2 = %8d DGEMM_P .. %d\n", l2, TABLE_NAME.dgemm_p); #endif - TABLE_NAME.shgemm_r = (((BUFFER_SIZE - - ((TABLE_NAME.shgemm_p * TABLE_NAME.shgemm_q * 4 + TABLE_NAME.offsetA - + TABLE_NAME.align) & ~TABLE_NAME.align) - ) / (TABLE_NAME.shgemm_q * 4) - 15) & ~15); - TABLE_NAME.sgemm_r = (((BUFFER_SIZE - ((TABLE_NAME.sgemm_p * TABLE_NAME.sgemm_q * 4 + TABLE_NAME.offsetA + TABLE_NAME.align) & ~TABLE_NAME.align)