diff --git a/kernel/x86_64/sgemv_n_4.c b/kernel/x86_64/sgemv_n_4.c index 06de28d97..90865c4b3 100644 --- a/kernel/x86_64/sgemv_n_4.c +++ b/kernel/x86_64/sgemv_n_4.c @@ -302,9 +302,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO FLOAT * xbuffer_align = x; FLOAT * ybuffer_align = y; - FLOAT * xbuffer = NULL; - FLOAT * ybuffer = NULL; - if (inc_x != 1) { xbuffer_align = buffer; for(BLASLONG i=0; i> (16-((m-tag_m_8x)*2)&15)); + unsigned short tail_mask_value = (((unsigned int)0xffff) >> (16-(((m-tag_m_8x)*2)&15))); __mmask16 a_mask = *((__mmask16*) &tail_mask_value); unsigned char y_mask_value = (((unsigned char)0xff) >> (8-(m-tag_m_8x))); __mmask8 y_mask = *((__mmask8*) &y_mask_value); - m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x]); + m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x*2]); m1 = _mm512_mul_ps(_mm512_mul_ps(m0, x1Array), ALPHAVECTOR); m2 = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), m1); __m256 ret = _mm256_add_ps(_mm512_extractf32x8_ps(m2, 1), _mm512_extractf32x8_ps(m2, 0)); @@ -322,7 +322,7 @@ static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float * { BLASLONG tag_m_4x = m & (~3); BLASLONG tag_m_2x = m & (~1); - __m512 m0, m1, m2; + __m512 m0, m1; __m256 m256_0, m256_1, c256_1, c256_2; __m128 c1, c2, c3, c4, ret; __m128 xarray = _mm_maskz_loadu_ps(0x0f, x); @@ -346,7 +346,7 @@ static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float * c3 = _mm256_extractf32x4_ps(c256_2, 0); c4 = _mm256_extractf32x4_ps(c256_2, 1); - ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, y)); + ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, &y[idx_m])); _mm_mask_storeu_ps(&y[idx_m], 0xff, ret); } @@ -958,6 +958,7 @@ static int sgemv_kernel_t_7(BLASLONG m, float alpha, float *a, float *x, float * c256_1 = _mm512_extractf32x8_ps(tmp0, 1); c256_0 = _mm256_add_ps(c256_0, c256_1); + c256_0 = _mm256_mul_ps(c256_0, alpha256); __m128 c128_0 = _mm256_extractf32x4_ps(c256_0, 0); __m128 c128_1 = _mm256_extractf32x4_ps(c256_0, 1); @@ -1016,9 +1017,10 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float * __m512 m0, m1, m2, m3; __m256 r0, r1, r2, r3, r4, r5, r6, r7, tmp0, tmp1, tmp2, tmp3; __m128 c128_0, c128_1, c128_2, c128_3; - __m128 alpha128 = _mm_set1_ps(alpha); + __m256 alpha256 = _mm256_set1_ps(alpha); __m256 x256 = _mm256_loadu_ps(x); + x256 = _mm256_mul_ps(x256, alpha256); __m512 x512 = _mm512_broadcast_f32x8(x256); for(BLASLONG idx_m=0; idx_m