This commit is contained in:
Arjan van de Ven 2018-08-09 14:28:34 +00:00 committed by GitHub
commit 9b56e815e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 36 additions and 55 deletions

View File

@ -25,71 +25,52 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already */
#ifndef __AVX512CD_
#if )defined(__GNUC__) && __GNUC__ < 6)
#pragma GCC target("avx")
#else
#pragma GCC target("avx2,fma")
#endif
#endif
#ifdef __AVX__
#define HAVE_KERNEL_8 1 #define HAVE_KERNEL_8 1
#include <immintrin.h>
static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *dot) __attribute__ ((noinline)); static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *dot) __attribute__ ((noinline));
static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot) static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
{ {
int i = 0;
__m256d accum_0, accum_1, accum_2, accum_3;
accum_0 = _mm256_setzero_pd();
accum_1 = _mm256_setzero_pd();
accum_2 = _mm256_setzero_pd();
accum_3 = _mm256_setzero_pd();
for (; i < n; i += 16) {
accum_0 += _mm256_loadu_pd(&x[i+ 0]) * _mm256_loadu_pd(&y[i+0]);
accum_1 += _mm256_loadu_pd(&x[i+ 4]) * _mm256_loadu_pd(&y[i+4]);
accum_2 += _mm256_loadu_pd(&x[i+ 8]) * _mm256_loadu_pd(&y[i+8]);
accum_3 += _mm256_loadu_pd(&x[i+12]) * _mm256_loadu_pd(&y[i+12]);
}
BLASLONG register i = 0; /* we now have the partial sums of the dot product in the 4 accumulation vectors, time to consolidate */
__asm__ __volatile__ accum_0 = accum_0 + accum_1 + accum_2 + accum_3;
(
"vxorpd %%ymm4, %%ymm4, %%ymm4 \n\t"
"vxorpd %%ymm5, %%ymm5, %%ymm5 \n\t"
"vxorpd %%ymm6, %%ymm6, %%ymm6 \n\t"
"vxorpd %%ymm7, %%ymm7, %%ymm7 \n\t"
".p2align 4 \n\t" __m128d half_accum0;
"1: \n\t"
"vmovups (%2,%0,8), %%ymm12 \n\t" // 2 * x
"vmovups 32(%2,%0,8), %%ymm13 \n\t" // 2 * x
"vmovups 64(%2,%0,8), %%ymm14 \n\t" // 2 * x
"vmovups 96(%2,%0,8), %%ymm15 \n\t" // 2 * x
"vfmadd231pd (%3,%0,8), %%ymm12, %%ymm4 \n\t" // 2 * y /* Add upper half to lower half of each of the 256 bit vector to get a 128 bit vector */
"vfmadd231pd 32(%3,%0,8), %%ymm13, %%ymm5 \n\t" // 2 * y half_accum0 = _mm256_extractf128_pd(accum_0, 0) + _mm256_extractf128_pd(accum_0, 1);
"vfmadd231pd 64(%3,%0,8), %%ymm14, %%ymm6 \n\t" // 2 * y
"vfmadd231pd 96(%3,%0,8), %%ymm15, %%ymm7 \n\t" // 2 * y
"addq $16 , %0 \n\t" /* in 128 bit land there is a hadd operation to do the rest of the element-wise sum in one go */
"subq $16 , %1 \n\t" half_accum0 = _mm_hadd_pd(half_accum0, half_accum0);
"jnz 1b \n\t"
"vextractf128 $1 , %%ymm4 , %%xmm12 \n\t"
"vextractf128 $1 , %%ymm5 , %%xmm13 \n\t"
"vextractf128 $1 , %%ymm6 , %%xmm14 \n\t"
"vextractf128 $1 , %%ymm7 , %%xmm15 \n\t"
"vaddpd %%xmm4, %%xmm12, %%xmm4 \n\t"
"vaddpd %%xmm5, %%xmm13, %%xmm5 \n\t"
"vaddpd %%xmm6, %%xmm14, %%xmm6 \n\t"
"vaddpd %%xmm7, %%xmm15, %%xmm7 \n\t"
"vaddpd %%xmm4, %%xmm5, %%xmm4 \n\t"
"vaddpd %%xmm6, %%xmm7, %%xmm6 \n\t"
"vaddpd %%xmm4, %%xmm6, %%xmm4 \n\t"
"vhaddpd %%xmm4, %%xmm4, %%xmm4 \n\t"
"vmovsd %%xmm4, (%4) \n\t"
"vzeroupper \n\t"
:
:
"r" (i), // 0
"r" (n), // 1
"r" (x), // 2
"r" (y), // 3
"r" (dot) // 4
: "cc",
"%xmm4", "%xmm5",
"%xmm6", "%xmm7",
"%xmm12", "%xmm13", "%xmm14", "%xmm15",
"memory"
);
}
*dot = half_accum0[0];
}
#endif