From a87793e03c4a073b533ceadafa54cf6c01a66f18 Mon Sep 17 00:00:00 2001 From: Rajalakshmi Srinivasaraghavan Date: Wed, 15 Apr 2020 09:09:50 -0500 Subject: [PATCH] Fix DYNAMIC_ARCH compilation errors --- common_param.h | 106 +++++++++++++++++++++++++++++--- kernel/generic/gemmkernel_2x2.c | 2 +- kernel/setparam-ref.c | 46 +++++++++++++- 3 files changed, 142 insertions(+), 12 deletions(-) diff --git a/common_param.h b/common_param.h index 6276f7f51..446d42452 100644 --- a/common_param.h +++ b/common_param.h @@ -41,15 +41,110 @@ #ifndef ASSEMBLER +#ifdef DYNAMIC_ARCH + #ifndef BFLOAT16 typedef unsigned short bfloat16; #endif -#ifdef DYNAMIC_ARCH typedef struct { int dtb_entries; int offsetA, offsetB, align; +#if 1 + int shgemm_p, shgemm_q, shgemm_r; + int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; + + float (*shamax_k) (BLASLONG, float *, BLASLONG); + float (*shamin_k) (BLASLONG, float *, BLASLONG); + float (*shmax_k) (BLASLONG, float *, BLASLONG); + float (*shmin_k) (BLASLONG, float *, BLASLONG); +BLASLONG (*ishamax_k)(BLASLONG, float *, BLASLONG); +BLASLONG (*ishamin_k)(BLASLONG, float *, BLASLONG); +BLASLONG (*ishmax_k) (BLASLONG, float *, BLASLONG); +BLASLONG (*ishmin_k) (BLASLONG, float *, BLASLONG); + + float (*shnrm2_k) (BLASLONG, float *, BLASLONG); + float (*shasum_k) (BLASLONG, float *, BLASLONG); + float (*shsum_k) (BLASLONG, float *, BLASLONG); + int (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); + float (*shdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); + double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); + + int (*shrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float); + + int (*shaxpy_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + int (*shscal_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + int (*shswap_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + + int (*shgemv_n) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + int (*shgemv_t) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + int (*shger_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + + int (*shsymv_L) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + int (*shsymv_U) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + + int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); + int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); + + int (*shgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_oncopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*shgemm_otcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + + int (*shtrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrsm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrsm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + + int (*shtrsm_iunucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iunncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iutucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iutncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ilnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ilnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_iltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ounucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_ounncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_outucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_outncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_olnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_olnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_oltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + int (*shtrsm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); + + int (*shtrmm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrmm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrmm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + int (*shtrmm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); + + int (*shtrmm_iunucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iunncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iutucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iutncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ilnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ilnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_iltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ounucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_ounncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_outucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_outncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_olnucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_olnncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_oltucopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shtrmm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + + int (*shsymm_iutcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shsymm_iltcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shsymm_outcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + int (*shsymm_oltcopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, BLASLONG, float *); + + int (*shneg_tcopy) (BLASLONG, BLASLONG, float *, BLASLONG, float *); + int (*shlaswp_ncopy) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG, blasint *, float *); + +#endif int sgemm_p, sgemm_q, sgemm_r; int sgemm_unroll_m, sgemm_unroll_n, sgemm_unroll_mn; @@ -87,15 +182,6 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); - int shgemm_p, shgemm_q, shgemm_r; - int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; - int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); - int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); - - int (*shgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); - int (*shgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); - int (*shgemm_oncopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); - int (*shgemm_otcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); int (*sgemm_incopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sgemm_itcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); diff --git a/kernel/generic/gemmkernel_2x2.c b/kernel/generic/gemmkernel_2x2.c index 26a88db6d..cc7bb8e48 100644 --- a/kernel/generic/gemmkernel_2x2.c +++ b/kernel/generic/gemmkernel_2x2.c @@ -1,6 +1,6 @@ #include "common.h" #if defined(HALF) && defined(HALFCONVERSION) -float +static float bfloat16tof32 (bfloat16 f16) { float result = 0; diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index 12d038901..79cd151f6 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -60,6 +60,15 @@ gotoblas_t TABLE_NAME = { #else MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N), #endif + + samax_kTS, samin_kTS, smax_kTS, smin_kTS, + isamax_kTS, isamin_kTS, ismax_kTS, ismin_kTS, + snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, sdot_kTS, + dsdot_kTS, + srot_kTS, saxpy_kTS, sscal_kTS, sswap_kTS, + sgemv_nTS, sgemv_tTS, sger_kTS, + ssymv_LTS, ssymv_UTS, + shgemm_kernelTS, shgemm_betaTS, #if SHGEMM_DEFAULT_UNROLL_M != SHGEMM_DEFAULT_UNROLL_N shgemm_incopyTS, shgemm_itcopyTS, @@ -67,7 +76,42 @@ gotoblas_t TABLE_NAME = { shgemm_oncopyTS, shgemm_otcopyTS, #endif shgemm_oncopyTS, shgemm_otcopyTS, - sgemm_kernelTS, sgemm_betaTS, + + strsm_kernel_LNTS, strsm_kernel_LTTS, strsm_kernel_RNTS, strsm_kernel_RTTS, +#if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N + strsm_iunucopyTS, strsm_iunncopyTS, strsm_iutucopyTS, strsm_iutncopyTS, + strsm_ilnucopyTS, strsm_ilnncopyTS, strsm_iltucopyTS, strsm_iltncopyTS, +#else + strsm_ounucopyTS, strsm_ounncopyTS, strsm_outucopyTS, strsm_outncopyTS, + strsm_olnucopyTS, strsm_olnncopyTS, strsm_oltucopyTS, strsm_oltncopyTS, +#endif + strsm_ounucopyTS, strsm_ounncopyTS, strsm_outucopyTS, strsm_outncopyTS, + strsm_olnucopyTS, strsm_olnncopyTS, strsm_oltucopyTS, strsm_oltncopyTS, + strmm_kernel_RNTS, strmm_kernel_RTTS, strmm_kernel_LNTS, strmm_kernel_LTTS, +#if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N + strmm_iunucopyTS, strmm_iunncopyTS, strmm_iutucopyTS, strmm_iutncopyTS, + strmm_ilnucopyTS, strmm_ilnncopyTS, strmm_iltucopyTS, strmm_iltncopyTS, +#else + strmm_ounucopyTS, strmm_ounncopyTS, strmm_outucopyTS, strmm_outncopyTS, + strmm_olnucopyTS, strmm_olnncopyTS, strmm_oltucopyTS, strmm_oltncopyTS, +#endif + strmm_ounucopyTS, strmm_ounncopyTS, strmm_outucopyTS, strmm_outncopyTS, + strmm_olnucopyTS, strmm_olnncopyTS, strmm_oltucopyTS, strmm_oltncopyTS, +#if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N + ssymm_iutcopyTS, ssymm_iltcopyTS, +#else + ssymm_outcopyTS, ssymm_oltcopyTS, +#endif + ssymm_outcopyTS, ssymm_oltcopyTS, + +#ifndef NO_LAPACK + sneg_tcopyTS, slaswp_ncopyTS, +#else + NULL,NULL, +#endif + + + 0, 0, 0, SGEMM_DEFAULT_UNROLL_M, SGEMM_DEFAULT_UNROLL_N, #ifdef SGEMM_DEFAULT_UNROLL_MN SGEMM_DEFAULT_UNROLL_MN,