sbgemm: spr: reduce tile conf loading by seperate tail k handling
This commit is contained in:
parent
0abbcd19c1
commit
88154ed02d
|
@ -127,10 +127,14 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
|||
BLASLONG m_count = m;
|
||||
BLASLONG n_count, k_count;
|
||||
|
||||
|
||||
IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64)));
|
||||
IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64)));
|
||||
tilecfg cfg;
|
||||
|
||||
if (k < 32)
|
||||
goto tail_k;
|
||||
|
||||
for (; m_count > 31; m_count -= 32) {
|
||||
ptr_b = B;
|
||||
|
||||
|
@ -140,6 +144,7 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
|||
ptr_c11 = ptr_c10 + 16;
|
||||
ptr_c += 32 * ldc;
|
||||
n_count = n;
|
||||
TCONF(cfg, 16, 16, 32);
|
||||
for (; n_count > 31; n_count -= 32) {
|
||||
ptr_a0 = ptr_a;
|
||||
ptr_a1 = ptr_a + 16 * k;
|
||||
|
@ -150,7 +155,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
|||
|
||||
lda = 32;
|
||||
ldb = 32;
|
||||
TCONF(cfg, 16, 16, 32);
|
||||
LOAD_C(0, 0); LOAD_C(0, 1);
|
||||
LOAD_C(1, 0); LOAD_C(1, 1);
|
||||
k_count = k;
|
||||
|
@ -167,47 +171,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
|||
}
|
||||
STORE_C(0, 0); STORE_C(0, 1);
|
||||
STORE_C(1, 0); STORE_C(1, 1);
|
||||
if (k_count > 1) {
|
||||
/* still have more than 2*k */
|
||||
int remain_k2 = k_count & ~1;
|
||||
k_count -= remain_k2;
|
||||
lda = remain_k2;
|
||||
TCONF(cfg, 16, 16, remain_k2);
|
||||
/* reconfig will clear all tiles,
|
||||
* need to store/load again
|
||||
*/
|
||||
LOAD_C(0, 0); LOAD_C(0, 1);
|
||||
LOAD_C(1, 0); LOAD_C(1, 1);
|
||||
|
||||
LOAD_A(0, x); LOAD_A(1, x);
|
||||
ptr_a0 += 16 * remain_k2;
|
||||
ptr_a1 += 16 * remain_k2;
|
||||
LOAD_B(x, 0); LOAD_B(x, 1);
|
||||
ptr_b0 += 16 * remain_k2;
|
||||
ptr_b1 += 16 * remain_k2;
|
||||
|
||||
MATMUL(0, 0); MATMUL(0, 1);
|
||||
MATMUL(1, 0); MATMUL(1, 1);
|
||||
|
||||
STORE_C(0, 0); STORE_C(0, 1);
|
||||
STORE_C(1, 0); STORE_C(1, 1);
|
||||
}
|
||||
if (k_count > 0) {
|
||||
/* still have odd tail k, need to transform into 2*k */
|
||||
TCONF(cfg, 16, 16, 2);
|
||||
|
||||
LOAD_C(0, 0); LOAD_C(0, 1);
|
||||
LOAD_C(1, 0); LOAD_C(1, 1);
|
||||
|
||||
LOAD_A_TAIL(0, x); LOAD_A_TAIL(1, x);
|
||||
LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1);
|
||||
|
||||
MATMUL(0, 0); MATMUL(0, 1);
|
||||
MATMUL(1, 0); MATMUL(1, 1);
|
||||
|
||||
STORE_C(0, 0); STORE_C(0, 1);
|
||||
STORE_C(1, 0); STORE_C(1, 1);
|
||||
}
|
||||
ptr_c00 += 32;
|
||||
ptr_c01 += 32;
|
||||
ptr_c10 += 32;
|
||||
|
@ -240,45 +203,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
|||
}
|
||||
STORE_C(0, 0);
|
||||
STORE_C(1, 0);
|
||||
if (k_count > 1) {
|
||||
/* still have more than 2*k */
|
||||
int remain_k2 = k_count & ~1;
|
||||
k_count -= remain_k2;
|
||||
lda = remain_k2;
|
||||
TCONF(cfg, 16, tail_n, remain_k2);
|
||||
/* reconfig will clear all tiles,
|
||||
* need to store/load again
|
||||
*/
|
||||
LOAD_C(0, 0);
|
||||
LOAD_C(1, 0);
|
||||
|
||||
LOAD_A(0, x); LOAD_A(1, x);
|
||||
ptr_a0 += 16 * remain_k2;
|
||||
ptr_a1 += 16 * remain_k2;
|
||||
LOAD_B(x, 0);
|
||||
ptr_b0 += tail_n * remain_k2;
|
||||
|
||||
MATMUL(0, 0);
|
||||
MATMUL(1, 0);
|
||||
|
||||
STORE_C(0, 0);
|
||||
STORE_C(1, 0);
|
||||
}
|
||||
if (k_count > 0) {
|
||||
/* still have odd tail k, need to transform into 2*k */
|
||||
TCONF(cfg, 16, tail_n, 2);
|
||||
|
||||
LOAD_C(0, 0);
|
||||
LOAD_C(1, 0);
|
||||
|
||||
LOAD_A_TAIL(0, x); LOAD_A_TAIL(1, x);
|
||||
MASK_LOAD_B_TAIL(x, 0);
|
||||
MATMUL(0, 0);
|
||||
MATMUL(1, 0);
|
||||
|
||||
STORE_C(0, 0);
|
||||
STORE_C(1, 0);
|
||||
}
|
||||
ptr_c00 += tail_n;
|
||||
ptr_c10 += tail_n;
|
||||
}
|
||||
|
@ -295,6 +219,7 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
|||
ptr_c01 = ptr_c00 + 16;
|
||||
ptr_c += tail_m * ldc;
|
||||
n_count = n;
|
||||
TCONF(cfg, tail_m, 16, 32);
|
||||
for (; n_count > 31; n_count -= 32) {
|
||||
ptr_a0 = ptr_a;
|
||||
|
||||
|
@ -304,7 +229,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
|||
|
||||
lda = 32;
|
||||
ldb = 32;
|
||||
TCONF(cfg, tail_m, 16, 32);
|
||||
LOAD_C(0, 0); LOAD_C(0, 1);
|
||||
k_count = k;
|
||||
for (; k_count > 31; k_count -= 32) {
|
||||
|
@ -317,40 +241,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
|||
MATMUL(0, 0); MATMUL(0, 1);
|
||||
}
|
||||
STORE_C(0, 0); STORE_C(0, 1);
|
||||
if (k_count > 1) {
|
||||
/* still have more than 2*k */
|
||||
int remain_k2 = k_count & ~1;
|
||||
k_count -= remain_k2;
|
||||
lda = remain_k2;
|
||||
TCONF(cfg, tail_m, 16, remain_k2);
|
||||
/* reconfig will clear all tiles,
|
||||
* need to store/load again
|
||||
*/
|
||||
LOAD_C(0, 0); LOAD_C(0, 1);
|
||||
|
||||
LOAD_A(0, x);
|
||||
ptr_a0 += tail_m * remain_k2;
|
||||
LOAD_B(x, 0); LOAD_B(x, 1);
|
||||
ptr_b0 += 16 * remain_k2;
|
||||
ptr_b1 += 16 * remain_k2;
|
||||
|
||||
MATMUL(0, 0); MATMUL(0, 1);
|
||||
|
||||
STORE_C(0, 0); STORE_C(0, 1);
|
||||
}
|
||||
if (k_count > 0) {
|
||||
/* still have odd tail k, need to transform into 2*k */
|
||||
TCONF(cfg, tail_m, 16, 2);
|
||||
|
||||
LOAD_C(0, 0); LOAD_C(0, 1);
|
||||
|
||||
MASK_LOAD_A_TAIL(0, x);
|
||||
LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1);
|
||||
|
||||
MATMUL(0, 0); MATMUL(0, 1);
|
||||
|
||||
STORE_C(0, 0); STORE_C(0, 1);
|
||||
}
|
||||
ptr_c00 += 32;
|
||||
ptr_c01 += 32;
|
||||
}
|
||||
|
@ -376,41 +266,95 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
|||
MATMUL(0, 0);
|
||||
}
|
||||
STORE_C(0, 0);
|
||||
if (k_count > 1) {
|
||||
/* still have more than 2*k */
|
||||
int remain_k2 = k_count & ~1;
|
||||
k_count -= remain_k2;
|
||||
lda = remain_k2;
|
||||
TCONF(cfg, tail_m, tail_n, remain_k2);
|
||||
/* reconfig will clear all tiles,
|
||||
* need to store/load again
|
||||
*/
|
||||
LOAD_C(0, 0);
|
||||
|
||||
LOAD_A(0, x);
|
||||
ptr_a0 += tail_m * remain_k2;
|
||||
LOAD_B(x, 0);
|
||||
ptr_b0 += tail_n * remain_k2;
|
||||
|
||||
MATMUL(0, 0);
|
||||
|
||||
STORE_C(0, 0);
|
||||
}
|
||||
if (k_count > 0) {
|
||||
/* still have odd tail k, need to transform into 2*k */
|
||||
TCONF(cfg, tail_m, tail_n, 2);
|
||||
|
||||
LOAD_C(0, 0);
|
||||
|
||||
MASK_LOAD_A_TAIL(0, x);
|
||||
MASK_LOAD_B_TAIL(x, 0);
|
||||
MATMUL(0, 0);
|
||||
|
||||
STORE_C(0, 0);
|
||||
}
|
||||
ptr_c00 += tail_n;
|
||||
}
|
||||
ptr_a += tail_m * k;
|
||||
}
|
||||
|
||||
tail_k:
|
||||
// process for k < 32
|
||||
BLASLONG k32 = k & ~31;
|
||||
BLASLONG k2 = k & ~1;
|
||||
int remain_k2 = k2 - k32;
|
||||
if (remain_k2 > 0) {
|
||||
m_count = m;
|
||||
ptr_a = A;
|
||||
ptr_c = C;
|
||||
for (; m_count > 0; m_count -= 16) {
|
||||
int tail_m = (m_count > 16) ? 16: m_count;
|
||||
__mmask16 amask = (1UL << tail_m) - 1;
|
||||
|
||||
ptr_a0 = ptr_a + tail_m * k32;
|
||||
ptr_a += tail_m * k;
|
||||
ptr_b = B;
|
||||
ptr_c00 = ptr_c;
|
||||
ptr_c += tail_m * ldc;
|
||||
n_count = n;
|
||||
lda = remain_k2;
|
||||
ldb = 32;
|
||||
TCONF(cfg, tail_m, 16, remain_k2);
|
||||
for (; n_count > 15; n_count -= 16) {
|
||||
ptr_b0 = ptr_b + 16 * k32;
|
||||
LOAD_C(0, 0);
|
||||
LOAD_A(0, x);
|
||||
LOAD_B(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
ptr_b += 16 * k;
|
||||
ptr_c00 += 16;
|
||||
}
|
||||
if (n_count > 0) {
|
||||
int tail_n = (n_count > 16) ? 16: n_count;
|
||||
__mmask16 bmask = (1UL << tail_n) - 1;
|
||||
ptr_b0 = ptr_b + tail_n * k32;
|
||||
ldb = 2 * tail_n;
|
||||
TCONF(cfg, tail_m, tail_n, remain_k2);
|
||||
LOAD_C(0, 0);
|
||||
LOAD_A(0, x);
|
||||
LOAD_B(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (k2 != k) {
|
||||
m_count = m;
|
||||
ptr_a = A;
|
||||
ptr_c = C;
|
||||
for (; m_count > 0; m_count -= 16) {
|
||||
int tail_m = (m_count > 16) ? 16: m_count;
|
||||
__mmask16 amask = (1UL << tail_m) - 1;
|
||||
|
||||
ptr_a0 = ptr_a + tail_m * k2;
|
||||
ptr_a += tail_m * k;
|
||||
ptr_b = B;
|
||||
ptr_c00 = ptr_c;
|
||||
ptr_c += tail_m * ldc;
|
||||
n_count = n;
|
||||
TCONF(cfg, tail_m, 16, 2);
|
||||
for (; n_count > 15; n_count -= 16) {
|
||||
ptr_b0 = ptr_b + 16 * k2;
|
||||
LOAD_C(0, 0);
|
||||
MASK_LOAD_A_TAIL(0, x);
|
||||
LOAD_B_TAIL(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
ptr_b += 16 * k;
|
||||
ptr_c00 += 16;
|
||||
}
|
||||
if (n_count > 0) {
|
||||
int tail_n = (n_count > 16) ? 16: n_count;
|
||||
__mmask16 bmask = (1UL << tail_n) - 1;
|
||||
ptr_b0 = ptr_b + tail_n * k2;
|
||||
TCONF(cfg, tail_m, tail_n, 2);
|
||||
LOAD_C(0, 0);
|
||||
MASK_LOAD_A_TAIL(0, x);
|
||||
MASK_LOAD_B_TAIL(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue