Small Matrix: skylakex: sgemm nn: add n6 to improve performance

This commit is contained in:
Wangyang Guo 2021-05-13 10:16:54 +00:00
parent 4c9d9940fd
commit a87736346f
1 changed files with 87 additions and 3 deletions

View File

@ -110,6 +110,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
BLASLONG m4 = M & ~3;
BLASLONG m2 = M & ~1;
BLASLONG n6 = N - (N % 6);
BLASLONG n4 = N & ~3;
BLASLONG n2 = N & ~1;
@ -165,7 +166,34 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
}
}
for (; i < m32; i += 32) {
for (j = 0; j < n4; j += 4) {
for (j = 0; j < n6; j += 6) {
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1);
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2);
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3);
DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4);
DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5);
for (k = 0; k < K; k++) {
LOAD_A_512(0, x); LOAD_A_512(1, x);
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1);
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3);
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5);
MATMUL_512(0, 0); MATMUL_512(1, 0);
MATMUL_512(0, 1); MATMUL_512(1, 1);
MATMUL_512(0, 2); MATMUL_512(1, 2);
MATMUL_512(0, 3); MATMUL_512(1, 3);
MATMUL_512(0, 4); MATMUL_512(1, 4);
MATMUL_512(0, 5); MATMUL_512(1, 5);
}
STORE_512(0, 0); STORE_512(1, 0);
STORE_512(0, 1); STORE_512(1, 1);
STORE_512(0, 2); STORE_512(1, 2);
STORE_512(0, 3); STORE_512(1, 3);
STORE_512(0, 4); STORE_512(1, 4);
STORE_512(0, 5); STORE_512(1, 5);
}
for (;j < n4; j += 4) {
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1);
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2);
@ -208,7 +236,34 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
}
}
for (; i < m16; i += 16) {
for (j = 0; j < n4; j += 4) {
for (j = 0; j < n6; j += 6) {
DECLARE_RESULT_512(0, 0);
DECLARE_RESULT_512(0, 1);
DECLARE_RESULT_512(0, 2);
DECLARE_RESULT_512(0, 3);
DECLARE_RESULT_512(0, 4);
DECLARE_RESULT_512(0, 5);
for (k = 0; k < K; k++) {
LOAD_A_512(0, x);
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1);
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3);
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5);
MATMUL_512(0, 0);
MATMUL_512(0, 1);
MATMUL_512(0, 2);
MATMUL_512(0, 3);
MATMUL_512(0, 4);
MATMUL_512(0, 5);
}
STORE_512(0, 0);
STORE_512(0, 1);
STORE_512(0, 2);
STORE_512(0, 3);
STORE_512(0, 4);
STORE_512(0, 5);
}
for (; j < n4; j += 4) {
DECLARE_RESULT_512(0, 0);
DECLARE_RESULT_512(0, 1);
DECLARE_RESULT_512(0, 2);
@ -228,6 +283,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
STORE_512(0, 2);
STORE_512(0, 3);
}
for (; j < n2; j += 2) {
DECLARE_RESULT_512(0, 0);
DECLARE_RESULT_512(0, 1);
@ -254,7 +310,34 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
if (!mm) return 0;
if (mm > 8 || K < 32) {
register __mmask16 mask asm("k1") = (1UL << mm) - 1;
for (j = 0; j < n4; j += 4) {
for (j = 0; j < n6; j += 6) {
DECLARE_RESULT_512(0, 0);
DECLARE_RESULT_512(0, 1);
DECLARE_RESULT_512(0, 2);
DECLARE_RESULT_512(0, 3);
DECLARE_RESULT_512(0, 4);
DECLARE_RESULT_512(0, 5);
for (k = 0; k < K; k++) {
MASK_LOAD_A_512(0, x);
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1);
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3);
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5);
MATMUL_512(0, 0);
MATMUL_512(0, 1);
MATMUL_512(0, 2);
MATMUL_512(0, 3);
MATMUL_512(0, 4);
MATMUL_512(0, 5);
}
MASK_STORE_512(0, 0);
MASK_STORE_512(0, 1);
MASK_STORE_512(0, 2);
MASK_STORE_512(0, 3);
MASK_STORE_512(0, 4);
MASK_STORE_512(0, 5);
}
for (; j < n4; j += 4) {
DECLARE_RESULT_512(0, 0);
DECLARE_RESULT_512(0, 1);
DECLARE_RESULT_512(0, 2);
@ -274,6 +357,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
MASK_STORE_512(0, 2);
MASK_STORE_512(0, 3);
}
for (; j < n2; j += 2) {
DECLARE_RESULT_512(0, 0);
DECLARE_RESULT_512(0, 1);