From e9acb464318618009d13ddcc7e30dc300e878052 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 25 Aug 2021 07:07:27 +0000 Subject: [PATCH 1/2] sgemv: skylakex: bug fix for sgemv_t kernel in corner case --- kernel/x86_64/sgemv_t_4.c | 2 +- .../x86_64/sgemv_t_microk_skylakex_template.c | 23 ++++++++++--------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/kernel/x86_64/sgemv_t_4.c b/kernel/x86_64/sgemv_t_4.c index 76236cd16..a36c8ace9 100644 --- a/kernel/x86_64/sgemv_t_4.c +++ b/kernel/x86_64/sgemv_t_4.c @@ -38,7 +38,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "sgemv_t_microk_haswell-4.c" #elif defined (SKYLAKEX) || defined (COOPERLAKE) #include "sgemv_t_microk_haswell-4.c" -/*#include "sgemv_t_microk_skylakex.c"*/ +#include "sgemv_t_microk_skylakex.c" #endif #if defined(STEAMROLLER) || defined(EXCAVATOR) diff --git a/kernel/x86_64/sgemv_t_microk_skylakex_template.c b/kernel/x86_64/sgemv_t_microk_skylakex_template.c index 34415054c..423413465 100644 --- a/kernel/x86_64/sgemv_t_microk_skylakex_template.c +++ b/kernel/x86_64/sgemv_t_microk_skylakex_template.c @@ -93,7 +93,7 @@ static int sgemv_kernel_t_1(BLASLONG m, float alpha, float *a, float *x, float * } if (tag_m_32x != m) { - for (BLASLONG idx_m = tag_m_64x; idx_m < tag_m_16x; idx_m+=32) { + for (BLASLONG idx_m = tag_m_32x; idx_m < tag_m_16x; idx_m+=16) { matrixArray_0 = _mm512_loadu_ps(&a[idx_m + 0]); _mm512_storeu_ps(&y[idx_m + 0], _mm512_fmadd_ps(matrixArray_0, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 0]))); @@ -145,8 +145,8 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float * } if (tag_m_32x != m) { for (BLASLONG idx_m = tag_m_32x; idx_m < tag_m_16x; idx_m+=16) { - m0 = _mm512_loadu_ps(&a[idx_m]); - m1 = _mm512_loadu_ps(&a[idx_m + 16]); + m0 = _mm512_loadu_ps(&a[idx_m*2]); + m1 = _mm512_loadu_ps(&a[idx_m*2 + 16]); col1_1 = _mm512_permutex2var_ps(m0, idx_base_0, m1); col1_2 = _mm512_permutex2var_ps(m0, idx_base_1, m1); _mm512_storeu_ps(&y[idx_m], _mm512_add_ps(_mm512_fmadd_ps(x2Array, col1_2, _mm512_mul_ps(col1_1, x1Array)), _mm512_loadu_ps(&y[idx_m]))); @@ -157,7 +157,7 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float * __mmask8 load_mask = *((__mmask8*) &load_mask_value); x1Array = _mm512_broadcast_f32x2(_mm_maskz_loadu_ps(load_mask, x)); for (BLASLONG idx_m = tag_m_16x; idx_m < tag_m_8x; idx_m+=8) { - m0 = _mm512_loadu_ps(&a[idx_m]); + m0 = _mm512_loadu_ps(&a[idx_m*2]); m1 = _mm512_mul_ps(_mm512_mul_ps(m0, x1Array), ALPHAVECTOR); m2 = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), m1); __m256 ret = _mm256_add_ps(_mm512_extractf32x8_ps(m2, 1), _mm512_extractf32x8_ps(m2, 0)); @@ -171,7 +171,7 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float * unsigned char y_mask_value = (((unsigned char)0xff) >> (8-(m-tag_m_8x))); __mmask8 y_mask = *((__mmask8*) &y_mask_value); - m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x]); + m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x*2]); m1 = _mm512_mul_ps(_mm512_mul_ps(m0, x1Array), ALPHAVECTOR); m2 = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), m1); __m256 ret = _mm256_add_ps(_mm512_extractf32x8_ps(m2, 1), _mm512_extractf32x8_ps(m2, 0)); @@ -346,7 +346,7 @@ static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float * c3 = _mm256_extractf32x4_ps(c256_2, 0); c4 = _mm256_extractf32x4_ps(c256_2, 1); - ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, y)); + ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, &y[idx_m])); _mm_mask_storeu_ps(&y[idx_m], 0xff, ret); } @@ -958,6 +958,7 @@ static int sgemv_kernel_t_7(BLASLONG m, float alpha, float *a, float *x, float * c256_1 = _mm512_extractf32x8_ps(tmp0, 1); c256_0 = _mm256_add_ps(c256_0, c256_1); + c256_0 = _mm256_mul_ps(c256_0, alpha256); __m128 c128_0 = _mm256_extractf32x4_ps(c256_0, 0); __m128 c128_1 = _mm256_extractf32x4_ps(c256_0, 1); @@ -1016,9 +1017,10 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float * __m512 m0, m1, m2, m3; __m256 r0, r1, r2, r3, r4, r5, r6, r7, tmp0, tmp1, tmp2, tmp3; __m128 c128_0, c128_1, c128_2, c128_3; - __m128 alpha128 = _mm_set1_ps(alpha); + __m256 alpha256 = _mm256_set1_ps(alpha); __m256 x256 = _mm256_loadu_ps(x); + x256 = _mm256_mul_ps(x256, alpha256); __m512 x512 = _mm512_broadcast_f32x8(x256); for(BLASLONG idx_m=0; idx_m Date: Wed, 25 Aug 2021 07:13:00 +0000 Subject: [PATCH 2/2] sgemv: skylakex: fix build warning --- kernel/x86_64/sgemv_n_4.c | 3 --- kernel/x86_64/sgemv_t_microk_skylakex_template.c | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/kernel/x86_64/sgemv_n_4.c b/kernel/x86_64/sgemv_n_4.c index 06de28d97..90865c4b3 100644 --- a/kernel/x86_64/sgemv_n_4.c +++ b/kernel/x86_64/sgemv_n_4.c @@ -302,9 +302,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO FLOAT * xbuffer_align = x; FLOAT * ybuffer_align = y; - FLOAT * xbuffer = NULL; - FLOAT * ybuffer = NULL; - if (inc_x != 1) { xbuffer_align = buffer; for(BLASLONG i=0; i> (16-((m-tag_m_8x)*2)&15)); + unsigned short tail_mask_value = (((unsigned int)0xffff) >> (16-(((m-tag_m_8x)*2)&15))); __mmask16 a_mask = *((__mmask16*) &tail_mask_value); unsigned char y_mask_value = (((unsigned char)0xff) >> (8-(m-tag_m_8x))); __mmask8 y_mask = *((__mmask8*) &y_mask_value); @@ -322,7 +322,7 @@ static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float * { BLASLONG tag_m_4x = m & (~3); BLASLONG tag_m_2x = m & (~1); - __m512 m0, m1, m2; + __m512 m0, m1; __m256 m256_0, m256_1, c256_1, c256_2; __m128 c1, c2, c3, c4, ret; __m128 xarray = _mm_maskz_loadu_ps(0x0f, x);