diff --git a/kernel/x86_64/sbgemm_tcopy_16_cooperlake.c b/kernel/x86_64/sbgemm_tcopy_16_cooperlake.c index ce4458d2c..88725f343 100644 --- a/kernel/x86_64/sbgemm_tcopy_16_cooperlake.c +++ b/kernel/x86_64/sbgemm_tcopy_16_cooperlake.c @@ -160,4 +160,5 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){ } } } + return 0; } diff --git a/kernel/x86_64/sbgemm_tcopy_4_cooperlake.c b/kernel/x86_64/sbgemm_tcopy_4_cooperlake.c index afcf6f647..74f30d44a 100644 --- a/kernel/x86_64/sbgemm_tcopy_4_cooperlake.c +++ b/kernel/x86_64/sbgemm_tcopy_4_cooperlake.c @@ -26,8 +26,94 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ #include +#include #include "common.h" int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){ + BLASLONG i, j; + IFLOAT *boffset0, *boffset1; + + boffset0 = b; + + BLASLONG n8 = n & ~7; + BLASLONG m4 = m & ~3; + BLASLONG m2 = m & ~1; + + for (j = 0; j < n8; j += 8) { + boffset1 = boffset0 + m * 4; + for (i = 0; i < m4; i +=4) { + __m128i a0 = _mm_loadu_si128((void *)&a[(i + 0)*lda + j]); + __m128i a1 = _mm_loadu_si128((void *)&a[(i + 1)*lda + j]); + __m128i a2 = _mm_loadu_si128((void *)&a[(i + 2)*lda + j]); + __m128i a3 = _mm_loadu_si128((void *)&a[(i + 3)*lda + j]); + __m128i a00 = _mm_unpacklo_epi16(a0, a1); + __m128i a01 = _mm_unpackhi_epi16(a0, a1); + __m128i a10 = _mm_unpacklo_epi16(a2, a3); + __m128i a11 = _mm_unpackhi_epi16(a2, a3); + _mm_storeu_si128((void *)(boffset0 + 0), a00); + _mm_storeu_si128((void *)(boffset0 + 8), a10); + _mm_storeu_si128((void *)(boffset1 + 0), a01); + _mm_storeu_si128((void *)(boffset1 + 8), a11); + boffset0 += 16; + boffset1 += 16; + } + for (; i < m2; i+= 2) { + __m128i a0 = _mm_loadu_si128((void *)&a[(i + 0)*lda + j]); + __m128i a1 = _mm_loadu_si128((void *)&a[(i + 1)*lda + j]); + __m128i a00 = _mm_unpacklo_epi16(a0, a1); + __m128i a01 = _mm_unpackhi_epi16(a0, a1); + _mm_storeu_si128((void *)(boffset0 + 0), a00); + _mm_storeu_si128((void *)(boffset1 + 0), a01); + boffset0 += 8; + boffset1 += 8; + } + for (; i < m; i++) { + __m128d a0 = _mm_loadu_pd((void *)&a[(i + 0)*lda + j]); + _mm_store_sd((void *)boffset0, a0); + _mm_store_sd((void *)boffset1, _mm_permute_pd(a0, 0x1)); + boffset0 += 4; + boffset1 += 4; + } + boffset0 = boffset1; + } + if (j < n) { + uint32_t remains = n - j; + __mmask8 r_mask = (1UL << remains) - 1; + if (remains > 4) { + boffset1 = boffset0 + m * 4; + uint32_t tail1 = remains - 4; + __mmask8 w_mask1 = (1UL << tail1) - 1; + for (i = 0; i < m2; i += 2) { + __m128i a0 = _mm_maskz_loadu_epi16(r_mask, &a[(i + 0)*lda + j]); + __m128i a1 = _mm_maskz_loadu_epi16(r_mask, &a[(i + 1)*lda + j]); + __m128i a00 = _mm_unpacklo_epi16(a0, a1); + __m128i a01 = _mm_unpackhi_epi16(a0, a1); + _mm_storeu_si128((void *)boffset0, a00); + _mm_mask_storeu_epi32((void *)boffset1, w_mask1, a01); + boffset0 += 8; + boffset1 += 2 * tail1; + } + for (; i < m; i++) { + __m128i a0 = _mm_maskz_loadu_epi16(r_mask, &a[(i + 0)*lda + j]); + _mm_store_sd((void *)boffset0, (__m128d) a0); + _mm_mask_storeu_epi16((void *)boffset1, w_mask1, (__m128i) _mm_permute_pd((__m128d) a0, 0x1)); + boffset0 += 4; + boffset1 += tail1; + } + } else { + for (i = 0; i < m2; i += 2) { + __m128i a0 = _mm_maskz_loadu_epi16(r_mask, &a[(i + 0)*lda + j]); + __m128i a1 = _mm_maskz_loadu_epi16(r_mask, &a[(i + 1)*lda + j]); + __m128i a00 = _mm_unpacklo_epi16(a0, a1); + _mm_mask_storeu_epi32((void *)boffset0, r_mask, a00); + boffset0 += 2 * remains; + } + for (; i < m; i++) { + __m128i a0 = _mm_maskz_loadu_epi16(r_mask, &a[(i + 0)*lda + j]); + _mm_mask_storeu_epi16((void *)boffset0, r_mask, a0); + } + } + } + return 0; }