diff --git a/kernel/x86_64/daxpy_microk_haswell-2.c b/kernel/x86_64/daxpy_microk_haswell-2.c index bbe8b9550..c2491ba9b 100644 --- a/kernel/x86_64/daxpy_microk_haswell-2.c +++ b/kernel/x86_64/daxpy_microk_haswell-2.c @@ -25,54 +25,49 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ + + +#ifndef __AVX512CD__ +#pragma GCC target("avx2,fma") +#endif + +#ifdef __AVX2__ + +#include + #define HAVE_KERNEL_8 1 -static void daxpy_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *alpha) __attribute__ ((noinline)); static void daxpy_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *alpha) { - - - BLASLONG register i = 0; - - __asm__ __volatile__ - ( - "vbroadcastsd (%4), %%ymm0 \n\t" // alpha - - ".p2align 4 \n\t" - "1: \n\t" - - "vmovups (%3,%0,8), %%ymm12 \n\t" // 4 * y - "vmovups 32(%3,%0,8), %%ymm13 \n\t" // 4 * y - "vmovups 64(%3,%0,8), %%ymm14 \n\t" // 4 * y - "vmovups 96(%3,%0,8), %%ymm15 \n\t" // 4 * y - "vfmadd231pd (%2,%0,8), %%ymm0 , %%ymm12 \n\t" // y += alpha * x - "vfmadd231pd 32(%2,%0,8), %%ymm0 , %%ymm13 \n\t" // y += alpha * x - "vfmadd231pd 64(%2,%0,8), %%ymm0 , %%ymm14 \n\t" // y += alpha * x - "vfmadd231pd 96(%2,%0,8), %%ymm0 , %%ymm15 \n\t" // y += alpha * x - "vmovups %%ymm12, (%3,%0,8) \n\t" - "vmovups %%ymm13, 32(%3,%0,8) \n\t" - "vmovups %%ymm14, 64(%3,%0,8) \n\t" - "vmovups %%ymm15, 96(%3,%0,8) \n\t" - - "addq $16, %0 \n\t" - "subq $16, %1 \n\t" - "jnz 1b \n\t" - "vzeroupper \n\t" - - : - : - "r" (i), // 0 - "r" (n), // 1 - "r" (x), // 2 - "r" (y), // 3 - "r" (alpha) // 4 - : "cc", - "%xmm0", - "%xmm8", "%xmm9", "%xmm10", "%xmm11", - "%xmm12", "%xmm13", "%xmm14", "%xmm15", - "memory" - ); - -} + BLASLONG i = 0; + + __m256d __alpha; + + __alpha = _mm256_broadcastsd_pd(_mm_load_sd(alpha)); + +#ifdef __AVX512CD__ + BLASLONG n32; + __m512d __alpha5; + __alpha5 = _mm512_broadcastsd_pd(_mm_load_sd(alpha)); + + n32 = n & ~31; + + for (; i < n32; i+= 32) { + _mm512_storeu_pd(&y[i + 0], _mm512_loadu_pd(&y[i + 0]) + __alpha5 * _mm512_loadu_pd(&x[i + 0])); + _mm512_storeu_pd(&y[i + 8], _mm512_loadu_pd(&y[i + 8]) + __alpha5 * _mm512_loadu_pd(&x[i + 8])); + _mm512_storeu_pd(&y[i + 16], _mm512_loadu_pd(&y[i + 16]) + __alpha5 * _mm512_loadu_pd(&x[i + 16])); + _mm512_storeu_pd(&y[i + 24], _mm512_loadu_pd(&y[i + 24]) + __alpha5 * _mm512_loadu_pd(&x[i + 24])); + } + +#endif + + for (; i < n; i+= 16) { + _mm256_storeu_pd(&y[i + 0], _mm256_loadu_pd(&y[i + 0]) + __alpha * _mm256_loadu_pd(&x[i + 0])); + _mm256_storeu_pd(&y[i + 4], _mm256_loadu_pd(&y[i + 4]) + __alpha * _mm256_loadu_pd(&x[i + 4])); + _mm256_storeu_pd(&y[i + 8], _mm256_loadu_pd(&y[i + 8]) + __alpha * _mm256_loadu_pd(&x[i + 8])); + _mm256_storeu_pd(&y[i + 12], _mm256_loadu_pd(&y[i + 12]) + __alpha * _mm256_loadu_pd(&x[i + 12])); + } +} +#endif diff --git a/kernel/x86_64/ddot_microk_haswell-2.c b/kernel/x86_64/ddot_microk_haswell-2.c index 365737363..b9be83af6 100644 --- a/kernel/x86_64/ddot_microk_haswell-2.c +++ b/kernel/x86_64/ddot_microk_haswell-2.c @@ -25,71 +25,75 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ +/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already */ + +#ifndef __AVX512CD__ +#pragma GCC target("avx2,fma") +#endif + +#ifdef __AVX2__ + #define HAVE_KERNEL_8 1 -static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *dot) __attribute__ ((noinline)); + +#include static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot) { + int i = 0; + __m256d accum_0, accum_1, accum_2, accum_3; + + accum_0 = _mm256_setzero_pd(); + accum_1 = _mm256_setzero_pd(); + accum_2 = _mm256_setzero_pd(); + accum_3 = _mm256_setzero_pd(); +#ifdef __AVX512CD__ + __m512d accum_05, accum_15, accum_25, accum_35; + int n32; + n32 = n & (~31); - BLASLONG register i = 0; + accum_05 = _mm512_setzero_pd(); + accum_15 = _mm512_setzero_pd(); + accum_25 = _mm512_setzero_pd(); + accum_35 = _mm512_setzero_pd(); - __asm__ __volatile__ - ( - "vxorpd %%ymm4, %%ymm4, %%ymm4 \n\t" - "vxorpd %%ymm5, %%ymm5, %%ymm5 \n\t" - "vxorpd %%ymm6, %%ymm6, %%ymm6 \n\t" - "vxorpd %%ymm7, %%ymm7, %%ymm7 \n\t" + for (; i < n32; i += 32) { + accum_05 += _mm512_loadu_pd(&x[i+ 0]) * _mm512_loadu_pd(&y[i+ 0]); + accum_15 += _mm512_loadu_pd(&x[i+ 8]) * _mm512_loadu_pd(&y[i+ 8]); + accum_25 += _mm512_loadu_pd(&x[i+16]) * _mm512_loadu_pd(&y[i+16]); + accum_35 += _mm512_loadu_pd(&x[i+24]) * _mm512_loadu_pd(&y[i+24]); + } - ".p2align 4 \n\t" - "1: \n\t" - "vmovups (%2,%0,8), %%ymm12 \n\t" // 2 * x - "vmovups 32(%2,%0,8), %%ymm13 \n\t" // 2 * x - "vmovups 64(%2,%0,8), %%ymm14 \n\t" // 2 * x - "vmovups 96(%2,%0,8), %%ymm15 \n\t" // 2 * x + /* + * we need to fold our 512 bit wide accumulator vectors into 256 bit wide vectors so that the AVX2 code + * below can continue using the intermediate results in its loop + */ + accum_0 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_05, 0), _mm512_extractf64x4_pd(accum_05, 1)); + accum_1 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_15, 0), _mm512_extractf64x4_pd(accum_15, 1)); + accum_2 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_25, 0), _mm512_extractf64x4_pd(accum_25, 1)); + accum_3 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_35, 0), _mm512_extractf64x4_pd(accum_35, 1)); - "vfmadd231pd (%3,%0,8), %%ymm12, %%ymm4 \n\t" // 2 * y - "vfmadd231pd 32(%3,%0,8), %%ymm13, %%ymm5 \n\t" // 2 * y - "vfmadd231pd 64(%3,%0,8), %%ymm14, %%ymm6 \n\t" // 2 * y - "vfmadd231pd 96(%3,%0,8), %%ymm15, %%ymm7 \n\t" // 2 * y +#endif + for (; i < n; i += 16) { + accum_0 += _mm256_loadu_pd(&x[i+ 0]) * _mm256_loadu_pd(&y[i+ 0]); + accum_1 += _mm256_loadu_pd(&x[i+ 4]) * _mm256_loadu_pd(&y[i+ 4]); + accum_2 += _mm256_loadu_pd(&x[i+ 8]) * _mm256_loadu_pd(&y[i+ 8]); + accum_3 += _mm256_loadu_pd(&x[i+12]) * _mm256_loadu_pd(&y[i+12]); + } - "addq $16 , %0 \n\t" - "subq $16 , %1 \n\t" - "jnz 1b \n\t" + /* we now have the partial sums of the dot product in the 4 accumulation vectors, time to consolidate */ - "vextractf128 $1 , %%ymm4 , %%xmm12 \n\t" - "vextractf128 $1 , %%ymm5 , %%xmm13 \n\t" - "vextractf128 $1 , %%ymm6 , %%xmm14 \n\t" - "vextractf128 $1 , %%ymm7 , %%xmm15 \n\t" + accum_0 = accum_0 + accum_1 + accum_2 + accum_3; - "vaddpd %%xmm4, %%xmm12, %%xmm4 \n\t" - "vaddpd %%xmm5, %%xmm13, %%xmm5 \n\t" - "vaddpd %%xmm6, %%xmm14, %%xmm6 \n\t" - "vaddpd %%xmm7, %%xmm15, %%xmm7 \n\t" + __m128d half_accum0; - "vaddpd %%xmm4, %%xmm5, %%xmm4 \n\t" - "vaddpd %%xmm6, %%xmm7, %%xmm6 \n\t" - "vaddpd %%xmm4, %%xmm6, %%xmm4 \n\t" + /* Add upper half to lower half of each of the 256 bit vector to get a 128 bit vector */ + half_accum0 = _mm_add_pd(_mm256_extractf128_pd(accum_0, 0), _mm256_extractf128_pd(accum_0, 1)); - "vhaddpd %%xmm4, %%xmm4, %%xmm4 \n\t" - - "vmovsd %%xmm4, (%4) \n\t" - "vzeroupper \n\t" - - : - : - "r" (i), // 0 - "r" (n), // 1 - "r" (x), // 2 - "r" (y), // 3 - "r" (dot) // 4 - : "cc", - "%xmm4", "%xmm5", - "%xmm6", "%xmm7", - "%xmm12", "%xmm13", "%xmm14", "%xmm15", - "memory" - ); - -} + /* in 128 bit land there is a hadd operation to do the rest of the element-wise sum in one go */ + half_accum0 = _mm_hadd_pd(half_accum0, half_accum0); + *dot = half_accum0[0]; +} +#endif diff --git a/kernel/x86_64/dgemv_n_microk_haswell-4.c b/kernel/x86_64/dgemv_n_microk_haswell-4.c index 584a6c6b5..69e93a1fb 100644 --- a/kernel/x86_64/dgemv_n_microk_haswell-4.c +++ b/kernel/x86_64/dgemv_n_microk_haswell-4.c @@ -25,167 +25,104 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ +/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already */ +#ifndef __AVX512CD_ +#pragma GCC target("avx2,fma") +#endif +#ifdef __AVX2__ #define HAVE_KERNEL_4x4 1 -static void dgemv_kernel_4x4( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT *alpha) __attribute__ ((noinline)); + +#include static void dgemv_kernel_4x4( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT *alpha) { - BLASLONG register i = 0; + int i = 0; - __asm__ __volatile__ - ( - "vbroadcastsd (%2), %%ymm12 \n\t" // x0 - "vbroadcastsd 8(%2), %%ymm13 \n\t" // x1 - "vbroadcastsd 16(%2), %%ymm14 \n\t" // x2 - "vbroadcastsd 24(%2), %%ymm15 \n\t" // x3 + __m256d x0, x1, x2, x3; + __m256d __alpha; - "vmovups (%4,%0,8), %%ymm0 \n\t" - "vmovups (%5,%0,8), %%ymm1 \n\t" - "vmovups (%6,%0,8), %%ymm2 \n\t" - "vmovups (%7,%0,8), %%ymm3 \n\t" - "vbroadcastsd (%8), %%ymm6 \n\t" // alpha + x0 = _mm256_broadcastsd_pd(_mm_load_sd(&x[0])); + x1 = _mm256_broadcastsd_pd(_mm_load_sd(&x[1])); + x2 = _mm256_broadcastsd_pd(_mm_load_sd(&x[2])); + x3 = _mm256_broadcastsd_pd(_mm_load_sd(&x[3])); - "addq $4 , %0 \n\t" - "subq $4 , %1 \n\t" - "jz 2f \n\t" + __alpha = _mm256_broadcastsd_pd(_mm_load_sd(alpha)); - // ".align 16 \n\t" - "1: \n\t" +#ifdef __AVX512CD__ + int n5; + __m512d x05, x15, x25, x35; + __m512d __alpha5; + n5 = n & ~7; - "vmulpd %%ymm0 , %%ymm12, %%ymm4 \n\t" - "vmulpd %%ymm1 , %%ymm13, %%ymm5 \n\t" - "vmovups (%4,%0,8), %%ymm0 \n\t" - "vmovups (%5,%0,8), %%ymm1 \n\t" - "vfmadd231pd %%ymm2 , %%ymm14, %%ymm4 \n\t" - "vfmadd231pd %%ymm3 , %%ymm15, %%ymm5 \n\t" - "vmovups (%6,%0,8), %%ymm2 \n\t" - "vmovups (%7,%0,8), %%ymm3 \n\t" + x05 = _mm512_broadcastsd_pd(_mm_load_sd(&x[0])); + x15 = _mm512_broadcastsd_pd(_mm_load_sd(&x[1])); + x25 = _mm512_broadcastsd_pd(_mm_load_sd(&x[2])); + x35 = _mm512_broadcastsd_pd(_mm_load_sd(&x[3])); - "vmovups -32(%3,%0,8), %%ymm8 \n\t" // 4 * y - "vaddpd %%ymm4 , %%ymm5 , %%ymm4 \n\t" - "vfmadd231pd %%ymm6 , %%ymm4 , %%ymm8 \n\t" + __alpha5 = _mm512_broadcastsd_pd(_mm_load_sd(alpha)); - "vmovups %%ymm8, -32(%3,%0,8) \n\t" // 4 * y + for (; i < n5; i+= 8) { + __m512d tempY; + __m512d sum; - "addq $4 , %0 \n\t" - "subq $4 , %1 \n\t" - "jnz 1b \n\t" - + sum = _mm512_loadu_pd(&ap[0][i]) * x05 + + _mm512_loadu_pd(&ap[1][i]) * x15 + + _mm512_loadu_pd(&ap[2][i]) * x25 + + _mm512_loadu_pd(&ap[3][i]) * x35; - "2: \n\t" + tempY = _mm512_loadu_pd(&y[i]); + tempY += sum * __alpha5; + _mm512_storeu_pd(&y[i], tempY); + } +#endif - "vmulpd %%ymm0 , %%ymm12, %%ymm4 \n\t" - "vmulpd %%ymm1 , %%ymm13, %%ymm5 \n\t" - "vfmadd231pd %%ymm2 , %%ymm14, %%ymm4 \n\t" - "vfmadd231pd %%ymm3 , %%ymm15, %%ymm5 \n\t" + for (; i < n; i+= 4) { + __m256d tempY; + __m256d sum; + sum = _mm256_loadu_pd(&ap[0][i]) * x0 + + _mm256_loadu_pd(&ap[1][i]) * x1 + + _mm256_loadu_pd(&ap[2][i]) * x2 + + _mm256_loadu_pd(&ap[3][i]) * x3; - "vmovups -32(%3,%0,8), %%ymm8 \n\t" // 4 * y - "vaddpd %%ymm4 , %%ymm5 , %%ymm4 \n\t" - "vfmadd231pd %%ymm6 , %%ymm4 , %%ymm8 \n\t" - - "vmovups %%ymm8, -32(%3,%0,8) \n\t" // 4 * y - - - "vzeroupper \n\t" - - : - "+r" (i), // 0 - "+r" (n) // 1 - : - "r" (x), // 2 - "r" (y), // 3 - "r" (ap[0]), // 4 - "r" (ap[1]), // 5 - "r" (ap[2]), // 6 - "r" (ap[3]), // 7 - "r" (alpha) // 8 - : "cc", - "%xmm4", "%xmm5", - "%xmm6", "%xmm7", - "%xmm8", "%xmm9", - "%xmm12", "%xmm13", "%xmm14", "%xmm15", - "memory" - ); + tempY = _mm256_loadu_pd(&y[i]); + tempY += sum * __alpha; + _mm256_storeu_pd(&y[i], tempY); + } } #define HAVE_KERNEL_4x2 -static void dgemv_kernel_4x2( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT *alpha) __attribute__ ((noinline)); - static void dgemv_kernel_4x2( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT *alpha) { - BLASLONG register i = 0; + int i = 0; - __asm__ __volatile__ - ( - "vbroadcastsd (%2), %%ymm12 \n\t" // x0 - "vbroadcastsd 8(%2), %%ymm13 \n\t" // x1 + __m256d x0, x1; + __m256d __alpha; - "vmovups (%4,%0,8), %%ymm0 \n\t" - "vmovups (%5,%0,8), %%ymm1 \n\t" + x0 = _mm256_broadcastsd_pd(_mm_load_sd(&x[0])); + x1 = _mm256_broadcastsd_pd(_mm_load_sd(&x[1])); - "vbroadcastsd (%6), %%ymm6 \n\t" // alpha - - "addq $4 , %0 \n\t" - "subq $4 , %1 \n\t" - "jz 2f \n\t" - - "1: \n\t" - - "vmulpd %%ymm0 , %%ymm12, %%ymm4 \n\t" - "vmulpd %%ymm1 , %%ymm13, %%ymm5 \n\t" - "vmovups (%4,%0,8), %%ymm0 \n\t" - "vmovups (%5,%0,8), %%ymm1 \n\t" - - "vmovups -32(%3,%0,8), %%ymm8 \n\t" // 4 * y - "vaddpd %%ymm4 , %%ymm5 , %%ymm4 \n\t" - "vfmadd231pd %%ymm6 , %%ymm4 , %%ymm8 \n\t" - - "vmovups %%ymm8, -32(%3,%0,8) \n\t" // 4 * y - - "addq $4 , %0 \n\t" - "subq $4 , %1 \n\t" - "jnz 1b \n\t" - - - "2: \n\t" - - "vmulpd %%ymm0 , %%ymm12, %%ymm4 \n\t" - "vmulpd %%ymm1 , %%ymm13, %%ymm5 \n\t" + __alpha = _mm256_broadcastsd_pd(_mm_load_sd(alpha)); - "vmovups -32(%3,%0,8), %%ymm8 \n\t" // 4 * y - "vaddpd %%ymm4 , %%ymm5 , %%ymm4 \n\t" - "vfmadd231pd %%ymm6 , %%ymm4 , %%ymm8 \n\t" + for (i = 0; i < n; i+= 4) { + __m256d tempY; + __m256d sum; - "vmovups %%ymm8, -32(%3,%0,8) \n\t" // 4 * y + sum = _mm256_loadu_pd(&ap[0][i]) * x0 + _mm256_loadu_pd(&ap[1][i]) * x1; + tempY = _mm256_loadu_pd(&y[i]); + tempY += sum * __alpha; + _mm256_storeu_pd(&y[i], tempY); + } - "vzeroupper \n\t" - - - : - "+r" (i), // 0 - "+r" (n) // 1 - : - "r" (x), // 2 - "r" (y), // 3 - "r" (ap[0]), // 4 - "r" (ap[1]), // 5 - "r" (alpha) // 6 - : "cc", - "%xmm0", "%xmm1", - "%xmm4", "%xmm5", - "%xmm6", - "%xmm8", - "%xmm12", "%xmm13", - "memory" - ); } + +#endif /* AVX2 */ diff --git a/kernel/x86_64/dscal_microk_haswell-2.c b/kernel/x86_64/dscal_microk_haswell-2.c index e732a2718..20fb2b283 100644 --- a/kernel/x86_64/dscal_microk_haswell-2.c +++ b/kernel/x86_64/dscal_microk_haswell-2.c @@ -25,182 +25,55 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ -#define HAVE_KERNEL_8 1 -static void dscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x) __attribute__ ((noinline)); +#ifndef __AVX512CD__ +#pragma GCC target("avx2,fma") +#endif + +#ifdef __AVX2__ + +#include + +#define HAVE_KERNEL_8 1 static void dscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x) { + int i = 0; - - BLASLONG n1 = n >> 4 ; - BLASLONG n2 = n & 8 ; - - __asm__ __volatile__ - ( - "vmovddup (%2), %%xmm0 \n\t" // alpha - - "addq $128, %1 \n\t" - - "cmpq $0, %0 \n\t" - "je 4f \n\t" - - "vmulpd -128(%1), %%xmm0, %%xmm4 \n\t" - "vmulpd -112(%1), %%xmm0, %%xmm5 \n\t" - "vmulpd -96(%1), %%xmm0, %%xmm6 \n\t" - "vmulpd -80(%1), %%xmm0, %%xmm7 \n\t" - - "vmulpd -64(%1), %%xmm0, %%xmm8 \n\t" - "vmulpd -48(%1), %%xmm0, %%xmm9 \n\t" - "vmulpd -32(%1), %%xmm0, %%xmm10 \n\t" - "vmulpd -16(%1), %%xmm0, %%xmm11 \n\t" - - "subq $1 , %0 \n\t" - "jz 2f \n\t" - - ".p2align 4 \n\t" - "1: \n\t" - // "prefetcht0 640(%1) \n\t" - - "vmovups %%xmm4 ,-128(%1) \n\t" - "vmovups %%xmm5 ,-112(%1) \n\t" - "vmulpd 0(%1), %%xmm0, %%xmm4 \n\t" - "vmovups %%xmm6 , -96(%1) \n\t" - "vmulpd 16(%1), %%xmm0, %%xmm5 \n\t" - "vmovups %%xmm7 , -80(%1) \n\t" - "vmulpd 32(%1), %%xmm0, %%xmm6 \n\t" - - // "prefetcht0 704(%1) \n\t" - - "vmovups %%xmm8 , -64(%1) \n\t" - "vmulpd 48(%1), %%xmm0, %%xmm7 \n\t" - "vmovups %%xmm9 , -48(%1) \n\t" - "vmulpd 64(%1), %%xmm0, %%xmm8 \n\t" - "vmovups %%xmm10 , -32(%1) \n\t" - "vmulpd 80(%1), %%xmm0, %%xmm9 \n\t" - "vmovups %%xmm11 , -16(%1) \n\t" - - "vmulpd 96(%1), %%xmm0, %%xmm10 \n\t" - "vmulpd 112(%1), %%xmm0, %%xmm11 \n\t" - - - "addq $128, %1 \n\t" - "subq $1 , %0 \n\t" - "jnz 1b \n\t" - - "2: \n\t" - - "vmovups %%xmm4 ,-128(%1) \n\t" - "vmovups %%xmm5 ,-112(%1) \n\t" - "vmovups %%xmm6 , -96(%1) \n\t" - "vmovups %%xmm7 , -80(%1) \n\t" - - "vmovups %%xmm8 , -64(%1) \n\t" - "vmovups %%xmm9 , -48(%1) \n\t" - "vmovups %%xmm10 , -32(%1) \n\t" - "vmovups %%xmm11 , -16(%1) \n\t" - - "addq $128, %1 \n\t" - - "4: \n\t" - - "cmpq $8 ,%3 \n\t" - "jne 5f \n\t" - - "vmulpd -128(%1), %%xmm0, %%xmm4 \n\t" - "vmulpd -112(%1), %%xmm0, %%xmm5 \n\t" - "vmulpd -96(%1), %%xmm0, %%xmm6 \n\t" - "vmulpd -80(%1), %%xmm0, %%xmm7 \n\t" - - "vmovups %%xmm4 ,-128(%1) \n\t" - "vmovups %%xmm5 ,-112(%1) \n\t" - "vmovups %%xmm6 , -96(%1) \n\t" - "vmovups %%xmm7 , -80(%1) \n\t" - - "5: \n\t" - - "vzeroupper \n\t" - - : - : - "r" (n1), // 0 - "r" (x), // 1 - "r" (alpha), // 2 - "r" (n2) // 3 - : "cc", - "%xmm0", "%xmm1", "%xmm2", "%xmm3", - "%xmm4", "%xmm5", "%xmm6", "%xmm7", - "%xmm8", "%xmm9", "%xmm10", "%xmm11", - "%xmm12", "%xmm13", "%xmm14", "%xmm15", - "memory" - ); - +#ifdef __AVX512CD__ + __m512d __alpha5 = _mm512_broadcastsd_pd(_mm_load_sd(alpha)); + for (; i < n; i += 8) { + _mm512_storeu_pd(&x[i + 0], __alpha5 * _mm512_loadu_pd(&x[i + 0])); + } +#else + __m256d __alpha = _mm256_broadcastsd_pd(_mm_load_sd(alpha)); + for (; i < n; i += 8) { + _mm256_storeu_pd(&x[i + 0], __alpha * _mm256_loadu_pd(&x[i + 0])); + _mm256_storeu_pd(&x[i + 4], __alpha * _mm256_loadu_pd(&x[i + 4])); + } +#endif } -static void dscal_kernel_8_zero( BLASLONG n, FLOAT *alpha, FLOAT *x) __attribute__ ((noinline)); - static void dscal_kernel_8_zero( BLASLONG n, FLOAT *alpha, FLOAT *x) { + int i = 0; + /* question to self: Why is this not just memset() */ - BLASLONG n1 = n >> 4 ; - BLASLONG n2 = n & 8 ; - - __asm__ __volatile__ - ( - "vxorpd %%xmm0, %%xmm0 , %%xmm0 \n\t" - - "addq $128, %1 \n\t" - - "cmpq $0, %0 \n\t" - "je 2f \n\t" - - ".p2align 4 \n\t" - "1: \n\t" - - "vmovups %%xmm0 ,-128(%1) \n\t" - "vmovups %%xmm0 ,-112(%1) \n\t" - "vmovups %%xmm0 , -96(%1) \n\t" - "vmovups %%xmm0 , -80(%1) \n\t" - - "vmovups %%xmm0 , -64(%1) \n\t" - "vmovups %%xmm0 , -48(%1) \n\t" - "vmovups %%xmm0 , -32(%1) \n\t" - "vmovups %%xmm0 , -16(%1) \n\t" - - "addq $128, %1 \n\t" - "subq $1 , %0 \n\t" - "jnz 1b \n\t" - - "2: \n\t" - - "cmpq $8 ,%3 \n\t" - "jne 4f \n\t" - - "vmovups %%xmm0 ,-128(%1) \n\t" - "vmovups %%xmm0 ,-112(%1) \n\t" - "vmovups %%xmm0 , -96(%1) \n\t" - "vmovups %%xmm0 , -80(%1) \n\t" - - "4: \n\t" - - "vzeroupper \n\t" - - : - : - "r" (n1), // 0 - "r" (x), // 1 - "r" (alpha), // 2 - "r" (n2) // 3 - : "cc", - "%xmm0", "%xmm1", "%xmm2", "%xmm3", - "%xmm4", "%xmm5", "%xmm6", "%xmm7", - "%xmm8", "%xmm9", "%xmm10", "%xmm11", - "%xmm12", "%xmm13", "%xmm14", "%xmm15", - "memory" - ); +#ifdef __AVX512CD__ + __m512d zero = _mm512_setzero_pd(); + for (; i < n; i += 8) { + _mm512_storeu_pd(&x[i], zero); + } +#else + __m256d zero = _mm256_setzero_pd(); + for (; i < n; i += 8) { + _mm256_storeu_pd(&x[i + 0], zero); + _mm256_storeu_pd(&x[i + 4], zero); + } +#endif } - +#endif diff --git a/kernel/x86_64/dsymv_L_microk_haswell-2.c b/kernel/x86_64/dsymv_L_microk_haswell-2.c index 866782ee6..268692091 100644 --- a/kernel/x86_64/dsymv_L_microk_haswell-2.c +++ b/kernel/x86_64/dsymv_L_microk_haswell-2.c @@ -25,105 +25,140 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ + +/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already */ +#ifndef __AVX512CD__ +#pragma GCC target("avx2,fma") +#endif + +#ifdef __AVX2__ + +#include + #define HAVE_KERNEL_4x4 1 -static void dsymv_kernel_4x4( BLASLONG from, BLASLONG to, FLOAT **a, FLOAT *x, FLOAT *y, FLOAT *temp1, FLOAT *temp2) __attribute__ ((noinline)); static void dsymv_kernel_4x4(BLASLONG from, BLASLONG to, FLOAT **a, FLOAT *x, FLOAT *y, FLOAT *temp1, FLOAT *temp2) { - __asm__ __volatile__ - ( - "vzeroupper \n\t" - "vxorpd %%ymm0 , %%ymm0 , %%ymm0 \n\t" // temp2[0] - "vxorpd %%ymm1 , %%ymm1 , %%ymm1 \n\t" // temp2[1] - "vxorpd %%ymm2 , %%ymm2 , %%ymm2 \n\t" // temp2[2] - "vxorpd %%ymm3 , %%ymm3 , %%ymm3 \n\t" // temp2[3] - "vbroadcastsd (%8), %%ymm4 \n\t" // temp1[0] - "vbroadcastsd 8(%8), %%ymm5 \n\t" // temp1[1] - "vbroadcastsd 16(%8), %%ymm6 \n\t" // temp1[1] - "vbroadcastsd 24(%8), %%ymm7 \n\t" // temp1[1] + __m256d accum_0, accum_1, accum_2, accum_3; + __m256d temp1_0, temp1_1, temp1_2, temp1_3; - ".p2align 4 \n\t" - "1: \n\t" + /* the 256 bit wide acculmulator vectors start out as zero */ + accum_0 = _mm256_setzero_pd(); + accum_1 = _mm256_setzero_pd(); + accum_2 = _mm256_setzero_pd(); + accum_3 = _mm256_setzero_pd(); - "vmovups (%3,%0,8), %%ymm9 \n\t" // 2 * y - "vmovups (%2,%0,8), %%ymm8 \n\t" // 2 * x + temp1_0 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[0])); + temp1_1 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[1])); + temp1_2 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[2])); + temp1_3 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[3])); - "vmovups (%4,%0,8), %%ymm12 \n\t" // 2 * a - "vmovups (%5,%0,8), %%ymm13 \n\t" // 2 * a - "vmovups (%6,%0,8), %%ymm14 \n\t" // 2 * a - "vmovups (%7,%0,8), %%ymm15 \n\t" // 2 * a +#ifdef __AVX512CD__ + __m512d accum_05, accum_15, accum_25, accum_35; + __m512d temp1_05, temp1_15, temp1_25, temp1_35; + BLASLONG to2; + int delta; - "vfmadd231pd %%ymm4, %%ymm12 , %%ymm9 \n\t" // y += temp1 * a - "vfmadd231pd %%ymm8, %%ymm12 , %%ymm0 \n\t" // temp2 += x * a + /* the 512 bit wide accumulator vectors start out as zero */ + accum_05 = _mm512_setzero_pd(); + accum_15 = _mm512_setzero_pd(); + accum_25 = _mm512_setzero_pd(); + accum_35 = _mm512_setzero_pd(); - "vfmadd231pd %%ymm5, %%ymm13 , %%ymm9 \n\t" // y += temp1 * a - "vfmadd231pd %%ymm8, %%ymm13 , %%ymm1 \n\t" // temp2 += x * a + temp1_05 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[0])); + temp1_15 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[1])); + temp1_25 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[2])); + temp1_35 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[3])); - "vfmadd231pd %%ymm6, %%ymm14 , %%ymm9 \n\t" // y += temp1 * a - "vfmadd231pd %%ymm8, %%ymm14 , %%ymm2 \n\t" // temp2 += x * a + delta = (to - from) & ~7; + to2 = from + delta; - "vfmadd231pd %%ymm7, %%ymm15 , %%ymm9 \n\t" // y += temp1 * a - "vfmadd231pd %%ymm8, %%ymm15 , %%ymm3 \n\t" // temp2 += x * a - "addq $4 , %0 \n\t" - "vmovups %%ymm9 , -32(%3,%0,8) \n\t" + for (; from < to2; from += 8) { + __m512d _x, _y; + __m512d a0, a1, a2, a3; - "cmpq %0 , %1 \n\t" - "jnz 1b \n\t" + _y = _mm512_loadu_pd(&y[from]); + _x = _mm512_loadu_pd(&x[from]); - "vmovsd (%9), %%xmm4 \n\t" - "vmovsd 8(%9), %%xmm5 \n\t" - "vmovsd 16(%9), %%xmm6 \n\t" - "vmovsd 24(%9), %%xmm7 \n\t" + a0 = _mm512_loadu_pd(&a[0][from]); + a1 = _mm512_loadu_pd(&a[1][from]); + a2 = _mm512_loadu_pd(&a[2][from]); + a3 = _mm512_loadu_pd(&a[3][from]); - "vextractf128 $0x01, %%ymm0 , %%xmm12 \n\t" - "vextractf128 $0x01, %%ymm1 , %%xmm13 \n\t" - "vextractf128 $0x01, %%ymm2 , %%xmm14 \n\t" - "vextractf128 $0x01, %%ymm3 , %%xmm15 \n\t" + _y += temp1_05 * a0 + temp1_15 * a1 + temp1_25 * a2 + temp1_35 * a3; - "vaddpd %%xmm0, %%xmm12, %%xmm0 \n\t" - "vaddpd %%xmm1, %%xmm13, %%xmm1 \n\t" - "vaddpd %%xmm2, %%xmm14, %%xmm2 \n\t" - "vaddpd %%xmm3, %%xmm15, %%xmm3 \n\t" + accum_05 += _x * a0; + accum_15 += _x * a1; + accum_25 += _x * a2; + accum_35 += _x * a3; - "vhaddpd %%xmm0, %%xmm0, %%xmm0 \n\t" - "vhaddpd %%xmm1, %%xmm1, %%xmm1 \n\t" - "vhaddpd %%xmm2, %%xmm2, %%xmm2 \n\t" - "vhaddpd %%xmm3, %%xmm3, %%xmm3 \n\t" + _mm512_storeu_pd(&y[from], _y); - "vaddsd %%xmm4, %%xmm0, %%xmm0 \n\t" - "vaddsd %%xmm5, %%xmm1, %%xmm1 \n\t" - "vaddsd %%xmm6, %%xmm2, %%xmm2 \n\t" - "vaddsd %%xmm7, %%xmm3, %%xmm3 \n\t" + }; - "vmovsd %%xmm0 , (%9) \n\t" // save temp2 - "vmovsd %%xmm1 , 8(%9) \n\t" // save temp2 - "vmovsd %%xmm2 ,16(%9) \n\t" // save temp2 - "vmovsd %%xmm3 ,24(%9) \n\t" // save temp2 - "vzeroupper \n\t" + /* + * we need to fold our 512 bit wide accumulator vectors into 256 bit wide vectors so that the AVX2 code + * below can continue using the intermediate results in its loop + */ + accum_0 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_05, 0), _mm512_extractf64x4_pd(accum_05, 1)); + accum_1 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_15, 0), _mm512_extractf64x4_pd(accum_15, 1)); + accum_2 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_25, 0), _mm512_extractf64x4_pd(accum_25, 1)); + accum_3 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_35, 0), _mm512_extractf64x4_pd(accum_35, 1)); - : - : - "r" (from), // 0 - "r" (to), // 1 - "r" (x), // 2 - "r" (y), // 3 - "r" (a[0]), // 4 - "r" (a[1]), // 5 - "r" (a[2]), // 6 - "r" (a[3]), // 8 - "r" (temp1), // 8 - "r" (temp2) // 9 - : "cc", - "%xmm0", "%xmm1", "%xmm2", "%xmm3", - "%xmm4", "%xmm5", "%xmm6", "%xmm7", - "%xmm8", "%xmm9", "%xmm10", "%xmm11", - "%xmm12", "%xmm13", "%xmm14", "%xmm15", - "memory" - ); +#endif + for (; from != to; from += 4) { + __m256d _x, _y; + __m256d a0, a1, a2, a3; + + _y = _mm256_loadu_pd(&y[from]); + _x = _mm256_loadu_pd(&x[from]); + + /* load 4 rows of matrix data */ + a0 = _mm256_loadu_pd(&a[0][from]); + a1 = _mm256_loadu_pd(&a[1][from]); + a2 = _mm256_loadu_pd(&a[2][from]); + a3 = _mm256_loadu_pd(&a[3][from]); + + _y += temp1_0 * a0 + temp1_1 * a1 + temp1_2 * a2 + temp1_3 * a3; + + accum_0 += _x * a0; + accum_1 += _x * a1; + accum_2 += _x * a2; + accum_3 += _x * a3; + + _mm256_storeu_pd(&y[from], _y); + + }; + + /* + * we now have 4 accumulator vectors. Each vector needs to be summed up element wise and stored in the temp2 + * output array. There is no direct instruction for this in 256 bit space, only in 128 space. + */ + + __m128d half_accum0, half_accum1, half_accum2, half_accum3; + + + /* Add upper half to lower half of each of the four 256 bit vectors to get to four 128 bit vectors */ + half_accum0 = _mm_add_pd(_mm256_extractf128_pd(accum_0, 0), _mm256_extractf128_pd(accum_0, 1)); + half_accum1 = _mm_add_pd(_mm256_extractf128_pd(accum_1, 0), _mm256_extractf128_pd(accum_1, 1)); + half_accum2 = _mm_add_pd(_mm256_extractf128_pd(accum_2, 0), _mm256_extractf128_pd(accum_2, 1)); + half_accum3 = _mm_add_pd(_mm256_extractf128_pd(accum_3, 0), _mm256_extractf128_pd(accum_3, 1)); + + /* in 128 bit land there is a hadd operation to do the rest of the element-wise sum in one go */ + half_accum0 = _mm_hadd_pd(half_accum0, half_accum0); + half_accum1 = _mm_hadd_pd(half_accum1, half_accum1); + half_accum2 = _mm_hadd_pd(half_accum2, half_accum2); + half_accum3 = _mm_hadd_pd(half_accum3, half_accum3); + + /* and store the lowest double value from each of these vectors in the temp2 output */ + temp2[0] += half_accum0[0]; + temp2[1] += half_accum1[0]; + temp2[2] += half_accum2[0]; + temp2[3] += half_accum3[0]; } - +#endif \ No newline at end of file diff --git a/kernel/x86_64/saxpy_microk_haswell-2.c b/kernel/x86_64/saxpy_microk_haswell-2.c index 3a743d64c..0f6587367 100644 --- a/kernel/x86_64/saxpy_microk_haswell-2.c +++ b/kernel/x86_64/saxpy_microk_haswell-2.c @@ -25,54 +25,47 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ + +#ifndef __AVX512CD__ +#pragma GCC target("avx2,fma") +#endif + +#ifdef __AVX2__ + #define HAVE_KERNEL_16 1 -static void saxpy_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *alpha) __attribute__ ((noinline)); + +#include static void saxpy_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *alpha) { + BLASLONG i = 0; + __m256 __alpha; - BLASLONG register i = 0; + __alpha = _mm256_broadcastss_ps(_mm_load_ss(alpha)); - __asm__ __volatile__ - ( - "vbroadcastss (%4), %%ymm0 \n\t" // alpha +#ifdef __AVX512CD__ + BLASLONG n64; + __m512 __alpha5; + __alpha5 = _mm512_broadcastss_ps(_mm_load_ss(alpha)); - ".p2align 4 \n\t" - "1: \n\t" + n64 = n & ~63; - "vmovups (%3,%0,4), %%ymm12 \n\t" // 8 * y - "vmovups 32(%3,%0,4), %%ymm13 \n\t" // 8 * y - "vmovups 64(%3,%0,4), %%ymm14 \n\t" // 8 * y - "vmovups 96(%3,%0,4), %%ymm15 \n\t" // 8 * y - "vfmadd231ps (%2,%0,4), %%ymm0 , %%ymm12 \n\t" // y += alpha * x - "vfmadd231ps 32(%2,%0,4), %%ymm0 , %%ymm13 \n\t" // y += alpha * x - "vfmadd231ps 64(%2,%0,4), %%ymm0 , %%ymm14 \n\t" // y += alpha * x - "vfmadd231ps 96(%2,%0,4), %%ymm0 , %%ymm15 \n\t" // y += alpha * x - "vmovups %%ymm12, (%3,%0,4) \n\t" - "vmovups %%ymm13, 32(%3,%0,4) \n\t" - "vmovups %%ymm14, 64(%3,%0,4) \n\t" - "vmovups %%ymm15, 96(%3,%0,4) \n\t" + for (; i < n64; i+= 64) { + _mm512_storeu_ps(&y[i + 0], _mm512_loadu_ps(&y[i + 0]) + __alpha5 * _mm512_loadu_ps(&x[i + 0])); + _mm512_storeu_ps(&y[i + 16], _mm512_loadu_ps(&y[i + 16]) + __alpha5 * _mm512_loadu_ps(&x[i + 16])); + _mm512_storeu_ps(&y[i + 32], _mm512_loadu_ps(&y[i + 32]) + __alpha5 * _mm512_loadu_ps(&x[i + 32])); + _mm512_storeu_ps(&y[i + 48], _mm512_loadu_ps(&y[i + 48]) + __alpha5 * _mm512_loadu_ps(&x[i + 48])); + } - "addq $32, %0 \n\t" - "subq $32, %1 \n\t" - "jnz 1b \n\t" - "vzeroupper \n\t" - - : - : - "r" (i), // 0 - "r" (n), // 1 - "r" (x), // 2 - "r" (y), // 3 - "r" (alpha) // 4 - : "cc", - "%xmm0", - "%xmm8", "%xmm9", "%xmm10", "%xmm11", - "%xmm12", "%xmm13", "%xmm14", "%xmm15", - "memory" - ); - -} +#endif + for (; i < n; i+= 32) { + _mm256_storeu_ps(&y[i + 0], _mm256_loadu_ps(&y[i + 0]) + __alpha * _mm256_loadu_ps(&x[i + 0])); + _mm256_storeu_ps(&y[i + 8], _mm256_loadu_ps(&y[i + 8]) + __alpha * _mm256_loadu_ps(&x[i + 8])); + _mm256_storeu_ps(&y[i + 16], _mm256_loadu_ps(&y[i + 16]) + __alpha * _mm256_loadu_ps(&x[i + 16])); + _mm256_storeu_ps(&y[i + 24], _mm256_loadu_ps(&y[i + 24]) + __alpha * _mm256_loadu_ps(&x[i + 24])); + } +} +#endif diff --git a/kernel/x86_64/sdot_microk_haswell-2.c b/kernel/x86_64/sdot_microk_haswell-2.c index df367b61f..070171792 100644 --- a/kernel/x86_64/sdot_microk_haswell-2.c +++ b/kernel/x86_64/sdot_microk_haswell-2.c @@ -25,74 +25,75 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ -#define HAVE_KERNEL_16 1 -static void sdot_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *dot) __attribute__ ((noinline)); - -static void sdot_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot) -{ - - - BLASLONG register i = 0; - - __asm__ __volatile__ - ( - "vxorps %%ymm4, %%ymm4, %%ymm4 \n\t" - "vxorps %%ymm5, %%ymm5, %%ymm5 \n\t" - "vxorps %%ymm6, %%ymm6, %%ymm6 \n\t" - "vxorps %%ymm7, %%ymm7, %%ymm7 \n\t" - - ".p2align 4 \n\t" - "1: \n\t" - "vmovups (%2,%0,4), %%ymm12 \n\t" // 2 * x - "vmovups 32(%2,%0,4), %%ymm13 \n\t" // 2 * x - "vmovups 64(%2,%0,4), %%ymm14 \n\t" // 2 * x - "vmovups 96(%2,%0,4), %%ymm15 \n\t" // 2 * x - - "vfmadd231ps (%3,%0,4), %%ymm12, %%ymm4 \n\t" // 2 * y - "vfmadd231ps 32(%3,%0,4), %%ymm13, %%ymm5 \n\t" // 2 * y - "vfmadd231ps 64(%3,%0,4), %%ymm14, %%ymm6 \n\t" // 2 * y - "vfmadd231ps 96(%3,%0,4), %%ymm15, %%ymm7 \n\t" // 2 * y - -#ifndef DSDOT - "addq $32 , %0 \n\t" - "subq $32 , %1 \n\t" - "jnz 1b \n\t" +#ifndef __AVX512CD__ +#pragma GCC target("avx2,fma") #endif - "vextractf128 $1 , %%ymm4 , %%xmm12 \n\t" - "vextractf128 $1 , %%ymm5 , %%xmm13 \n\t" - "vextractf128 $1 , %%ymm6 , %%xmm14 \n\t" - "vextractf128 $1 , %%ymm7 , %%xmm15 \n\t" +#ifdef __AVX2__ - "vaddps %%xmm4, %%xmm12, %%xmm4 \n\t" - "vaddps %%xmm5, %%xmm13, %%xmm5 \n\t" - "vaddps %%xmm6, %%xmm14, %%xmm6 \n\t" - "vaddps %%xmm7, %%xmm15, %%xmm7 \n\t" +#define HAVE_KERNEL_16 1 - "vaddps %%xmm4, %%xmm5, %%xmm4 \n\t" - "vaddps %%xmm6, %%xmm7, %%xmm6 \n\t" - "vaddps %%xmm4, %%xmm6, %%xmm4 \n\t" +#include - "vhaddps %%xmm4, %%xmm4, %%xmm4 \n\t" - "vhaddps %%xmm4, %%xmm4, %%xmm4 \n\t" +static void sdot_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot) - "vmovss %%xmm4, (%4) \n\t" - "vzeroupper \n\t" +{ + int i = 0; + __m256 accum_0, accum_1, accum_2, accum_3; - : - : - "r" (i), // 0 - "r" (n), // 1 - "r" (x), // 2 - "r" (y), // 3 - "r" (dot) // 4 - : "cc", - "%xmm4", "%xmm5", - "%xmm6", "%xmm7", - "%xmm12", "%xmm13", "%xmm14", "%xmm15", - "memory" - ); + accum_0 = _mm256_setzero_ps(); + accum_1 = _mm256_setzero_ps(); + accum_2 = _mm256_setzero_ps(); + accum_3 = _mm256_setzero_ps(); -} +#ifdef __AVX512CD__ + __m512 accum_05, accum_15, accum_25, accum_35; + int n64; + n64 = n & (~63); + accum_05 = _mm512_setzero_ps(); + accum_15 = _mm512_setzero_ps(); + accum_25 = _mm512_setzero_ps(); + accum_35 = _mm512_setzero_ps(); + for (; i < n64; i += 64) { + accum_05 += _mm512_loadu_ps(&x[i+ 0]) * _mm512_loadu_ps(&y[i+ 0]); + accum_15 += _mm512_loadu_ps(&x[i+16]) * _mm512_loadu_ps(&y[i+16]); + accum_25 += _mm512_loadu_ps(&x[i+32]) * _mm512_loadu_ps(&y[i+32]); + accum_35 += _mm512_loadu_ps(&x[i+48]) * _mm512_loadu_ps(&y[i+48]); + } + + /* + * we need to fold our 512 bit wide accumulator vectors into 256 bit wide vectors so that the AVX2 code + * below can continue using the intermediate results in its loop + */ + accum_0 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_05, 0), _mm512_extractf32x8_ps(accum_05, 1)); + accum_1 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_15, 0), _mm512_extractf32x8_ps(accum_15, 1)); + accum_2 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_25, 0), _mm512_extractf32x8_ps(accum_25, 1)); + accum_3 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_35, 0), _mm512_extractf32x8_ps(accum_35, 1)); + +#endif + for (; i < n; i += 32) { + accum_0 += _mm256_loadu_ps(&x[i+ 0]) * _mm256_loadu_ps(&y[i+ 0]); + accum_1 += _mm256_loadu_ps(&x[i+ 8]) * _mm256_loadu_ps(&y[i+ 8]); + accum_2 += _mm256_loadu_ps(&x[i+16]) * _mm256_loadu_ps(&y[i+16]); + accum_3 += _mm256_loadu_ps(&x[i+24]) * _mm256_loadu_ps(&y[i+24]); + } + + /* we now have the partial sums of the dot product in the 4 accumulation vectors, time to consolidate */ + + accum_0 = accum_0 + accum_1 + accum_2 + accum_3; + + __m128 half_accum0; + + /* Add upper half to lower half of each of the 256 bit vector to get a 128 bit vector */ + half_accum0 = _mm_add_ps(_mm256_extractf128_ps(accum_0, 0), _mm256_extractf128_ps(accum_0, 1)); + + /* in 128 bit land there is a hadd operation to do the rest of the element-wise sum in one go */ + half_accum0 = _mm_hadd_ps(half_accum0, half_accum0); + half_accum0 = _mm_hadd_ps(half_accum0, half_accum0); + + *dot = half_accum0[0]; +} + +#endif