diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr.c b/kernel/x86_64/sbgemm_kernel_16x16_spr.c index 41d2634d5..b7b4e36a3 100644 --- a/kernel/x86_64/sbgemm_kernel_16x16_spr.c +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr.c @@ -82,6 +82,12 @@ typedef struct { _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ } +#define MASK_LOAD_A_TAIL(M, N) {\ + __m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \ + __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ + _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ + _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ +} #define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2) #define LOAD_B_TAIL(M, N) {\ __m256i ymm = _mm256_loadu_epi16(ptr_b##N); \ @@ -111,7 +117,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA A = iB; B = iA; - printf("kernel: m %d, n %d, k %d, ldc: %d\n", m, n, k, ldc); IFLOAT *ptr_a = A, *ptr_b = B; IFLOAT *ptr_b0, *ptr_b1; IFLOAT *ptr_a0, *ptr_a1; @@ -279,5 +284,133 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA } ptr_a += 32 * k; } + for (; m_count > 0; m_count -= 16) { + // process at most 16 m at a time + int tail_m = (m_count > 16) ? 16: m_count; + __mmask16 amask = (1UL << tail_m) - 1; + + ptr_b = B; + + ptr_c00 = ptr_c; + ptr_c01 = ptr_c00 + 16; + ptr_c += tail_m * ldc; + n_count = n; + for (; n_count > 31; n_count -= 32) { + ptr_a0 = ptr_a; + + ptr_b0 = ptr_b; + ptr_b1 = ptr_b + 16 * k; + ptr_b += 32 * k; + + lda = 32; + ldb = 32; + TCONF(cfg, tail_m, 16, 32); + LOAD_C(0, 0); LOAD_C(0, 1); + k_count = k; + for (; k_count > 31; k_count -= 32) { + LOAD_A(0, x); + ptr_a0 += tail_m * 32; + LOAD_B(x, 0); LOAD_B(x, 1); + ptr_b0 += 16 * 32; + ptr_b1 += 16 * 32; + + MATMUL(0, 0); MATMUL(0, 1); + } + STORE_C(0, 0); STORE_C(0, 1); + if (k_count > 1) { + /* still have more than 2*k */ + int remain_k2 = k_count & ~1; + k_count -= remain_k2; + lda = remain_k2; + TCONF(cfg, tail_m, 16, remain_k2); + /* reconfig will clear all tiles, + * need to store/load again + */ + LOAD_C(0, 0); LOAD_C(0, 1); + + LOAD_A(0, x); + ptr_a0 += tail_m * remain_k2; + LOAD_B(x, 0); LOAD_B(x, 1); + ptr_b0 += 16 * remain_k2; + ptr_b1 += 16 * remain_k2; + + MATMUL(0, 0); MATMUL(0, 1); + + STORE_C(0, 0); STORE_C(0, 1); + } + if (k_count > 0) { + /* still have odd tail k, need to transform into 2*k */ + TCONF(cfg, tail_m, 16, 2); + + LOAD_C(0, 0); LOAD_C(0, 1); + + MASK_LOAD_A_TAIL(0, x); + LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1); + + MATMUL(0, 0); MATMUL(0, 1); + + STORE_C(0, 0); STORE_C(0, 1); + } + ptr_c00 += 32; + ptr_c01 += 32; + } + for (; n_count > 0; n_count -= 16) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_a0 = ptr_a; + + ptr_b0 = ptr_b; + ptr_b += tail_n * k; + + lda = 32; + ldb = 2 * tail_n; + TCONF(cfg, tail_m, tail_n, 32); + LOAD_C(0, 0); + k_count = k; + for (; k_count > 31; k_count -= 32) { + LOAD_A(0, x); + ptr_a0 += tail_m * 32; + LOAD_B(x, 0); + ptr_b0 += tail_n * 32; + + MATMUL(0, 0); + } + STORE_C(0, 0); + if (k_count > 1) { + /* still have more than 2*k */ + int remain_k2 = k_count & ~1; + k_count -= remain_k2; + lda = remain_k2; + TCONF(cfg, tail_m, tail_n, remain_k2); + /* reconfig will clear all tiles, + * need to store/load again + */ + LOAD_C(0, 0); + + LOAD_A(0, x); + ptr_a0 += tail_m * remain_k2; + LOAD_B(x, 0); + ptr_b0 += tail_n * remain_k2; + + MATMUL(0, 0); + + STORE_C(0, 0); + } + if (k_count > 0) { + /* still have odd tail k, need to transform into 2*k */ + TCONF(cfg, tail_m, tail_n, 2); + + LOAD_C(0, 0); + + MASK_LOAD_A_TAIL(0, x); + MASK_LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + + STORE_C(0, 0); + } + ptr_c00 += tail_n; + } + ptr_a += tail_m * k; + } return 0; } diff --git a/kernel/x86_64/sbgemm_oncopy_16_spr.c b/kernel/x86_64/sbgemm_oncopy_16_spr.c index da353d2c7..f5668e26e 100644 --- a/kernel/x86_64/sbgemm_oncopy_16_spr.c +++ b/kernel/x86_64/sbgemm_oncopy_16_spr.c @@ -32,8 +32,8 @@ #define MASK_COPY_32(N) _mm512_mask_storeu_epi16(boffset + tail_m * N, mmask, _mm512_maskz_loadu_epi16(mmask, aoffset##N + i)) #define COPY_ODD_TAIL(N) *(boffset + N) = *(aoffset##N + i); + int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { - printf("ONCOPY: m %d, n %d, lda %d\n", m, n, lda); BLASLONG i, j; IFLOAT *aoffset, *boffset; IFLOAT *aoffset0, *aoffset1, *aoffset2, *aoffset3;