From 2463938879a93ff8b8207b112d03bbeb4cbabae2 Mon Sep 17 00:00:00 2001 From: Guillaume Horel Date: Wed, 11 Sep 2019 10:33:35 -0400 Subject: [PATCH 01/12] fix error message --- interface/lapack/gesv.c | 14 +++++++------- interface/lapack/getf2.c | 2 +- interface/lapack/getrf.c | 2 +- interface/lapack/getrs.c | 2 +- interface/lapack/lauu2.c | 2 +- interface/lapack/lauum.c | 2 +- interface/lapack/potf2.c | 2 +- interface/lapack/potrf.c | 2 +- interface/lapack/potri.c | 2 +- interface/lapack/trti2.c | 2 +- interface/lapack/trtri.c | 2 +- interface/lapack/zgetf2.c | 2 +- interface/lapack/zgetrf.c | 2 +- interface/lapack/zgetrs.c | 2 +- interface/lapack/zlauu2.c | 2 +- interface/lapack/zpotf2.c | 2 +- interface/lapack/zpotrf.c | 2 +- interface/lapack/zpotri.c | 2 +- interface/lapack/ztrti2.c | 2 +- interface/lapack/ztrtri.c | 2 +- 20 files changed, 26 insertions(+), 26 deletions(-) diff --git a/interface/lapack/gesv.c b/interface/lapack/gesv.c index 721da970d..175350329 100644 --- a/interface/lapack/gesv.c +++ b/interface/lapack/gesv.c @@ -44,19 +44,19 @@ #ifndef COMPLEX #ifdef XDOUBLE -#define ERROR_NAME "QGESV " +#define ERROR_NAME "QGESV" #elif defined(DOUBLE) -#define ERROR_NAME "DGESV " +#define ERROR_NAME "DGESV" #else -#define ERROR_NAME "SGESV " +#define ERROR_NAME "SGESV" #endif #else #ifdef XDOUBLE -#define ERROR_NAME "XGESV " +#define ERROR_NAME "XGESV" #elif defined(DOUBLE) -#define ERROR_NAME "ZGESV " +#define ERROR_NAME "ZGESV" #else -#define ERROR_NAME "CGESV " +#define ERROR_NAME "CGESV" #endif #endif @@ -89,7 +89,7 @@ int NAME(blasint *N, blasint *NRHS, FLOAT *a, blasint *ldA, blasint *ipiv, if (args.m < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/getf2.c b/interface/lapack/getf2.c index 3e66c0403..8506feca9 100644 --- a/interface/lapack/getf2.c +++ b/interface/lapack/getf2.c @@ -74,7 +74,7 @@ int NAME(blasint *M, blasint *N, FLOAT *a, blasint *ldA, blasint *ipiv, blasint if (args.n < 0) info = 2; if (args.m < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/getrf.c b/interface/lapack/getrf.c index 44a92ddc4..02bb124b3 100644 --- a/interface/lapack/getrf.c +++ b/interface/lapack/getrf.c @@ -74,7 +74,7 @@ int NAME(blasint *M, blasint *N, FLOAT *a, blasint *ldA, blasint *ipiv, blasint if (args.n < 0) info = 2; if (args.m < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/getrs.c b/interface/lapack/getrs.c index 1b8c83aca..c2a9eb882 100644 --- a/interface/lapack/getrs.c +++ b/interface/lapack/getrs.c @@ -102,7 +102,7 @@ int NAME(char *TRANS, blasint *N, blasint *NRHS, FLOAT *a, blasint *ldA, if (trans < 0) info = 1; if (info != 0) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); return 0; } diff --git a/interface/lapack/lauu2.c b/interface/lapack/lauu2.c index 3599a4791..e581e3c15 100644 --- a/interface/lapack/lauu2.c +++ b/interface/lapack/lauu2.c @@ -90,7 +90,7 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){ if (args.n < 0) info = 2; if (uplo < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/lauum.c b/interface/lapack/lauum.c index 2c49eb0b0..70f6a0ec5 100644 --- a/interface/lapack/lauum.c +++ b/interface/lapack/lauum.c @@ -90,7 +90,7 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){ if (args.n < 0) info = 2; if (uplo < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/potf2.c b/interface/lapack/potf2.c index 837192265..1537b6ee4 100644 --- a/interface/lapack/potf2.c +++ b/interface/lapack/potf2.c @@ -90,7 +90,7 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){ if (args.n < 0) info = 2; if (uplo < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/potrf.c b/interface/lapack/potrf.c index 092272225..dbd55f62f 100644 --- a/interface/lapack/potrf.c +++ b/interface/lapack/potrf.c @@ -90,7 +90,7 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){ if (args.n < 0) info = 2; if (uplo < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/potri.c b/interface/lapack/potri.c index d6230621f..2c0c64b6f 100644 --- a/interface/lapack/potri.c +++ b/interface/lapack/potri.c @@ -99,7 +99,7 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){ if (uplo < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/trti2.c b/interface/lapack/trti2.c index 42c4c4815..47f04f06f 100644 --- a/interface/lapack/trti2.c +++ b/interface/lapack/trti2.c @@ -96,7 +96,7 @@ int NAME(char *UPLO, char *DIAG, blasint *N, FLOAT *a, blasint *ldA, blasint *In if (diag < 0) info = 2; if (uplo < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/trtri.c b/interface/lapack/trtri.c index 6724a678a..028529389 100644 --- a/interface/lapack/trtri.c +++ b/interface/lapack/trtri.c @@ -99,7 +99,7 @@ int NAME(char *UPLO, char *DIAG, blasint *N, FLOAT *a, blasint *ldA, blasint *In if (diag < 0) info = 2; if (uplo < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/zgetf2.c b/interface/lapack/zgetf2.c index 59ec4874e..68b9a7e4b 100644 --- a/interface/lapack/zgetf2.c +++ b/interface/lapack/zgetf2.c @@ -74,7 +74,7 @@ int NAME(blasint *M, blasint *N, FLOAT *a, blasint *ldA, blasint *ipiv, blasint if (args.n < 0) info = 2; if (args.m < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/zgetrf.c b/interface/lapack/zgetrf.c index 5031f587b..7f8db94f6 100644 --- a/interface/lapack/zgetrf.c +++ b/interface/lapack/zgetrf.c @@ -74,7 +74,7 @@ int NAME(blasint *M, blasint *N, FLOAT *a, blasint *ldA, blasint *ipiv, blasint if (args.n < 0) info = 2; if (args.m < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/zgetrs.c b/interface/lapack/zgetrs.c index 54d4b0905..0add909ca 100644 --- a/interface/lapack/zgetrs.c +++ b/interface/lapack/zgetrs.c @@ -102,7 +102,7 @@ int NAME(char *TRANS, blasint *N, blasint *NRHS, FLOAT *a, blasint *ldA, if (trans < 0) info = 1; if (info != 0) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); return 0; } diff --git a/interface/lapack/zlauu2.c b/interface/lapack/zlauu2.c index b0698ef2e..ae972543c 100644 --- a/interface/lapack/zlauu2.c +++ b/interface/lapack/zlauu2.c @@ -91,7 +91,7 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){ if (args.n < 0) info = 2; if (uplo < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/zpotf2.c b/interface/lapack/zpotf2.c index 27ee0891a..c74b66728 100644 --- a/interface/lapack/zpotf2.c +++ b/interface/lapack/zpotf2.c @@ -91,7 +91,7 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){ if (args.n < 0) info = 2; if (uplo < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/zpotrf.c b/interface/lapack/zpotrf.c index 8cd3980d5..c4cd99bf6 100644 --- a/interface/lapack/zpotrf.c +++ b/interface/lapack/zpotrf.c @@ -90,7 +90,7 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){ if (args.n < 0) info = 2; if (uplo < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/zpotri.c b/interface/lapack/zpotri.c index 7c72a7e62..8da211683 100644 --- a/interface/lapack/zpotri.c +++ b/interface/lapack/zpotri.c @@ -99,7 +99,7 @@ int NAME(char *UPLO, blasint *N, FLOAT *a, blasint *ldA, blasint *Info){ if (uplo < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/ztrti2.c b/interface/lapack/ztrti2.c index a25476677..cb9c0d557 100644 --- a/interface/lapack/ztrti2.c +++ b/interface/lapack/ztrti2.c @@ -96,7 +96,7 @@ int NAME(char *UPLO, char *DIAG, blasint *N, FLOAT *a, blasint *ldA, blasint *In if (diag < 0) info = 2; if (uplo < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } diff --git a/interface/lapack/ztrtri.c b/interface/lapack/ztrtri.c index b3ce85b9f..dda4a9e4b 100644 --- a/interface/lapack/ztrtri.c +++ b/interface/lapack/ztrtri.c @@ -96,7 +96,7 @@ int NAME(char *UPLO, char *DIAG, blasint *N, FLOAT *a, blasint *ldA, blasint *In if (diag < 0) info = 2; if (uplo < 0) info = 1; if (info) { - BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME) - 1); *Info = - info; return 0; } From daa4310db5c8abf82233810c240eacd640e78206 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sun, 12 Jan 2020 22:00:50 +0100 Subject: [PATCH 02/12] Install new lapack.h new file in LAPACK 3.9.0, split off from lapacke.h --- Makefile.install | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile.install b/Makefile.install index 8070b4729..e01d866c9 100644 --- a/Makefile.install +++ b/Makefile.install @@ -51,6 +51,7 @@ endif ifneq ($(OSNAME), AIX) ifndef NO_LAPACKE @echo Copying LAPACKE header files to $(DESTDIR)$(OPENBLAS_INCLUDE_DIR) + @-install -pm644 $(NETLIB_LAPACK_DIR)/LAPACKE/include/lapack.h "$(DESTDIR)$(OPENBLAS_INCLUDE_DIR)/lapack.h" @-install -pm644 $(NETLIB_LAPACK_DIR)/LAPACKE/include/lapacke.h "$(DESTDIR)$(OPENBLAS_INCLUDE_DIR)/lapacke.h" @-install -pm644 $(NETLIB_LAPACK_DIR)/LAPACKE/include/lapacke_config.h "$(DESTDIR)$(OPENBLAS_INCLUDE_DIR)/lapacke_config.h" @-install -pm644 $(NETLIB_LAPACK_DIR)/LAPACKE/include/lapacke_mangling_with_flags.h.in "$(DESTDIR)$(OPENBLAS_INCLUDE_DIR)/lapacke_mangling.h" @@ -100,6 +101,7 @@ else #install on AIX has different options syntax ifndef NO_LAPACKE @echo Copying LAPACKE header files to $(DESTDIR)$(OPENBLAS_INCLUDE_DIR) + @-installbsd -c -m 644 $(NETLIB_LAPACK_DIR)/LAPACKE/include/lapack.h "$(DESTDIR)$(OPENBLAS_INCLUDE_DIR)/lapack.h" @-installbsd -c -m 644 $(NETLIB_LAPACK_DIR)/LAPACKE/include/lapacke.h "$(DESTDIR)$(OPENBLAS_INCLUDE_DIR)/lapacke.h" @-installbsd -c -m 644 $(NETLIB_LAPACK_DIR)/LAPACKE/include/lapacke_config.h "$(DESTDIR)$(OPENBLAS_INCLUDE_DIR)/lapacke_config.h" @-installbsd -c -m 644 $(NETLIB_LAPACK_DIR)/LAPACKE/include/lapacke_mangling_with_flags.h.in "$(DESTDIR)$(OPENBLAS_INCLUDE_DIR)/lapacke_mangling.h" From 1c675670081422b8a3d7f0998dfd7d1454c0d2bd Mon Sep 17 00:00:00 2001 From: wjc404 <52632443+wjc404@users.noreply.github.com> Date: Mon, 13 Jan 2020 16:26:03 +0800 Subject: [PATCH 03/12] improve skylakex paralleled sgemm performance --- param.h | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/param.h b/param.h index 70c5945ae..3baae31cf 100644 --- a/param.h +++ b/param.h @@ -1690,18 +1690,13 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #else -#define SGEMM_DEFAULT_P 768 +#define SGEMM_DEFAULT_P 640 #define DGEMM_DEFAULT_P 384 #define CGEMM_DEFAULT_P 384 #define ZGEMM_DEFAULT_P 256 -#ifdef WINDOWS_ABI -#define SGEMM_DEFAULT_Q 192 +#define SGEMM_DEFAULT_Q 320 #define DGEMM_DEFAULT_Q 168 -#else -#define SGEMM_DEFAULT_Q 192 -#define DGEMM_DEFAULT_Q 168 -#endif #define CGEMM_DEFAULT_Q 192 #define ZGEMM_DEFAULT_Q 128 From feaafbedd347871b3f25a018e6655fa9af6d141c Mon Sep 17 00:00:00 2001 From: wjc404 <52632443+wjc404@users.noreply.github.com> Date: Mon, 13 Jan 2020 16:28:41 +0800 Subject: [PATCH 04/12] make skylakex sgemm code more friendly for readers BTW some kernels were adjusted to improve performance --- kernel/x86_64/sgemm_direct_skylakex.c | 467 ++++++++++++ kernel/x86_64/sgemm_kernel_16x4_skylakex.c | 465 +----------- kernel/x86_64/sgemm_kernel_16x4_skylakex_2.c | 713 +++---------------- 3 files changed, 575 insertions(+), 1070 deletions(-) create mode 100644 kernel/x86_64/sgemm_direct_skylakex.c diff --git a/kernel/x86_64/sgemm_direct_skylakex.c b/kernel/x86_64/sgemm_direct_skylakex.c new file mode 100644 index 000000000..4f9af6e57 --- /dev/null +++ b/kernel/x86_64/sgemm_direct_skylakex.c @@ -0,0 +1,467 @@ + +/* the direct sgemm code written by Arjan van der Ven */ +#include + +/* + * "Direct sgemm" code. This code operates directly on the inputs and outputs + * of the sgemm call, avoiding the copies, memory realignments and threading, + * and only supports alpha = 1 and beta = 0. + * This is a common case and provides value for relatively small matrixes. + * For larger matrixes the "regular" sgemm code is superior, there the cost of + * copying/shuffling the B matrix really pays off. + */ + + + +#define DECLARE_RESULT_512(N,M) __m512 result##N##M = _mm512_setzero_ps() +#define BROADCAST_LOAD_A_512(N,M) __m512 Aval##M = _mm512_broadcastss_ps(_mm_load_ss(&A[k + strideA * (i+M)])) +#define LOAD_B_512(N,M) __m512 Bval##N = _mm512_loadu_ps(&B[strideB * k + j + (N*16)]) +#define MATMUL_512(N,M) result##N##M = _mm512_fmadd_ps(Aval##M, Bval##N , result##N##M) +#define STORE_512(N,M) _mm512_storeu_ps(&R[(i+M) * strideR + j+(N*16)], result##N##M) + + +#define DECLARE_RESULT_256(N,M) __m256 result##N##M = _mm256_setzero_ps() +#define BROADCAST_LOAD_A_256(N,M) __m256 Aval##M = _mm256_broadcastss_ps(_mm_load_ss(&A[k + strideA * (i+M)])) +#define LOAD_B_256(N,M) __m256 Bval##N = _mm256_loadu_ps(&B[strideB * k + j + (N*8)]) +#define MATMUL_256(N,M) result##N##M = _mm256_fmadd_ps(Aval##M, Bval##N , result##N##M) +#define STORE_256(N,M) _mm256_storeu_ps(&R[(i+M) * strideR + j+(N*8)], result##N##M) + +#define DECLARE_RESULT_128(N,M) __m128 result##N##M = _mm_setzero_ps() +#define BROADCAST_LOAD_A_128(N,M) __m128 Aval##M = _mm_broadcastss_ps(_mm_load_ss(&A[k + strideA * (i+M)])) +#define LOAD_B_128(N,M) __m128 Bval##N = _mm_loadu_ps(&B[strideB * k + j + (N*4)]) +#define MATMUL_128(N,M) result##N##M = _mm_fmadd_ps(Aval##M, Bval##N , result##N##M) +#define STORE_128(N,M) _mm_storeu_ps(&R[(i+M) * strideR + j+(N*4)], result##N##M) + +#define DECLARE_RESULT_SCALAR(N,M) float result##N##M = 0; +#define BROADCAST_LOAD_A_SCALAR(N,M) float Aval##M = A[k + strideA * (i + M)]; +#define LOAD_B_SCALAR(N,M) float Bval##N = B[k * strideB + j + N]; +#define MATMUL_SCALAR(N,M) result##N##M += Aval##M * Bval##N; +#define STORE_SCALAR(N,M) R[(i+M) * strideR + j + N] = result##N##M; + +int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K) +{ + unsigned long long mnk = M * N * K; + /* large matrixes -> not performant */ + if (mnk >= 28 * 512 * 512) + return 0; + + /* + * if the B matrix is not a nice multiple if 4 we get many unaligned accesses, + * and the regular sgemm copy/realignment of data pays off much quicker + */ + if ((N & 3) != 0 && (mnk >= 8 * 512 * 512)) + return 0; + +#ifdef SMP + /* if we can run multithreaded, the threading changes the based threshold */ + if (mnk > 2 * 350 * 512 && num_cpu_avail(3)> 1) + return 0; +#endif + + return 1; +} + + + +void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR) +{ + int i, j, k; + + int m4 = M & ~3; + int m2 = M & ~1; + + int n64 = N & ~63; + int n32 = N & ~31; + int n16 = N & ~15; + int n8 = N & ~7; + int n4 = N & ~3; + int n2 = N & ~1; + + i = 0; + + for (i = 0; i < m4; i+=4) { + + for (j = 0; j < n64; j+= 64) { + k = 0; + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(x, 0); + BROADCAST_LOAD_A_512(x, 1); + BROADCAST_LOAD_A_512(x, 2); + BROADCAST_LOAD_A_512(x, 3); + + LOAD_B_512(0, x); LOAD_B_512(1, x); LOAD_B_512(2, x); LOAD_B_512(3, x); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); + STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); + } + + for (; j < n32; j+= 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(x, 0); + BROADCAST_LOAD_A_512(x, 1); + BROADCAST_LOAD_A_512(x, 2); + BROADCAST_LOAD_A_512(x, 3); + + LOAD_B_512(0, x); LOAD_B_512(1, x); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + } + + for (; j < n16; j+= 16) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(x, 0); + BROADCAST_LOAD_A_512(x, 1); + BROADCAST_LOAD_A_512(x, 2); + BROADCAST_LOAD_A_512(x, 3); + + LOAD_B_512(0, x); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + } + + for (; j < n8; j+= 8) { + DECLARE_RESULT_256(0, 0); + DECLARE_RESULT_256(0, 1); + DECLARE_RESULT_256(0, 2); + DECLARE_RESULT_256(0, 3); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_256(x, 0); + BROADCAST_LOAD_A_256(x, 1); + BROADCAST_LOAD_A_256(x, 2); + BROADCAST_LOAD_A_256(x, 3); + + LOAD_B_256(0, x); + + MATMUL_256(0, 0); + MATMUL_256(0, 1); + MATMUL_256(0, 2); + MATMUL_256(0, 3); + } + STORE_256(0, 0); + STORE_256(0, 1); + STORE_256(0, 2); + STORE_256(0, 3); + } + + for (; j < n4; j+= 4) { + DECLARE_RESULT_128(0, 0); + DECLARE_RESULT_128(0, 1); + DECLARE_RESULT_128(0, 2); + DECLARE_RESULT_128(0, 3); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_128(x, 0); + BROADCAST_LOAD_A_128(x, 1); + BROADCAST_LOAD_A_128(x, 2); + BROADCAST_LOAD_A_128(x, 3); + + LOAD_B_128(0, x); + + MATMUL_128(0, 0); + MATMUL_128(0, 1); + MATMUL_128(0, 2); + MATMUL_128(0, 3); + } + STORE_128(0, 0); + STORE_128(0, 1); + STORE_128(0, 2); + STORE_128(0, 3); + } + + for (; j < n2; j+= 2) { + DECLARE_RESULT_SCALAR(0, 0); DECLARE_RESULT_SCALAR(1, 0); + DECLARE_RESULT_SCALAR(0, 1); DECLARE_RESULT_SCALAR(1, 1); + DECLARE_RESULT_SCALAR(0, 2); DECLARE_RESULT_SCALAR(1, 2); + DECLARE_RESULT_SCALAR(0, 3); DECLARE_RESULT_SCALAR(1, 3); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_SCALAR(x, 0); + BROADCAST_LOAD_A_SCALAR(x, 1); + BROADCAST_LOAD_A_SCALAR(x, 2); + BROADCAST_LOAD_A_SCALAR(x, 3); + + LOAD_B_SCALAR(0, x); LOAD_B_SCALAR(1, x); + + MATMUL_SCALAR(0, 0); MATMUL_SCALAR(1, 0); + MATMUL_SCALAR(0, 1); MATMUL_SCALAR(1, 1); + MATMUL_SCALAR(0, 2); MATMUL_SCALAR(1, 2); + MATMUL_SCALAR(0, 3); MATMUL_SCALAR(1, 3); + } + STORE_SCALAR(0, 0); STORE_SCALAR(1, 0); + STORE_SCALAR(0, 1); STORE_SCALAR(1, 1); + STORE_SCALAR(0, 2); STORE_SCALAR(1, 2); + STORE_SCALAR(0, 3); STORE_SCALAR(1, 3); + } + + for (; j < N; j++) { + DECLARE_RESULT_SCALAR(0, 0) + DECLARE_RESULT_SCALAR(0, 1) + DECLARE_RESULT_SCALAR(0, 2) + DECLARE_RESULT_SCALAR(0, 3) + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_SCALAR(0, 0); + BROADCAST_LOAD_A_SCALAR(0, 1); + BROADCAST_LOAD_A_SCALAR(0, 2); + BROADCAST_LOAD_A_SCALAR(0, 3); + + LOAD_B_SCALAR(0, 0); + + MATMUL_SCALAR(0, 0); + MATMUL_SCALAR(0, 1); + MATMUL_SCALAR(0, 2); + MATMUL_SCALAR(0, 3); + } + STORE_SCALAR(0, 0); + STORE_SCALAR(0, 1); + STORE_SCALAR(0, 2); + STORE_SCALAR(0, 3); + } + } + + for (; i < m2; i+=2) { + j = 0; + + for (; j < n64; j+= 64) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(x, 0); + BROADCAST_LOAD_A_512(x, 1); + + LOAD_B_512(0, x); LOAD_B_512(1, x); LOAD_B_512(2, x); LOAD_B_512(3, x); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + } + + for (; j < n32; j+= 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(x, 0); + BROADCAST_LOAD_A_512(x, 1); + + LOAD_B_512(0, x); LOAD_B_512(1, x); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + } + + + for (; j < n16; j+= 16) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(x, 0); + BROADCAST_LOAD_A_512(x, 1); + + LOAD_B_512(0, x); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_512(0, 0); + STORE_512(0, 1); + } + + for (; j < n8; j+= 8) { + DECLARE_RESULT_256(0, 0); + DECLARE_RESULT_256(0, 1); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_256(x, 0); + BROADCAST_LOAD_A_256(x, 1); + + LOAD_B_256(0, x); + + MATMUL_256(0, 0); + MATMUL_256(0, 1); + } + STORE_256(0, 0); + STORE_256(0, 1); + } + + for (; j < n4; j+= 4) { + DECLARE_RESULT_128(0, 0); + DECLARE_RESULT_128(0, 1); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_128(x, 0); + BROADCAST_LOAD_A_128(x, 1); + + LOAD_B_128(0, x); + + MATMUL_128(0, 0); + MATMUL_128(0, 1); + } + STORE_128(0, 0); + STORE_128(0, 1); + } + for (; j < n2; j+= 2) { + DECLARE_RESULT_SCALAR(0, 0); DECLARE_RESULT_SCALAR(1, 0); + DECLARE_RESULT_SCALAR(0, 1); DECLARE_RESULT_SCALAR(1, 1); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_SCALAR(x, 0); + BROADCAST_LOAD_A_SCALAR(x, 1); + + LOAD_B_SCALAR(0, x); LOAD_B_SCALAR(1, x); + + MATMUL_SCALAR(0, 0); MATMUL_SCALAR(1, 0); + MATMUL_SCALAR(0, 1); MATMUL_SCALAR(1, 1); + } + STORE_SCALAR(0, 0); STORE_SCALAR(1, 0); + STORE_SCALAR(0, 1); STORE_SCALAR(1, 1); + } + + for (; j < N; j++) { + DECLARE_RESULT_SCALAR(0, 0); + DECLARE_RESULT_SCALAR(0, 1); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_SCALAR(0, 0); + BROADCAST_LOAD_A_SCALAR(0, 1); + + LOAD_B_SCALAR(0, 0); + + MATMUL_SCALAR(0, 0); + MATMUL_SCALAR(0, 1); + } + STORE_SCALAR(0, 0); + STORE_SCALAR(0, 1); + } + } + + for (; i < M; i+=1) { + j = 0; + for (; j < n64; j+= 64) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(x, 0); + LOAD_B_512(0, x); LOAD_B_512(1, x); LOAD_B_512(2, x); LOAD_B_512(3, x); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + } + for (; j < n32; j+= 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(x, 0); + LOAD_B_512(0, x); LOAD_B_512(1, x); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_512(0, 0); STORE_512(1, 0); + } + + + for (; j < n16; j+= 16) { + DECLARE_RESULT_512(0, 0); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(x, 0); + + LOAD_B_512(0, x); + + MATMUL_512(0, 0); + } + STORE_512(0, 0); + } + + for (; j < n8; j+= 8) { + DECLARE_RESULT_256(0, 0); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_256(x, 0); + LOAD_B_256(0, x); + MATMUL_256(0, 0); + } + STORE_256(0, 0); + } + + for (; j < n4; j+= 4) { + DECLARE_RESULT_128(0, 0); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_128(x, 0); + LOAD_B_128(0, x); + MATMUL_128(0, 0); + } + STORE_128(0, 0); + } + + for (; j < n2; j+= 2) { + DECLARE_RESULT_SCALAR(0, 0); DECLARE_RESULT_SCALAR(1, 0); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_SCALAR(x, 0); + LOAD_B_SCALAR(0, 0); LOAD_B_SCALAR(1, 0); + MATMUL_SCALAR(0, 0); MATMUL_SCALAR(1, 0); + } + STORE_SCALAR(0, 0); STORE_SCALAR(1, 0); + } + + for (; j < N; j++) { + DECLARE_RESULT_SCALAR(0, 0); + + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_SCALAR(0, 0); + LOAD_B_SCALAR(0, 0); + MATMUL_SCALAR(0, 0); + } + STORE_SCALAR(0, 0); + } + } +} diff --git a/kernel/x86_64/sgemm_kernel_16x4_skylakex.c b/kernel/x86_64/sgemm_kernel_16x4_skylakex.c index 76b82e65b..d174bbcc3 100644 --- a/kernel/x86_64/sgemm_kernel_16x4_skylakex.c +++ b/kernel/x86_64/sgemm_kernel_16x4_skylakex.c @@ -1176,467 +1176,4 @@ CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float * __restrict A, flo return 0; } - -/* - * "Direct sgemm" code. This code operates directly on the inputs and outputs - * of the sgemm call, avoiding the copies, memory realignments and threading, - * and only supports alpha = 1 and beta = 0. - * This is a common case and provides value for relatively small matrixes. - * For larger matrixes the "regular" sgemm code is superior, there the cost of - * copying/shuffling the B matrix really pays off. - */ - - - -#define DECLARE_RESULT_512(N,M) __m512 result##N##M = _mm512_setzero_ps() -#define BROADCAST_LOAD_A_512(N,M) __m512 Aval##M = _mm512_broadcastss_ps(_mm_load_ss(&A[k + strideA * (i+M)])) -#define LOAD_B_512(N,M) __m512 Bval##N = _mm512_loadu_ps(&B[strideB * k + j + (N*16)]) -#define MATMUL_512(N,M) result##N##M = _mm512_fmadd_ps(Aval##M, Bval##N , result##N##M) -#define STORE_512(N,M) _mm512_storeu_ps(&R[(i+M) * strideR + j+(N*16)], result##N##M) - - -#define DECLARE_RESULT_256(N,M) __m256 result##N##M = _mm256_setzero_ps() -#define BROADCAST_LOAD_A_256(N,M) __m256 Aval##M = _mm256_broadcastss_ps(_mm_load_ss(&A[k + strideA * (i+M)])) -#define LOAD_B_256(N,M) __m256 Bval##N = _mm256_loadu_ps(&B[strideB * k + j + (N*8)]) -#define MATMUL_256(N,M) result##N##M = _mm256_fmadd_ps(Aval##M, Bval##N , result##N##M) -#define STORE_256(N,M) _mm256_storeu_ps(&R[(i+M) * strideR + j+(N*8)], result##N##M) - -#define DECLARE_RESULT_128(N,M) __m128 result##N##M = _mm_setzero_ps() -#define BROADCAST_LOAD_A_128(N,M) __m128 Aval##M = _mm_broadcastss_ps(_mm_load_ss(&A[k + strideA * (i+M)])) -#define LOAD_B_128(N,M) __m128 Bval##N = _mm_loadu_ps(&B[strideB * k + j + (N*4)]) -#define MATMUL_128(N,M) result##N##M = _mm_fmadd_ps(Aval##M, Bval##N , result##N##M) -#define STORE_128(N,M) _mm_storeu_ps(&R[(i+M) * strideR + j+(N*4)], result##N##M) - -#define DECLARE_RESULT_SCALAR(N,M) float result##N##M = 0; -#define BROADCAST_LOAD_A_SCALAR(N,M) float Aval##M = A[k + strideA * (i + M)]; -#define LOAD_B_SCALAR(N,M) float Bval##N = B[k * strideB + j + N]; -#define MATMUL_SCALAR(N,M) result##N##M += Aval##M * Bval##N; -#define STORE_SCALAR(N,M) R[(i+M) * strideR + j + N] = result##N##M; - -int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K) -{ - unsigned long long mnk = M * N * K; - /* large matrixes -> not performant */ - if (mnk >= 28 * 512 * 512) - return 0; - - /* - * if the B matrix is not a nice multiple if 4 we get many unaligned accesses, - * and the regular sgemm copy/realignment of data pays off much quicker - */ - if ((N & 3) != 0 && (mnk >= 8 * 512 * 512)) - return 0; - -#ifdef SMP - /* if we can run multithreaded, the threading changes the based threshold */ - if (mnk > 2 * 350 * 512 && num_cpu_avail(3)> 1) - return 0; -#endif - - return 1; -} - - - -void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR) -{ - int i, j, k; - - int m4 = M & ~3; - int m2 = M & ~1; - - int n64 = N & ~63; - int n32 = N & ~31; - int n16 = N & ~15; - int n8 = N & ~7; - int n4 = N & ~3; - int n2 = N & ~1; - - i = 0; - - for (i = 0; i < m4; i+=4) { - - for (j = 0; j < n64; j+= 64) { - k = 0; - DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); - DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); - DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); - DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); - - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - BROADCAST_LOAD_A_512(x, 1); - BROADCAST_LOAD_A_512(x, 2); - BROADCAST_LOAD_A_512(x, 3); - - LOAD_B_512(0, x); LOAD_B_512(1, x); LOAD_B_512(2, x); LOAD_B_512(3, x); - - MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); - MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); - MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); - MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); - } - STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); - STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); - STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); - STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); - } - - for (; j < n32; j+= 32) { - DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); - DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); - DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); - DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - BROADCAST_LOAD_A_512(x, 1); - BROADCAST_LOAD_A_512(x, 2); - BROADCAST_LOAD_A_512(x, 3); - - LOAD_B_512(0, x); LOAD_B_512(1, x); - - MATMUL_512(0, 0); MATMUL_512(1, 0); - MATMUL_512(0, 1); MATMUL_512(1, 1); - MATMUL_512(0, 2); MATMUL_512(1, 2); - MATMUL_512(0, 3); MATMUL_512(1, 3); - } - STORE_512(0, 0); STORE_512(1, 0); - STORE_512(0, 1); STORE_512(1, 1); - STORE_512(0, 2); STORE_512(1, 2); - STORE_512(0, 3); STORE_512(1, 3); - } - - for (; j < n16; j+= 16) { - DECLARE_RESULT_512(0, 0); - DECLARE_RESULT_512(0, 1); - DECLARE_RESULT_512(0, 2); - DECLARE_RESULT_512(0, 3); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - BROADCAST_LOAD_A_512(x, 1); - BROADCAST_LOAD_A_512(x, 2); - BROADCAST_LOAD_A_512(x, 3); - - LOAD_B_512(0, x); - - MATMUL_512(0, 0); - MATMUL_512(0, 1); - MATMUL_512(0, 2); - MATMUL_512(0, 3); - } - STORE_512(0, 0); - STORE_512(0, 1); - STORE_512(0, 2); - STORE_512(0, 3); - } - - for (; j < n8; j+= 8) { - DECLARE_RESULT_256(0, 0); - DECLARE_RESULT_256(0, 1); - DECLARE_RESULT_256(0, 2); - DECLARE_RESULT_256(0, 3); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_256(x, 0); - BROADCAST_LOAD_A_256(x, 1); - BROADCAST_LOAD_A_256(x, 2); - BROADCAST_LOAD_A_256(x, 3); - - LOAD_B_256(0, x); - - MATMUL_256(0, 0); - MATMUL_256(0, 1); - MATMUL_256(0, 2); - MATMUL_256(0, 3); - } - STORE_256(0, 0); - STORE_256(0, 1); - STORE_256(0, 2); - STORE_256(0, 3); - } - - for (; j < n4; j+= 4) { - DECLARE_RESULT_128(0, 0); - DECLARE_RESULT_128(0, 1); - DECLARE_RESULT_128(0, 2); - DECLARE_RESULT_128(0, 3); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_128(x, 0); - BROADCAST_LOAD_A_128(x, 1); - BROADCAST_LOAD_A_128(x, 2); - BROADCAST_LOAD_A_128(x, 3); - - LOAD_B_128(0, x); - - MATMUL_128(0, 0); - MATMUL_128(0, 1); - MATMUL_128(0, 2); - MATMUL_128(0, 3); - } - STORE_128(0, 0); - STORE_128(0, 1); - STORE_128(0, 2); - STORE_128(0, 3); - } - - for (; j < n2; j+= 2) { - DECLARE_RESULT_SCALAR(0, 0); DECLARE_RESULT_SCALAR(1, 0); - DECLARE_RESULT_SCALAR(0, 1); DECLARE_RESULT_SCALAR(1, 1); - DECLARE_RESULT_SCALAR(0, 2); DECLARE_RESULT_SCALAR(1, 2); - DECLARE_RESULT_SCALAR(0, 3); DECLARE_RESULT_SCALAR(1, 3); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_SCALAR(x, 0); - BROADCAST_LOAD_A_SCALAR(x, 1); - BROADCAST_LOAD_A_SCALAR(x, 2); - BROADCAST_LOAD_A_SCALAR(x, 3); - - LOAD_B_SCALAR(0, x); LOAD_B_SCALAR(1, x); - - MATMUL_SCALAR(0, 0); MATMUL_SCALAR(1, 0); - MATMUL_SCALAR(0, 1); MATMUL_SCALAR(1, 1); - MATMUL_SCALAR(0, 2); MATMUL_SCALAR(1, 2); - MATMUL_SCALAR(0, 3); MATMUL_SCALAR(1, 3); - } - STORE_SCALAR(0, 0); STORE_SCALAR(1, 0); - STORE_SCALAR(0, 1); STORE_SCALAR(1, 1); - STORE_SCALAR(0, 2); STORE_SCALAR(1, 2); - STORE_SCALAR(0, 3); STORE_SCALAR(1, 3); - } - - for (; j < N; j++) { - DECLARE_RESULT_SCALAR(0, 0) - DECLARE_RESULT_SCALAR(0, 1) - DECLARE_RESULT_SCALAR(0, 2) - DECLARE_RESULT_SCALAR(0, 3) - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_SCALAR(0, 0); - BROADCAST_LOAD_A_SCALAR(0, 1); - BROADCAST_LOAD_A_SCALAR(0, 2); - BROADCAST_LOAD_A_SCALAR(0, 3); - - LOAD_B_SCALAR(0, 0); - - MATMUL_SCALAR(0, 0); - MATMUL_SCALAR(0, 1); - MATMUL_SCALAR(0, 2); - MATMUL_SCALAR(0, 3); - } - STORE_SCALAR(0, 0); - STORE_SCALAR(0, 1); - STORE_SCALAR(0, 2); - STORE_SCALAR(0, 3); - } - } - - for (; i < m2; i+=2) { - j = 0; - - for (; j < n64; j+= 64) { - DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); - DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); - - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - BROADCAST_LOAD_A_512(x, 1); - - LOAD_B_512(0, x); LOAD_B_512(1, x); LOAD_B_512(2, x); LOAD_B_512(3, x); - - MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); - MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); - } - STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); - STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); - } - - for (; j < n32; j+= 32) { - DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); - DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - BROADCAST_LOAD_A_512(x, 1); - - LOAD_B_512(0, x); LOAD_B_512(1, x); - - MATMUL_512(0, 0); MATMUL_512(1, 0); - MATMUL_512(0, 1); MATMUL_512(1, 1); - } - STORE_512(0, 0); STORE_512(1, 0); - STORE_512(0, 1); STORE_512(1, 1); - } - - - for (; j < n16; j+= 16) { - DECLARE_RESULT_512(0, 0); - DECLARE_RESULT_512(0, 1); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - BROADCAST_LOAD_A_512(x, 1); - - LOAD_B_512(0, x); - - MATMUL_512(0, 0); - MATMUL_512(0, 1); - } - STORE_512(0, 0); - STORE_512(0, 1); - } - - for (; j < n8; j+= 8) { - DECLARE_RESULT_256(0, 0); - DECLARE_RESULT_256(0, 1); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_256(x, 0); - BROADCAST_LOAD_A_256(x, 1); - - LOAD_B_256(0, x); - - MATMUL_256(0, 0); - MATMUL_256(0, 1); - } - STORE_256(0, 0); - STORE_256(0, 1); - } - - for (; j < n4; j+= 4) { - DECLARE_RESULT_128(0, 0); - DECLARE_RESULT_128(0, 1); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_128(x, 0); - BROADCAST_LOAD_A_128(x, 1); - - LOAD_B_128(0, x); - - MATMUL_128(0, 0); - MATMUL_128(0, 1); - } - STORE_128(0, 0); - STORE_128(0, 1); - } - for (; j < n2; j+= 2) { - DECLARE_RESULT_SCALAR(0, 0); DECLARE_RESULT_SCALAR(1, 0); - DECLARE_RESULT_SCALAR(0, 1); DECLARE_RESULT_SCALAR(1, 1); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_SCALAR(x, 0); - BROADCAST_LOAD_A_SCALAR(x, 1); - - LOAD_B_SCALAR(0, x); LOAD_B_SCALAR(1, x); - - MATMUL_SCALAR(0, 0); MATMUL_SCALAR(1, 0); - MATMUL_SCALAR(0, 1); MATMUL_SCALAR(1, 1); - } - STORE_SCALAR(0, 0); STORE_SCALAR(1, 0); - STORE_SCALAR(0, 1); STORE_SCALAR(1, 1); - } - - for (; j < N; j++) { - DECLARE_RESULT_SCALAR(0, 0); - DECLARE_RESULT_SCALAR(0, 1); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_SCALAR(0, 0); - BROADCAST_LOAD_A_SCALAR(0, 1); - - LOAD_B_SCALAR(0, 0); - - MATMUL_SCALAR(0, 0); - MATMUL_SCALAR(0, 1); - } - STORE_SCALAR(0, 0); - STORE_SCALAR(0, 1); - } - } - - for (; i < M; i+=1) { - j = 0; - for (; j < n64; j+= 64) { - DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - LOAD_B_512(0, x); LOAD_B_512(1, x); LOAD_B_512(2, x); LOAD_B_512(3, x); - MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); - } - STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); - } - for (; j < n32; j+= 32) { - DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - LOAD_B_512(0, x); LOAD_B_512(1, x); - MATMUL_512(0, 0); MATMUL_512(1, 0); - } - STORE_512(0, 0); STORE_512(1, 0); - } - - - for (; j < n16; j+= 16) { - DECLARE_RESULT_512(0, 0); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - - LOAD_B_512(0, x); - - MATMUL_512(0, 0); - } - STORE_512(0, 0); - } - - for (; j < n8; j+= 8) { - DECLARE_RESULT_256(0, 0); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_256(x, 0); - LOAD_B_256(0, x); - MATMUL_256(0, 0); - } - STORE_256(0, 0); - } - - for (; j < n4; j+= 4) { - DECLARE_RESULT_128(0, 0); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_128(x, 0); - LOAD_B_128(0, x); - MATMUL_128(0, 0); - } - STORE_128(0, 0); - } - - for (; j < n2; j+= 2) { - DECLARE_RESULT_SCALAR(0, 0); DECLARE_RESULT_SCALAR(1, 0); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_SCALAR(x, 0); - LOAD_B_SCALAR(0, 0); LOAD_B_SCALAR(1, 0); - MATMUL_SCALAR(0, 0); MATMUL_SCALAR(1, 0); - } - STORE_SCALAR(0, 0); STORE_SCALAR(1, 0); - } - - for (; j < N; j++) { - DECLARE_RESULT_SCALAR(0, 0); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_SCALAR(0, 0); - LOAD_B_SCALAR(0, 0); - MATMUL_SCALAR(0, 0); - } - STORE_SCALAR(0, 0); - } - } -} +#include "sgemm_direct_skylakex.c" diff --git a/kernel/x86_64/sgemm_kernel_16x4_skylakex_2.c b/kernel/x86_64/sgemm_kernel_16x4_skylakex_2.c index ee3417505..e4ca6b1bd 100644 --- a/kernel/x86_64/sgemm_kernel_16x4_skylakex_2.c +++ b/kernel/x86_64/sgemm_kernel_16x4_skylakex_2.c @@ -1,5 +1,5 @@ /* %0 = "+r"(a_pointer), %1 = "+r"(b_pointer), %2 = "+r"(c_pointer), %3 = "+r"(ldc_in_bytes), %4 for k_count, %5 for c_store */ -/* r12 = k << 4(const), r13 = k(const), r14 = b_head_pos(const), r15 = %1 + 3r12 */ +/* r10 to assist prefetch, r12 = k << 4(const), r13 = k(const), r14 = b_head_pos(const), r15 = %1 + 3r12 */ #include "common.h" #include @@ -53,26 +53,25 @@ #define SAVE_m16(ndim) SAVE_h_m16n##ndim "addq $64,%2;" #define COMPUTE_m16(ndim) \ INIT_m16n##ndim\ - "movq %%r13,%4; movq %%r14,%1; leaq (%1,%%r12,2),%%r15; addq %%r12,%%r15; movq %2,%5;"\ - "cmpq $18,%4; jb "#ndim"016162f;"\ + "movq %%r13,%4; movq %%r14,%1; leaq (%1,%%r12,2),%%r15; addq %%r12,%%r15; movq %2,%5; xorq %%r10,%%r10;"\ + "cmpq $16,%4; jb "#ndim"016162f;"\ #ndim"016161:\n\t"\ + "cmpq $126,%%r10; movq $126,%%r10; cmoveq %3,%%r10;"\ KERNEL_k1m16n##ndim\ KERNEL_k1m16n##ndim\ - KERNEL_k1m16n##ndim\ - "prefetcht1 (%5); prefetcht1 63(%5); addq %3,%5;"\ + "prefetcht1 (%5); subq $63,%5; addq %%r10,%5;"\ KERNEL_k1m16n##ndim\ KERNEL_k1m16n##ndim\ - KERNEL_k1m16n##ndim\ - "prefetcht1 (%8); addq $32,%8;"\ - "subq $6,%4; cmpq $18,%4; jnb "#ndim"016161b;"\ + "prefetcht1 (%6); addq $32,%6;"\ + "subq $4,%4; cmpq $16,%4; jnb "#ndim"016161b;"\ "movq %2,%5;"\ #ndim"016162:\n\t"\ - "testq %4,%4; jz "#ndim"016163f;"\ + "testq %4,%4; jz "#ndim"016164f;"\ + #ndim"016163:\n\t"\ "prefetcht0 (%5); prefetcht0 63(%5); prefetcht0 (%5,%3,1); prefetcht0 63(%5,%3,1);"\ KERNEL_k1m16n##ndim\ - "leaq (%5,%3,2),%5;"\ - "decq %4; jmp "#ndim"016162b;"\ - #ndim"016163:\n\t"\ + "leaq (%5,%3,2),%5; decq %4; jnz "#ndim"016163b;"\ + #ndim"016164:\n\t"\ "prefetcht0 (%%r14); prefetcht0 64(%%r14);"\ SAVE_m16(ndim) @@ -212,185 +211,152 @@ #define COMPUTE_m4_n24 COMPUTE_L_m4(12,55555) COMPUTE_R_m4(12,55955) #define COMPUTE_m4(ndim) COMPUTE_m4_n##ndim -/* m = 2 *//* xmm0 for alpha, xmm1-xmm3 and xmm10 for temporary use, xmm4-xmm9 for accumulators */ +/* m = 2 *//* xmm0 for alpha, xmm1-xmm3 for temporary use, xmm4-xmm15 for accumulators */ #define INIT_m2n1 "vpxor %%xmm4,%%xmm4,%%xmm4;" -#define KERNEL_k1m2n1(b_addr) \ +#define KERNEL_k1m2n1 \ "vmovsd (%0),%%xmm1; addq $8,%0;"\ - "vbroadcastss ("#b_addr"),%%xmm2; vfmadd231ps %%xmm1,%%xmm2,%%xmm4;"\ - "addq $4,"#b_addr";" -#define SAVE_L_m2n1 "vmovsd (%2),%%xmm1; vfmadd213ps %%xmm1,%%xmm0,%%xmm4; vmovsd %%xmm4,(%2);" + "vbroadcastss (%1),%%xmm2; vfmadd231ps %%xmm1,%%xmm2,%%xmm4;"\ + "addq $4,%1;" +#define SAVE_h_m2n1 "vmovsd (%2),%%xmm1; vfmadd213ps %%xmm1,%%xmm0,%%xmm4; vmovsd %%xmm4,(%2);" #define INIT_m2n2 INIT_m2n1 "vpxor %%xmm5,%%xmm5,%%xmm5;" -#define KERNEL_k1m2n2(b_addr) \ +#define KERNEL_k1m2n2 \ "vmovsd (%0),%%xmm1; addq $8,%0;"\ - "vbroadcastss ("#b_addr"),%%xmm2; vfmadd231ps %%xmm1,%%xmm2,%%xmm4;"\ - "vbroadcastss 4("#b_addr"),%%xmm3; vfmadd231ps %%xmm1,%%xmm3,%%xmm5;"\ - "addq $8,"#b_addr";" -#define SAVE_L_m2n2 SAVE_L_m2n1 "vmovsd (%2,%3,1),%%xmm1; vfmadd213ps %%xmm1,%%xmm0,%%xmm5; vmovsd %%xmm5,(%2,%3,1);" + "vbroadcastss (%1),%%xmm2; vfmadd231ps %%xmm1,%%xmm2,%%xmm4;"\ + "vbroadcastss 4(%1),%%xmm3; vfmadd231ps %%xmm1,%%xmm3,%%xmm5;"\ + "addq $8,%1;" +#define SAVE_h_m2n2 SAVE_h_m2n1 "vmovsd (%2,%3,1),%%xmm1; vfmadd213ps %%xmm1,%%xmm0,%%xmm5; vmovsd %%xmm5,(%2,%3,1);" #define INIT_m2n4 INIT_m2n2 #define INIT_m2n8 INIT_m2n4 "vpxor %%xmm6,%%xmm6,%%xmm6; vpxor %%xmm7,%%xmm7,%%xmm7;" #define INIT_m2n12 INIT_m2n8 "vpxor %%xmm8,%%xmm8,%%xmm8; vpxor %%xmm9,%%xmm9,%%xmm9;" -#define KERNEL_k1m2n4(b_addr) \ - "vmovups ("#b_addr"),%%xmm3; addq $16,"#b_addr";"\ - "vbroadcastss (%0),%%xmm1; vfmadd231ps %%xmm3,%%xmm1,%%xmm4;"\ - "vbroadcastss 4(%0),%%xmm2; vfmadd231ps %%xmm3,%%xmm2,%%xmm5;"\ - "addq $8,%0;" -#define KERNEL_k1m2n8(b_addr) \ - "vmovups ("#b_addr"),%%xmm3; vmovups ("#b_addr",%%r12,1),%%xmm2; addq $16,"#b_addr";"\ - "vbroadcastss (%0),%%xmm1; vfmadd231ps %%xmm3,%%xmm1,%%xmm4; vfmadd231ps %%xmm2,%%xmm1,%%xmm6;"\ - "vbroadcastss 4(%0),%%xmm1; vfmadd231ps %%xmm3,%%xmm1,%%xmm5; vfmadd231ps %%xmm2,%%xmm1,%%xmm7;"\ - "addq $8,%0;" -#define KERNEL_k1m2n12(b_addr) \ - "vmovups ("#b_addr"),%%xmm3; vmovups ("#b_addr",%%r12,1),%%xmm2; vmovups ("#b_addr",%%r12,2),%%xmm1; addq $16,"#b_addr";"\ - "vbroadcastss (%0),%%xmm10; vfmadd231ps %%xmm3,%%xmm10,%%xmm4; vfmadd231ps %%xmm2,%%xmm10,%%xmm6; vfmadd231ps %%xmm1,%%xmm10,%%xmm8;"\ - "vbroadcastss 4(%0),%%xmm10; vfmadd231ps %%xmm3,%%xmm10,%%xmm5; vfmadd231ps %%xmm2,%%xmm10,%%xmm7; vfmadd231ps %%xmm1,%%xmm10,%%xmm9;"\ - "addq $8,%0;" +#define INIT_m2n16 INIT_m2n12 "vpxor %%xmm10,%%xmm10,%%xmm10; vpxor %%xmm11,%%xmm11,%%xmm11;" +#define INIT_m2n20 INIT_m2n16 "vpxor %%xmm12,%%xmm12,%%xmm12; vpxor %%xmm13,%%xmm13,%%xmm13;" +#define INIT_m2n24 INIT_m2n20 "vpxor %%xmm14,%%xmm14,%%xmm14; vpxor %%xmm15,%%xmm15,%%xmm15;" +#define KERNEL_h_k1m2n4 \ + "vbroadcastss (%0),%%xmm1; vbroadcastss 4(%0),%%xmm2; addq $8,%0;"\ + "vmovups (%1),%%xmm3; vfmadd231ps %%xmm1,%%xmm3,%%xmm4; vfmadd231ps %%xmm2,%%xmm3,%%xmm5;" +#define KERNEL_k1m2n4 KERNEL_h_k1m2n4 "addq $16,%1;" +#define KERNEL_h_k1m2n8 KERNEL_h_k1m2n4 "vmovups (%1,%%r12,1),%%xmm3; vfmadd231ps %%xmm1,%%xmm3,%%xmm6; vfmadd231ps %%xmm2,%%xmm3,%%xmm7;" +#define KERNEL_k1m2n8 KERNEL_h_k1m2n8 "addq $16,%1;" +#define KERNEL_k1m2n12 KERNEL_h_k1m2n8 \ + "vmovups (%1,%%r12,2),%%xmm3; vfmadd231ps %%xmm1,%%xmm3,%%xmm8; vfmadd231ps %%xmm2,%%xmm3,%%xmm9; addq $16,%1;" +#define KERNEL_h_k1m2n16 KERNEL_k1m2n12 "vmovups (%%r15),%%xmm3; vfmadd231ps %%xmm1,%%xmm3,%%xmm10; vfmadd231ps %%xmm2,%%xmm3,%%xmm11;" +#define KERNEL_k1m2n16 KERNEL_h_k1m2n16 "addq $16,%%r15;" +#define KERNEL_h_k1m2n20 KERNEL_h_k1m2n16 "vmovups (%%r15,%%r12,1),%%xmm3; vfmadd231ps %%xmm1,%%xmm3,%%xmm12; vfmadd231ps %%xmm2,%%xmm3,%%xmm13;" +#define KERNEL_k1m2n20 KERNEL_h_k1m2n20 "addq $16,%%r15;" +#define KERNEL_h_k1m2n24 KERNEL_h_k1m2n20 "vmovups (%%r15,%%r12,2),%%xmm3; vfmadd231ps %%xmm1,%%xmm3,%%xmm14; vfmadd231ps %%xmm2,%%xmm3,%%xmm15;" +#define KERNEL_k1m2n24 KERNEL_h_k1m2n24 "addq $16,%%r15;" #define unit_save_m2n4(c1,c2) \ "vunpcklps "#c2","#c1",%%xmm1; vunpckhps "#c2","#c1",%%xmm2;"\ "vmovsd (%5),%%xmm3; vmovhpd (%5,%3,1),%%xmm3,%%xmm3; vfmadd213ps %%xmm3,%%xmm0,%%xmm1; vmovsd %%xmm1,(%5); vmovhpd %%xmm1,(%5,%3,1);"\ "leaq (%5,%3,2),%5;"\ "vmovsd (%5),%%xmm3; vmovhpd (%5,%3,1),%%xmm3,%%xmm3; vfmadd213ps %%xmm3,%%xmm0,%%xmm2; vmovsd %%xmm2,(%5); vmovhpd %%xmm2,(%5,%3,1);"\ "leaq (%5,%3,2),%5;" -#define SAVE_L_m2n4 "movq %2,%5;" unit_save_m2n4(%%xmm4,%%xmm5) -#define SAVE_L_m2n8 SAVE_L_m2n4 unit_save_m2n4(%%xmm6,%%xmm7) -#define SAVE_L_m2n12 SAVE_L_m2n8 unit_save_m2n4(%%xmm8,%%xmm9) -#define SAVE_R_m2n4 unit_save_m2n4(%%xmm4,%%xmm5) -#define SAVE_R_m2n8 SAVE_R_m2n4 unit_save_m2n4(%%xmm6,%%xmm7) -#define SAVE_R_m2n12 SAVE_R_m2n8 unit_save_m2n4(%%xmm8,%%xmm9) -#define COMPUTE_L_m2(ndim,sim) \ +#define SAVE_h_m2n4 "movq %2,%5;" unit_save_m2n4(%%xmm4,%%xmm5) +#define SAVE_h_m2n8 SAVE_h_m2n4 unit_save_m2n4(%%xmm6,%%xmm7) +#define SAVE_h_m2n12 SAVE_h_m2n8 unit_save_m2n4(%%xmm8,%%xmm9) +#define SAVE_h_m2n16 SAVE_h_m2n12 unit_save_m2n4(%%xmm10,%%xmm11) +#define SAVE_h_m2n20 SAVE_h_m2n16 unit_save_m2n4(%%xmm12,%%xmm13) +#define SAVE_h_m2n24 SAVE_h_m2n20 unit_save_m2n4(%%xmm14,%%xmm15) +#define SAVE_m2(ndim) SAVE_h_m2n##ndim "addq $8,%2;" +#define COMPUTE_m2(ndim) \ INIT_m2n##ndim\ - "movq %%r13,%4; movq %%r14,%1;"\ - #ndim""#sim"222:\n\t"\ - "testq %4,%4; jz "#ndim""#sim"223f;"\ - KERNEL_k1m2n##ndim(%1)\ - "decq %4; jmp "#ndim""#sim"222b;"\ - #ndim""#sim"223:\n\t"\ - SAVE_L_m2n##ndim "addq $8,%2;" -#define COMPUTE_R_m2(ndim,sim) \ - "salq $3,%%r13;subq %%r13,%0;sarq $3,%%r13;"\ - INIT_m2n##ndim\ - "movq %%r13,%4; leaq (%%r14,%%r12,2),%%r15; addq %%r12,%%r15;"\ - #ndim""#sim"222:\n\t"\ - "testq %4,%4; jz "#ndim""#sim"223f;"\ - KERNEL_k1m2n##ndim(%%r15)\ - "decq %4; jmp "#ndim""#sim"222b;"\ - #ndim""#sim"223:\n\t"\ - SAVE_R_m2n##ndim -#define COMPUTE_m2_n1 COMPUTE_L_m2(1,77877) -#define COMPUTE_m2_n2 COMPUTE_L_m2(2,77877) -#define COMPUTE_m2_n4 COMPUTE_L_m2(4,77877) -#define COMPUTE_m2_n8 COMPUTE_L_m2(8,77877) -#define COMPUTE_m2_n12 COMPUTE_L_m2(12,77877) -#define COMPUTE_m2_n16 COMPUTE_L_m2(12,77777) COMPUTE_R_m2(4,77977) -#define COMPUTE_m2_n20 COMPUTE_L_m2(12,77677) COMPUTE_R_m2(8,77977) -#define COMPUTE_m2_n24 COMPUTE_L_m2(12,77577) COMPUTE_R_m2(12,77977) -#define COMPUTE_m2(ndim) COMPUTE_m2_n##ndim + "movq %%r13,%4; movq %%r14,%1; leaq (%1,%%r12,2),%%r15; addq %%r12,%%r15;"\ + "testq %4,%4; jz "#ndim"002022f;"\ + #ndim"002021:\n\t"\ + KERNEL_k1m2n##ndim "decq %4; jnz "#ndim"002021b;"\ + #ndim"002022:\n\t"\ + SAVE_m2(ndim) -/* m = 1 *//* xmm0 for alpha, xmm1-xmm3 and xmm10 for temporary use, xmm4-xmm6 for accumulators */ +/* m = 1 *//* xmm0 for alpha, xmm1-xmm3 and xmm10 for temporary use, xmm4-xmm9 for accumulators */ #define INIT_m1n1 "vpxor %%xmm4,%%xmm4,%%xmm4;" -#define KERNEL_k1m1n1(b_addr) \ - "vmovss ("#b_addr"),%%xmm3; addq $4,"#b_addr";"\ +#define KERNEL_k1m1n1 \ + "vmovss (%1),%%xmm3; addq $4,%1;"\ "vmovss (%0),%%xmm1; vfmadd231ss %%xmm3,%%xmm1,%%xmm4;"\ "addq $4,%0;" -#define SAVE_L_m1n1 "vfmadd213ss (%2),%%xmm0,%%xmm4; vmovss %%xmm4,(%2);" +#define SAVE_h_m1n1 "vfmadd213ss (%2),%%xmm0,%%xmm4; vmovss %%xmm4,(%2);" #define INIT_m1n2 INIT_m1n1 -#define KERNEL_k1m1n2(b_addr) \ - "vmovsd ("#b_addr"),%%xmm3; addq $8,"#b_addr";"\ +#define KERNEL_k1m1n2 \ + "vmovsd (%1),%%xmm3; addq $8,%1;"\ "vbroadcastss (%0),%%xmm1; vfmadd231ps %%xmm3,%%xmm1,%%xmm4;"\ "addq $4,%0;" -#define SAVE_L_m1n2 \ +#define SAVE_h_m1n2 \ "vmovss (%2),%%xmm3; vinsertps $16,(%2,%3,1),%%xmm3,%%xmm3; vfmadd213ps %%xmm3,%%xmm0,%%xmm4;"\ "vmovss %%xmm4,(%2); vextractps $1,%%xmm4,(%2,%3,1);" #define INIT_m1n4 INIT_m1n2 #define INIT_m1n8 INIT_m1n4 "vpxor %%xmm5,%%xmm5,%%xmm5;" #define INIT_m1n12 INIT_m1n8 "vpxor %%xmm6,%%xmm6,%%xmm6;" -#define KERNEL_k1m1n4(b_addr) \ - "vmovups ("#b_addr"),%%xmm3; addq $16,"#b_addr";"\ - "vbroadcastss (%0),%%xmm1; vfmadd231ps %%xmm3,%%xmm1,%%xmm4;"\ - "addq $4,%0;" -#define KERNEL_k1m1n8(b_addr) \ - "vmovups ("#b_addr"),%%xmm3; vmovups ("#b_addr",%%r12,1),%%xmm2; addq $16,"#b_addr";"\ - "vbroadcastss (%0),%%xmm1; vfmadd231ps %%xmm3,%%xmm1,%%xmm4; vfmadd231ps %%xmm2,%%xmm1,%%xmm5;"\ - "addq $4,%0;" -#define KERNEL_k1m1n12(b_addr) \ - "vmovups ("#b_addr"),%%xmm3; vmovups ("#b_addr",%%r12,1),%%xmm2; vmovups ("#b_addr",%%r12,2),%%xmm1; addq $16,"#b_addr";"\ - "vbroadcastss (%0),%%xmm10; vfmadd231ps %%xmm3,%%xmm10,%%xmm4; vfmadd231ps %%xmm2,%%xmm10,%%xmm5; vfmadd231ps %%xmm1,%%xmm10,%%xmm6;"\ - "addq $4,%0;" +#define INIT_m1n16 INIT_m1n12 "vpxor %%xmm7,%%xmm7,%%xmm7;" +#define INIT_m1n20 INIT_m1n16 "vpxor %%xmm8,%%xmm8,%%xmm8;" +#define INIT_m1n24 INIT_m1n20 "vpxor %%xmm9,%%xmm9,%%xmm9;" +#define KERNEL_h_k1m1n4 \ + "vbroadcastss (%0),%%xmm1; addq $4,%0; vfmadd231ps (%1),%%xmm1,%%xmm4;" +#define KERNEL_k1m1n4 KERNEL_h_k1m1n4 "addq $16,%1;" +#define KERNEL_h_k1m1n8 KERNEL_h_k1m1n4 "vfmadd231ps (%1,%%r12,1),%%xmm1,%%xmm5;" +#define KERNEL_k1m1n8 KERNEL_h_k1m1n8 "addq $16,%1;" +#define KERNEL_k1m1n12 KERNEL_h_k1m1n8 "vfmadd231ps (%1,%%r12,2),%%xmm1,%%xmm6; addq $16,%1;" +#define KERNEL_h_k1m1n16 KERNEL_k1m1n12 "vfmadd231ps (%%r15),%%xmm1,%%xmm7;" +#define KERNEL_k1m1n16 KERNEL_h_k1m1n16 "addq $16,%%r15;" +#define KERNEL_h_k1m1n20 KERNEL_h_k1m1n16 "vfmadd231ps (%%r15,%%r12,1),%%xmm1,%%xmm8;" +#define KERNEL_k1m1n20 KERNEL_h_k1m1n20 "addq $16,%%r15;" +#define KERNEL_h_k1m1n24 KERNEL_h_k1m1n20 "vfmadd231ps (%%r15,%%r12,2),%%xmm1,%%xmm9;" +#define KERNEL_k1m1n24 KERNEL_h_k1m1n24 "addq $16,%%r15;" #define unit_save_m1n4(c1) \ "vpxor %%xmm10,%%xmm10,%%xmm10; vmovsd "#c1",%%xmm10,%%xmm2; vmovhlps "#c1",%%xmm10,%%xmm1;"\ "vmovss (%5),%%xmm3; vinsertps $16,(%5,%3,1),%%xmm3,%%xmm3; vfmadd213ps %%xmm3,%%xmm0,%%xmm2;"\ "vmovss %%xmm2,(%5); vextractps $1,%%xmm2,(%5,%3,1); leaq (%5,%3,2),%5;"\ "vmovss (%5),%%xmm3; vinsertps $16,(%5,%3,1),%%xmm3,%%xmm3; vfmadd213ps %%xmm3,%%xmm0,%%xmm1;"\ "vmovss %%xmm1,(%5); vextractps $1,%%xmm1,(%5,%3,1); leaq (%5,%3,2),%5;" -#define SAVE_L_m1n4 "movq %2,%5;" unit_save_m1n4(%%xmm4) -#define SAVE_L_m1n8 SAVE_L_m1n4 unit_save_m1n4(%%xmm5) -#define SAVE_L_m1n12 SAVE_L_m1n8 unit_save_m1n4(%%xmm6) -#define SAVE_R_m1n4 unit_save_m1n4(%%xmm4) -#define SAVE_R_m1n8 SAVE_R_m1n4 unit_save_m1n4(%%xmm5) -#define SAVE_R_m1n12 SAVE_R_m1n8 unit_save_m1n4(%%xmm6) -#define COMPUTE_L_m1(ndim,sim) \ +#define SAVE_h_m1n4 "movq %2,%5;" unit_save_m1n4(%%xmm4) +#define SAVE_h_m1n8 SAVE_h_m1n4 unit_save_m1n4(%%xmm5) +#define SAVE_h_m1n12 SAVE_h_m1n8 unit_save_m1n4(%%xmm6) +#define SAVE_h_m1n16 SAVE_h_m1n12 unit_save_m1n4(%%xmm7) +#define SAVE_h_m1n20 SAVE_h_m1n16 unit_save_m1n4(%%xmm8) +#define SAVE_h_m1n24 SAVE_h_m1n20 unit_save_m1n4(%%xmm9) +#define SAVE_m1(ndim) SAVE_h_m1n##ndim "addq $4,%2;" +#define COMPUTE_m1(ndim) \ INIT_m1n##ndim\ - "movq %%r13,%4; movq %%r14,%1;"\ - #ndim""#sim"112:\n\t"\ - "testq %4,%4; jz "#ndim""#sim"113f;"\ - KERNEL_k1m1n##ndim(%1)\ - "decq %4; jmp "#ndim""#sim"112b;"\ - #ndim""#sim"113:\n\t"\ - SAVE_L_m1n##ndim "addq $4,%2;" -#define COMPUTE_R_m1(ndim,sim) \ - "salq $2,%%r13;subq %%r13,%0;sarq $2,%%r13;"\ - INIT_m1n##ndim\ - "movq %%r13,%4; leaq (%%r14,%%r12,2),%%r15; addq %%r12,%%r15;"\ - #ndim""#sim"112:\n\t"\ - "testq %4,%4; jz "#ndim""#sim"113f;"\ - KERNEL_k1m1n##ndim(%%r15)\ - "decq %4; jmp "#ndim""#sim"112b;"\ - #ndim""#sim"113:\n\t"\ - SAVE_R_m1n##ndim -#define COMPUTE_m1_n1 COMPUTE_L_m1(1,99899) -#define COMPUTE_m1_n2 COMPUTE_L_m1(2,99899) -#define COMPUTE_m1_n4 COMPUTE_L_m1(4,99899) -#define COMPUTE_m1_n8 COMPUTE_L_m1(8,99899) -#define COMPUTE_m1_n12 COMPUTE_L_m1(12,99899) -#define COMPUTE_m1_n16 COMPUTE_L_m1(12,99799) COMPUTE_R_m1(4,99999) -#define COMPUTE_m1_n20 COMPUTE_L_m1(12,99699) COMPUTE_R_m1(8,99999) -#define COMPUTE_m1_n24 COMPUTE_L_m1(12,99599) COMPUTE_R_m1(12,99999) -#define COMPUTE_m1(ndim) COMPUTE_m1_n##ndim + "movq %%r13,%4; movq %%r14,%1; leaq (%1,%%r12,2),%%r15; addq %%r12,%%r15;"\ + "testq %4,%4; jz "#ndim"001012f;"\ + #ndim"001011:\n\t"\ + KERNEL_k1m1n##ndim "decq %4; jnz "#ndim"001011b;"\ + #ndim"001012:\n\t"\ + SAVE_m1(ndim) /* %0 = "+r"(a_pointer), %1 = "+r"(b_pointer), %2 = "+r"(c_pointer), %3 = "+r"(ldc_in_bytes), %4 = "+r"(K), %5 = "+r"(ctemp) */ -/* %6 = "+r"(&alpha), %7 = "+r"(M), %8 = "+r"(next_b) */ -/* r11 = m(const), r12 = k << 4(const), r13 = k(const), r14 = b_head_pos(const), r15 = %1 + 3r12 */ +/* %6 = "+r"(next_b), %7 = "m"(ALPHA), %8 = "m"(M) */ +/* r11 = m_counter, r12 = k << 4(const), r13 = k(const), r14 = b_head_pos(const), r15 = %1 + 3r12 */ #define COMPUTE(ndim) {\ next_b = b_pointer + ndim * K;\ __asm__ __volatile__(\ - "vbroadcastss (%6),%%zmm0;"\ - "movq %4,%%r13; movq %4,%%r12; salq $4,%%r12; movq %1,%%r14; movq %7,%%r11;"\ - "cmpq $16,%7;jb 33101"#ndim"f;"\ + "vbroadcastss %7,%%zmm0;"\ + "movq %4,%%r13; movq %4,%%r12; salq $4,%%r12; movq %1,%%r14; movq %8,%%r11;"\ + "cmpq $16,%%r11;jb 33101"#ndim"f;"\ "33109"#ndim":\n\t"\ COMPUTE_m16(ndim)\ - "subq $16,%7;cmpq $16,%7;jnb 33109"#ndim"b;"\ + "subq $16,%%r11;cmpq $16,%%r11;jnb 33109"#ndim"b;"\ "33101"#ndim":\n\t"\ - "cmpq $8,%7;jb 33102"#ndim"f;"\ + "cmpq $8,%%r11;jb 33102"#ndim"f;"\ COMPUTE_m8(ndim)\ - "subq $8,%7;"\ + "subq $8,%%r11;"\ "33102"#ndim":\n\t"\ - "cmpq $4,%7;jb 33103"#ndim"f;"\ + "cmpq $4,%%r11;jb 33103"#ndim"f;"\ COMPUTE_m4(ndim)\ - "subq $4,%7;"\ + "subq $4,%%r11;"\ "33103"#ndim":\n\t"\ - "cmpq $2,%7;jb 33104"#ndim"f;"\ + "cmpq $2,%%r11;jb 33104"#ndim"f;"\ COMPUTE_m2(ndim)\ - "subq $2,%7;"\ + "subq $2,%%r11;"\ "33104"#ndim":\n\t"\ - "testq %7,%7;jz 33105"#ndim"f;"\ + "testq %%r11,%%r11;jz 33105"#ndim"f;"\ COMPUTE_m1(ndim)\ "33105"#ndim":\n\t"\ - "movq %%r13,%4; movq %%r14,%1; movq %%r11,%7;"\ - :"+r"(a_pointer),"+r"(b_pointer),"+r"(c_pointer),"+r"(ldc_in_bytes),"+r"(K),"+r"(ctemp),"+r"(alp),"+r"(M),"+r"(next_b)\ - ::"r11","r12","r13","r14","r15","zmm0","zmm1","zmm2","zmm3","zmm4","zmm5","zmm6","zmm7","zmm8","zmm9","zmm10","zmm11","zmm12","zmm13","zmm14",\ + "movq %%r13,%4; movq %%r14,%1; vzeroupper;"\ + :"+r"(a_pointer),"+r"(b_pointer),"+r"(c_pointer),"+r"(ldc_in_bytes),"+r"(K),"+r"(ctemp),"+r"(next_b):"m"(ALPHA),"m"(M)\ + :"r10","r11","r12","r13","r14","r15","zmm0","zmm1","zmm2","zmm3","zmm4","zmm5","zmm6","zmm7","zmm8","zmm9","zmm10","zmm11","zmm12","zmm13","zmm14",\ "zmm15","zmm16","zmm17","zmm18","zmm19","zmm20","zmm21","zmm22","zmm23","zmm24","zmm25","zmm26","zmm27","zmm28","zmm29","zmm30","zmm31",\ "cc","memory");\ - a_pointer -= M * K; b_pointer += ndim * K;c_pointer += LDC * ndim - M;\ + a_pointer -= M * K; b_pointer += ndim * K; c_pointer += LDC * ndim - M;\ } int __attribute__ ((noinline)) CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float * __restrict__ A, float * __restrict__ B, float * __restrict__ C, BLASLONG LDC) @@ -399,7 +365,7 @@ CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float * __restrict__ A, f int64_t ldc_in_bytes = (int64_t)LDC * sizeof(float);float ALPHA = alpha; int64_t M = (int64_t)m, K = (int64_t)k; BLASLONG n_count = n; - float *a_pointer = A,*b_pointer = B,*c_pointer = C,*ctemp = C,*alp = &ALPHA,*next_b = B; + float *a_pointer = A,*b_pointer = B,*c_pointer = C,*ctemp = C,*next_b = B; for(;n_count>23;n_count-=24) COMPUTE(24) for(;n_count>19;n_count-=20) COMPUTE(20) for(;n_count>15;n_count-=16) COMPUTE(16) @@ -411,469 +377,4 @@ CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float * __restrict__ A, f return 0; } -#include -/* codes below are copied from the sgemm kernel written by Arjan van der Ven */ - -/* - * "Direct sgemm" code. This code operates directly on the inputs and outputs - * of the sgemm call, avoiding the copies, memory realignments and threading, - * and only supports alpha = 1 and beta = 0. - * This is a common case and provides value for relatively small matrixes. - * For larger matrixes the "regular" sgemm code is superior, there the cost of - * copying/shuffling the B matrix really pays off. - */ - - - -#define DECLARE_RESULT_512(N,M) __m512 result##N##M = _mm512_setzero_ps() -#define BROADCAST_LOAD_A_512(N,M) __m512 Aval##M = _mm512_broadcastss_ps(_mm_load_ss(&A[k + strideA * (i+M)])) -#define LOAD_B_512(N,M) __m512 Bval##N = _mm512_loadu_ps(&B[strideB * k + j + (N*16)]) -#define MATMUL_512(N,M) result##N##M = _mm512_fmadd_ps(Aval##M, Bval##N , result##N##M) -#define STORE_512(N,M) _mm512_storeu_ps(&R[(i+M) * strideR + j+(N*16)], result##N##M) - - -#define DECLARE_RESULT_256(N,M) __m256 result##N##M = _mm256_setzero_ps() -#define BROADCAST_LOAD_A_256(N,M) __m256 Aval##M = _mm256_broadcastss_ps(_mm_load_ss(&A[k + strideA * (i+M)])) -#define LOAD_B_256(N,M) __m256 Bval##N = _mm256_loadu_ps(&B[strideB * k + j + (N*8)]) -#define MATMUL_256(N,M) result##N##M = _mm256_fmadd_ps(Aval##M, Bval##N , result##N##M) -#define STORE_256(N,M) _mm256_storeu_ps(&R[(i+M) * strideR + j+(N*8)], result##N##M) - -#define DECLARE_RESULT_128(N,M) __m128 result##N##M = _mm_setzero_ps() -#define BROADCAST_LOAD_A_128(N,M) __m128 Aval##M = _mm_broadcastss_ps(_mm_load_ss(&A[k + strideA * (i+M)])) -#define LOAD_B_128(N,M) __m128 Bval##N = _mm_loadu_ps(&B[strideB * k + j + (N*4)]) -#define MATMUL_128(N,M) result##N##M = _mm_fmadd_ps(Aval##M, Bval##N , result##N##M) -#define STORE_128(N,M) _mm_storeu_ps(&R[(i+M) * strideR + j+(N*4)], result##N##M) - -#define DECLARE_RESULT_SCALAR(N,M) float result##N##M = 0; -#define BROADCAST_LOAD_A_SCALAR(N,M) float Aval##M = A[k + strideA * (i + M)]; -#define LOAD_B_SCALAR(N,M) float Bval##N = B[k * strideB + j + N]; -#define MATMUL_SCALAR(N,M) result##N##M += Aval##M * Bval##N; -#define STORE_SCALAR(N,M) R[(i+M) * strideR + j + N] = result##N##M; - -int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K) -{ - unsigned long long mnk = M * N * K; - /* large matrixes -> not performant */ - if (mnk >= 28 * 512 * 512) - return 0; - - /* - * if the B matrix is not a nice multiple if 4 we get many unaligned accesses, - * and the regular sgemm copy/realignment of data pays off much quicker - */ - if ((N & 3) != 0 && (mnk >= 8 * 512 * 512)) - return 0; - -#ifdef SMP - /* if we can run multithreaded, the threading changes the based threshold */ - if (mnk > 2 * 350 * 512 && num_cpu_avail(3)> 1) - return 0; -#endif - - return 1; -} - - - -void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR) -{ - int i, j, k; - - int m4 = M & ~3; - int m2 = M & ~1; - - int n64 = N & ~63; - int n32 = N & ~31; - int n16 = N & ~15; - int n8 = N & ~7; - int n4 = N & ~3; - int n2 = N & ~1; - - i = 0; - - for (i = 0; i < m4; i+=4) { - - for (j = 0; j < n64; j+= 64) { - k = 0; - DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); - DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); - DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); - DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); - - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - BROADCAST_LOAD_A_512(x, 1); - BROADCAST_LOAD_A_512(x, 2); - BROADCAST_LOAD_A_512(x, 3); - - LOAD_B_512(0, x); LOAD_B_512(1, x); LOAD_B_512(2, x); LOAD_B_512(3, x); - - MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); - MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); - MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); - MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); - } - STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); - STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); - STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); - STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); - } - - for (; j < n32; j+= 32) { - DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); - DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); - DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); - DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - BROADCAST_LOAD_A_512(x, 1); - BROADCAST_LOAD_A_512(x, 2); - BROADCAST_LOAD_A_512(x, 3); - - LOAD_B_512(0, x); LOAD_B_512(1, x); - - MATMUL_512(0, 0); MATMUL_512(1, 0); - MATMUL_512(0, 1); MATMUL_512(1, 1); - MATMUL_512(0, 2); MATMUL_512(1, 2); - MATMUL_512(0, 3); MATMUL_512(1, 3); - } - STORE_512(0, 0); STORE_512(1, 0); - STORE_512(0, 1); STORE_512(1, 1); - STORE_512(0, 2); STORE_512(1, 2); - STORE_512(0, 3); STORE_512(1, 3); - } - - for (; j < n16; j+= 16) { - DECLARE_RESULT_512(0, 0); - DECLARE_RESULT_512(0, 1); - DECLARE_RESULT_512(0, 2); - DECLARE_RESULT_512(0, 3); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - BROADCAST_LOAD_A_512(x, 1); - BROADCAST_LOAD_A_512(x, 2); - BROADCAST_LOAD_A_512(x, 3); - - LOAD_B_512(0, x); - - MATMUL_512(0, 0); - MATMUL_512(0, 1); - MATMUL_512(0, 2); - MATMUL_512(0, 3); - } - STORE_512(0, 0); - STORE_512(0, 1); - STORE_512(0, 2); - STORE_512(0, 3); - } - - for (; j < n8; j+= 8) { - DECLARE_RESULT_256(0, 0); - DECLARE_RESULT_256(0, 1); - DECLARE_RESULT_256(0, 2); - DECLARE_RESULT_256(0, 3); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_256(x, 0); - BROADCAST_LOAD_A_256(x, 1); - BROADCAST_LOAD_A_256(x, 2); - BROADCAST_LOAD_A_256(x, 3); - - LOAD_B_256(0, x); - - MATMUL_256(0, 0); - MATMUL_256(0, 1); - MATMUL_256(0, 2); - MATMUL_256(0, 3); - } - STORE_256(0, 0); - STORE_256(0, 1); - STORE_256(0, 2); - STORE_256(0, 3); - } - - for (; j < n4; j+= 4) { - DECLARE_RESULT_128(0, 0); - DECLARE_RESULT_128(0, 1); - DECLARE_RESULT_128(0, 2); - DECLARE_RESULT_128(0, 3); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_128(x, 0); - BROADCAST_LOAD_A_128(x, 1); - BROADCAST_LOAD_A_128(x, 2); - BROADCAST_LOAD_A_128(x, 3); - - LOAD_B_128(0, x); - - MATMUL_128(0, 0); - MATMUL_128(0, 1); - MATMUL_128(0, 2); - MATMUL_128(0, 3); - } - STORE_128(0, 0); - STORE_128(0, 1); - STORE_128(0, 2); - STORE_128(0, 3); - } - - for (; j < n2; j+= 2) { - DECLARE_RESULT_SCALAR(0, 0); DECLARE_RESULT_SCALAR(1, 0); - DECLARE_RESULT_SCALAR(0, 1); DECLARE_RESULT_SCALAR(1, 1); - DECLARE_RESULT_SCALAR(0, 2); DECLARE_RESULT_SCALAR(1, 2); - DECLARE_RESULT_SCALAR(0, 3); DECLARE_RESULT_SCALAR(1, 3); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_SCALAR(x, 0); - BROADCAST_LOAD_A_SCALAR(x, 1); - BROADCAST_LOAD_A_SCALAR(x, 2); - BROADCAST_LOAD_A_SCALAR(x, 3); - - LOAD_B_SCALAR(0, x); LOAD_B_SCALAR(1, x); - - MATMUL_SCALAR(0, 0); MATMUL_SCALAR(1, 0); - MATMUL_SCALAR(0, 1); MATMUL_SCALAR(1, 1); - MATMUL_SCALAR(0, 2); MATMUL_SCALAR(1, 2); - MATMUL_SCALAR(0, 3); MATMUL_SCALAR(1, 3); - } - STORE_SCALAR(0, 0); STORE_SCALAR(1, 0); - STORE_SCALAR(0, 1); STORE_SCALAR(1, 1); - STORE_SCALAR(0, 2); STORE_SCALAR(1, 2); - STORE_SCALAR(0, 3); STORE_SCALAR(1, 3); - } - - for (; j < N; j++) { - DECLARE_RESULT_SCALAR(0, 0) - DECLARE_RESULT_SCALAR(0, 1) - DECLARE_RESULT_SCALAR(0, 2) - DECLARE_RESULT_SCALAR(0, 3) - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_SCALAR(0, 0); - BROADCAST_LOAD_A_SCALAR(0, 1); - BROADCAST_LOAD_A_SCALAR(0, 2); - BROADCAST_LOAD_A_SCALAR(0, 3); - - LOAD_B_SCALAR(0, 0); - - MATMUL_SCALAR(0, 0); - MATMUL_SCALAR(0, 1); - MATMUL_SCALAR(0, 2); - MATMUL_SCALAR(0, 3); - } - STORE_SCALAR(0, 0); - STORE_SCALAR(0, 1); - STORE_SCALAR(0, 2); - STORE_SCALAR(0, 3); - } - } - - for (; i < m2; i+=2) { - j = 0; - - for (; j < n64; j+= 64) { - DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); - DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); - - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - BROADCAST_LOAD_A_512(x, 1); - - LOAD_B_512(0, x); LOAD_B_512(1, x); LOAD_B_512(2, x); LOAD_B_512(3, x); - - MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); - MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); - } - STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); - STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); - } - - for (; j < n32; j+= 32) { - DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); - DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - BROADCAST_LOAD_A_512(x, 1); - - LOAD_B_512(0, x); LOAD_B_512(1, x); - - MATMUL_512(0, 0); MATMUL_512(1, 0); - MATMUL_512(0, 1); MATMUL_512(1, 1); - } - STORE_512(0, 0); STORE_512(1, 0); - STORE_512(0, 1); STORE_512(1, 1); - } - - - for (; j < n16; j+= 16) { - DECLARE_RESULT_512(0, 0); - DECLARE_RESULT_512(0, 1); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - BROADCAST_LOAD_A_512(x, 1); - - LOAD_B_512(0, x); - - MATMUL_512(0, 0); - MATMUL_512(0, 1); - } - STORE_512(0, 0); - STORE_512(0, 1); - } - - for (; j < n8; j+= 8) { - DECLARE_RESULT_256(0, 0); - DECLARE_RESULT_256(0, 1); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_256(x, 0); - BROADCAST_LOAD_A_256(x, 1); - - LOAD_B_256(0, x); - - MATMUL_256(0, 0); - MATMUL_256(0, 1); - } - STORE_256(0, 0); - STORE_256(0, 1); - } - - for (; j < n4; j+= 4) { - DECLARE_RESULT_128(0, 0); - DECLARE_RESULT_128(0, 1); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_128(x, 0); - BROADCAST_LOAD_A_128(x, 1); - - LOAD_B_128(0, x); - - MATMUL_128(0, 0); - MATMUL_128(0, 1); - } - STORE_128(0, 0); - STORE_128(0, 1); - } - for (; j < n2; j+= 2) { - DECLARE_RESULT_SCALAR(0, 0); DECLARE_RESULT_SCALAR(1, 0); - DECLARE_RESULT_SCALAR(0, 1); DECLARE_RESULT_SCALAR(1, 1); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_SCALAR(x, 0); - BROADCAST_LOAD_A_SCALAR(x, 1); - - LOAD_B_SCALAR(0, x); LOAD_B_SCALAR(1, x); - - MATMUL_SCALAR(0, 0); MATMUL_SCALAR(1, 0); - MATMUL_SCALAR(0, 1); MATMUL_SCALAR(1, 1); - } - STORE_SCALAR(0, 0); STORE_SCALAR(1, 0); - STORE_SCALAR(0, 1); STORE_SCALAR(1, 1); - } - - for (; j < N; j++) { - DECLARE_RESULT_SCALAR(0, 0); - DECLARE_RESULT_SCALAR(0, 1); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_SCALAR(0, 0); - BROADCAST_LOAD_A_SCALAR(0, 1); - - LOAD_B_SCALAR(0, 0); - - MATMUL_SCALAR(0, 0); - MATMUL_SCALAR(0, 1); - } - STORE_SCALAR(0, 0); - STORE_SCALAR(0, 1); - } - } - - for (; i < M; i+=1) { - j = 0; - for (; j < n64; j+= 64) { - DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - LOAD_B_512(0, x); LOAD_B_512(1, x); LOAD_B_512(2, x); LOAD_B_512(3, x); - MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); - } - STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); - } - for (; j < n32; j+= 32) { - DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - LOAD_B_512(0, x); LOAD_B_512(1, x); - MATMUL_512(0, 0); MATMUL_512(1, 0); - } - STORE_512(0, 0); STORE_512(1, 0); - } - - - for (; j < n16; j+= 16) { - DECLARE_RESULT_512(0, 0); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_512(x, 0); - - LOAD_B_512(0, x); - - MATMUL_512(0, 0); - } - STORE_512(0, 0); - } - - for (; j < n8; j+= 8) { - DECLARE_RESULT_256(0, 0); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_256(x, 0); - LOAD_B_256(0, x); - MATMUL_256(0, 0); - } - STORE_256(0, 0); - } - - for (; j < n4; j+= 4) { - DECLARE_RESULT_128(0, 0); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_128(x, 0); - LOAD_B_128(0, x); - MATMUL_128(0, 0); - } - STORE_128(0, 0); - } - - for (; j < n2; j+= 2) { - DECLARE_RESULT_SCALAR(0, 0); DECLARE_RESULT_SCALAR(1, 0); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_SCALAR(x, 0); - LOAD_B_SCALAR(0, 0); LOAD_B_SCALAR(1, 0); - MATMUL_SCALAR(0, 0); MATMUL_SCALAR(1, 0); - } - STORE_SCALAR(0, 0); STORE_SCALAR(1, 0); - } - - for (; j < N; j++) { - DECLARE_RESULT_SCALAR(0, 0); - - for (k = 0; k < K; k++) { - BROADCAST_LOAD_A_SCALAR(0, 0); - LOAD_B_SCALAR(0, 0); - MATMUL_SCALAR(0, 0); - } - STORE_SCALAR(0, 0); - } - } -} +#include "sgemm_direct_skylakex.c" From 952cc2ba3860419defed3c27af1c3becca9e40e9 Mon Sep 17 00:00:00 2001 From: wjc404 <52632443+wjc404@users.noreply.github.com> Date: Mon, 13 Jan 2020 16:58:54 +0800 Subject: [PATCH 05/12] Update sgemm_kernel_16x4_skylakex_2.c --- kernel/x86_64/sgemm_kernel_16x4_skylakex_2.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernel/x86_64/sgemm_kernel_16x4_skylakex_2.c b/kernel/x86_64/sgemm_kernel_16x4_skylakex_2.c index e4ca6b1bd..6ca822b91 100644 --- a/kernel/x86_64/sgemm_kernel_16x4_skylakex_2.c +++ b/kernel/x86_64/sgemm_kernel_16x4_skylakex_2.c @@ -376,5 +376,5 @@ CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float * __restrict__ A, f if(n_count>0) COMPUTE(1) return 0; } - +#include #include "sgemm_direct_skylakex.c" From e5dcdeb5506a8e0ab26e0956c5b8e7fed7e80e9a Mon Sep 17 00:00:00 2001 From: wjc404 <52632443+wjc404@users.noreply.github.com> Date: Mon, 13 Jan 2020 16:59:23 +0800 Subject: [PATCH 06/12] Update sgemm_direct_skylakex.c --- kernel/x86_64/sgemm_direct_skylakex.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernel/x86_64/sgemm_direct_skylakex.c b/kernel/x86_64/sgemm_direct_skylakex.c index 4f9af6e57..0e8f1318f 100644 --- a/kernel/x86_64/sgemm_direct_skylakex.c +++ b/kernel/x86_64/sgemm_direct_skylakex.c @@ -1,6 +1,6 @@ /* the direct sgemm code written by Arjan van der Ven */ -#include +//#include /* * "Direct sgemm" code. This code operates directly on the inputs and outputs From 78100b80935753a7a86c6a5380e2a53bc9469b7f Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sat, 18 Jan 2020 15:06:39 +0100 Subject: [PATCH 07/12] Free Windows thread memory with MEM_RELEASE rather than MEM_DECOMMIT as suggested by hjmndv in #2370 --- driver/others/memory.c | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/driver/others/memory.c b/driver/others/memory.c index 55dce72b8..62a5a0214 100644 --- a/driver/others/memory.c +++ b/driver/others/memory.c @@ -822,7 +822,7 @@ static void *alloc_qalloc(void *address){ static void alloc_windows_free(struct alloc_t *alloc_info){ - VirtualFree(alloc_info, allocation_block_size, MEM_DECOMMIT); + VirtualFree(alloc_info, 0, MEM_RELEASE); } @@ -935,7 +935,7 @@ static void alloc_hugetlb_free(struct alloc_t *alloc_info){ #ifdef OS_WINDOWS - VirtualFree(alloc_info, allocation_block_size, MEM_LARGE_PAGES | MEM_DECOMMIT); + VirtualFree(alloc_info, 0, MEM_LARGE_PAGES | MEM_RELEASE); #endif @@ -2310,7 +2310,7 @@ static void *alloc_qalloc(void *address){ static void alloc_windows_free(struct release_t *release){ - VirtualFree(release -> address, BUFFER_SIZE, MEM_DECOMMIT); + VirtualFree(release -> address, 0, MEM_RELEASE); } @@ -2432,7 +2432,7 @@ static void alloc_hugetlb_free(struct release_t *release){ #ifdef OS_WINDOWS - VirtualFree(release -> address, BUFFER_SIZE, MEM_LARGE_PAGES | MEM_DECOMMIT); + VirtualFree(release -> address, 0, MEM_LARGE_PAGES | MEM_RELEASE); #endif From 23f322f997c8b018977be24122c56fb62d728a05 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sun, 19 Jan 2020 13:28:27 +0100 Subject: [PATCH 08/12] Do not run any cleanup if the program is exiting anyway From keno's PR #2350 - this avoids the potential hang in blas_thread_shutdown where we may wait for threads to exit while they are waiting on the loader lock from DllMain --- exports/dllinit.c | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/exports/dllinit.c b/exports/dllinit.c index 4a05c0e14..88f9af658 100644 --- a/exports/dllinit.c +++ b/exports/dllinit.c @@ -50,7 +50,10 @@ BOOL APIENTRY DllMain(HINSTANCE hInst, DWORD reason, LPVOID reserved) { gotoblas_init(); break; case DLL_PROCESS_DETACH: - gotoblas_quit(); + // If the process is about to exit, don't bother releasing any resources + // The kernel is much better at bulk releasing then. + if (!reserved) + gotoblas_quit(); break; case DLL_THREAD_ATTACH: break; From ff42e68652fbba58936c9c66d0b060c3a6d694e7 Mon Sep 17 00:00:00 2001 From: Qiyu8 Date: Mon, 20 Jan 2020 11:49:42 +0800 Subject: [PATCH 09/12] Optimize genenal Gemm Beta --- kernel/generic/gemm_beta.c | 132 ++++++++++++------------------------- 1 file changed, 42 insertions(+), 90 deletions(-) diff --git a/kernel/generic/gemm_beta.c b/kernel/generic/gemm_beta.c index c4e4f7abe..fa9d7680d 100644 --- a/kernel/generic/gemm_beta.c +++ b/kernel/generic/gemm_beta.c @@ -42,101 +42,53 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, FLOAT *dummy2, BLASLONG dummy3, FLOAT *dummy4, BLASLONG dummy5, FLOAT *c, BLASLONG ldc){ + BLASLONG i, j; + BLASLONG chunk, remain; FLOAT *c_offset1, *c_offset; - FLOAT ctemp1, ctemp2, ctemp3, ctemp4; - FLOAT ctemp5, ctemp6, ctemp7, ctemp8; - c_offset = c; - + chunk = m >> 3; + remain = m & 7; if (beta == ZERO){ - - j = n; - do { - c_offset1 = c_offset; - c_offset += ldc; - - i = (m >> 3); - if (i > 0){ - do { - *(c_offset1 + 0) = ZERO; - *(c_offset1 + 1) = ZERO; - *(c_offset1 + 2) = ZERO; - *(c_offset1 + 3) = ZERO; - *(c_offset1 + 4) = ZERO; - *(c_offset1 + 5) = ZERO; - *(c_offset1 + 6) = ZERO; - *(c_offset1 + 7) = ZERO; - c_offset1 += 8; - i --; - } while (i > 0); - } - - i = (m & 7); - if (i > 0){ - do { - *c_offset1 = ZERO; - c_offset1 ++; - i --; - } while (i > 0); - } - j --; - } while (j > 0); - + for(j=n; j>0; j--){ + c_offset1 = c_offset; + c_offset += ldc; + for(i=chunk; i>0; i--){ + *(c_offset1 + 0) = ZERO; + *(c_offset1 + 1) = ZERO; + *(c_offset1 + 2) = ZERO; + *(c_offset1 + 3) = ZERO; + *(c_offset1 + 4) = ZERO; + *(c_offset1 + 5) = ZERO; + *(c_offset1 + 6) = ZERO; + *(c_offset1 + 7) = ZERO; + c_offset1 += 8; + } + for(i=remain; i>0; i--){ + *c_offset1 = ZERO; + c_offset1 ++; + } + } } else { - - j = n; - do { - c_offset1 = c_offset; - c_offset += ldc; - - i = (m >> 3); - if (i > 0){ - do { - ctemp1 = *(c_offset1 + 0); - ctemp2 = *(c_offset1 + 1); - ctemp3 = *(c_offset1 + 2); - ctemp4 = *(c_offset1 + 3); - ctemp5 = *(c_offset1 + 4); - ctemp6 = *(c_offset1 + 5); - ctemp7 = *(c_offset1 + 6); - ctemp8 = *(c_offset1 + 7); - - ctemp1 *= beta; - ctemp2 *= beta; - ctemp3 *= beta; - ctemp4 *= beta; - ctemp5 *= beta; - ctemp6 *= beta; - ctemp7 *= beta; - ctemp8 *= beta; - - *(c_offset1 + 0) = ctemp1; - *(c_offset1 + 1) = ctemp2; - *(c_offset1 + 2) = ctemp3; - *(c_offset1 + 3) = ctemp4; - *(c_offset1 + 4) = ctemp5; - *(c_offset1 + 5) = ctemp6; - *(c_offset1 + 6) = ctemp7; - *(c_offset1 + 7) = ctemp8; - c_offset1 += 8; - i --; - } while (i > 0); - } - - i = (m & 7); - if (i > 0){ - do { - ctemp1 = *c_offset1; - ctemp1 *= beta; - *c_offset1 = ctemp1; - c_offset1 ++; - i --; - } while (i > 0); - } - j --; - } while (j > 0); - + for(j=n; j>0; j--){ + c_offset1 = c_offset; + c_offset += ldc; + for(i=chunk; i>0; i--){ + *(c_offset1 + 0) *= beta; + *(c_offset1 + 1) *= beta; + *(c_offset1 + 2) *= beta; + *(c_offset1 + 3) *= beta; + *(c_offset1 + 4) *= beta; + *(c_offset1 + 5) *= beta; + *(c_offset1 + 6) *= beta; + *(c_offset1 + 7) *= beta; + c_offset1 += 8; + } + for(i=remain; i>0; i--){ + *c_offset1 *= beta; + c_offset1 ++; + } + } } return 0; }; From fbf4f48f4a3d324dd268aaad51624022ee4f0ea2 Mon Sep 17 00:00:00 2001 From: "Wang,Long" Date: Wed, 22 Jan 2020 15:07:50 +0000 Subject: [PATCH 10/12] fix a few performance drop in some matrix size per data type Signed-off-by: Wang,Long --- param.h | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/param.h b/param.h index 3baae31cf..075c12ca2 100644 --- a/param.h +++ b/param.h @@ -1507,8 +1507,13 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define SYMV_P 8 -#define SWITCH_RATIO 32 -#define GEMM_PREFERED_SIZE 16 +#if defined(XDOUBLE) || defined(DOUBLE) +#define SWITCH_RATIO 4 +#define GEMM_PREFERED_SIZE 4 +#else +#define SWITCH_RATIO 8 +#define GEMM_PREFERED_SIZE 8 +#endif #ifdef ARCH_X86 @@ -1627,8 +1632,13 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define SYMV_P 8 -#define SWITCH_RATIO 32 -#define GEMM_PREFERED_SIZE 32 +#if defined(XDOUBLE) || defined(DOUBLE) +#define SWITCH_RATIO 8 +#define GEMM_PREFERED_SIZE 8 +#else +#define SWITCH_RATIO 16 +#define GEMM_PREFERED_SIZE 16 +#endif #define USE_SGEMM_KERNEL_DIRECT 1 #ifdef ARCH_X86 From e9fb8f62b1822c456ccc0b9db23f49aa66dd6801 Mon Sep 17 00:00:00 2001 From: wjc404 <52632443+wjc404@users.noreply.github.com> Date: Wed, 22 Jan 2020 17:40:03 +0000 Subject: [PATCH 11/12] Update level3_gemm3m_thread.c --- driver/level3/level3_gemm3m_thread.c | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/driver/level3/level3_gemm3m_thread.c b/driver/level3/level3_gemm3m_thread.c index 21d431b60..9216daaed 100644 --- a/driver/level3/level3_gemm3m_thread.c +++ b/driver/level3/level3_gemm3m_thread.c @@ -104,7 +104,7 @@ typedef struct { #define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \ GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \ BETA[0], BETA[1], NULL, 0, NULL, 0, \ - (FLOAT *)(C) + (M_FROM) + (N_FROM) * (LDC) * COMPSIZE, LDC) + (FLOAT *)(C) + ((M_FROM) + (N_FROM) * (LDC)) * COMPSIZE, LDC) #endif #ifndef ICOPYB_OPERATION @@ -414,7 +414,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, for(jjs = xxx; jjs < MIN(n_to, xxx + div_n); jjs += min_jj){ min_jj = MIN(n_to, xxx + div_n) - jjs; - if (min_jj > GEMM3M_UNROLL_N) min_jj = GEMM3M_UNROLL_N; + if (min_jj > GEMM3M_UNROLL_N*3) min_jj = GEMM3M_UNROLL_N*3; START_RPCC(); @@ -550,7 +550,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, for(jjs = xxx; jjs < MIN(n_to, xxx + div_n); jjs += min_jj){ min_jj = MIN(n_to, xxx + div_n) - jjs; - if (min_jj > GEMM3M_UNROLL_N) min_jj = GEMM3M_UNROLL_N; + if (min_jj > GEMM3M_UNROLL_N*3) min_jj = GEMM3M_UNROLL_N*3; START_RPCC(); @@ -687,7 +687,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, for(jjs = xxx; jjs < MIN(n_to, xxx + div_n); jjs += min_jj){ min_jj = MIN(n_to, xxx + div_n) - jjs; - if (min_jj > GEMM3M_UNROLL_N) min_jj = GEMM3M_UNROLL_N; + if (min_jj > GEMM3M_UNROLL_N*3) min_jj = GEMM3M_UNROLL_N*3; START_RPCC(); From 8dc9fd4dfeb894d8b7553c8e5fcc991917335557 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Thu, 30 Jan 2020 12:41:18 +0100 Subject: [PATCH 12/12] Add -march option for AVX512 --- cmake/cc.cmake | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cmake/cc.cmake b/cmake/cc.cmake index 37da0d6ed..22217575c 100644 --- a/cmake/cc.cmake +++ b/cmake/cc.cmake @@ -96,3 +96,10 @@ if (${CMAKE_C_COMPILER_ID} STREQUAL "SUN") endif () endif () +if (${CORE} STREQUAL "SKYLAKEX") + if (NOT DYNAMIC_ARCH) + if (NOT NO_AVX512) + set (CCOMMON_OPT = "${CCOMMON_OPT} -march=skylake-avx512") + endif () + endif () +endif ()