diff --git a/kernel/x86_64/sbgemm_kernel_16x4_cooperlake.c b/kernel/x86_64/sbgemm_kernel_16x4_cooperlake.c index 4c1f50650..0280b441e 100644 --- a/kernel/x86_64/sbgemm_kernel_16x4_cooperlake.c +++ b/kernel/x86_64/sbgemm_kernel_16x4_cooperlake.c @@ -201,7 +201,31 @@ int CNAME (BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * DECLARE_B_PAIR(); DECLARE_RESULT_4X(0, 0, 0); DECLARE_RESULT_4X(0, 0, 1); DECLARE_RESULT_4X(0, 0, 2); DECLARE_RESULT_4X(0, 1, 0); DECLARE_RESULT_4X(0, 1, 1); DECLARE_RESULT_4X(0, 1, 2); - for (k_count = k; k_count > 1; k_count -=2) { + k_count = k; + for (; k_count > 3; k_count -=4) { + LOAD_A_PAIR(0); + ptr_a0 += 16 * 2; + BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0); + BROADCAST_B_PAIR(0, 1); MATMUL_4X(0, 0, 1); + BROADCAST_B_PAIR(0, 2); MATMUL_4X(0, 0, 2); + ptr_b0 += 4 * 2; + BROADCAST_B_PAIR(1, 0); MATMUL_4X(0, 1, 0); + BROADCAST_B_PAIR(1, 1); MATMUL_4X(0, 1, 1); + BROADCAST_B_PAIR(1, 2); MATMUL_4X(0, 1, 2); + ptr_b1 += 4 * 2; + + LOAD_A_PAIR(0); + ptr_a0 += 16 * 2; + BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0); + BROADCAST_B_PAIR(0, 1); MATMUL_4X(0, 0, 1); + BROADCAST_B_PAIR(0, 2); MATMUL_4X(0, 0, 2); + ptr_b0 += 4 * 2; + BROADCAST_B_PAIR(1, 0); MATMUL_4X(0, 1, 0); + BROADCAST_B_PAIR(1, 1); MATMUL_4X(0, 1, 1); + BROADCAST_B_PAIR(1, 2); MATMUL_4X(0, 1, 2); + ptr_b1 += 4 * 2; + } + for (; k_count > 1; k_count -=2) { LOAD_A_PAIR(0); ptr_a0 += 16 * 2; BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0);