Merge pull request #3348 from guowangy/skylakex-sgemv_t-fix

skylakex sgemv_t kernel fix
This commit is contained in:
Martin Kroeker 2021-08-25 22:43:45 +02:00 committed by GitHub
commit 6bb1805ed6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 17 deletions

View File

@ -302,9 +302,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
FLOAT * xbuffer_align = x;
FLOAT * ybuffer_align = y;
FLOAT * xbuffer = NULL;
FLOAT * ybuffer = NULL;
if (inc_x != 1) {
xbuffer_align = buffer;
for(BLASLONG i=0; i<n; i++) {

View File

@ -38,7 +38,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "sgemv_t_microk_haswell-4.c"
#elif defined (SKYLAKEX) || defined (COOPERLAKE)
#include "sgemv_t_microk_haswell-4.c"
/*#include "sgemv_t_microk_skylakex.c"*/
#include "sgemv_t_microk_skylakex.c"
#endif
#if defined(STEAMROLLER) || defined(EXCAVATOR)

View File

@ -93,7 +93,7 @@ static int sgemv_kernel_t_1(BLASLONG m, float alpha, float *a, float *x, float *
}
if (tag_m_32x != m) {
for (BLASLONG idx_m = tag_m_64x; idx_m < tag_m_16x; idx_m+=32) {
for (BLASLONG idx_m = tag_m_32x; idx_m < tag_m_16x; idx_m+=16) {
matrixArray_0 = _mm512_loadu_ps(&a[idx_m + 0]);
_mm512_storeu_ps(&y[idx_m + 0], _mm512_fmadd_ps(matrixArray_0, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 0])));
@ -145,8 +145,8 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float *
}
if (tag_m_32x != m) {
for (BLASLONG idx_m = tag_m_32x; idx_m < tag_m_16x; idx_m+=16) {
m0 = _mm512_loadu_ps(&a[idx_m]);
m1 = _mm512_loadu_ps(&a[idx_m + 16]);
m0 = _mm512_loadu_ps(&a[idx_m*2]);
m1 = _mm512_loadu_ps(&a[idx_m*2 + 16]);
col1_1 = _mm512_permutex2var_ps(m0, idx_base_0, m1);
col1_2 = _mm512_permutex2var_ps(m0, idx_base_1, m1);
_mm512_storeu_ps(&y[idx_m], _mm512_add_ps(_mm512_fmadd_ps(x2Array, col1_2, _mm512_mul_ps(col1_1, x1Array)), _mm512_loadu_ps(&y[idx_m])));
@ -157,7 +157,7 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float *
__mmask8 load_mask = *((__mmask8*) &load_mask_value);
x1Array = _mm512_broadcast_f32x2(_mm_maskz_loadu_ps(load_mask, x));
for (BLASLONG idx_m = tag_m_16x; idx_m < tag_m_8x; idx_m+=8) {
m0 = _mm512_loadu_ps(&a[idx_m]);
m0 = _mm512_loadu_ps(&a[idx_m*2]);
m1 = _mm512_mul_ps(_mm512_mul_ps(m0, x1Array), ALPHAVECTOR);
m2 = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), m1);
__m256 ret = _mm256_add_ps(_mm512_extractf32x8_ps(m2, 1), _mm512_extractf32x8_ps(m2, 0));
@ -166,12 +166,12 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float *
}
if (tag_m_8x != m) {
unsigned short tail_mask_value = (((unsigned int)0xffff) >> (16-((m-tag_m_8x)*2)&15));
unsigned short tail_mask_value = (((unsigned int)0xffff) >> (16-(((m-tag_m_8x)*2)&15)));
__mmask16 a_mask = *((__mmask16*) &tail_mask_value);
unsigned char y_mask_value = (((unsigned char)0xff) >> (8-(m-tag_m_8x)));
__mmask8 y_mask = *((__mmask8*) &y_mask_value);
m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x]);
m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x*2]);
m1 = _mm512_mul_ps(_mm512_mul_ps(m0, x1Array), ALPHAVECTOR);
m2 = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), m1);
__m256 ret = _mm256_add_ps(_mm512_extractf32x8_ps(m2, 1), _mm512_extractf32x8_ps(m2, 0));
@ -322,7 +322,7 @@ static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float *
{
BLASLONG tag_m_4x = m & (~3);
BLASLONG tag_m_2x = m & (~1);
__m512 m0, m1, m2;
__m512 m0, m1;
__m256 m256_0, m256_1, c256_1, c256_2;
__m128 c1, c2, c3, c4, ret;
__m128 xarray = _mm_maskz_loadu_ps(0x0f, x);
@ -346,7 +346,7 @@ static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float *
c3 = _mm256_extractf32x4_ps(c256_2, 0);
c4 = _mm256_extractf32x4_ps(c256_2, 1);
ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, y));
ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, &y[idx_m]));
_mm_mask_storeu_ps(&y[idx_m], 0xff, ret);
}
@ -958,6 +958,7 @@ static int sgemv_kernel_t_7(BLASLONG m, float alpha, float *a, float *x, float *
c256_1 = _mm512_extractf32x8_ps(tmp0, 1);
c256_0 = _mm256_add_ps(c256_0, c256_1);
c256_0 = _mm256_mul_ps(c256_0, alpha256);
__m128 c128_0 = _mm256_extractf32x4_ps(c256_0, 0);
__m128 c128_1 = _mm256_extractf32x4_ps(c256_0, 1);
@ -1016,9 +1017,10 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float *
__m512 m0, m1, m2, m3;
__m256 r0, r1, r2, r3, r4, r5, r6, r7, tmp0, tmp1, tmp2, tmp3;
__m128 c128_0, c128_1, c128_2, c128_3;
__m128 alpha128 = _mm_set1_ps(alpha);
__m256 alpha256 = _mm256_set1_ps(alpha);
__m256 x256 = _mm256_loadu_ps(x);
x256 = _mm256_mul_ps(x256, alpha256);
__m512 x512 = _mm512_broadcast_f32x8(x256);
for(BLASLONG idx_m=0; idx_m<tag_m_8x; idx_m+=8) {
@ -1053,8 +1055,8 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float *
c128_0 = _mm_add_ps(c128_0, c128_1);
c128_2 = _mm_add_ps(c128_2, c128_3);
_mm_storeu_ps(&y[idx_m], _mm_fmadd_ps(c128_0, alpha128, _mm_loadu_ps(&y[idx_m])));
_mm_storeu_ps(&y[idx_m+4], _mm_fmadd_ps(c128_2, alpha128, _mm_loadu_ps(&y[idx_m+4])));
_mm_storeu_ps(&y[idx_m], _mm_add_ps(c128_0, _mm_loadu_ps(&y[idx_m])));
_mm_storeu_ps(&y[idx_m+4], _mm_add_ps(c128_2, _mm_loadu_ps(&y[idx_m+4])));
}
if (tag_m_8x !=m ){
@ -1078,7 +1080,7 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float *
c128_1 = _mm256_extractf32x4_ps(tmp1, 1);
c128_0 = _mm_add_ps(c128_0, c128_1);
_mm_storeu_ps(&y[idx_m], _mm_fmadd_ps(c128_0, alpha128, _mm_loadu_ps(&y[idx_m])));
_mm_storeu_ps(&y[idx_m], _mm_add_ps(c128_0, _mm_loadu_ps(&y[idx_m])));
}
@ -1094,7 +1096,6 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float *
c128_1 = _mm256_extractf32x4_ps(tmp0, 1);
c128_0 = _mm_add_ps(c128_0, c128_1);
c128_0 = _mm_mul_ps(c128_0, alpha128);
_mm_storeu_ps(ret, c128_0);
y[idx_m] += (ret[0]+ret[1]);