diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr.c b/kernel/x86_64/sbgemm_kernel_16x16_spr.c index b7b4e36a3..112adbe1c 100644 --- a/kernel/x86_64/sbgemm_kernel_16x16_spr.c +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr.c @@ -127,10 +127,14 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA BLASLONG m_count = m; BLASLONG n_count, k_count; + IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64))); IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64))); tilecfg cfg; + if (k < 32) + goto tail_k; + for (; m_count > 31; m_count -= 32) { ptr_b = B; @@ -140,6 +144,7 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA ptr_c11 = ptr_c10 + 16; ptr_c += 32 * ldc; n_count = n; + TCONF(cfg, 16, 16, 32); for (; n_count > 31; n_count -= 32) { ptr_a0 = ptr_a; ptr_a1 = ptr_a + 16 * k; @@ -150,7 +155,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA lda = 32; ldb = 32; - TCONF(cfg, 16, 16, 32); LOAD_C(0, 0); LOAD_C(0, 1); LOAD_C(1, 0); LOAD_C(1, 1); k_count = k; @@ -167,47 +171,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA } STORE_C(0, 0); STORE_C(0, 1); STORE_C(1, 0); STORE_C(1, 1); - if (k_count > 1) { - /* still have more than 2*k */ - int remain_k2 = k_count & ~1; - k_count -= remain_k2; - lda = remain_k2; - TCONF(cfg, 16, 16, remain_k2); - /* reconfig will clear all tiles, - * need to store/load again - */ - LOAD_C(0, 0); LOAD_C(0, 1); - LOAD_C(1, 0); LOAD_C(1, 1); - - LOAD_A(0, x); LOAD_A(1, x); - ptr_a0 += 16 * remain_k2; - ptr_a1 += 16 * remain_k2; - LOAD_B(x, 0); LOAD_B(x, 1); - ptr_b0 += 16 * remain_k2; - ptr_b1 += 16 * remain_k2; - - MATMUL(0, 0); MATMUL(0, 1); - MATMUL(1, 0); MATMUL(1, 1); - - STORE_C(0, 0); STORE_C(0, 1); - STORE_C(1, 0); STORE_C(1, 1); - } - if (k_count > 0) { - /* still have odd tail k, need to transform into 2*k */ - TCONF(cfg, 16, 16, 2); - - LOAD_C(0, 0); LOAD_C(0, 1); - LOAD_C(1, 0); LOAD_C(1, 1); - - LOAD_A_TAIL(0, x); LOAD_A_TAIL(1, x); - LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1); - - MATMUL(0, 0); MATMUL(0, 1); - MATMUL(1, 0); MATMUL(1, 1); - - STORE_C(0, 0); STORE_C(0, 1); - STORE_C(1, 0); STORE_C(1, 1); - } ptr_c00 += 32; ptr_c01 += 32; ptr_c10 += 32; @@ -240,45 +203,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA } STORE_C(0, 0); STORE_C(1, 0); - if (k_count > 1) { - /* still have more than 2*k */ - int remain_k2 = k_count & ~1; - k_count -= remain_k2; - lda = remain_k2; - TCONF(cfg, 16, tail_n, remain_k2); - /* reconfig will clear all tiles, - * need to store/load again - */ - LOAD_C(0, 0); - LOAD_C(1, 0); - - LOAD_A(0, x); LOAD_A(1, x); - ptr_a0 += 16 * remain_k2; - ptr_a1 += 16 * remain_k2; - LOAD_B(x, 0); - ptr_b0 += tail_n * remain_k2; - - MATMUL(0, 0); - MATMUL(1, 0); - - STORE_C(0, 0); - STORE_C(1, 0); - } - if (k_count > 0) { - /* still have odd tail k, need to transform into 2*k */ - TCONF(cfg, 16, tail_n, 2); - - LOAD_C(0, 0); - LOAD_C(1, 0); - - LOAD_A_TAIL(0, x); LOAD_A_TAIL(1, x); - MASK_LOAD_B_TAIL(x, 0); - MATMUL(0, 0); - MATMUL(1, 0); - - STORE_C(0, 0); - STORE_C(1, 0); - } ptr_c00 += tail_n; ptr_c10 += tail_n; } @@ -295,6 +219,7 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA ptr_c01 = ptr_c00 + 16; ptr_c += tail_m * ldc; n_count = n; + TCONF(cfg, tail_m, 16, 32); for (; n_count > 31; n_count -= 32) { ptr_a0 = ptr_a; @@ -304,7 +229,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA lda = 32; ldb = 32; - TCONF(cfg, tail_m, 16, 32); LOAD_C(0, 0); LOAD_C(0, 1); k_count = k; for (; k_count > 31; k_count -= 32) { @@ -317,40 +241,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA MATMUL(0, 0); MATMUL(0, 1); } STORE_C(0, 0); STORE_C(0, 1); - if (k_count > 1) { - /* still have more than 2*k */ - int remain_k2 = k_count & ~1; - k_count -= remain_k2; - lda = remain_k2; - TCONF(cfg, tail_m, 16, remain_k2); - /* reconfig will clear all tiles, - * need to store/load again - */ - LOAD_C(0, 0); LOAD_C(0, 1); - - LOAD_A(0, x); - ptr_a0 += tail_m * remain_k2; - LOAD_B(x, 0); LOAD_B(x, 1); - ptr_b0 += 16 * remain_k2; - ptr_b1 += 16 * remain_k2; - - MATMUL(0, 0); MATMUL(0, 1); - - STORE_C(0, 0); STORE_C(0, 1); - } - if (k_count > 0) { - /* still have odd tail k, need to transform into 2*k */ - TCONF(cfg, tail_m, 16, 2); - - LOAD_C(0, 0); LOAD_C(0, 1); - - MASK_LOAD_A_TAIL(0, x); - LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1); - - MATMUL(0, 0); MATMUL(0, 1); - - STORE_C(0, 0); STORE_C(0, 1); - } ptr_c00 += 32; ptr_c01 += 32; } @@ -376,41 +266,95 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA MATMUL(0, 0); } STORE_C(0, 0); - if (k_count > 1) { - /* still have more than 2*k */ - int remain_k2 = k_count & ~1; - k_count -= remain_k2; - lda = remain_k2; - TCONF(cfg, tail_m, tail_n, remain_k2); - /* reconfig will clear all tiles, - * need to store/load again - */ - LOAD_C(0, 0); - - LOAD_A(0, x); - ptr_a0 += tail_m * remain_k2; - LOAD_B(x, 0); - ptr_b0 += tail_n * remain_k2; - - MATMUL(0, 0); - - STORE_C(0, 0); - } - if (k_count > 0) { - /* still have odd tail k, need to transform into 2*k */ - TCONF(cfg, tail_m, tail_n, 2); - - LOAD_C(0, 0); - - MASK_LOAD_A_TAIL(0, x); - MASK_LOAD_B_TAIL(x, 0); - MATMUL(0, 0); - - STORE_C(0, 0); - } ptr_c00 += tail_n; } ptr_a += tail_m * k; } + +tail_k: + // process for k < 32 + BLASLONG k32 = k & ~31; + BLASLONG k2 = k & ~1; + int remain_k2 = k2 - k32; + if (remain_k2 > 0) { + m_count = m; + ptr_a = A; + ptr_c = C; + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + __mmask16 amask = (1UL << tail_m) - 1; + + ptr_a0 = ptr_a + tail_m * k32; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + lda = remain_k2; + ldb = 32; + TCONF(cfg, tail_m, 16, remain_k2); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k32; + LOAD_C(0, 0); + LOAD_A(0, x); + LOAD_B(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_b0 = ptr_b + tail_n * k32; + ldb = 2 * tail_n; + TCONF(cfg, tail_m, tail_n, remain_k2); + LOAD_C(0, 0); + LOAD_A(0, x); + LOAD_B(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + } + } + } + if (k2 != k) { + m_count = m; + ptr_a = A; + ptr_c = C; + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + __mmask16 amask = (1UL << tail_m) - 1; + + ptr_a0 = ptr_a + tail_m * k2; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + TCONF(cfg, tail_m, 16, 2); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k2; + LOAD_C(0, 0); + MASK_LOAD_A_TAIL(0, x); + LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_b0 = ptr_b + tail_n * k2; + TCONF(cfg, tail_m, tail_n, 2); + LOAD_C(0, 0); + MASK_LOAD_A_TAIL(0, x); + MASK_LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + } + } + + } return 0; }