This commit is contained in:
Arjan van de Ven 2018-08-09 14:30:05 +00:00 committed by GitHub
commit b30b82ce46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 403 additions and 565 deletions

View File

@ -25,54 +25,49 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
#ifndef __AVX512CD__
#pragma GCC target("avx2,fma")
#endif
#ifdef __AVX2__
#include <immintrin.h>
#define HAVE_KERNEL_8 1 #define HAVE_KERNEL_8 1
static void daxpy_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *alpha) __attribute__ ((noinline));
static void daxpy_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *alpha) static void daxpy_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *alpha)
{ {
BLASLONG i = 0;
BLASLONG register i = 0; __m256d __alpha;
__asm__ __volatile__ __alpha = _mm256_broadcastsd_pd(_mm_load_sd(alpha));
(
"vbroadcastsd (%4), %%ymm0 \n\t" // alpha #ifdef __AVX512CD__
BLASLONG n32;
".p2align 4 \n\t" __m512d __alpha5;
"1: \n\t" __alpha5 = _mm512_broadcastsd_pd(_mm_load_sd(alpha));
"vmovups (%3,%0,8), %%ymm12 \n\t" // 4 * y n32 = n & ~31;
"vmovups 32(%3,%0,8), %%ymm13 \n\t" // 4 * y
"vmovups 64(%3,%0,8), %%ymm14 \n\t" // 4 * y for (; i < n32; i+= 32) {
"vmovups 96(%3,%0,8), %%ymm15 \n\t" // 4 * y _mm512_storeu_pd(&y[i + 0], _mm512_loadu_pd(&y[i + 0]) + __alpha5 * _mm512_loadu_pd(&x[i + 0]));
"vfmadd231pd (%2,%0,8), %%ymm0 , %%ymm12 \n\t" // y += alpha * x _mm512_storeu_pd(&y[i + 8], _mm512_loadu_pd(&y[i + 8]) + __alpha5 * _mm512_loadu_pd(&x[i + 8]));
"vfmadd231pd 32(%2,%0,8), %%ymm0 , %%ymm13 \n\t" // y += alpha * x _mm512_storeu_pd(&y[i + 16], _mm512_loadu_pd(&y[i + 16]) + __alpha5 * _mm512_loadu_pd(&x[i + 16]));
"vfmadd231pd 64(%2,%0,8), %%ymm0 , %%ymm14 \n\t" // y += alpha * x _mm512_storeu_pd(&y[i + 24], _mm512_loadu_pd(&y[i + 24]) + __alpha5 * _mm512_loadu_pd(&x[i + 24]));
"vfmadd231pd 96(%2,%0,8), %%ymm0 , %%ymm15 \n\t" // y += alpha * x }
"vmovups %%ymm12, (%3,%0,8) \n\t"
"vmovups %%ymm13, 32(%3,%0,8) \n\t" #endif
"vmovups %%ymm14, 64(%3,%0,8) \n\t"
"vmovups %%ymm15, 96(%3,%0,8) \n\t" for (; i < n; i+= 16) {
_mm256_storeu_pd(&y[i + 0], _mm256_loadu_pd(&y[i + 0]) + __alpha * _mm256_loadu_pd(&x[i + 0]));
"addq $16, %0 \n\t" _mm256_storeu_pd(&y[i + 4], _mm256_loadu_pd(&y[i + 4]) + __alpha * _mm256_loadu_pd(&x[i + 4]));
"subq $16, %1 \n\t" _mm256_storeu_pd(&y[i + 8], _mm256_loadu_pd(&y[i + 8]) + __alpha * _mm256_loadu_pd(&x[i + 8]));
"jnz 1b \n\t" _mm256_storeu_pd(&y[i + 12], _mm256_loadu_pd(&y[i + 12]) + __alpha * _mm256_loadu_pd(&x[i + 12]));
"vzeroupper \n\t" }
}
: #endif
:
"r" (i), // 0
"r" (n), // 1
"r" (x), // 2
"r" (y), // 3
"r" (alpha) // 4
: "cc",
"%xmm0",
"%xmm8", "%xmm9", "%xmm10", "%xmm11",
"%xmm12", "%xmm13", "%xmm14", "%xmm15",
"memory"
);
}

View File

@ -25,71 +25,75 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already */
#ifndef __AVX512CD__
#pragma GCC target("avx2,fma")
#endif
#ifdef __AVX2__
#define HAVE_KERNEL_8 1 #define HAVE_KERNEL_8 1
static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *dot) __attribute__ ((noinline));
#include <immintrin.h>
static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot) static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
{ {
int i = 0;
__m256d accum_0, accum_1, accum_2, accum_3;
accum_0 = _mm256_setzero_pd();
accum_1 = _mm256_setzero_pd();
accum_2 = _mm256_setzero_pd();
accum_3 = _mm256_setzero_pd();
#ifdef __AVX512CD__
__m512d accum_05, accum_15, accum_25, accum_35;
int n32;
n32 = n & (~31);
BLASLONG register i = 0; accum_05 = _mm512_setzero_pd();
accum_15 = _mm512_setzero_pd();
accum_25 = _mm512_setzero_pd();
accum_35 = _mm512_setzero_pd();
__asm__ __volatile__ for (; i < n32; i += 32) {
( accum_05 += _mm512_loadu_pd(&x[i+ 0]) * _mm512_loadu_pd(&y[i+ 0]);
"vxorpd %%ymm4, %%ymm4, %%ymm4 \n\t" accum_15 += _mm512_loadu_pd(&x[i+ 8]) * _mm512_loadu_pd(&y[i+ 8]);
"vxorpd %%ymm5, %%ymm5, %%ymm5 \n\t" accum_25 += _mm512_loadu_pd(&x[i+16]) * _mm512_loadu_pd(&y[i+16]);
"vxorpd %%ymm6, %%ymm6, %%ymm6 \n\t" accum_35 += _mm512_loadu_pd(&x[i+24]) * _mm512_loadu_pd(&y[i+24]);
"vxorpd %%ymm7, %%ymm7, %%ymm7 \n\t" }
".p2align 4 \n\t" /*
"1: \n\t" * we need to fold our 512 bit wide accumulator vectors into 256 bit wide vectors so that the AVX2 code
"vmovups (%2,%0,8), %%ymm12 \n\t" // 2 * x * below can continue using the intermediate results in its loop
"vmovups 32(%2,%0,8), %%ymm13 \n\t" // 2 * x */
"vmovups 64(%2,%0,8), %%ymm14 \n\t" // 2 * x accum_0 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_05, 0), _mm512_extractf64x4_pd(accum_05, 1));
"vmovups 96(%2,%0,8), %%ymm15 \n\t" // 2 * x accum_1 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_15, 0), _mm512_extractf64x4_pd(accum_15, 1));
accum_2 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_25, 0), _mm512_extractf64x4_pd(accum_25, 1));
accum_3 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_35, 0), _mm512_extractf64x4_pd(accum_35, 1));
"vfmadd231pd (%3,%0,8), %%ymm12, %%ymm4 \n\t" // 2 * y #endif
"vfmadd231pd 32(%3,%0,8), %%ymm13, %%ymm5 \n\t" // 2 * y for (; i < n; i += 16) {
"vfmadd231pd 64(%3,%0,8), %%ymm14, %%ymm6 \n\t" // 2 * y accum_0 += _mm256_loadu_pd(&x[i+ 0]) * _mm256_loadu_pd(&y[i+ 0]);
"vfmadd231pd 96(%3,%0,8), %%ymm15, %%ymm7 \n\t" // 2 * y accum_1 += _mm256_loadu_pd(&x[i+ 4]) * _mm256_loadu_pd(&y[i+ 4]);
accum_2 += _mm256_loadu_pd(&x[i+ 8]) * _mm256_loadu_pd(&y[i+ 8]);
accum_3 += _mm256_loadu_pd(&x[i+12]) * _mm256_loadu_pd(&y[i+12]);
}
"addq $16 , %0 \n\t" /* we now have the partial sums of the dot product in the 4 accumulation vectors, time to consolidate */
"subq $16 , %1 \n\t"
"jnz 1b \n\t"
"vextractf128 $1 , %%ymm4 , %%xmm12 \n\t" accum_0 = accum_0 + accum_1 + accum_2 + accum_3;
"vextractf128 $1 , %%ymm5 , %%xmm13 \n\t"
"vextractf128 $1 , %%ymm6 , %%xmm14 \n\t"
"vextractf128 $1 , %%ymm7 , %%xmm15 \n\t"
"vaddpd %%xmm4, %%xmm12, %%xmm4 \n\t" __m128d half_accum0;
"vaddpd %%xmm5, %%xmm13, %%xmm5 \n\t"
"vaddpd %%xmm6, %%xmm14, %%xmm6 \n\t"
"vaddpd %%xmm7, %%xmm15, %%xmm7 \n\t"
"vaddpd %%xmm4, %%xmm5, %%xmm4 \n\t" /* Add upper half to lower half of each of the 256 bit vector to get a 128 bit vector */
"vaddpd %%xmm6, %%xmm7, %%xmm6 \n\t" half_accum0 = _mm_add_pd(_mm256_extractf128_pd(accum_0, 0), _mm256_extractf128_pd(accum_0, 1));
"vaddpd %%xmm4, %%xmm6, %%xmm4 \n\t"
"vhaddpd %%xmm4, %%xmm4, %%xmm4 \n\t" /* in 128 bit land there is a hadd operation to do the rest of the element-wise sum in one go */
half_accum0 = _mm_hadd_pd(half_accum0, half_accum0);
"vmovsd %%xmm4, (%4) \n\t"
"vzeroupper \n\t"
:
:
"r" (i), // 0
"r" (n), // 1
"r" (x), // 2
"r" (y), // 3
"r" (dot) // 4
: "cc",
"%xmm4", "%xmm5",
"%xmm6", "%xmm7",
"%xmm12", "%xmm13", "%xmm14", "%xmm15",
"memory"
);
}
*dot = half_accum0[0];
}
#endif

View File

@ -25,167 +25,104 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already */
#ifndef __AVX512CD_
#pragma GCC target("avx2,fma")
#endif
#ifdef __AVX2__
#define HAVE_KERNEL_4x4 1 #define HAVE_KERNEL_4x4 1
static void dgemv_kernel_4x4( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT *alpha) __attribute__ ((noinline));
#include <immintrin.h>
static void dgemv_kernel_4x4( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT *alpha) static void dgemv_kernel_4x4( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT *alpha)
{ {
BLASLONG register i = 0; int i = 0;
__asm__ __volatile__ __m256d x0, x1, x2, x3;
( __m256d __alpha;
"vbroadcastsd (%2), %%ymm12 \n\t" // x0
"vbroadcastsd 8(%2), %%ymm13 \n\t" // x1
"vbroadcastsd 16(%2), %%ymm14 \n\t" // x2
"vbroadcastsd 24(%2), %%ymm15 \n\t" // x3
"vmovups (%4,%0,8), %%ymm0 \n\t" x0 = _mm256_broadcastsd_pd(_mm_load_sd(&x[0]));
"vmovups (%5,%0,8), %%ymm1 \n\t" x1 = _mm256_broadcastsd_pd(_mm_load_sd(&x[1]));
"vmovups (%6,%0,8), %%ymm2 \n\t" x2 = _mm256_broadcastsd_pd(_mm_load_sd(&x[2]));
"vmovups (%7,%0,8), %%ymm3 \n\t" x3 = _mm256_broadcastsd_pd(_mm_load_sd(&x[3]));
"vbroadcastsd (%8), %%ymm6 \n\t" // alpha
"addq $4 , %0 \n\t" __alpha = _mm256_broadcastsd_pd(_mm_load_sd(alpha));
"subq $4 , %1 \n\t"
"jz 2f \n\t"
// ".align 16 \n\t" #ifdef __AVX512CD__
"1: \n\t" int n5;
__m512d x05, x15, x25, x35;
__m512d __alpha5;
n5 = n & ~7;
"vmulpd %%ymm0 , %%ymm12, %%ymm4 \n\t" x05 = _mm512_broadcastsd_pd(_mm_load_sd(&x[0]));
"vmulpd %%ymm1 , %%ymm13, %%ymm5 \n\t" x15 = _mm512_broadcastsd_pd(_mm_load_sd(&x[1]));
"vmovups (%4,%0,8), %%ymm0 \n\t" x25 = _mm512_broadcastsd_pd(_mm_load_sd(&x[2]));
"vmovups (%5,%0,8), %%ymm1 \n\t" x35 = _mm512_broadcastsd_pd(_mm_load_sd(&x[3]));
"vfmadd231pd %%ymm2 , %%ymm14, %%ymm4 \n\t"
"vfmadd231pd %%ymm3 , %%ymm15, %%ymm5 \n\t"
"vmovups (%6,%0,8), %%ymm2 \n\t"
"vmovups (%7,%0,8), %%ymm3 \n\t"
"vmovups -32(%3,%0,8), %%ymm8 \n\t" // 4 * y __alpha5 = _mm512_broadcastsd_pd(_mm_load_sd(alpha));
"vaddpd %%ymm4 , %%ymm5 , %%ymm4 \n\t"
"vfmadd231pd %%ymm6 , %%ymm4 , %%ymm8 \n\t"
"vmovups %%ymm8, -32(%3,%0,8) \n\t" // 4 * y for (; i < n5; i+= 8) {
__m512d tempY;
__m512d sum;
"addq $4 , %0 \n\t" sum = _mm512_loadu_pd(&ap[0][i]) * x05 +
"subq $4 , %1 \n\t" _mm512_loadu_pd(&ap[1][i]) * x15 +
"jnz 1b \n\t" _mm512_loadu_pd(&ap[2][i]) * x25 +
_mm512_loadu_pd(&ap[3][i]) * x35;
"2: \n\t" tempY = _mm512_loadu_pd(&y[i]);
tempY += sum * __alpha5;
_mm512_storeu_pd(&y[i], tempY);
}
#endif
"vmulpd %%ymm0 , %%ymm12, %%ymm4 \n\t" for (; i < n; i+= 4) {
"vmulpd %%ymm1 , %%ymm13, %%ymm5 \n\t" __m256d tempY;
"vfmadd231pd %%ymm2 , %%ymm14, %%ymm4 \n\t" __m256d sum;
"vfmadd231pd %%ymm3 , %%ymm15, %%ymm5 \n\t"
sum = _mm256_loadu_pd(&ap[0][i]) * x0 +
_mm256_loadu_pd(&ap[1][i]) * x1 +
_mm256_loadu_pd(&ap[2][i]) * x2 +
_mm256_loadu_pd(&ap[3][i]) * x3;
"vmovups -32(%3,%0,8), %%ymm8 \n\t" // 4 * y tempY = _mm256_loadu_pd(&y[i]);
"vaddpd %%ymm4 , %%ymm5 , %%ymm4 \n\t" tempY += sum * __alpha;
"vfmadd231pd %%ymm6 , %%ymm4 , %%ymm8 \n\t" _mm256_storeu_pd(&y[i], tempY);
}
"vmovups %%ymm8, -32(%3,%0,8) \n\t" // 4 * y
"vzeroupper \n\t"
:
"+r" (i), // 0
"+r" (n) // 1
:
"r" (x), // 2
"r" (y), // 3
"r" (ap[0]), // 4
"r" (ap[1]), // 5
"r" (ap[2]), // 6
"r" (ap[3]), // 7
"r" (alpha) // 8
: "cc",
"%xmm4", "%xmm5",
"%xmm6", "%xmm7",
"%xmm8", "%xmm9",
"%xmm12", "%xmm13", "%xmm14", "%xmm15",
"memory"
);
} }
#define HAVE_KERNEL_4x2 #define HAVE_KERNEL_4x2
static void dgemv_kernel_4x2( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT *alpha) __attribute__ ((noinline));
static void dgemv_kernel_4x2( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT *alpha) static void dgemv_kernel_4x2( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT *alpha)
{ {
BLASLONG register i = 0; int i = 0;
__asm__ __volatile__ __m256d x0, x1;
( __m256d __alpha;
"vbroadcastsd (%2), %%ymm12 \n\t" // x0
"vbroadcastsd 8(%2), %%ymm13 \n\t" // x1
"vmovups (%4,%0,8), %%ymm0 \n\t" x0 = _mm256_broadcastsd_pd(_mm_load_sd(&x[0]));
"vmovups (%5,%0,8), %%ymm1 \n\t" x1 = _mm256_broadcastsd_pd(_mm_load_sd(&x[1]));
"vbroadcastsd (%6), %%ymm6 \n\t" // alpha __alpha = _mm256_broadcastsd_pd(_mm_load_sd(alpha));
"addq $4 , %0 \n\t"
"subq $4 , %1 \n\t"
"jz 2f \n\t"
"1: \n\t"
"vmulpd %%ymm0 , %%ymm12, %%ymm4 \n\t"
"vmulpd %%ymm1 , %%ymm13, %%ymm5 \n\t"
"vmovups (%4,%0,8), %%ymm0 \n\t"
"vmovups (%5,%0,8), %%ymm1 \n\t"
"vmovups -32(%3,%0,8), %%ymm8 \n\t" // 4 * y
"vaddpd %%ymm4 , %%ymm5 , %%ymm4 \n\t"
"vfmadd231pd %%ymm6 , %%ymm4 , %%ymm8 \n\t"
"vmovups %%ymm8, -32(%3,%0,8) \n\t" // 4 * y
"addq $4 , %0 \n\t"
"subq $4 , %1 \n\t"
"jnz 1b \n\t"
"2: \n\t"
"vmulpd %%ymm0 , %%ymm12, %%ymm4 \n\t"
"vmulpd %%ymm1 , %%ymm13, %%ymm5 \n\t"
"vmovups -32(%3,%0,8), %%ymm8 \n\t" // 4 * y for (i = 0; i < n; i+= 4) {
"vaddpd %%ymm4 , %%ymm5 , %%ymm4 \n\t" __m256d tempY;
"vfmadd231pd %%ymm6 , %%ymm4 , %%ymm8 \n\t" __m256d sum;
"vmovups %%ymm8, -32(%3,%0,8) \n\t" // 4 * y sum = _mm256_loadu_pd(&ap[0][i]) * x0 + _mm256_loadu_pd(&ap[1][i]) * x1;
tempY = _mm256_loadu_pd(&y[i]);
tempY += sum * __alpha;
_mm256_storeu_pd(&y[i], tempY);
}
"vzeroupper \n\t"
:
"+r" (i), // 0
"+r" (n) // 1
:
"r" (x), // 2
"r" (y), // 3
"r" (ap[0]), // 4
"r" (ap[1]), // 5
"r" (alpha) // 6
: "cc",
"%xmm0", "%xmm1",
"%xmm4", "%xmm5",
"%xmm6",
"%xmm8",
"%xmm12", "%xmm13",
"memory"
);
} }
#endif /* AVX2 */

View File

@ -25,182 +25,55 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
#define HAVE_KERNEL_8 1
static void dscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x) __attribute__ ((noinline)); #ifndef __AVX512CD__
#pragma GCC target("avx2,fma")
#endif
#ifdef __AVX2__
#include <immintrin.h>
#define HAVE_KERNEL_8 1
static void dscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x) static void dscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x)
{ {
int i = 0;
#ifdef __AVX512CD__
BLASLONG n1 = n >> 4 ; __m512d __alpha5 = _mm512_broadcastsd_pd(_mm_load_sd(alpha));
BLASLONG n2 = n & 8 ; for (; i < n; i += 8) {
_mm512_storeu_pd(&x[i + 0], __alpha5 * _mm512_loadu_pd(&x[i + 0]));
__asm__ __volatile__ }
( #else
"vmovddup (%2), %%xmm0 \n\t" // alpha __m256d __alpha = _mm256_broadcastsd_pd(_mm_load_sd(alpha));
for (; i < n; i += 8) {
"addq $128, %1 \n\t" _mm256_storeu_pd(&x[i + 0], __alpha * _mm256_loadu_pd(&x[i + 0]));
_mm256_storeu_pd(&x[i + 4], __alpha * _mm256_loadu_pd(&x[i + 4]));
"cmpq $0, %0 \n\t" }
"je 4f \n\t" #endif
"vmulpd -128(%1), %%xmm0, %%xmm4 \n\t"
"vmulpd -112(%1), %%xmm0, %%xmm5 \n\t"
"vmulpd -96(%1), %%xmm0, %%xmm6 \n\t"
"vmulpd -80(%1), %%xmm0, %%xmm7 \n\t"
"vmulpd -64(%1), %%xmm0, %%xmm8 \n\t"
"vmulpd -48(%1), %%xmm0, %%xmm9 \n\t"
"vmulpd -32(%1), %%xmm0, %%xmm10 \n\t"
"vmulpd -16(%1), %%xmm0, %%xmm11 \n\t"
"subq $1 , %0 \n\t"
"jz 2f \n\t"
".p2align 4 \n\t"
"1: \n\t"
// "prefetcht0 640(%1) \n\t"
"vmovups %%xmm4 ,-128(%1) \n\t"
"vmovups %%xmm5 ,-112(%1) \n\t"
"vmulpd 0(%1), %%xmm0, %%xmm4 \n\t"
"vmovups %%xmm6 , -96(%1) \n\t"
"vmulpd 16(%1), %%xmm0, %%xmm5 \n\t"
"vmovups %%xmm7 , -80(%1) \n\t"
"vmulpd 32(%1), %%xmm0, %%xmm6 \n\t"
// "prefetcht0 704(%1) \n\t"
"vmovups %%xmm8 , -64(%1) \n\t"
"vmulpd 48(%1), %%xmm0, %%xmm7 \n\t"
"vmovups %%xmm9 , -48(%1) \n\t"
"vmulpd 64(%1), %%xmm0, %%xmm8 \n\t"
"vmovups %%xmm10 , -32(%1) \n\t"
"vmulpd 80(%1), %%xmm0, %%xmm9 \n\t"
"vmovups %%xmm11 , -16(%1) \n\t"
"vmulpd 96(%1), %%xmm0, %%xmm10 \n\t"
"vmulpd 112(%1), %%xmm0, %%xmm11 \n\t"
"addq $128, %1 \n\t"
"subq $1 , %0 \n\t"
"jnz 1b \n\t"
"2: \n\t"
"vmovups %%xmm4 ,-128(%1) \n\t"
"vmovups %%xmm5 ,-112(%1) \n\t"
"vmovups %%xmm6 , -96(%1) \n\t"
"vmovups %%xmm7 , -80(%1) \n\t"
"vmovups %%xmm8 , -64(%1) \n\t"
"vmovups %%xmm9 , -48(%1) \n\t"
"vmovups %%xmm10 , -32(%1) \n\t"
"vmovups %%xmm11 , -16(%1) \n\t"
"addq $128, %1 \n\t"
"4: \n\t"
"cmpq $8 ,%3 \n\t"
"jne 5f \n\t"
"vmulpd -128(%1), %%xmm0, %%xmm4 \n\t"
"vmulpd -112(%1), %%xmm0, %%xmm5 \n\t"
"vmulpd -96(%1), %%xmm0, %%xmm6 \n\t"
"vmulpd -80(%1), %%xmm0, %%xmm7 \n\t"
"vmovups %%xmm4 ,-128(%1) \n\t"
"vmovups %%xmm5 ,-112(%1) \n\t"
"vmovups %%xmm6 , -96(%1) \n\t"
"vmovups %%xmm7 , -80(%1) \n\t"
"5: \n\t"
"vzeroupper \n\t"
:
:
"r" (n1), // 0
"r" (x), // 1
"r" (alpha), // 2
"r" (n2) // 3
: "cc",
"%xmm0", "%xmm1", "%xmm2", "%xmm3",
"%xmm4", "%xmm5", "%xmm6", "%xmm7",
"%xmm8", "%xmm9", "%xmm10", "%xmm11",
"%xmm12", "%xmm13", "%xmm14", "%xmm15",
"memory"
);
} }
static void dscal_kernel_8_zero( BLASLONG n, FLOAT *alpha, FLOAT *x) __attribute__ ((noinline));
static void dscal_kernel_8_zero( BLASLONG n, FLOAT *alpha, FLOAT *x) static void dscal_kernel_8_zero( BLASLONG n, FLOAT *alpha, FLOAT *x)
{ {
int i = 0;
/* question to self: Why is this not just memset() */
BLASLONG n1 = n >> 4 ; #ifdef __AVX512CD__
BLASLONG n2 = n & 8 ; __m512d zero = _mm512_setzero_pd();
for (; i < n; i += 8) {
__asm__ __volatile__ _mm512_storeu_pd(&x[i], zero);
( }
"vxorpd %%xmm0, %%xmm0 , %%xmm0 \n\t" #else
__m256d zero = _mm256_setzero_pd();
"addq $128, %1 \n\t" for (; i < n; i += 8) {
_mm256_storeu_pd(&x[i + 0], zero);
"cmpq $0, %0 \n\t" _mm256_storeu_pd(&x[i + 4], zero);
"je 2f \n\t" }
#endif
".p2align 4 \n\t"
"1: \n\t"
"vmovups %%xmm0 ,-128(%1) \n\t"
"vmovups %%xmm0 ,-112(%1) \n\t"
"vmovups %%xmm0 , -96(%1) \n\t"
"vmovups %%xmm0 , -80(%1) \n\t"
"vmovups %%xmm0 , -64(%1) \n\t"
"vmovups %%xmm0 , -48(%1) \n\t"
"vmovups %%xmm0 , -32(%1) \n\t"
"vmovups %%xmm0 , -16(%1) \n\t"
"addq $128, %1 \n\t"
"subq $1 , %0 \n\t"
"jnz 1b \n\t"
"2: \n\t"
"cmpq $8 ,%3 \n\t"
"jne 4f \n\t"
"vmovups %%xmm0 ,-128(%1) \n\t"
"vmovups %%xmm0 ,-112(%1) \n\t"
"vmovups %%xmm0 , -96(%1) \n\t"
"vmovups %%xmm0 , -80(%1) \n\t"
"4: \n\t"
"vzeroupper \n\t"
:
:
"r" (n1), // 0
"r" (x), // 1
"r" (alpha), // 2
"r" (n2) // 3
: "cc",
"%xmm0", "%xmm1", "%xmm2", "%xmm3",
"%xmm4", "%xmm5", "%xmm6", "%xmm7",
"%xmm8", "%xmm9", "%xmm10", "%xmm11",
"%xmm12", "%xmm13", "%xmm14", "%xmm15",
"memory"
);
} }
#endif

View File

@ -25,105 +25,140 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already */
#ifndef __AVX512CD__
#pragma GCC target("avx2,fma")
#endif
#ifdef __AVX2__
#include <immintrin.h>
#define HAVE_KERNEL_4x4 1 #define HAVE_KERNEL_4x4 1
static void dsymv_kernel_4x4( BLASLONG from, BLASLONG to, FLOAT **a, FLOAT *x, FLOAT *y, FLOAT *temp1, FLOAT *temp2) __attribute__ ((noinline));
static void dsymv_kernel_4x4(BLASLONG from, BLASLONG to, FLOAT **a, FLOAT *x, FLOAT *y, FLOAT *temp1, FLOAT *temp2) static void dsymv_kernel_4x4(BLASLONG from, BLASLONG to, FLOAT **a, FLOAT *x, FLOAT *y, FLOAT *temp1, FLOAT *temp2)
{ {
__asm__ __volatile__ __m256d accum_0, accum_1, accum_2, accum_3;
( __m256d temp1_0, temp1_1, temp1_2, temp1_3;
"vzeroupper \n\t"
"vxorpd %%ymm0 , %%ymm0 , %%ymm0 \n\t" // temp2[0]
"vxorpd %%ymm1 , %%ymm1 , %%ymm1 \n\t" // temp2[1]
"vxorpd %%ymm2 , %%ymm2 , %%ymm2 \n\t" // temp2[2]
"vxorpd %%ymm3 , %%ymm3 , %%ymm3 \n\t" // temp2[3]
"vbroadcastsd (%8), %%ymm4 \n\t" // temp1[0]
"vbroadcastsd 8(%8), %%ymm5 \n\t" // temp1[1]
"vbroadcastsd 16(%8), %%ymm6 \n\t" // temp1[1]
"vbroadcastsd 24(%8), %%ymm7 \n\t" // temp1[1]
".p2align 4 \n\t" /* the 256 bit wide acculmulator vectors start out as zero */
"1: \n\t" accum_0 = _mm256_setzero_pd();
accum_1 = _mm256_setzero_pd();
accum_2 = _mm256_setzero_pd();
accum_3 = _mm256_setzero_pd();
"vmovups (%3,%0,8), %%ymm9 \n\t" // 2 * y temp1_0 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[0]));
"vmovups (%2,%0,8), %%ymm8 \n\t" // 2 * x temp1_1 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[1]));
temp1_2 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[2]));
temp1_3 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[3]));
"vmovups (%4,%0,8), %%ymm12 \n\t" // 2 * a #ifdef __AVX512CD__
"vmovups (%5,%0,8), %%ymm13 \n\t" // 2 * a __m512d accum_05, accum_15, accum_25, accum_35;
"vmovups (%6,%0,8), %%ymm14 \n\t" // 2 * a __m512d temp1_05, temp1_15, temp1_25, temp1_35;
"vmovups (%7,%0,8), %%ymm15 \n\t" // 2 * a BLASLONG to2;
int delta;
"vfmadd231pd %%ymm4, %%ymm12 , %%ymm9 \n\t" // y += temp1 * a /* the 512 bit wide accumulator vectors start out as zero */
"vfmadd231pd %%ymm8, %%ymm12 , %%ymm0 \n\t" // temp2 += x * a accum_05 = _mm512_setzero_pd();
accum_15 = _mm512_setzero_pd();
accum_25 = _mm512_setzero_pd();
accum_35 = _mm512_setzero_pd();
"vfmadd231pd %%ymm5, %%ymm13 , %%ymm9 \n\t" // y += temp1 * a temp1_05 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[0]));
"vfmadd231pd %%ymm8, %%ymm13 , %%ymm1 \n\t" // temp2 += x * a temp1_15 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[1]));
temp1_25 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[2]));
temp1_35 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[3]));
"vfmadd231pd %%ymm6, %%ymm14 , %%ymm9 \n\t" // y += temp1 * a delta = (to - from) & ~7;
"vfmadd231pd %%ymm8, %%ymm14 , %%ymm2 \n\t" // temp2 += x * a to2 = from + delta;
"vfmadd231pd %%ymm7, %%ymm15 , %%ymm9 \n\t" // y += temp1 * a
"vfmadd231pd %%ymm8, %%ymm15 , %%ymm3 \n\t" // temp2 += x * a
"addq $4 , %0 \n\t"
"vmovups %%ymm9 , -32(%3,%0,8) \n\t" for (; from < to2; from += 8) {
__m512d _x, _y;
__m512d a0, a1, a2, a3;
"cmpq %0 , %1 \n\t" _y = _mm512_loadu_pd(&y[from]);
"jnz 1b \n\t" _x = _mm512_loadu_pd(&x[from]);
"vmovsd (%9), %%xmm4 \n\t" a0 = _mm512_loadu_pd(&a[0][from]);
"vmovsd 8(%9), %%xmm5 \n\t" a1 = _mm512_loadu_pd(&a[1][from]);
"vmovsd 16(%9), %%xmm6 \n\t" a2 = _mm512_loadu_pd(&a[2][from]);
"vmovsd 24(%9), %%xmm7 \n\t" a3 = _mm512_loadu_pd(&a[3][from]);
"vextractf128 $0x01, %%ymm0 , %%xmm12 \n\t" _y += temp1_05 * a0 + temp1_15 * a1 + temp1_25 * a2 + temp1_35 * a3;
"vextractf128 $0x01, %%ymm1 , %%xmm13 \n\t"
"vextractf128 $0x01, %%ymm2 , %%xmm14 \n\t"
"vextractf128 $0x01, %%ymm3 , %%xmm15 \n\t"
"vaddpd %%xmm0, %%xmm12, %%xmm0 \n\t" accum_05 += _x * a0;
"vaddpd %%xmm1, %%xmm13, %%xmm1 \n\t" accum_15 += _x * a1;
"vaddpd %%xmm2, %%xmm14, %%xmm2 \n\t" accum_25 += _x * a2;
"vaddpd %%xmm3, %%xmm15, %%xmm3 \n\t" accum_35 += _x * a3;
"vhaddpd %%xmm0, %%xmm0, %%xmm0 \n\t" _mm512_storeu_pd(&y[from], _y);
"vhaddpd %%xmm1, %%xmm1, %%xmm1 \n\t"
"vhaddpd %%xmm2, %%xmm2, %%xmm2 \n\t"
"vhaddpd %%xmm3, %%xmm3, %%xmm3 \n\t"
"vaddsd %%xmm4, %%xmm0, %%xmm0 \n\t" };
"vaddsd %%xmm5, %%xmm1, %%xmm1 \n\t"
"vaddsd %%xmm6, %%xmm2, %%xmm2 \n\t"
"vaddsd %%xmm7, %%xmm3, %%xmm3 \n\t"
"vmovsd %%xmm0 , (%9) \n\t" // save temp2 /*
"vmovsd %%xmm1 , 8(%9) \n\t" // save temp2 * we need to fold our 512 bit wide accumulator vectors into 256 bit wide vectors so that the AVX2 code
"vmovsd %%xmm2 ,16(%9) \n\t" // save temp2 * below can continue using the intermediate results in its loop
"vmovsd %%xmm3 ,24(%9) \n\t" // save temp2 */
"vzeroupper \n\t" accum_0 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_05, 0), _mm512_extractf64x4_pd(accum_05, 1));
accum_1 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_15, 0), _mm512_extractf64x4_pd(accum_15, 1));
accum_2 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_25, 0), _mm512_extractf64x4_pd(accum_25, 1));
accum_3 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_35, 0), _mm512_extractf64x4_pd(accum_35, 1));
: #endif
:
"r" (from), // 0
"r" (to), // 1
"r" (x), // 2
"r" (y), // 3
"r" (a[0]), // 4
"r" (a[1]), // 5
"r" (a[2]), // 6
"r" (a[3]), // 8
"r" (temp1), // 8
"r" (temp2) // 9
: "cc",
"%xmm0", "%xmm1", "%xmm2", "%xmm3",
"%xmm4", "%xmm5", "%xmm6", "%xmm7",
"%xmm8", "%xmm9", "%xmm10", "%xmm11",
"%xmm12", "%xmm13", "%xmm14", "%xmm15",
"memory"
);
for (; from != to; from += 4) {
__m256d _x, _y;
__m256d a0, a1, a2, a3;
_y = _mm256_loadu_pd(&y[from]);
_x = _mm256_loadu_pd(&x[from]);
/* load 4 rows of matrix data */
a0 = _mm256_loadu_pd(&a[0][from]);
a1 = _mm256_loadu_pd(&a[1][from]);
a2 = _mm256_loadu_pd(&a[2][from]);
a3 = _mm256_loadu_pd(&a[3][from]);
_y += temp1_0 * a0 + temp1_1 * a1 + temp1_2 * a2 + temp1_3 * a3;
accum_0 += _x * a0;
accum_1 += _x * a1;
accum_2 += _x * a2;
accum_3 += _x * a3;
_mm256_storeu_pd(&y[from], _y);
};
/*
* we now have 4 accumulator vectors. Each vector needs to be summed up element wise and stored in the temp2
* output array. There is no direct instruction for this in 256 bit space, only in 128 space.
*/
__m128d half_accum0, half_accum1, half_accum2, half_accum3;
/* Add upper half to lower half of each of the four 256 bit vectors to get to four 128 bit vectors */
half_accum0 = _mm_add_pd(_mm256_extractf128_pd(accum_0, 0), _mm256_extractf128_pd(accum_0, 1));
half_accum1 = _mm_add_pd(_mm256_extractf128_pd(accum_1, 0), _mm256_extractf128_pd(accum_1, 1));
half_accum2 = _mm_add_pd(_mm256_extractf128_pd(accum_2, 0), _mm256_extractf128_pd(accum_2, 1));
half_accum3 = _mm_add_pd(_mm256_extractf128_pd(accum_3, 0), _mm256_extractf128_pd(accum_3, 1));
/* in 128 bit land there is a hadd operation to do the rest of the element-wise sum in one go */
half_accum0 = _mm_hadd_pd(half_accum0, half_accum0);
half_accum1 = _mm_hadd_pd(half_accum1, half_accum1);
half_accum2 = _mm_hadd_pd(half_accum2, half_accum2);
half_accum3 = _mm_hadd_pd(half_accum3, half_accum3);
/* and store the lowest double value from each of these vectors in the temp2 output */
temp2[0] += half_accum0[0];
temp2[1] += half_accum1[0];
temp2[2] += half_accum2[0];
temp2[3] += half_accum3[0];
} }
#endif

View File

@ -25,54 +25,47 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
#ifndef __AVX512CD__
#pragma GCC target("avx2,fma")
#endif
#ifdef __AVX2__
#define HAVE_KERNEL_16 1 #define HAVE_KERNEL_16 1
static void saxpy_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *alpha) __attribute__ ((noinline));
#include <immintrin.h>
static void saxpy_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *alpha) static void saxpy_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *alpha)
{ {
BLASLONG i = 0;
__m256 __alpha;
BLASLONG register i = 0; __alpha = _mm256_broadcastss_ps(_mm_load_ss(alpha));
__asm__ __volatile__ #ifdef __AVX512CD__
( BLASLONG n64;
"vbroadcastss (%4), %%ymm0 \n\t" // alpha __m512 __alpha5;
__alpha5 = _mm512_broadcastss_ps(_mm_load_ss(alpha));
".p2align 4 \n\t" n64 = n & ~63;
"1: \n\t"
"vmovups (%3,%0,4), %%ymm12 \n\t" // 8 * y for (; i < n64; i+= 64) {
"vmovups 32(%3,%0,4), %%ymm13 \n\t" // 8 * y _mm512_storeu_ps(&y[i + 0], _mm512_loadu_ps(&y[i + 0]) + __alpha5 * _mm512_loadu_ps(&x[i + 0]));
"vmovups 64(%3,%0,4), %%ymm14 \n\t" // 8 * y _mm512_storeu_ps(&y[i + 16], _mm512_loadu_ps(&y[i + 16]) + __alpha5 * _mm512_loadu_ps(&x[i + 16]));
"vmovups 96(%3,%0,4), %%ymm15 \n\t" // 8 * y _mm512_storeu_ps(&y[i + 32], _mm512_loadu_ps(&y[i + 32]) + __alpha5 * _mm512_loadu_ps(&x[i + 32]));
"vfmadd231ps (%2,%0,4), %%ymm0 , %%ymm12 \n\t" // y += alpha * x _mm512_storeu_ps(&y[i + 48], _mm512_loadu_ps(&y[i + 48]) + __alpha5 * _mm512_loadu_ps(&x[i + 48]));
"vfmadd231ps 32(%2,%0,4), %%ymm0 , %%ymm13 \n\t" // y += alpha * x }
"vfmadd231ps 64(%2,%0,4), %%ymm0 , %%ymm14 \n\t" // y += alpha * x
"vfmadd231ps 96(%2,%0,4), %%ymm0 , %%ymm15 \n\t" // y += alpha * x
"vmovups %%ymm12, (%3,%0,4) \n\t"
"vmovups %%ymm13, 32(%3,%0,4) \n\t"
"vmovups %%ymm14, 64(%3,%0,4) \n\t"
"vmovups %%ymm15, 96(%3,%0,4) \n\t"
"addq $32, %0 \n\t" #endif
"subq $32, %1 \n\t"
"jnz 1b \n\t"
"vzeroupper \n\t"
:
:
"r" (i), // 0
"r" (n), // 1
"r" (x), // 2
"r" (y), // 3
"r" (alpha) // 4
: "cc",
"%xmm0",
"%xmm8", "%xmm9", "%xmm10", "%xmm11",
"%xmm12", "%xmm13", "%xmm14", "%xmm15",
"memory"
);
}
for (; i < n; i+= 32) {
_mm256_storeu_ps(&y[i + 0], _mm256_loadu_ps(&y[i + 0]) + __alpha * _mm256_loadu_ps(&x[i + 0]));
_mm256_storeu_ps(&y[i + 8], _mm256_loadu_ps(&y[i + 8]) + __alpha * _mm256_loadu_ps(&x[i + 8]));
_mm256_storeu_ps(&y[i + 16], _mm256_loadu_ps(&y[i + 16]) + __alpha * _mm256_loadu_ps(&x[i + 16]));
_mm256_storeu_ps(&y[i + 24], _mm256_loadu_ps(&y[i + 24]) + __alpha * _mm256_loadu_ps(&x[i + 24]));
}
}
#endif

View File

@ -25,74 +25,75 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
#define HAVE_KERNEL_16 1 #ifndef __AVX512CD__
static void sdot_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *dot) __attribute__ ((noinline)); #pragma GCC target("avx2,fma")
static void sdot_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
{
BLASLONG register i = 0;
__asm__ __volatile__
(
"vxorps %%ymm4, %%ymm4, %%ymm4 \n\t"
"vxorps %%ymm5, %%ymm5, %%ymm5 \n\t"
"vxorps %%ymm6, %%ymm6, %%ymm6 \n\t"
"vxorps %%ymm7, %%ymm7, %%ymm7 \n\t"
".p2align 4 \n\t"
"1: \n\t"
"vmovups (%2,%0,4), %%ymm12 \n\t" // 2 * x
"vmovups 32(%2,%0,4), %%ymm13 \n\t" // 2 * x
"vmovups 64(%2,%0,4), %%ymm14 \n\t" // 2 * x
"vmovups 96(%2,%0,4), %%ymm15 \n\t" // 2 * x
"vfmadd231ps (%3,%0,4), %%ymm12, %%ymm4 \n\t" // 2 * y
"vfmadd231ps 32(%3,%0,4), %%ymm13, %%ymm5 \n\t" // 2 * y
"vfmadd231ps 64(%3,%0,4), %%ymm14, %%ymm6 \n\t" // 2 * y
"vfmadd231ps 96(%3,%0,4), %%ymm15, %%ymm7 \n\t" // 2 * y
#ifndef DSDOT
"addq $32 , %0 \n\t"
"subq $32 , %1 \n\t"
"jnz 1b \n\t"
#endif #endif
"vextractf128 $1 , %%ymm4 , %%xmm12 \n\t" #ifdef __AVX2__
"vextractf128 $1 , %%ymm5 , %%xmm13 \n\t"
"vextractf128 $1 , %%ymm6 , %%xmm14 \n\t"
"vextractf128 $1 , %%ymm7 , %%xmm15 \n\t"
"vaddps %%xmm4, %%xmm12, %%xmm4 \n\t" #define HAVE_KERNEL_16 1
"vaddps %%xmm5, %%xmm13, %%xmm5 \n\t"
"vaddps %%xmm6, %%xmm14, %%xmm6 \n\t"
"vaddps %%xmm7, %%xmm15, %%xmm7 \n\t"
"vaddps %%xmm4, %%xmm5, %%xmm4 \n\t" #include <immintrin.h>
"vaddps %%xmm6, %%xmm7, %%xmm6 \n\t"
"vaddps %%xmm4, %%xmm6, %%xmm4 \n\t"
"vhaddps %%xmm4, %%xmm4, %%xmm4 \n\t" static void sdot_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
"vhaddps %%xmm4, %%xmm4, %%xmm4 \n\t"
"vmovss %%xmm4, (%4) \n\t" {
"vzeroupper \n\t" int i = 0;
__m256 accum_0, accum_1, accum_2, accum_3;
: accum_0 = _mm256_setzero_ps();
: accum_1 = _mm256_setzero_ps();
"r" (i), // 0 accum_2 = _mm256_setzero_ps();
"r" (n), // 1 accum_3 = _mm256_setzero_ps();
"r" (x), // 2
"r" (y), // 3
"r" (dot) // 4
: "cc",
"%xmm4", "%xmm5",
"%xmm6", "%xmm7",
"%xmm12", "%xmm13", "%xmm14", "%xmm15",
"memory"
);
} #ifdef __AVX512CD__
__m512 accum_05, accum_15, accum_25, accum_35;
int n64;
n64 = n & (~63);
accum_05 = _mm512_setzero_ps();
accum_15 = _mm512_setzero_ps();
accum_25 = _mm512_setzero_ps();
accum_35 = _mm512_setzero_ps();
for (; i < n64; i += 64) {
accum_05 += _mm512_loadu_ps(&x[i+ 0]) * _mm512_loadu_ps(&y[i+ 0]);
accum_15 += _mm512_loadu_ps(&x[i+16]) * _mm512_loadu_ps(&y[i+16]);
accum_25 += _mm512_loadu_ps(&x[i+32]) * _mm512_loadu_ps(&y[i+32]);
accum_35 += _mm512_loadu_ps(&x[i+48]) * _mm512_loadu_ps(&y[i+48]);
}
/*
* we need to fold our 512 bit wide accumulator vectors into 256 bit wide vectors so that the AVX2 code
* below can continue using the intermediate results in its loop
*/
accum_0 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_05, 0), _mm512_extractf32x8_ps(accum_05, 1));
accum_1 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_15, 0), _mm512_extractf32x8_ps(accum_15, 1));
accum_2 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_25, 0), _mm512_extractf32x8_ps(accum_25, 1));
accum_3 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_35, 0), _mm512_extractf32x8_ps(accum_35, 1));
#endif
for (; i < n; i += 32) {
accum_0 += _mm256_loadu_ps(&x[i+ 0]) * _mm256_loadu_ps(&y[i+ 0]);
accum_1 += _mm256_loadu_ps(&x[i+ 8]) * _mm256_loadu_ps(&y[i+ 8]);
accum_2 += _mm256_loadu_ps(&x[i+16]) * _mm256_loadu_ps(&y[i+16]);
accum_3 += _mm256_loadu_ps(&x[i+24]) * _mm256_loadu_ps(&y[i+24]);
}
/* we now have the partial sums of the dot product in the 4 accumulation vectors, time to consolidate */
accum_0 = accum_0 + accum_1 + accum_2 + accum_3;
__m128 half_accum0;
/* Add upper half to lower half of each of the 256 bit vector to get a 128 bit vector */
half_accum0 = _mm_add_ps(_mm256_extractf128_ps(accum_0, 0), _mm256_extractf128_ps(accum_0, 1));
/* in 128 bit land there is a hadd operation to do the rest of the element-wise sum in one go */
half_accum0 = _mm_hadd_ps(half_accum0, half_accum0);
half_accum0 = _mm_hadd_ps(half_accum0, half_accum0);
*dot = half_accum0[0];
}
#endif