diff --git a/kernel/x86_64/ddot_microk_haswell-2.c b/kernel/x86_64/ddot_microk_haswell-2.c index 8feeb79a1..b9be83af6 100644 --- a/kernel/x86_64/ddot_microk_haswell-2.c +++ b/kernel/x86_64/ddot_microk_haswell-2.c @@ -26,7 +26,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ /* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already */ -#ifndef __AVX512CD_ + +#ifndef __AVX512CD__ #pragma GCC target("avx2,fma") #endif @@ -35,7 +36,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define HAVE_KERNEL_8 1 #include -static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *dot) __attribute__ ((noinline)); static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot) { @@ -47,10 +47,37 @@ static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot) accum_2 = _mm256_setzero_pd(); accum_3 = _mm256_setzero_pd(); +#ifdef __AVX512CD__ + __m512d accum_05, accum_15, accum_25, accum_35; + int n32; + n32 = n & (~31); + + accum_05 = _mm512_setzero_pd(); + accum_15 = _mm512_setzero_pd(); + accum_25 = _mm512_setzero_pd(); + accum_35 = _mm512_setzero_pd(); + + for (; i < n32; i += 32) { + accum_05 += _mm512_loadu_pd(&x[i+ 0]) * _mm512_loadu_pd(&y[i+ 0]); + accum_15 += _mm512_loadu_pd(&x[i+ 8]) * _mm512_loadu_pd(&y[i+ 8]); + accum_25 += _mm512_loadu_pd(&x[i+16]) * _mm512_loadu_pd(&y[i+16]); + accum_35 += _mm512_loadu_pd(&x[i+24]) * _mm512_loadu_pd(&y[i+24]); + } + + /* + * we need to fold our 512 bit wide accumulator vectors into 256 bit wide vectors so that the AVX2 code + * below can continue using the intermediate results in its loop + */ + accum_0 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_05, 0), _mm512_extractf64x4_pd(accum_05, 1)); + accum_1 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_15, 0), _mm512_extractf64x4_pd(accum_15, 1)); + accum_2 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_25, 0), _mm512_extractf64x4_pd(accum_25, 1)); + accum_3 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_35, 0), _mm512_extractf64x4_pd(accum_35, 1)); + +#endif for (; i < n; i += 16) { - accum_0 += _mm256_loadu_pd(&x[i+ 0]) * _mm256_loadu_pd(&y[i+0]); - accum_1 += _mm256_loadu_pd(&x[i+ 4]) * _mm256_loadu_pd(&y[i+4]); - accum_2 += _mm256_loadu_pd(&x[i+ 8]) * _mm256_loadu_pd(&y[i+8]); + accum_0 += _mm256_loadu_pd(&x[i+ 0]) * _mm256_loadu_pd(&y[i+ 0]); + accum_1 += _mm256_loadu_pd(&x[i+ 4]) * _mm256_loadu_pd(&y[i+ 4]); + accum_2 += _mm256_loadu_pd(&x[i+ 8]) * _mm256_loadu_pd(&y[i+ 8]); accum_3 += _mm256_loadu_pd(&x[i+12]) * _mm256_loadu_pd(&y[i+12]); }