sdot_haswell: similar to ddot: turn into intrinsics based C code that supports AVX512

do the same thing for SDOT that the previous patches did for DDOT; the perf gain
is in the 60% range so at least somewhat interesting
This commit is contained in:
Arjan van de Ven 2018-08-05 16:38:19 +00:00
parent 21c6220d63
commit ef30a7239c
1 changed files with 62 additions and 61 deletions

View File

@ -25,74 +25,75 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
#define HAVE_KERNEL_16 1 #ifndef __AVX512CD__
static void sdot_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *dot) __attribute__ ((noinline)); #pragma GCC target("avx2,fma")
static void sdot_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
{
BLASLONG register i = 0;
__asm__ __volatile__
(
"vxorps %%ymm4, %%ymm4, %%ymm4 \n\t"
"vxorps %%ymm5, %%ymm5, %%ymm5 \n\t"
"vxorps %%ymm6, %%ymm6, %%ymm6 \n\t"
"vxorps %%ymm7, %%ymm7, %%ymm7 \n\t"
".p2align 4 \n\t"
"1: \n\t"
"vmovups (%2,%0,4), %%ymm12 \n\t" // 2 * x
"vmovups 32(%2,%0,4), %%ymm13 \n\t" // 2 * x
"vmovups 64(%2,%0,4), %%ymm14 \n\t" // 2 * x
"vmovups 96(%2,%0,4), %%ymm15 \n\t" // 2 * x
"vfmadd231ps (%3,%0,4), %%ymm12, %%ymm4 \n\t" // 2 * y
"vfmadd231ps 32(%3,%0,4), %%ymm13, %%ymm5 \n\t" // 2 * y
"vfmadd231ps 64(%3,%0,4), %%ymm14, %%ymm6 \n\t" // 2 * y
"vfmadd231ps 96(%3,%0,4), %%ymm15, %%ymm7 \n\t" // 2 * y
#ifndef DSDOT
"addq $32 , %0 \n\t"
"subq $32 , %1 \n\t"
"jnz 1b \n\t"
#endif #endif
"vextractf128 $1 , %%ymm4 , %%xmm12 \n\t" #ifdef __AVX2__
"vextractf128 $1 , %%ymm5 , %%xmm13 \n\t"
"vextractf128 $1 , %%ymm6 , %%xmm14 \n\t"
"vextractf128 $1 , %%ymm7 , %%xmm15 \n\t"
"vaddps %%xmm4, %%xmm12, %%xmm4 \n\t" #define HAVE_KERNEL_16 1
"vaddps %%xmm5, %%xmm13, %%xmm5 \n\t"
"vaddps %%xmm6, %%xmm14, %%xmm6 \n\t"
"vaddps %%xmm7, %%xmm15, %%xmm7 \n\t"
"vaddps %%xmm4, %%xmm5, %%xmm4 \n\t" #include <immintrin.h>
"vaddps %%xmm6, %%xmm7, %%xmm6 \n\t"
"vaddps %%xmm4, %%xmm6, %%xmm4 \n\t"
"vhaddps %%xmm4, %%xmm4, %%xmm4 \n\t" static void sdot_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
"vhaddps %%xmm4, %%xmm4, %%xmm4 \n\t"
"vmovss %%xmm4, (%4) \n\t" {
"vzeroupper \n\t" int i = 0;
__m256 accum_0, accum_1, accum_2, accum_3;
: accum_0 = _mm256_setzero_ps();
: accum_1 = _mm256_setzero_ps();
"r" (i), // 0 accum_2 = _mm256_setzero_ps();
"r" (n), // 1 accum_3 = _mm256_setzero_ps();
"r" (x), // 2
"r" (y), // 3
"r" (dot) // 4
: "cc",
"%xmm4", "%xmm5",
"%xmm6", "%xmm7",
"%xmm12", "%xmm13", "%xmm14", "%xmm15",
"memory"
);
#ifdef __AVX512CD__
__m512 accum_05, accum_15, accum_25, accum_35;
int n64;
n64 = n & (~63);
accum_05 = _mm512_setzero_ps();
accum_15 = _mm512_setzero_ps();
accum_25 = _mm512_setzero_ps();
accum_35 = _mm512_setzero_ps();
for (; i < n64; i += 64) {
accum_05 += _mm512_loadu_ps(&x[i+ 0]) * _mm512_loadu_ps(&y[i+ 0]);
accum_15 += _mm512_loadu_ps(&x[i+16]) * _mm512_loadu_ps(&y[i+16]);
accum_25 += _mm512_loadu_ps(&x[i+32]) * _mm512_loadu_ps(&y[i+32]);
accum_35 += _mm512_loadu_ps(&x[i+48]) * _mm512_loadu_ps(&y[i+48]);
}
/*
* we need to fold our 512 bit wide accumulator vectors into 256 bit wide vectors so that the AVX2 code
* below can continue using the intermediate results in its loop
*/
accum_0 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_05, 0), _mm512_extractf32x8_ps(accum_05, 1));
accum_1 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_15, 0), _mm512_extractf32x8_ps(accum_15, 1));
accum_2 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_25, 0), _mm512_extractf32x8_ps(accum_25, 1));
accum_3 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_35, 0), _mm512_extractf32x8_ps(accum_35, 1));
#endif
for (; i < n; i += 32) {
accum_0 += _mm256_loadu_ps(&x[i+ 0]) * _mm256_loadu_ps(&y[i+ 0]);
accum_1 += _mm256_loadu_ps(&x[i+ 8]) * _mm256_loadu_ps(&y[i+ 8]);
accum_2 += _mm256_loadu_ps(&x[i+16]) * _mm256_loadu_ps(&y[i+16]);
accum_3 += _mm256_loadu_ps(&x[i+24]) * _mm256_loadu_ps(&y[i+24]);
}
/* we now have the partial sums of the dot product in the 4 accumulation vectors, time to consolidate */
accum_0 = accum_0 + accum_1 + accum_2 + accum_3;
__m128 half_accum0;
/* Add upper half to lower half of each of the 256 bit vector to get a 128 bit vector */
half_accum0 = _mm_add_ps(_mm256_extractf128_ps(accum_0, 0), _mm256_extractf128_ps(accum_0, 1));
/* in 128 bit land there is a hadd operation to do the rest of the element-wise sum in one go */
half_accum0 = _mm_hadd_ps(half_accum0, half_accum0);
half_accum0 = _mm_hadd_ps(half_accum0, half_accum0);
*dot = half_accum0[0];
} }
#endif