diff --git a/kernel/x86_64/sbgemm_otcopy_16_spr.c b/kernel/x86_64/sbgemm_otcopy_16_spr.c index 2f57ae7b6..b5d5d38fb 100644 --- a/kernel/x86_64/sbgemm_otcopy_16_spr.c +++ b/kernel/x86_64/sbgemm_otcopy_16_spr.c @@ -25,8 +25,278 @@ * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * *****************************************************************************/ +#include #include "common.h" +#define LOAD_A_8VEC(aptr) \ + r0 = _mm256_loadu_si256((__m256i *)(aptr + lda*0)); \ + r1 = _mm256_loadu_si256((__m256i *)(aptr + lda*1)); \ + r2 = _mm256_loadu_si256((__m256i *)(aptr + lda*2)); \ + r3 = _mm256_loadu_si256((__m256i *)(aptr + lda*3)); \ + r4 = _mm256_loadu_si256((__m256i *)(aptr + lda*4)); \ + r5 = _mm256_loadu_si256((__m256i *)(aptr + lda*5)); \ + r6 = _mm256_loadu_si256((__m256i *)(aptr + lda*6)); \ + r7 = _mm256_loadu_si256((__m256i *)(aptr + lda*7)); + +#define MASK_LOAD_A_8VEC(aptr) \ + r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*0)); \ + r1 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*1)); \ + r2 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*2)); \ + r3 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*3)); \ + r4 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*4)); \ + r5 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*5)); \ + r6 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*6)); \ + r7 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*7)); + +#define SWITCH_LOAD_A_8VEC(aptr, cond) \ + switch((cond)) { \ + case 8: r7 = _mm256_loadu_si256((__m256i *)(aptr + lda*7)); \ + case 7: r6 = _mm256_loadu_si256((__m256i *)(aptr + lda*6)); \ + case 6: r5 = _mm256_loadu_si256((__m256i *)(aptr + lda*5)); \ + case 5: r4 = _mm256_loadu_si256((__m256i *)(aptr + lda*4)); \ + case 4: r3 = _mm256_loadu_si256((__m256i *)(aptr + lda*3)); \ + case 3: r2 = _mm256_loadu_si256((__m256i *)(aptr + lda*2)); \ + case 2: r1 = _mm256_loadu_si256((__m256i *)(aptr + lda*1)); \ + case 1: r0 = _mm256_loadu_si256((__m256i *)(aptr + lda*0)); \ + } + +#define SWITCH_MASK_LOAD_A_8VEC(aptr, cond) \ + switch((cond)) { \ + case 8: r7 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*7)); \ + case 7: r6 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*6)); \ + case 6: r5 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*5)); \ + case 5: r4 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*4)); \ + case 4: r3 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*3)); \ + case 3: r2 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*2)); \ + case 2: r1 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*1)); \ + case 1: r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*0)); \ + } + +#define REORDER_8x16(t0, t1, t2, t3, t4, t5, t6, t7) \ + t0 = _mm256_unpacklo_epi16(r0, r1); \ + t1 = _mm256_unpackhi_epi16(r0, r1); \ + t2 = _mm256_unpacklo_epi16(r2, r3); \ + t3 = _mm256_unpackhi_epi16(r2, r3); \ + t4 = _mm256_unpacklo_epi16(r4, r5); \ + t5 = _mm256_unpackhi_epi16(r4, r5); \ + t6 = _mm256_unpacklo_epi16(r6, r7); \ + t7 = _mm256_unpackhi_epi16(r6, r7); \ + r0 = _mm256_unpacklo_epi32(t0, t2); \ + r1 = _mm256_unpacklo_epi32(t1, t3); \ + r2 = _mm256_unpacklo_epi32(t4, t6); \ + r3 = _mm256_unpacklo_epi32(t5, t7); \ + r4 = _mm256_unpackhi_epi32(t0, t2); \ + r5 = _mm256_unpackhi_epi32(t1, t3); \ + r6 = _mm256_unpackhi_epi32(t4, t6); \ + r7 = _mm256_unpackhi_epi32(t5, t7); \ + t0 = _mm256_unpacklo_epi64(r0, r2); \ + t1 = _mm256_unpackhi_epi64(r0, r2); \ + t2 = _mm256_unpacklo_epi64(r4, r6); \ + t3 = _mm256_unpackhi_epi64(r4, r6); \ + t4 = _mm256_unpacklo_epi64(r1, r3); \ + t5 = _mm256_unpackhi_epi64(r1, r3); \ + t6 = _mm256_unpacklo_epi64(r5, r7); \ + t7 = _mm256_unpackhi_epi64(r5, r7); + +#define STORE_256_LO(x) \ + v = _mm256_permute2x128_si256(t0##x, t1##x, 0x20); \ + _mm256_storeu_si256((__m256i *)(boffset + x*32), v); + +#define STORE_256_HI(x) \ + v = _mm256_permute2x128_si256(t0##x, t1##x, 0x31); \ + _mm256_storeu_si256((__m256i *)(boffset + (x + 8)*32), v); + +#define MASK_STORE_256_LO(x) \ + v = _mm256_permute2x128_si256(t0##x, t1##x, 0x20); \ + _mm256_mask_storeu_epi16(boffset + x*m_load, mmask, v); + +#define MASK_STORE_256_HI(x) \ + v = _mm256_permute2x128_si256(t0##x, t1##x, 0x31); \ + _mm256_mask_storeu_epi16(boffset + (x + 8)*m_load, mmask, v); + +#define STORE_256(x, y) {\ + __m256i v; \ + if (x == 0) { STORE_256_LO(y); } \ + else { STORE_256_HI(y); } \ +} + +#define MASK_STORE_256(x, y) {\ + __m256i v; \ + if (x == 0) { MASK_STORE_256_LO(y); } \ + else { MASK_STORE_256_HI(y); } \ +} + +#define SWITCH_STORE_16x(cond, func) \ + switch((cond)) {\ + case 15: func(1, 6); \ + case 14: func(1, 5); \ + case 13: func(1, 4); \ + case 12: func(1, 3); \ + case 11: func(1, 2); \ + case 10: func(1, 1); \ + case 9: func(1, 0); \ + case 8: func(0, 7); \ + case 7: func(0, 6); \ + case 6: func(0, 5); \ + case 5: func(0, 4); \ + case 4: func(0, 3); \ + case 3: func(0, 2); \ + case 2: func(0, 1); \ + case 1: func(0, 0); \ + } + + int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { + IFLOAT *aoffset, *boffset; + IFLOAT *aoffset00, *aoffset01, *aoffset10, *aoffset11; + IFLOAT *boffset0; + + __m256i r0, r1, r2, r3, r4, r5, r6, r7; + __m256i t00, t01, t02, t03, t04, t05, t06, t07; + __m256i t10, t11, t12, t13, t14, t15, t16, t17; + + aoffset = a; + boffset = b; + BLASLONG n_count = n; + BLASLONG m_count = m; + for (; n_count > 15; n_count -= 16) { + aoffset00 = aoffset; + aoffset01 = aoffset00 + 8 * lda; + aoffset10 = aoffset01 + 8 * lda; + aoffset11 = aoffset10 + 8 * lda; + aoffset += 16; + m_count = m; + for (; m_count > 31; m_count -= 32) { + // first 16 rows + LOAD_A_8VEC(aoffset00); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + LOAD_A_8VEC(aoffset01); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + STORE_256(0, 0); STORE_256(0, 1); STORE_256(0, 2); STORE_256(0, 3); + STORE_256(0, 4); STORE_256(0, 5); STORE_256(0, 6); STORE_256(0, 7); + STORE_256(1, 0); STORE_256(1, 1); STORE_256(1, 2); STORE_256(1, 3); + STORE_256(1, 4); STORE_256(1, 5); STORE_256(1, 6); STORE_256(1, 7); + // last 16 rows + boffset += 16; + LOAD_A_8VEC(aoffset10); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + LOAD_A_8VEC(aoffset11); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + STORE_256(0, 0); STORE_256(0, 1); STORE_256(0, 2); STORE_256(0, 3); + STORE_256(0, 4); STORE_256(0, 5); STORE_256(0, 6); STORE_256(0, 7); + STORE_256(1, 0); STORE_256(1, 1); STORE_256(1, 2); STORE_256(1, 3); + STORE_256(1, 4); STORE_256(1, 5); STORE_256(1, 6); STORE_256(1, 7); + aoffset00 += 32 * lda; + aoffset01 += 32 * lda; + aoffset10 += 32 * lda; + aoffset11 += 32 * lda; + boffset += 31 * 16; + } + if (m_count > 1) { + int m_load = m_count & ~1; + m_count -= m_load; + __mmask16 mmask; + SWITCH_LOAD_A_8VEC(aoffset00, m_load > 8 ? 8: m_load); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + if (m_load > 8) { + SWITCH_LOAD_A_8VEC(aoffset01, m_load > 16 ? 8: m_load - 8); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + } + int this_load = m_load > 16 ? 16 : m_load; + mmask = (1UL << this_load) - 1; + MASK_STORE_256(0, 0); MASK_STORE_256(0, 1); MASK_STORE_256(0, 2); MASK_STORE_256(0, 3); + MASK_STORE_256(0, 4); MASK_STORE_256(0, 5); MASK_STORE_256(0, 6); MASK_STORE_256(0, 7); + MASK_STORE_256(1, 0); MASK_STORE_256(1, 1); MASK_STORE_256(1, 2); MASK_STORE_256(1, 3); + MASK_STORE_256(1, 4); MASK_STORE_256(1, 5); MASK_STORE_256(1, 6); MASK_STORE_256(1, 7); + boffset0 = boffset; + if (m_load > 16) { + boffset += this_load; + SWITCH_LOAD_A_8VEC(aoffset10, m_load > 24 ? 8: m_load - 16); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + if (m_load > 24) { + SWITCH_LOAD_A_8VEC(aoffset11, m_load - 24); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + } + this_load = m_load - 16; + mmask = (1UL << this_load) - 1; + MASK_STORE_256(0, 0); MASK_STORE_256(0, 1); MASK_STORE_256(0, 2); MASK_STORE_256(0, 3); + MASK_STORE_256(0, 4); MASK_STORE_256(0, 5); MASK_STORE_256(0, 6); MASK_STORE_256(0, 7); + MASK_STORE_256(1, 0); MASK_STORE_256(1, 1); MASK_STORE_256(1, 2); MASK_STORE_256(1, 3); + MASK_STORE_256(1, 4); MASK_STORE_256(1, 5); MASK_STORE_256(1, 6); MASK_STORE_256(1, 7); + } + boffset = boffset0 + 16 * m_load; + aoffset00 += m_load * lda; + } + if (m_count > 0) { + // just copy lask K to B directly + r0 = _mm256_loadu_si256((__m256i *)(aoffset00)); + _mm256_storeu_si256((__m256i *)(boffset), r0); + boffset += 16; + } + } + if (n_count > 0) { + __mmask16 nmask = (1UL << n_count) - 1; + aoffset00 = aoffset; + aoffset01 = aoffset00 + 8 * lda; + aoffset10 = aoffset01 + 8 * lda; + aoffset11 = aoffset10 + 8 * lda; + m_count = m; + for (; m_count > 31; m_count -= 32) { + // first 16 rows + MASK_LOAD_A_8VEC(aoffset00); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + MASK_LOAD_A_8VEC(aoffset01); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + SWITCH_STORE_16x(n_count, STORE_256); + // last 16 rows + boffset0 = boffset; + boffset += 16; + MASK_LOAD_A_8VEC(aoffset10); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + MASK_LOAD_A_8VEC(aoffset11); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + SWITCH_STORE_16x(n_count, STORE_256); + aoffset00 += 32 * lda; + aoffset01 += 32 * lda; + aoffset10 += 32 * lda; + aoffset11 += 32 * lda; + boffset = 32 * n_count + boffset0; + } + if (m_count > 1) { + int m_load = m_count & ~1; + m_count -= m_load; + __mmask16 mmask; + SWITCH_MASK_LOAD_A_8VEC(aoffset00, m_load > 8 ? 8: m_load); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + if (m_load > 8) { + SWITCH_MASK_LOAD_A_8VEC(aoffset01, m_load > 16 ? 8: m_load - 8); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + } + int this_load = m_load > 16 ? 16 : m_load; + mmask = (1UL << this_load) - 1; + SWITCH_STORE_16x(n_count, MASK_STORE_256); + boffset0 = boffset; + if (m_load > 16) { + boffset += this_load; + SWITCH_MASK_LOAD_A_8VEC(aoffset10, m_load > 24 ? 8: m_load - 16); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + if (m_load > 24) { + SWITCH_MASK_LOAD_A_8VEC(aoffset11, m_load - 24); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + } + this_load = m_load - 16; + mmask = (1UL << this_load) - 1; + SWITCH_STORE_16x(n_count, MASK_STORE_256); + } + boffset = boffset0 + n_count * m_load; + aoffset00 += m_load * lda; + } + if (m_count > 0) { + // just copy lask K to B directly + r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aoffset00)); + _mm256_mask_storeu_epi16((__m256i *)(boffset), nmask, r0); + boffset += 16; + } + } return 0; }