Add AVX512 support to dsymv_L_microk_haswell-2.c

Now that the code is written in intrinsics it's relatively easy to add AVX512 support
This commit is contained in:
Arjan van de Ven 2018-08-04 23:56:06 +00:00
parent c202e06297
commit f2810beafb
1 changed files with 50 additions and 0 deletions

View File

@ -46,6 +46,56 @@ static void dsymv_kernel_4x4(BLASLONG from, BLASLONG to, FLOAT **a, FLOAT *x, FL
temp1_2 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[2])); temp1_2 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[2]));
temp1_3 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[3])); temp1_3 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[3]));
#ifdef __AVX512CD__
__m512d temp2_05, temp2_15, temp2_25, temp2_35; // temp2_0 temp2_1 temp2_2 temp2_3
__m512d temp1_05, temp1_15, temp1_25, temp1_35;
BLASLONG to2;
int delta;
temp2_05 = _mm512_setzero_pd();
temp2_15 = _mm512_setzero_pd();
temp2_25 = _mm512_setzero_pd();
temp2_35 = _mm512_setzero_pd();
temp1_05 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[0]));
temp1_15 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[1]));
temp1_25 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[2]));
temp1_35 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[3]));
delta = (to - from) & ~7;
to2 = from + delta;
for (; from < to2; from += 8) {
__m512d _x, _y;
__m512d a0, a1, a2, a3;
_y = _mm512_loadu_pd(&y[from]);
_x = _mm512_loadu_pd(&x[from]);
a0 = _mm512_loadu_pd(&a[0][from]);
a1 = _mm512_loadu_pd(&a[1][from]);
a2 = _mm512_loadu_pd(&a[2][from]);
a3 = _mm512_loadu_pd(&a[3][from]);
_y += temp1_05 * a0 + temp1_15 * a1 + temp1_25 * a2 + temp1_35 * a3;
temp2_05 += _x * a0;
temp2_15 += _x * a1;
temp2_25 += _x * a2;
temp2_35 += _x * a3;
_mm512_storeu_pd(&y[from], _y);
};
temp2_0 = _mm256_add_pd(_mm512_extractf64x4_pd(temp2_05, 0), _mm512_extractf64x4_pd(temp2_05, 1));
temp2_1 = _mm256_add_pd(_mm512_extractf64x4_pd(temp2_15, 0), _mm512_extractf64x4_pd(temp2_15, 1));
temp2_2 = _mm256_add_pd(_mm512_extractf64x4_pd(temp2_25, 0), _mm512_extractf64x4_pd(temp2_25, 1));
temp2_3 = _mm256_add_pd(_mm512_extractf64x4_pd(temp2_35, 0), _mm512_extractf64x4_pd(temp2_35, 1));
#endif
for (; from != to; from += 4) { for (; from != to; from += 4) {
__m256d _x, _y; __m256d _x, _y;
__m256d a0, a1, a2, a3; __m256d a0, a1, a2, a3;