Add AVX512 support to DDOT

now that it's written in C + intrinsics it's easy to add AVX512 support
for DDOT
This commit is contained in:
Arjan van de Ven 2018-08-05 15:16:20 +00:00
parent ae38fa55c3
commit 34d63df4b3
1 changed files with 32 additions and 5 deletions

View File

@ -26,7 +26,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already */
#ifndef __AVX512CD_
#ifndef __AVX512CD__
#pragma GCC target("avx2,fma")
#endif
@ -35,7 +36,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define HAVE_KERNEL_8 1
#include <immintrin.h>
static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *dot) __attribute__ ((noinline));
static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
{
@ -47,10 +47,37 @@ static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
accum_2 = _mm256_setzero_pd();
accum_3 = _mm256_setzero_pd();
#ifdef __AVX512CD__
__m512d accum_05, accum_15, accum_25, accum_35;
int n32;
n32 = n & (~31);
accum_05 = _mm512_setzero_pd();
accum_15 = _mm512_setzero_pd();
accum_25 = _mm512_setzero_pd();
accum_35 = _mm512_setzero_pd();
for (; i < n32; i += 32) {
accum_05 += _mm512_loadu_pd(&x[i+ 0]) * _mm512_loadu_pd(&y[i+ 0]);
accum_15 += _mm512_loadu_pd(&x[i+ 8]) * _mm512_loadu_pd(&y[i+ 8]);
accum_25 += _mm512_loadu_pd(&x[i+16]) * _mm512_loadu_pd(&y[i+16]);
accum_35 += _mm512_loadu_pd(&x[i+24]) * _mm512_loadu_pd(&y[i+24]);
}
/*
* we need to fold our 512 bit wide accumulator vectors into 256 bit wide vectors so that the AVX2 code
* below can continue using the intermediate results in its loop
*/
accum_0 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_05, 0), _mm512_extractf64x4_pd(accum_05, 1));
accum_1 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_15, 0), _mm512_extractf64x4_pd(accum_15, 1));
accum_2 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_25, 0), _mm512_extractf64x4_pd(accum_25, 1));
accum_3 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_35, 0), _mm512_extractf64x4_pd(accum_35, 1));
#endif
for (; i < n; i += 16) {
accum_0 += _mm256_loadu_pd(&x[i+ 0]) * _mm256_loadu_pd(&y[i+0]);
accum_1 += _mm256_loadu_pd(&x[i+ 4]) * _mm256_loadu_pd(&y[i+4]);
accum_2 += _mm256_loadu_pd(&x[i+ 8]) * _mm256_loadu_pd(&y[i+8]);
accum_0 += _mm256_loadu_pd(&x[i+ 0]) * _mm256_loadu_pd(&y[i+ 0]);
accum_1 += _mm256_loadu_pd(&x[i+ 4]) * _mm256_loadu_pd(&y[i+ 4]);
accum_2 += _mm256_loadu_pd(&x[i+ 8]) * _mm256_loadu_pd(&y[i+ 8]);
accum_3 += _mm256_loadu_pd(&x[i+12]) * _mm256_loadu_pd(&y[i+12]);
}