From d9ba49165af15d535d9b9955bd248eab4d259f06 Mon Sep 17 00:00:00 2001 From: Gengxin Xie Date: Sun, 27 Sep 2020 10:38:19 +0800 Subject: [PATCH] Improve the performance of rot by using AVX512 and AVX2 intrinsic --- driver/others/blas_l1_thread.c | 2 +- driver/others/blas_server_win32.c | 11 +- kernel/x86_64/KERNEL.HASWELL | 3 + kernel/x86_64/drot.c | 139 +++++++++++++++++++++++++ kernel/x86_64/drot_microk_haswell-2.c | 87 ++++++++++++++++ kernel/x86_64/drot_microk_skylakex-2.c | 94 +++++++++++++++++ kernel/x86_64/srot.c | 139 +++++++++++++++++++++++++ kernel/x86_64/srot_microk_haswell-2.c | 87 ++++++++++++++++ kernel/x86_64/srot_microk_skylakex-2.c | 91 ++++++++++++++++ 9 files changed, 648 insertions(+), 5 deletions(-) create mode 100644 kernel/x86_64/drot.c create mode 100644 kernel/x86_64/drot_microk_haswell-2.c create mode 100644 kernel/x86_64/drot_microk_skylakex-2.c create mode 100644 kernel/x86_64/srot.c create mode 100644 kernel/x86_64/srot_microk_haswell-2.c create mode 100644 kernel/x86_64/srot_microk_skylakex-2.c diff --git a/driver/others/blas_l1_thread.c b/driver/others/blas_l1_thread.c index 04acbcc5f..06039c952 100644 --- a/driver/others/blas_l1_thread.c +++ b/driver/others/blas_l1_thread.c @@ -80,7 +80,7 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha break; } - mode |= BLAS_LEGACY; + if(!(mode & BLAS_PTHREAD)) mode |= BLAS_LEGACY; for (i = 0; i < nthreads; i++) blas_queue_init(&queue[i]); diff --git a/driver/others/blas_server_win32.c b/driver/others/blas_server_win32.c index d2cc91757..f47908c70 100644 --- a/driver/others/blas_server_win32.c +++ b/driver/others/blas_server_win32.c @@ -476,12 +476,15 @@ int exec_blas(BLASLONG num, blas_queue_t *queue){ routine = queue -> routine; - if (!(queue -> mode & BLAS_LEGACY)) { + if (queue -> mode & BLAS_LEGACY) { + legacy_exec(routine, queue -> mode, queue -> args, queue -> sb); + } else + if (queue -> mode & BLAS_PTHREAD) { + void (*pthreadcompat)(void *) = queue -> routine; + (pthreadcompat)(queue -> args); + } else (routine)(queue -> args, queue -> range_m, queue -> range_n, queue -> sa, queue -> sb, 0); - } else { - legacy_exec(routine, queue -> mode, queue -> args, queue -> sb); - } if ((num > 1) && queue -> next) exec_blas_async_wait(num - 1, queue -> next); diff --git a/kernel/x86_64/KERNEL.HASWELL b/kernel/x86_64/KERNEL.HASWELL index b979fc0ae..81eaf96ac 100644 --- a/kernel/x86_64/KERNEL.HASWELL +++ b/kernel/x86_64/KERNEL.HASWELL @@ -102,3 +102,6 @@ ZGEMM3MKERNEL = zgemm3m_kernel_4x4_haswell.c SASUMKERNEL = sasum.c DASUMKERNEL = dasum.c + +SROTKERNEL = srot.c +DROTKERNEL = drot.c diff --git a/kernel/x86_64/drot.c b/kernel/x86_64/drot.c new file mode 100644 index 000000000..a312b7ff9 --- /dev/null +++ b/kernel/x86_64/drot.c @@ -0,0 +1,139 @@ +#include "common.h" + +#if defined(SKYLAKEX) +#include "drot_microk_skylakex-2.c" +#elif defined(HASWELL) +#include "drot_microk_haswell-2.c" +#endif + +#ifndef HAVE_DROT_KERNEL + +static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s) +{ + BLASLONG i = 0; + FLOAT f0, f1, f2, f3; + FLOAT x0, x1, x2, x3; + FLOAT g0, g1, g2, g3; + FLOAT y0, y1, y2, y3; + + FLOAT* xp = x; + FLOAT* yp = y; + + BLASLONG n1 = n & (~7); + + while (i < n1) { + x0 = xp[0]; + y0 = yp[0]; + x1 = xp[1]; + y1 = yp[1]; + x2 = xp[2]; + y2 = yp[2]; + x3 = xp[3]; + y3 = yp[3]; + + f0 = c*x0 + s*y0; + g0 = c*y0 - s*x0; + f1 = c*x1 + s*y1; + g1 = c*y1 - s*x1; + f2 = c*x2 + s*y2; + g2 = c*y2 - s*x2; + f3 = c*x3 + s*y3; + g3 = c*y3 - s*x3; + + xp[0] = f0; + yp[0] = g0; + xp[1] = f1; + yp[1] = g1; + xp[2] = f2; + yp[2] = g2; + xp[3] = f3; + yp[3] = g3; + + xp += 4; + yp += 4; + i += 4; + } + + while (i < n) { + FLOAT temp = c*x[i] + s*y[i]; + y[i] = c*y[i] - s*x[i]; + x[i] = temp; + + i++; + } +} + +#endif +static void rot_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT c, FLOAT s) +{ + BLASLONG i = 0; + BLASLONG ix = 0, iy = 0; + + FLOAT temp; + + if (n <= 0) + return; + if ((inc_x == 1) && (inc_y == 1)) { + drot_kernel(n, x, y, c, s); + } + else { + while (i < n) { + temp = c * x[ix] + s * y[iy]; + y[iy] = c * y[iy] - s * x[ix]; + x[ix] = temp; + + ix += inc_x; + iy += inc_y; + i++; + } + } + return; +} + + +#if defined(SMP) +static int rot_thread_function(blas_arg_t *args) +{ + + rot_compute(args->m, + args->a, args->lda, + args->b, args->ldb, + ((FLOAT *)args->alpha)[0], + ((FLOAT *)args->alpha)[1]); + return 0; +} + +extern int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha, void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc, int (*function)(), int nthreads); +#endif +int CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT c, FLOAT s) +{ +#if defined(SMP) + int nthreads; + FLOAT alpha[2]={c, s}; + FLOAT dummy_c; +#endif + +#if defined(SMP) + if (inc_x == 0 || inc_y == 0 || n <= 100000) { + nthreads = 1; + } + else { + nthreads = num_cpu_avail(1); + } + + if (nthreads == 1) { + rot_compute(n, x, inc_x, y, inc_y, c, s); + } + else { +#if defined(DOUBLE) + int mode = BLAS_DOUBLE | BLAS_REAL | BLAS_PTHREAD; +#else + int mode = BLAS_SINGLE | BLAS_REAL | BLAS_PTHREAD; +#endif + blas_level1_thread(mode, n, 0, 0, alpha, x, inc_x, y, inc_y, &dummy_c, 0, (void *)rot_thread_function, nthreads); + } +#else + rot_compute(n, x, inc_x, y, inc_y, c, s); +#endif + return 0; +} diff --git a/kernel/x86_64/drot_microk_haswell-2.c b/kernel/x86_64/drot_microk_haswell-2.c new file mode 100644 index 000000000..72a87696e --- /dev/null +++ b/kernel/x86_64/drot_microk_haswell-2.c @@ -0,0 +1,87 @@ +/* need a new enough GCC for avx512 support */ +#if (( defined(__GNUC__) && __GNUC__ > 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 9)) + +#define HAVE_DROT_KERNEL 1 + +#include +#include + +static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s) +{ + BLASLONG i = 0; + + BLASLONG tail_index_4 = n&(~3); + BLASLONG tail_index_16 = n&(~15); + + __m256d c_256, s_256; + if (n >= 4) { + c_256 = _mm256_set1_pd(c); + s_256 = _mm256_set1_pd(s); + } + + __m256d x0, x1, x2, x3; + __m256d y0, y1, y2, y3; + __m256d t0, t1, t2, t3; + + for (i = 0; i < tail_index_16; i += 16) { + x0 = _mm256_loadu_pd(&x[i + 0]); + x1 = _mm256_loadu_pd(&x[i + 4]); + x2 = _mm256_loadu_pd(&x[i + 8]); + x3 = _mm256_loadu_pd(&x[i +12]); + y0 = _mm256_loadu_pd(&y[i + 0]); + y1 = _mm256_loadu_pd(&y[i + 4]); + y2 = _mm256_loadu_pd(&y[i + 8]); + y3 = _mm256_loadu_pd(&y[i +12]); + + t0 = _mm256_mul_pd(s_256, y0); + t1 = _mm256_mul_pd(s_256, y1); + t2 = _mm256_mul_pd(s_256, y2); + t3 = _mm256_mul_pd(s_256, y3); + + t0 = _mm256_fmadd_pd(c_256, x0, t0); + t1 = _mm256_fmadd_pd(c_256, x1, t1); + t2 = _mm256_fmadd_pd(c_256, x2, t2); + t3 = _mm256_fmadd_pd(c_256, x3, t3); + + _mm256_storeu_pd(&x[i + 0], t0); + _mm256_storeu_pd(&x[i + 4], t1); + _mm256_storeu_pd(&x[i + 8], t2); + _mm256_storeu_pd(&x[i +12], t3); + + t0 = _mm256_mul_pd(s_256, x0); + t1 = _mm256_mul_pd(s_256, x1); + t2 = _mm256_mul_pd(s_256, x2); + t3 = _mm256_mul_pd(s_256, x3); + + t0 = _mm256_fmsub_pd(c_256, y0, t0); + t1 = _mm256_fmsub_pd(c_256, y1, t1); + t2 = _mm256_fmsub_pd(c_256, y2, t2); + t3 = _mm256_fmsub_pd(c_256, y3, t3); + + _mm256_storeu_pd(&y[i + 0], t0); + _mm256_storeu_pd(&y[i + 4], t1); + _mm256_storeu_pd(&y[i + 8], t2); + _mm256_storeu_pd(&y[i +12], t3); + + } + + for (i = tail_index_16; i < tail_index_4; i += 4) { + x0 = _mm256_loadu_pd(&x[i]); + y0 = _mm256_loadu_pd(&y[i]); + + t0 = _mm256_mul_pd(s_256, y0); + t0 = _mm256_fmadd_pd(c_256, x0, t0); + _mm256_storeu_pd(&x[i], t0); + + t0 = _mm256_mul_pd(s_256, x0); + t0 = _mm256_fmsub_pd(c_256, y0, t0); + _mm256_storeu_pd(&y[i], t0); + } + + for (i = tail_index_4; i < n; ++i) { + FLOAT temp = c * x[i] + s * y[i]; + y[i] = c * y[i] - s * x[i]; + x[i] = temp; + } +} +#endif diff --git a/kernel/x86_64/drot_microk_skylakex-2.c b/kernel/x86_64/drot_microk_skylakex-2.c new file mode 100644 index 000000000..4e862e663 --- /dev/null +++ b/kernel/x86_64/drot_microk_skylakex-2.c @@ -0,0 +1,94 @@ +/* need a new enough GCC for avx512 support */ +#if (( defined(__GNUC__) && __GNUC__ > 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 9)) + +#define HAVE_DROT_KERNEL 1 + +#include +#include + +static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s) +{ + BLASLONG i = 0; + BLASLONG n1 = n; + + BLASLONG tail_index_8 = 0; + BLASLONG tail_index_32 = 0; + + __m512d c_512 = _mm512_set1_pd(c); + __m512d s_512 = _mm512_set1_pd(s); + + tail_index_8 = n1 & (~7); + tail_index_32 = n1 & (~31); + + + __m512d x0, x1, x2, x3; + __m512d y0, y1, y2, y3; + __m512d t0, t1, t2, t3; + + for (i = 0; i < tail_index_32; i += 32) { + x0 = _mm512_loadu_pd(&x[i + 0]); + x1 = _mm512_loadu_pd(&x[i + 8]); + x2 = _mm512_loadu_pd(&x[i +16]); + x3 = _mm512_loadu_pd(&x[i +24]); + y0 = _mm512_loadu_pd(&y[i + 0]); + y1 = _mm512_loadu_pd(&y[i + 8]); + y2 = _mm512_loadu_pd(&y[i +16]); + y3 = _mm512_loadu_pd(&y[i +24]); + + t0 = _mm512_mul_pd(s_512, y0); + t1 = _mm512_mul_pd(s_512, y1); + t2 = _mm512_mul_pd(s_512, y2); + t3 = _mm512_mul_pd(s_512, y3); + + t0 = _mm512_fmadd_pd(c_512, x0, t0); + t1 = _mm512_fmadd_pd(c_512, x1, t1); + t2 = _mm512_fmadd_pd(c_512, x2, t2); + t3 = _mm512_fmadd_pd(c_512, x3, t3); + + _mm512_storeu_pd(&x[i + 0], t0); + _mm512_storeu_pd(&x[i + 8], t1); + _mm512_storeu_pd(&x[i +16], t2); + _mm512_storeu_pd(&x[i +24], t3); + + t0 = _mm512_mul_pd(s_512, x0); + t1 = _mm512_mul_pd(s_512, x1); + t2 = _mm512_mul_pd(s_512, x2); + t3 = _mm512_mul_pd(s_512, x3); + + t0 = _mm512_fmsub_pd(c_512, y0, t0); + t1 = _mm512_fmsub_pd(c_512, y1, t1); + t2 = _mm512_fmsub_pd(c_512, y2, t2); + t3 = _mm512_fmsub_pd(c_512, y3, t3); + + _mm512_storeu_pd(&y[i + 0], t0); + _mm512_storeu_pd(&y[i + 8], t1); + _mm512_storeu_pd(&y[i +16], t2); + _mm512_storeu_pd(&y[i +24], t3); + } + + for (i = tail_index_32; i < tail_index_8; i += 8) { + x0 = _mm512_loadu_pd(&x[i]); + y0 = _mm512_loadu_pd(&y[i]); + + t0 = _mm512_mul_pd(s_512, y0); + t0 = _mm512_fmadd_pd(c_512, x0, t0); + _mm512_storeu_pd(&x[i], t0); + + t0 = _mm512_mul_pd(s_512, x0); + t0 = _mm512_fmsub_pd(c_512, y0, t0); + _mm512_storeu_pd(&y[i], t0); + } + + if ((n1&7) > 0) { + unsigned char tail_mask8 = (((unsigned char) 0xff) >> (8 -(n1&7))); + __m512d tail_x = _mm512_maskz_loadu_pd(*((__mmask8*) &tail_mask8), &x[tail_index_8]); + __m512d tail_y = _mm512_maskz_loadu_pd(*((__mmask8*) &tail_mask8), &y[tail_index_8]); + __m512d temp = _mm512_mul_pd(s_512, tail_y); + temp = _mm512_fmadd_pd(c_512, tail_x, temp); + _mm512_mask_storeu_pd(&x[tail_index_8],*((__mmask8*)&tail_mask8), temp); + temp = _mm512_mul_pd(s_512, tail_x); + temp = _mm512_fmsub_pd(c_512, tail_y, temp); + _mm512_mask_storeu_pd(&y[tail_index_8], *((__mmask8*)&tail_mask8), temp); + } +} +#endif diff --git a/kernel/x86_64/srot.c b/kernel/x86_64/srot.c new file mode 100644 index 000000000..021c20d82 --- /dev/null +++ b/kernel/x86_64/srot.c @@ -0,0 +1,139 @@ +#include "common.h" + +#if defined(SKYLAKEX) +#include "srot_microk_skylakex-2.c" +#elif defined(HASWELL) +#include "srot_microk_haswell-2.c" +#endif + +#ifndef HAVE_SROT_KERNEL + +static void srot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s) +{ + BLASLONG i = 0; + FLOAT f0, f1, f2, f3; + FLOAT x0, x1, x2, x3; + FLOAT g0, g1, g2, g3; + FLOAT y0, y1, y2, y3; + + FLOAT* xp = x; + FLOAT* yp = y; + + BLASLONG n1 = n & (~7); + + while (i < n1) { + x0 = xp[0]; + y0 = yp[0]; + x1 = xp[1]; + y1 = yp[1]; + x2 = xp[2]; + y2 = yp[2]; + x3 = xp[3]; + y3 = yp[3]; + + f0 = c*x0 + s*y0; + g0 = c*y0 - s*x0; + f1 = c*x1 + s*y1; + g1 = c*y1 - s*x1; + f2 = c*x2 + s*y2; + g2 = c*y2 - s*x2; + f3 = c*x3 + s*y3; + g3 = c*y3 - s*x3; + + xp[0] = f0; + yp[0] = g0; + xp[1] = f1; + yp[1] = g1; + xp[2] = f2; + yp[2] = g2; + xp[3] = f3; + yp[3] = g3; + + xp += 4; + yp += 4; + i += 4; + } + + while (i < n) { + FLOAT temp = c*x[i] + s*y[i]; + y[i] = c*y[i] - s*x[i]; + x[i] = temp; + + i++; + } +} + +#endif +static void rot_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT c, FLOAT s) +{ + BLASLONG i = 0; + BLASLONG ix = 0, iy = 0; + + FLOAT temp; + + if (n <= 0) + return; + if ((inc_x == 1) && (inc_y == 1)) { + srot_kernel(n, x, y, c, s); + } + else { + while (i < n) { + temp = c * x[ix] + s * y[iy]; + y[iy] = c * y[iy] - s * x[ix]; + x[ix] = temp; + + ix += inc_x; + iy += inc_y; + i++; + } + } + return; +} + + +#if defined(SMP) +static int rot_thread_function(blas_arg_t *args) +{ + + rot_compute(args->m, + args->a, args->lda, + args->b, args->ldb, + ((float *)args->alpha)[0], + ((float *)args->alpha)[1]); + return 0; +} + +extern int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha, void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc, int (*function)(), int nthreads); +#endif +int CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT c, FLOAT s) +{ +#if defined(SMP) + int nthreads; + FLOAT alpha[2]={c, s}; + FLOAT dummy_c; +#endif + +#if defined(SMP) + if (inc_x == 0 || inc_y == 0 || n <= 100000) { + nthreads = 1; + } + else { + nthreads = num_cpu_avail(1); + } + + if (nthreads == 1) { + rot_compute(n, x, inc_x, y, inc_y, c, s); + } + else { +#if defined(DOUBLE) + int mode = BLAS_DOUBLE | BLAS_REAL | BLAS_PTHREAD; +#else + int mode = BLAS_SINGLE | BLAS_REAL | BLAS_PTHREAD; +#endif + blas_level1_thread(mode, n, 0, 0, alpha, x, inc_x, y, inc_y, &dummy_c, 0, (void *)rot_thread_function, nthreads); + } +#else + rot_compute(n, x, inc_x, y, inc_y, c, s); +#endif + return 0; +} diff --git a/kernel/x86_64/srot_microk_haswell-2.c b/kernel/x86_64/srot_microk_haswell-2.c new file mode 100644 index 000000000..cba962042 --- /dev/null +++ b/kernel/x86_64/srot_microk_haswell-2.c @@ -0,0 +1,87 @@ +/* need a new enough GCC for avx512 support */ +#if (( defined(__GNUC__) && __GNUC__ > 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 9)) + +#define HAVE_SROT_KERNEL 1 + +#include +#include + +static void srot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s) +{ + BLASLONG i = 0; + + BLASLONG tail_index_8 = n&(~7); + BLASLONG tail_index_32 = n&(~31); + + __m256 c_256, s_256; + if (n >= 8) { + c_256 = _mm256_set1_ps(c); + s_256 = _mm256_set1_ps(s); + } + + __m256 x0, x1, x2, x3; + __m256 y0, y1, y2, y3; + __m256 t0, t1, t2, t3; + + for (i = 0; i < tail_index_32; i += 32) { + x0 = _mm256_loadu_ps(&x[i + 0]); + x1 = _mm256_loadu_ps(&x[i + 8]); + x2 = _mm256_loadu_ps(&x[i +16]); + x3 = _mm256_loadu_ps(&x[i +24]); + y0 = _mm256_loadu_ps(&y[i + 0]); + y1 = _mm256_loadu_ps(&y[i + 8]); + y2 = _mm256_loadu_ps(&y[i +16]); + y3 = _mm256_loadu_ps(&y[i +24]); + + t0 = _mm256_mul_ps(s_256, y0); + t1 = _mm256_mul_ps(s_256, y1); + t2 = _mm256_mul_ps(s_256, y2); + t3 = _mm256_mul_ps(s_256, y3); + + t0 = _mm256_fmadd_ps(c_256, x0, t0); + t1 = _mm256_fmadd_ps(c_256, x1, t1); + t2 = _mm256_fmadd_ps(c_256, x2, t2); + t3 = _mm256_fmadd_ps(c_256, x3, t3); + + _mm256_storeu_ps(&x[i + 0], t0); + _mm256_storeu_ps(&x[i + 8], t1); + _mm256_storeu_ps(&x[i +16], t2); + _mm256_storeu_ps(&x[i +24], t3); + + t0 = _mm256_mul_ps(s_256, x0); + t1 = _mm256_mul_ps(s_256, x1); + t2 = _mm256_mul_ps(s_256, x2); + t3 = _mm256_mul_ps(s_256, x3); + + t0 = _mm256_fmsub_ps(c_256, y0, t0); + t1 = _mm256_fmsub_ps(c_256, y1, t1); + t2 = _mm256_fmsub_ps(c_256, y2, t2); + t3 = _mm256_fmsub_ps(c_256, y3, t3); + + _mm256_storeu_ps(&y[i + 0], t0); + _mm256_storeu_ps(&y[i + 8], t1); + _mm256_storeu_ps(&y[i +16], t2); + _mm256_storeu_ps(&y[i +24], t3); + + } + + for (i = tail_index_32; i < tail_index_8; i += 8) { + x0 = _mm256_loadu_ps(&x[i]); + y0 = _mm256_loadu_ps(&y[i]); + + t0 = _mm256_mul_ps(s_256, y0); + t0 = _mm256_fmadd_ps(c_256, s0, t0); + _mm256_storeu_ps(&x[i], t0); + + t0 = _mm256_mul_ps(s_256, x0); + t0 = _mm256_fmsub_ps(c_256, y0, t0); + _mm256_storeu_ps(&y[i], t0); + } + + for (i = tail_index_8; i < n; ++i) { + FLOAT temp = c * x[i] + s * y[i]; + y[i] = c * y[i] - s * x[i]; + x[i] = temp; + } +} +#endif diff --git a/kernel/x86_64/srot_microk_skylakex-2.c b/kernel/x86_64/srot_microk_skylakex-2.c new file mode 100644 index 000000000..a21d1cf64 --- /dev/null +++ b/kernel/x86_64/srot_microk_skylakex-2.c @@ -0,0 +1,91 @@ +/* need a new enough GCC for avx512 support */ +#if (( defined(__GNUC__) && __GNUC__ > 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 9)) + +#define HAVE_SROT_KERNEL 1 + +#include +#include + +static void srot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s) +{ + BLASLONG i = 0; + __m512 c_512, s_512; + c_512 = _mm512_set1_ps(c); + s_512 = _mm512_set1_ps(s); + + BLASLONG tail_index_16 = n&(~15); + BLASLONG tail_index_64 = n&(~63); + + + __m512 x0, x1, x2, x3; + __m512 y0, y1, y2, y3; + __m512 t0, t1, t2, t3; + + for (i = 0; i < tail_index_64; i += 64) { + x0 = _mm512_loadu_ps(&x[i + 0]); + x1 = _mm512_loadu_ps(&x[i +16]); + x2 = _mm512_loadu_ps(&x[i +32]); + x3 = _mm512_loadu_ps(&x[i +48]); + y0 = _mm512_loadu_ps(&y[i + 0]); + y1 = _mm512_loadu_ps(&y[i +16]); + y2 = _mm512_loadu_ps(&y[i +32]); + y3 = _mm512_loadu_ps(&y[i +48]); + + t0 = _mm512_mul_ps(s_512, y0); + t1 = _mm512_mul_ps(s_512, y1); + t2 = _mm512_mul_ps(s_512, y2); + t3 = _mm512_mul_ps(s_512, y3); + + t0 = _mm512_fmadd_ps(c_512, x0, t0); + t1 = _mm512_fmadd_ps(c_512, x1, t1); + t2 = _mm512_fmadd_ps(c_512, x2, t2); + t3 = _mm512_fmadd_ps(c_512, x3, t3); + + _mm512_storeu_ps(&x[i + 0], t0); + _mm512_storeu_ps(&x[i +16], t1); + _mm512_storeu_ps(&x[i +32], t2); + _mm512_storeu_ps(&x[i +48], t3); + + t0 = _mm512_mul_ps(s_512, x0); + t1 = _mm512_mul_ps(s_512, x1); + t2 = _mm512_mul_ps(s_512, x2); + t3 = _mm512_mul_ps(s_512, x3); + + t0 = _mm512_fmsub_ps(c_512, y0, t0); + t1 = _mm512_fmsub_ps(c_512, y1, t1); + t2 = _mm512_fmsub_ps(c_512, y2, t2); + t3 = _mm512_fmsub_ps(c_512, y3, t3); + + _mm512_storeu_ps(&y[i + 0], t0); + _mm512_storeu_ps(&y[i +16], t1); + _mm512_storeu_ps(&y[i +32], t2); + _mm512_storeu_ps(&y[i +48], t3); + } + + for (i = tail_index_64; i < tail_index_16; i += 16) { + x0 = _mm512_loadu_ps(&x[i]); + y0 = _mm512_loadu_ps(&y[i]); + + t0 = _mm512_mul_ps(s_512, y0); + t0 = _mm512_fmadd_ps(c_512, x0, t0); + _mm512_storeu_ps(&x[i], t0); + + t0 = _mm512_mul_ps(s_512, x0); + t0 = _mm512_fmsub_ps(c_512, y0, t0); + _mm512_storeu_ps(&y[i], t0); + } + + + if ((n & 15) > 0) { + uint16_t tail_mask16 = (((uint16_t) 0xffff) >> (16-(n&15))); + __m512 tail_x = _mm512_maskz_loadu_ps(*((__mmask16*)&tail_mask16), &x[tail_index_16]); + __m512 tail_y = _mm512_maskz_loadu_ps(*((__mmask16*)&tail_mask16), &y[tail_index_16]); + __m512 temp = _mm512_mul_ps(s_512, tail_y); + temp = _mm512_fmadd_ps(c_512, tail_x, temp); + _mm512_mask_storeu_ps(&x[tail_index_16], *((__mmask16*)&tail_mask16), temp); + temp = _mm512_mul_ps(s_512, tail_x); + temp = _mm512_fmsub_ps(c_512, tail_y, temp); + _mm512_mask_storeu_ps(&y[tail_index_16], *((__mmask16*)&tail_mask16), temp); + } +} +#endif