sbgemm: spr: only load A once in tail_k handling
This commit is contained in:
parent
9ab33228bb
commit
f2485352a6
|
@ -317,17 +317,19 @@ tail_k:
|
||||||
n_count = n;
|
n_count = n;
|
||||||
lda = remain_k2;
|
lda = remain_k2;
|
||||||
ldb = 32;
|
ldb = 32;
|
||||||
TCONF_TAIL(cfg, tail_m, 16, remain_k2);
|
if (n_count > 15) {
|
||||||
for (; n_count > 15; n_count -= 16) {
|
TCONF_TAIL(cfg, tail_m, 16, remain_k2);
|
||||||
ptr_b0 = ptr_b + 16 * k32;
|
|
||||||
ptr_b1 = ptr_b + 16 * k2;
|
|
||||||
LOAD_C(0, 0);
|
|
||||||
LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x);
|
LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x);
|
||||||
LOAD_B(x, 0); LOAD_B_TAIL(x, 1);
|
for (; n_count > 15; n_count -= 16) {
|
||||||
MATMUL(0, 0); MATMUL_TAIL(1, 1);
|
ptr_b0 = ptr_b + 16 * k32;
|
||||||
STORE_C(0, 0);
|
ptr_b1 = ptr_b + 16 * k2;
|
||||||
ptr_b += 16 * k;
|
LOAD_C(0, 0);
|
||||||
ptr_c00 += 16;
|
LOAD_B(x, 0); LOAD_B_TAIL(x, 1);
|
||||||
|
MATMUL(0, 0); MATMUL_TAIL(1, 1);
|
||||||
|
STORE_C(0, 0);
|
||||||
|
ptr_b += 16 * k;
|
||||||
|
ptr_c00 += 16;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (n_count > 0) {
|
if (n_count > 0) {
|
||||||
int tail_n = (n_count > 16) ? 16: n_count;
|
int tail_n = (n_count > 16) ? 16: n_count;
|
||||||
|
@ -356,16 +358,18 @@ tail_k:
|
||||||
n_count = n;
|
n_count = n;
|
||||||
lda = remain_k2;
|
lda = remain_k2;
|
||||||
ldb = 32;
|
ldb = 32;
|
||||||
TCONF(cfg, tail_m, 16, remain_k2);
|
if (n_count > 15) {
|
||||||
for (; n_count > 15; n_count -= 16) {
|
TCONF(cfg, tail_m, 16, remain_k2);
|
||||||
ptr_b0 = ptr_b + 16 * k32;
|
|
||||||
LOAD_C(0, 0);
|
|
||||||
LOAD_A(0, x);
|
LOAD_A(0, x);
|
||||||
LOAD_B(x, 0);
|
for (; n_count > 15; n_count -= 16) {
|
||||||
MATMUL(0, 0);
|
ptr_b0 = ptr_b + 16 * k32;
|
||||||
STORE_C(0, 0);
|
LOAD_C(0, 0);
|
||||||
ptr_b += 16 * k;
|
LOAD_B(x, 0);
|
||||||
ptr_c00 += 16;
|
MATMUL(0, 0);
|
||||||
|
STORE_C(0, 0);
|
||||||
|
ptr_b += 16 * k;
|
||||||
|
ptr_c00 += 16;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (n_count > 0) {
|
if (n_count > 0) {
|
||||||
int tail_n = (n_count > 16) ? 16: n_count;
|
int tail_n = (n_count > 16) ? 16: n_count;
|
||||||
|
@ -390,16 +394,18 @@ tail_k:
|
||||||
ptr_c00 = ptr_c;
|
ptr_c00 = ptr_c;
|
||||||
ptr_c += tail_m * ldc;
|
ptr_c += tail_m * ldc;
|
||||||
n_count = n;
|
n_count = n;
|
||||||
TCONF(cfg, tail_m, 16, 2);
|
if (n_count > 15) {
|
||||||
for (; n_count > 15; n_count -= 16) {
|
TCONF(cfg, tail_m, 16, 2);
|
||||||
ptr_b0 = ptr_b + 16 * k2;
|
|
||||||
LOAD_C(0, 0);
|
|
||||||
MASK_LOAD_A_TAIL(0, x);
|
MASK_LOAD_A_TAIL(0, x);
|
||||||
LOAD_B_TAIL(x, 0);
|
for (; n_count > 15; n_count -= 16) {
|
||||||
MATMUL(0, 0);
|
ptr_b0 = ptr_b + 16 * k2;
|
||||||
STORE_C(0, 0);
|
LOAD_C(0, 0);
|
||||||
ptr_b += 16 * k;
|
LOAD_B_TAIL(x, 0);
|
||||||
ptr_c00 += 16;
|
MATMUL(0, 0);
|
||||||
|
STORE_C(0, 0);
|
||||||
|
ptr_b += 16 * k;
|
||||||
|
ptr_c00 += 16;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (n_count > 0) {
|
if (n_count > 0) {
|
||||||
int tail_n = (n_count > 16) ? 16: n_count;
|
int tail_n = (n_count > 16) ? 16: n_count;
|
||||||
|
|
Loading…
Reference in New Issue