Add AVX512 support to the dgemv_n_microk_haswell-4.c kernel

Now that the kernel is written in C-with-intrinsics, adding
AVX512 support to this kernel is trivial and yields a pretty significant
performance increase
This commit is contained in:
Arjan van de Ven 2018-08-04 20:48:59 +00:00
parent e52d01cfe7
commit df31ec064e
1 changed files with 39 additions and 1 deletions

View File

@ -25,7 +25,12 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already*/
#ifndef __AVX512CD_
#pragma GCC target("avx2,fma")
#endif
#ifdef __AVX2__
#define HAVE_KERNEL_4x4 1 #define HAVE_KERNEL_4x4 1
@ -46,8 +51,39 @@ static void dgemv_kernel_4x4( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT
__alpha = _mm256_broadcastsd_pd(_mm_load_sd(alpha)); __alpha = _mm256_broadcastsd_pd(_mm_load_sd(alpha));
#ifdef __AVX512CD__
int n5;
__m512d x05, x15, x25, x35;
__m512d __alpha5;
n5 = n & ~7;
for (i = 0; i < n; i+= 4) { x05 = _mm512_broadcastsd_pd(_mm_load_sd(&x[0]));
x15 = _mm512_broadcastsd_pd(_mm_load_sd(&x[1]));
x25 = _mm512_broadcastsd_pd(_mm_load_sd(&x[2]));
x35 = _mm512_broadcastsd_pd(_mm_load_sd(&x[3]));
__alpha5 = _mm512_broadcastsd_pd(_mm_load_sd(alpha));
for (; i < n5; i+= 8) {
__m512d tempY;
__m512d sum;
sum = _mm512_add_pd(
_mm512_add_pd(
_mm512_mul_pd(_mm512_loadu_pd(&ap[0][i]), x05),
_mm512_mul_pd(_mm512_loadu_pd(&ap[1][i]), x15)),
_mm512_add_pd(
_mm512_mul_pd(_mm512_loadu_pd(&ap[2][i]), x25),
_mm512_mul_pd(_mm512_loadu_pd(&ap[3][i]), x35))
);
tempY = _mm512_loadu_pd(&y[i]);
tempY = _mm512_add_pd(tempY, _mm512_mul_pd(sum, __alpha5));
_mm512_storeu_pd(&y[i], tempY);
}
#endif
for (; i < n; i+= 4) {
__m256d tempY; __m256d tempY;
__m256d sum; __m256d sum;
@ -99,3 +135,5 @@ static void dgemv_kernel_4x2( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT
} }
} }
#endif /* AVX2 */