From f2485352a603f11bd3c7b45788c1c80449ba76eb Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 16 Sep 2021 01:04:01 -0700 Subject: [PATCH] sbgemm: spr: only load A once in tail_k handling --- kernel/x86_64/sbgemm_kernel_16x16_spr.c | 62 ++++++++++++++----------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr.c b/kernel/x86_64/sbgemm_kernel_16x16_spr.c index dbfacd6ab..b34035896 100644 --- a/kernel/x86_64/sbgemm_kernel_16x16_spr.c +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr.c @@ -317,17 +317,19 @@ tail_k: n_count = n; lda = remain_k2; ldb = 32; - TCONF_TAIL(cfg, tail_m, 16, remain_k2); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k32; - ptr_b1 = ptr_b + 16 * k2; - LOAD_C(0, 0); + if (n_count > 15) { + TCONF_TAIL(cfg, tail_m, 16, remain_k2); LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); - LOAD_B(x, 0); LOAD_B_TAIL(x, 1); - MATMUL(0, 0); MATMUL_TAIL(1, 1); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k32; + ptr_b1 = ptr_b + 16 * k2; + LOAD_C(0, 0); + LOAD_B(x, 0); LOAD_B_TAIL(x, 1); + MATMUL(0, 0); MATMUL_TAIL(1, 1); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } } if (n_count > 0) { int tail_n = (n_count > 16) ? 16: n_count; @@ -356,16 +358,18 @@ tail_k: n_count = n; lda = remain_k2; ldb = 32; - TCONF(cfg, tail_m, 16, remain_k2); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k32; - LOAD_C(0, 0); + if (n_count > 15) { + TCONF(cfg, tail_m, 16, remain_k2); LOAD_A(0, x); - LOAD_B(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k32; + LOAD_C(0, 0); + LOAD_B(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } } if (n_count > 0) { int tail_n = (n_count > 16) ? 16: n_count; @@ -390,16 +394,18 @@ tail_k: ptr_c00 = ptr_c; ptr_c += tail_m * ldc; n_count = n; - TCONF(cfg, tail_m, 16, 2); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k2; - LOAD_C(0, 0); + if (n_count > 15) { + TCONF(cfg, tail_m, 16, 2); MASK_LOAD_A_TAIL(0, x); - LOAD_B_TAIL(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k2; + LOAD_C(0, 0); + LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } } if (n_count > 0) { int tail_n = (n_count > 16) ? 16: n_count;