diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr.c b/kernel/x86_64/sbgemm_kernel_16x16_spr.c index 112adbe1c..dbfacd6ab 100644 --- a/kernel/x86_64/sbgemm_kernel_16x16_spr.c +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr.c @@ -64,6 +64,28 @@ typedef struct { cfg.tile_colsb[7] = n * 4; \ _tile_loadconfig(&cfg); +/* CONFIG for handling k2 and odd tail at the same time + * tile0 -- A (m x 2k) + * tile1 -- A (m x 1) + * tile2 -- B (2k x n) + * tile3 -- B (1 x n) + * tile4 -- C (m x n) + */ +#define TCONF_TAIL(cfg, m, n, k2) \ + memset(&cfg, 0, sizeof(tilecfg)); \ + cfg.palette_id = 1; \ + cfg.tile_rows[0] = m; \ + cfg.tile_rows[1] = m; \ + cfg.tile_rows[2] = k2>>1; \ + cfg.tile_rows[3] = 1; \ + cfg.tile_rows[4] = m; \ + cfg.tile_colsb[0] = k2<<1; \ + cfg.tile_colsb[1] = 4; \ + cfg.tile_colsb[2] = n * 4; \ + cfg.tile_colsb[3] = n * 4; \ + cfg.tile_colsb[4] = n * 4; \ + _tile_loadconfig(&cfg); + #define T_A0 0 #define T_A1 1 #define T_B0 2 @@ -104,6 +126,7 @@ typedef struct { #define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4) #define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N) +#define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N) #define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4) @@ -275,86 +298,123 @@ tail_k: // process for k < 32 BLASLONG k32 = k & ~31; BLASLONG k2 = k & ~1; - int remain_k2 = k2 - k32; - if (remain_k2 > 0) { + if (k32 != k) { + int remain_k2 = k2 - k32; m_count = m; ptr_a = A; ptr_c = C; - for (; m_count > 0; m_count -= 16) { - int tail_m = (m_count > 16) ? 16: m_count; - __mmask16 amask = (1UL << tail_m) - 1; + if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0) + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + __mmask16 amask = (1UL << tail_m) - 1; - ptr_a0 = ptr_a + tail_m * k32; - ptr_a += tail_m * k; - ptr_b = B; - ptr_c00 = ptr_c; - ptr_c += tail_m * ldc; - n_count = n; - lda = remain_k2; - ldb = 32; - TCONF(cfg, tail_m, 16, remain_k2); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k32; - LOAD_C(0, 0); - LOAD_A(0, x); - LOAD_B(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; + ptr_a0 = ptr_a + tail_m * k32; + ptr_a1 = ptr_a + tail_m * k2; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + lda = remain_k2; + ldb = 32; + TCONF_TAIL(cfg, tail_m, 16, remain_k2); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k32; + ptr_b1 = ptr_b + 16 * k2; + LOAD_C(0, 0); + LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); + LOAD_B(x, 0); LOAD_B_TAIL(x, 1); + MATMUL(0, 0); MATMUL_TAIL(1, 1); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_b0 = ptr_b + tail_n * k32; + ptr_b1 = ptr_b + tail_n * k2; + ldb = 2 * tail_n; + TCONF_TAIL(cfg, tail_m, tail_n, remain_k2); + LOAD_C(0, 0); + LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); + LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1); + MATMUL(0, 0); MATMUL_TAIL(1, 1); + STORE_C(0, 0); + } } - if (n_count > 0) { - int tail_n = (n_count > 16) ? 16: n_count; - __mmask16 bmask = (1UL << tail_n) - 1; - ptr_b0 = ptr_b + tail_n * k32; - ldb = 2 * tail_n; - TCONF(cfg, tail_m, tail_n, remain_k2); - LOAD_C(0, 0); - LOAD_A(0, x); - LOAD_B(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); + + } else if (remain_k2 > 0) { // k%32 = 2x + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + + ptr_a0 = ptr_a + tail_m * k32; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + lda = remain_k2; + ldb = 32; + TCONF(cfg, tail_m, 16, remain_k2); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k32; + LOAD_C(0, 0); + LOAD_A(0, x); + LOAD_B(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + ptr_b0 = ptr_b + tail_n * k32; + ldb = 2 * tail_n; + TCONF(cfg, tail_m, tail_n, remain_k2); + LOAD_C(0, 0); + LOAD_A(0, x); + LOAD_B(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + } } + } else { // k%32 = 1 + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + __mmask16 amask = (1UL << tail_m) - 1; + + ptr_a0 = ptr_a + tail_m * k2; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + TCONF(cfg, tail_m, 16, 2); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k2; + LOAD_C(0, 0); + MASK_LOAD_A_TAIL(0, x); + LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_b0 = ptr_b + tail_n * k2; + TCONF(cfg, tail_m, tail_n, 2); + LOAD_C(0, 0); + MASK_LOAD_A_TAIL(0, x); + MASK_LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + } + } + } } - if (k2 != k) { - m_count = m; - ptr_a = A; - ptr_c = C; - for (; m_count > 0; m_count -= 16) { - int tail_m = (m_count > 16) ? 16: m_count; - __mmask16 amask = (1UL << tail_m) - 1; - - ptr_a0 = ptr_a + tail_m * k2; - ptr_a += tail_m * k; - ptr_b = B; - ptr_c00 = ptr_c; - ptr_c += tail_m * ldc; - n_count = n; - TCONF(cfg, tail_m, 16, 2); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k2; - LOAD_C(0, 0); - MASK_LOAD_A_TAIL(0, x); - LOAD_B_TAIL(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; - } - if (n_count > 0) { - int tail_n = (n_count > 16) ? 16: n_count; - __mmask16 bmask = (1UL << tail_n) - 1; - ptr_b0 = ptr_b + tail_n * k2; - TCONF(cfg, tail_m, tail_n, 2); - LOAD_C(0, 0); - MASK_LOAD_A_TAIL(0, x); - MASK_LOAD_B_TAIL(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - } - } - - } return 0; }