diff --git a/kernel/x86_64/sbgemm_oncopy_16_spr.c b/kernel/x86_64/sbgemm_oncopy_16_spr.c index 593f2433d..ccb00ada1 100644 --- a/kernel/x86_64/sbgemm_oncopy_16_spr.c +++ b/kernel/x86_64/sbgemm_oncopy_16_spr.c @@ -28,18 +28,45 @@ #include #include "common.h" -#define COPY_32(N) _mm512_storeu_si512(boffset + 32 * N, _mm512_loadu_si512(aoffset##N + i)) -#define MASK_COPY_32(N) _mm512_mask_storeu_epi16(boffset + tail_m * N, mmask, _mm512_maskz_loadu_epi16(mmask, aoffset##N + i)) -#define COPY_ODD_TAIL(N) *(boffset + N) = *(aoffset##N + i); +typedef struct { + char palette_id; + char start_row; + char dummy0[14]; // bytes 2-15 reserved, must be zero + short tile_colsb[8]; + char dummy1[16]; // bytes 32-47 reserved, must be zero + char tile_rows[8]; + char dummy2[16]; // bytes 56-63 reserved, must be zero +} tilecfg; + +#define T_16x32 0 +#define T_16xm 1 +#define T_nx32 2 +#define T_nxm 3 + +#define TCONF(cfg, m, n) \ + memset(&cfg, 0, sizeof(tilecfg)); \ + cfg.palette_id = 1; \ + cfg.tile_rows[T_16x32] = 16; \ + cfg.tile_colsb[T_16x32] = 64; \ + if (m) { \ + cfg.tile_rows[T_16xm] = 16; \ + cfg.tile_colsb[T_16xm] = m * 2; \ + } \ + if (n) { \ + cfg.tile_rows[T_nx32] = n; \ + cfg.tile_colsb[T_nx32] = 64; \ + } \ + if (m && n) { \ + cfg.tile_rows[T_nxm] = n; \ + cfg.tile_colsb[T_nxm] = m * 2; \ + } \ + _tile_loadconfig(&cfg); int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { BLASLONG i, j; IFLOAT *aoffset, *boffset; - IFLOAT *aoffset0, *aoffset1, *aoffset2, *aoffset3; - IFLOAT *aoffset4, *aoffset5, *aoffset6, *aoffset7; - IFLOAT *aoffset8, *aoffset9, *aoffset10, *aoffset11; - IFLOAT *aoffset12, *aoffset13, *aoffset14, *aoffset15; + IFLOAT *aoffset0; aoffset = a; boffset = b; @@ -48,141 +75,52 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { BLASLONG m32 = m & ~31; BLASLONG m2 = m & ~1; + BLASLONG tail_m = m2 - m32; + BLASLONG tail_n = n - n16; + tilecfg cfg; + TCONF(cfg, tail_m, tail_n); + for (j = 0; j < n16; j += 16) { - IFLOAT *boffset0 = boffset; aoffset0 = aoffset; - aoffset1 = aoffset0 + lda; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; for (i = 0; i < m32; i += 32) { - COPY_32(0); COPY_32(1); COPY_32(2); COPY_32(3); - boffset += 32 * 16; - } - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - boffset = boffset0; - for (i = 0; i < m32; i += 32) { - COPY_32(4); COPY_32(5); COPY_32(6); COPY_32(7); - boffset += 32 * 16; - } - aoffset8 = aoffset7 + lda; - aoffset9 = aoffset8 + lda; - aoffset10 = aoffset9 + lda; - aoffset11 = aoffset10 + lda; - boffset = boffset0; - for (i = 0; i < m32; i += 32) { - COPY_32(8); COPY_32(9); COPY_32(10); COPY_32(11); - boffset += 32 * 16; - } - aoffset12 = aoffset11 + lda; - aoffset13 = aoffset12 + lda; - aoffset14 = aoffset13 + lda; - aoffset15 = aoffset14 + lda; - boffset = boffset0; - for (i = 0; i < m32; i += 32) { - COPY_32(12); COPY_32(13); COPY_32(14); COPY_32(15); + _tile_loadd(T_16x32, aoffset0, lda * 2); + _tile_stored(T_16x32, boffset, 32 * 2); + aoffset0 += 32; boffset += 32 * 16; } if (i < m2) { - int tail_m = m2 - i; - __mmask32 mmask = (1UL << tail_m) - 1; - MASK_COPY_32(0); MASK_COPY_32(1); MASK_COPY_32(2); MASK_COPY_32(3); - MASK_COPY_32(4); MASK_COPY_32(5); MASK_COPY_32(6); MASK_COPY_32(7); - MASK_COPY_32(8); MASK_COPY_32(9); MASK_COPY_32(10); MASK_COPY_32(11); - MASK_COPY_32(12); MASK_COPY_32(13); MASK_COPY_32(14); MASK_COPY_32(15); - i = m2; + _tile_loadd(T_16xm, aoffset0, lda * 2); + _tile_stored(T_16xm, boffset, tail_m * 2); + aoffset0 += tail_m; boffset += tail_m * 16; + i = m2; } if (i < m) { /* the tail odd k should put alone */ - COPY_ODD_TAIL(0); COPY_ODD_TAIL(1); COPY_ODD_TAIL(2); COPY_ODD_TAIL(3); - COPY_ODD_TAIL(4); COPY_ODD_TAIL(5); COPY_ODD_TAIL(6); COPY_ODD_TAIL(7); - COPY_ODD_TAIL(8); COPY_ODD_TAIL(9); COPY_ODD_TAIL(10); COPY_ODD_TAIL(11); - COPY_ODD_TAIL(12); COPY_ODD_TAIL(13); COPY_ODD_TAIL(14); COPY_ODD_TAIL(15); + for (int ii = 0; ii < 16; ii++) { + *(boffset + ii) = *(aoffset0 + lda * ii); + } boffset += 16; } aoffset += 16 * lda; } if (j < n) { - int remain_n = n - j; aoffset0 = aoffset; - aoffset1 = aoffset0 + lda; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; - aoffset9 = aoffset8 + lda; - aoffset10 = aoffset9 + lda; - aoffset11 = aoffset10 + lda; - aoffset12 = aoffset11 + lda; - aoffset13 = aoffset12 + lda; - aoffset14 = aoffset13 + lda; - aoffset15 = aoffset14 + lda; for (i = 0; i < m32; i += 32) { - switch(remain_n) { - case 15: COPY_32(14); - case 14: COPY_32(13); - case 13: COPY_32(12); - case 12: COPY_32(11); - case 11: COPY_32(10); - case 10: COPY_32(9); - case 9: COPY_32(8); - case 8: COPY_32(7); - case 7: COPY_32(6); - case 6: COPY_32(5); - case 5: COPY_32(4); - case 4: COPY_32(3); - case 3: COPY_32(2); - case 2: COPY_32(1); - case 1: COPY_32(0); - } - boffset += 32 * remain_n; + _tile_loadd(T_nx32, aoffset0, lda * 2); + _tile_stored(T_nx32, boffset, 32 * 2); + aoffset0 += 32; + boffset += 32 * tail_n; } if (i < m2) { - int tail_m = m2 - i; - __mmask32 mmask = (1UL << tail_m) - 1; - switch(remain_n) { - case 15: MASK_COPY_32(14); - case 14: MASK_COPY_32(13); - case 13: MASK_COPY_32(12); - case 12: MASK_COPY_32(11); - case 11: MASK_COPY_32(10); - case 10: MASK_COPY_32(9); - case 9: MASK_COPY_32(8); - case 8: MASK_COPY_32(7); - case 7: MASK_COPY_32(6); - case 6: MASK_COPY_32(5); - case 5: MASK_COPY_32(4); - case 4: MASK_COPY_32(3); - case 3: MASK_COPY_32(2); - case 2: MASK_COPY_32(1); - case 1: MASK_COPY_32(0); - } - i = m2; - boffset += tail_m * remain_n; + _tile_loadd(T_nxm, aoffset0, lda * 2); + _tile_stored(T_nxm, boffset, tail_m * 2); + aoffset0 += tail_m; + boffset += tail_m * tail_n; } if (i < m) { - switch(remain_n) { - case 15: COPY_ODD_TAIL(14); - case 14: COPY_ODD_TAIL(13); - case 13: COPY_ODD_TAIL(12); - case 12: COPY_ODD_TAIL(11); - case 11: COPY_ODD_TAIL(10); - case 10: COPY_ODD_TAIL(9); - case 9: COPY_ODD_TAIL(8); - case 8: COPY_ODD_TAIL(7); - case 7: COPY_ODD_TAIL(6); - case 6: COPY_ODD_TAIL(5); - case 5: COPY_ODD_TAIL(4); - case 4: COPY_ODD_TAIL(3); - case 3: COPY_ODD_TAIL(2); - case 2: COPY_ODD_TAIL(1); - case 1: COPY_ODD_TAIL(0); + for (int ii = 0; ii < tail_n; ii++) { + *(boffset + ii) = *(aoffset0 + lda * ii); } } }