diff --git a/common_param.h b/common_param.h index 8b39ca2fc..c082d248e 100644 --- a/common_param.h +++ b/common_param.h @@ -53,6 +53,7 @@ typedef struct { int sbgemm_p, sbgemm_q, sbgemm_r; int sbgemm_unroll_m, sbgemm_unroll_n, sbgemm_unroll_mn; int sbgemm_align_k; + int need_amxtile_permission; // 0 default, 1 for device support amx. void (*sbstobf16_k) (BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG); void (*sbdtobf16_k) (BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG); diff --git a/interface/gemm.c b/interface/gemm.c index 71cc77a1b..285b99eb9 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -154,6 +154,23 @@ static size_t zgemm_small_kernel_b0[] = { #endif #endif +#if defined(__linux__) && defined(BFLOAT16) +#define XFEATURE_XTILEDATA 18 +#define ARCH_REQ_XCOMP_PERM 0x1023 +static int openblas_amxtile_permission = 0; +static int init_amxtile_permission() { + long status = + syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); + if (status != 0) { + fprintf(stderr, "XTILEDATA permission not granted in your device(Linux, " + "Intel Sapphier Rapids), skip sbgemm calculation\n"); + return -1; + } + openblas_amxtile_permission = 1; + return 0; +} +#endif + #ifndef CBLAS void NAME(char *TRANSA, char *TRANSB, @@ -455,6 +472,20 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS #endif +#if defined(__linux__) && defined(BFLOAT16) +#if defined(DYNAMIC_ARCH) + if (gotoblas->need_amxtile_permission && + openblas_amxtile_permission == 0 && init_amxtile_permission() == -1) { + return; + } +#endif +#if !defined(DYNAMIC_ARCH) && defined(SAPPHIRERAPIDS) + if (openblas_amxtile_permission == 0 && init_amxtile_permission() == -1) { + return; + } +#endif +#endif // defined(__linux__) && defined(BFLOAT16) + if ((args.m == 0) || (args.n == 0)) return; #if 0 diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index 7832c0a87..4c361f155 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -66,6 +66,7 @@ gotoblas_t TABLE_NAME = { #endif SBGEMM_ALIGN_K, + 0, // need_amxtile_permission sbstobf16_kTS, sbdtobf16_kTS, sbf16tos_kTS, dbf16tod_kTS, @@ -1809,6 +1810,12 @@ static void init_parameter(void) { #endif +#ifdef SAPPHIRERAPIDS +#if (BUILD_BFLOAT16 == 1) + TABLE_NAME.need_amxtile_permission = 1; +#endif +#endif + #if BUILD_COMPLEX==1 #ifdef CGEMM3M_DEFAULT_P TABLE_NAME.cgemm3m_p = CGEMM3M_DEFAULT_P;