sbgemm: spr: process k2 and odd k at the same time
This commit is contained in:
parent
7b2f5cb3b7
commit
9ab33228bb
|
@ -64,6 +64,28 @@ typedef struct {
|
|||
cfg.tile_colsb[7] = n * 4; \
|
||||
_tile_loadconfig(&cfg);
|
||||
|
||||
/* CONFIG for handling k2 and odd tail at the same time
|
||||
* tile0 -- A (m x 2k)
|
||||
* tile1 -- A (m x 1)
|
||||
* tile2 -- B (2k x n)
|
||||
* tile3 -- B (1 x n)
|
||||
* tile4 -- C (m x n)
|
||||
*/
|
||||
#define TCONF_TAIL(cfg, m, n, k2) \
|
||||
memset(&cfg, 0, sizeof(tilecfg)); \
|
||||
cfg.palette_id = 1; \
|
||||
cfg.tile_rows[0] = m; \
|
||||
cfg.tile_rows[1] = m; \
|
||||
cfg.tile_rows[2] = k2>>1; \
|
||||
cfg.tile_rows[3] = 1; \
|
||||
cfg.tile_rows[4] = m; \
|
||||
cfg.tile_colsb[0] = k2<<1; \
|
||||
cfg.tile_colsb[1] = 4; \
|
||||
cfg.tile_colsb[2] = n * 4; \
|
||||
cfg.tile_colsb[3] = n * 4; \
|
||||
cfg.tile_colsb[4] = n * 4; \
|
||||
_tile_loadconfig(&cfg);
|
||||
|
||||
#define T_A0 0
|
||||
#define T_A1 1
|
||||
#define T_B0 2
|
||||
|
@ -104,6 +126,7 @@ typedef struct {
|
|||
|
||||
#define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4)
|
||||
#define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N)
|
||||
#define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N)
|
||||
#define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4)
|
||||
|
||||
|
||||
|
@ -275,86 +298,123 @@ tail_k:
|
|||
// process for k < 32
|
||||
BLASLONG k32 = k & ~31;
|
||||
BLASLONG k2 = k & ~1;
|
||||
int remain_k2 = k2 - k32;
|
||||
if (remain_k2 > 0) {
|
||||
if (k32 != k) {
|
||||
int remain_k2 = k2 - k32;
|
||||
m_count = m;
|
||||
ptr_a = A;
|
||||
ptr_c = C;
|
||||
for (; m_count > 0; m_count -= 16) {
|
||||
int tail_m = (m_count > 16) ? 16: m_count;
|
||||
__mmask16 amask = (1UL << tail_m) - 1;
|
||||
if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0)
|
||||
for (; m_count > 0; m_count -= 16) {
|
||||
int tail_m = (m_count > 16) ? 16: m_count;
|
||||
__mmask16 amask = (1UL << tail_m) - 1;
|
||||
|
||||
ptr_a0 = ptr_a + tail_m * k32;
|
||||
ptr_a += tail_m * k;
|
||||
ptr_b = B;
|
||||
ptr_c00 = ptr_c;
|
||||
ptr_c += tail_m * ldc;
|
||||
n_count = n;
|
||||
lda = remain_k2;
|
||||
ldb = 32;
|
||||
TCONF(cfg, tail_m, 16, remain_k2);
|
||||
for (; n_count > 15; n_count -= 16) {
|
||||
ptr_b0 = ptr_b + 16 * k32;
|
||||
LOAD_C(0, 0);
|
||||
LOAD_A(0, x);
|
||||
LOAD_B(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
ptr_b += 16 * k;
|
||||
ptr_c00 += 16;
|
||||
ptr_a0 = ptr_a + tail_m * k32;
|
||||
ptr_a1 = ptr_a + tail_m * k2;
|
||||
ptr_a += tail_m * k;
|
||||
ptr_b = B;
|
||||
ptr_c00 = ptr_c;
|
||||
ptr_c += tail_m * ldc;
|
||||
n_count = n;
|
||||
lda = remain_k2;
|
||||
ldb = 32;
|
||||
TCONF_TAIL(cfg, tail_m, 16, remain_k2);
|
||||
for (; n_count > 15; n_count -= 16) {
|
||||
ptr_b0 = ptr_b + 16 * k32;
|
||||
ptr_b1 = ptr_b + 16 * k2;
|
||||
LOAD_C(0, 0);
|
||||
LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x);
|
||||
LOAD_B(x, 0); LOAD_B_TAIL(x, 1);
|
||||
MATMUL(0, 0); MATMUL_TAIL(1, 1);
|
||||
STORE_C(0, 0);
|
||||
ptr_b += 16 * k;
|
||||
ptr_c00 += 16;
|
||||
}
|
||||
if (n_count > 0) {
|
||||
int tail_n = (n_count > 16) ? 16: n_count;
|
||||
__mmask16 bmask = (1UL << tail_n) - 1;
|
||||
ptr_b0 = ptr_b + tail_n * k32;
|
||||
ptr_b1 = ptr_b + tail_n * k2;
|
||||
ldb = 2 * tail_n;
|
||||
TCONF_TAIL(cfg, tail_m, tail_n, remain_k2);
|
||||
LOAD_C(0, 0);
|
||||
LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x);
|
||||
LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1);
|
||||
MATMUL(0, 0); MATMUL_TAIL(1, 1);
|
||||
STORE_C(0, 0);
|
||||
}
|
||||
}
|
||||
if (n_count > 0) {
|
||||
int tail_n = (n_count > 16) ? 16: n_count;
|
||||
__mmask16 bmask = (1UL << tail_n) - 1;
|
||||
ptr_b0 = ptr_b + tail_n * k32;
|
||||
ldb = 2 * tail_n;
|
||||
TCONF(cfg, tail_m, tail_n, remain_k2);
|
||||
LOAD_C(0, 0);
|
||||
LOAD_A(0, x);
|
||||
LOAD_B(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
|
||||
} else if (remain_k2 > 0) { // k%32 = 2x
|
||||
for (; m_count > 0; m_count -= 16) {
|
||||
int tail_m = (m_count > 16) ? 16: m_count;
|
||||
|
||||
ptr_a0 = ptr_a + tail_m * k32;
|
||||
ptr_a += tail_m * k;
|
||||
ptr_b = B;
|
||||
ptr_c00 = ptr_c;
|
||||
ptr_c += tail_m * ldc;
|
||||
n_count = n;
|
||||
lda = remain_k2;
|
||||
ldb = 32;
|
||||
TCONF(cfg, tail_m, 16, remain_k2);
|
||||
for (; n_count > 15; n_count -= 16) {
|
||||
ptr_b0 = ptr_b + 16 * k32;
|
||||
LOAD_C(0, 0);
|
||||
LOAD_A(0, x);
|
||||
LOAD_B(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
ptr_b += 16 * k;
|
||||
ptr_c00 += 16;
|
||||
}
|
||||
if (n_count > 0) {
|
||||
int tail_n = (n_count > 16) ? 16: n_count;
|
||||
ptr_b0 = ptr_b + tail_n * k32;
|
||||
ldb = 2 * tail_n;
|
||||
TCONF(cfg, tail_m, tail_n, remain_k2);
|
||||
LOAD_C(0, 0);
|
||||
LOAD_A(0, x);
|
||||
LOAD_B(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
}
|
||||
}
|
||||
} else { // k%32 = 1
|
||||
for (; m_count > 0; m_count -= 16) {
|
||||
int tail_m = (m_count > 16) ? 16: m_count;
|
||||
__mmask16 amask = (1UL << tail_m) - 1;
|
||||
|
||||
ptr_a0 = ptr_a + tail_m * k2;
|
||||
ptr_a += tail_m * k;
|
||||
ptr_b = B;
|
||||
ptr_c00 = ptr_c;
|
||||
ptr_c += tail_m * ldc;
|
||||
n_count = n;
|
||||
TCONF(cfg, tail_m, 16, 2);
|
||||
for (; n_count > 15; n_count -= 16) {
|
||||
ptr_b0 = ptr_b + 16 * k2;
|
||||
LOAD_C(0, 0);
|
||||
MASK_LOAD_A_TAIL(0, x);
|
||||
LOAD_B_TAIL(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
ptr_b += 16 * k;
|
||||
ptr_c00 += 16;
|
||||
}
|
||||
if (n_count > 0) {
|
||||
int tail_n = (n_count > 16) ? 16: n_count;
|
||||
__mmask16 bmask = (1UL << tail_n) - 1;
|
||||
ptr_b0 = ptr_b + tail_n * k2;
|
||||
TCONF(cfg, tail_m, tail_n, 2);
|
||||
LOAD_C(0, 0);
|
||||
MASK_LOAD_A_TAIL(0, x);
|
||||
MASK_LOAD_B_TAIL(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
if (k2 != k) {
|
||||
m_count = m;
|
||||
ptr_a = A;
|
||||
ptr_c = C;
|
||||
for (; m_count > 0; m_count -= 16) {
|
||||
int tail_m = (m_count > 16) ? 16: m_count;
|
||||
__mmask16 amask = (1UL << tail_m) - 1;
|
||||
|
||||
ptr_a0 = ptr_a + tail_m * k2;
|
||||
ptr_a += tail_m * k;
|
||||
ptr_b = B;
|
||||
ptr_c00 = ptr_c;
|
||||
ptr_c += tail_m * ldc;
|
||||
n_count = n;
|
||||
TCONF(cfg, tail_m, 16, 2);
|
||||
for (; n_count > 15; n_count -= 16) {
|
||||
ptr_b0 = ptr_b + 16 * k2;
|
||||
LOAD_C(0, 0);
|
||||
MASK_LOAD_A_TAIL(0, x);
|
||||
LOAD_B_TAIL(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
ptr_b += 16 * k;
|
||||
ptr_c00 += 16;
|
||||
}
|
||||
if (n_count > 0) {
|
||||
int tail_n = (n_count > 16) ? 16: n_count;
|
||||
__mmask16 bmask = (1UL << tail_n) - 1;
|
||||
ptr_b0 = ptr_b + tail_n * k2;
|
||||
TCONF(cfg, tail_m, tail_n, 2);
|
||||
LOAD_C(0, 0);
|
||||
MASK_LOAD_A_TAIL(0, x);
|
||||
MASK_LOAD_B_TAIL(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue