From f249ccb741653b8f72783bdd20912fc2baef50ca Mon Sep 17 00:00:00 2001 From: Honglin Zhu Date: Thu, 18 May 2023 23:51:37 +0800 Subject: [PATCH] Fix spr sbgemm error --- cpuid_x86.c | 8 ++++++-- kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c | 11 +++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/cpuid_x86.c b/cpuid_x86.c index 69cbba90e..c2486e380 100644 --- a/cpuid_x86.c +++ b/cpuid_x86.c @@ -1479,6 +1479,8 @@ int get_cpuname(void){ else return CPUTYPE_NEHALEM; case 15: // Sapphire Rapids + if(support_amx_bf16()) + return CPUTYPE_SAPPHIRERAPIDS; if(support_avx512_bf16()) return CPUTYPE_COOPERLAKE; if(support_avx512()) @@ -1845,7 +1847,8 @@ static char *cpuname[] = { "ZEN", "SKYLAKEX", "DHYANA", - "COOPERLAKE" + "COOPERLAKE", + "SAPPHIRERAPIDS", }; static char *lowercpuname[] = { @@ -1902,7 +1905,8 @@ static char *lowercpuname[] = { "zen", "skylakex", "dhyana", - "cooperlake" + "cooperlake", + "sapphirerapids", }; static char *corename[] = { diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c b/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c index 90e0a32c7..5ee3c8532 100644 --- a/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c @@ -97,33 +97,32 @@ typedef struct { #define T_C10 6 #define T_C11 7 -// FIXME: gcc11 seem have problem in tile load/store address calc, -// need to multiply with element size (2 or 4) here. + #define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2) #define LOAD_A_TAIL(M, N) {\ __m256i ymm = _mm256_loadu_epi16(ptr_a##M); \ __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ - _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ + _tile_loadd(T_A##M, tail_a + 16 * M, 2 * 2); \ } #define MASK_LOAD_A_TAIL(M, N) {\ __m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \ __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ - _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ + _tile_loadd(T_A##M, tail_a + 16 * M, 2 * 2); \ } #define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2) #define LOAD_B_TAIL(M, N) {\ __m256i ymm = _mm256_loadu_epi16(ptr_b##N); \ __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ _mm512_storeu_epi16(tail_b + 16 * N, zmm); \ - _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ + _tile_loadd(T_B##N, tail_b + 16 * N, 2 * 2); \ } #define MASK_LOAD_B_TAIL(M, N) {\ __m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \ __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ _mm512_storeu_epi16(tail_b + 16 * N, zmm); \ - _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ + _tile_loadd(T_B##N, tail_b + 16 * N, 2 * 2); \ } #define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N)