Use FMA for cscal and zscal Haswell microkernels.

This patch has two benefits:
1. Using vfmaddsub231p[sd] instead of vaddsubp[sd] eliminates a vmulp[sd]
instruction, giving a ~10% speedup, measured from ~33 to ~36 Gflops
for sscal with 4096 elements and from ~17 to ~19 Gflops for dscal on my Kaby
Lake laptop, see e.g. OPENBLAS_LOOPS=10000 benchmark/cscal.goto 4096 4096.

2. Using it for both the main loop and the tail end makes sure the same FMA
instruction is used for all loop iterations, which is not the case with the
current situation where the tail loop is implemented in C, if the compiler is
allowed to use FMA instructions. This is important for some LAPACK eigenvalue
testcases that rely on bitwise identical results independent of how many loop
iterations are used.
This commit is contained in:
Bart Oldeman 2022-10-21 13:52:34 -04:00
parent 8c10f0abba
commit c956271c2e
4 changed files with 168 additions and 64 deletions

View File

@ -337,6 +337,15 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
}
#ifdef HAVE_KERNEL
if ( da_r != 0.0 && da_i != 0.0 ) {
alpha[0] = da_r;
alpha[1] = da_i;
cscal_kernel(n , alpha, x);
return(0);
}
#endif
BLASLONG n1 = n & -16;
if ( n1 > 0 )
{
@ -352,8 +361,10 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
else
if ( da_i == 0 )
cscal_kernel_16_zero_i(n1 , alpha , x);
#ifndef HAVE_KERNEL
else
cscal_kernel_16(n1 , alpha , x);
#endif
i = n1 << 1;
j = n1;

View File

@ -26,11 +26,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#define HAVE_KERNEL 1
#define HAVE_KERNEL_16 1
static void cscal_kernel_16( BLASLONG n, FLOAT *alpha, FLOAT *x) __attribute__ ((noinline));
static void cscal_kernel( BLASLONG n, FLOAT *alpha, FLOAT *x) __attribute__ ((noinline));
static void cscal_kernel_16( BLASLONG n, FLOAT *alpha, FLOAT *x)
static void cscal_kernel( BLASLONG n, FLOAT *alpha, FLOAT *x)
{
@ -39,6 +40,9 @@ static void cscal_kernel_16( BLASLONG n, FLOAT *alpha, FLOAT *x)
"vbroadcastss (%2), %%ymm0 \n\t" // da_r
"vbroadcastss 4(%2), %%ymm1 \n\t" // da_i
"cmpq $16, %0 \n\t"
"jb 3f \n\t"
"addq $128, %1 \n\t"
"vmovups -128(%1), %%ymm4 \n\t"
@ -52,7 +56,8 @@ static void cscal_kernel_16( BLASLONG n, FLOAT *alpha, FLOAT *x)
"vpermilps $0xb1 , %%ymm7, %%ymm15 \n\t"
"subq $16, %0 \n\t"
"jz 2f \n\t"
"cmpq $16, %0 \n\t"
"jb 2f \n\t"
".p2align 4 \n\t"
"1: \n\t"
@ -60,23 +65,19 @@ static void cscal_kernel_16( BLASLONG n, FLOAT *alpha, FLOAT *x)
//"prefetcht0 128(%1) \n\t"
// ".align 2 \n\t"
"vmulps %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1
"vmovups 0(%1), %%ymm4 \n\t"
"vmulps %%ymm0, %%ymm5 , %%ymm9 \n\t"
"vmovups 32(%1), %%ymm5 \n\t"
"vmulps %%ymm0, %%ymm6 , %%ymm10 \n\t"
"vmovups 64(%1), %%ymm6 \n\t"
"vmulps %%ymm0, %%ymm7 , %%ymm11 \n\t"
"vmovups 96(%1), %%ymm7 \n\t"
"vmulps %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0
"vmulps %%ymm1, %%ymm13, %%ymm9 \n\t"
"vmulps %%ymm1, %%ymm14, %%ymm10 \n\t"
"vmulps %%ymm1, %%ymm15, %%ymm11 \n\t"
"vmulps %%ymm1, %%ymm12, %%ymm12 \n\t" // da_i*x1 , da_i *x0
"vaddsubps %%ymm12 , %%ymm8 , %%ymm8 \n\t"
"vmulps %%ymm1, %%ymm13, %%ymm13 \n\t"
"vaddsubps %%ymm13 , %%ymm9 , %%ymm9 \n\t"
"vmulps %%ymm1, %%ymm14, %%ymm14 \n\t"
"vaddsubps %%ymm14 , %%ymm10, %%ymm10 \n\t"
"vmulps %%ymm1, %%ymm15, %%ymm15 \n\t"
"vaddsubps %%ymm15 , %%ymm11, %%ymm11 \n\t"
"vfmaddsub231ps %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1
"vmovups 0(%1), %%ymm4 \n\t"
"vfmaddsub231ps %%ymm0, %%ymm5 , %%ymm9 \n\t"
"vmovups 32(%1), %%ymm5 \n\t"
"vfmaddsub231ps %%ymm0, %%ymm6 , %%ymm10 \n\t"
"vmovups 64(%1), %%ymm6 \n\t"
"vfmaddsub231ps %%ymm0, %%ymm7 , %%ymm11 \n\t"
"vmovups 96(%1), %%ymm7 \n\t"
"vmovups %%ymm8 , -128(%1) \n\t"
"vpermilps $0xb1 , %%ymm4, %%ymm12 \n\t"
@ -89,30 +90,75 @@ static void cscal_kernel_16( BLASLONG n, FLOAT *alpha, FLOAT *x)
"addq $128 ,%1 \n\t"
"subq $16, %0 \n\t"
"jnz 1b \n\t"
"cmpq $16, %0 \n\t"
"jae 1b \n\t"
"2: \n\t"
"vmulps %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1
"vmulps %%ymm0, %%ymm5 , %%ymm9 \n\t"
"vmulps %%ymm0, %%ymm6 , %%ymm10 \n\t"
"vmulps %%ymm0, %%ymm7 , %%ymm11 \n\t"
"vmulps %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0
"vmulps %%ymm1, %%ymm13, %%ymm9 \n\t"
"vmulps %%ymm1, %%ymm14, %%ymm10 \n\t"
"vmulps %%ymm1, %%ymm15, %%ymm11 \n\t"
"vmulps %%ymm1, %%ymm12, %%ymm12 \n\t" // da_i*x1 , da_i *x0
"vaddsubps %%ymm12 , %%ymm8 , %%ymm8 \n\t"
"vmulps %%ymm1, %%ymm13, %%ymm13 \n\t"
"vaddsubps %%ymm13 , %%ymm9 , %%ymm9 \n\t"
"vmulps %%ymm1, %%ymm14, %%ymm14 \n\t"
"vaddsubps %%ymm14 , %%ymm10, %%ymm10 \n\t"
"vmulps %%ymm1, %%ymm15, %%ymm15 \n\t"
"vaddsubps %%ymm15 , %%ymm11, %%ymm11 \n\t"
"vfmaddsub231ps %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1
"vfmaddsub231ps %%ymm0, %%ymm5 , %%ymm9 \n\t"
"vfmaddsub231ps %%ymm0, %%ymm6 , %%ymm10 \n\t"
"vfmaddsub231ps %%ymm0, %%ymm7 , %%ymm11 \n\t"
"vmovups %%ymm8 , -128(%1) \n\t"
"vmovups %%ymm9 , -96(%1) \n\t"
"vmovups %%ymm10, -64(%1) \n\t"
"vmovups %%ymm11, -32(%1) \n\t"
"testq $15, %0 \n\t"
"jz 7f \n\t"
"3: \n\t"
"testq $8, %0 \n\t"
"jz 4f \n\t"
"vmovups 0(%1), %%ymm4 \n\t"
"vmovups 32(%1), %%ymm5 \n\t"
"vpermilps $0xb1 , %%ymm4 , %%ymm12 \n\t"
"vpermilps $0xb1 , %%ymm5 , %%ymm13 \n\t"
"vmulps %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0
"vmulps %%ymm1, %%ymm13, %%ymm9 \n\t"
"vfmaddsub231ps %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1
"vfmaddsub231ps %%ymm0, %%ymm5 , %%ymm9 \n\t"
"vmovups %%ymm8, 0(%1) \n\t"
"vmovups %%ymm9, 32(%1) \n\t"
"addq $64, %1 \n\t"
"4: \n\t"
"testq $4, %0 \n\t"
"jz 5f \n\t"
"vmovups 0(%1), %%ymm4 \n\t"
"vpermilps $0xb1 , %%ymm4 , %%ymm12 \n\t"
"vmulps %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0
"vfmaddsub231ps %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1
"vmovups %%ymm8, 0(%1) \n\t"
"addq $32, %1 \n\t"
"5: \n\t"
"testq $2, %0 \n\t"
"jz 6f \n\t"
"vmovups 0(%1), %%xmm4 \n\t"
"vpermilps $0xb1 , %%xmm4 , %%xmm12 \n\t"
"vmulps %%xmm1, %%xmm12, %%xmm8 \n\t" // da_i*x1 , da_i *x0
"vfmaddsub231ps %%xmm0, %%xmm4 , %%xmm8 \n\t" // da_r*x0 , da_r *x1
"vmovups %%xmm8, 0(%1) \n\t"
"addq $16, %1 \n\t"
"6: \n\t"
"testq $1, %0 \n\t"
"jz 7f \n\t"
"vmovsd 0(%1), %%xmm4 \n\t"
"vpermilps $0x1 , %%xmm4 , %%xmm12 \n\t"
"vmulps %%xmm1, %%xmm12, %%xmm8 \n\t" // da_i*x1 , da_i *x0
"vfmaddsub231ps %%xmm0, %%xmm4 , %%xmm8 \n\t" // da_r*x0 , da_r *x1
"vmovsd %%xmm8, 0(%1) \n\t"
"7: \n\t"
"vzeroupper \n\t"
:

View File

@ -335,6 +335,15 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
}
#ifdef HAVE_KERNEL
if ( da_r != 0.0 && da_i != 0.0 ) {
alpha[0] = da_r;
alpha[1] = da_i;
zscal_kernel(n , alpha, x);
return(0);
}
#endif
BLASLONG n1 = n & -8;
if ( n1 > 0 )
{
@ -350,8 +359,10 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
else
if ( da_i == 0 )
zscal_kernel_8_zero_i(n1 , alpha , x);
#ifndef HAVE_KERNEL
else
zscal_kernel_8(n1 , alpha , x);
#endif
i = n1 << 1;
j = n1;

View File

@ -26,11 +26,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#define HAVE_KERNEL 1
#define HAVE_KERNEL_8 1
static void zscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x) __attribute__ ((noinline));
static void zscal_kernel( BLASLONG n, FLOAT *alpha, FLOAT *x) __attribute__ ((noinline));
static void zscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x)
static void zscal_kernel( BLASLONG n, FLOAT *alpha, FLOAT *x)
{
@ -39,6 +40,9 @@ static void zscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x)
"vbroadcastsd (%2), %%ymm0 \n\t" // da_r
"vbroadcastsd 8(%2), %%ymm1 \n\t" // da_i
"cmpq $8 , %0 \n\t"
"jb 3f \n\t"
"addq $128, %1 \n\t"
"vmovups -128(%1), %%ymm4 \n\t"
@ -52,7 +56,8 @@ static void zscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x)
"vpermilpd $0x05 , %%ymm7, %%ymm15 \n\t"
"subq $8 , %0 \n\t"
"jz 2f \n\t"
"cmpq $8 , %0 \n\t"
"jb 2f \n\t"
".p2align 4 \n\t"
"1: \n\t"
@ -60,23 +65,19 @@ static void zscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x)
//"prefetcht0 128(%1) \n\t"
// ".align 2 \n\t"
"vmulpd %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1
"vmovups 0(%1), %%ymm4 \n\t"
"vmulpd %%ymm0, %%ymm5 , %%ymm9 \n\t"
"vmovups 32(%1), %%ymm5 \n\t"
"vmulpd %%ymm0, %%ymm6 , %%ymm10 \n\t"
"vmovups 64(%1), %%ymm6 \n\t"
"vmulpd %%ymm0, %%ymm7 , %%ymm11 \n\t"
"vmovups 96(%1), %%ymm7 \n\t"
"vmulpd %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0
"vmulpd %%ymm1, %%ymm13, %%ymm9 \n\t"
"vmulpd %%ymm1, %%ymm14, %%ymm10 \n\t"
"vmulpd %%ymm1, %%ymm15, %%ymm11 \n\t"
"vmulpd %%ymm1, %%ymm12, %%ymm12 \n\t" // da_i*x1 , da_i *x0
"vaddsubpd %%ymm12 , %%ymm8 , %%ymm8 \n\t"
"vmulpd %%ymm1, %%ymm13, %%ymm13 \n\t"
"vaddsubpd %%ymm13 , %%ymm9 , %%ymm9 \n\t"
"vmulpd %%ymm1, %%ymm14, %%ymm14 \n\t"
"vaddsubpd %%ymm14 , %%ymm10, %%ymm10 \n\t"
"vmulpd %%ymm1, %%ymm15, %%ymm15 \n\t"
"vaddsubpd %%ymm15 , %%ymm11, %%ymm11 \n\t"
"vfmaddsub231pd %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1
"vmovups 0(%1), %%ymm4 \n\t"
"vfmaddsub231pd %%ymm0, %%ymm5 , %%ymm9 \n\t"
"vmovups 32(%1), %%ymm5 \n\t"
"vfmaddsub231pd %%ymm0, %%ymm6 , %%ymm10 \n\t"
"vmovups 64(%1), %%ymm6 \n\t"
"vfmaddsub231pd %%ymm0, %%ymm7 , %%ymm11 \n\t"
"vmovups 96(%1), %%ymm7 \n\t"
"vmovups %%ymm8 , -128(%1) \n\t"
"vpermilpd $0x05 , %%ymm4, %%ymm12 \n\t"
@ -89,30 +90,65 @@ static void zscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x)
"addq $128 ,%1 \n\t"
"subq $8 , %0 \n\t"
"jnz 1b \n\t"
"cmpq $8 , %0 \n\t"
"jae 1b \n\t"
"2: \n\t"
"vmulpd %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1
"vmulpd %%ymm0, %%ymm5 , %%ymm9 \n\t"
"vmulpd %%ymm0, %%ymm6 , %%ymm10 \n\t"
"vmulpd %%ymm0, %%ymm7 , %%ymm11 \n\t"
"vmulpd %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0
"vmulpd %%ymm1, %%ymm13, %%ymm9 \n\t"
"vmulpd %%ymm1, %%ymm14, %%ymm10 \n\t"
"vmulpd %%ymm1, %%ymm15, %%ymm11 \n\t"
"vmulpd %%ymm1, %%ymm12, %%ymm12 \n\t" // da_i*x1 , da_i *x0
"vaddsubpd %%ymm12 , %%ymm8 , %%ymm8 \n\t"
"vmulpd %%ymm1, %%ymm13, %%ymm13 \n\t"
"vaddsubpd %%ymm13 , %%ymm9 , %%ymm9 \n\t"
"vmulpd %%ymm1, %%ymm14, %%ymm14 \n\t"
"vaddsubpd %%ymm14 , %%ymm10, %%ymm10 \n\t"
"vmulpd %%ymm1, %%ymm15, %%ymm15 \n\t"
"vaddsubpd %%ymm15 , %%ymm11, %%ymm11 \n\t"
"vfmaddsub231pd %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1
"vfmaddsub231pd %%ymm0, %%ymm5 , %%ymm9 \n\t"
"vfmaddsub231pd %%ymm0, %%ymm6 , %%ymm10 \n\t"
"vfmaddsub231pd %%ymm0, %%ymm7 , %%ymm11 \n\t"
"vmovups %%ymm8 , -128(%1) \n\t"
"vmovups %%ymm9 , -96(%1) \n\t"
"vmovups %%ymm10, -64(%1) \n\t"
"vmovups %%ymm11, -32(%1) \n\t"
"testq $7, %0 \n\t"
"jz 6f \n\t"
"3: \n\t"
"testq $4, %0 \n\t"
"jz 4f \n\t"
"vmovups 0(%1), %%ymm4 \n\t"
"vmovups 32(%1), %%ymm5 \n\t"
"vpermilpd $0x05 , %%ymm4 , %%ymm12 \n\t"
"vpermilpd $0x05 , %%ymm5 , %%ymm13 \n\t"
"vmulpd %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0
"vmulpd %%ymm1, %%ymm13, %%ymm9 \n\t"
"vfmaddsub231pd %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1
"vfmaddsub231pd %%ymm0, %%ymm5 , %%ymm9 \n\t"
"vmovups %%ymm8, 0(%1) \n\t"
"vmovups %%ymm9, 32(%1) \n\t"
"addq $64, %1 \n\t"
"4: \n\t"
"testq $2, %0 \n\t"
"jz 5f \n\t"
"vmovups 0(%1), %%ymm4 \n\t"
"vpermilpd $0x05 , %%ymm4 , %%ymm12 \n\t"
"vmulpd %%ymm1, %%ymm12, %%ymm8 \n\t" // da_i*x1 , da_i *x0
"vfmaddsub231pd %%ymm0, %%ymm4 , %%ymm8 \n\t" // da_r*x0 , da_r *x1
"vmovups %%ymm8, 0(%1) \n\t"
"addq $32, %1 \n\t"
"5: \n\t"
"testq $1, %0 \n\t"
"jz 6f \n\t"
"vmovups 0(%1), %%xmm4 \n\t"
"vpermilpd $0x01 , %%xmm4 , %%xmm12 \n\t"
"vmulpd %%xmm1, %%xmm12, %%xmm8 \n\t" // da_i*x1 , da_i *x0
"vfmaddsub231pd %%xmm0, %%xmm4 , %%xmm8 \n\t" // da_r*x0 , da_r *x1
"vmovups %%xmm8, 0(%1) \n\t"
"6: \n\t"
"vzeroupper \n\t"
: