sdot_haswell: similar to ddot: turn into intrinsics based C code that supports AVX512
do the same thing for SDOT that the previous patches did for DDOT; the perf gain is in the 60% range so at least somewhat interesting
This commit is contained in:
parent
21c6220d63
commit
ef30a7239c
|
@ -25,74 +25,75 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
|||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#define HAVE_KERNEL_16 1
|
||||
static void sdot_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *dot) __attribute__ ((noinline));
|
||||
|
||||
static void sdot_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
|
||||
{
|
||||
|
||||
|
||||
BLASLONG register i = 0;
|
||||
|
||||
__asm__ __volatile__
|
||||
(
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4 \n\t"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5 \n\t"
|
||||
"vxorps %%ymm6, %%ymm6, %%ymm6 \n\t"
|
||||
"vxorps %%ymm7, %%ymm7, %%ymm7 \n\t"
|
||||
|
||||
".p2align 4 \n\t"
|
||||
"1: \n\t"
|
||||
"vmovups (%2,%0,4), %%ymm12 \n\t" // 2 * x
|
||||
"vmovups 32(%2,%0,4), %%ymm13 \n\t" // 2 * x
|
||||
"vmovups 64(%2,%0,4), %%ymm14 \n\t" // 2 * x
|
||||
"vmovups 96(%2,%0,4), %%ymm15 \n\t" // 2 * x
|
||||
|
||||
"vfmadd231ps (%3,%0,4), %%ymm12, %%ymm4 \n\t" // 2 * y
|
||||
"vfmadd231ps 32(%3,%0,4), %%ymm13, %%ymm5 \n\t" // 2 * y
|
||||
"vfmadd231ps 64(%3,%0,4), %%ymm14, %%ymm6 \n\t" // 2 * y
|
||||
"vfmadd231ps 96(%3,%0,4), %%ymm15, %%ymm7 \n\t" // 2 * y
|
||||
|
||||
#ifndef DSDOT
|
||||
"addq $32 , %0 \n\t"
|
||||
"subq $32 , %1 \n\t"
|
||||
"jnz 1b \n\t"
|
||||
#ifndef __AVX512CD__
|
||||
#pragma GCC target("avx2,fma")
|
||||
#endif
|
||||
|
||||
"vextractf128 $1 , %%ymm4 , %%xmm12 \n\t"
|
||||
"vextractf128 $1 , %%ymm5 , %%xmm13 \n\t"
|
||||
"vextractf128 $1 , %%ymm6 , %%xmm14 \n\t"
|
||||
"vextractf128 $1 , %%ymm7 , %%xmm15 \n\t"
|
||||
#ifdef __AVX2__
|
||||
|
||||
"vaddps %%xmm4, %%xmm12, %%xmm4 \n\t"
|
||||
"vaddps %%xmm5, %%xmm13, %%xmm5 \n\t"
|
||||
"vaddps %%xmm6, %%xmm14, %%xmm6 \n\t"
|
||||
"vaddps %%xmm7, %%xmm15, %%xmm7 \n\t"
|
||||
#define HAVE_KERNEL_16 1
|
||||
|
||||
"vaddps %%xmm4, %%xmm5, %%xmm4 \n\t"
|
||||
"vaddps %%xmm6, %%xmm7, %%xmm6 \n\t"
|
||||
"vaddps %%xmm4, %%xmm6, %%xmm4 \n\t"
|
||||
#include <immintrin.h>
|
||||
|
||||
"vhaddps %%xmm4, %%xmm4, %%xmm4 \n\t"
|
||||
"vhaddps %%xmm4, %%xmm4, %%xmm4 \n\t"
|
||||
static void sdot_kernel_16( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
|
||||
|
||||
"vmovss %%xmm4, (%4) \n\t"
|
||||
"vzeroupper \n\t"
|
||||
{
|
||||
int i = 0;
|
||||
__m256 accum_0, accum_1, accum_2, accum_3;
|
||||
|
||||
:
|
||||
:
|
||||
"r" (i), // 0
|
||||
"r" (n), // 1
|
||||
"r" (x), // 2
|
||||
"r" (y), // 3
|
||||
"r" (dot) // 4
|
||||
: "cc",
|
||||
"%xmm4", "%xmm5",
|
||||
"%xmm6", "%xmm7",
|
||||
"%xmm12", "%xmm13", "%xmm14", "%xmm15",
|
||||
"memory"
|
||||
);
|
||||
accum_0 = _mm256_setzero_ps();
|
||||
accum_1 = _mm256_setzero_ps();
|
||||
accum_2 = _mm256_setzero_ps();
|
||||
accum_3 = _mm256_setzero_ps();
|
||||
|
||||
}
|
||||
#ifdef __AVX512CD__
|
||||
__m512 accum_05, accum_15, accum_25, accum_35;
|
||||
int n64;
|
||||
n64 = n & (~63);
|
||||
|
||||
accum_05 = _mm512_setzero_ps();
|
||||
accum_15 = _mm512_setzero_ps();
|
||||
accum_25 = _mm512_setzero_ps();
|
||||
accum_35 = _mm512_setzero_ps();
|
||||
|
||||
for (; i < n64; i += 64) {
|
||||
accum_05 += _mm512_loadu_ps(&x[i+ 0]) * _mm512_loadu_ps(&y[i+ 0]);
|
||||
accum_15 += _mm512_loadu_ps(&x[i+16]) * _mm512_loadu_ps(&y[i+16]);
|
||||
accum_25 += _mm512_loadu_ps(&x[i+32]) * _mm512_loadu_ps(&y[i+32]);
|
||||
accum_35 += _mm512_loadu_ps(&x[i+48]) * _mm512_loadu_ps(&y[i+48]);
|
||||
}
|
||||
|
||||
/*
|
||||
* we need to fold our 512 bit wide accumulator vectors into 256 bit wide vectors so that the AVX2 code
|
||||
* below can continue using the intermediate results in its loop
|
||||
*/
|
||||
accum_0 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_05, 0), _mm512_extractf32x8_ps(accum_05, 1));
|
||||
accum_1 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_15, 0), _mm512_extractf32x8_ps(accum_15, 1));
|
||||
accum_2 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_25, 0), _mm512_extractf32x8_ps(accum_25, 1));
|
||||
accum_3 = _mm256_add_ps(_mm512_extractf32x8_ps(accum_35, 0), _mm512_extractf32x8_ps(accum_35, 1));
|
||||
|
||||
#endif
|
||||
for (; i < n; i += 32) {
|
||||
accum_0 += _mm256_loadu_ps(&x[i+ 0]) * _mm256_loadu_ps(&y[i+ 0]);
|
||||
accum_1 += _mm256_loadu_ps(&x[i+ 8]) * _mm256_loadu_ps(&y[i+ 8]);
|
||||
accum_2 += _mm256_loadu_ps(&x[i+16]) * _mm256_loadu_ps(&y[i+16]);
|
||||
accum_3 += _mm256_loadu_ps(&x[i+24]) * _mm256_loadu_ps(&y[i+24]);
|
||||
}
|
||||
|
||||
/* we now have the partial sums of the dot product in the 4 accumulation vectors, time to consolidate */
|
||||
|
||||
accum_0 = accum_0 + accum_1 + accum_2 + accum_3;
|
||||
|
||||
__m128 half_accum0;
|
||||
|
||||
/* Add upper half to lower half of each of the 256 bit vector to get a 128 bit vector */
|
||||
half_accum0 = _mm_add_ps(_mm256_extractf128_ps(accum_0, 0), _mm256_extractf128_ps(accum_0, 1));
|
||||
|
||||
/* in 128 bit land there is a hadd operation to do the rest of the element-wise sum in one go */
|
||||
half_accum0 = _mm_hadd_ps(half_accum0, half_accum0);
|
||||
half_accum0 = _mm_hadd_ps(half_accum0, half_accum0);
|
||||
|
||||
*dot = half_accum0[0];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue