From 4989e039a5b37de140b41df9a042720599336e29 Mon Sep 17 00:00:00 2001 From: Honglin Zhu Date: Thu, 27 Oct 2022 14:10:26 +0800 Subject: [PATCH] Define SBGEMM_ALIGN_K for DYNAMIC_ARCH build --- common_param.h | 2 +- driver/level3/level3.c | 11 +++++------ driver/level3/level3_thread.c | 10 +++++----- kernel/setparam-ref.c | 8 ++------ param.h | 5 +++++ 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/common_param.h b/common_param.h index 091840343..e14ef2782 100644 --- a/common_param.h +++ b/common_param.h @@ -50,6 +50,7 @@ typedef struct { #ifdef BUILD_BFLOAT16 int sbgemm_p, sbgemm_q, sbgemm_r; int sbgemm_unroll_m, sbgemm_unroll_n, sbgemm_unroll_mn; + int sbgemm_align_k; void (*sbstobf16_k) (BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG); void (*sbdtobf16_k) (BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG); @@ -1193,7 +1194,6 @@ BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG); #ifdef BUILD_COMPLEX16 int (*zgeadd_k) (BLASLONG, BLASLONG, double, double, double *, BLASLONG, double, double, double *, BLASLONG); #endif - int align_k; // must be 2^n } gotoblas_t; extern gotoblas_t *gotoblas; diff --git a/driver/level3/level3.c b/driver/level3/level3.c index d3281345d..b7328876b 100644 --- a/driver/level3/level3.c +++ b/driver/level3/level3.c @@ -305,13 +305,12 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, } BLASLONG pad_min_l = min_l; - -#if defined(HALF) && defined(DYNAMIC_ARCH) - pad_min_l = (min_l + gotoblas->align_k - 1) & ~(gotoblas->align_k-1); +#if defined(HALF) +#if defined(DYNAMIC_ARCH) + pad_min_l = (min_l + gotoblas->sbgemm_align_k - 1) & ~(gotoblas->sbgemm_align_k-1); +#else + pad_min_l = (min_l + SBGEMM_ALIGN_K - 1) & ~(SBGEMM_ALIGN_K - 1);; #endif - -#if defined(HALF) && !defined(DYNAMIC_ARCH) && defined(NEOVERSEN2) - pad_min_l = (min_l + 3) & ~3; #endif /* First, we have to move data A to L2 cache */ diff --git a/driver/level3/level3_thread.c b/driver/level3/level3_thread.c index 95c8e6d19..02b60b50d 100644 --- a/driver/level3/level3_thread.c +++ b/driver/level3/level3_thread.c @@ -327,12 +327,12 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, BLASLONG pad_min_l = min_l; -#if defined(HALF) && defined(DYNAMIC_ARCH) - pad_min_l = (min_l + gotoblas->align_k - 1) & ~(gotoblas->align_k-1); +#if defined(HALF) +#if defined(DYNAMIC_ARCH) + pad_min_l = (min_l + gotoblas->sbgemm_align_k - 1) & ~(gotoblas->sbgemm_align_k-1); +#else + pad_min_l = (min_l + SBGEMM_ALIGN_K - 1) & ~(SBGEMM_ALIGN_K - 1);; #endif - -#if defined(HALF) && !defined(DYNAMIC_ARCH) && defined(NEOVERSEN2) - pad_min_l = (min_l + 3) & ~3; #endif /* Determine step size in m diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index effcf8965..16d19af1b 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -62,6 +62,8 @@ gotoblas_t TABLE_NAME = { MAX(SBGEMM_DEFAULT_UNROLL_M, SBGEMM_DEFAULT_UNROLL_N), #endif + SBGEMM_ALIGN_K, + sbstobf16_kTS, sbdtobf16_kTS, sbf16tos_kTS, dbf16tod_kTS, samax_kTS, samin_kTS, smax_kTS, smin_kTS, @@ -973,12 +975,6 @@ static void init_parameter(void) { TABLE_NAME.xgemm3m_r = TABLE_NAME.qgemm_r; #endif #endif - -#if defined(NEOVERSEN2) && BUILD_BFLOAT16 == 1 - TABLE_NAME.align_k = 4; -#else - TABLE_NAME.align_k = 1; -#endif } #else // (ARCH_ARM64) diff --git a/param.h b/param.h index b9b9a55e8..514b13a3a 100644 --- a/param.h +++ b/param.h @@ -79,6 +79,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define SBGEMM_DEFAULT_P 256 #define SBGEMM_DEFAULT_R 256 #define SBGEMM_DEFAULT_Q 256 +#define SBGEMM_ALIGN_K 1 // must be 2^x + #ifdef OPTERON #define SNUMOPT 4 @@ -3394,6 +3396,9 @@ is a big desktop or server with abundant cache rather than a phone or embedded d #elif defined(NEOVERSEN2) +#undef SBGEMM_ALIGN_K +#define SBGEMM_ALIGN_K 4 + #undef SBGEMM_DEFAULT_UNROLL_M #undef SBGEMM_DEFAULT_UNROLL_N #define SBGEMM_DEFAULT_UNROLL_M 8