sbgemm: spr: reduce tile conf loading by seperate tail k handling
This commit is contained in:
parent
0abbcd19c1
commit
88154ed02d
|
@ -127,10 +127,14 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
||||||
BLASLONG m_count = m;
|
BLASLONG m_count = m;
|
||||||
BLASLONG n_count, k_count;
|
BLASLONG n_count, k_count;
|
||||||
|
|
||||||
|
|
||||||
IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64)));
|
IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64)));
|
||||||
IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64)));
|
IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64)));
|
||||||
tilecfg cfg;
|
tilecfg cfg;
|
||||||
|
|
||||||
|
if (k < 32)
|
||||||
|
goto tail_k;
|
||||||
|
|
||||||
for (; m_count > 31; m_count -= 32) {
|
for (; m_count > 31; m_count -= 32) {
|
||||||
ptr_b = B;
|
ptr_b = B;
|
||||||
|
|
||||||
|
@ -140,6 +144,7 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
||||||
ptr_c11 = ptr_c10 + 16;
|
ptr_c11 = ptr_c10 + 16;
|
||||||
ptr_c += 32 * ldc;
|
ptr_c += 32 * ldc;
|
||||||
n_count = n;
|
n_count = n;
|
||||||
|
TCONF(cfg, 16, 16, 32);
|
||||||
for (; n_count > 31; n_count -= 32) {
|
for (; n_count > 31; n_count -= 32) {
|
||||||
ptr_a0 = ptr_a;
|
ptr_a0 = ptr_a;
|
||||||
ptr_a1 = ptr_a + 16 * k;
|
ptr_a1 = ptr_a + 16 * k;
|
||||||
|
@ -150,7 +155,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
||||||
|
|
||||||
lda = 32;
|
lda = 32;
|
||||||
ldb = 32;
|
ldb = 32;
|
||||||
TCONF(cfg, 16, 16, 32);
|
|
||||||
LOAD_C(0, 0); LOAD_C(0, 1);
|
LOAD_C(0, 0); LOAD_C(0, 1);
|
||||||
LOAD_C(1, 0); LOAD_C(1, 1);
|
LOAD_C(1, 0); LOAD_C(1, 1);
|
||||||
k_count = k;
|
k_count = k;
|
||||||
|
@ -167,47 +171,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
||||||
}
|
}
|
||||||
STORE_C(0, 0); STORE_C(0, 1);
|
STORE_C(0, 0); STORE_C(0, 1);
|
||||||
STORE_C(1, 0); STORE_C(1, 1);
|
STORE_C(1, 0); STORE_C(1, 1);
|
||||||
if (k_count > 1) {
|
|
||||||
/* still have more than 2*k */
|
|
||||||
int remain_k2 = k_count & ~1;
|
|
||||||
k_count -= remain_k2;
|
|
||||||
lda = remain_k2;
|
|
||||||
TCONF(cfg, 16, 16, remain_k2);
|
|
||||||
/* reconfig will clear all tiles,
|
|
||||||
* need to store/load again
|
|
||||||
*/
|
|
||||||
LOAD_C(0, 0); LOAD_C(0, 1);
|
|
||||||
LOAD_C(1, 0); LOAD_C(1, 1);
|
|
||||||
|
|
||||||
LOAD_A(0, x); LOAD_A(1, x);
|
|
||||||
ptr_a0 += 16 * remain_k2;
|
|
||||||
ptr_a1 += 16 * remain_k2;
|
|
||||||
LOAD_B(x, 0); LOAD_B(x, 1);
|
|
||||||
ptr_b0 += 16 * remain_k2;
|
|
||||||
ptr_b1 += 16 * remain_k2;
|
|
||||||
|
|
||||||
MATMUL(0, 0); MATMUL(0, 1);
|
|
||||||
MATMUL(1, 0); MATMUL(1, 1);
|
|
||||||
|
|
||||||
STORE_C(0, 0); STORE_C(0, 1);
|
|
||||||
STORE_C(1, 0); STORE_C(1, 1);
|
|
||||||
}
|
|
||||||
if (k_count > 0) {
|
|
||||||
/* still have odd tail k, need to transform into 2*k */
|
|
||||||
TCONF(cfg, 16, 16, 2);
|
|
||||||
|
|
||||||
LOAD_C(0, 0); LOAD_C(0, 1);
|
|
||||||
LOAD_C(1, 0); LOAD_C(1, 1);
|
|
||||||
|
|
||||||
LOAD_A_TAIL(0, x); LOAD_A_TAIL(1, x);
|
|
||||||
LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1);
|
|
||||||
|
|
||||||
MATMUL(0, 0); MATMUL(0, 1);
|
|
||||||
MATMUL(1, 0); MATMUL(1, 1);
|
|
||||||
|
|
||||||
STORE_C(0, 0); STORE_C(0, 1);
|
|
||||||
STORE_C(1, 0); STORE_C(1, 1);
|
|
||||||
}
|
|
||||||
ptr_c00 += 32;
|
ptr_c00 += 32;
|
||||||
ptr_c01 += 32;
|
ptr_c01 += 32;
|
||||||
ptr_c10 += 32;
|
ptr_c10 += 32;
|
||||||
|
@ -240,45 +203,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
||||||
}
|
}
|
||||||
STORE_C(0, 0);
|
STORE_C(0, 0);
|
||||||
STORE_C(1, 0);
|
STORE_C(1, 0);
|
||||||
if (k_count > 1) {
|
|
||||||
/* still have more than 2*k */
|
|
||||||
int remain_k2 = k_count & ~1;
|
|
||||||
k_count -= remain_k2;
|
|
||||||
lda = remain_k2;
|
|
||||||
TCONF(cfg, 16, tail_n, remain_k2);
|
|
||||||
/* reconfig will clear all tiles,
|
|
||||||
* need to store/load again
|
|
||||||
*/
|
|
||||||
LOAD_C(0, 0);
|
|
||||||
LOAD_C(1, 0);
|
|
||||||
|
|
||||||
LOAD_A(0, x); LOAD_A(1, x);
|
|
||||||
ptr_a0 += 16 * remain_k2;
|
|
||||||
ptr_a1 += 16 * remain_k2;
|
|
||||||
LOAD_B(x, 0);
|
|
||||||
ptr_b0 += tail_n * remain_k2;
|
|
||||||
|
|
||||||
MATMUL(0, 0);
|
|
||||||
MATMUL(1, 0);
|
|
||||||
|
|
||||||
STORE_C(0, 0);
|
|
||||||
STORE_C(1, 0);
|
|
||||||
}
|
|
||||||
if (k_count > 0) {
|
|
||||||
/* still have odd tail k, need to transform into 2*k */
|
|
||||||
TCONF(cfg, 16, tail_n, 2);
|
|
||||||
|
|
||||||
LOAD_C(0, 0);
|
|
||||||
LOAD_C(1, 0);
|
|
||||||
|
|
||||||
LOAD_A_TAIL(0, x); LOAD_A_TAIL(1, x);
|
|
||||||
MASK_LOAD_B_TAIL(x, 0);
|
|
||||||
MATMUL(0, 0);
|
|
||||||
MATMUL(1, 0);
|
|
||||||
|
|
||||||
STORE_C(0, 0);
|
|
||||||
STORE_C(1, 0);
|
|
||||||
}
|
|
||||||
ptr_c00 += tail_n;
|
ptr_c00 += tail_n;
|
||||||
ptr_c10 += tail_n;
|
ptr_c10 += tail_n;
|
||||||
}
|
}
|
||||||
|
@ -295,6 +219,7 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
||||||
ptr_c01 = ptr_c00 + 16;
|
ptr_c01 = ptr_c00 + 16;
|
||||||
ptr_c += tail_m * ldc;
|
ptr_c += tail_m * ldc;
|
||||||
n_count = n;
|
n_count = n;
|
||||||
|
TCONF(cfg, tail_m, 16, 32);
|
||||||
for (; n_count > 31; n_count -= 32) {
|
for (; n_count > 31; n_count -= 32) {
|
||||||
ptr_a0 = ptr_a;
|
ptr_a0 = ptr_a;
|
||||||
|
|
||||||
|
@ -304,7 +229,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
||||||
|
|
||||||
lda = 32;
|
lda = 32;
|
||||||
ldb = 32;
|
ldb = 32;
|
||||||
TCONF(cfg, tail_m, 16, 32);
|
|
||||||
LOAD_C(0, 0); LOAD_C(0, 1);
|
LOAD_C(0, 0); LOAD_C(0, 1);
|
||||||
k_count = k;
|
k_count = k;
|
||||||
for (; k_count > 31; k_count -= 32) {
|
for (; k_count > 31; k_count -= 32) {
|
||||||
|
@ -317,40 +241,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
||||||
MATMUL(0, 0); MATMUL(0, 1);
|
MATMUL(0, 0); MATMUL(0, 1);
|
||||||
}
|
}
|
||||||
STORE_C(0, 0); STORE_C(0, 1);
|
STORE_C(0, 0); STORE_C(0, 1);
|
||||||
if (k_count > 1) {
|
|
||||||
/* still have more than 2*k */
|
|
||||||
int remain_k2 = k_count & ~1;
|
|
||||||
k_count -= remain_k2;
|
|
||||||
lda = remain_k2;
|
|
||||||
TCONF(cfg, tail_m, 16, remain_k2);
|
|
||||||
/* reconfig will clear all tiles,
|
|
||||||
* need to store/load again
|
|
||||||
*/
|
|
||||||
LOAD_C(0, 0); LOAD_C(0, 1);
|
|
||||||
|
|
||||||
LOAD_A(0, x);
|
|
||||||
ptr_a0 += tail_m * remain_k2;
|
|
||||||
LOAD_B(x, 0); LOAD_B(x, 1);
|
|
||||||
ptr_b0 += 16 * remain_k2;
|
|
||||||
ptr_b1 += 16 * remain_k2;
|
|
||||||
|
|
||||||
MATMUL(0, 0); MATMUL(0, 1);
|
|
||||||
|
|
||||||
STORE_C(0, 0); STORE_C(0, 1);
|
|
||||||
}
|
|
||||||
if (k_count > 0) {
|
|
||||||
/* still have odd tail k, need to transform into 2*k */
|
|
||||||
TCONF(cfg, tail_m, 16, 2);
|
|
||||||
|
|
||||||
LOAD_C(0, 0); LOAD_C(0, 1);
|
|
||||||
|
|
||||||
MASK_LOAD_A_TAIL(0, x);
|
|
||||||
LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1);
|
|
||||||
|
|
||||||
MATMUL(0, 0); MATMUL(0, 1);
|
|
||||||
|
|
||||||
STORE_C(0, 0); STORE_C(0, 1);
|
|
||||||
}
|
|
||||||
ptr_c00 += 32;
|
ptr_c00 += 32;
|
||||||
ptr_c01 += 32;
|
ptr_c01 += 32;
|
||||||
}
|
}
|
||||||
|
@ -376,41 +266,95 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA
|
||||||
MATMUL(0, 0);
|
MATMUL(0, 0);
|
||||||
}
|
}
|
||||||
STORE_C(0, 0);
|
STORE_C(0, 0);
|
||||||
if (k_count > 1) {
|
|
||||||
/* still have more than 2*k */
|
|
||||||
int remain_k2 = k_count & ~1;
|
|
||||||
k_count -= remain_k2;
|
|
||||||
lda = remain_k2;
|
|
||||||
TCONF(cfg, tail_m, tail_n, remain_k2);
|
|
||||||
/* reconfig will clear all tiles,
|
|
||||||
* need to store/load again
|
|
||||||
*/
|
|
||||||
LOAD_C(0, 0);
|
|
||||||
|
|
||||||
LOAD_A(0, x);
|
|
||||||
ptr_a0 += tail_m * remain_k2;
|
|
||||||
LOAD_B(x, 0);
|
|
||||||
ptr_b0 += tail_n * remain_k2;
|
|
||||||
|
|
||||||
MATMUL(0, 0);
|
|
||||||
|
|
||||||
STORE_C(0, 0);
|
|
||||||
}
|
|
||||||
if (k_count > 0) {
|
|
||||||
/* still have odd tail k, need to transform into 2*k */
|
|
||||||
TCONF(cfg, tail_m, tail_n, 2);
|
|
||||||
|
|
||||||
LOAD_C(0, 0);
|
|
||||||
|
|
||||||
MASK_LOAD_A_TAIL(0, x);
|
|
||||||
MASK_LOAD_B_TAIL(x, 0);
|
|
||||||
MATMUL(0, 0);
|
|
||||||
|
|
||||||
STORE_C(0, 0);
|
|
||||||
}
|
|
||||||
ptr_c00 += tail_n;
|
ptr_c00 += tail_n;
|
||||||
}
|
}
|
||||||
ptr_a += tail_m * k;
|
ptr_a += tail_m * k;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tail_k:
|
||||||
|
// process for k < 32
|
||||||
|
BLASLONG k32 = k & ~31;
|
||||||
|
BLASLONG k2 = k & ~1;
|
||||||
|
int remain_k2 = k2 - k32;
|
||||||
|
if (remain_k2 > 0) {
|
||||||
|
m_count = m;
|
||||||
|
ptr_a = A;
|
||||||
|
ptr_c = C;
|
||||||
|
for (; m_count > 0; m_count -= 16) {
|
||||||
|
int tail_m = (m_count > 16) ? 16: m_count;
|
||||||
|
__mmask16 amask = (1UL << tail_m) - 1;
|
||||||
|
|
||||||
|
ptr_a0 = ptr_a + tail_m * k32;
|
||||||
|
ptr_a += tail_m * k;
|
||||||
|
ptr_b = B;
|
||||||
|
ptr_c00 = ptr_c;
|
||||||
|
ptr_c += tail_m * ldc;
|
||||||
|
n_count = n;
|
||||||
|
lda = remain_k2;
|
||||||
|
ldb = 32;
|
||||||
|
TCONF(cfg, tail_m, 16, remain_k2);
|
||||||
|
for (; n_count > 15; n_count -= 16) {
|
||||||
|
ptr_b0 = ptr_b + 16 * k32;
|
||||||
|
LOAD_C(0, 0);
|
||||||
|
LOAD_A(0, x);
|
||||||
|
LOAD_B(x, 0);
|
||||||
|
MATMUL(0, 0);
|
||||||
|
STORE_C(0, 0);
|
||||||
|
ptr_b += 16 * k;
|
||||||
|
ptr_c00 += 16;
|
||||||
|
}
|
||||||
|
if (n_count > 0) {
|
||||||
|
int tail_n = (n_count > 16) ? 16: n_count;
|
||||||
|
__mmask16 bmask = (1UL << tail_n) - 1;
|
||||||
|
ptr_b0 = ptr_b + tail_n * k32;
|
||||||
|
ldb = 2 * tail_n;
|
||||||
|
TCONF(cfg, tail_m, tail_n, remain_k2);
|
||||||
|
LOAD_C(0, 0);
|
||||||
|
LOAD_A(0, x);
|
||||||
|
LOAD_B(x, 0);
|
||||||
|
MATMUL(0, 0);
|
||||||
|
STORE_C(0, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (k2 != k) {
|
||||||
|
m_count = m;
|
||||||
|
ptr_a = A;
|
||||||
|
ptr_c = C;
|
||||||
|
for (; m_count > 0; m_count -= 16) {
|
||||||
|
int tail_m = (m_count > 16) ? 16: m_count;
|
||||||
|
__mmask16 amask = (1UL << tail_m) - 1;
|
||||||
|
|
||||||
|
ptr_a0 = ptr_a + tail_m * k2;
|
||||||
|
ptr_a += tail_m * k;
|
||||||
|
ptr_b = B;
|
||||||
|
ptr_c00 = ptr_c;
|
||||||
|
ptr_c += tail_m * ldc;
|
||||||
|
n_count = n;
|
||||||
|
TCONF(cfg, tail_m, 16, 2);
|
||||||
|
for (; n_count > 15; n_count -= 16) {
|
||||||
|
ptr_b0 = ptr_b + 16 * k2;
|
||||||
|
LOAD_C(0, 0);
|
||||||
|
MASK_LOAD_A_TAIL(0, x);
|
||||||
|
LOAD_B_TAIL(x, 0);
|
||||||
|
MATMUL(0, 0);
|
||||||
|
STORE_C(0, 0);
|
||||||
|
ptr_b += 16 * k;
|
||||||
|
ptr_c00 += 16;
|
||||||
|
}
|
||||||
|
if (n_count > 0) {
|
||||||
|
int tail_n = (n_count > 16) ? 16: n_count;
|
||||||
|
__mmask16 bmask = (1UL << tail_n) - 1;
|
||||||
|
ptr_b0 = ptr_b + tail_n * k2;
|
||||||
|
TCONF(cfg, tail_m, tail_n, 2);
|
||||||
|
LOAD_C(0, 0);
|
||||||
|
MASK_LOAD_A_TAIL(0, x);
|
||||||
|
MASK_LOAD_B_TAIL(x, 0);
|
||||||
|
MATMUL(0, 0);
|
||||||
|
STORE_C(0, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue