diff --git a/kernel/x86_64/sgemv_t_4.c b/kernel/x86_64/sgemv_t_4.c index fe886f57f..a36c8ace9 100644 --- a/kernel/x86_64/sgemv_t_4.c +++ b/kernel/x86_64/sgemv_t_4.c @@ -34,8 +34,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "sgemv_t_microk_bulldozer-4.c" #elif defined(SANDYBRIDGE) #include "sgemv_t_microk_sandy-4.c" -#elif defined(HASWELL) || defined(ZEN) || defined (SKYLAKEX) || defined (COOPERLAKE) +#elif defined(HASWELL) || defined(ZEN) #include "sgemv_t_microk_haswell-4.c" +#elif defined (SKYLAKEX) || defined (COOPERLAKE) +#include "sgemv_t_microk_haswell-4.c" +#include "sgemv_t_microk_skylakex.c" #endif #if defined(STEAMROLLER) || defined(EXCAVATOR) @@ -305,6 +308,37 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO if ( m < 1 ) return(0); if ( n < 1 ) return(0); + #ifdef HAVE_SGEMV_T_SKYLAKE_KERNEL + if (lda == m && n <= 16384 && m <= 8) + { + FLOAT * xbuffer_align = x; + FLOAT * ybuffer_align = y; + + if (inc_x != 1) { + xbuffer_align = buffer; + for(BLASLONG i=0; i= 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 6)) + +#define HAVE_SGEMV_T_SKYLAKE_KERNEL 1 +#include "common.h" +#include +#include "sgemv_t_microk_skylakex_template.c" + +//sgemv_t: +// ----- m ----- +// |<----------- +// |<----------- +// n +// |<----------- +// |<----------- + +static int sgemv_kernel_t(BLASLONG m, BLASLONG n, float alpha, float *a, float *x, float *y) +{ + switch(m) { + case 1: sgemv_kernel_t_1(n, alpha, a, x, y); break; + case 2: sgemv_kernel_t_2(n, alpha, a, x, y); break; + case 3: sgemv_kernel_t_3(n, alpha, a, x, y); break; + case 4: sgemv_kernel_t_4(n, alpha, a, x, y); break; + case 5: sgemv_kernel_t_5(n, alpha, a, x, y); break; + case 6: sgemv_kernel_t_6(n, alpha, a, x, y); break; + case 7: sgemv_kernel_t_7(n, alpha, a, x, y); break; + case 8: sgemv_kernel_t_8(n, alpha, a, x, y); break; + default: break; + } + return 0; +} + +#endif diff --git a/kernel/x86_64/sgemv_t_microk_skylakex_template.c b/kernel/x86_64/sgemv_t_microk_skylakex_template.c new file mode 100644 index 000000000..34415054c --- /dev/null +++ b/kernel/x86_64/sgemv_t_microk_skylakex_template.c @@ -0,0 +1,1120 @@ +/*************************************************************************** +Copyright (c) 2014, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ +#include +#include "common.h" + +//Here the m means n in sgemv_t: +// ----- n ----- +// | +// | +// m +// | +// | +static int sgemv_kernel_t_1(BLASLONG m, float alpha, float *a, float *x, float *y) +{ + //printf("enter into t_1 kernel\n"); + //printf("m = %ld\n", m); + __m512 matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7; + float alphaX = alpha * (*x); + __m512 ALPHAXVECTOR = _mm512_set1_ps(alphaX); + + BLASLONG tag_m_128x = m & (~127); + BLASLONG tag_m_64x = m & (~63); + BLASLONG tag_m_32x = m & (~31); + BLASLONG tag_m_16x = m & (~15); + + for (BLASLONG idx_m = 0; idx_m < tag_m_128x; idx_m+=128) { + matrixArray_0 = _mm512_loadu_ps(&a[idx_m + 0]); + matrixArray_1 = _mm512_loadu_ps(&a[idx_m + 16]); + matrixArray_2 = _mm512_loadu_ps(&a[idx_m + 32]); + matrixArray_3 = _mm512_loadu_ps(&a[idx_m + 48]); + matrixArray_4 = _mm512_loadu_ps(&a[idx_m + 64]); + matrixArray_5 = _mm512_loadu_ps(&a[idx_m + 80]); + matrixArray_6 = _mm512_loadu_ps(&a[idx_m + 96]); + matrixArray_7 = _mm512_loadu_ps(&a[idx_m + 112]); + + _mm512_storeu_ps(&y[idx_m + 0], _mm512_fmadd_ps(matrixArray_0, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 0]))); + _mm512_storeu_ps(&y[idx_m + 16], _mm512_fmadd_ps(matrixArray_1, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 16]))); + _mm512_storeu_ps(&y[idx_m + 32], _mm512_fmadd_ps(matrixArray_2, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 32]))); + _mm512_storeu_ps(&y[idx_m + 48], _mm512_fmadd_ps(matrixArray_3, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 48]))); + _mm512_storeu_ps(&y[idx_m + 64], _mm512_fmadd_ps(matrixArray_4, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 64]))); + _mm512_storeu_ps(&y[idx_m + 80], _mm512_fmadd_ps(matrixArray_5, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 80]))); + _mm512_storeu_ps(&y[idx_m + 96], _mm512_fmadd_ps(matrixArray_6, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 96]))); + _mm512_storeu_ps(&y[idx_m + 112], _mm512_fmadd_ps(matrixArray_7, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 112]))); + + } + + if (tag_m_128x != m) { + for (BLASLONG idx_m = tag_m_128x; idx_m < tag_m_64x; idx_m+=64) { + matrixArray_0 = _mm512_loadu_ps(&a[idx_m + 0]); + matrixArray_1 = _mm512_loadu_ps(&a[idx_m + 16]); + matrixArray_2 = _mm512_loadu_ps(&a[idx_m + 32]); + matrixArray_3 = _mm512_loadu_ps(&a[idx_m + 48]); + + _mm512_storeu_ps(&y[idx_m + 0], _mm512_fmadd_ps(matrixArray_0, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 0]))); + _mm512_storeu_ps(&y[idx_m + 16], _mm512_fmadd_ps(matrixArray_1, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 16]))); + _mm512_storeu_ps(&y[idx_m + 32], _mm512_fmadd_ps(matrixArray_2, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 32]))); + _mm512_storeu_ps(&y[idx_m + 48], _mm512_fmadd_ps(matrixArray_3, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 48]))); + + } + + if (tag_m_64x != m) { + for (BLASLONG idx_m = tag_m_64x; idx_m < tag_m_32x; idx_m+=32) { + matrixArray_0 = _mm512_loadu_ps(&a[idx_m + 0]); + matrixArray_1 = _mm512_loadu_ps(&a[idx_m + 16]); + + _mm512_storeu_ps(&y[idx_m + 0], _mm512_fmadd_ps(matrixArray_0, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 0]))); + _mm512_storeu_ps(&y[idx_m + 16], _mm512_fmadd_ps(matrixArray_1, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 16]))); + + } + + if (tag_m_32x != m) { + for (BLASLONG idx_m = tag_m_64x; idx_m < tag_m_16x; idx_m+=32) { + matrixArray_0 = _mm512_loadu_ps(&a[idx_m + 0]); + + _mm512_storeu_ps(&y[idx_m + 0], _mm512_fmadd_ps(matrixArray_0, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 0]))); + } + + if (tag_m_16x != m) { + unsigned short tail_mask_value = (((unsigned int)0xffff) >> (16-(m&15))); + __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); + matrixArray_0 = _mm512_maskz_loadu_ps(tail_mask, &a[tag_m_16x]); + + _mm512_mask_storeu_ps(&y[tag_m_16x], tail_mask, _mm512_fmadd_ps(matrixArray_0, ALPHAXVECTOR, _mm512_maskz_loadu_ps(tail_mask, &y[tag_m_16x]))); + + } + + + } + } + } + + return 0; +} + +static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float *y) +{ + __m512 m0, m1, m2, m3, col0_1, col0_2, col1_1, col1_2, x1Array, x2Array; + float x1a = x[0] * alpha; + float x2a = x[1] * alpha; + x1Array = _mm512_set1_ps(x1a); + x2Array = _mm512_set1_ps(x2a); + BLASLONG tag_m_32x = m & (~31); + BLASLONG tag_m_16x = m & (~15); + BLASLONG tag_m_8x = m & (~7); + __m512i M512_EPI32_1 = _mm512_set1_epi32(1); + __m512i idx_base_0 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); + __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_1); + + for (BLASLONG idx_m = 0; idx_m < tag_m_32x; idx_m+=32) { + m0 = _mm512_loadu_ps(&a[idx_m*2]); + m1 = _mm512_loadu_ps(&a[idx_m*2 + 16]); + m2 = _mm512_loadu_ps(&a[idx_m*2 + 32]); + m3 = _mm512_loadu_ps(&a[idx_m*2 + 48]); + col0_1 = _mm512_permutex2var_ps(m0, idx_base_0, m1); + col0_2 = _mm512_permutex2var_ps(m0, idx_base_1, m1); + col1_1 = _mm512_permutex2var_ps(m2, idx_base_0, m3); + col1_2 = _mm512_permutex2var_ps(m2, idx_base_1, m3); + + _mm512_storeu_ps(&y[idx_m], _mm512_add_ps(_mm512_fmadd_ps(x2Array, col0_2, _mm512_mul_ps(col0_1, x1Array)), _mm512_loadu_ps(&y[idx_m]))); + _mm512_storeu_ps(&y[idx_m + 16], _mm512_add_ps(_mm512_fmadd_ps(x2Array, col1_2, _mm512_mul_ps(col1_1, x1Array)), _mm512_loadu_ps(&y[idx_m + 16]))); + } + if (tag_m_32x != m) { + for (BLASLONG idx_m = tag_m_32x; idx_m < tag_m_16x; idx_m+=16) { + m0 = _mm512_loadu_ps(&a[idx_m]); + m1 = _mm512_loadu_ps(&a[idx_m + 16]); + col1_1 = _mm512_permutex2var_ps(m0, idx_base_0, m1); + col1_2 = _mm512_permutex2var_ps(m0, idx_base_1, m1); + _mm512_storeu_ps(&y[idx_m], _mm512_add_ps(_mm512_fmadd_ps(x2Array, col1_2, _mm512_mul_ps(col1_1, x1Array)), _mm512_loadu_ps(&y[idx_m]))); + } + if (tag_m_16x != m) { + __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); + unsigned char load_mask_value = (((unsigned char)0xff) >> 6); + __mmask8 load_mask = *((__mmask8*) &load_mask_value); + x1Array = _mm512_broadcast_f32x2(_mm_maskz_loadu_ps(load_mask, x)); + for (BLASLONG idx_m = tag_m_16x; idx_m < tag_m_8x; idx_m+=8) { + m0 = _mm512_loadu_ps(&a[idx_m]); + m1 = _mm512_mul_ps(_mm512_mul_ps(m0, x1Array), ALPHAVECTOR); + m2 = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), m1); + __m256 ret = _mm256_add_ps(_mm512_extractf32x8_ps(m2, 1), _mm512_extractf32x8_ps(m2, 0)); + _mm256_storeu_ps(&y[idx_m], _mm256_add_ps(ret, _mm256_loadu_ps(&y[idx_m]))); + + } + + if (tag_m_8x != m) { + unsigned short tail_mask_value = (((unsigned int)0xffff) >> (16-((m-tag_m_8x)*2)&15)); + __mmask16 a_mask = *((__mmask16*) &tail_mask_value); + unsigned char y_mask_value = (((unsigned char)0xff) >> (8-(m-tag_m_8x))); + __mmask8 y_mask = *((__mmask8*) &y_mask_value); + + m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x]); + m1 = _mm512_mul_ps(_mm512_mul_ps(m0, x1Array), ALPHAVECTOR); + m2 = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), m1); + __m256 ret = _mm256_add_ps(_mm512_extractf32x8_ps(m2, 1), _mm512_extractf32x8_ps(m2, 0)); + _mm256_mask_storeu_ps(&y[tag_m_8x], y_mask, _mm256_add_ps(ret, _mm256_maskz_loadu_ps(y_mask, &y[tag_m_8x]))); + } + } + } + return 0; +} + +static int sgemv_kernel_t_3(BLASLONG m, float alpha, float *a, float *x, float *y) +{ + __m512 m0, m1, m2, c1, c2, c3, tmp, x1Array, x2Array, x3Array; + float x1a = x[0] * alpha; + float x2a = x[1] * alpha; + float x3a = x[2] * alpha; + x1Array = _mm512_set1_ps(x1a); + x2Array = _mm512_set1_ps(x2a); + x3Array = _mm512_set1_ps(x3a); + BLASLONG tag_m_16x = m & (~15); + BLASLONG tag_m_8x = m & (~7); + BLASLONG tag_m_4x = m & (~3); + BLASLONG tag_m_2x = m & (~1); + + __m512i M512_EPI32_1 = _mm512_set1_epi32(1); + __m512i M512_EPI32_s1 = _mm512_set1_epi32(-1); + __m512i idx_c1_1 = _mm512_set_epi32(0, 0, 0, 0, 0, 30, 27, 24, 21, 18, 15, 12, 9, 6, 3, 0); + __m512i idx_c2_1 = _mm512_add_epi32(idx_c1_1, M512_EPI32_1); + __m512i idx_c3_1 = _mm512_add_epi32(idx_c2_1, M512_EPI32_1); + + __m512i idx_c3_2 = _mm512_set_epi32(31, 28, 25, 22, 19, 16, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + __m512i idx_c2_2 = _mm512_add_epi32(idx_c3_2, M512_EPI32_s1); + __m512i idx_c1_2 = _mm512_add_epi32(idx_c2_2, M512_EPI32_s1); + + __mmask16 step_1 = 0x07ff; + __mmask16 step_2 = 0xf800; + __mmask16 c31 = 0x03ff; + + for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) { + m0 = _mm512_loadu_ps(&a[idx_m*3]); + m1 = _mm512_loadu_ps(&a[idx_m*3 + 16]); + m2 = _mm512_loadu_ps(&a[idx_m*3 + 32]); + + tmp = _mm512_mask_permutex2var_ps(m0, step_1, idx_c1_1, m1); + c1 = _mm512_mask_permutex2var_ps(tmp, step_2, idx_c1_2, m2); + tmp = _mm512_mask_permutex2var_ps(m0, step_1, idx_c2_1, m1); + c2 = _mm512_mask_permutex2var_ps(tmp, step_2, idx_c2_2, m2); + tmp = _mm512_mask_permutex2var_ps(m0, c31, idx_c3_1, m1); + c3 = _mm512_permutex2var_ps(tmp, idx_c3_2, m2); + + tmp = _mm512_fmadd_ps(x2Array, c2, _mm512_mul_ps(c1, x1Array)); + _mm512_storeu_ps(&y[idx_m], _mm512_add_ps(_mm512_fmadd_ps(x3Array, c3, tmp), _mm512_loadu_ps(&y[idx_m]))); + } + + if(tag_m_16x != m) { + __mmask8 a_mask = 0xff; + __m256i M256_EPI32_1 = _mm256_maskz_set1_epi32(a_mask, 1); + __m256i M256_EPI32_s1 = _mm256_maskz_set1_epi32(a_mask, -1); + __m256i idx_c1_1 = _mm256_set_epi32(0, 0, 15, 12, 9, 6, 3, 0); + __m256i idx_c2_1 = _mm256_add_epi32(idx_c1_1, M256_EPI32_1); + __m256i idx_c3_1 = _mm256_add_epi32(idx_c2_1, M256_EPI32_1); + + __m256i idx_c3_2 = _mm256_set_epi32(15, 12, 9, 0, 0, 0, 0, 0); + __m256i idx_c2_2 = _mm256_add_epi32(idx_c3_2, M256_EPI32_s1); + __m256i idx_c1_2 = _mm256_add_epi32(idx_c2_2, M256_EPI32_s1); + + __mmask8 step_1 = 0x1f; + __mmask8 step_2 = 0xe0; + __mmask8 c12 = 0xc0; + + __m256 m256_0, m256_1, m256_2, tmp256, c256_1, c256_2, c256_3, x256_1, x256_2, x256_3; + x256_1 = _mm256_set1_ps(x1a); + x256_2 = _mm256_set1_ps(x2a); + x256_3 = _mm256_set1_ps(x3a); + + for (BLASLONG idx_m = tag_m_16x; idx_m < tag_m_8x; idx_m+=8) { + m256_0 = _mm256_loadu_ps(&a[idx_m*3]); + m256_1 = _mm256_loadu_ps(&a[idx_m*3 + 8]); + m256_2 = _mm256_loadu_ps(&a[idx_m*3 + 16]); + + tmp256 = _mm256_permutex2var_ps(m256_0, idx_c1_1, m256_1); + c256_1 = _mm256_mask_permutex2var_ps(tmp256, c12, idx_c1_2, m256_2); + tmp256 = _mm256_mask_permutex2var_ps(m256_0, step_1, idx_c2_1, m256_1); + c256_2 = _mm256_mask_permutex2var_ps(tmp256, step_2, idx_c2_2, m256_2); + tmp256 = _mm256_mask_permutex2var_ps(m256_0, step_1, idx_c3_1, m256_1); + c256_3 = _mm256_mask_permutex2var_ps(tmp256, step_2, idx_c3_2, m256_2); + + tmp256 = _mm256_fmadd_ps(x256_2, c256_2, _mm256_mul_ps(c256_1, x256_1)); + _mm256_storeu_ps(&y[idx_m], _mm256_maskz_add_ps(a_mask, _mm256_fmadd_ps(x256_3, c256_3, tmp256), _mm256_loadu_ps(&y[idx_m]))); + } + + if(tag_m_8x != m){ + for (BLASLONG idx_m = tag_m_8x; idx_m < tag_m_4x; idx_m+=4){ + m0 = _mm512_maskz_loadu_ps(0x0fff, &a[tag_m_8x*3]); + m256_0 = _mm512_extractf32x8_ps(m0, 0); + m256_1 = _mm512_extractf32x8_ps(m0, 1); + __m256i idx1 = _mm256_set_epi32(10, 7, 4, 1, 9, 6, 3, 0); + __m256i M256_EPI32_2 = _mm256_maskz_set1_epi32(0x0f, 2); + __m256i idx2 = _mm256_add_epi32(idx1, M256_EPI32_2); + + c256_1 = _mm256_mask_permutex2var_ps(m256_0, 0xff, idx1, m256_1); + c256_2 = _mm256_mask_permutex2var_ps(m256_0, 0x0f, idx2, m256_1); + + __m128 c128_1 = _mm256_extractf32x4_ps(c256_1, 0); + __m128 c128_2 = _mm256_extractf32x4_ps(c256_1, 1); + __m128 c128_3 = _mm256_extractf32x4_ps(c256_2, 0); + + __m128 x128_1 = _mm_set1_ps(x1a); + __m128 x128_2 = _mm_set1_ps(x2a); + __m128 x128_3 = _mm_set1_ps(x3a); + + __m128 tmp128 = _mm_maskz_fmadd_ps(0x0f, c128_1, x128_1, _mm_maskz_mul_ps(0x0f, c128_2, x128_2)); + _mm_mask_storeu_ps(&y[idx_m], 0x0f, _mm_maskz_add_ps(0x0f, _mm_maskz_fmadd_ps(0x0f, c128_3, x128_3, tmp128), _mm_maskz_loadu_ps(0x0f, &y[idx_m]))); + } + + if(tag_m_4x != m) { + for (BLASLONG idx_m = tag_m_4x; idx_m < tag_m_2x; idx_m+=2) { + m256_0 = _mm256_maskz_loadu_ps(0x3f, &a[idx_m*3]); + __m128 a128_1 = _mm256_extractf32x4_ps(m256_0, 0); + __m128 a128_2 = _mm256_extractf32x4_ps(m256_0, 1); + __m128 x128 = _mm_maskz_loadu_ps(0x07, x); + + __m128i idx128_1= _mm_set_epi32(0, 2, 1, 0); + __m128i M128_EPI32_3 = _mm_maskz_set1_epi32(0x07, 3); + __m128i idx128_2 = _mm_add_epi32(idx128_1, M128_EPI32_3); + + __m128 c128_1 = _mm_maskz_permutex2var_ps(0x07, a128_1, idx128_1, a128_2); + __m128 c128_2 = _mm_maskz_permutex2var_ps(0x07, a128_1, idx128_2, a128_2); + + __m128 tmp128 = _mm_hadd_ps(_mm_maskz_mul_ps(0x07, c128_1, x128), _mm_maskz_mul_ps(0x07, c128_2, x128)); + float ret[4]; + _mm_mask_storeu_ps(ret, 0x0f, tmp128); + y[idx_m] += alpha *(ret[0] + ret[1]); + y[idx_m+1] += alpha * (ret[2] + ret[3]); + } + + if(tag_m_2x != m) { + y[tag_m_2x] += alpha*(a[tag_m_2x*3]*x[0] + a[tag_m_2x*3+1]*x[1] + a[tag_m_2x*3+2]*x[2]); + } + } + } + } + + return 0; +} + +static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float *y) +{ + BLASLONG tag_m_4x = m & (~3); + BLASLONG tag_m_2x = m & (~1); + __m512 m0, m1, m2; + __m256 m256_0, m256_1, c256_1, c256_2; + __m128 c1, c2, c3, c4, ret; + __m128 xarray = _mm_maskz_loadu_ps(0x0f, x); + __m512 x512 = _mm512_broadcast_f32x4(xarray); + __m512 alphavector = _mm512_set1_ps(alpha); + __m512 xa512 = _mm512_mul_ps(x512, alphavector); + __m256i idx1 = _mm256_set_epi32(13, 9, 5, 1, 12, 8, 4, 0); + __m256i idx2 = _mm256_set_epi32(15, 11, 7, 3, 14, 10, 6, 2); + + + for (BLASLONG idx_m = 0; idx_m < tag_m_4x; idx_m+=4) { + m0 = _mm512_loadu_ps(&a[idx_m*4]); + m1 = _mm512_mul_ps(m0, xa512); + m256_0 = _mm512_extractf32x8_ps(m1, 0); + m256_1 = _mm512_extractf32x8_ps(m1, 1); + c256_1 = _mm256_mask_permutex2var_ps(m256_0, 0xff, idx1, m256_1); + c256_2 = _mm256_mask_permutex2var_ps(m256_0, 0xff, idx2, m256_1); + + c1 = _mm256_extractf32x4_ps(c256_1, 0); + c2 = _mm256_extractf32x4_ps(c256_1, 1); + c3 = _mm256_extractf32x4_ps(c256_2, 0); + c4 = _mm256_extractf32x4_ps(c256_2, 1); + + ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, y)); + _mm_mask_storeu_ps(&y[idx_m], 0xff, ret); + } + + if(tag_m_4x != m) { + float result[4]; + for(BLASLONG idx_m=tag_m_4x; idx_m < tag_m_2x; idx_m+=2) { + m256_0 = _mm256_maskz_loadu_ps(0xff, &a[idx_m*4]); + c1 = _mm256_maskz_extractf32x4_ps(0xff, m256_0, 0); + c2 = _mm256_maskz_extractf32x4_ps(0xff, m256_0, 1); + + c3 = _mm_maskz_mul_ps(0x0f, c1, xarray); + c4 = _mm_maskz_mul_ps(0x0f, c2, xarray); + + ret = _mm_hadd_ps(c3, c4); + _mm_mask_storeu_ps(result, 0x0f, ret); + y[idx_m] += alpha *(result[0] + result[1]); + y[idx_m+1] += alpha * (result[2] + result[3]); + } + + if(tag_m_2x != m ) { + c1 = _mm_maskz_loadu_ps(0x0f, &a[tag_m_2x * 4]); + c2 = _mm_maskz_mul_ps(0x0f, c1, xarray); + _mm_mask_storeu_ps(result, 0x0f, c2); + y[tag_m_2x] += alpha *(result[0] + result[1] + result[2] + result[3]); + } + } + + return 0; +} + +static int sgemv_kernel_t_5(BLASLONG m, float alpha, float *a, float *x, float *y) +{ + BLASLONG tag_m_16x = m & (~15); + BLASLONG tag_m_8x = m & (~7); + BLASLONG tag_m_4x = m & (~3); + BLASLONG tag_m_2x = m & (~1); + __m512 m0, m1, m2, m3, m4, tmp0, tmp1, tmp2, accum, c0, c1, c2, c3, c4; + __m512 x0_512 = _mm512_set1_ps(x[0]); + __m512 x1_512 = _mm512_set1_ps(x[1]); + __m512 x2_512 = _mm512_set1_ps(x[2]); + __m512 x3_512 = _mm512_set1_ps(x[3]); + __m512 x4_512 = _mm512_set1_ps(x[4]); + __m512 alpha_512 = _mm512_set1_ps(alpha); + + + __m512i M512_EPI32_1 = _mm512_set1_epi32(1); + __m512i M512_EPI32_16 = _mm512_set1_epi32(16); + __m512i M512_EPI32_0 = _mm512_setzero_epi32(); + + __m512i idx_c0 = _mm512_set_epi32(27, 22, 17, 28, 23, 18, 13, 8, 3, 30, 25, 20, 15, 10, 5, 0); + __m512i idx_c1 = _mm512_add_epi32(idx_c0, M512_EPI32_1); + __m512i idx_c2 = _mm512_add_epi32(idx_c1, M512_EPI32_1); + idx_c2 = _mm512_mask_blend_epi32(0x0040, idx_c2, M512_EPI32_0); + __m512i idx_c3 = _mm512_add_epi32(idx_c2, M512_EPI32_1); + __m512i idx_c4 = _mm512_add_epi32(idx_c3, M512_EPI32_1); + idx_c4 = _mm512_mask_blend_epi32(0x1000, idx_c4, M512_EPI32_16); + + for (BLASLONG idx_m=0; idx_m < tag_m_16x; idx_m+=16) { + m0 = _mm512_loadu_ps(&a[idx_m*5]); + m1 = _mm512_loadu_ps(&a[idx_m*5 + 16]); + m2 = _mm512_loadu_ps(&a[idx_m*5 + 32]); + m3 = _mm512_loadu_ps(&a[idx_m*5 + 48]); + m4 = _mm512_loadu_ps(&a[idx_m*5 + 64]); + + tmp0 = _mm512_maskz_permutex2var_ps(0x007f, m0, idx_c0, m1); + tmp1 = _mm512_maskz_permutex2var_ps(0x1f80, m2, idx_c0, m3); + c0 = _mm512_mask_blend_ps(0x1f80, tmp0, tmp1); + c0 = _mm512_mask_permutex2var_ps(c0, 0xe000, idx_c0, m4); + + tmp0 = _mm512_maskz_permutex2var_ps(0x007f, m0, idx_c1, m1); + tmp1 = _mm512_maskz_permutex2var_ps(0x1f80, m2, idx_c1, m3); + c1 = _mm512_mask_blend_ps(0x1f80, tmp0, tmp1); + c1 = _mm512_mask_permutex2var_ps(c1, 0xe000, idx_c1, m4); + + tmp0 = _mm512_maskz_permutex2var_ps(0x003f, m0, idx_c2, m1); + tmp1 = _mm512_maskz_permutex2var_ps(0x1fc0, m2, idx_c2, m3); + c2 = _mm512_mask_blend_ps(0x1fc0, tmp0, tmp1); + c2 = _mm512_mask_permutex2var_ps(c2, 0xe000, idx_c2, m4); + + tmp0 = _mm512_maskz_permutex2var_ps(0x003f, m0, idx_c3, m1); + tmp1 = _mm512_maskz_permutex2var_ps(0x1fc0, m2, idx_c3, m3); + c3 = _mm512_mask_blend_ps(0x1fc0, tmp0, tmp1); + c3 = _mm512_mask_permutex2var_ps(c3, 0xe000, idx_c3, m4); + + tmp0 = _mm512_maskz_permutex2var_ps(0x003f, m0, idx_c4, m1); + tmp1 = _mm512_maskz_permutex2var_ps(0x0fc0, m2, idx_c4, m3); + c4 = _mm512_mask_blend_ps(0x0fc0, tmp0, tmp1); + c4 = _mm512_mask_permutex2var_ps(c4, 0xf000, idx_c4, m4); + + accum = _mm512_fmadd_ps(c1, x1_512, _mm512_mul_ps(c0, x0_512)); + accum = _mm512_fmadd_ps(c2, x2_512, accum); + accum = _mm512_fmadd_ps(c3, x3_512, accum); + accum = _mm512_fmadd_ps(c4, x4_512, accum); + accum = _mm512_fmadd_ps(accum, alpha_512, _mm512_loadu_ps(&y[idx_m])); + _mm512_storeu_ps(&y[idx_m], accum); + + } + if(tag_m_16x !=m) { + __m512i idx_c0c2 = _mm512_set_epi32(0, 0, 27, 22, 17, 12, 7, 2 , 0, 30, 25, 20, 15, 10, 5, 0); + __m512i idx_c1c3 = _mm512_add_epi32(idx_c0c2, M512_EPI32_1); + idx_c4 = _mm512_add_epi32(idx_c1c3, M512_EPI32_1); + __m256i idx_c0m4 = _mm256_set_epi32(11, 6, 0, 0, 0, 0, 0, 0); + __m256i M256_EPI32_1 = _mm256_set1_epi32(1); + __m256i idx_c1m4 = _mm256_add_epi32(idx_c0m4, M256_EPI32_1); + __m256i idx_c2m4 = _mm256_add_epi32(idx_c1m4, M256_EPI32_1); + __m256i idx_c3m4 = _mm256_add_epi32(idx_c2m4, M256_EPI32_1); + __m256i idx_c4m4 = _mm256_add_epi32(idx_c3m4, M256_EPI32_1); + //TODO: below can change to use extract to decrease the latency + __m256 x0_256 = _mm256_set1_ps(x[0]); + __m256 x1_256 = _mm256_set1_ps(x[1]); + __m256 x2_256 = _mm256_set1_ps(x[2]); + __m256 x3_256 = _mm256_set1_ps(x[3]); + __m256 x4_256 = _mm256_set1_ps(x[4]); + __m256 alpha256 = _mm256_set1_ps(alpha); + __m256 accum_256, m256_4; + + for(BLASLONG idx_m=tag_m_16x; idx_m < tag_m_8x; idx_m+=8) { + m0 = _mm512_loadu_ps(&a[idx_m*5]); + m1 = _mm512_loadu_ps(&a[idx_m*5 + 16]); + m256_4 = _mm256_loadu_ps(&a[idx_m*5 + 32]); + tmp0 = _mm512_permutex2var_ps(m0, idx_c0c2, m1); + tmp1 = _mm512_permutex2var_ps(m0, idx_c1c3, m1); + tmp2 = _mm512_permutex2var_ps(m0, idx_c4, m1); + + __m256 c256_0 = _mm512_extractf32x8_ps(tmp0, 0); + __m256 c256_2 = _mm512_extractf32x8_ps(tmp0, 1); + __m256 c256_1 = _mm512_extractf32x8_ps(tmp1, 0); + __m256 c256_3 = _mm512_extractf32x8_ps(tmp1, 1); + __m256 c256_4 = _mm512_extractf32x8_ps(tmp2, 1); + + c256_0 = _mm256_mask_permutex2var_ps(c256_0, 0x80, idx_c0m4, m256_4); + c256_1 = _mm256_mask_permutex2var_ps(c256_1, 0x80, idx_c1m4, m256_4); + c256_2 = _mm256_mask_permutex2var_ps(c256_2, 0xc0, idx_c2m4, m256_4); + c256_3 = _mm256_mask_permutex2var_ps(c256_3, 0xc0, idx_c3m4, m256_4); + c256_4 = _mm256_mask_permutex2var_ps(c256_4, 0xc0, idx_c4m4, m256_4); + + accum_256 = _mm256_fmadd_ps(c256_1, x1_256, _mm256_mul_ps(c256_0, x0_256)); + accum_256 = _mm256_fmadd_ps(c256_2, x2_256, accum_256); + accum_256 = _mm256_fmadd_ps(c256_3, x3_256, accum_256); + accum_256 = _mm256_fmadd_ps(c256_4, x4_256, accum_256); + accum_256 = _mm256_fmadd_ps(accum_256, alpha256, _mm256_loadu_ps(&y[idx_m])); + _mm256_storeu_ps(&y[idx_m], accum_256); + } + if(tag_m_8x != m) { + __m256i idx_c02 = _mm256_set_epi32(17, 12, 7, 2, 15, 10, 5, 0); + __m256i idx_c13 = _mm256_add_epi32(idx_c02, M256_EPI32_1); + __m256i idx_4 = _mm256_add_epi32(idx_c13, M256_EPI32_1); + __m128 accum_128; + __m256 m256_0, m256_1, tmp256_0, tmp256_1; + for (BLASLONG idx_m = tag_m_8x; idx_m < tag_m_4x; idx_m+=4){ + m256_0 = _mm256_loadu_ps(&a[idx_m*5]); + m256_1 = _mm256_loadu_ps(&a[idx_m*5 + 8]); + __m128 m128_4 = _mm_maskz_loadu_ps(0x0f, &a[idx_m*5 + 16]); + + tmp256_0 = _mm256_permutex2var_ps(m256_0, idx_c02, m256_1); + tmp256_1 = _mm256_permutex2var_ps(m256_0, idx_c13, m256_1); + __m256 tmp256_2 = _mm256_maskz_permutex2var_ps(0xf0, m256_0, idx_4, m256_1); + + __m128 c128_0 = _mm256_extractf32x4_ps(tmp256_0, 0); + __m128 c128_1 = _mm256_extractf32x4_ps(tmp256_1, 0); + __m128 c128_2 = _mm256_extractf32x4_ps(tmp256_0, 1); + __m128 c128_3 = _mm256_extractf32x4_ps(tmp256_1, 1); + __m128 c128_4 = _mm256_extractf32x4_ps(tmp256_2, 1); + + __m128i idx_c14 = _mm_set_epi32(4, 0, 0, 0); + __m128i M128_EPI32_1 = _mm_set1_epi32(1); + __m128i idx_c24 = _mm_add_epi32(idx_c14, M128_EPI32_1); + __m128i idx_c34 = _mm_add_epi32(idx_c24, M128_EPI32_1); + __m128i idx_c44 = _mm_add_epi32(idx_c34, M128_EPI32_1); + + c128_1 = _mm_mask_permutex2var_ps(c128_1, 0x08, idx_c14, m128_4); + c128_2 = _mm_mask_permutex2var_ps(c128_2, 0x08, idx_c24, m128_4); + c128_3 = _mm_mask_permutex2var_ps(c128_3, 0x08, idx_c34, m128_4); + c128_4 = _mm_mask_permutex2var_ps(c128_4, 0x08, idx_c44, m128_4); + + __m128 x128_0 = _mm256_extractf32x4_ps(x0_256, 0); + __m128 x128_1 = _mm256_extractf32x4_ps(x1_256, 0); + __m128 x128_2 = _mm256_extractf32x4_ps(x2_256, 0); + __m128 x128_3 = _mm256_extractf32x4_ps(x3_256, 0); + __m128 x128_4 = _mm256_extractf32x4_ps(x4_256, 0); + + __m128 alpha_128 = _mm256_extractf32x4_ps(alpha256, 0); + accum_128 = _mm_maskz_fmadd_ps(0x0f, c128_1, x128_1, _mm_maskz_mul_ps(0x0f, c128_0, x128_0)); + accum_128 = _mm_maskz_fmadd_ps(0x0f, c128_2, x128_2, accum_128); + accum_128 = _mm_maskz_fmadd_ps(0x0f, c128_3, x128_3, accum_128); + accum_128 = _mm_maskz_fmadd_ps(0x0f, c128_4, x128_4, accum_128); + accum_128 = _mm_maskz_fmadd_ps(0x0f, accum_128, alpha_128, _mm_maskz_loadu_ps(0x0f, &y[idx_m])); + _mm_mask_storeu_ps(&y[idx_m], 0x0f, accum_128); + + } + + if(tag_m_4x !=m ){ + x0_256 = _mm256_maskz_loadu_ps(0x1f, x); + x0_256 = _mm256_mul_ps(x0_256, alpha256); + float ret8[8]; + + for(BLASLONG idx_m = tag_m_4x; idx_m < tag_m_2x; idx_m+=2){ + m256_0 = _mm256_maskz_loadu_ps(0x1f, &a[idx_m*5]); + m256_1 = _mm256_maskz_loadu_ps(0x1f, &a[idx_m*5 + 5]); + + m256_0 = _mm256_mul_ps(m256_0, x0_256); + m256_1 = _mm256_mul_ps(m256_1, x0_256); + + _mm256_mask_storeu_ps(ret8, 0x1f, m256_0); + y[idx_m] += ret8[0] + ret8[1] + ret8[2] + ret8[3] + ret8[4]; + _mm256_mask_storeu_ps(ret8, 0x1f, m256_1); + y[idx_m+1] += ret8[0] + ret8[1] + ret8[2] + ret8[3] + ret8[4]; + + } + + if(tag_m_2x != m){ + m256_0 = _mm256_maskz_loadu_ps(0x1f, &a[tag_m_2x*5]); + m256_0 = _mm256_mul_ps(m256_0, x0_256); + + + _mm256_mask_storeu_ps(ret8, 0x1f, m256_0); + y[tag_m_2x] += ret8[0] + ret8[1] + ret8[2] + ret8[3] + ret8[4]; + + } + } + } + + } + return 0; +} + +static int sgemv_kernel_t_6(BLASLONG m, float alpha, float *a, float *x, float *y) +{ + BLASLONG tag_m_16x = m & (~15); + BLASLONG tag_m_8x = m & (~7); + BLASLONG tag_m_4x = m & (~3); + BLASLONG tag_m_2x = m & (~1); + + __m512 m0, m1, m2, m3, m4, m5, c0, c1, c2, c3, c4, c5, tmp0, tmp1, tmp2, accum; + __m512i idx_c0 = _mm512_set_epi32(26, 20, 14, 8, 2, 28, 22, 16, 10, 4, 30, 24, 18, 12, 6, 0); + __m512i M512_EPI32_1 = _mm512_set1_epi32(1); + __m512i M512_EPI32_0 = _mm512_setzero_epi32(); + __m512i M512_EPI32_16 = _mm512_set1_epi32(16); + __m512i idx_c1 = _mm512_add_epi32(idx_c0, M512_EPI32_1); + __m512i idx_c2 = _mm512_add_epi32(idx_c1, M512_EPI32_1); + idx_c2 = _mm512_mask_blend_epi32(0x0020, idx_c2, M512_EPI32_0); + __m512i idx_c3 = _mm512_add_epi32(idx_c2, M512_EPI32_1); + __m512i idx_c4 = _mm512_add_epi32(idx_c3, M512_EPI32_1); + idx_c4 = _mm512_mask_blend_epi32(0x0400, idx_c4, M512_EPI32_0); + __m512i idx_c5 = _mm512_add_epi32(idx_c4, M512_EPI32_1); + + __m512 x0_512 = _mm512_set1_ps(x[0]); + __m512 x1_512 = _mm512_set1_ps(x[1]); + __m512 x2_512 = _mm512_set1_ps(x[2]); + __m512 x3_512 = _mm512_set1_ps(x[3]); + __m512 x4_512 = _mm512_set1_ps(x[4]); + __m512 x5_512 = _mm512_set1_ps(x[5]); + __m512 alpha_512 = _mm512_set1_ps(alpha); + + for (BLASLONG idx_m=0; idx_m < tag_m_16x; idx_m+=16) { + m0 = _mm512_loadu_ps(&a[idx_m*6]); + m1 = _mm512_loadu_ps(&a[idx_m*6 + 16]); + m2 = _mm512_loadu_ps(&a[idx_m*6 + 32]); + m3 = _mm512_loadu_ps(&a[idx_m*6 + 48]); + m4 = _mm512_loadu_ps(&a[idx_m*6 + 64]); + m5 = _mm512_loadu_ps(&a[idx_m*6 + 80]); + + tmp0 = _mm512_maskz_permutex2var_ps(0x003f, m0, idx_c0, m1); + tmp1 = _mm512_maskz_permutex2var_ps(0x07c0, m2, idx_c0, m3); + tmp2 = _mm512_maskz_permutex2var_ps(0xf800, m4, idx_c0, m5); + c0 = _mm512_mask_blend_ps(0x07c0, tmp0, tmp1); + c0 = _mm512_mask_blend_ps(0xf800, c0, tmp2); + + tmp0 = _mm512_maskz_permutex2var_ps(0x003f, m0, idx_c1, m1); + tmp1 = _mm512_maskz_permutex2var_ps(0x07c0, m2, idx_c1, m3); + tmp2 = _mm512_maskz_permutex2var_ps(0xf800, m4, idx_c1, m5); + c1 = _mm512_mask_blend_ps(0x07c0, tmp0, tmp1); + c1 = _mm512_mask_blend_ps(0xf800, c1, tmp2); + + tmp0 = _mm512_maskz_permutex2var_ps(0x001f, m0, idx_c2, m1); + tmp1 = _mm512_maskz_permutex2var_ps(0x07e0, m2, idx_c2, m3); + tmp2 = _mm512_maskz_permutex2var_ps(0xf800, m4, idx_c2, m5); + c2 = _mm512_mask_blend_ps(0x07e0, tmp0, tmp1); + c2 = _mm512_mask_blend_ps(0xf800, c2, tmp2); + + tmp0 = _mm512_maskz_permutex2var_ps(0x001f, m0, idx_c3, m1); + tmp1 = _mm512_maskz_permutex2var_ps(0x07e0, m2, idx_c3, m3); + tmp2 = _mm512_maskz_permutex2var_ps(0xf800, m4, idx_c3, m5); + c3 = _mm512_mask_blend_ps(0x07e0, tmp0, tmp1); + c3 = _mm512_mask_blend_ps(0xf800, c3, tmp2); + + tmp0 = _mm512_maskz_permutex2var_ps(0x001f, m0, idx_c4, m1); + tmp1 = _mm512_maskz_permutex2var_ps(0x03e0, m2, idx_c4, m3); + tmp2 = _mm512_maskz_permutex2var_ps(0xfc00, m4, idx_c4, m5); + c4 = _mm512_mask_blend_ps(0x03e0, tmp0, tmp1); + c4 = _mm512_mask_blend_ps(0xfc00, c4, tmp2); + + tmp0 = _mm512_maskz_permutex2var_ps(0x001f, m0, idx_c5 , m1); + tmp1 = _mm512_maskz_permutex2var_ps(0x03e0, m2, idx_c5 , m3); + tmp2 = _mm512_maskz_permutex2var_ps(0xfc00, m4, idx_c5 , m5); + c5 = _mm512_mask_blend_ps(0x03e0, tmp0, tmp1); + c5 = _mm512_mask_blend_ps(0xfc00, c5, tmp2); + + accum = _mm512_fmadd_ps(c1, x1_512, _mm512_mul_ps(c0, x0_512)); + accum = _mm512_fmadd_ps(c2, x2_512, accum); + accum = _mm512_fmadd_ps(c3, x3_512, accum); + accum = _mm512_fmadd_ps(c4, x4_512, accum); + accum = _mm512_fmadd_ps(c5, x5_512, accum); + accum = _mm512_fmadd_ps(accum, alpha_512, _mm512_loadu_ps(&y[idx_m])); + _mm512_storeu_ps(&y[idx_m], accum); + } + + if(tag_m_16x != m) { + __m512i idx_c0c3 = _mm512_set_epi32(29, 23, 17, 27, 21, 15, 9, 3, 26, 20, 30, 24, 18, 12, 6, 0); + __m512i idx_c1c4 = _mm512_add_epi32(idx_c0c3, M512_EPI32_1); + __m512i idx_c2c5 = _mm512_add_epi32(idx_c1c4, M512_EPI32_1); + idx_c2c5 = _mm512_mask_blend_epi32(0x0020, idx_c2c5, M512_EPI32_16); + __m256 c256_0, c256_1, c256_2, c256_3, c256_4, c256_5; + + __m256 x0_256 = _mm256_set1_ps(x[0]); + __m256 x1_256 = _mm256_set1_ps(x[1]); + __m256 x2_256 = _mm256_set1_ps(x[2]); + __m256 x3_256 = _mm256_set1_ps(x[3]); + __m256 x4_256 = _mm256_set1_ps(x[4]); + __m256 x5_256 = _mm256_set1_ps(x[5]); + __m256 alpha256 = _mm256_set1_ps(alpha); + __m256 accum_256; + + for(BLASLONG idx_m = tag_m_16x; idx_m