From c956271c2e2d9196872f58f09d6ee3187fa0b718 Mon Sep 17 00:00:00 2001 From: Bart Oldeman Date: Fri, 21 Oct 2022 13:52:34 -0400 Subject: [PATCH] Use FMA for cscal and zscal Haswell microkernels. This patch has two benefits: 1. Using vfmaddsub231p[sd] instead of vaddsubp[sd] eliminates a vmulp[sd] instruction, giving a ~10% speedup, measured from ~33 to ~36 Gflops for sscal with 4096 elements and from ~17 to ~19 Gflops for dscal on my Kaby Lake laptop, see e.g. OPENBLAS_LOOPS=10000 benchmark/cscal.goto 4096 4096. 2. Using it for both the main loop and the tail end makes sure the same FMA instruction is used for all loop iterations, which is not the case with the current situation where the tail loop is implemented in C, if the compiler is allowed to use FMA instructions. This is important for some LAPACK eigenvalue testcases that rely on bitwise identical results independent of how many loop iterations are used. --- kernel/x86_64/cscal.c | 11 +++ kernel/x86_64/cscal_microk_haswell-2.c | 110 ++++++++++++++++++------- kernel/x86_64/zscal.c | 11 +++ kernel/x86_64/zscal_microk_haswell-2.c | 100 +++++++++++++++------- 4 files changed, 168 insertions(+), 64 deletions(-) diff --git a/kernel/x86_64/cscal.c b/kernel/x86_64/cscal.c index dc3f688c6..bc5856267 100644 --- a/kernel/x86_64/cscal.c +++ b/kernel/x86_64/cscal.c @@ -337,6 +337,15 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i, } +#ifdef HAVE_KERNEL + if ( da_r != 0.0 && da_i != 0.0 ) { + alpha[0] = da_r; + alpha[1] = da_i; + cscal_kernel(n , alpha, x); + return(0); + } +#endif + BLASLONG n1 = n & -16; if ( n1 > 0 ) { @@ -352,8 +361,10 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i, else if ( da_i == 0 ) cscal_kernel_16_zero_i(n1 , alpha , x); +#ifndef HAVE_KERNEL else cscal_kernel_16(n1 , alpha , x); +#endif i = n1 << 1; j = n1; diff --git a/kernel/x86_64/cscal_microk_haswell-2.c b/kernel/x86_64/cscal_microk_haswell-2.c index a04a4c4ab..098b12b61 100644 --- a/kernel/x86_64/cscal_microk_haswell-2.c +++ b/kernel/x86_64/cscal_microk_haswell-2.c @@ -26,11 +26,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ +#define HAVE_KERNEL 1 #define HAVE_KERNEL_16 1 -static void cscal_kernel_16( BLASLONG n, FLOAT *alpha, FLOAT *x) __attribute__ ((noinline)); +static void cscal_kernel( BLASLONG n, FLOAT *alpha, FLOAT *x) __attribute__ ((noinline)); -static void cscal_kernel_16( BLASLONG n, FLOAT *alpha, FLOAT *x) +static void cscal_kernel( BLASLONG n, FLOAT *alpha, FLOAT *x) { @@ -39,6 +40,9 @@ static void cscal_kernel_16( BLASLONG n, FLOAT *alpha, FLOAT *x) "vbroadcastss (%2), %%ymm0 \n\t" // da_r "vbroadcastss 4(%2), %%ymm1 \n\t" // da_i + "cmpq $16, %0 \n\t" + "jb 3f \n\t" + "addq $128, %1 \n\t" "vmovups -128(%1), %%ymm4 \n\t" @@ -52,7 +56,8 @@ static void cscal_kernel_16( BLASLONG n, FLOAT *alpha, FLOAT *x) "vpermilps $0xb1 , %%ymm7, %%ymm15 \n\t" "subq $16, %0 \n\t" - "jz 2f \n\t" + "cmpq $16, %0 \n\t" + "jb 2f \n\t" ".p2align 4 \n\t" "1: \n\t" @@ -60,23 +65,19 @@ static void cscal_kernel_16( BLASLONG n, FLOAT *alpha, FLOAT *x) //"prefetcht0 128(%1) \n\t" // ".align 2 \n\t" - "vmulps %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1 - "vmovups 0(%1), %%ymm4 \n\t" - "vmulps %%ymm0, %%ymm5 , %%ymm9 \n\t" - "vmovups 32(%1), %%ymm5 \n\t" - "vmulps %%ymm0, %%ymm6 , %%ymm10 \n\t" - "vmovups 64(%1), %%ymm6 \n\t" - "vmulps %%ymm0, %%ymm7 , %%ymm11 \n\t" - "vmovups 96(%1), %%ymm7 \n\t" + "vmulps %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0 + "vmulps %%ymm1, %%ymm13, %%ymm9 \n\t" + "vmulps %%ymm1, %%ymm14, %%ymm10 \n\t" + "vmulps %%ymm1, %%ymm15, %%ymm11 \n\t" - "vmulps %%ymm1, %%ymm12, %%ymm12 \n\t" // da_i*x1 , da_i *x0 - "vaddsubps %%ymm12 , %%ymm8 , %%ymm8 \n\t" - "vmulps %%ymm1, %%ymm13, %%ymm13 \n\t" - "vaddsubps %%ymm13 , %%ymm9 , %%ymm9 \n\t" - "vmulps %%ymm1, %%ymm14, %%ymm14 \n\t" - "vaddsubps %%ymm14 , %%ymm10, %%ymm10 \n\t" - "vmulps %%ymm1, %%ymm15, %%ymm15 \n\t" - "vaddsubps %%ymm15 , %%ymm11, %%ymm11 \n\t" + "vfmaddsub231ps %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1 + "vmovups 0(%1), %%ymm4 \n\t" + "vfmaddsub231ps %%ymm0, %%ymm5 , %%ymm9 \n\t" + "vmovups 32(%1), %%ymm5 \n\t" + "vfmaddsub231ps %%ymm0, %%ymm6 , %%ymm10 \n\t" + "vmovups 64(%1), %%ymm6 \n\t" + "vfmaddsub231ps %%ymm0, %%ymm7 , %%ymm11 \n\t" + "vmovups 96(%1), %%ymm7 \n\t" "vmovups %%ymm8 , -128(%1) \n\t" "vpermilps $0xb1 , %%ymm4, %%ymm12 \n\t" @@ -89,30 +90,75 @@ static void cscal_kernel_16( BLASLONG n, FLOAT *alpha, FLOAT *x) "addq $128 ,%1 \n\t" "subq $16, %0 \n\t" - "jnz 1b \n\t" + "cmpq $16, %0 \n\t" + "jae 1b \n\t" "2: \n\t" - "vmulps %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1 - "vmulps %%ymm0, %%ymm5 , %%ymm9 \n\t" - "vmulps %%ymm0, %%ymm6 , %%ymm10 \n\t" - "vmulps %%ymm0, %%ymm7 , %%ymm11 \n\t" + "vmulps %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0 + "vmulps %%ymm1, %%ymm13, %%ymm9 \n\t" + "vmulps %%ymm1, %%ymm14, %%ymm10 \n\t" + "vmulps %%ymm1, %%ymm15, %%ymm11 \n\t" - "vmulps %%ymm1, %%ymm12, %%ymm12 \n\t" // da_i*x1 , da_i *x0 - "vaddsubps %%ymm12 , %%ymm8 , %%ymm8 \n\t" - "vmulps %%ymm1, %%ymm13, %%ymm13 \n\t" - "vaddsubps %%ymm13 , %%ymm9 , %%ymm9 \n\t" - "vmulps %%ymm1, %%ymm14, %%ymm14 \n\t" - "vaddsubps %%ymm14 , %%ymm10, %%ymm10 \n\t" - "vmulps %%ymm1, %%ymm15, %%ymm15 \n\t" - "vaddsubps %%ymm15 , %%ymm11, %%ymm11 \n\t" + "vfmaddsub231ps %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1 + "vfmaddsub231ps %%ymm0, %%ymm5 , %%ymm9 \n\t" + "vfmaddsub231ps %%ymm0, %%ymm6 , %%ymm10 \n\t" + "vfmaddsub231ps %%ymm0, %%ymm7 , %%ymm11 \n\t" "vmovups %%ymm8 , -128(%1) \n\t" "vmovups %%ymm9 , -96(%1) \n\t" "vmovups %%ymm10, -64(%1) \n\t" "vmovups %%ymm11, -32(%1) \n\t" + "testq $15, %0 \n\t" + "jz 7f \n\t" + + "3: \n\t" + "testq $8, %0 \n\t" + "jz 4f \n\t" + "vmovups 0(%1), %%ymm4 \n\t" + "vmovups 32(%1), %%ymm5 \n\t" + "vpermilps $0xb1 , %%ymm4 , %%ymm12 \n\t" + "vpermilps $0xb1 , %%ymm5 , %%ymm13 \n\t" + "vmulps %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0 + "vmulps %%ymm1, %%ymm13, %%ymm9 \n\t" + "vfmaddsub231ps %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1 + "vfmaddsub231ps %%ymm0, %%ymm5 , %%ymm9 \n\t" + "vmovups %%ymm8, 0(%1) \n\t" + "vmovups %%ymm9, 32(%1) \n\t" + "addq $64, %1 \n\t" + + "4: \n\t" + "testq $4, %0 \n\t" + "jz 5f \n\t" + "vmovups 0(%1), %%ymm4 \n\t" + "vpermilps $0xb1 , %%ymm4 , %%ymm12 \n\t" + "vmulps %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0 + "vfmaddsub231ps %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1 + "vmovups %%ymm8, 0(%1) \n\t" + "addq $32, %1 \n\t" + + "5: \n\t" + "testq $2, %0 \n\t" + "jz 6f \n\t" + "vmovups 0(%1), %%xmm4 \n\t" + "vpermilps $0xb1 , %%xmm4 , %%xmm12 \n\t" + "vmulps %%xmm1, %%xmm12, %%xmm8 \n\t" // da_i*x1 , da_i *x0 + "vfmaddsub231ps %%xmm0, %%xmm4 , %%xmm8 \n\t" // da_r*x0 , da_r *x1 + "vmovups %%xmm8, 0(%1) \n\t" + "addq $16, %1 \n\t" + + "6: \n\t" + "testq $1, %0 \n\t" + "jz 7f \n\t" + "vmovsd 0(%1), %%xmm4 \n\t" + "vpermilps $0x1 , %%xmm4 , %%xmm12 \n\t" + "vmulps %%xmm1, %%xmm12, %%xmm8 \n\t" // da_i*x1 , da_i *x0 + "vfmaddsub231ps %%xmm0, %%xmm4 , %%xmm8 \n\t" // da_r*x0 , da_r *x1 + "vmovsd %%xmm8, 0(%1) \n\t" + + "7: \n\t" "vzeroupper \n\t" : diff --git a/kernel/x86_64/zscal.c b/kernel/x86_64/zscal.c index 3744c98bb..b4de5034d 100644 --- a/kernel/x86_64/zscal.c +++ b/kernel/x86_64/zscal.c @@ -335,6 +335,15 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i, } +#ifdef HAVE_KERNEL + if ( da_r != 0.0 && da_i != 0.0 ) { + alpha[0] = da_r; + alpha[1] = da_i; + zscal_kernel(n , alpha, x); + return(0); + } +#endif + BLASLONG n1 = n & -8; if ( n1 > 0 ) { @@ -350,8 +359,10 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i, else if ( da_i == 0 ) zscal_kernel_8_zero_i(n1 , alpha , x); +#ifndef HAVE_KERNEL else zscal_kernel_8(n1 , alpha , x); +#endif i = n1 << 1; j = n1; diff --git a/kernel/x86_64/zscal_microk_haswell-2.c b/kernel/x86_64/zscal_microk_haswell-2.c index 8c8f5b75c..19b49643c 100644 --- a/kernel/x86_64/zscal_microk_haswell-2.c +++ b/kernel/x86_64/zscal_microk_haswell-2.c @@ -26,11 +26,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ +#define HAVE_KERNEL 1 #define HAVE_KERNEL_8 1 -static void zscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x) __attribute__ ((noinline)); +static void zscal_kernel( BLASLONG n, FLOAT *alpha, FLOAT *x) __attribute__ ((noinline)); -static void zscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x) +static void zscal_kernel( BLASLONG n, FLOAT *alpha, FLOAT *x) { @@ -39,6 +40,9 @@ static void zscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x) "vbroadcastsd (%2), %%ymm0 \n\t" // da_r "vbroadcastsd 8(%2), %%ymm1 \n\t" // da_i + "cmpq $8 , %0 \n\t" + "jb 3f \n\t" + "addq $128, %1 \n\t" "vmovups -128(%1), %%ymm4 \n\t" @@ -52,7 +56,8 @@ static void zscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x) "vpermilpd $0x05 , %%ymm7, %%ymm15 \n\t" "subq $8 , %0 \n\t" - "jz 2f \n\t" + "cmpq $8 , %0 \n\t" + "jb 2f \n\t" ".p2align 4 \n\t" "1: \n\t" @@ -60,23 +65,19 @@ static void zscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x) //"prefetcht0 128(%1) \n\t" // ".align 2 \n\t" - "vmulpd %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1 - "vmovups 0(%1), %%ymm4 \n\t" - "vmulpd %%ymm0, %%ymm5 , %%ymm9 \n\t" - "vmovups 32(%1), %%ymm5 \n\t" - "vmulpd %%ymm0, %%ymm6 , %%ymm10 \n\t" - "vmovups 64(%1), %%ymm6 \n\t" - "vmulpd %%ymm0, %%ymm7 , %%ymm11 \n\t" - "vmovups 96(%1), %%ymm7 \n\t" + "vmulpd %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0 + "vmulpd %%ymm1, %%ymm13, %%ymm9 \n\t" + "vmulpd %%ymm1, %%ymm14, %%ymm10 \n\t" + "vmulpd %%ymm1, %%ymm15, %%ymm11 \n\t" - "vmulpd %%ymm1, %%ymm12, %%ymm12 \n\t" // da_i*x1 , da_i *x0 - "vaddsubpd %%ymm12 , %%ymm8 , %%ymm8 \n\t" - "vmulpd %%ymm1, %%ymm13, %%ymm13 \n\t" - "vaddsubpd %%ymm13 , %%ymm9 , %%ymm9 \n\t" - "vmulpd %%ymm1, %%ymm14, %%ymm14 \n\t" - "vaddsubpd %%ymm14 , %%ymm10, %%ymm10 \n\t" - "vmulpd %%ymm1, %%ymm15, %%ymm15 \n\t" - "vaddsubpd %%ymm15 , %%ymm11, %%ymm11 \n\t" + "vfmaddsub231pd %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1 + "vmovups 0(%1), %%ymm4 \n\t" + "vfmaddsub231pd %%ymm0, %%ymm5 , %%ymm9 \n\t" + "vmovups 32(%1), %%ymm5 \n\t" + "vfmaddsub231pd %%ymm0, %%ymm6 , %%ymm10 \n\t" + "vmovups 64(%1), %%ymm6 \n\t" + "vfmaddsub231pd %%ymm0, %%ymm7 , %%ymm11 \n\t" + "vmovups 96(%1), %%ymm7 \n\t" "vmovups %%ymm8 , -128(%1) \n\t" "vpermilpd $0x05 , %%ymm4, %%ymm12 \n\t" @@ -89,30 +90,65 @@ static void zscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x) "addq $128 ,%1 \n\t" "subq $8 , %0 \n\t" - "jnz 1b \n\t" + "cmpq $8 , %0 \n\t" + "jae 1b \n\t" "2: \n\t" - "vmulpd %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1 - "vmulpd %%ymm0, %%ymm5 , %%ymm9 \n\t" - "vmulpd %%ymm0, %%ymm6 , %%ymm10 \n\t" - "vmulpd %%ymm0, %%ymm7 , %%ymm11 \n\t" + "vmulpd %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0 + "vmulpd %%ymm1, %%ymm13, %%ymm9 \n\t" + "vmulpd %%ymm1, %%ymm14, %%ymm10 \n\t" + "vmulpd %%ymm1, %%ymm15, %%ymm11 \n\t" - "vmulpd %%ymm1, %%ymm12, %%ymm12 \n\t" // da_i*x1 , da_i *x0 - "vaddsubpd %%ymm12 , %%ymm8 , %%ymm8 \n\t" - "vmulpd %%ymm1, %%ymm13, %%ymm13 \n\t" - "vaddsubpd %%ymm13 , %%ymm9 , %%ymm9 \n\t" - "vmulpd %%ymm1, %%ymm14, %%ymm14 \n\t" - "vaddsubpd %%ymm14 , %%ymm10, %%ymm10 \n\t" - "vmulpd %%ymm1, %%ymm15, %%ymm15 \n\t" - "vaddsubpd %%ymm15 , %%ymm11, %%ymm11 \n\t" + "vfmaddsub231pd %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1 + "vfmaddsub231pd %%ymm0, %%ymm5 , %%ymm9 \n\t" + "vfmaddsub231pd %%ymm0, %%ymm6 , %%ymm10 \n\t" + "vfmaddsub231pd %%ymm0, %%ymm7 , %%ymm11 \n\t" "vmovups %%ymm8 , -128(%1) \n\t" "vmovups %%ymm9 , -96(%1) \n\t" "vmovups %%ymm10, -64(%1) \n\t" "vmovups %%ymm11, -32(%1) \n\t" + "testq $7, %0 \n\t" + "jz 6f \n\t" + + "3: \n\t" + "testq $4, %0 \n\t" + "jz 4f \n\t" + "vmovups 0(%1), %%ymm4 \n\t" + "vmovups 32(%1), %%ymm5 \n\t" + "vpermilpd $0x05 , %%ymm4 , %%ymm12 \n\t" + "vpermilpd $0x05 , %%ymm5 , %%ymm13 \n\t" + "vmulpd %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0 + "vmulpd %%ymm1, %%ymm13, %%ymm9 \n\t" + "vfmaddsub231pd %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1 + "vfmaddsub231pd %%ymm0, %%ymm5 , %%ymm9 \n\t" + "vmovups %%ymm8, 0(%1) \n\t" + "vmovups %%ymm9, 32(%1) \n\t" + "addq $64, %1 \n\t" + + "4: \n\t" + "testq $2, %0 \n\t" + "jz 5f \n\t" + "vmovups 0(%1), %%ymm4 \n\t" + "vpermilpd $0x05 , %%ymm4 , %%ymm12 \n\t" + "vmulpd %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0 + "vfmaddsub231pd %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1 + "vmovups %%ymm8, 0(%1) \n\t" + "addq $32, %1 \n\t" + + "5: \n\t" + "testq $1, %0 \n\t" + "jz 6f \n\t" + "vmovups 0(%1), %%xmm4 \n\t" + "vpermilpd $0x01 , %%xmm4 , %%xmm12 \n\t" + "vmulpd %%xmm1, %%xmm12, %%xmm8 \n\t" // da_i*x1 , da_i *x0 + "vfmaddsub231pd %%xmm0, %%xmm4 , %%xmm8 \n\t" // da_r*x0 , da_r *x1 + "vmovups %%xmm8, 0(%1) \n\t" + + "6: \n\t" "vzeroupper \n\t" :