sbgemm: spr: reduce tile conf loading by seperate tail k handling

This commit is contained in:
Wangyang Guo 2021-09-15 01:11:15 -07:00
parent 0abbcd19c1
commit 88154ed02d
1 changed files with 92 additions and 148 deletions

View File

@ -127,10 +127,14 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
BLASLONG m_count = m;
BLASLONG n_count, k_count;
IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64)));
IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64)));
tilecfg cfg;
if (k < 32)
goto tail_k;
for (; m_count > 31; m_count -= 32) {
ptr_b = B;
@ -140,6 +144,7 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
ptr_c11 = ptr_c10 + 16;
ptr_c += 32 * ldc;
n_count = n;
TCONF(cfg, 16, 16, 32);
for (; n_count > 31; n_count -= 32) {
ptr_a0 = ptr_a;
ptr_a1 = ptr_a + 16 * k;
@ -150,7 +155,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
lda = 32;
ldb = 32;
TCONF(cfg, 16, 16, 32);
LOAD_C(0, 0); LOAD_C(0, 1);
LOAD_C(1, 0); LOAD_C(1, 1);
k_count = k;
@ -167,47 +171,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
}
STORE_C(0, 0); STORE_C(0, 1);
STORE_C(1, 0); STORE_C(1, 1);
if (k_count > 1) {
/* still have more than 2*k */
int remain_k2 = k_count & ~1;
k_count -= remain_k2;
lda = remain_k2;
TCONF(cfg, 16, 16, remain_k2);
/* reconfig will clear all tiles,
* need to store/load again
*/
LOAD_C(0, 0); LOAD_C(0, 1);
LOAD_C(1, 0); LOAD_C(1, 1);
LOAD_A(0, x); LOAD_A(1, x);
ptr_a0 += 16 * remain_k2;
ptr_a1 += 16 * remain_k2;
LOAD_B(x, 0); LOAD_B(x, 1);
ptr_b0 += 16 * remain_k2;
ptr_b1 += 16 * remain_k2;
MATMUL(0, 0); MATMUL(0, 1);
MATMUL(1, 0); MATMUL(1, 1);
STORE_C(0, 0); STORE_C(0, 1);
STORE_C(1, 0); STORE_C(1, 1);
}
if (k_count > 0) {
/* still have odd tail k, need to transform into 2*k */
TCONF(cfg, 16, 16, 2);
LOAD_C(0, 0); LOAD_C(0, 1);
LOAD_C(1, 0); LOAD_C(1, 1);
LOAD_A_TAIL(0, x); LOAD_A_TAIL(1, x);
LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1);
MATMUL(0, 0); MATMUL(0, 1);
MATMUL(1, 0); MATMUL(1, 1);
STORE_C(0, 0); STORE_C(0, 1);
STORE_C(1, 0); STORE_C(1, 1);
}
ptr_c00 += 32;
ptr_c01 += 32;
ptr_c10 += 32;
@ -240,45 +203,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
}
STORE_C(0, 0);
STORE_C(1, 0);
if (k_count > 1) {
/* still have more than 2*k */
int remain_k2 = k_count & ~1;
k_count -= remain_k2;
lda = remain_k2;
TCONF(cfg, 16, tail_n, remain_k2);
/* reconfig will clear all tiles,
* need to store/load again
*/
LOAD_C(0, 0);
LOAD_C(1, 0);
LOAD_A(0, x); LOAD_A(1, x);
ptr_a0 += 16 * remain_k2;
ptr_a1 += 16 * remain_k2;
LOAD_B(x, 0);
ptr_b0 += tail_n * remain_k2;
MATMUL(0, 0);
MATMUL(1, 0);
STORE_C(0, 0);
STORE_C(1, 0);
}
if (k_count > 0) {
/* still have odd tail k, need to transform into 2*k */
TCONF(cfg, 16, tail_n, 2);
LOAD_C(0, 0);
LOAD_C(1, 0);
LOAD_A_TAIL(0, x); LOAD_A_TAIL(1, x);
MASK_LOAD_B_TAIL(x, 0);
MATMUL(0, 0);
MATMUL(1, 0);
STORE_C(0, 0);
STORE_C(1, 0);
}
ptr_c00 += tail_n;
ptr_c10 += tail_n;
}
@ -295,6 +219,7 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
ptr_c01 = ptr_c00 + 16;
ptr_c += tail_m * ldc;
n_count = n;
TCONF(cfg, tail_m, 16, 32);
for (; n_count > 31; n_count -= 32) {
ptr_a0 = ptr_a;
@ -304,7 +229,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
lda = 32;
ldb = 32;
TCONF(cfg, tail_m, 16, 32);
LOAD_C(0, 0); LOAD_C(0, 1);
k_count = k;
for (; k_count > 31; k_count -= 32) {
@ -317,40 +241,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
MATMUL(0, 0); MATMUL(0, 1);
}
STORE_C(0, 0); STORE_C(0, 1);
if (k_count > 1) {
/* still have more than 2*k */
int remain_k2 = k_count & ~1;
k_count -= remain_k2;
lda = remain_k2;
TCONF(cfg, tail_m, 16, remain_k2);
/* reconfig will clear all tiles,
* need to store/load again
*/
LOAD_C(0, 0); LOAD_C(0, 1);
LOAD_A(0, x);
ptr_a0 += tail_m * remain_k2;
LOAD_B(x, 0); LOAD_B(x, 1);
ptr_b0 += 16 * remain_k2;
ptr_b1 += 16 * remain_k2;
MATMUL(0, 0); MATMUL(0, 1);
STORE_C(0, 0); STORE_C(0, 1);
}
if (k_count > 0) {
/* still have odd tail k, need to transform into 2*k */
TCONF(cfg, tail_m, 16, 2);
LOAD_C(0, 0); LOAD_C(0, 1);
MASK_LOAD_A_TAIL(0, x);
LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1);
MATMUL(0, 0); MATMUL(0, 1);
STORE_C(0, 0); STORE_C(0, 1);
}
ptr_c00 += 32;
ptr_c01 += 32;
}
@ -376,41 +266,95 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
MATMUL(0, 0);
}
STORE_C(0, 0);
if (k_count > 1) {
/* still have more than 2*k */
int remain_k2 = k_count & ~1;
k_count -= remain_k2;
lda = remain_k2;
TCONF(cfg, tail_m, tail_n, remain_k2);
/* reconfig will clear all tiles,
* need to store/load again
*/
LOAD_C(0, 0);
LOAD_A(0, x);
ptr_a0 += tail_m * remain_k2;
LOAD_B(x, 0);
ptr_b0 += tail_n * remain_k2;
MATMUL(0, 0);
STORE_C(0, 0);
}
if (k_count > 0) {
/* still have odd tail k, need to transform into 2*k */
TCONF(cfg, tail_m, tail_n, 2);
LOAD_C(0, 0);
MASK_LOAD_A_TAIL(0, x);
MASK_LOAD_B_TAIL(x, 0);
MATMUL(0, 0);
STORE_C(0, 0);
}
ptr_c00 += tail_n;
}
ptr_a += tail_m * k;
}
tail_k:
// process for k < 32
BLASLONG k32 = k & ~31;
BLASLONG k2 = k & ~1;
int remain_k2 = k2 - k32;
if (remain_k2 > 0) {
m_count = m;
ptr_a = A;
ptr_c = C;
for (; m_count > 0; m_count -= 16) {
int tail_m = (m_count > 16) ? 16: m_count;
__mmask16 amask = (1UL << tail_m) - 1;
ptr_a0 = ptr_a + tail_m * k32;
ptr_a += tail_m * k;
ptr_b = B;
ptr_c00 = ptr_c;
ptr_c += tail_m * ldc;
n_count = n;
lda = remain_k2;
ldb = 32;
TCONF(cfg, tail_m, 16, remain_k2);
for (; n_count > 15; n_count -= 16) {
ptr_b0 = ptr_b + 16 * k32;
LOAD_C(0, 0);
LOAD_A(0, x);
LOAD_B(x, 0);
MATMUL(0, 0);
STORE_C(0, 0);
ptr_b += 16 * k;
ptr_c00 += 16;
}
if (n_count > 0) {
int tail_n = (n_count > 16) ? 16: n_count;
__mmask16 bmask = (1UL << tail_n) - 1;
ptr_b0 = ptr_b + tail_n * k32;
ldb = 2 * tail_n;
TCONF(cfg, tail_m, tail_n, remain_k2);
LOAD_C(0, 0);
LOAD_A(0, x);
LOAD_B(x, 0);
MATMUL(0, 0);
STORE_C(0, 0);
}
}
}
if (k2 != k) {
m_count = m;
ptr_a = A;
ptr_c = C;
for (; m_count > 0; m_count -= 16) {
int tail_m = (m_count > 16) ? 16: m_count;
__mmask16 amask = (1UL << tail_m) - 1;
ptr_a0 = ptr_a + tail_m * k2;
ptr_a += tail_m * k;
ptr_b = B;
ptr_c00 = ptr_c;
ptr_c += tail_m * ldc;
n_count = n;
TCONF(cfg, tail_m, 16, 2);
for (; n_count > 15; n_count -= 16) {
ptr_b0 = ptr_b + 16 * k2;
LOAD_C(0, 0);
MASK_LOAD_A_TAIL(0, x);
LOAD_B_TAIL(x, 0);
MATMUL(0, 0);
STORE_C(0, 0);
ptr_b += 16 * k;
ptr_c00 += 16;
}
if (n_count > 0) {
int tail_n = (n_count > 16) ? 16: n_count;
__mmask16 bmask = (1UL << tail_n) - 1;
ptr_b0 = ptr_b + tail_n * k2;
TCONF(cfg, tail_m, tail_n, 2);
LOAD_C(0, 0);
MASK_LOAD_A_TAIL(0, x);
MASK_LOAD_B_TAIL(x, 0);
MATMUL(0, 0);
STORE_C(0, 0);
}
}
}
return 0;
}