From a87736346fd3988618c0d8895827566fce5a5487 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 13 May 2021 10:16:54 +0000 Subject: [PATCH] Small Matrix: skylakex: sgemm nn: add n6 to improve performance --- .../x86_64/sgemm_small_kernel_nn_skylakex.c | 90 ++++++++++++++++++- 1 file changed, 87 insertions(+), 3 deletions(-) diff --git a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c index c9f43f9a2..a67541161 100644 --- a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c +++ b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c @@ -110,6 +110,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp BLASLONG m4 = M & ~3; BLASLONG m2 = M & ~1; + BLASLONG n6 = N - (N % 6); BLASLONG n4 = N & ~3; BLASLONG n2 = N & ~1; @@ -165,7 +166,34 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp } } for (; i < m32; i += 32) { - for (j = 0; j < n4; j += 4) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); + DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + MATMUL_512(0, 4); MATMUL_512(1, 4); + MATMUL_512(0, 5); MATMUL_512(1, 5); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + STORE_512(0, 4); STORE_512(1, 4); + STORE_512(0, 5); STORE_512(1, 5); + } + for (;j < n4; j += 4) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); @@ -208,7 +236,34 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp } } for (; i < m16; i += 16) { - for (j = 0; j < n4; j += 4) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + STORE_512(0, 4); + STORE_512(0, 5); + } + for (; j < n4; j += 4) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(0, 2); @@ -228,6 +283,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp STORE_512(0, 2); STORE_512(0, 3); } + for (; j < n2; j += 2) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(0, 1); @@ -254,7 +310,34 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp if (!mm) return 0; if (mm > 8 || K < 32) { register __mmask16 mask asm("k1") = (1UL << mm) - 1; - for (j = 0; j < n4; j += 4) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + MASK_STORE_512(0, 2); + MASK_STORE_512(0, 3); + MASK_STORE_512(0, 4); + MASK_STORE_512(0, 5); + } + for (; j < n4; j += 4) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(0, 2); @@ -274,6 +357,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp MASK_STORE_512(0, 2); MASK_STORE_512(0, 3); } + for (; j < n2; j += 2) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(0, 1);