Add AVX512 support to DDOT
now that it's written in C + intrinsics it's easy to add AVX512 support for DDOT
This commit is contained in:
parent
ae38fa55c3
commit
34d63df4b3
|
@ -26,7 +26,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
*****************************************************************************/
|
*****************************************************************************/
|
||||||
|
|
||||||
/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already */
|
/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already */
|
||||||
#ifndef __AVX512CD_
|
|
||||||
|
#ifndef __AVX512CD__
|
||||||
#pragma GCC target("avx2,fma")
|
#pragma GCC target("avx2,fma")
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -35,7 +36,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
#define HAVE_KERNEL_8 1
|
#define HAVE_KERNEL_8 1
|
||||||
|
|
||||||
#include <immintrin.h>
|
#include <immintrin.h>
|
||||||
static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y , FLOAT *dot) __attribute__ ((noinline));
|
|
||||||
|
|
||||||
static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
|
static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
|
||||||
{
|
{
|
||||||
|
@ -47,10 +47,37 @@ static void ddot_kernel_8( BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *dot)
|
||||||
accum_2 = _mm256_setzero_pd();
|
accum_2 = _mm256_setzero_pd();
|
||||||
accum_3 = _mm256_setzero_pd();
|
accum_3 = _mm256_setzero_pd();
|
||||||
|
|
||||||
|
#ifdef __AVX512CD__
|
||||||
|
__m512d accum_05, accum_15, accum_25, accum_35;
|
||||||
|
int n32;
|
||||||
|
n32 = n & (~31);
|
||||||
|
|
||||||
|
accum_05 = _mm512_setzero_pd();
|
||||||
|
accum_15 = _mm512_setzero_pd();
|
||||||
|
accum_25 = _mm512_setzero_pd();
|
||||||
|
accum_35 = _mm512_setzero_pd();
|
||||||
|
|
||||||
|
for (; i < n32; i += 32) {
|
||||||
|
accum_05 += _mm512_loadu_pd(&x[i+ 0]) * _mm512_loadu_pd(&y[i+ 0]);
|
||||||
|
accum_15 += _mm512_loadu_pd(&x[i+ 8]) * _mm512_loadu_pd(&y[i+ 8]);
|
||||||
|
accum_25 += _mm512_loadu_pd(&x[i+16]) * _mm512_loadu_pd(&y[i+16]);
|
||||||
|
accum_35 += _mm512_loadu_pd(&x[i+24]) * _mm512_loadu_pd(&y[i+24]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* we need to fold our 512 bit wide accumulator vectors into 256 bit wide vectors so that the AVX2 code
|
||||||
|
* below can continue using the intermediate results in its loop
|
||||||
|
*/
|
||||||
|
accum_0 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_05, 0), _mm512_extractf64x4_pd(accum_05, 1));
|
||||||
|
accum_1 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_15, 0), _mm512_extractf64x4_pd(accum_15, 1));
|
||||||
|
accum_2 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_25, 0), _mm512_extractf64x4_pd(accum_25, 1));
|
||||||
|
accum_3 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_35, 0), _mm512_extractf64x4_pd(accum_35, 1));
|
||||||
|
|
||||||
|
#endif
|
||||||
for (; i < n; i += 16) {
|
for (; i < n; i += 16) {
|
||||||
accum_0 += _mm256_loadu_pd(&x[i+ 0]) * _mm256_loadu_pd(&y[i+0]);
|
accum_0 += _mm256_loadu_pd(&x[i+ 0]) * _mm256_loadu_pd(&y[i+ 0]);
|
||||||
accum_1 += _mm256_loadu_pd(&x[i+ 4]) * _mm256_loadu_pd(&y[i+4]);
|
accum_1 += _mm256_loadu_pd(&x[i+ 4]) * _mm256_loadu_pd(&y[i+ 4]);
|
||||||
accum_2 += _mm256_loadu_pd(&x[i+ 8]) * _mm256_loadu_pd(&y[i+8]);
|
accum_2 += _mm256_loadu_pd(&x[i+ 8]) * _mm256_loadu_pd(&y[i+ 8]);
|
||||||
accum_3 += _mm256_loadu_pd(&x[i+12]) * _mm256_loadu_pd(&y[i+12]);
|
accum_3 += _mm256_loadu_pd(&x[i+12]) * _mm256_loadu_pd(&y[i+12]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue