diff --git a/kernel/x86_64/dgemv_n_microk_haswell-4.c b/kernel/x86_64/dgemv_n_microk_haswell-4.c index 80879fdee..26c3d1ae8 100644 --- a/kernel/x86_64/dgemv_n_microk_haswell-4.c +++ b/kernel/x86_64/dgemv_n_microk_haswell-4.c @@ -25,7 +25,12 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ +/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already*/ +#ifndef __AVX512CD_ +#pragma GCC target("avx2,fma") +#endif +#ifdef __AVX2__ #define HAVE_KERNEL_4x4 1 @@ -46,8 +51,39 @@ static void dgemv_kernel_4x4( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT __alpha = _mm256_broadcastsd_pd(_mm_load_sd(alpha)); +#ifdef __AVX512CD__ + int n5; + __m512d x05, x15, x25, x35; + __m512d __alpha5; + n5 = n & ~7; - for (i = 0; i < n; i+= 4) { + x05 = _mm512_broadcastsd_pd(_mm_load_sd(&x[0])); + x15 = _mm512_broadcastsd_pd(_mm_load_sd(&x[1])); + x25 = _mm512_broadcastsd_pd(_mm_load_sd(&x[2])); + x35 = _mm512_broadcastsd_pd(_mm_load_sd(&x[3])); + + __alpha5 = _mm512_broadcastsd_pd(_mm_load_sd(alpha)); + + for (; i < n5; i+= 8) { + __m512d tempY; + __m512d sum; + + sum = _mm512_add_pd( + _mm512_add_pd( + _mm512_mul_pd(_mm512_loadu_pd(&ap[0][i]), x05), + _mm512_mul_pd(_mm512_loadu_pd(&ap[1][i]), x15)), + _mm512_add_pd( + _mm512_mul_pd(_mm512_loadu_pd(&ap[2][i]), x25), + _mm512_mul_pd(_mm512_loadu_pd(&ap[3][i]), x35)) + ); + + tempY = _mm512_loadu_pd(&y[i]); + tempY = _mm512_add_pd(tempY, _mm512_mul_pd(sum, __alpha5)); + _mm512_storeu_pd(&y[i], tempY); + } +#endif + + for (; i < n; i+= 4) { __m256d tempY; __m256d sum; @@ -99,3 +135,5 @@ static void dgemv_kernel_4x2( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT } } + +#endif /* AVX2 */