From bb1c4fa5bdf93724075ed400e3ff5bbdabd0b31a Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 18 Aug 2021 21:17:08 +0800 Subject: [PATCH] sbgemm: cooperlake: prefetch A & B --- kernel/x86_64/sbgemm_kernel_16x4_cooperlake.c | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/kernel/x86_64/sbgemm_kernel_16x4_cooperlake.c b/kernel/x86_64/sbgemm_kernel_16x4_cooperlake.c index 0280b441e..7af51b6d8 100644 --- a/kernel/x86_64/sbgemm_kernel_16x4_cooperlake.c +++ b/kernel/x86_64/sbgemm_kernel_16x4_cooperlake.c @@ -64,6 +64,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define DECLARE_B_PAIR() \ __m512i B_lo; __m512i B_hi; +#define PREFETCH_B_STEP 32 +#define PREFETCH_B(Bx, By) \ + if (By == 0) asm("prefetcht0 %c1(%0)": : "r"(ptr_b##Bx), "n"(PREFETCH_B_STEP * 2)); \ + else asm("prefetcht0 %c3(%0, %1, %c2)": : "r"(ptr_b##Bx), "r"(n_blksize), "n"(By*2), "n"(PREFETCH_B_STEP * 2)) + #define BROADCAST_B_PAIR(Bx, By) \ BROADCAST64(ptr_b##Bx, n_blksize, By, 0, B_lo); \ BROADCAST64(ptr_b##Bx, n_blksize, By, 4, B_hi); @@ -204,17 +209,19 @@ int CNAME (BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * k_count = k; for (; k_count > 3; k_count -=4) { LOAD_A_PAIR(0); + _mm_prefetch(ptr_a0 + 128, _MM_HINT_T0); ptr_a0 += 16 * 2; - BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0); - BROADCAST_B_PAIR(0, 1); MATMUL_4X(0, 0, 1); - BROADCAST_B_PAIR(0, 2); MATMUL_4X(0, 0, 2); + BROADCAST_B_PAIR(0, 0); PREFETCH_B(0, 0); MATMUL_4X(0, 0, 0); + BROADCAST_B_PAIR(0, 1); PREFETCH_B(0, 1); MATMUL_4X(0, 0, 1); + BROADCAST_B_PAIR(0, 2); PREFETCH_B(0, 2); MATMUL_4X(0, 0, 2); ptr_b0 += 4 * 2; - BROADCAST_B_PAIR(1, 0); MATMUL_4X(0, 1, 0); - BROADCAST_B_PAIR(1, 1); MATMUL_4X(0, 1, 1); - BROADCAST_B_PAIR(1, 2); MATMUL_4X(0, 1, 2); + BROADCAST_B_PAIR(1, 0); PREFETCH_B(1, 0); MATMUL_4X(0, 1, 0); + BROADCAST_B_PAIR(1, 1); PREFETCH_B(1, 1); MATMUL_4X(0, 1, 1); + BROADCAST_B_PAIR(1, 2); PREFETCH_B(1, 2); MATMUL_4X(0, 1, 2); ptr_b1 += 4 * 2; LOAD_A_PAIR(0); + _mm_prefetch(ptr_a0 + 128, _MM_HINT_T0); ptr_a0 += 16 * 2; BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0); BROADCAST_B_PAIR(0, 1); MATMUL_4X(0, 0, 1);