Fix spr sbgemm error
This commit is contained in:
parent
941a34bb96
commit
f249ccb741
|
@ -1479,6 +1479,8 @@ int get_cpuname(void){
|
||||||
else
|
else
|
||||||
return CPUTYPE_NEHALEM;
|
return CPUTYPE_NEHALEM;
|
||||||
case 15: // Sapphire Rapids
|
case 15: // Sapphire Rapids
|
||||||
|
if(support_amx_bf16())
|
||||||
|
return CPUTYPE_SAPPHIRERAPIDS;
|
||||||
if(support_avx512_bf16())
|
if(support_avx512_bf16())
|
||||||
return CPUTYPE_COOPERLAKE;
|
return CPUTYPE_COOPERLAKE;
|
||||||
if(support_avx512())
|
if(support_avx512())
|
||||||
|
@ -1845,7 +1847,8 @@ static char *cpuname[] = {
|
||||||
"ZEN",
|
"ZEN",
|
||||||
"SKYLAKEX",
|
"SKYLAKEX",
|
||||||
"DHYANA",
|
"DHYANA",
|
||||||
"COOPERLAKE"
|
"COOPERLAKE",
|
||||||
|
"SAPPHIRERAPIDS",
|
||||||
};
|
};
|
||||||
|
|
||||||
static char *lowercpuname[] = {
|
static char *lowercpuname[] = {
|
||||||
|
@ -1902,7 +1905,8 @@ static char *lowercpuname[] = {
|
||||||
"zen",
|
"zen",
|
||||||
"skylakex",
|
"skylakex",
|
||||||
"dhyana",
|
"dhyana",
|
||||||
"cooperlake"
|
"cooperlake",
|
||||||
|
"sapphirerapids",
|
||||||
};
|
};
|
||||||
|
|
||||||
static char *corename[] = {
|
static char *corename[] = {
|
||||||
|
|
|
@ -97,33 +97,32 @@ typedef struct {
|
||||||
#define T_C10 6
|
#define T_C10 6
|
||||||
#define T_C11 7
|
#define T_C11 7
|
||||||
|
|
||||||
// FIXME: gcc11 seem have problem in tile load/store address calc,
|
|
||||||
// need to multiply with element size (2 or 4) here.
|
|
||||||
#define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2)
|
#define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2)
|
||||||
#define LOAD_A_TAIL(M, N) {\
|
#define LOAD_A_TAIL(M, N) {\
|
||||||
__m256i ymm = _mm256_loadu_epi16(ptr_a##M); \
|
__m256i ymm = _mm256_loadu_epi16(ptr_a##M); \
|
||||||
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
|
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
|
||||||
_mm512_storeu_epi16(tail_a + 16 * M, zmm); \
|
_mm512_storeu_epi16(tail_a + 16 * M, zmm); \
|
||||||
_tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
|
_tile_loadd(T_A##M, tail_a + 16 * M, 2 * 2); \
|
||||||
}
|
}
|
||||||
#define MASK_LOAD_A_TAIL(M, N) {\
|
#define MASK_LOAD_A_TAIL(M, N) {\
|
||||||
__m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \
|
__m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \
|
||||||
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
|
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
|
||||||
_mm512_storeu_epi16(tail_a + 16 * M, zmm); \
|
_mm512_storeu_epi16(tail_a + 16 * M, zmm); \
|
||||||
_tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
|
_tile_loadd(T_A##M, tail_a + 16 * M, 2 * 2); \
|
||||||
}
|
}
|
||||||
#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2)
|
#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2)
|
||||||
#define LOAD_B_TAIL(M, N) {\
|
#define LOAD_B_TAIL(M, N) {\
|
||||||
__m256i ymm = _mm256_loadu_epi16(ptr_b##N); \
|
__m256i ymm = _mm256_loadu_epi16(ptr_b##N); \
|
||||||
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
|
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
|
||||||
_mm512_storeu_epi16(tail_b + 16 * N, zmm); \
|
_mm512_storeu_epi16(tail_b + 16 * N, zmm); \
|
||||||
_tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \
|
_tile_loadd(T_B##N, tail_b + 16 * N, 2 * 2); \
|
||||||
}
|
}
|
||||||
#define MASK_LOAD_B_TAIL(M, N) {\
|
#define MASK_LOAD_B_TAIL(M, N) {\
|
||||||
__m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \
|
__m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \
|
||||||
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
|
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
|
||||||
_mm512_storeu_epi16(tail_b + 16 * N, zmm); \
|
_mm512_storeu_epi16(tail_b + 16 * N, zmm); \
|
||||||
_tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \
|
_tile_loadd(T_B##N, tail_b + 16 * N, 2 * 2); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N)
|
#define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N)
|
||||||
|
|
Loading…
Reference in New Issue