Merge pull request #3348 from guowangy/skylakex-sgemv_t-fix
skylakex sgemv_t kernel fix
This commit is contained in:
commit
6bb1805ed6
|
@ -302,9 +302,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
|
|||
FLOAT * xbuffer_align = x;
|
||||
FLOAT * ybuffer_align = y;
|
||||
|
||||
FLOAT * xbuffer = NULL;
|
||||
FLOAT * ybuffer = NULL;
|
||||
|
||||
if (inc_x != 1) {
|
||||
xbuffer_align = buffer;
|
||||
for(BLASLONG i=0; i<n; i++) {
|
||||
|
|
|
@ -38,7 +38,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
#include "sgemv_t_microk_haswell-4.c"
|
||||
#elif defined (SKYLAKEX) || defined (COOPERLAKE)
|
||||
#include "sgemv_t_microk_haswell-4.c"
|
||||
/*#include "sgemv_t_microk_skylakex.c"*/
|
||||
#include "sgemv_t_microk_skylakex.c"
|
||||
#endif
|
||||
|
||||
#if defined(STEAMROLLER) || defined(EXCAVATOR)
|
||||
|
|
|
@ -93,7 +93,7 @@ static int sgemv_kernel_t_1(BLASLONG m, float alpha, float *a, float *x, float *
|
|||
}
|
||||
|
||||
if (tag_m_32x != m) {
|
||||
for (BLASLONG idx_m = tag_m_64x; idx_m < tag_m_16x; idx_m+=32) {
|
||||
for (BLASLONG idx_m = tag_m_32x; idx_m < tag_m_16x; idx_m+=16) {
|
||||
matrixArray_0 = _mm512_loadu_ps(&a[idx_m + 0]);
|
||||
|
||||
_mm512_storeu_ps(&y[idx_m + 0], _mm512_fmadd_ps(matrixArray_0, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 0])));
|
||||
|
@ -145,8 +145,8 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float *
|
|||
}
|
||||
if (tag_m_32x != m) {
|
||||
for (BLASLONG idx_m = tag_m_32x; idx_m < tag_m_16x; idx_m+=16) {
|
||||
m0 = _mm512_loadu_ps(&a[idx_m]);
|
||||
m1 = _mm512_loadu_ps(&a[idx_m + 16]);
|
||||
m0 = _mm512_loadu_ps(&a[idx_m*2]);
|
||||
m1 = _mm512_loadu_ps(&a[idx_m*2 + 16]);
|
||||
col1_1 = _mm512_permutex2var_ps(m0, idx_base_0, m1);
|
||||
col1_2 = _mm512_permutex2var_ps(m0, idx_base_1, m1);
|
||||
_mm512_storeu_ps(&y[idx_m], _mm512_add_ps(_mm512_fmadd_ps(x2Array, col1_2, _mm512_mul_ps(col1_1, x1Array)), _mm512_loadu_ps(&y[idx_m])));
|
||||
|
@ -157,7 +157,7 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float *
|
|||
__mmask8 load_mask = *((__mmask8*) &load_mask_value);
|
||||
x1Array = _mm512_broadcast_f32x2(_mm_maskz_loadu_ps(load_mask, x));
|
||||
for (BLASLONG idx_m = tag_m_16x; idx_m < tag_m_8x; idx_m+=8) {
|
||||
m0 = _mm512_loadu_ps(&a[idx_m]);
|
||||
m0 = _mm512_loadu_ps(&a[idx_m*2]);
|
||||
m1 = _mm512_mul_ps(_mm512_mul_ps(m0, x1Array), ALPHAVECTOR);
|
||||
m2 = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), m1);
|
||||
__m256 ret = _mm256_add_ps(_mm512_extractf32x8_ps(m2, 1), _mm512_extractf32x8_ps(m2, 0));
|
||||
|
@ -166,12 +166,12 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float *
|
|||
}
|
||||
|
||||
if (tag_m_8x != m) {
|
||||
unsigned short tail_mask_value = (((unsigned int)0xffff) >> (16-((m-tag_m_8x)*2)&15));
|
||||
unsigned short tail_mask_value = (((unsigned int)0xffff) >> (16-(((m-tag_m_8x)*2)&15)));
|
||||
__mmask16 a_mask = *((__mmask16*) &tail_mask_value);
|
||||
unsigned char y_mask_value = (((unsigned char)0xff) >> (8-(m-tag_m_8x)));
|
||||
__mmask8 y_mask = *((__mmask8*) &y_mask_value);
|
||||
|
||||
m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x]);
|
||||
m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x*2]);
|
||||
m1 = _mm512_mul_ps(_mm512_mul_ps(m0, x1Array), ALPHAVECTOR);
|
||||
m2 = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), m1);
|
||||
__m256 ret = _mm256_add_ps(_mm512_extractf32x8_ps(m2, 1), _mm512_extractf32x8_ps(m2, 0));
|
||||
|
@ -322,7 +322,7 @@ static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float *
|
|||
{
|
||||
BLASLONG tag_m_4x = m & (~3);
|
||||
BLASLONG tag_m_2x = m & (~1);
|
||||
__m512 m0, m1, m2;
|
||||
__m512 m0, m1;
|
||||
__m256 m256_0, m256_1, c256_1, c256_2;
|
||||
__m128 c1, c2, c3, c4, ret;
|
||||
__m128 xarray = _mm_maskz_loadu_ps(0x0f, x);
|
||||
|
@ -346,7 +346,7 @@ static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float *
|
|||
c3 = _mm256_extractf32x4_ps(c256_2, 0);
|
||||
c4 = _mm256_extractf32x4_ps(c256_2, 1);
|
||||
|
||||
ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, y));
|
||||
ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, &y[idx_m]));
|
||||
_mm_mask_storeu_ps(&y[idx_m], 0xff, ret);
|
||||
}
|
||||
|
||||
|
@ -958,6 +958,7 @@ static int sgemv_kernel_t_7(BLASLONG m, float alpha, float *a, float *x, float *
|
|||
c256_1 = _mm512_extractf32x8_ps(tmp0, 1);
|
||||
|
||||
c256_0 = _mm256_add_ps(c256_0, c256_1);
|
||||
c256_0 = _mm256_mul_ps(c256_0, alpha256);
|
||||
|
||||
__m128 c128_0 = _mm256_extractf32x4_ps(c256_0, 0);
|
||||
__m128 c128_1 = _mm256_extractf32x4_ps(c256_0, 1);
|
||||
|
@ -1016,9 +1017,10 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float *
|
|||
__m512 m0, m1, m2, m3;
|
||||
__m256 r0, r1, r2, r3, r4, r5, r6, r7, tmp0, tmp1, tmp2, tmp3;
|
||||
__m128 c128_0, c128_1, c128_2, c128_3;
|
||||
__m128 alpha128 = _mm_set1_ps(alpha);
|
||||
__m256 alpha256 = _mm256_set1_ps(alpha);
|
||||
|
||||
__m256 x256 = _mm256_loadu_ps(x);
|
||||
x256 = _mm256_mul_ps(x256, alpha256);
|
||||
__m512 x512 = _mm512_broadcast_f32x8(x256);
|
||||
|
||||
for(BLASLONG idx_m=0; idx_m<tag_m_8x; idx_m+=8) {
|
||||
|
@ -1053,8 +1055,8 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float *
|
|||
|
||||
c128_0 = _mm_add_ps(c128_0, c128_1);
|
||||
c128_2 = _mm_add_ps(c128_2, c128_3);
|
||||
_mm_storeu_ps(&y[idx_m], _mm_fmadd_ps(c128_0, alpha128, _mm_loadu_ps(&y[idx_m])));
|
||||
_mm_storeu_ps(&y[idx_m+4], _mm_fmadd_ps(c128_2, alpha128, _mm_loadu_ps(&y[idx_m+4])));
|
||||
_mm_storeu_ps(&y[idx_m], _mm_add_ps(c128_0, _mm_loadu_ps(&y[idx_m])));
|
||||
_mm_storeu_ps(&y[idx_m+4], _mm_add_ps(c128_2, _mm_loadu_ps(&y[idx_m+4])));
|
||||
}
|
||||
|
||||
if (tag_m_8x !=m ){
|
||||
|
@ -1078,7 +1080,7 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float *
|
|||
c128_1 = _mm256_extractf32x4_ps(tmp1, 1);
|
||||
|
||||
c128_0 = _mm_add_ps(c128_0, c128_1);
|
||||
_mm_storeu_ps(&y[idx_m], _mm_fmadd_ps(c128_0, alpha128, _mm_loadu_ps(&y[idx_m])));
|
||||
_mm_storeu_ps(&y[idx_m], _mm_add_ps(c128_0, _mm_loadu_ps(&y[idx_m])));
|
||||
|
||||
}
|
||||
|
||||
|
@ -1094,7 +1096,6 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float *
|
|||
c128_1 = _mm256_extractf32x4_ps(tmp0, 1);
|
||||
|
||||
c128_0 = _mm_add_ps(c128_0, c128_1);
|
||||
c128_0 = _mm_mul_ps(c128_0, alpha128);
|
||||
|
||||
_mm_storeu_ps(ret, c128_0);
|
||||
y[idx_m] += (ret[0]+ret[1]);
|
||||
|
|
Loading…
Reference in New Issue