From 0a2077901cf94877f6173f6b580762b68b2fd2e0 Mon Sep 17 00:00:00 2001 From: Xianyi Zhang Date: Tue, 28 Apr 2020 19:01:36 +0800 Subject: [PATCH 01/37] Add small marix optimization kernel interface. make SMALL_MATRIX_OPT=1 --- Makefile.system | 5 ++ common_d.h | 6 ++ common_level3.h | 12 ++++ common_macro.h | 16 +++++ common_s.h | 5 ++ interface/gemm.c | 28 +++++++- kernel/Makefile.L3 | 73 ++++++++++++++++++++ kernel/generic/gemm_small_matrix_kernel_nn.c | 49 +++++++++++++ kernel/generic/gemm_small_matrix_kernel_nt.c | 49 +++++++++++++ kernel/generic/gemm_small_matrix_kernel_tn.c | 49 +++++++++++++ kernel/generic/gemm_small_matrix_kernel_tt.c | 49 +++++++++++++ 11 files changed, 340 insertions(+), 1 deletion(-) create mode 100644 kernel/generic/gemm_small_matrix_kernel_nn.c create mode 100644 kernel/generic/gemm_small_matrix_kernel_nt.c create mode 100644 kernel/generic/gemm_small_matrix_kernel_tn.c create mode 100644 kernel/generic/gemm_small_matrix_kernel_tt.c diff --git a/Makefile.system b/Makefile.system index 13c946ba1..20d8d2f2a 100644 --- a/Makefile.system +++ b/Makefile.system @@ -244,6 +244,11 @@ else ONLY_CBLAS = 0 endif +#For small matrix optimization +ifeq ($(SMALL_MATRIX_OPT), 1) +CCOMMON_OPT += -DSMALL_MATRIX_OPT +endif + # This operation is expensive, so execution should be once. ifndef GOTOBLAS_MAKEFILE export GOTOBLAS_MAKEFILE = 1 diff --git a/common_d.h b/common_d.h index 94dc3eea8..dad304a5f 100644 --- a/common_d.h +++ b/common_d.h @@ -157,6 +157,12 @@ #define DIMATCOPY_K_RT dimatcopy_k_rt #define DGEADD_K dgeadd_k + +#define DGEMM_SMALL_KERNEL_NN dgemm_small_kernel_nn +#define DGEMM_SMALL_KERNEL_NT dgemm_small_kernel_nt +#define DGEMM_SMALL_KERNEL_TN dgemm_small_kernel_tn +#define DGEMM_SMALL_KERNEL_TT dgemm_small_kernel_tt + #else #define DAMAX_K gotoblas -> damax_k diff --git a/common_level3.h b/common_level3.h index c4f9435a9..751592b67 100644 --- a/common_level3.h +++ b/common_level3.h @@ -515,6 +515,18 @@ int qgemm_kernel(BLASLONG, BLASLONG, BLASLONG, xidouble *, xidouble *, xidouble int qgemm_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG); #endif +#ifdef SMALL_MATRIX_OPT +int sgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); +int sgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); +int sgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); +int sgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + +int dgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); +int dgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); +int dgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); +int dgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); +#endif + int cgemm_kernel_n(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG); int cgemm_kernel_l(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG); int cgemm_kernel_r(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG); diff --git a/common_macro.h b/common_macro.h index 0136f18ab..eb2abcdc0 100644 --- a/common_macro.h +++ b/common_macro.h @@ -644,6 +644,11 @@ #define GEADD_K DGEADD_K +#define GEMM_SMALL_KERNEL_NN DGEMM_SMALL_KERNEL_NN +#define GEMM_SMALL_KERNEL_NT DGEMM_SMALL_KERNEL_NT +#define GEMM_SMALL_KERNEL_TN DGEMM_SMALL_KERNEL_TN +#define GEMM_SMALL_KERNEL_TT DGEMM_SMALL_KERNEL_TT + #elif defined(BFLOAT16) #define D_TO_BF16_K SBDTOBF16_K @@ -931,6 +936,11 @@ #define GEADD_K SGEADD_K +#define GEMM_SMALL_KERNEL_NN SGEMM_SMALL_KERNEL_NN +#define GEMM_SMALL_KERNEL_NT SGEMM_SMALL_KERNEL_NT +#define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN +#define GEMM_SMALL_KERNEL_TT SGEMM_SMALL_KERNEL_TT + #endif #else @@ -1236,6 +1246,12 @@ #define IMATCOPY_K_RT SIMATCOPY_K_RT #define GEADD_K SGEADD_K + +#define GEMM_SMALL_KERNEL_NN SGEMM_SMALL_KERNEL_NN +#define GEMM_SMALL_KERNEL_NT SGEMM_SMALL_KERNEL_NT +#define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN +#define GEMM_SMALL_KERNEL_TT SGEMM_SMALL_KERNEL_TT + #endif #else #ifdef XDOUBLE diff --git a/common_s.h b/common_s.h index 34903ec49..6ad98ba8b 100644 --- a/common_s.h +++ b/common_s.h @@ -164,6 +164,11 @@ #define SGEADD_K sgeadd_k +#define SGEMM_SMALL_KERNEL_NN sgemm_small_kernel_nn +#define SGEMM_SMALL_KERNEL_NT sgemm_small_kernel_nt +#define SGEMM_SMALL_KERNEL_TN sgemm_small_kernel_tn +#define SGEMM_SMALL_KERNEL_TT sgemm_small_kernel_tt + #else #define SAMAX_K gotoblas -> samax_k diff --git a/interface/gemm.c b/interface/gemm.c index 10426fd8f..d2fb42ff7 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -105,6 +105,18 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, B #endif }; +#ifdef SMALL_MATRIX_OPT +//Only support s/dgemm small matrix optimiztion so far. +static int (*gemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG) = { +#ifndef GEMM3M +#ifndef COMPLEX + GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, NULL, NULL, + GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, NULL, NULL, +#endif +#endif +}; +#endif + #ifndef CBLAS void NAME(char *TRANSA, char *TRANSB, @@ -417,6 +429,20 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS FUNCTION_PROFILE_START(); + MNK = (double) args.m * (double) args.n * (double) args.k; + +#ifdef SMALL_MATRIX_OPT +#if !defined(COMPLEX) + //need to tune small matrices cases. + if(MNK <= 100.0*100.0*100.0){ + (gemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, + args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); + return; + } +#endif +#endif + + buffer = (XFLOAT *)blas_memory_alloc(0); sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A); @@ -428,7 +454,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS mode |= (transb << BLAS_TRANSB_SHIFT); #endif - MNK = (double) args.m * (double) args.n * (double) args.k; + if ( MNK <= (SMP_THRESHOLD_MIN * (double) GEMM_MULTITHREAD_THRESHOLD) ) args.nthreads = 1; else diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 2d9e3ec36..88e5eb2d6 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -447,6 +447,19 @@ XBLASOBJS += \ endif +###### BLAS small matrix optimization ##### +ifeq ($(SMALL_MATRIX_OPT), 1) + +SBLASOBJS += \ + sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ + sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) + +DBLASOBJS += \ + dgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ + dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) + +endif + ###### BLAS extensions ##### ifeq ($(BUILD_SINGLE),1) @@ -4237,3 +4250,63 @@ endif $(KDIR)zgeadd_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEADD_K) $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -UROWM $< -o $@ + + +###### BLAS small matrix optimization ##### + +ifndef DGEMM_SAMLL_K_NN +DGEMM_SAMLL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c +endif + +ifndef DGEMM_SAMLL_K_NT +DGEMM_SAMLL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c +endif + +ifndef DGEMM_SAMLL_K_TN +DGEMM_SAMLL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c +endif + +ifndef DGEMM_SAMLL_K_TT +DGEMM_SAMLL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c +endif + +$(KDIR)dgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)dgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + + +ifndef SGEMM_SAMLL_K_NN +SGEMM_SAMLL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c +endif + +ifndef SGEMM_SAMLL_K_NT +SGEMM_SAMLL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c +endif + +ifndef SGEMM_SAMLL_K_TN +SGEMM_SAMLL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c +endif + +ifndef SGEMM_SAMLL_K_TT +SGEMM_SAMLL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c +endif + +$(KDIR)sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ diff --git a/kernel/generic/gemm_small_matrix_kernel_nn.c b/kernel/generic/gemm_small_matrix_kernel_nn.c new file mode 100644 index 000000000..efcc27cba --- /dev/null +++ b/kernel/generic/gemm_small_matrix_kernel_nn.c @@ -0,0 +1,49 @@ +/*************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +{ + //naive implemtation + //Column major + + BLASLONG i,j,k; + FLOAT result=0.0; + + for(i=0; i Date: Tue, 28 Apr 2020 22:35:36 +0800 Subject: [PATCH 02/37] Add alpha=1.0 beta=0.0 for small gemm. --- common_d.h | 5 ++ common_level3.h | 11 ++++ common_macro.h | 14 ++++ common_s.h | 5 ++ interface/gemm.c | 18 +++++- kernel/Makefile.L3 | 64 ++++++++++++++++++- .../gemm_small_matrix_kernel_a1b0_nn.c | 49 ++++++++++++++ .../gemm_small_matrix_kernel_a1b0_nt.c | 49 ++++++++++++++ .../gemm_small_matrix_kernel_a1b0_tn.c | 49 ++++++++++++++ .../gemm_small_matrix_kernel_a1b0_tt.c | 49 ++++++++++++++ 10 files changed, 309 insertions(+), 4 deletions(-) create mode 100644 kernel/generic/gemm_small_matrix_kernel_a1b0_nn.c create mode 100644 kernel/generic/gemm_small_matrix_kernel_a1b0_nt.c create mode 100644 kernel/generic/gemm_small_matrix_kernel_a1b0_tn.c create mode 100644 kernel/generic/gemm_small_matrix_kernel_a1b0_tt.c diff --git a/common_d.h b/common_d.h index dad304a5f..f5d7935fa 100644 --- a/common_d.h +++ b/common_d.h @@ -163,6 +163,11 @@ #define DGEMM_SMALL_KERNEL_TN dgemm_small_kernel_tn #define DGEMM_SMALL_KERNEL_TT dgemm_small_kernel_tt +#define DGEMM_SMALL_KERNEL_A1B0_NN dgemm_small_kernel_a1b0_nn +#define DGEMM_SMALL_KERNEL_A1B0_NT dgemm_small_kernel_a1b0_nt +#define DGEMM_SMALL_KERNEL_A1B0_TN dgemm_small_kernel_a1b0_tn +#define DGEMM_SMALL_KERNEL_A1B0_TT dgemm_small_kernel_a1b0_tt + #else #define DAMAX_K gotoblas -> damax_k diff --git a/common_level3.h b/common_level3.h index 751592b67..31d514cd5 100644 --- a/common_level3.h +++ b/common_level3.h @@ -525,6 +525,17 @@ int dgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLO int dgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); int dgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); int dgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); + +int sgemm_small_kernel_a1b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sgemm_small_kernel_a1b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sgemm_small_kernel_a1b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sgemm_small_kernel_a1b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + +int dgemm_small_kernel_a1b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int dgemm_small_kernel_a1b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int dgemm_small_kernel_a1b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int dgemm_small_kernel_a1b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + #endif int cgemm_kernel_n(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG); diff --git a/common_macro.h b/common_macro.h index eb2abcdc0..2f7263023 100644 --- a/common_macro.h +++ b/common_macro.h @@ -648,6 +648,10 @@ #define GEMM_SMALL_KERNEL_NT DGEMM_SMALL_KERNEL_NT #define GEMM_SMALL_KERNEL_TN DGEMM_SMALL_KERNEL_TN #define GEMM_SMALL_KERNEL_TT DGEMM_SMALL_KERNEL_TT +#define GEMM_SMALL_KERNEL_A1B0_NN DGEMM_SMALL_KERNEL_A1B0_NN +#define GEMM_SMALL_KERNEL_A1B0_NT DGEMM_SMALL_KERNEL_A1B0_NT +#define GEMM_SMALL_KERNEL_A1B0_TN DGEMM_SMALL_KERNEL_A1B0_TN +#define GEMM_SMALL_KERNEL_A1B0_TT DGEMM_SMALL_KERNEL_A1B0_TT #elif defined(BFLOAT16) @@ -941,6 +945,11 @@ #define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN #define GEMM_SMALL_KERNEL_TT SGEMM_SMALL_KERNEL_TT +#define GEMM_SMALL_KERNEL_A1B0_NN SGEMM_SMALL_KERNEL_A1B0_NN +#define GEMM_SMALL_KERNEL_A1B0_NT SGEMM_SMALL_KERNEL_A1B0_NT +#define GEMM_SMALL_KERNEL_A1B0_TN SGEMM_SMALL_KERNEL_A1B0_TN +#define GEMM_SMALL_KERNEL_A1B0_TT SGEMM_SMALL_KERNEL_A1B0_TT + #endif #else @@ -1252,6 +1261,11 @@ #define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN #define GEMM_SMALL_KERNEL_TT SGEMM_SMALL_KERNEL_TT +#define GEMM_SMALL_KERNEL_A1B0_NN SGEMM_SMALL_KERNEL_A1B0_NN +#define GEMM_SMALL_KERNEL_A1B0_NT SGEMM_SMALL_KERNEL_A1B0_NT +#define GEMM_SMALL_KERNEL_A1B0_TN SGEMM_SMALL_KERNEL_A1B0_TN +#define GEMM_SMALL_KERNEL_A1B0_TT SGEMM_SMALL_KERNEL_A1B0_TT + #endif #else #ifdef XDOUBLE diff --git a/common_s.h b/common_s.h index 6ad98ba8b..440b78723 100644 --- a/common_s.h +++ b/common_s.h @@ -169,6 +169,11 @@ #define SGEMM_SMALL_KERNEL_TN sgemm_small_kernel_tn #define SGEMM_SMALL_KERNEL_TT sgemm_small_kernel_tt +#define SGEMM_SMALL_KERNEL_A1B0_NN sgemm_small_kernel_a1b0_nn +#define SGEMM_SMALL_KERNEL_A1B0_NT sgemm_small_kernel_a1b0_nt +#define SGEMM_SMALL_KERNEL_A1B0_TN sgemm_small_kernel_a1b0_tn +#define SGEMM_SMALL_KERNEL_A1B0_TT sgemm_small_kernel_a1b0_tt + #else #define SAMAX_K gotoblas -> samax_k diff --git a/interface/gemm.c b/interface/gemm.c index d2fb42ff7..da602f7a9 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -115,6 +115,15 @@ static int (*gemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLON #endif #endif }; + +static int (*gemm_small_kernel_a1b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = { +#ifndef GEMM3M +#ifndef COMPLEX + GEMM_SMALL_KERNEL_A1B0_NN, GEMM_SMALL_KERNEL_A1B0_TN, NULL, NULL, + GEMM_SMALL_KERNEL_A1B0_NT, GEMM_SMALL_KERNEL_A1B0_TT, NULL, NULL, +#endif +#endif +}; #endif #ifndef CBLAS @@ -435,8 +444,13 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS #if !defined(COMPLEX) //need to tune small matrices cases. if(MNK <= 100.0*100.0*100.0){ - (gemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, - args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); + + if(*(FLOAT *)(args.alpha) == 1.0 && *(FLOAT *)(args.beta) == 0.0){ + (gemm_small_kernel_a1b0[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda,args.b, args.ldb, args.c, args.ldc); + }else{ + (gemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); + } + return; } #endif diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 88e5eb2d6..448d22e4e 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -452,11 +452,15 @@ ifeq ($(SMALL_MATRIX_OPT), 1) SBLASOBJS += \ sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ - sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) + sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ + sgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) \ + sgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) DBLASOBJS += \ dgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ - dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) + dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ + dgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) \ + dgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) endif @@ -4282,6 +4286,34 @@ $(KDIR)dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_ $(KDIR)dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_TT) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ +ifndef DGEMM_SAMLL_K_A1B0_NN +DGEMM_SAMLL_K_A1B0_NN = ../generic/gemm_small_matrix_kernel_a1b0_nn.c +endif + +ifndef DGEMM_SAMLL_K_A1B0_NT +DGEMM_SAMLL_K_A1B0_NT = ../generic/gemm_small_matrix_kernel_a1b0_nt.c +endif + +ifndef DGEMM_SAMLL_K_A1B0_TN +DGEMM_SAMLL_K_A1B0_TN = ../generic/gemm_small_matrix_kernel_a1b0_tn.c +endif + +ifndef DGEMM_SAMLL_K_A1B0_TT +DGEMM_SAMLL_K_A1B0_TT = ../generic/gemm_small_matrix_kernel_a1b0_tt.c +endif + +$(KDIR)dgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_A1B0_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)dgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_A1B0_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)dgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_A1B0_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)dgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_A1B0_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + ifndef SGEMM_SAMLL_K_NN SGEMM_SAMLL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c @@ -4310,3 +4342,31 @@ $(KDIR)sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_ $(KDIR)sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_TT) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +ifndef SGEMM_SAMLL_K_A1B0_NN +SGEMM_SAMLL_K_A1B0_NN = ../generic/gemm_small_matrix_kernel_a1b0_nn.c +endif + +ifndef SGEMM_SAMLL_K_A1B0_NT +SGEMM_SAMLL_K_A1B0_NT = ../generic/gemm_small_matrix_kernel_a1b0_nt.c +endif + +ifndef SGEMM_SAMLL_K_A1B0_TN +SGEMM_SAMLL_K_A1B0_TN = ../generic/gemm_small_matrix_kernel_a1b0_tn.c +endif + +ifndef SGEMM_SAMLL_K_A1B0_TT +SGEMM_SAMLL_K_A1B0_TT = ../generic/gemm_small_matrix_kernel_a1b0_tt.c +endif + +$(KDIR)sgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_A1B0_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_A1B0_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_A1B0_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_A1B0_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ diff --git a/kernel/generic/gemm_small_matrix_kernel_a1b0_nn.c b/kernel/generic/gemm_small_matrix_kernel_a1b0_nn.c new file mode 100644 index 000000000..8e3417027 --- /dev/null +++ b/kernel/generic/gemm_small_matrix_kernel_a1b0_nn.c @@ -0,0 +1,49 @@ +/*************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT * B, BLASLONG ldb,FLOAT * C, BLASLONG ldc) +{ + //naive implemtation + //Column major + + BLASLONG i,j,k; + FLOAT result=0.0; + + for(i=0; i Date: Tue, 28 Apr 2020 23:15:20 +0800 Subject: [PATCH 03/37] Fix gemm interface bug for small matrix. --- interface/gemm.c | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/interface/gemm.c b/interface/gemm.c index da602f7a9..4f1bbfd1c 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -145,7 +145,7 @@ void NAME(char *TRANSA, char *TRANSB, IFLOAT *buffer; IFLOAT *sa, *sb; -#ifdef SMP +#if defined (SMP) || defined(SMALL_MATRIX_OPT) double MNK; #if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY) #ifndef COMPLEX @@ -269,8 +269,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS XFLOAT *buffer; XFLOAT *sa, *sb; -#ifdef SMP +#if defined (SMP) || defined(SMALL_MATRIX_OPT) double MNK; +#endif + +#ifdef SMP #if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY) #ifndef COMPLEX #ifdef XDOUBLE @@ -438,7 +441,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS FUNCTION_PROFILE_START(); +#if defined(SMP) || defined(SMALL_MATRIX_OPT) MNK = (double) args.m * (double) args.n * (double) args.k; +#endif #ifdef SMALL_MATRIX_OPT #if !defined(COMPLEX) From 59cb5de46b89a080d1190e89bed543fd32f924c7 Mon Sep 17 00:00:00 2001 From: Xianyi Zhang Date: Wed, 29 Apr 2020 00:19:19 +0800 Subject: [PATCH 04/37] Refs #2587 Fix typos. --- kernel/Makefile.L3 | 96 +++++++++++++++++++++++----------------------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 448d22e4e..6476334e9 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -4258,115 +4258,115 @@ $(KDIR)zgeadd_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEADD_K) ###### BLAS small matrix optimization ##### -ifndef DGEMM_SAMLL_K_NN -DGEMM_SAMLL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c +ifndef DGEMM_SMALL_K_NN +DGEMM_SMALL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c endif -ifndef DGEMM_SAMLL_K_NT -DGEMM_SAMLL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c +ifndef DGEMM_SMALL_K_NT +DGEMM_SMALL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c endif -ifndef DGEMM_SAMLL_K_TN -DGEMM_SAMLL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c +ifndef DGEMM_SMALL_K_TN +DGEMM_SMALL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c endif -ifndef DGEMM_SAMLL_K_TT -DGEMM_SAMLL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c +ifndef DGEMM_SMALL_K_TT +DGEMM_SMALL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c endif -$(KDIR)dgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_NN) +$(KDIR)dgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_NN) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)dgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_NT) +$(KDIR)dgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_NT) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_TN) +$(KDIR)dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_TN) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_TT) +$(KDIR)dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_TT) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ -ifndef DGEMM_SAMLL_K_A1B0_NN -DGEMM_SAMLL_K_A1B0_NN = ../generic/gemm_small_matrix_kernel_a1b0_nn.c +ifndef DGEMM_SMALL_K_A1B0_NN +DGEMM_SMALL_K_A1B0_NN = ../generic/gemm_small_matrix_kernel_a1b0_nn.c endif -ifndef DGEMM_SAMLL_K_A1B0_NT -DGEMM_SAMLL_K_A1B0_NT = ../generic/gemm_small_matrix_kernel_a1b0_nt.c +ifndef DGEMM_SMALL_K_A1B0_NT +DGEMM_SMALL_K_A1B0_NT = ../generic/gemm_small_matrix_kernel_a1b0_nt.c endif -ifndef DGEMM_SAMLL_K_A1B0_TN -DGEMM_SAMLL_K_A1B0_TN = ../generic/gemm_small_matrix_kernel_a1b0_tn.c +ifndef DGEMM_SMALL_K_A1B0_TN +DGEMM_SMALL_K_A1B0_TN = ../generic/gemm_small_matrix_kernel_a1b0_tn.c endif -ifndef DGEMM_SAMLL_K_A1B0_TT -DGEMM_SAMLL_K_A1B0_TT = ../generic/gemm_small_matrix_kernel_a1b0_tt.c +ifndef DGEMM_SMALL_K_A1B0_TT +DGEMM_SMALL_K_A1B0_TT = ../generic/gemm_small_matrix_kernel_a1b0_tt.c endif -$(KDIR)dgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_A1B0_NN) +$(KDIR)dgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_A1B0_NN) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)dgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_A1B0_NT) +$(KDIR)dgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_A1B0_NT) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)dgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_A1B0_TN) +$(KDIR)dgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_A1B0_TN) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)dgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_A1B0_TT) +$(KDIR)dgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_A1B0_TT) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ -ifndef SGEMM_SAMLL_K_NN -SGEMM_SAMLL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c +ifndef SGEMM_SMALL_K_NN +SGEMM_SMALL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c endif -ifndef SGEMM_SAMLL_K_NT -SGEMM_SAMLL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c +ifndef SGEMM_SMALL_K_NT +SGEMM_SMALL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c endif -ifndef SGEMM_SAMLL_K_TN -SGEMM_SAMLL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c +ifndef SGEMM_SMALL_K_TN +SGEMM_SMALL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c endif -ifndef SGEMM_SAMLL_K_TT -SGEMM_SAMLL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c +ifndef SGEMM_SMALL_K_TT +SGEMM_SMALL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c endif -$(KDIR)sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_NN) +$(KDIR)sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_NN) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_NT) +$(KDIR)sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_NT) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_TN) +$(KDIR)sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_TN) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_TT) +$(KDIR)sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_TT) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ -ifndef SGEMM_SAMLL_K_A1B0_NN -SGEMM_SAMLL_K_A1B0_NN = ../generic/gemm_small_matrix_kernel_a1b0_nn.c +ifndef SGEMM_SMALL_K_A1B0_NN +SGEMM_SMALL_K_A1B0_NN = ../generic/gemm_small_matrix_kernel_a1b0_nn.c endif -ifndef SGEMM_SAMLL_K_A1B0_NT -SGEMM_SAMLL_K_A1B0_NT = ../generic/gemm_small_matrix_kernel_a1b0_nt.c +ifndef SGEMM_SMALL_K_A1B0_NT +SGEMM_SMALL_K_A1B0_NT = ../generic/gemm_small_matrix_kernel_a1b0_nt.c endif -ifndef SGEMM_SAMLL_K_A1B0_TN -SGEMM_SAMLL_K_A1B0_TN = ../generic/gemm_small_matrix_kernel_a1b0_tn.c +ifndef SGEMM_SMALL_K_A1B0_TN +SGEMM_SMALL_K_A1B0_TN = ../generic/gemm_small_matrix_kernel_a1b0_tn.c endif -ifndef SGEMM_SAMLL_K_A1B0_TT -SGEMM_SAMLL_K_A1B0_TT = ../generic/gemm_small_matrix_kernel_a1b0_tt.c +ifndef SGEMM_SMALL_K_A1B0_TT +SGEMM_SMALL_K_A1B0_TT = ../generic/gemm_small_matrix_kernel_a1b0_tt.c endif -$(KDIR)sgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_A1B0_NN) +$(KDIR)sgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_A1B0_NN) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)sgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_A1B0_NT) +$(KDIR)sgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_A1B0_NT) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)sgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_A1B0_TN) +$(KDIR)sgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_A1B0_TN) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)sgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_A1B0_TT) +$(KDIR)sgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_A1B0_TT) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ From 17d32a4a8271141be2fb96c8c767ac1ed2e60a36 Mon Sep 17 00:00:00 2001 From: Xianyi Zhang Date: Fri, 28 Aug 2020 07:55:27 +0800 Subject: [PATCH 05/37] Change a1b0 gemm to b0 gemm. --- common_d.h | 8 +-- common_level3.h | 16 +++--- common_macro.h | 24 ++++---- common_s.h | 8 +-- interface/gemm.c | 10 ++-- kernel/Makefile.L3 | 56 +++++++++---------- ..._nn.c => gemm_small_matrix_kernel_b0_nn.c} | 4 +- ..._nt.c => gemm_small_matrix_kernel_b0_nt.c} | 4 +- ..._tn.c => gemm_small_matrix_kernel_b0_tn.c} | 4 +- ..._tt.c => gemm_small_matrix_kernel_b0_tt.c} | 4 +- 10 files changed, 69 insertions(+), 69 deletions(-) rename kernel/generic/{gemm_small_matrix_kernel_a1b0_nn.c => gemm_small_matrix_kernel_b0_nn.c} (95%) rename kernel/generic/{gemm_small_matrix_kernel_a1b0_nt.c => gemm_small_matrix_kernel_b0_nt.c} (95%) rename kernel/generic/{gemm_small_matrix_kernel_a1b0_tn.c => gemm_small_matrix_kernel_b0_tn.c} (95%) rename kernel/generic/{gemm_small_matrix_kernel_a1b0_tt.c => gemm_small_matrix_kernel_b0_tt.c} (95%) diff --git a/common_d.h b/common_d.h index f5d7935fa..42c14e828 100644 --- a/common_d.h +++ b/common_d.h @@ -163,10 +163,10 @@ #define DGEMM_SMALL_KERNEL_TN dgemm_small_kernel_tn #define DGEMM_SMALL_KERNEL_TT dgemm_small_kernel_tt -#define DGEMM_SMALL_KERNEL_A1B0_NN dgemm_small_kernel_a1b0_nn -#define DGEMM_SMALL_KERNEL_A1B0_NT dgemm_small_kernel_a1b0_nt -#define DGEMM_SMALL_KERNEL_A1B0_TN dgemm_small_kernel_a1b0_tn -#define DGEMM_SMALL_KERNEL_A1B0_TT dgemm_small_kernel_a1b0_tt +#define DGEMM_SMALL_KERNEL_B0_NN dgemm_small_kernel_b0_nn +#define DGEMM_SMALL_KERNEL_B0_NT dgemm_small_kernel_b0_nt +#define DGEMM_SMALL_KERNEL_B0_TN dgemm_small_kernel_b0_tn +#define DGEMM_SMALL_KERNEL_B0_TT dgemm_small_kernel_b0_tt #else diff --git a/common_level3.h b/common_level3.h index 31d514cd5..7be7ab06b 100644 --- a/common_level3.h +++ b/common_level3.h @@ -526,15 +526,15 @@ int dgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLO int dgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); int dgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); -int sgemm_small_kernel_a1b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int sgemm_small_kernel_a1b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int sgemm_small_kernel_a1b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int sgemm_small_kernel_a1b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int dgemm_small_kernel_a1b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int dgemm_small_kernel_a1b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int dgemm_small_kernel_a1b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int dgemm_small_kernel_a1b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int dgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int dgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int dgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int dgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); #endif diff --git a/common_macro.h b/common_macro.h index 2f7263023..fa7884180 100644 --- a/common_macro.h +++ b/common_macro.h @@ -648,10 +648,10 @@ #define GEMM_SMALL_KERNEL_NT DGEMM_SMALL_KERNEL_NT #define GEMM_SMALL_KERNEL_TN DGEMM_SMALL_KERNEL_TN #define GEMM_SMALL_KERNEL_TT DGEMM_SMALL_KERNEL_TT -#define GEMM_SMALL_KERNEL_A1B0_NN DGEMM_SMALL_KERNEL_A1B0_NN -#define GEMM_SMALL_KERNEL_A1B0_NT DGEMM_SMALL_KERNEL_A1B0_NT -#define GEMM_SMALL_KERNEL_A1B0_TN DGEMM_SMALL_KERNEL_A1B0_TN -#define GEMM_SMALL_KERNEL_A1B0_TT DGEMM_SMALL_KERNEL_A1B0_TT +#define GEMM_SMALL_KERNEL_B0_NN DGEMM_SMALL_KERNEL_B0_NN +#define GEMM_SMALL_KERNEL_B0_NT DGEMM_SMALL_KERNEL_B0_NT +#define GEMM_SMALL_KERNEL_B0_TN DGEMM_SMALL_KERNEL_B0_TN +#define GEMM_SMALL_KERNEL_B0_TT DGEMM_SMALL_KERNEL_B0_TT #elif defined(BFLOAT16) @@ -945,10 +945,10 @@ #define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN #define GEMM_SMALL_KERNEL_TT SGEMM_SMALL_KERNEL_TT -#define GEMM_SMALL_KERNEL_A1B0_NN SGEMM_SMALL_KERNEL_A1B0_NN -#define GEMM_SMALL_KERNEL_A1B0_NT SGEMM_SMALL_KERNEL_A1B0_NT -#define GEMM_SMALL_KERNEL_A1B0_TN SGEMM_SMALL_KERNEL_A1B0_TN -#define GEMM_SMALL_KERNEL_A1B0_TT SGEMM_SMALL_KERNEL_A1B0_TT +#define GEMM_SMALL_KERNEL_B0_NN SGEMM_SMALL_KERNEL_B0_NN +#define GEMM_SMALL_KERNEL_B0_NT SGEMM_SMALL_KERNEL_B0_NT +#define GEMM_SMALL_KERNEL_B0_TN SGEMM_SMALL_KERNEL_B0_TN +#define GEMM_SMALL_KERNEL_B0_TT SGEMM_SMALL_KERNEL_B0_TT #endif @@ -1261,10 +1261,10 @@ #define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN #define GEMM_SMALL_KERNEL_TT SGEMM_SMALL_KERNEL_TT -#define GEMM_SMALL_KERNEL_A1B0_NN SGEMM_SMALL_KERNEL_A1B0_NN -#define GEMM_SMALL_KERNEL_A1B0_NT SGEMM_SMALL_KERNEL_A1B0_NT -#define GEMM_SMALL_KERNEL_A1B0_TN SGEMM_SMALL_KERNEL_A1B0_TN -#define GEMM_SMALL_KERNEL_A1B0_TT SGEMM_SMALL_KERNEL_A1B0_TT +#define GEMM_SMALL_KERNEL_B0_NN SGEMM_SMALL_KERNEL_B0_NN +#define GEMM_SMALL_KERNEL_B0_NT SGEMM_SMALL_KERNEL_B0_NT +#define GEMM_SMALL_KERNEL_B0_TN SGEMM_SMALL_KERNEL_B0_TN +#define GEMM_SMALL_KERNEL_B0_TT SGEMM_SMALL_KERNEL_B0_TT #endif #else diff --git a/common_s.h b/common_s.h index 440b78723..685d73062 100644 --- a/common_s.h +++ b/common_s.h @@ -169,10 +169,10 @@ #define SGEMM_SMALL_KERNEL_TN sgemm_small_kernel_tn #define SGEMM_SMALL_KERNEL_TT sgemm_small_kernel_tt -#define SGEMM_SMALL_KERNEL_A1B0_NN sgemm_small_kernel_a1b0_nn -#define SGEMM_SMALL_KERNEL_A1B0_NT sgemm_small_kernel_a1b0_nt -#define SGEMM_SMALL_KERNEL_A1B0_TN sgemm_small_kernel_a1b0_tn -#define SGEMM_SMALL_KERNEL_A1B0_TT sgemm_small_kernel_a1b0_tt +#define SGEMM_SMALL_KERNEL_B0_NN sgemm_small_kernel_b0_nn +#define SGEMM_SMALL_KERNEL_B0_NT sgemm_small_kernel_b0_nt +#define SGEMM_SMALL_KERNEL_B0_TN sgemm_small_kernel_b0_tn +#define SGEMM_SMALL_KERNEL_B0_TT sgemm_small_kernel_b0_tt #else diff --git a/interface/gemm.c b/interface/gemm.c index 4f1bbfd1c..3730f37fa 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -116,11 +116,11 @@ static int (*gemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLON #endif }; -static int (*gemm_small_kernel_a1b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = { +static int (*gemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = { #ifndef GEMM3M #ifndef COMPLEX - GEMM_SMALL_KERNEL_A1B0_NN, GEMM_SMALL_KERNEL_A1B0_TN, NULL, NULL, - GEMM_SMALL_KERNEL_A1B0_NT, GEMM_SMALL_KERNEL_A1B0_TT, NULL, NULL, + GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, NULL, NULL, + GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, NULL, NULL, #endif #endif }; @@ -450,8 +450,8 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS //need to tune small matrices cases. if(MNK <= 100.0*100.0*100.0){ - if(*(FLOAT *)(args.alpha) == 1.0 && *(FLOAT *)(args.beta) == 0.0){ - (gemm_small_kernel_a1b0[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda,args.b, args.ldb, args.c, args.ldc); + if(*(FLOAT *)(args.beta) == 0.0){ + (gemm_small_kernel_b0[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, args.c, args.ldc); }else{ (gemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); } diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 6476334e9..c9544086a 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -453,14 +453,14 @@ ifeq ($(SMALL_MATRIX_OPT), 1) SBLASOBJS += \ sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ - sgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) \ - sgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) + sgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ + sgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) DBLASOBJS += \ dgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ - dgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) \ - dgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) + dgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ + dgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) endif @@ -4286,32 +4286,32 @@ $(KDIR)dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_ $(KDIR)dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_TT) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ -ifndef DGEMM_SMALL_K_A1B0_NN -DGEMM_SMALL_K_A1B0_NN = ../generic/gemm_small_matrix_kernel_a1b0_nn.c +ifndef DGEMM_SMALL_K_B0_NN +DGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_b0_nn.c endif -ifndef DGEMM_SMALL_K_A1B0_NT -DGEMM_SMALL_K_A1B0_NT = ../generic/gemm_small_matrix_kernel_a1b0_nt.c +ifndef DGEMM_SMALL_K_B0_NT +DGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_b0_nt.c endif -ifndef DGEMM_SMALL_K_A1B0_TN -DGEMM_SMALL_K_A1B0_TN = ../generic/gemm_small_matrix_kernel_a1b0_tn.c +ifndef DGEMM_SMALL_K_B0_TN +DGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_b0_tn.c endif -ifndef DGEMM_SMALL_K_A1B0_TT -DGEMM_SMALL_K_A1B0_TT = ../generic/gemm_small_matrix_kernel_a1b0_tt.c +ifndef DGEMM_SMALL_K_B0_TT +DGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_b0_tt.c endif -$(KDIR)dgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_A1B0_NN) +$(KDIR)dgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_NN) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)dgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_A1B0_NT) +$(KDIR)dgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_NT) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)dgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_A1B0_TN) +$(KDIR)dgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_TN) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)dgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_A1B0_TT) +$(KDIR)dgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_TT) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ @@ -4343,30 +4343,30 @@ $(KDIR)sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_ $(KDIR)sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_TT) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ -ifndef SGEMM_SMALL_K_A1B0_NN -SGEMM_SMALL_K_A1B0_NN = ../generic/gemm_small_matrix_kernel_a1b0_nn.c +ifndef SGEMM_SMALL_K_B0_NN +SGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_b0_nn.c endif -ifndef SGEMM_SMALL_K_A1B0_NT -SGEMM_SMALL_K_A1B0_NT = ../generic/gemm_small_matrix_kernel_a1b0_nt.c +ifndef SGEMM_SMALL_K_B0_NT +SGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_b0_nt.c endif -ifndef SGEMM_SMALL_K_A1B0_TN -SGEMM_SMALL_K_A1B0_TN = ../generic/gemm_small_matrix_kernel_a1b0_tn.c +ifndef SGEMM_SMALL_K_B0_TN +SGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_b0_tn.c endif -ifndef SGEMM_SMALL_K_A1B0_TT -SGEMM_SMALL_K_A1B0_TT = ../generic/gemm_small_matrix_kernel_a1b0_tt.c +ifndef SGEMM_SMALL_K_B0_TT +SGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_b0_tt.c endif -$(KDIR)sgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_A1B0_NN) +$(KDIR)sgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_NN) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)sgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_A1B0_NT) +$(KDIR)sgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_NT) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)sgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_A1B0_TN) +$(KDIR)sgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_TN) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ -$(KDIR)sgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_A1B0_TT) +$(KDIR)sgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_TT) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ diff --git a/kernel/generic/gemm_small_matrix_kernel_a1b0_nn.c b/kernel/generic/gemm_small_matrix_kernel_b0_nn.c similarity index 95% rename from kernel/generic/gemm_small_matrix_kernel_a1b0_nn.c rename to kernel/generic/gemm_small_matrix_kernel_b0_nn.c index 8e3417027..3be918017 100644 --- a/kernel/generic/gemm_small_matrix_kernel_a1b0_nn.c +++ b/kernel/generic/gemm_small_matrix_kernel_b0_nn.c @@ -27,7 +27,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT * B, BLASLONG ldb,FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb,FLOAT * C, BLASLONG ldc) { //naive implemtation //Column major @@ -41,7 +41,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT * B for(k=0; k Date: Fri, 28 Aug 2020 21:00:54 +0800 Subject: [PATCH 06/37] Refs #2587 Add small matrix optimization reference kernel for c/zgemm. --- common_c.h | 40 +++ common_level3.h | 80 +++++ common_macro.h | 80 +++++ common_z.h | 40 +++ interface/gemm.c | 35 ++- kernel/Makefile.L3 | 293 ++++++++++++++++++ .../generic/zgemm_small_matrix_kernel_b0_nn.c | 74 +++++ .../generic/zgemm_small_matrix_kernel_b0_nt.c | 77 +++++ .../generic/zgemm_small_matrix_kernel_b0_tn.c | 77 +++++ .../generic/zgemm_small_matrix_kernel_b0_tt.c | 77 +++++ kernel/generic/zgemm_small_matrix_kernel_nn.c | 78 +++++ kernel/generic/zgemm_small_matrix_kernel_nt.c | 82 +++++ kernel/generic/zgemm_small_matrix_kernel_tn.c | 82 +++++ kernel/generic/zgemm_small_matrix_kernel_tt.c | 82 +++++ 14 files changed, 1193 insertions(+), 4 deletions(-) create mode 100644 kernel/generic/zgemm_small_matrix_kernel_b0_nn.c create mode 100644 kernel/generic/zgemm_small_matrix_kernel_b0_nt.c create mode 100644 kernel/generic/zgemm_small_matrix_kernel_b0_tn.c create mode 100644 kernel/generic/zgemm_small_matrix_kernel_b0_tt.c create mode 100644 kernel/generic/zgemm_small_matrix_kernel_nn.c create mode 100644 kernel/generic/zgemm_small_matrix_kernel_nt.c create mode 100644 kernel/generic/zgemm_small_matrix_kernel_tn.c create mode 100644 kernel/generic/zgemm_small_matrix_kernel_tt.c diff --git a/common_c.h b/common_c.h index 40ecf5b8b..9388ece93 100644 --- a/common_c.h +++ b/common_c.h @@ -232,6 +232,46 @@ #define CGEADD_K cgeadd_k +#define CGEMM_SMALL_KERNEL_NN cgemm_small_kernel_nn +#define CGEMM_SMALL_KERNEL_NT cgemm_small_kernel_nt +#define CGEMM_SMALL_KERNEL_NR cgemm_small_kernel_nr +#define CGEMM_SMALL_KERNEL_NC cgemm_small_kernel_nc + +#define CGEMM_SMALL_KERNEL_TN cgemm_small_kernel_tn +#define CGEMM_SMALL_KERNEL_TT cgemm_small_kernel_tt +#define CGEMM_SMALL_KERNEL_TR cgemm_small_kernel_tr +#define CGEMM_SMALL_KERNEL_TC cgemm_small_kernel_tc + +#define CGEMM_SMALL_KERNEL_RN cgemm_small_kernel_rn +#define CGEMM_SMALL_KERNEL_RT cgemm_small_kernel_rt +#define CGEMM_SMALL_KERNEL_RR cgemm_small_kernel_rr +#define CGEMM_SMALL_KERNEL_RC cgemm_small_kernel_rc + +#define CGEMM_SMALL_KERNEL_CN cgemm_small_kernel_cn +#define CGEMM_SMALL_KERNEL_CT cgemm_small_kernel_ct +#define CGEMM_SMALL_KERNEL_CR cgemm_small_kernel_cr +#define CGEMM_SMALL_KERNEL_CC cgemm_small_kernel_cc + +#define CGEMM_SMALL_KERNEL_B0_NN cgemm_small_kernel_b0_nn +#define CGEMM_SMALL_KERNEL_B0_NT cgemm_small_kernel_b0_nt +#define CGEMM_SMALL_KERNEL_B0_NR cgemm_small_kernel_b0_nr +#define CGEMM_SMALL_KERNEL_B0_NC cgemm_small_kernel_b0_nc + +#define CGEMM_SMALL_KERNEL_B0_TN cgemm_small_kernel_b0_tn +#define CGEMM_SMALL_KERNEL_B0_TT cgemm_small_kernel_b0_tt +#define CGEMM_SMALL_KERNEL_B0_TR cgemm_small_kernel_b0_tr +#define CGEMM_SMALL_KERNEL_B0_TC cgemm_small_kernel_b0_tc + +#define CGEMM_SMALL_KERNEL_B0_RN cgemm_small_kernel_b0_rn +#define CGEMM_SMALL_KERNEL_B0_RT cgemm_small_kernel_b0_rt +#define CGEMM_SMALL_KERNEL_B0_RR cgemm_small_kernel_b0_rr +#define CGEMM_SMALL_KERNEL_B0_RC cgemm_small_kernel_b0_rc + +#define CGEMM_SMALL_KERNEL_B0_CN cgemm_small_kernel_b0_cn +#define CGEMM_SMALL_KERNEL_B0_CT cgemm_small_kernel_b0_ct +#define CGEMM_SMALL_KERNEL_B0_CR cgemm_small_kernel_b0_cr +#define CGEMM_SMALL_KERNEL_B0_CC cgemm_small_kernel_b0_cc + #else #define CAMAX_K gotoblas -> camax_k diff --git a/common_level3.h b/common_level3.h index 7be7ab06b..5741f56d5 100644 --- a/common_level3.h +++ b/common_level3.h @@ -536,6 +536,86 @@ int dgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLA int dgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); int dgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int cgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_nr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_nc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); + +int cgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_tr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_tc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); + +int cgemm_small_kernel_rn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_rt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_rr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_rc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); + +int cgemm_small_kernel_cn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_ct(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_cr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_cc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); + +int zgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_nr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_nc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); + +int zgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_tr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_tc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); + +int zgemm_small_kernel_rn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_rt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_rr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_rc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); + +int zgemm_small_kernel_cn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_ct(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_cr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_cc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); + +int cgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_nr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_nc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + +int cgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_tr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_tc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + +int cgemm_small_kernel_b0_rn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_rt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_rr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_rc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + +int cgemm_small_kernel_b0_cn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_ct(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_cr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_cc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + +int zgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_nr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_nc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + +int zgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_tr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_tc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + +int zgemm_small_kernel_b0_rn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_rt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_rr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_rc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + +int zgemm_small_kernel_b0_cn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_ct(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_cr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_cc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + #endif int cgemm_kernel_n(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG); diff --git a/common_macro.h b/common_macro.h index fa7884180..2cccf9b39 100644 --- a/common_macro.h +++ b/common_macro.h @@ -2093,6 +2093,46 @@ #define GEADD_K ZGEADD_K +#define GEMM_SMALL_KERNEL_NN ZGEMM_SMALL_KERNEL_NN +#define GEMM_SMALL_KERNEL_NT ZGEMM_SMALL_KERNEL_NT +#define GEMM_SMALL_KERNEL_NR ZGEMM_SMALL_KERNEL_NR +#define GEMM_SMALL_KERNEL_NC ZGEMM_SMALL_KERNEL_NC + +#define GEMM_SMALL_KERNEL_TN ZGEMM_SMALL_KERNEL_TN +#define GEMM_SMALL_KERNEL_TT ZGEMM_SMALL_KERNEL_TT +#define GEMM_SMALL_KERNEL_TR ZGEMM_SMALL_KERNEL_TR +#define GEMM_SMALL_KERNEL_TC ZGEMM_SMALL_KERNEL_TC + +#define GEMM_SMALL_KERNEL_RN ZGEMM_SMALL_KERNEL_RN +#define GEMM_SMALL_KERNEL_RT ZGEMM_SMALL_KERNEL_RT +#define GEMM_SMALL_KERNEL_RR ZGEMM_SMALL_KERNEL_RR +#define GEMM_SMALL_KERNEL_RC ZGEMM_SMALL_KERNEL_RC + +#define GEMM_SMALL_KERNEL_CN ZGEMM_SMALL_KERNEL_CN +#define GEMM_SMALL_KERNEL_CT ZGEMM_SMALL_KERNEL_CT +#define GEMM_SMALL_KERNEL_CR ZGEMM_SMALL_KERNEL_CR +#define GEMM_SMALL_KERNEL_CC ZGEMM_SMALL_KERNEL_CC + +#define GEMM_SMALL_KERNEL_B0_NN ZGEMM_SMALL_KERNEL_B0_NN +#define GEMM_SMALL_KERNEL_B0_NT ZGEMM_SMALL_KERNEL_B0_NT +#define GEMM_SMALL_KERNEL_B0_NR ZGEMM_SMALL_KERNEL_B0_NR +#define GEMM_SMALL_KERNEL_B0_NC ZGEMM_SMALL_KERNEL_B0_NC + +#define GEMM_SMALL_KERNEL_B0_TN ZGEMM_SMALL_KERNEL_B0_TN +#define GEMM_SMALL_KERNEL_B0_TT ZGEMM_SMALL_KERNEL_B0_TT +#define GEMM_SMALL_KERNEL_B0_TR ZGEMM_SMALL_KERNEL_B0_TR +#define GEMM_SMALL_KERNEL_B0_TC ZGEMM_SMALL_KERNEL_B0_TC + +#define GEMM_SMALL_KERNEL_B0_RN ZGEMM_SMALL_KERNEL_B0_RN +#define GEMM_SMALL_KERNEL_B0_RT ZGEMM_SMALL_KERNEL_B0_RT +#define GEMM_SMALL_KERNEL_B0_RR ZGEMM_SMALL_KERNEL_B0_RR +#define GEMM_SMALL_KERNEL_B0_RC ZGEMM_SMALL_KERNEL_B0_RC + +#define GEMM_SMALL_KERNEL_B0_CN ZGEMM_SMALL_KERNEL_B0_CN +#define GEMM_SMALL_KERNEL_B0_CT ZGEMM_SMALL_KERNEL_B0_CT +#define GEMM_SMALL_KERNEL_B0_CR ZGEMM_SMALL_KERNEL_B0_CR +#define GEMM_SMALL_KERNEL_B0_CC ZGEMM_SMALL_KERNEL_B0_CC + #else #define AMAX_K CAMAX_K @@ -2516,6 +2556,46 @@ #define GEADD_K CGEADD_K +#define GEMM_SMALL_KERNEL_NN CGEMM_SMALL_KERNEL_NN +#define GEMM_SMALL_KERNEL_NT CGEMM_SMALL_KERNEL_NT +#define GEMM_SMALL_KERNEL_NR CGEMM_SMALL_KERNEL_NR +#define GEMM_SMALL_KERNEL_NC CGEMM_SMALL_KERNEL_NC + +#define GEMM_SMALL_KERNEL_TN CGEMM_SMALL_KERNEL_TN +#define GEMM_SMALL_KERNEL_TT CGEMM_SMALL_KERNEL_TT +#define GEMM_SMALL_KERNEL_TR CGEMM_SMALL_KERNEL_TR +#define GEMM_SMALL_KERNEL_TC CGEMM_SMALL_KERNEL_TC + +#define GEMM_SMALL_KERNEL_RN CGEMM_SMALL_KERNEL_RN +#define GEMM_SMALL_KERNEL_RT CGEMM_SMALL_KERNEL_RT +#define GEMM_SMALL_KERNEL_RR CGEMM_SMALL_KERNEL_RR +#define GEMM_SMALL_KERNEL_RC CGEMM_SMALL_KERNEL_RC + +#define GEMM_SMALL_KERNEL_CN CGEMM_SMALL_KERNEL_CN +#define GEMM_SMALL_KERNEL_CT CGEMM_SMALL_KERNEL_CT +#define GEMM_SMALL_KERNEL_CR CGEMM_SMALL_KERNEL_CR +#define GEMM_SMALL_KERNEL_CC CGEMM_SMALL_KERNEL_CC + +#define GEMM_SMALL_KERNEL_B0_NN CGEMM_SMALL_KERNEL_B0_NN +#define GEMM_SMALL_KERNEL_B0_NT CGEMM_SMALL_KERNEL_B0_NT +#define GEMM_SMALL_KERNEL_B0_NR CGEMM_SMALL_KERNEL_B0_NR +#define GEMM_SMALL_KERNEL_B0_NC CGEMM_SMALL_KERNEL_B0_NC + +#define GEMM_SMALL_KERNEL_B0_TN CGEMM_SMALL_KERNEL_B0_TN +#define GEMM_SMALL_KERNEL_B0_TT CGEMM_SMALL_KERNEL_B0_TT +#define GEMM_SMALL_KERNEL_B0_TR CGEMM_SMALL_KERNEL_B0_TR +#define GEMM_SMALL_KERNEL_B0_TC CGEMM_SMALL_KERNEL_B0_TC + +#define GEMM_SMALL_KERNEL_B0_RN CGEMM_SMALL_KERNEL_B0_RN +#define GEMM_SMALL_KERNEL_B0_RT CGEMM_SMALL_KERNEL_B0_RT +#define GEMM_SMALL_KERNEL_B0_RR CGEMM_SMALL_KERNEL_B0_RR +#define GEMM_SMALL_KERNEL_B0_RC CGEMM_SMALL_KERNEL_B0_RC + +#define GEMM_SMALL_KERNEL_B0_CN CGEMM_SMALL_KERNEL_B0_CN +#define GEMM_SMALL_KERNEL_B0_CT CGEMM_SMALL_KERNEL_B0_CT +#define GEMM_SMALL_KERNEL_B0_CR CGEMM_SMALL_KERNEL_B0_CR +#define GEMM_SMALL_KERNEL_B0_CC CGEMM_SMALL_KERNEL_B0_CC + #endif #endif diff --git a/common_z.h b/common_z.h index f1e78dd08..8594ec74d 100644 --- a/common_z.h +++ b/common_z.h @@ -232,6 +232,46 @@ #define ZGEADD_K zgeadd_k +#define ZGEMM_SMALL_KERNEL_NN zgemm_small_kernel_nn +#define ZGEMM_SMALL_KERNEL_NT zgemm_small_kernel_nt +#define ZGEMM_SMALL_KERNEL_NR zgemm_small_kernel_nr +#define ZGEMM_SMALL_KERNEL_NC zgemm_small_kernel_nc + +#define ZGEMM_SMALL_KERNEL_TN zgemm_small_kernel_tn +#define ZGEMM_SMALL_KERNEL_TT zgemm_small_kernel_tt +#define ZGEMM_SMALL_KERNEL_TR zgemm_small_kernel_tr +#define ZGEMM_SMALL_KERNEL_TC zgemm_small_kernel_tc + +#define ZGEMM_SMALL_KERNEL_RN zgemm_small_kernel_rn +#define ZGEMM_SMALL_KERNEL_RT zgemm_small_kernel_rt +#define ZGEMM_SMALL_KERNEL_RR zgemm_small_kernel_rr +#define ZGEMM_SMALL_KERNEL_RC zgemm_small_kernel_rc + +#define ZGEMM_SMALL_KERNEL_CN zgemm_small_kernel_cn +#define ZGEMM_SMALL_KERNEL_CT zgemm_small_kernel_ct +#define ZGEMM_SMALL_KERNEL_CR zgemm_small_kernel_cr +#define ZGEMM_SMALL_KERNEL_CC zgemm_small_kernel_cc + +#define ZGEMM_SMALL_KERNEL_B0_NN zgemm_small_kernel_b0_nn +#define ZGEMM_SMALL_KERNEL_B0_NT zgemm_small_kernel_b0_nt +#define ZGEMM_SMALL_KERNEL_B0_NR zgemm_small_kernel_b0_nr +#define ZGEMM_SMALL_KERNEL_B0_NC zgemm_small_kernel_b0_nc + +#define ZGEMM_SMALL_KERNEL_B0_TN zgemm_small_kernel_b0_tn +#define ZGEMM_SMALL_KERNEL_B0_TT zgemm_small_kernel_b0_tt +#define ZGEMM_SMALL_KERNEL_B0_TR zgemm_small_kernel_b0_tr +#define ZGEMM_SMALL_KERNEL_B0_TC zgemm_small_kernel_b0_tc + +#define ZGEMM_SMALL_KERNEL_B0_RN zgemm_small_kernel_b0_rn +#define ZGEMM_SMALL_KERNEL_B0_RT zgemm_small_kernel_b0_rt +#define ZGEMM_SMALL_KERNEL_B0_RR zgemm_small_kernel_b0_rr +#define ZGEMM_SMALL_KERNEL_B0_RC zgemm_small_kernel_b0_rc + +#define ZGEMM_SMALL_KERNEL_B0_CN zgemm_small_kernel_b0_cn +#define ZGEMM_SMALL_KERNEL_B0_CT zgemm_small_kernel_b0_ct +#define ZGEMM_SMALL_KERNEL_B0_CR zgemm_small_kernel_b0_cr +#define ZGEMM_SMALL_KERNEL_B0_CC zgemm_small_kernel_b0_cc + #else #define ZAMAX_K gotoblas -> zamax_k diff --git a/interface/gemm.c b/interface/gemm.c index 3730f37fa..b73baa9bd 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -124,6 +124,28 @@ static int (*gemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLAS #endif #endif }; + +static int (*zgemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT *,FLOAT *, BLASLONG, FLOAT *, FLOAT *, BLASLONG) = { +#ifndef GEMM3M +#ifdef COMPLEX + GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN, + GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT, + GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR, + GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC, +#endif +#endif +}; + +static int (*zgemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT *, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = { +#ifndef GEMM3M +#ifdef COMPLEX + GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN, + GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT, + GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR, + GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC, +#endif +#endif +}; #endif #ifndef CBLAS @@ -446,20 +468,25 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS #endif #ifdef SMALL_MATRIX_OPT -#if !defined(COMPLEX) //need to tune small matrices cases. if(MNK <= 100.0*100.0*100.0){ - + +#if !defined(COMPLEX) if(*(FLOAT *)(args.beta) == 0.0){ (gemm_small_kernel_b0[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, args.c, args.ldc); }else{ (gemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); } - +#else + if(beta[0] == 0.0 && beta[1] == 0.0){ + (zgemm_small_kernel_b0[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, (FLOAT *)(args.alpha), args.b, args.ldb, args.c, args.ldc); + }else{ + (zgemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, (FLOAT *)(args.alpha), args.b, args.ldb, (FLOAT *)(args.beta), args.c, args.ldc); + } +#endif return; } #endif -#endif buffer = (XFLOAT *)blas_memory_alloc(0); diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index c9544086a..1c4a00158 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -462,6 +462,42 @@ DBLASOBJS += \ dgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ dgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) +CBLASOBJS += \ + cgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_nr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_nc$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_tr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_tc$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_rn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_rt$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_rr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_rc$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_cn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_ct$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_cr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_cc$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_nr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_nc$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_tr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_tc$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_rn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_rt$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_rr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_rc$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_cn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) + +ZBLASOBJS += \ + zgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_nr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_nc$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_tr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_tc$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_rn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_rt$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_rr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_rc$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_cn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_ct$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_cr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_cc$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_nr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_nc$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_tr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_tc$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_rn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_rt$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_rr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_rc$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_cn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) + endif ###### BLAS extensions ##### @@ -4370,3 +4406,260 @@ $(KDIR)sgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL $(KDIR)sgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_TT) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + + +ifndef CGEMM_SMALL_K_NN +CGEMM_SMALL_K_NN = ../generic/zgemm_small_matrix_kernel_nn.c +endif + +ifndef CGEMM_SMALL_K_NT +CGEMM_SMALL_K_NT = ../generic/zgemm_small_matrix_kernel_nt.c +endif + +ifndef CGEMM_SMALL_K_TN +CGEMM_SMALL_K_TN = ../generic/zgemm_small_matrix_kernel_tn.c +endif + +ifndef CGEMM_SMALL_K_TT +CGEMM_SMALL_K_TT = ../generic/zgemm_small_matrix_kernel_tt.c +endif + +$(KDIR)cgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNN $< -o $@ + +$(KDIR)cgemm_small_kernel_nr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNR $< -o $@ + +$(KDIR)cgemm_small_kernel_rn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRN $< -o $@ + +$(KDIR)cgemm_small_kernel_rr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRR $< -o $@ + +$(KDIR)cgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNT $< -o $@ + +$(KDIR)cgemm_small_kernel_nc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNC $< -o $@ + +$(KDIR)cgemm_small_kernel_rt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRT $< -o $@ + +$(KDIR)cgemm_small_kernel_rc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRC $< -o $@ + +$(KDIR)cgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTN $< -o $@ + +$(KDIR)cgemm_small_kernel_tr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTR $< -o $@ + +$(KDIR)cgemm_small_kernel_cn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCN $< -o $@ + +$(KDIR)cgemm_small_kernel_cr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCR $< -o $@ + +$(KDIR)cgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTT $< -o $@ + +$(KDIR)cgemm_small_kernel_tc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTC $< -o $@ + +$(KDIR)cgemm_small_kernel_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCT $< -o $@ + +$(KDIR)cgemm_small_kernel_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCC $< -o $@ + +ifndef CGEMM_SMALL_K_B0_NN +CGEMM_SMALL_K_B0_NN = ../generic/zgemm_small_matrix_kernel_b0_nn.c +endif + +ifndef CGEMM_SMALL_K_B0_NT +CGEMM_SMALL_K_B0_NT = ../generic/zgemm_small_matrix_kernel_b0_nt.c +endif + +ifndef CGEMM_SMALL_K_B0_TN +CGEMM_SMALL_K_B0_TN = ../generic/zgemm_small_matrix_kernel_b0_tn.c +endif + +ifndef CGEMM_SMALL_K_B0_TT +CGEMM_SMALL_K_B0_TT = ../generic/zgemm_small_matrix_kernel_b0_tt.c +endif + +$(KDIR)cgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNN $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_nr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNR $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_rn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRN $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_rr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRR $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNT $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_nc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNC $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_rt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRT $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_rc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRC $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTN $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_tr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTR $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_cn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCN $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCR $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTT $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_tc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTC $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCT $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCC $< -o $@ + +ifndef ZGEMM_SMALL_K_NN +ZGEMM_SMALL_K_NN = ../generic/zgemm_small_matrix_kernel_nn.c +endif + +ifndef ZGEMM_SMALL_K_NT +ZGEMM_SMALL_K_NT = ../generic/zgemm_small_matrix_kernel_nt.c +endif + +ifndef ZGEMM_SMALL_K_TN +ZGEMM_SMALL_K_TN = ../generic/zgemm_small_matrix_kernel_tn.c +endif + +ifndef ZGEMM_SMALL_K_TT +ZGEMM_SMALL_K_TT = ../generic/zgemm_small_matrix_kernel_tt.c +endif + +$(KDIR)zgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNN $< -o $@ + +$(KDIR)zgemm_small_kernel_nr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNR $< -o $@ + +$(KDIR)zgemm_small_kernel_rn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRN $< -o $@ + +$(KDIR)zgemm_small_kernel_rr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRR $< -o $@ + +$(KDIR)zgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNT $< -o $@ + +$(KDIR)zgemm_small_kernel_nc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNC $< -o $@ + +$(KDIR)zgemm_small_kernel_rt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRT $< -o $@ + +$(KDIR)zgemm_small_kernel_rc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRC $< -o $@ + +$(KDIR)zgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTN $< -o $@ + +$(KDIR)zgemm_small_kernel_tr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTR $< -o $@ + +$(KDIR)zgemm_small_kernel_cn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCN $< -o $@ + +$(KDIR)zgemm_small_kernel_cr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCR $< -o $@ + +$(KDIR)zgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTT $< -o $@ + +$(KDIR)zgemm_small_kernel_tc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTC $< -o $@ + +$(KDIR)zgemm_small_kernel_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCT $< -o $@ + +$(KDIR)zgemm_small_kernel_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCC $< -o $@ + +ifndef ZGEMM_SMALL_K_B0_NN +ZGEMM_SMALL_K_B0_NN = ../generic/zgemm_small_matrix_kernel_b0_nn.c +endif + +ifndef ZGEMM_SMALL_K_B0_NT +ZGEMM_SMALL_K_B0_NT = ../generic/zgemm_small_matrix_kernel_b0_nt.c +endif + +ifndef ZGEMM_SMALL_K_B0_TN +ZGEMM_SMALL_K_B0_TN = ../generic/zgemm_small_matrix_kernel_b0_tn.c +endif + +ifndef ZGEMM_SMALL_K_B0_TT +ZGEMM_SMALL_K_B0_TT = ../generic/zgemm_small_matrix_kernel_b0_tt.c +endif + +$(KDIR)zgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNN $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_nr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNR $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_rn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRN $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_rr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRR $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNT $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_nc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNC $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_rt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRT $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_rc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRC $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTN $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_tr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTR $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_cn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCN $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCR $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTT $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_tc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTC $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCT $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCC $< -o $@ diff --git a/kernel/generic/zgemm_small_matrix_kernel_b0_nn.c b/kernel/generic/zgemm_small_matrix_kernel_b0_nn.c new file mode 100644 index 000000000..11e746e52 --- /dev/null +++ b/kernel/generic/zgemm_small_matrix_kernel_b0_nn.c @@ -0,0 +1,74 @@ +/*************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +{ + FLOAT real, imag; + + int i, j, l; + for(i = 0; i < M; i++){ + for(j = 0; j < N; j++){ + real=0; + imag=0; + + for(l = 0; l < K; l++){ +#if defined(NN) + real += (A[l*2*lda + 2*i]*B[j*2*ldb + 2*l] + -A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(A[l*2*lda + 2*i] * B[j*2*ldb + 2*l + 1] + + A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l]); +#elif defined(NR) + real += (A[l*2*lda + 2*i]*B[j*2*ldb + 2*l] + +A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(-A[l*2*lda + 2*i] * B[j*2*ldb + 2*l + 1] + + A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l]); +#elif defined(RN) + real += (A[l*2*lda + 2*i]*B[j*2*ldb + 2*l] + +A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(A[l*2*lda + 2*i] * B[j*2*ldb + 2*l + 1] + - A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l]); +#elif defined(RR) + real += (A[l*2*lda + 2*i]*B[j*2*ldb + 2*l] + -A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(-A[l*2*lda + 2*i] * B[j*2*ldb + 2*l + 1] + - A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l]); +#endif + } + + C[j*2*ldc + 2*i] = alpha[0]*real - alpha[1]*imag; + C[j*2*ldc+ 2*i + 1] = alpha[0]*imag + real*alpha[1]; + } + } + + return 0; +} diff --git a/kernel/generic/zgemm_small_matrix_kernel_b0_nt.c b/kernel/generic/zgemm_small_matrix_kernel_b0_nt.c new file mode 100644 index 000000000..1ef743017 --- /dev/null +++ b/kernel/generic/zgemm_small_matrix_kernel_b0_nt.c @@ -0,0 +1,77 @@ +/*************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +{ + FLOAT real, imag; + int i, j, l; + for(i = 0; i < M; i++){ + for(j = 0; j < N; j++){ + real=0; + imag=0; + + for(l = 0; l < K; l++){ +#if defined(NT) + real += (A[l*2*lda + 2*i]*B[l*2*ldb + 2*j] + -A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(A[l*2*lda + 2*i] * B[l*2*ldb + 2*j + 1] + + A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j]); + +#elif defined(NC) + real += (A[l*2*lda + 2*i]*B[l*2*ldb + 2*j] + +A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(-A[l*2*lda + 2*i] * B[l*2*ldb + 2*j + 1] + + A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j]); + +#elif defined(RT) + real += (A[l*2*lda + 2*i]*B[l*2*ldb + 2*j] + +A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(A[l*2*lda + 2*i] * B[l*2*ldb + 2*j + 1] + - A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j]); + +#elif defined(RC) + real += (A[l*2*lda + 2*i]*B[l*2*ldb + 2*j] + -A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(-A[l*2*lda + 2*i] * B[l*2*ldb + 2*j + 1] + - A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j]); + +#endif + } + + C[j*2*ldc + 2*i] = alpha[0]*real - alpha[1]*imag; + C[j*2*ldc+ 2*i + 1] = alpha[0]*imag + real*alpha[1]; + } + } + + return 0; +} diff --git a/kernel/generic/zgemm_small_matrix_kernel_b0_tn.c b/kernel/generic/zgemm_small_matrix_kernel_b0_tn.c new file mode 100644 index 000000000..2cd3ebcf2 --- /dev/null +++ b/kernel/generic/zgemm_small_matrix_kernel_b0_tn.c @@ -0,0 +1,77 @@ +/*************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +{ + FLOAT real, imag; + int i, j, l; + for(i = 0; i < M; i++){ + for(j = 0; j < N; j++){ + real=0; + imag=0; + + for(l = 0; l < K; l++){ +#if defined(TN) + real += (A[i*2*lda + 2*l]*B[j*2*ldb + 2*l] + -A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(A[i*2*lda + 2*l] * B[j*2*ldb + 2*l + 1] + + A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l]); + +#elif defined(TR) + real += (A[i*2*lda + 2*l]*B[j*2*ldb + 2*l] + +A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(-A[i*2*lda + 2*l] * B[j*2*ldb + 2*l + 1] + + A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l]); + +#elif defined(CN) + real += (A[i*2*lda + 2*l]*B[j*2*ldb + 2*l] + +A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(A[i*2*lda + 2*l] * B[j*2*ldb + 2*l + 1] + - A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l]); + +#elif defined(CR) + real += (A[i*2*lda + 2*l]*B[j*2*ldb + 2*l] + -A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(-A[i*2*lda + 2*l] * B[j*2*ldb + 2*l + 1] + - A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l]); + +#endif + } + + C[j*2*ldc + 2*i] = alpha[0]*real - alpha[1]*imag; + C[j*2*ldc+ 2*i + 1] = alpha[0]*imag + real*alpha[1]; + } + } + + return 0; +} diff --git a/kernel/generic/zgemm_small_matrix_kernel_b0_tt.c b/kernel/generic/zgemm_small_matrix_kernel_b0_tt.c new file mode 100644 index 000000000..25b05b4aa --- /dev/null +++ b/kernel/generic/zgemm_small_matrix_kernel_b0_tt.c @@ -0,0 +1,77 @@ +/*************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +{ + FLOAT real, imag; + int i, j, l; + for(i = 0; i < M; i++){ + for(j = 0; j < N; j++){ + real=0; + imag=0; + + for(l = 0; l < K; l++){ +#if defined(TT) + real += (A[i*2*lda + 2*l]*B[l*2*ldb + 2*j] + -A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(A[i*2*lda + 2*l] * B[l*2*ldb + 2*j + 1] + + A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j]); + +#elif defined(TC) + real += (A[i*2*lda + 2*l]*B[l*2*ldb + 2*j] + +A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(-A[i*2*lda + 2*l] * B[l*2*ldb + 2*j + 1] + + A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j]); + +#elif defined(CT) + real += (A[i*2*lda + 2*l]*B[l*2*ldb + 2*j] + +A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(A[i*2*lda + 2*l] * B[l*2*ldb + 2*j + 1] + - A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j]); + +#elif defined(CC) + real += (A[i*2*lda + 2*l]*B[l*2*ldb + 2*j] + -A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(-A[i*2*lda + 2*l] * B[l*2*ldb + 2*j + 1] + - A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j]); + +#endif + } + + C[j*2*ldc + 2*i] = alpha[0]*real - alpha[1]*imag; + C[j*2*ldc+ 2*i + 1] = alpha[0]*imag + real*alpha[1]; + } + } + + return 0; +} diff --git a/kernel/generic/zgemm_small_matrix_kernel_nn.c b/kernel/generic/zgemm_small_matrix_kernel_nn.c new file mode 100644 index 000000000..6ef1b9655 --- /dev/null +++ b/kernel/generic/zgemm_small_matrix_kernel_nn.c @@ -0,0 +1,78 @@ +/*************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT* beta, FLOAT * C, BLASLONG ldc) +{ + FLOAT real, imag; + FLOAT tmp0, tmp1; + int i, j, l; + for(i = 0; i < M; i++){ + for(j = 0; j < N; j++){ + real=0; + imag=0; + + for(l = 0; l < K; l++){ +#if defined(NN) + real += (A[l*2*lda + 2*i]*B[j*2*ldb + 2*l] + -A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(A[l*2*lda + 2*i] * B[j*2*ldb + 2*l + 1] + + A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l]); +#elif defined(NR) + real += (A[l*2*lda + 2*i]*B[j*2*ldb + 2*l] + +A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(-A[l*2*lda + 2*i] * B[j*2*ldb + 2*l + 1] + + A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l]); +#elif defined(RN) + real += (A[l*2*lda + 2*i]*B[j*2*ldb + 2*l] + +A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(A[l*2*lda + 2*i] * B[j*2*ldb + 2*l + 1] + - A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l]); +#elif defined(RR) + real += (A[l*2*lda + 2*i]*B[j*2*ldb + 2*l] + -A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(-A[l*2*lda + 2*i] * B[j*2*ldb + 2*l + 1] + - A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l]); +#endif + } + + tmp0 = beta[0]*C[j*2*ldc + 2*i] - beta[1]*C[j*2*ldc+ 2*i + 1]; + tmp1 = beta[0]*C[j*2*ldc+ 2*i + 1] + beta[1]*C[j*2*ldc + 2*i]; + + + C[j*2*ldc + 2*i] =tmp0+ alpha[0]*real - alpha[1]*imag; + C[j*2*ldc+ 2*i + 1] = tmp1+ alpha[0]*imag + real*alpha[1]; + } + } + + return 0; +} diff --git a/kernel/generic/zgemm_small_matrix_kernel_nt.c b/kernel/generic/zgemm_small_matrix_kernel_nt.c new file mode 100644 index 000000000..3c81ad79e --- /dev/null +++ b/kernel/generic/zgemm_small_matrix_kernel_nt.c @@ -0,0 +1,82 @@ +/*************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT* beta, FLOAT * C, BLASLONG ldc) +{ + FLOAT real, imag; + FLOAT tmp0, tmp1; + int i, j, l; + for(i = 0; i < M; i++){ + for(j = 0; j < N; j++){ + real=0; + imag=0; + + for(l = 0; l < K; l++){ +#if defined(NT) + real += (A[l*2*lda + 2*i]*B[l*2*ldb + 2*j] + -A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(A[l*2*lda + 2*i] * B[l*2*ldb + 2*j + 1] + + A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j]); + +#elif defined(NC) + real += (A[l*2*lda + 2*i]*B[l*2*ldb + 2*j] + +A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(-A[l*2*lda + 2*i] * B[l*2*ldb + 2*j + 1] + + A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j]); + +#elif defined(RT) + real += (A[l*2*lda + 2*i]*B[l*2*ldb + 2*j] + +A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(A[l*2*lda + 2*i] * B[l*2*ldb + 2*j + 1] + - A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j]); + +#elif defined(RC) + real += (A[l*2*lda + 2*i]*B[l*2*ldb + 2*j] + -A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(-A[l*2*lda + 2*i] * B[l*2*ldb + 2*j + 1] + - A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j]); + +#endif + } + + tmp0 = beta[0]*C[j*2*ldc + 2*i] - beta[1]*C[j*2*ldc+ 2*i + 1]; + tmp1 = beta[0]*C[j*2*ldc+ 2*i + 1] + beta[1]*C[j*2*ldc + 2*i]; + + + C[j*2*ldc + 2*i] =tmp0+ alpha[0]*real - alpha[1]*imag; + C[j*2*ldc+ 2*i + 1] = tmp1+ alpha[0]*imag + real*alpha[1]; + } + } + + return 0; +} diff --git a/kernel/generic/zgemm_small_matrix_kernel_tn.c b/kernel/generic/zgemm_small_matrix_kernel_tn.c new file mode 100644 index 000000000..143190bb1 --- /dev/null +++ b/kernel/generic/zgemm_small_matrix_kernel_tn.c @@ -0,0 +1,82 @@ +/*************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT* beta, FLOAT * C, BLASLONG ldc) +{ + FLOAT real, imag; + FLOAT tmp0, tmp1; + int i, j, l; + for(i = 0; i < M; i++){ + for(j = 0; j < N; j++){ + real=0; + imag=0; + + for(l = 0; l < K; l++){ +#if defined(TN) + real += (A[i*2*lda + 2*l]*B[j*2*ldb + 2*l] + -A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(A[i*2*lda + 2*l] * B[j*2*ldb + 2*l + 1] + + A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l]); + +#elif defined(TR) + real += (A[i*2*lda + 2*l]*B[j*2*ldb + 2*l] + +A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(-A[i*2*lda + 2*l] * B[j*2*ldb + 2*l + 1] + + A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l]); + +#elif defined(CN) + real += (A[i*2*lda + 2*l]*B[j*2*ldb + 2*l] + +A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(A[i*2*lda + 2*l] * B[j*2*ldb + 2*l + 1] + - A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l]); + +#elif defined(CR) + real += (A[i*2*lda + 2*l]*B[j*2*ldb + 2*l] + -A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l + 1]); + + imag+=(-A[i*2*lda + 2*l] * B[j*2*ldb + 2*l + 1] + - A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l]); + +#endif + } + + tmp0 = beta[0]*C[j*2*ldc + 2*i] - beta[1]*C[j*2*ldc+ 2*i + 1]; + tmp1 = beta[0]*C[j*2*ldc+ 2*i + 1] + beta[1]*C[j*2*ldc + 2*i]; + + + C[j*2*ldc + 2*i] =tmp0+ alpha[0]*real - alpha[1]*imag; + C[j*2*ldc+ 2*i + 1] = tmp1+ alpha[0]*imag + real*alpha[1]; + } + } + + return 0; +} diff --git a/kernel/generic/zgemm_small_matrix_kernel_tt.c b/kernel/generic/zgemm_small_matrix_kernel_tt.c new file mode 100644 index 000000000..246e26e84 --- /dev/null +++ b/kernel/generic/zgemm_small_matrix_kernel_tt.c @@ -0,0 +1,82 @@ +/*************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT* beta, FLOAT * C, BLASLONG ldc) +{ + FLOAT real, imag; + FLOAT tmp0, tmp1; + int i, j, l; + for(i = 0; i < M; i++){ + for(j = 0; j < N; j++){ + real=0; + imag=0; + + for(l = 0; l < K; l++){ +#if defined(TT) + real += (A[i*2*lda + 2*l]*B[l*2*ldb + 2*j] + -A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(A[i*2*lda + 2*l] * B[l*2*ldb + 2*j + 1] + + A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j]); + +#elif defined(TC) + real += (A[i*2*lda + 2*l]*B[l*2*ldb + 2*j] + +A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(-A[i*2*lda + 2*l] * B[l*2*ldb + 2*j + 1] + + A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j]); + +#elif defined(CT) + real += (A[i*2*lda + 2*l]*B[l*2*ldb + 2*j] + +A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(A[i*2*lda + 2*l] * B[l*2*ldb + 2*j + 1] + - A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j]); + +#elif defined(CC) + real += (A[i*2*lda + 2*l]*B[l*2*ldb + 2*j] + -A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j + 1]); + + imag+=(-A[i*2*lda + 2*l] * B[l*2*ldb + 2*j + 1] + - A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j]); + +#endif + } + + tmp0 = beta[0]*C[j*2*ldc + 2*i] - beta[1]*C[j*2*ldc+ 2*i + 1]; + tmp1 = beta[0]*C[j*2*ldc+ 2*i + 1] + beta[1]*C[j*2*ldc + 2*i]; + + + C[j*2*ldc + 2*i] =tmp0+ alpha[0]*real - alpha[1]*imag; + C[j*2*ldc+ 2*i + 1] = tmp1+ alpha[0]*imag + real*alpha[1]; + } + } + + return 0; +} From 6022e5629c7708b114a3c2387e652ebd32122300 Mon Sep 17 00:00:00 2001 From: Xianyi Zhang Date: Fri, 28 Aug 2020 22:36:36 +0800 Subject: [PATCH 07/37] Refs #2587 fix small matrix c/zgemm bug. --- common_level3.h | 128 +++++++++--------- interface/gemm.c | 22 ++- .../generic/zgemm_small_matrix_kernel_b0_nn.c | 6 +- .../generic/zgemm_small_matrix_kernel_b0_nt.c | 6 +- .../generic/zgemm_small_matrix_kernel_b0_tn.c | 6 +- .../generic/zgemm_small_matrix_kernel_b0_tt.c | 6 +- kernel/generic/zgemm_small_matrix_kernel_nn.c | 10 +- kernel/generic/zgemm_small_matrix_kernel_nt.c | 10 +- kernel/generic/zgemm_small_matrix_kernel_tn.c | 10 +- kernel/generic/zgemm_small_matrix_kernel_tt.c | 10 +- 10 files changed, 105 insertions(+), 109 deletions(-) diff --git a/common_level3.h b/common_level3.h index 5741f56d5..a3a487dab 100644 --- a/common_level3.h +++ b/common_level3.h @@ -536,85 +536,85 @@ int dgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLA int dgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); int dgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int cgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); -int cgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); -int cgemm_small_kernel_nr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); -int cgemm_small_kernel_nc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_nr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_nc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); -int cgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); -int cgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); -int cgemm_small_kernel_tr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); -int cgemm_small_kernel_tc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_tr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_tc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); -int cgemm_small_kernel_rn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); -int cgemm_small_kernel_rt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); -int cgemm_small_kernel_rr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); -int cgemm_small_kernel_rc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_rn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_rt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_rr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_rc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); -int cgemm_small_kernel_cn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); -int cgemm_small_kernel_ct(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); -int cgemm_small_kernel_cr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); -int cgemm_small_kernel_cc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * beta, float * C, BLASLONG ldc); +int cgemm_small_kernel_cn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_ct(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_cr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_cc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); -int zgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); -int zgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); -int zgemm_small_kernel_nr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); -int zgemm_small_kernel_nc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_nr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_nc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); -int zgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); -int zgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); -int zgemm_small_kernel_tr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); -int zgemm_small_kernel_tc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_tr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_tc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); -int zgemm_small_kernel_rn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); -int zgemm_small_kernel_rt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); -int zgemm_small_kernel_rr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); -int zgemm_small_kernel_rc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_rn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_rt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_rr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_rc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); -int zgemm_small_kernel_cn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); -int zgemm_small_kernel_ct(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); -int zgemm_small_kernel_cr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); -int zgemm_small_kernel_cc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * beta, double * C, BLASLONG ldc); +int zgemm_small_kernel_cn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_ct(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_cr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_cc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); -int cgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_nr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_nc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_nr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_nc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_tr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_tc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_tr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_tc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_rn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_rt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_rr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_rc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_rn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_rt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_rr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_rc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_cn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_ct(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_cr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int cgemm_small_kernel_b0_cc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_cn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_ct(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_cr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_cc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); -int zgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_nr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_nc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_nr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_nc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_tr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_tc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_tr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_tc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_rn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_rt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_rr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_rc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_rn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_rt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_rr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_rc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_cn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_ct(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_cr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); -int zgemm_small_kernel_b0_cc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_cn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_ct(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_cr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_cc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); #endif diff --git a/interface/gemm.c b/interface/gemm.c index b73baa9bd..7251993ee 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -106,47 +106,43 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, B }; #ifdef SMALL_MATRIX_OPT -//Only support s/dgemm small matrix optimiztion so far. + +#ifndef COMPLEX static int (*gemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG) = { #ifndef GEMM3M -#ifndef COMPLEX GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, NULL, NULL, GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, NULL, NULL, #endif -#endif }; static int (*gemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = { #ifndef GEMM3M -#ifndef COMPLEX GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, NULL, NULL, GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, NULL, NULL, #endif -#endif }; -static int (*zgemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT *,FLOAT *, BLASLONG, FLOAT *, FLOAT *, BLASLONG) = { +#else + +static int (*zgemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG) = { #ifndef GEMM3M -#ifdef COMPLEX GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN, GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT, GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR, GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC, #endif -#endif }; -static int (*zgemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT *, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = { +static int (*zgemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = { #ifndef GEMM3M -#ifdef COMPLEX GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN, GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT, GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR, GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC, #endif -#endif }; #endif +#endif #ifndef CBLAS @@ -479,9 +475,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS } #else if(beta[0] == 0.0 && beta[1] == 0.0){ - (zgemm_small_kernel_b0[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, (FLOAT *)(args.alpha), args.b, args.ldb, args.c, args.ldc); + (zgemm_small_kernel_b0[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, args.c, args.ldc); }else{ - (zgemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, (FLOAT *)(args.alpha), args.b, args.ldb, (FLOAT *)(args.beta), args.c, args.ldc); + (zgemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, beta[0], beta[1], args.c, args.ldc); } #endif return; diff --git a/kernel/generic/zgemm_small_matrix_kernel_b0_nn.c b/kernel/generic/zgemm_small_matrix_kernel_b0_nn.c index 11e746e52..3ab057fef 100644 --- a/kernel/generic/zgemm_small_matrix_kernel_b0_nn.c +++ b/kernel/generic/zgemm_small_matrix_kernel_b0_nn.c @@ -27,7 +27,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) { FLOAT real, imag; @@ -65,8 +65,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* al #endif } - C[j*2*ldc + 2*i] = alpha[0]*real - alpha[1]*imag; - C[j*2*ldc+ 2*i + 1] = alpha[0]*imag + real*alpha[1]; + C[j*2*ldc + 2*i] = alpha0*real - alpha1*imag; + C[j*2*ldc+ 2*i + 1] = alpha0*imag + real*alpha1; } } diff --git a/kernel/generic/zgemm_small_matrix_kernel_b0_nt.c b/kernel/generic/zgemm_small_matrix_kernel_b0_nt.c index 1ef743017..dc35f4a6d 100644 --- a/kernel/generic/zgemm_small_matrix_kernel_b0_nt.c +++ b/kernel/generic/zgemm_small_matrix_kernel_b0_nt.c @@ -27,7 +27,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) { FLOAT real, imag; int i, j, l; @@ -68,8 +68,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* al #endif } - C[j*2*ldc + 2*i] = alpha[0]*real - alpha[1]*imag; - C[j*2*ldc+ 2*i + 1] = alpha[0]*imag + real*alpha[1]; + C[j*2*ldc + 2*i] = alpha0*real - alpha1*imag; + C[j*2*ldc+ 2*i + 1] = alpha0*imag + real*alpha1; } } diff --git a/kernel/generic/zgemm_small_matrix_kernel_b0_tn.c b/kernel/generic/zgemm_small_matrix_kernel_b0_tn.c index 2cd3ebcf2..479a56e8f 100644 --- a/kernel/generic/zgemm_small_matrix_kernel_b0_tn.c +++ b/kernel/generic/zgemm_small_matrix_kernel_b0_tn.c @@ -27,7 +27,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) { FLOAT real, imag; int i, j, l; @@ -68,8 +68,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* al #endif } - C[j*2*ldc + 2*i] = alpha[0]*real - alpha[1]*imag; - C[j*2*ldc+ 2*i + 1] = alpha[0]*imag + real*alpha[1]; + C[j*2*ldc + 2*i] = alpha0*real - alpha1*imag; + C[j*2*ldc+ 2*i + 1] = alpha0*imag + real*alpha1; } } diff --git a/kernel/generic/zgemm_small_matrix_kernel_b0_tt.c b/kernel/generic/zgemm_small_matrix_kernel_b0_tt.c index 25b05b4aa..b698973dd 100644 --- a/kernel/generic/zgemm_small_matrix_kernel_b0_tt.c +++ b/kernel/generic/zgemm_small_matrix_kernel_b0_tt.c @@ -27,7 +27,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) { FLOAT real, imag; int i, j, l; @@ -68,8 +68,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* al #endif } - C[j*2*ldc + 2*i] = alpha[0]*real - alpha[1]*imag; - C[j*2*ldc+ 2*i + 1] = alpha[0]*imag + real*alpha[1]; + C[j*2*ldc + 2*i] = alpha0*real - alpha1*imag; + C[j*2*ldc+ 2*i + 1] = alpha0*imag + real*alpha1; } } diff --git a/kernel/generic/zgemm_small_matrix_kernel_nn.c b/kernel/generic/zgemm_small_matrix_kernel_nn.c index 6ef1b9655..4bf6bf7ee 100644 --- a/kernel/generic/zgemm_small_matrix_kernel_nn.c +++ b/kernel/generic/zgemm_small_matrix_kernel_nn.c @@ -27,7 +27,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT* beta, FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT beta0, FLOAT beta1, FLOAT * C, BLASLONG ldc) { FLOAT real, imag; FLOAT tmp0, tmp1; @@ -65,12 +65,12 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* al #endif } - tmp0 = beta[0]*C[j*2*ldc + 2*i] - beta[1]*C[j*2*ldc+ 2*i + 1]; - tmp1 = beta[0]*C[j*2*ldc+ 2*i + 1] + beta[1]*C[j*2*ldc + 2*i]; + tmp0 = beta0*C[j*2*ldc + 2*i] - beta1*C[j*2*ldc+ 2*i + 1]; + tmp1 = beta0*C[j*2*ldc+ 2*i + 1] + beta1*C[j*2*ldc + 2*i]; - C[j*2*ldc + 2*i] =tmp0+ alpha[0]*real - alpha[1]*imag; - C[j*2*ldc+ 2*i + 1] = tmp1+ alpha[0]*imag + real*alpha[1]; + C[j*2*ldc + 2*i] =tmp0+ alpha0*real - alpha1*imag; + C[j*2*ldc+ 2*i + 1] = tmp1+ alpha0*imag + real*alpha1; } } diff --git a/kernel/generic/zgemm_small_matrix_kernel_nt.c b/kernel/generic/zgemm_small_matrix_kernel_nt.c index 3c81ad79e..288e49c13 100644 --- a/kernel/generic/zgemm_small_matrix_kernel_nt.c +++ b/kernel/generic/zgemm_small_matrix_kernel_nt.c @@ -27,7 +27,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT* beta, FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT beta0, FLOAT beta1, FLOAT * C, BLASLONG ldc) { FLOAT real, imag; FLOAT tmp0, tmp1; @@ -69,12 +69,12 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* al #endif } - tmp0 = beta[0]*C[j*2*ldc + 2*i] - beta[1]*C[j*2*ldc+ 2*i + 1]; - tmp1 = beta[0]*C[j*2*ldc+ 2*i + 1] + beta[1]*C[j*2*ldc + 2*i]; + tmp0 = beta0*C[j*2*ldc + 2*i] - beta1*C[j*2*ldc+ 2*i + 1]; + tmp1 = beta0*C[j*2*ldc+ 2*i + 1] + beta1*C[j*2*ldc + 2*i]; - C[j*2*ldc + 2*i] =tmp0+ alpha[0]*real - alpha[1]*imag; - C[j*2*ldc+ 2*i + 1] = tmp1+ alpha[0]*imag + real*alpha[1]; + C[j*2*ldc + 2*i] =tmp0+ alpha0*real - alpha1*imag; + C[j*2*ldc+ 2*i + 1] = tmp1+ alpha0*imag + real*alpha1; } } diff --git a/kernel/generic/zgemm_small_matrix_kernel_tn.c b/kernel/generic/zgemm_small_matrix_kernel_tn.c index 143190bb1..1e2a5aed4 100644 --- a/kernel/generic/zgemm_small_matrix_kernel_tn.c +++ b/kernel/generic/zgemm_small_matrix_kernel_tn.c @@ -27,7 +27,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT* beta, FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT beta0, FLOAT beta1, FLOAT * C, BLASLONG ldc) { FLOAT real, imag; FLOAT tmp0, tmp1; @@ -69,12 +69,12 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* al #endif } - tmp0 = beta[0]*C[j*2*ldc + 2*i] - beta[1]*C[j*2*ldc+ 2*i + 1]; - tmp1 = beta[0]*C[j*2*ldc+ 2*i + 1] + beta[1]*C[j*2*ldc + 2*i]; + tmp0 = beta0*C[j*2*ldc + 2*i] - beta1*C[j*2*ldc+ 2*i + 1]; + tmp1 = beta0*C[j*2*ldc+ 2*i + 1] + beta1*C[j*2*ldc + 2*i]; - C[j*2*ldc + 2*i] =tmp0+ alpha[0]*real - alpha[1]*imag; - C[j*2*ldc+ 2*i + 1] = tmp1+ alpha[0]*imag + real*alpha[1]; + C[j*2*ldc + 2*i] =tmp0+ alpha0*real - alpha1*imag; + C[j*2*ldc+ 2*i + 1] = tmp1+ alpha0*imag + real*alpha1; } } diff --git a/kernel/generic/zgemm_small_matrix_kernel_tt.c b/kernel/generic/zgemm_small_matrix_kernel_tt.c index 246e26e84..180043539 100644 --- a/kernel/generic/zgemm_small_matrix_kernel_tt.c +++ b/kernel/generic/zgemm_small_matrix_kernel_tt.c @@ -27,7 +27,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* alpha, FLOAT * B, BLASLONG ldb, FLOAT* beta, FLOAT * C, BLASLONG ldc) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT beta0, FLOAT beta1, FLOAT * C, BLASLONG ldc) { FLOAT real, imag; FLOAT tmp0, tmp1; @@ -69,12 +69,12 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT* al #endif } - tmp0 = beta[0]*C[j*2*ldc + 2*i] - beta[1]*C[j*2*ldc+ 2*i + 1]; - tmp1 = beta[0]*C[j*2*ldc+ 2*i + 1] + beta[1]*C[j*2*ldc + 2*i]; + tmp0 = beta0*C[j*2*ldc + 2*i] - beta1*C[j*2*ldc+ 2*i + 1]; + tmp1 = beta0*C[j*2*ldc+ 2*i + 1] + beta1*C[j*2*ldc + 2*i]; - C[j*2*ldc + 2*i] =tmp0+ alpha[0]*real - alpha[1]*imag; - C[j*2*ldc+ 2*i + 1] = tmp1+ alpha[0]*imag + real*alpha[1]; + C[j*2*ldc + 2*i] =tmp0+ alpha0*real - alpha1*imag; + C[j*2*ldc+ 2*i + 1] = tmp1+ alpha0*imag + real*alpha1; } } From 9186456a1297f7ee97bae56370c404114933a5ee Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Sat, 8 May 2021 10:45:10 +0000 Subject: [PATCH 08/37] small matrix: SkylakeX: add SGEMM NN kernel --- kernel/x86_64/KERNEL.SKYLAKEX | 2 + .../sgemm_small_kernel_b0_nn_skylakex.c | 2 + .../x86_64/sgemm_small_kernel_nn_skylakex.c | 424 ++++++++++++++++++ 3 files changed, 428 insertions(+) create mode 100644 kernel/x86_64/sgemm_small_kernel_b0_nn_skylakex.c create mode 100644 kernel/x86_64/sgemm_small_kernel_nn_skylakex.c diff --git a/kernel/x86_64/KERNEL.SKYLAKEX b/kernel/x86_64/KERNEL.SKYLAKEX index 3d71584fe..1a2e67b52 100644 --- a/kernel/x86_64/KERNEL.SKYLAKEX +++ b/kernel/x86_64/KERNEL.SKYLAKEX @@ -10,6 +10,8 @@ STRSMKERNEL_LN = ../generic/trsm_kernel_LN.c STRSMKERNEL_LT = ../generic/trsm_kernel_LT.c STRSMKERNEL_RN = ../generic/trsm_kernel_RN.c STRSMKERNEL_RT = ../generic/trsm_kernel_RT.c +SGEMM_SMALL_K_NN = sgemm_small_kernel_nn_skylakex.c +SGEMM_SMALL_K_B0_NN = sgemm_small_kernel_b0_nn_skylakex.c DGEMMKERNEL = dgemm_kernel_16x2_skylakex.c DTRMMKERNEL = dgemm_kernel_16x2_skylakex.c diff --git a/kernel/x86_64/sgemm_small_kernel_b0_nn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_b0_nn_skylakex.c new file mode 100644 index 000000000..704e964b8 --- /dev/null +++ b/kernel/x86_64/sgemm_small_kernel_b0_nn_skylakex.c @@ -0,0 +1,2 @@ +#define B0 1 +#include "./sgemm_small_kernel_nn_skylakex.c" diff --git a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c new file mode 100644 index 000000000..f2c79873e --- /dev/null +++ b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c @@ -0,0 +1,424 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include + +#define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() +#define LOAD_A_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[lda * k + i + (M*16)]) +#define BROADCAST_LOAD_B_512(M, N) __m512 Bval##N = _mm512_broadcastss_ps(_mm_load_ss(&B[k + ldb * (j+N)])) +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) +#if defined(B0) +#define STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) +#else +#define STORE_512(M, N) \ + BLASLONG offset##M##N = (j+N)*ldc + i + (M*16); \ + result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + asm("vfmadd231ps (%1, %2, 4), %3, %0": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_512)); \ + _mm512_storeu_ps(&C[offset##M##N], result##M##N) +#endif + +#define DECLARE_RESULT_256(M, N) __m256 result##M##N = _mm256_setzero_ps() +#define LOAD_A_256(M, N) __m256 Aval##M = _mm256_loadu_ps(&A[lda * k + i + (M*8)]) +#define BROADCAST_LOAD_B_256(M, N) __m256 Bval##N = _mm256_broadcastss_ps(_mm_load_ss(&B[k + ldb * (j+N)])) +#define MATMUL_256(M, N) result##M##N = _mm256_fmadd_ps(Aval##M, Bval##N, result##M##N) +#if defined(B0) +#define STORE_256(M, N) result##M##N = _mm256_mul_ps(result##M##N, alpha_256); \ + _mm256_storeu_ps(&C[(j+N)*ldc + i + (M*8)], result##M##N) +#else +#define STORE_256(M, N) \ + BLASLONG offset##M##N = (j+N)*ldc + i + (M*8); \ + result##M##N = _mm256_mul_ps(result##M##N, alpha_256); \ + asm("vfmadd231ps (%1, %2, 4), %3, %0": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_256)); \ + _mm256_storeu_ps(&C[offset##M##N], result##M##N) +#endif + +#define DECLARE_RESULT_128(M, N) __m128 result##M##N; asm("vpxorq %0, %0, %0": "+v"(result##M##N):) +#define LOAD_A_128(M, N) __m128 Aval##M = _mm_maskz_loadu_ps(mask, &A[lda * k + i + (M*4)]) +#define BROADCAST_LOAD_B_128(M, N) __m128 Bval##N = _mm_broadcastss_ps(_mm_load_ss(&B[k + ldb * (j+N)])) +#define MATMUL_128(M, N) result##M##N = _mm_fmadd_ps(Aval##M, Bval##N, result##M##N) +#if defined(B0) +#define STORE_128(M, N) result##M##N = _mm_maskz_mul_ps(mask, result##M##N, alpha_128); \ + _mm_mask_storeu_ps(&C[(j+N)*ldc + i + (M*4)], mask, result##M##N) +#else +#define STORE_128(M, N) \ + BLASLONG offset##M##N = (j+N)*ldc + i + (M*4); \ + result##M##N = _mm_maskz_mul_ps(mask, result##M##N, alpha_128); \ + asm("vfmadd231ps (%1, %2, 4), %3, %0": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_128)); \ + _mm_mask_storeu_ps(&C[offset##M##N], mask, result##M##N) +#endif + +#define DECLARE_RESULT_S(M, N) float result##M##N = 0; +#define LOAD_A_S(M, N) float Aval##M = A[lda * k + i + M] +#define BROADCAST_LOAD_B_S(M, N) float Bval##N = B[k + ldb * (j+N)] +#define MATMUL_S(M, N) result##M##N += Aval##M * Bval##N +#if defined(B0) +#define STORE_S(M, N) C[(j+N)*ldc + i + M] = result##M##N * alpha +#else +#define STORE_S(M, N) C[(j+N)*ldc + i + M] = result##M##N * alpha + C[(j+N)*ldc + i + M] * beta +#endif + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m64 = M & ~63; + BLASLONG m32 = M & ~31; + BLASLONG m16 = M & ~15; + BLASLONG m8 = M & ~7; + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n4 = N & ~3; + BLASLONG n2 = N & ~1; + + __mmask8 mask = 0xff; // just use to avoid SSE instruction + + __m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha)); +#if !defined(B0) + __m512 beta_512 = _mm512_broadcastss_ps(_mm_load_ss(&beta)); +#endif + + for (i = 0; i < m64; i += 64) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); + STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + } + } + for (; i < m32; i += 32) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_512(0, 0); STORE_512(1, 0); + } + } + for (; i < m16; i += 16) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_512(0, 0); + STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + STORE_512(0, 0); + } + } + __m256 alpha_256 = _mm256_broadcastss_ps(_mm_load_ss(&alpha)); +#if !defined(B0) + __m256 beta_256 = _mm256_broadcastss_ps(_mm_load_ss(&beta)); +#endif + for (; i < m8; i += 8) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_256(0, 0); + DECLARE_RESULT_256(0, 1); + DECLARE_RESULT_256(0, 2); + DECLARE_RESULT_256(0, 3); + for (k = 0; k < K; k++) { + LOAD_A_256(0, x); + BROADCAST_LOAD_B_256(x, 0); BROADCAST_LOAD_B_256(x, 1); + BROADCAST_LOAD_B_256(x, 2); BROADCAST_LOAD_B_256(x, 3); + + MATMUL_256(0, 0); + MATMUL_256(0, 1); + MATMUL_256(0, 2); + MATMUL_256(0, 3); + } + STORE_256(0, 0); + STORE_256(0, 1); + STORE_256(0, 2); + STORE_256(0, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_256(0, 0); + DECLARE_RESULT_256(0, 1); + for (k = 0; k < K; k++) { + LOAD_A_256(0, x); + BROADCAST_LOAD_B_256(x, 0); BROADCAST_LOAD_B_256(x, 1); + MATMUL_256(0, 0); + MATMUL_256(0, 1); + } + STORE_256(0, 0); + STORE_256(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_256(0, 0); + for (k = 0; k < K; k++) { + LOAD_A_256(0, x); + BROADCAST_LOAD_B_256(x, 0); + MATMUL_256(0, 0); + } + STORE_256(0, 0); + } + } + __m128 alpha_128 = _mm_broadcastss_ps(_mm_load_ss(&alpha)); +#if !defined(B0) + __m128 beta_128 = _mm_broadcastss_ps(_mm_load_ss(&beta)); +#endif + for (; i < m4; i += 4) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_128(0, 0); + DECLARE_RESULT_128(0, 1); + DECLARE_RESULT_128(0, 2); + DECLARE_RESULT_128(0, 3); + for (k = 0; k < K; k++) { + LOAD_A_128(0, x); + BROADCAST_LOAD_B_128(x, 0); BROADCAST_LOAD_B_128(x, 1); + BROADCAST_LOAD_B_128(x, 2); BROADCAST_LOAD_B_128(x, 3); + + MATMUL_128(0, 0); + MATMUL_128(0, 1); + MATMUL_128(0, 2); + MATMUL_128(0, 3); + } + STORE_128(0, 0); + STORE_128(0, 1); + STORE_128(0, 2); + STORE_128(0, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_128(0, 0); + DECLARE_RESULT_128(0, 1); + for (k = 0; k < K; k++) { + LOAD_A_128(0, x); + BROADCAST_LOAD_B_128(x, 0); BROADCAST_LOAD_B_128(x, 1); + MATMUL_128(0, 0); + MATMUL_128(0, 1); + } + STORE_128(0, 0); + STORE_128(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_128(0, 0); + for (k = 0; k < K; k++) { + LOAD_A_128(0, x); + BROADCAST_LOAD_B_128(x, 0); + MATMUL_128(0, 0); + } + STORE_128(0, 0); + } + } + for (; i < m2; i += 2) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_S(0, 0); DECLARE_RESULT_S(1, 0); + DECLARE_RESULT_S(0, 1); DECLARE_RESULT_S(1, 1); + DECLARE_RESULT_S(0, 2); DECLARE_RESULT_S(1, 2); + DECLARE_RESULT_S(0, 3); DECLARE_RESULT_S(1, 3); + for (k = 0; k < K; k++) { + LOAD_A_S(0, x); LOAD_A_S(1, x); + BROADCAST_LOAD_B_S(x, 0); BROADCAST_LOAD_B_S(x, 1); + BROADCAST_LOAD_B_S(x, 2); BROADCAST_LOAD_B_S(x, 3); + + MATMUL_S(0, 0); MATMUL_S(1, 0); + MATMUL_S(0, 1); MATMUL_S(1, 1); + MATMUL_S(0, 2); MATMUL_S(1, 2); + MATMUL_S(0, 3); MATMUL_S(1, 3); + } + STORE_S(0, 0); STORE_S(1, 0); + STORE_S(0, 1); STORE_S(1, 1); + STORE_S(0, 2); STORE_S(1, 2); + STORE_S(0, 3); STORE_S(1, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_S(0, 0); DECLARE_RESULT_S(1, 0); + DECLARE_RESULT_S(0, 1); DECLARE_RESULT_S(1, 1); + for (k = 0; k < K; k++) { + LOAD_A_S(0, x); LOAD_A_S(1, x); + BROADCAST_LOAD_B_S(x, 0); BROADCAST_LOAD_B_S(x, 1); + MATMUL_S(0, 0); MATMUL_S(1, 0); + MATMUL_S(0, 1); MATMUL_S(1, 1); + } + STORE_S(0, 0); STORE_S(1, 0); + STORE_S(0, 1); STORE_S(1, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_S(0, 0); DECLARE_RESULT_S(1, 0); + for (k = 0; k < K; k++) { + LOAD_A_S(0, x); LOAD_A_S(1, x); + BROADCAST_LOAD_B_S(x, 0); + MATMUL_S(0, 0); MATMUL_S(1, 0); + } + STORE_S(0, 0); STORE_S(1, 0); + } + } + for (; i < M; i += 1) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_S(0, 0); + DECLARE_RESULT_S(0, 1); + DECLARE_RESULT_S(0, 2); + DECLARE_RESULT_S(0, 3); + for (k = 0; k < K; k++) { + LOAD_A_S(0, x); + BROADCAST_LOAD_B_S(x, 0); BROADCAST_LOAD_B_S(x, 1); + BROADCAST_LOAD_B_S(x, 2); BROADCAST_LOAD_B_S(x, 3); + + MATMUL_S(0, 0); + MATMUL_S(0, 1); + MATMUL_S(0, 2); + MATMUL_S(0, 3); + } + STORE_S(0, 0); + STORE_S(0, 1); + STORE_S(0, 2); + STORE_S(0, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_S(0, 0); + DECLARE_RESULT_S(0, 1); + for (k = 0; k < K; k++) { + LOAD_A_S(0, x); + BROADCAST_LOAD_B_S(x, 0); BROADCAST_LOAD_B_S(x, 1); + MATMUL_S(0, 0); + MATMUL_S(0, 1); + } + STORE_S(0, 0); + STORE_S(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_S(0, 0); + for (k = 0; k < K; k++) { + LOAD_A_S(0, x); LOAD_A_S(1, x); + BROADCAST_LOAD_B_S(x, 0); + MATMUL_S(0, 0); + } + STORE_S(0, 0); + } + } +} From f88470323bdb72a1e3ac54717606810699319d3b Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Sat, 8 May 2021 15:59:14 +0000 Subject: [PATCH 09/37] Optimize M < 16 using AVX512 mask --- .../x86_64/sgemm_small_kernel_nn_skylakex.c | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c index f2c79873e..f0b6d63a6 100644 --- a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c +++ b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c @@ -31,17 +31,25 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() #define LOAD_A_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[lda * k + i + (M*16)]) +#define MASK_LOAD_A_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &A[lda * k + i + (M*16)]) #define BROADCAST_LOAD_B_512(M, N) __m512 Bval##N = _mm512_broadcastss_ps(_mm_load_ss(&B[k + ldb * (j+N)])) #define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) #if defined(B0) #define STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ _mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) +#define MASK_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) #else #define STORE_512(M, N) \ BLASLONG offset##M##N = (j+N)*ldc + i + (M*16); \ result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ asm("vfmadd231ps (%1, %2, 4), %3, %0": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_512)); \ _mm512_storeu_ps(&C[offset##M##N], result##M##N) +#define MASK_STORE_512(M, N) \ + BLASLONG offset##M##N = (j+N)*ldc + i + (M*16); \ + result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + asm("vfmadd231ps (%1, %2, 4), %3, %0 %{%4%}": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_512), "k"(mask)); \ + _mm512_mask_storeu_ps(&C[offset##M##N], mask, result##M##N) #endif #define DECLARE_RESULT_256(M, N) __m256 result##M##N = _mm256_setzero_ps() @@ -241,6 +249,51 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp STORE_512(0, 0); } } + if (M - i > 0) { + register __mmask16 mask asm("k1") = (1UL << (M - i)) - 1; + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + MASK_STORE_512(0, 2); + MASK_STORE_512(0, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_STORE_512(0, 0); + } + return; + } __m256 alpha_256 = _mm256_broadcastss_ps(_mm_load_ss(&alpha)); #if !defined(B0) __m256 beta_256 = _mm256_broadcastss_ps(_mm_load_ss(&beta)); From 49b61a3f3027e24f19e78e573e50c86432aec574 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Tue, 11 May 2021 10:24:10 +0000 Subject: [PATCH 10/37] Small Matrix: skylakex: sgemm_nn: optimize for M <= 8 --- .../x86_64/sgemm_small_kernel_nn_skylakex.c | 302 +++++++++++++++++- 1 file changed, 301 insertions(+), 1 deletion(-) diff --git a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c index f0b6d63a6..ae4a9daa3 100644 --- a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c +++ b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c @@ -28,6 +28,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include "common.h" #include +#include #define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() #define LOAD_A_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[lda * k + i + (M*16)]) @@ -52,6 +53,18 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. _mm512_mask_storeu_ps(&C[offset##M##N], mask, result##M##N) #endif +#define LOAD_KA_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&mbuf[(mi + M)*K + k]); +#define LOAD_KB_512(M, N) __m512 Bval##N = _mm512_loadu_ps(&B[(j + N)*ldb + k]) +#define MASK_LOAD_KA_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &mbuf[(mi + M)*K + k]) +#define MASK_LOAD_KB_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[(j + N)*ldb + k]) +#if defined(B0) +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N); +#else +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N) + beta * C[(j+N)*ldc + i + M]; +#endif + + + #define DECLARE_RESULT_256(M, N) __m256 result##M##N = _mm256_setzero_ps() #define LOAD_A_256(M, N) __m256 Aval##M = _mm256_loadu_ps(&A[lda * k + i + (M*8)]) #define BROADCAST_LOAD_B_256(M, N) __m256 Bval##N = _mm256_broadcastss_ps(_mm_load_ss(&B[k + ldb * (j+N)])) @@ -249,7 +262,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp STORE_512(0, 0); } } - if (M - i > 0) { + if (M - i > 8) { register __mmask16 mask asm("k1") = (1UL << (M - i)) - 1; for (j = 0; j < n4; j += 4) { DECLARE_RESULT_512(0, 0); @@ -294,6 +307,293 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp } return; } + int mm = M - i; + if (mm) { + FLOAT *mbuf = (FLOAT *) malloc(sizeof(FLOAT)*mm*K); + __mmask8 mask8 = (1UL << mm) - 1; + __mmask16 mask; + BLASLONG k16 = K & ~15; + BLASLONG k8 = K & ~7; + for (k = 0; k < k8; k += 8) { + __m256 r0, r1, r2, r3, r4, r5, r6, r7; + __m256 t0, t1, t2, t3, t4, t5, t6, t7; + r0 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(0 + k)]); + r1 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(1 + k)]); + r2 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(2 + k)]); + r3 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(3 + k)]); + r4 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(4 + k)]); + r5 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(5 + k)]); + r6 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(6 + k)]); + r7 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(7 + k)]); + + t0 = _mm256_unpacklo_ps(r0, r1); + t1 = _mm256_unpackhi_ps(r0, r1); + t2 = _mm256_unpacklo_ps(r2, r3); + t3 = _mm256_unpackhi_ps(r2, r3); + t4 = _mm256_unpacklo_ps(r4, r5); + t5 = _mm256_unpackhi_ps(r4, r5); + t6 = _mm256_unpacklo_ps(r6, r7); + t7 = _mm256_unpackhi_ps(r6, r7); + + r0 = _mm256_shuffle_ps(t0,t2,_MM_SHUFFLE(1,0,1,0)); + r1 = _mm256_shuffle_ps(t0,t2,_MM_SHUFFLE(3,2,3,2)); + r2 = _mm256_shuffle_ps(t1,t3,_MM_SHUFFLE(1,0,1,0)); + r3 = _mm256_shuffle_ps(t1,t3,_MM_SHUFFLE(3,2,3,2)); + r4 = _mm256_shuffle_ps(t4,t6,_MM_SHUFFLE(1,0,1,0)); + r5 = _mm256_shuffle_ps(t4,t6,_MM_SHUFFLE(3,2,3,2)); + r6 = _mm256_shuffle_ps(t5,t7,_MM_SHUFFLE(1,0,1,0)); + r7 = _mm256_shuffle_ps(t5,t7,_MM_SHUFFLE(3,2,3,2)); + + t0 = _mm256_permute2f128_ps(r0, r4, 0x20); + t1 = _mm256_permute2f128_ps(r1, r5, 0x20); + t2 = _mm256_permute2f128_ps(r2, r6, 0x20); + t3 = _mm256_permute2f128_ps(r3, r7, 0x20); + t4 = _mm256_permute2f128_ps(r0, r4, 0x31); + t5 = _mm256_permute2f128_ps(r1, r5, 0x31); + t6 = _mm256_permute2f128_ps(r2, r6, 0x31); + t7 = _mm256_permute2f128_ps(r3, r7, 0x31); + + switch (mm) { + case 8: _mm256_storeu_ps(&mbuf[k + 7*K], t7); + case 7: _mm256_storeu_ps(&mbuf[k + 6*K], t6); + case 6: _mm256_storeu_ps(&mbuf[k + 5*K], t5); + case 5: _mm256_storeu_ps(&mbuf[k + 4*K], t4); + case 4: _mm256_storeu_ps(&mbuf[k + 3*K], t3); + case 3: _mm256_storeu_ps(&mbuf[k + 2*K], t2); + case 2: _mm256_storeu_ps(&mbuf[k + 1*K], t1); + case 1: _mm256_storeu_ps(&mbuf[k + 0*K], t0); + } + } + for (; k < K; k++) { + for (int ii = 0; ii < mm; ii++) { + mbuf[k + ii*K] = A[i + lda*k + ii]; + } + } + int mi = 0; + for (; i < m4; i += 4, mi += 4) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); STORE_REDUCE(2, 0); STORE_REDUCE(3, 0); + STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); STORE_REDUCE(2, 1); STORE_REDUCE(3, 1); + STORE_REDUCE(0, 2); STORE_REDUCE(1, 2); STORE_REDUCE(2, 2); STORE_REDUCE(3, 2); + STORE_REDUCE(0, 3); STORE_REDUCE(1, 3); STORE_REDUCE(2, 3); STORE_REDUCE(3, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); STORE_REDUCE(2, 0); STORE_REDUCE(3, 0); + STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); STORE_REDUCE(2, 1); STORE_REDUCE(3, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); STORE_REDUCE(2, 0); STORE_REDUCE(3, 0); + } + + } + for (; i < m2; i += 2, mi += 2) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); + STORE_REDUCE(0, 2); STORE_REDUCE(1, 2); + STORE_REDUCE(0, 3); STORE_REDUCE(1, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + } + } + for (; i < M; i += 1, mi += 1) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + STORE_REDUCE(0, 0); + STORE_REDUCE(0, 1); + STORE_REDUCE(0, 2); + STORE_REDUCE(0, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_REDUCE(0, 0); + STORE_REDUCE(0, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + STORE_REDUCE(0, 0); + } + } + free(mbuf); + return; + } __m256 alpha_256 = _mm256_broadcastss_ps(_mm_load_ss(&alpha)); #if !defined(B0) __m256 beta_256 = _mm256_broadcastss_ps(_mm_load_ss(&beta)); From 3d8c6d9607c82a999ad8661834d0d78605a5f321 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Tue, 11 May 2021 10:33:07 +0000 Subject: [PATCH 11/37] Small Matrix: skylakex: sgemm nn: clean up unused code --- .../x86_64/sgemm_small_kernel_nn_skylakex.c | 222 ------------------ 1 file changed, 222 deletions(-) diff --git a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c index ae4a9daa3..a5c530593 100644 --- a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c +++ b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c @@ -63,48 +63,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N) + beta * C[(j+N)*ldc + i + M]; #endif - - -#define DECLARE_RESULT_256(M, N) __m256 result##M##N = _mm256_setzero_ps() -#define LOAD_A_256(M, N) __m256 Aval##M = _mm256_loadu_ps(&A[lda * k + i + (M*8)]) -#define BROADCAST_LOAD_B_256(M, N) __m256 Bval##N = _mm256_broadcastss_ps(_mm_load_ss(&B[k + ldb * (j+N)])) -#define MATMUL_256(M, N) result##M##N = _mm256_fmadd_ps(Aval##M, Bval##N, result##M##N) -#if defined(B0) -#define STORE_256(M, N) result##M##N = _mm256_mul_ps(result##M##N, alpha_256); \ - _mm256_storeu_ps(&C[(j+N)*ldc + i + (M*8)], result##M##N) -#else -#define STORE_256(M, N) \ - BLASLONG offset##M##N = (j+N)*ldc + i + (M*8); \ - result##M##N = _mm256_mul_ps(result##M##N, alpha_256); \ - asm("vfmadd231ps (%1, %2, 4), %3, %0": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_256)); \ - _mm256_storeu_ps(&C[offset##M##N], result##M##N) -#endif - -#define DECLARE_RESULT_128(M, N) __m128 result##M##N; asm("vpxorq %0, %0, %0": "+v"(result##M##N):) -#define LOAD_A_128(M, N) __m128 Aval##M = _mm_maskz_loadu_ps(mask, &A[lda * k + i + (M*4)]) -#define BROADCAST_LOAD_B_128(M, N) __m128 Bval##N = _mm_broadcastss_ps(_mm_load_ss(&B[k + ldb * (j+N)])) -#define MATMUL_128(M, N) result##M##N = _mm_fmadd_ps(Aval##M, Bval##N, result##M##N) -#if defined(B0) -#define STORE_128(M, N) result##M##N = _mm_maskz_mul_ps(mask, result##M##N, alpha_128); \ - _mm_mask_storeu_ps(&C[(j+N)*ldc + i + (M*4)], mask, result##M##N) -#else -#define STORE_128(M, N) \ - BLASLONG offset##M##N = (j+N)*ldc + i + (M*4); \ - result##M##N = _mm_maskz_mul_ps(mask, result##M##N, alpha_128); \ - asm("vfmadd231ps (%1, %2, 4), %3, %0": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_128)); \ - _mm_mask_storeu_ps(&C[offset##M##N], mask, result##M##N) -#endif - -#define DECLARE_RESULT_S(M, N) float result##M##N = 0; -#define LOAD_A_S(M, N) float Aval##M = A[lda * k + i + M] -#define BROADCAST_LOAD_B_S(M, N) float Bval##N = B[k + ldb * (j+N)] -#define MATMUL_S(M, N) result##M##N += Aval##M * Bval##N -#if defined(B0) -#define STORE_S(M, N) C[(j+N)*ldc + i + M] = result##M##N * alpha -#else -#define STORE_S(M, N) C[(j+N)*ldc + i + M] = result##M##N * alpha + C[(j+N)*ldc + i + M] * beta -#endif - #if defined(B0) int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) #else @@ -594,184 +552,4 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp free(mbuf); return; } - __m256 alpha_256 = _mm256_broadcastss_ps(_mm_load_ss(&alpha)); -#if !defined(B0) - __m256 beta_256 = _mm256_broadcastss_ps(_mm_load_ss(&beta)); -#endif - for (; i < m8; i += 8) { - for (j = 0; j < n4; j += 4) { - DECLARE_RESULT_256(0, 0); - DECLARE_RESULT_256(0, 1); - DECLARE_RESULT_256(0, 2); - DECLARE_RESULT_256(0, 3); - for (k = 0; k < K; k++) { - LOAD_A_256(0, x); - BROADCAST_LOAD_B_256(x, 0); BROADCAST_LOAD_B_256(x, 1); - BROADCAST_LOAD_B_256(x, 2); BROADCAST_LOAD_B_256(x, 3); - - MATMUL_256(0, 0); - MATMUL_256(0, 1); - MATMUL_256(0, 2); - MATMUL_256(0, 3); - } - STORE_256(0, 0); - STORE_256(0, 1); - STORE_256(0, 2); - STORE_256(0, 3); - } - for (; j < n2; j += 2) { - DECLARE_RESULT_256(0, 0); - DECLARE_RESULT_256(0, 1); - for (k = 0; k < K; k++) { - LOAD_A_256(0, x); - BROADCAST_LOAD_B_256(x, 0); BROADCAST_LOAD_B_256(x, 1); - MATMUL_256(0, 0); - MATMUL_256(0, 1); - } - STORE_256(0, 0); - STORE_256(0, 1); - } - for (; j < N; j++) { - DECLARE_RESULT_256(0, 0); - for (k = 0; k < K; k++) { - LOAD_A_256(0, x); - BROADCAST_LOAD_B_256(x, 0); - MATMUL_256(0, 0); - } - STORE_256(0, 0); - } - } - __m128 alpha_128 = _mm_broadcastss_ps(_mm_load_ss(&alpha)); -#if !defined(B0) - __m128 beta_128 = _mm_broadcastss_ps(_mm_load_ss(&beta)); -#endif - for (; i < m4; i += 4) { - for (j = 0; j < n4; j += 4) { - DECLARE_RESULT_128(0, 0); - DECLARE_RESULT_128(0, 1); - DECLARE_RESULT_128(0, 2); - DECLARE_RESULT_128(0, 3); - for (k = 0; k < K; k++) { - LOAD_A_128(0, x); - BROADCAST_LOAD_B_128(x, 0); BROADCAST_LOAD_B_128(x, 1); - BROADCAST_LOAD_B_128(x, 2); BROADCAST_LOAD_B_128(x, 3); - - MATMUL_128(0, 0); - MATMUL_128(0, 1); - MATMUL_128(0, 2); - MATMUL_128(0, 3); - } - STORE_128(0, 0); - STORE_128(0, 1); - STORE_128(0, 2); - STORE_128(0, 3); - } - for (; j < n2; j += 2) { - DECLARE_RESULT_128(0, 0); - DECLARE_RESULT_128(0, 1); - for (k = 0; k < K; k++) { - LOAD_A_128(0, x); - BROADCAST_LOAD_B_128(x, 0); BROADCAST_LOAD_B_128(x, 1); - MATMUL_128(0, 0); - MATMUL_128(0, 1); - } - STORE_128(0, 0); - STORE_128(0, 1); - } - for (; j < N; j++) { - DECLARE_RESULT_128(0, 0); - for (k = 0; k < K; k++) { - LOAD_A_128(0, x); - BROADCAST_LOAD_B_128(x, 0); - MATMUL_128(0, 0); - } - STORE_128(0, 0); - } - } - for (; i < m2; i += 2) { - for (j = 0; j < n4; j += 4) { - DECLARE_RESULT_S(0, 0); DECLARE_RESULT_S(1, 0); - DECLARE_RESULT_S(0, 1); DECLARE_RESULT_S(1, 1); - DECLARE_RESULT_S(0, 2); DECLARE_RESULT_S(1, 2); - DECLARE_RESULT_S(0, 3); DECLARE_RESULT_S(1, 3); - for (k = 0; k < K; k++) { - LOAD_A_S(0, x); LOAD_A_S(1, x); - BROADCAST_LOAD_B_S(x, 0); BROADCAST_LOAD_B_S(x, 1); - BROADCAST_LOAD_B_S(x, 2); BROADCAST_LOAD_B_S(x, 3); - - MATMUL_S(0, 0); MATMUL_S(1, 0); - MATMUL_S(0, 1); MATMUL_S(1, 1); - MATMUL_S(0, 2); MATMUL_S(1, 2); - MATMUL_S(0, 3); MATMUL_S(1, 3); - } - STORE_S(0, 0); STORE_S(1, 0); - STORE_S(0, 1); STORE_S(1, 1); - STORE_S(0, 2); STORE_S(1, 2); - STORE_S(0, 3); STORE_S(1, 3); - } - for (; j < n2; j += 2) { - DECLARE_RESULT_S(0, 0); DECLARE_RESULT_S(1, 0); - DECLARE_RESULT_S(0, 1); DECLARE_RESULT_S(1, 1); - for (k = 0; k < K; k++) { - LOAD_A_S(0, x); LOAD_A_S(1, x); - BROADCAST_LOAD_B_S(x, 0); BROADCAST_LOAD_B_S(x, 1); - MATMUL_S(0, 0); MATMUL_S(1, 0); - MATMUL_S(0, 1); MATMUL_S(1, 1); - } - STORE_S(0, 0); STORE_S(1, 0); - STORE_S(0, 1); STORE_S(1, 1); - } - for (; j < N; j++) { - DECLARE_RESULT_S(0, 0); DECLARE_RESULT_S(1, 0); - for (k = 0; k < K; k++) { - LOAD_A_S(0, x); LOAD_A_S(1, x); - BROADCAST_LOAD_B_S(x, 0); - MATMUL_S(0, 0); MATMUL_S(1, 0); - } - STORE_S(0, 0); STORE_S(1, 0); - } - } - for (; i < M; i += 1) { - for (j = 0; j < n4; j += 4) { - DECLARE_RESULT_S(0, 0); - DECLARE_RESULT_S(0, 1); - DECLARE_RESULT_S(0, 2); - DECLARE_RESULT_S(0, 3); - for (k = 0; k < K; k++) { - LOAD_A_S(0, x); - BROADCAST_LOAD_B_S(x, 0); BROADCAST_LOAD_B_S(x, 1); - BROADCAST_LOAD_B_S(x, 2); BROADCAST_LOAD_B_S(x, 3); - - MATMUL_S(0, 0); - MATMUL_S(0, 1); - MATMUL_S(0, 2); - MATMUL_S(0, 3); - } - STORE_S(0, 0); - STORE_S(0, 1); - STORE_S(0, 2); - STORE_S(0, 3); - } - for (; j < n2; j += 2) { - DECLARE_RESULT_S(0, 0); - DECLARE_RESULT_S(0, 1); - for (k = 0; k < K; k++) { - LOAD_A_S(0, x); - BROADCAST_LOAD_B_S(x, 0); BROADCAST_LOAD_B_S(x, 1); - MATMUL_S(0, 0); - MATMUL_S(0, 1); - } - STORE_S(0, 0); - STORE_S(0, 1); - } - for (; j < N; j++) { - DECLARE_RESULT_S(0, 0); - for (k = 0; k < K; k++) { - LOAD_A_S(0, x); LOAD_A_S(1, x); - BROADCAST_LOAD_B_S(x, 0); - MATMUL_S(0, 0); - } - STORE_S(0, 0); - } - } } From 13b32f69b78b15e7d95978011ea6c2bb3d9e3642 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 12 May 2021 17:08:18 +0000 Subject: [PATCH 12/37] Small Matrix: skylakex: sgemm nn: reduce store 4 M at a time --- .../x86_64/sgemm_small_kernel_nn_skylakex.c | 64 ++++++++++++++----- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c index a5c530593..be9f085c0 100644 --- a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c +++ b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c @@ -57,10 +57,30 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define LOAD_KB_512(M, N) __m512 Bval##N = _mm512_loadu_ps(&B[(j + N)*ldb + k]) #define MASK_LOAD_KA_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &mbuf[(mi + M)*K + k]) #define MASK_LOAD_KB_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[(j + N)*ldb + k]) +#define REDUCE_M4(N) \ + __m512 r0, r1, r2, r3, t0, t1, t2, t3;\ + r0 = _mm512_unpacklo_ps(result0##N, result1##N); r1 = _mm512_unpackhi_ps(result0##N, result1##N); \ + r2 = _mm512_unpacklo_ps(result2##N, result3##N); r3 = _mm512_unpackhi_ps(result2##N, result3##N); \ + t0 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(1, 0, 1, 0)); t1 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 2, 3, 2)); \ + t2 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(1, 0, 1, 0)); t3 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 2, 3, 2)); \ + r0 = _mm512_add_ps(t0, t1); r1 = _mm512_add_ps(t2, t3); t0 = _mm512_add_ps(r0, r1); \ + __m128 s0, s1, s2, s3; \ + s0 = _mm512_extractf32x4_ps(t0, 0); s1 = _mm512_extractf32x4_ps(t0, 1); s2 = _mm512_extractf32x4_ps(t0, 2); s3 = _mm512_extractf32x4_ps(t0, 3); \ + s0 = _mm_maskz_add_ps(mask8, s0, s1); s2 = _mm_maskz_add_ps(mask8, s2, s3); s0 = _mm_maskz_add_ps(mask8, s0, s2); \ + s0 = _mm_maskz_mul_ps(mask8, alpha_128, s0); #if defined(B0) #define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N); +#define STORE_REDUCE_M4(N) {\ + REDUCE_M4(N) \ + _mm_mask_storeu_ps(&C[(j + N)*ldc + i], mask8, s0); \ +} #else #define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N) + beta * C[(j+N)*ldc + i + M]; +#define STORE_REDUCE_M4(N) {\ + REDUCE_M4(N) \ + asm("vfmadd231ps (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_128)); \ + _mm_mask_storeu_ps(&C[(j + N)*ldc + i], mask8, s0); \ +} #endif #if defined(B0) @@ -75,14 +95,12 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp BLASLONG m64 = M & ~63; BLASLONG m32 = M & ~31; BLASLONG m16 = M & ~15; - BLASLONG m8 = M & ~7; BLASLONG m4 = M & ~3; BLASLONG m2 = M & ~1; BLASLONG n4 = N & ~3; BLASLONG n2 = N & ~1; - __mmask8 mask = 0xff; // just use to avoid SSE instruction __m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha)); #if !defined(B0) @@ -220,8 +238,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp STORE_512(0, 0); } } - if (M - i > 8) { - register __mmask16 mask asm("k1") = (1UL << (M - i)) - 1; + int mm = M - i; + if (!mm) return 0; + if (mm > 8 || K < 32) { + register __mmask16 mask asm("k1") = (1UL << mm) - 1; for (j = 0; j < n4; j += 4) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(0, 1); @@ -263,10 +283,20 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp } MASK_STORE_512(0, 0); } - return; - } - int mm = M - i; - if (mm) { + } else { + /* M => [1, 8] + * + * This kernel use dot-like style to calc a value - C(x, y): + * C(x, y) = A(x, 0)*B(0, y) + A(x, 1)*B(1, y) +....+ A(x, K)*B(K, y) + * + * Alloc a buf to copy rest of A as row major, + * so memory access from 0 to K is continuous for both A & B. + * + * Loading to zmm and FMA 16 of k at one loop, + * finally reduce_add zmm to a single float result in C(x, y). + * + * Note: performance is bad when K is small. + */ FLOAT *mbuf = (FLOAT *) malloc(sizeof(FLOAT)*mm*K); __mmask8 mask8 = (1UL << mm) - 1; __mmask16 mask; @@ -328,6 +358,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp } } int mi = 0; + mask8 = 0xff; // just use to avoid SSE instruction + __m128 alpha_128 = _mm_broadcast_ss(&alpha); +#if !defined(B0) + __m128 beta_128 = _mm_broadcast_ss(&beta); +#endif for (; i < m4; i += 4, mi += 4) { for (j = 0; j < n4; j += 4) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); @@ -354,10 +389,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); } - STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); STORE_REDUCE(2, 0); STORE_REDUCE(3, 0); - STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); STORE_REDUCE(2, 1); STORE_REDUCE(3, 1); - STORE_REDUCE(0, 2); STORE_REDUCE(1, 2); STORE_REDUCE(2, 2); STORE_REDUCE(3, 2); - STORE_REDUCE(0, 3); STORE_REDUCE(1, 3); STORE_REDUCE(2, 3); STORE_REDUCE(3, 3); + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); STORE_REDUCE_M4(2); STORE_REDUCE_M4(3); } for (; j < n2; j += 2) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); @@ -378,9 +410,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); } - STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); STORE_REDUCE(2, 0); STORE_REDUCE(3, 0); - STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); STORE_REDUCE(2, 1); STORE_REDUCE(3, 1); - + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); } for (; j < N; j += 1) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); @@ -398,7 +428,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); } - STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); STORE_REDUCE(2, 0); STORE_REDUCE(3, 0); + STORE_REDUCE_M4(0); } } @@ -550,6 +580,6 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp } } free(mbuf); - return; } + return 0; } From 4c9d9940fdd6a458289a02e850afd65d5b9689ba Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 13 May 2021 09:41:51 +0000 Subject: [PATCH 13/37] Small Matrix: skylakex: sgemm nn: reduce store 4 N at a time --- .../x86_64/sgemm_small_kernel_nn_skylakex.c | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c index be9f085c0..c9f43f9a2 100644 --- a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c +++ b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c @@ -57,10 +57,10 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define LOAD_KB_512(M, N) __m512 Bval##N = _mm512_loadu_ps(&B[(j + N)*ldb + k]) #define MASK_LOAD_KA_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &mbuf[(mi + M)*K + k]) #define MASK_LOAD_KB_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[(j + N)*ldb + k]) -#define REDUCE_M4(N) \ +#define REDUCE_4(rr0, rr1, rr2, rr3) \ __m512 r0, r1, r2, r3, t0, t1, t2, t3;\ - r0 = _mm512_unpacklo_ps(result0##N, result1##N); r1 = _mm512_unpackhi_ps(result0##N, result1##N); \ - r2 = _mm512_unpacklo_ps(result2##N, result3##N); r3 = _mm512_unpackhi_ps(result2##N, result3##N); \ + r0 = _mm512_unpacklo_ps(rr0, rr1); r1 = _mm512_unpackhi_ps(rr0, rr1); \ + r2 = _mm512_unpacklo_ps(rr2, rr3); r3 = _mm512_unpackhi_ps(rr2, rr3); \ t0 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(1, 0, 1, 0)); t1 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 2, 3, 2)); \ t2 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(1, 0, 1, 0)); t3 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 2, 3, 2)); \ r0 = _mm512_add_ps(t0, t1); r1 = _mm512_add_ps(t2, t3); t0 = _mm512_add_ps(r0, r1); \ @@ -68,12 +68,18 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. s0 = _mm512_extractf32x4_ps(t0, 0); s1 = _mm512_extractf32x4_ps(t0, 1); s2 = _mm512_extractf32x4_ps(t0, 2); s3 = _mm512_extractf32x4_ps(t0, 3); \ s0 = _mm_maskz_add_ps(mask8, s0, s1); s2 = _mm_maskz_add_ps(mask8, s2, s3); s0 = _mm_maskz_add_ps(mask8, s0, s2); \ s0 = _mm_maskz_mul_ps(mask8, alpha_128, s0); +#define REDUCE_M4(N) REDUCE_4(result0##N, result1##N, result2##N, result3##N) +#define REDUCE_N4(M) REDUCE_4(result##M##0, result##M##1, result##M##2, result##M##3) #if defined(B0) #define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N); #define STORE_REDUCE_M4(N) {\ REDUCE_M4(N) \ _mm_mask_storeu_ps(&C[(j + N)*ldc + i], mask8, s0); \ } +#define STORE_REDUCE_N4(M) {\ + REDUCE_N4(M) \ + _mm_i32scatter_ps(&C[j*ldc + i + M], vindex_n, s0, 4); \ +} #else #define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N) + beta * C[(j+N)*ldc + i + M]; #define STORE_REDUCE_M4(N) {\ @@ -81,6 +87,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. asm("vfmadd231ps (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_128)); \ _mm_mask_storeu_ps(&C[(j + N)*ldc + i], mask8, s0); \ } +#define STORE_REDUCE_N4(M) {\ + REDUCE_N4(M) \ + s1 = _mm_i32gather_ps(&C[j*ldc + i + M], vindex_n, 4); \ + s0 = _mm_fmadd_ps(s1, beta_128, s0); \ + _mm_i32scatter_ps(&C[j*ldc + i + M], vindex_n, s0, 4); \ +} #endif #if defined(B0) @@ -363,6 +375,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp #if !defined(B0) __m128 beta_128 = _mm_broadcast_ss(&beta); #endif + __m128i vindex_n = _mm_set_epi32(ldc*3, ldc*2, ldc, 0); for (; i < m4; i += 4, mi += 4) { for (j = 0; j < n4; j += 4) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); @@ -458,10 +471,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(0, 3); MATMUL_512(1, 3); } - STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); - STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); - STORE_REDUCE(0, 2); STORE_REDUCE(1, 2); - STORE_REDUCE(0, 3); STORE_REDUCE(1, 3); + STORE_REDUCE_N4(0); STORE_REDUCE_N4(1); } for (; j < n2; j += 2) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); @@ -532,10 +542,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp MATMUL_512(0, 2); MATMUL_512(0, 3); } - STORE_REDUCE(0, 0); - STORE_REDUCE(0, 1); - STORE_REDUCE(0, 2); - STORE_REDUCE(0, 3); + STORE_REDUCE_N4(0); } for (; j < n2; j += 2) { DECLARE_RESULT_512(0, 0); From a87736346fd3988618c0d8895827566fce5a5487 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 13 May 2021 10:16:54 +0000 Subject: [PATCH 14/37] Small Matrix: skylakex: sgemm nn: add n6 to improve performance --- .../x86_64/sgemm_small_kernel_nn_skylakex.c | 90 ++++++++++++++++++- 1 file changed, 87 insertions(+), 3 deletions(-) diff --git a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c index c9f43f9a2..a67541161 100644 --- a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c +++ b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c @@ -110,6 +110,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp BLASLONG m4 = M & ~3; BLASLONG m2 = M & ~1; + BLASLONG n6 = N - (N % 6); BLASLONG n4 = N & ~3; BLASLONG n2 = N & ~1; @@ -165,7 +166,34 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp } } for (; i < m32; i += 32) { - for (j = 0; j < n4; j += 4) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); + DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + MATMUL_512(0, 4); MATMUL_512(1, 4); + MATMUL_512(0, 5); MATMUL_512(1, 5); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + STORE_512(0, 4); STORE_512(1, 4); + STORE_512(0, 5); STORE_512(1, 5); + } + for (;j < n4; j += 4) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); @@ -208,7 +236,34 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp } } for (; i < m16; i += 16) { - for (j = 0; j < n4; j += 4) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + STORE_512(0, 4); + STORE_512(0, 5); + } + for (; j < n4; j += 4) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(0, 2); @@ -228,6 +283,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp STORE_512(0, 2); STORE_512(0, 3); } + for (; j < n2; j += 2) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(0, 1); @@ -254,7 +310,34 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp if (!mm) return 0; if (mm > 8 || K < 32) { register __mmask16 mask asm("k1") = (1UL << mm) - 1; - for (j = 0; j < n4; j += 4) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + MASK_STORE_512(0, 2); + MASK_STORE_512(0, 3); + MASK_STORE_512(0, 4); + MASK_STORE_512(0, 5); + } + for (; j < n4; j += 4) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(0, 2); @@ -274,6 +357,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp MASK_STORE_512(0, 2); MASK_STORE_512(0, 3); } + for (; j < n2; j += 2) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(0, 1); From 9967e61abb3ba0b87a043662382c515ed9d220bb Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 19 May 2021 10:50:03 +0000 Subject: [PATCH 15/37] Small Matrix: skylakex: sgemm nn: fix error when beta not zero --- kernel/x86_64/sgemm_small_kernel_nn_skylakex.c | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c index a67541161..99856d0af 100644 --- a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c +++ b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c @@ -42,15 +42,13 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. _mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) #else #define STORE_512(M, N) \ - BLASLONG offset##M##N = (j+N)*ldc + i + (M*16); \ result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ - asm("vfmadd231ps (%1, %2, 4), %3, %0": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_512)); \ - _mm512_storeu_ps(&C[offset##M##N], result##M##N) + asm("vfmadd231ps (%1), %2, %0": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512)); \ + _mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) #define MASK_STORE_512(M, N) \ - BLASLONG offset##M##N = (j+N)*ldc + i + (M*16); \ result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ - asm("vfmadd231ps (%1, %2, 4), %3, %0 %{%4%}": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_512), "k"(mask)); \ - _mm512_mask_storeu_ps(&C[offset##M##N], mask, result##M##N) + asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \ + _mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) #endif #define LOAD_KA_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&mbuf[(mi + M)*K + k]); From ca7682e3a3dceeb52ba1ad554f384388ffb24c9a Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 20 May 2021 11:24:31 +0000 Subject: [PATCH 16/37] Small Matrix: skylakex: sgemm nn: fix n6 conflicts with n4 --- .../x86_64/sgemm_small_kernel_nn_skylakex.c | 62 ------------------- 1 file changed, 62 deletions(-) diff --git a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c index 99856d0af..9bc7a7c58 100644 --- a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c +++ b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c @@ -191,26 +191,6 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp STORE_512(0, 4); STORE_512(1, 4); STORE_512(0, 5); STORE_512(1, 5); } - for (;j < n4; j += 4) { - DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); - DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); - DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); - DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); - for (k = 0; k < K; k++) { - LOAD_A_512(0, x); LOAD_A_512(1, x); - BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); - BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); - - MATMUL_512(0, 0); MATMUL_512(1, 0); - MATMUL_512(0, 1); MATMUL_512(1, 1); - MATMUL_512(0, 2); MATMUL_512(1, 2); - MATMUL_512(0, 3); MATMUL_512(1, 3); - } - STORE_512(0, 0); STORE_512(1, 0); - STORE_512(0, 1); STORE_512(1, 1); - STORE_512(0, 2); STORE_512(1, 2); - STORE_512(0, 3); STORE_512(1, 3); - } for (; j < n2; j += 2) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); @@ -261,27 +241,6 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp STORE_512(0, 4); STORE_512(0, 5); } - for (; j < n4; j += 4) { - DECLARE_RESULT_512(0, 0); - DECLARE_RESULT_512(0, 1); - DECLARE_RESULT_512(0, 2); - DECLARE_RESULT_512(0, 3); - for (k = 0; k < K; k++) { - LOAD_A_512(0, x); - BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); - BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); - - MATMUL_512(0, 0); - MATMUL_512(0, 1); - MATMUL_512(0, 2); - MATMUL_512(0, 3); - } - STORE_512(0, 0); - STORE_512(0, 1); - STORE_512(0, 2); - STORE_512(0, 3); - } - for (; j < n2; j += 2) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(0, 1); @@ -335,27 +294,6 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp MASK_STORE_512(0, 4); MASK_STORE_512(0, 5); } - for (; j < n4; j += 4) { - DECLARE_RESULT_512(0, 0); - DECLARE_RESULT_512(0, 1); - DECLARE_RESULT_512(0, 2); - DECLARE_RESULT_512(0, 3); - for (k = 0; k < K; k++) { - MASK_LOAD_A_512(0, x); - BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); - BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); - - MATMUL_512(0, 0); - MATMUL_512(0, 1); - MATMUL_512(0, 2); - MATMUL_512(0, 3); - } - MASK_STORE_512(0, 0); - MASK_STORE_512(0, 1); - MASK_STORE_512(0, 2); - MASK_STORE_512(0, 3); - } - for (; j < n2; j += 2) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(0, 1); From 0d72d75bf9455c91b6f0c4ecf5b7555845dccf6f Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 20 May 2021 11:47:10 +0000 Subject: [PATCH 17/37] Small Matrix: skylakex: add sgemm nt kernel --- kernel/x86_64/KERNEL.SKYLAKEX | 2 + .../sgemm_small_kernel_b0_nt_skylakex.c | 2 + .../x86_64/sgemm_small_kernel_nt_skylakex.c | 366 ++++++++++++++++++ 3 files changed, 370 insertions(+) create mode 100644 kernel/x86_64/sgemm_small_kernel_b0_nt_skylakex.c create mode 100644 kernel/x86_64/sgemm_small_kernel_nt_skylakex.c diff --git a/kernel/x86_64/KERNEL.SKYLAKEX b/kernel/x86_64/KERNEL.SKYLAKEX index 1a2e67b52..d3560bf80 100644 --- a/kernel/x86_64/KERNEL.SKYLAKEX +++ b/kernel/x86_64/KERNEL.SKYLAKEX @@ -12,6 +12,8 @@ STRSMKERNEL_RN = ../generic/trsm_kernel_RN.c STRSMKERNEL_RT = ../generic/trsm_kernel_RT.c SGEMM_SMALL_K_NN = sgemm_small_kernel_nn_skylakex.c SGEMM_SMALL_K_B0_NN = sgemm_small_kernel_b0_nn_skylakex.c +SGEMM_SMALL_K_NT = sgemm_small_kernel_nt_skylakex.c +SGEMM_SMALL_K_B0_NT = sgemm_small_kernel_b0_nt_skylakex.c DGEMMKERNEL = dgemm_kernel_16x2_skylakex.c DTRMMKERNEL = dgemm_kernel_16x2_skylakex.c diff --git a/kernel/x86_64/sgemm_small_kernel_b0_nt_skylakex.c b/kernel/x86_64/sgemm_small_kernel_b0_nt_skylakex.c new file mode 100644 index 000000000..6d7934be1 --- /dev/null +++ b/kernel/x86_64/sgemm_small_kernel_b0_nt_skylakex.c @@ -0,0 +1,2 @@ +#define B0 1 +#include "./sgemm_small_kernel_nt_skylakex.c" diff --git a/kernel/x86_64/sgemm_small_kernel_nt_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nt_skylakex.c new file mode 100644 index 000000000..3fc842669 --- /dev/null +++ b/kernel/x86_64/sgemm_small_kernel_nt_skylakex.c @@ -0,0 +1,366 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include +#include + +#define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() +#define LOAD_A_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[lda * k + i + (M*16)]) +#define MASK_LOAD_A_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &A[lda * k + i + (M*16)]) +#define BROADCAST_LOAD_B_512(M, N) __m512 Bval##N = _mm512_broadcastss_ps(_mm_load_ss(&B[ldb * k + j + N])) +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) +#if defined(B0) +#define STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) +#define MASK_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) +#else +#define STORE_512(M, N) \ + result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + asm("vfmadd231ps (%1), %2, %0": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512)); \ + _mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) +#define MASK_STORE_512(M, N) \ + result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \ + _mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) +#endif + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m64 = M & ~63; + BLASLONG m32 = M & ~31; + BLASLONG m16 = M & ~15; + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n8 = N & ~7; + BLASLONG n6 = N - (N % 6); + BLASLONG n4 = N & ~3; + BLASLONG n2 = N & ~1; + + + __m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha)); +#if !defined(B0) + __m512 beta_512 = _mm512_broadcastss_ps(_mm_load_ss(&beta)); +#endif + + for (i = 0; i < m64; i += 64) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); DECLARE_RESULT_512(2, 4); DECLARE_RESULT_512(3, 4); + DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); DECLARE_RESULT_512(2, 5); DECLARE_RESULT_512(3, 5); + + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + MATMUL_512(0, 4); MATMUL_512(1, 4); MATMUL_512(2, 4); MATMUL_512(3, 4); + MATMUL_512(0, 5); MATMUL_512(1, 5); MATMUL_512(2, 5); MATMUL_512(3, 5); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); + STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); + STORE_512(0, 4); STORE_512(1, 4); STORE_512(2, 4); STORE_512(3, 4); + STORE_512(0, 5); STORE_512(1, 5); STORE_512(2, 5); STORE_512(3, 5); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + } + } + for (; i < m32; i += 32) { + for (j = 0; j < n8; j += 8) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); + DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); + DECLARE_RESULT_512(0, 6); DECLARE_RESULT_512(1, 6); + DECLARE_RESULT_512(0, 7); DECLARE_RESULT_512(1, 7); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + MATMUL_512(0, 4); MATMUL_512(1, 4); + MATMUL_512(0, 5); MATMUL_512(1, 5); + MATMUL_512(0, 6); MATMUL_512(1, 6); + MATMUL_512(0, 7); MATMUL_512(1, 7); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + STORE_512(0, 4); STORE_512(1, 4); + STORE_512(0, 5); STORE_512(1, 5); + STORE_512(0, 6); STORE_512(1, 6); + STORE_512(0, 7); STORE_512(1, 7); + } + for (;j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_512(0, 0); STORE_512(1, 0); + } + } + for (; i < m16; i += 16) { + for (j = 0; j < n8; j += 8) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + DECLARE_RESULT_512(0, 6); + DECLARE_RESULT_512(0, 7); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + MATMUL_512(0, 6); + MATMUL_512(0, 7); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + STORE_512(0, 4); + STORE_512(0, 5); + STORE_512(0, 6); + STORE_512(0, 7); + } + for (; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + } + + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_512(0, 0); + STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + STORE_512(0, 0); + } + } + int mm = M - i; + if (mm > 0) { + register __mmask16 mask asm("k1") = (1UL << mm) - 1; + for (j = 0; j < n8; j += 8) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + DECLARE_RESULT_512(0, 6); + DECLARE_RESULT_512(0, 7); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + MATMUL_512(0, 6); + MATMUL_512(0, 7); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + MASK_STORE_512(0, 2); + MASK_STORE_512(0, 3); + MASK_STORE_512(0, 4); + MASK_STORE_512(0, 5); + MASK_STORE_512(0, 6); + MASK_STORE_512(0, 7); + } + for (; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + MASK_STORE_512(0, 2); + MASK_STORE_512(0, 3); + } + + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_STORE_512(0, 0); + } + } +} From ae3f5c737c24e6fdb7de4559969bee5631aa1683 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Fri, 21 May 2021 13:31:31 +0000 Subject: [PATCH 18/37] Small Matrix: skylakex: sgemm nt: optimize for M < 12 --- .../x86_64/sgemm_small_kernel_nt_skylakex.c | 171 +++++++++++++++++- 1 file changed, 170 insertions(+), 1 deletion(-) diff --git a/kernel/x86_64/sgemm_small_kernel_nt_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nt_skylakex.c index 3fc842669..f293bf9f9 100644 --- a/kernel/x86_64/sgemm_small_kernel_nt_skylakex.c +++ b/kernel/x86_64/sgemm_small_kernel_nt_skylakex.c @@ -35,11 +35,19 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define MASK_LOAD_A_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &A[lda * k + i + (M*16)]) #define BROADCAST_LOAD_B_512(M, N) __m512 Bval##N = _mm512_broadcastss_ps(_mm_load_ss(&B[ldb * k + j + N])) #define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) + +#define BROADCAST_LOAD_A_512(M, N) __m512 Aval##M = _mm512_broadcastss_ps(_mm_load_ss(&A[lda * k + i + M])) +#define LOAD_B_512(M, N) __m512 Bval##N = _mm512_loadu_ps(&B[ldb * k + j + (N*16)]) +#define MASK_LOAD_B_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[ldb * k + j + (N*16)]) #if defined(B0) #define STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ _mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) #define MASK_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ _mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4) #else #define STORE_512(M, N) \ result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ @@ -49,6 +57,14 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \ _mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + __m512 tmp##M##N = _mm512_i32gather_ps(vindex_n, &C[(j + N*16)*ldc + i + M], 4); \ + result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \ + _mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + __m512 tmp##M##N = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, vindex_n, &C[(j + N*16)*ldc + i + M], 4); \ + result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \ + _mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4); #endif #if defined(B0) @@ -66,6 +82,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp BLASLONG m4 = M & ~3; BLASLONG m2 = M & ~1; + BLASLONG n64 = N & ~63; + BLASLONG n32 = N & ~31; BLASLONG n8 = N & ~7; BLASLONG n6 = N - (N % 6); BLASLONG n4 = N & ~3; @@ -284,7 +302,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp } } int mm = M - i; - if (mm > 0) { + if (mm >= 12) { register __mmask16 mask asm("k1") = (1UL << mm) - 1; for (j = 0; j < n8; j += 8) { DECLARE_RESULT_512(0, 0); @@ -362,5 +380,156 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp } MASK_STORE_512(0, 0); } + } else if (mm > 0) { + int index_n[16]; + for (int ii = 0; ii < 16; ii++) { + index_n[ii] = ii * ldc; + } + __m512i vindex_n = _mm512_loadu_epi32(index_n); + for (; i < m4; i += 4) { + for (j = 0; j < n64; j += 64) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + LOAD_B_512(x, 2); + LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); SCATTER_STORE_512(2, 0); SCATTER_STORE_512(3, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); SCATTER_STORE_512(2, 1); SCATTER_STORE_512(3, 1); + SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); SCATTER_STORE_512(2, 2); SCATTER_STORE_512(3, 2); + SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); SCATTER_STORE_512(2, 3); SCATTER_STORE_512(3, 3); + } + for (; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); SCATTER_STORE_512(2, 0); SCATTER_STORE_512(3, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); SCATTER_STORE_512(2, 1); SCATTER_STORE_512(3, 1); + } + __mmask16 mask = 0xffff; + for (; j < N; j += 16) { + int remains = N - j; + if (remains < 16) mask = (1UL << remains) - 1; + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); MASK_SCATTER_STORE_512(2, 0); MASK_SCATTER_STORE_512(3, 0); + } + } + for (; i < m2; i += 2) { + for (j = 0; j < n64; j += 64) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + LOAD_B_512(x, 2); + LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); + SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); + } + for (; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + } + __mmask16 mask = 0xffff; + for (; j < N; j += 16) { + int remains = N - j; + if (remains < 16) mask = (1UL << remains) - 1; + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); + } + } + for (; i < M; i += 1) { + for (j = 0; j < n64; j += 64) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + LOAD_B_512(x, 2); + LOAD_B_512(x, 3); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + SCATTER_STORE_512(0, 2); + SCATTER_STORE_512(0, 3); + } + for (; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + } + __mmask16 mask = 0xffff; + for (; j < N; j += 16) { + int remains = N - j; + if (remains < 16) mask = (1UL << remains) - 1; + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_SCATTER_STORE_512(0, 0); + } + } } + return 0; } From 642c3938790b45606dea7450a6fbc23b6c9b9b9c Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 26 May 2021 16:30:57 +0000 Subject: [PATCH 19/37] Small Matrix: skylakex: add sgemm tn kernel --- kernel/x86_64/KERNEL.SKYLAKEX | 2 + .../sgemm_small_kernel_b0_tn_skylakex.c | 2 + .../x86_64/sgemm_small_kernel_tn_skylakex.c | 316 ++++++++++++++++++ 3 files changed, 320 insertions(+) create mode 100644 kernel/x86_64/sgemm_small_kernel_b0_tn_skylakex.c create mode 100644 kernel/x86_64/sgemm_small_kernel_tn_skylakex.c diff --git a/kernel/x86_64/KERNEL.SKYLAKEX b/kernel/x86_64/KERNEL.SKYLAKEX index d3560bf80..5e0d9e5b4 100644 --- a/kernel/x86_64/KERNEL.SKYLAKEX +++ b/kernel/x86_64/KERNEL.SKYLAKEX @@ -14,6 +14,8 @@ SGEMM_SMALL_K_NN = sgemm_small_kernel_nn_skylakex.c SGEMM_SMALL_K_B0_NN = sgemm_small_kernel_b0_nn_skylakex.c SGEMM_SMALL_K_NT = sgemm_small_kernel_nt_skylakex.c SGEMM_SMALL_K_B0_NT = sgemm_small_kernel_b0_nt_skylakex.c +SGEMM_SMALL_K_TN = sgemm_small_kernel_tn_skylakex.c +SGEMM_SMALL_K_B0_TN = sgemm_small_kernel_b0_tn_skylakex.c DGEMMKERNEL = dgemm_kernel_16x2_skylakex.c DTRMMKERNEL = dgemm_kernel_16x2_skylakex.c diff --git a/kernel/x86_64/sgemm_small_kernel_b0_tn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_b0_tn_skylakex.c new file mode 100644 index 000000000..0f9745b72 --- /dev/null +++ b/kernel/x86_64/sgemm_small_kernel_b0_tn_skylakex.c @@ -0,0 +1,2 @@ +#define B0 1 +#include "./sgemm_small_kernel_tn_skylakex.c" diff --git a/kernel/x86_64/sgemm_small_kernel_tn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_tn_skylakex.c new file mode 100644 index 000000000..5a9a4ea32 --- /dev/null +++ b/kernel/x86_64/sgemm_small_kernel_tn_skylakex.c @@ -0,0 +1,316 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include +#include + +#define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) + +#define LOAD_KA_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[(i + M)*lda + k]); +#define LOAD_KB_512(M, N) __m512 Bval##N = _mm512_loadu_ps(&B[(j + N)*ldb + k]) +#define MASK_LOAD_KA_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &A[(i + M)*lda + k]) +#define MASK_LOAD_KB_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[(j + N)*ldb + k]) + +#define REDUCE_4(rr0, rr1, rr2, rr3) \ + __m512 r0, r1, r2, r3, t0, t1, t2, t3;\ + r0 = _mm512_unpacklo_ps(rr0, rr1); r1 = _mm512_unpackhi_ps(rr0, rr1); \ + r2 = _mm512_unpacklo_ps(rr2, rr3); r3 = _mm512_unpackhi_ps(rr2, rr3); \ + t0 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(1, 0, 1, 0)); t1 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 2, 3, 2)); \ + t2 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(1, 0, 1, 0)); t3 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 2, 3, 2)); \ + r0 = _mm512_add_ps(t0, t1); r1 = _mm512_add_ps(t2, t3); t0 = _mm512_add_ps(r0, r1); \ + __m128 s0, s1, s2, s3; \ + s0 = _mm512_extractf32x4_ps(t0, 0); s1 = _mm512_extractf32x4_ps(t0, 1); s2 = _mm512_extractf32x4_ps(t0, 2); s3 = _mm512_extractf32x4_ps(t0, 3); \ + s0 = _mm_maskz_add_ps(mask8, s0, s1); s2 = _mm_maskz_add_ps(mask8, s2, s3); s0 = _mm_maskz_add_ps(mask8, s0, s2); \ + s0 = _mm_maskz_mul_ps(mask8, alpha_128, s0); + +#define REDUCE_M4(N) REDUCE_4(result0##N, result1##N, result2##N, result3##N) +#define REDUCE_N4(M) REDUCE_4(result##M##0, result##M##1, result##M##2, result##M##3) + +#if defined(B0) +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N) +#define STORE_M4(N, s0) _mm_mask_storeu_ps(&C[(j + N)*ldc + i], mask8, s0); +#define STORE_N4(M, s0) _mm_i32scatter_ps(&C[j*ldc + i + M], vindex_n, s0, 4); +#else +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N) + beta * C[(j+N)*ldc + i + M] +#define STORE_M4(N, s0) \ + asm("vfmadd231ps (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_128)); \ + _mm_mask_storeu_ps(&C[(j + N)*ldc + i], mask8, s0); + +#define STORE_N4(M, s0) \ + s0 = _mm_fmadd_ps(_mm_i32gather_ps(&C[j*ldc + i + M], vindex_n, 4), beta_128, s0); \ + _mm_i32scatter_ps(&C[j*ldc + i + M], vindex_n, s0, 4); +#endif +#define STORE_REDUCE_M4(N) {\ + REDUCE_M4(N) \ + STORE_M4(N, s0) \ +} +#define STORE_REDUCE_N4(M) {\ + REDUCE_N4(M) \ + STORE_N4(M, s0) \ +} + + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n4 = N & ~3; + BLASLONG n2 = N & ~1; + + BLASLONG k16 = K & ~15; + + __mmask16 mask; + __mmask8 mask8 = 0xff; // just use to avoid SSE instruction + + __m128i vindex_n = _mm_set_epi32(ldc*3, ldc*2, ldc, 0); + __m128 alpha_128 = _mm_broadcast_ss(&alpha); +#if !defined(B0) + __m128 beta_128 = _mm_broadcast_ss(&beta); +#endif + for (i = 0; i < m4; i += 4) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); STORE_REDUCE_M4(2); STORE_REDUCE_M4(3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_REDUCE_M4(0); + } + + } + for (; i < m2; i += 2) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + STORE_REDUCE_N4(0); STORE_REDUCE_N4(1); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + } + } + for (; i < M; i += 1) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + STORE_REDUCE_N4(0); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_REDUCE(0, 0); + STORE_REDUCE(0, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + STORE_REDUCE(0, 0); + } + } + return 0; +} From 5dc7c3c8e572c1760cd9aba40dde1db54bb3f2e3 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 27 May 2021 11:03:56 +0000 Subject: [PATCH 20/37] Small Matrix: add GEMM_SMALL_MATRIX_PERMIT to tune small matrics case --- common_c.h | 2 ++ common_d.h | 1 + common_level3.h | 8 +++++ common_macro.h | 10 ++++++ common_s.h | 2 ++ common_z.h | 2 ++ interface/gemm.c | 9 +++--- kernel/Makefile.L3 | 31 ++++++++++++++++++ kernel/generic/gemm_small_matrix_permit.c | 37 ++++++++++++++++++++++ kernel/generic/zgemm_small_matrix_permit.c | 37 ++++++++++++++++++++++ 10 files changed, 135 insertions(+), 4 deletions(-) create mode 100644 kernel/generic/gemm_small_matrix_permit.c create mode 100644 kernel/generic/zgemm_small_matrix_permit.c diff --git a/common_c.h b/common_c.h index 9388ece93..dc273eef0 100644 --- a/common_c.h +++ b/common_c.h @@ -232,6 +232,8 @@ #define CGEADD_K cgeadd_k +#define CGEMM_SMALL_MATRIX_PERMIT cgemm_small_matrix_permit + #define CGEMM_SMALL_KERNEL_NN cgemm_small_kernel_nn #define CGEMM_SMALL_KERNEL_NT cgemm_small_kernel_nt #define CGEMM_SMALL_KERNEL_NR cgemm_small_kernel_nr diff --git a/common_d.h b/common_d.h index 42c14e828..bb85f1232 100644 --- a/common_d.h +++ b/common_d.h @@ -157,6 +157,7 @@ #define DIMATCOPY_K_RT dimatcopy_k_rt #define DGEADD_K dgeadd_k +#define DGEMM_SMALL_MATRIX_PERMIT dgemm_small_matrix_permit #define DGEMM_SMALL_KERNEL_NN dgemm_small_kernel_nn #define DGEMM_SMALL_KERNEL_NT dgemm_small_kernel_nt diff --git a/common_level3.h b/common_level3.h index a3a487dab..187402a9a 100644 --- a/common_level3.h +++ b/common_level3.h @@ -516,11 +516,15 @@ int qgemm_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, xdouble *, xd #endif #ifdef SMALL_MATRIX_OPT +int sgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float beta); + int sgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); int sgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); int sgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); int sgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); +int dgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, double alpha, double beta); + int dgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); int dgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); int dgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); @@ -536,6 +540,8 @@ int dgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLA int dgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); int dgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int cgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha0, float alpha1, float beta0, float beta1); + int cgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); int cgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); int cgemm_small_kernel_nr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); @@ -556,6 +562,8 @@ int cgemm_small_kernel_ct(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLON int cgemm_small_kernel_cr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); int cgemm_small_kernel_cc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int zgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, double alpha0, double alpha1, double beta0, double beta1); + int zgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); int zgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); int zgemm_small_kernel_nr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); diff --git a/common_macro.h b/common_macro.h index 2cccf9b39..aeb9a205b 100644 --- a/common_macro.h +++ b/common_macro.h @@ -644,6 +644,8 @@ #define GEADD_K DGEADD_K +#define GEMM_SMALL_MATRIX_PERMIT DGEMM_SMALL_MATRIX_PERMIT + #define GEMM_SMALL_KERNEL_NN DGEMM_SMALL_KERNEL_NN #define GEMM_SMALL_KERNEL_NT DGEMM_SMALL_KERNEL_NT #define GEMM_SMALL_KERNEL_TN DGEMM_SMALL_KERNEL_TN @@ -940,6 +942,8 @@ #define GEADD_K SGEADD_K +#define GEMM_SMALL_MATRIX_PERMIT SGEMM_SMALL_MATRIX_PERMIT + #define GEMM_SMALL_KERNEL_NN SGEMM_SMALL_KERNEL_NN #define GEMM_SMALL_KERNEL_NT SGEMM_SMALL_KERNEL_NT #define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN @@ -1256,6 +1260,8 @@ #define GEADD_K SGEADD_K +#define GEMM_SMALL_MATRIX_PERMIT SGEMM_SMALL_MATRIX_PERMIT + #define GEMM_SMALL_KERNEL_NN SGEMM_SMALL_KERNEL_NN #define GEMM_SMALL_KERNEL_NT SGEMM_SMALL_KERNEL_NT #define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN @@ -2093,6 +2099,8 @@ #define GEADD_K ZGEADD_K +#define GEMM_SMALL_MATRIX_PERMIT ZGEMM_SMALL_MATRIX_PERMIT + #define GEMM_SMALL_KERNEL_NN ZGEMM_SMALL_KERNEL_NN #define GEMM_SMALL_KERNEL_NT ZGEMM_SMALL_KERNEL_NT #define GEMM_SMALL_KERNEL_NR ZGEMM_SMALL_KERNEL_NR @@ -2556,6 +2564,8 @@ #define GEADD_K CGEADD_K +#define GEMM_SMALL_MATRIX_PERMIT CGEMM_SMALL_MATRIX_PERMIT + #define GEMM_SMALL_KERNEL_NN CGEMM_SMALL_KERNEL_NN #define GEMM_SMALL_KERNEL_NT CGEMM_SMALL_KERNEL_NT #define GEMM_SMALL_KERNEL_NR CGEMM_SMALL_KERNEL_NR diff --git a/common_s.h b/common_s.h index 685d73062..5851014cf 100644 --- a/common_s.h +++ b/common_s.h @@ -164,6 +164,8 @@ #define SGEADD_K sgeadd_k +#define SGEMM_SMALL_MATRIX_PERMIT sgemm_small_matrix_permit + #define SGEMM_SMALL_KERNEL_NN sgemm_small_kernel_nn #define SGEMM_SMALL_KERNEL_NT sgemm_small_kernel_nt #define SGEMM_SMALL_KERNEL_TN sgemm_small_kernel_tn diff --git a/common_z.h b/common_z.h index 8594ec74d..6088260a1 100644 --- a/common_z.h +++ b/common_z.h @@ -232,6 +232,8 @@ #define ZGEADD_K zgeadd_k +#define ZGEMM_SMALL_MATRIX_PERMIT zgemm_small_matrix_permit + #define ZGEMM_SMALL_KERNEL_NN zgemm_small_kernel_nn #define ZGEMM_SMALL_KERNEL_NT zgemm_small_kernel_nt #define ZGEMM_SMALL_KERNEL_NR zgemm_small_kernel_nr diff --git a/interface/gemm.c b/interface/gemm.c index 7251993ee..ad8780668 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -464,25 +464,26 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS #endif #ifdef SMALL_MATRIX_OPT - //need to tune small matrices cases. - if(MNK <= 100.0*100.0*100.0){ - #if !defined(COMPLEX) + if(GEMM_SMALL_MATRIX_PERMIT(transa, transb, args.m, args.n, args.k, *(FLOAT *)(args.alpha), *(FLOAT *)(args.beta))){ if(*(FLOAT *)(args.beta) == 0.0){ (gemm_small_kernel_b0[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, args.c, args.ldc); }else{ (gemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); } + return; + } #else + if(GEMM_SMALL_MATRIX_PERMIT(transa, transb, args.m, args.n, args.k, alpha[0], alpha[1], beta[0], beta[1])){ if(beta[0] == 0.0 && beta[1] == 0.0){ (zgemm_small_kernel_b0[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, args.c, args.ldc); }else{ (zgemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, beta[0], beta[1], args.c, args.ldc); } -#endif return; } #endif +#endif buffer = (XFLOAT *)blas_memory_alloc(0); diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 1c4a00158..f977793a0 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -451,18 +451,21 @@ endif ifeq ($(SMALL_MATRIX_OPT), 1) SBLASOBJS += \ + sgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ sgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ sgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) DBLASOBJS += \ + dgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ dgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ dgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ dgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) CBLASOBJS += \ + cgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ cgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ cgemm_small_kernel_nr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_nc$(TSUFFIX).$(SUFFIX) \ cgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ @@ -481,6 +484,7 @@ CBLASOBJS += \ cgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) ZBLASOBJS += \ + zgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ zgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ zgemm_small_kernel_nr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_nc$(TSUFFIX).$(SUFFIX) \ zgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ @@ -4294,6 +4298,10 @@ $(KDIR)zgeadd_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEADD_K) ###### BLAS small matrix optimization ##### +ifndef DGEMM_SMALL_M_PERMIT +DGEMM_SMALL_M_PERMIT = ../generic/gemm_small_matrix_permit.c +endif + ifndef DGEMM_SMALL_K_NN DGEMM_SMALL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c endif @@ -4310,6 +4318,9 @@ ifndef DGEMM_SMALL_K_TT DGEMM_SMALL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c endif +$(KDIR)dgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_M_PERMIT) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + $(KDIR)dgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_NN) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ @@ -4350,6 +4361,9 @@ $(KDIR)dgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL $(KDIR)dgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_TT) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ +ifndef SGEMM_SMALL_M_PERMIT +SGEMM_SMALL_M_PERMIT = ../generic/gemm_small_matrix_permit.c +endif ifndef SGEMM_SMALL_K_NN SGEMM_SMALL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c @@ -4367,6 +4381,9 @@ ifndef SGEMM_SMALL_K_TT SGEMM_SMALL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c endif +$(KDIR)sgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_M_PERMIT) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + $(KDIR)sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_NN) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ @@ -4407,6 +4424,9 @@ $(KDIR)sgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL $(KDIR)sgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_TT) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ +ifndef CGEMM_SMALL_M_PERMIT +CGEMM_SMALL_M_PERMIT = ../generic/zgemm_small_matrix_permit.c +endif ifndef CGEMM_SMALL_K_NN CGEMM_SMALL_K_NN = ../generic/zgemm_small_matrix_kernel_nn.c @@ -4424,6 +4444,9 @@ ifndef CGEMM_SMALL_K_TT CGEMM_SMALL_K_TT = ../generic/zgemm_small_matrix_kernel_tt.c endif +$(KDIR)cgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_M_PERMIT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX $< -o $@ + $(KDIR)cgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NN) $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNN $< -o $@ @@ -4536,6 +4559,10 @@ $(KDIR)cgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL $(KDIR)cgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCC $< -o $@ +ifndef ZGEMM_SMALL_M_PERMIT +ZGEMM_SMALL_M_PERMIT = ../generic/zgemm_small_matrix_permit.c +endif + ifndef ZGEMM_SMALL_K_NN ZGEMM_SMALL_K_NN = ../generic/zgemm_small_matrix_kernel_nn.c endif @@ -4552,6 +4579,10 @@ ifndef ZGEMM_SMALL_K_TT ZGEMM_SMALL_K_TT = ../generic/zgemm_small_matrix_kernel_tt.c endif +$(KDIR)zgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_M_PERMIT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX $< -o $@ + + $(KDIR)zgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NN) $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNN $< -o $@ diff --git a/kernel/generic/gemm_small_matrix_permit.c b/kernel/generic/gemm_small_matrix_permit.c new file mode 100644 index 000000000..6e1ab1fc1 --- /dev/null +++ b/kernel/generic/gemm_small_matrix_permit.c @@ -0,0 +1,37 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta) +{ + double MNK = (double) M * (double) N * (double) K; + if (MNK <= 100.0*100.0*100.0) + return 1; + else + return 0; +} diff --git a/kernel/generic/zgemm_small_matrix_permit.c b/kernel/generic/zgemm_small_matrix_permit.c new file mode 100644 index 000000000..288937256 --- /dev/null +++ b/kernel/generic/zgemm_small_matrix_permit.c @@ -0,0 +1,37 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha0, FLOAT alpha1, FLOAT beta0, FLOAT beta1) +{ + double MNK = (double) M * (double) N * (double) K; + if (MNK <= 100.0*100.0*100.0) + return 1; + else + return 0; +} From 02c6e764f2e94779ae5699ca2ea8c2189aa9fa02 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 27 May 2021 11:26:49 +0000 Subject: [PATCH 21/37] Small Matrix: skylakex: add SGEMM_SMALL_M_PERMIT and tune for TN kernel --- kernel/x86_64/KERNEL.SKYLAKEX | 1 + .../sgemm_small_kernel_permit_skylakex.c | 50 +++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 kernel/x86_64/sgemm_small_kernel_permit_skylakex.c diff --git a/kernel/x86_64/KERNEL.SKYLAKEX b/kernel/x86_64/KERNEL.SKYLAKEX index 5e0d9e5b4..264e3a9f4 100644 --- a/kernel/x86_64/KERNEL.SKYLAKEX +++ b/kernel/x86_64/KERNEL.SKYLAKEX @@ -10,6 +10,7 @@ STRSMKERNEL_LN = ../generic/trsm_kernel_LN.c STRSMKERNEL_LT = ../generic/trsm_kernel_LT.c STRSMKERNEL_RN = ../generic/trsm_kernel_RN.c STRSMKERNEL_RT = ../generic/trsm_kernel_RT.c +SGEMM_SMALL_M_PERMIT = sgemm_small_kernel_permit_skylakex.c SGEMM_SMALL_K_NN = sgemm_small_kernel_nn_skylakex.c SGEMM_SMALL_K_B0_NN = sgemm_small_kernel_b0_nn_skylakex.c SGEMM_SMALL_K_NT = sgemm_small_kernel_nt_skylakex.c diff --git a/kernel/x86_64/sgemm_small_kernel_permit_skylakex.c b/kernel/x86_64/sgemm_small_kernel_permit_skylakex.c new file mode 100644 index 000000000..159ae10b5 --- /dev/null +++ b/kernel/x86_64/sgemm_small_kernel_permit_skylakex.c @@ -0,0 +1,50 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta) +{ + double MNK = (double) M * (double) N * (double) K; + if (MNK > 100.0*100.0*100.0) // disable for big size matrix + return 0; + // tuning for A transpose + if (transa) { + if (transb) { + return 0; // TT kernel not support yet + } else { // TN kernel + /* TN kernel perform not good when: + * 1. C matrix is too big + * 2. K is too small + */ + if (M * N > 1200 || K < 32) + return 0; + } + } + + return 1; +} From 72e070539cd13364c8a02ac34e3dfcd65b657c7a Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Mon, 31 May 2021 14:53:03 +0000 Subject: [PATCH 22/37] Small Matrix: skylakex: add sgemm tt kernel --- kernel/x86_64/KERNEL.SKYLAKEX | 2 + .../sgemm_small_kernel_b0_tt_skylakex.c | 3 + .../sgemm_small_kernel_permit_skylakex.c | 7 +- .../x86_64/sgemm_small_kernel_tt_skylakex.c | 414 ++++++++++++++++++ 4 files changed, 424 insertions(+), 2 deletions(-) create mode 100644 kernel/x86_64/sgemm_small_kernel_b0_tt_skylakex.c create mode 100644 kernel/x86_64/sgemm_small_kernel_tt_skylakex.c diff --git a/kernel/x86_64/KERNEL.SKYLAKEX b/kernel/x86_64/KERNEL.SKYLAKEX index 264e3a9f4..0f58a4d46 100644 --- a/kernel/x86_64/KERNEL.SKYLAKEX +++ b/kernel/x86_64/KERNEL.SKYLAKEX @@ -17,6 +17,8 @@ SGEMM_SMALL_K_NT = sgemm_small_kernel_nt_skylakex.c SGEMM_SMALL_K_B0_NT = sgemm_small_kernel_b0_nt_skylakex.c SGEMM_SMALL_K_TN = sgemm_small_kernel_tn_skylakex.c SGEMM_SMALL_K_B0_TN = sgemm_small_kernel_b0_tn_skylakex.c +SGEMM_SMALL_K_TT = sgemm_small_kernel_tt_skylakex.c +SGEMM_SMALL_K_B0_TT = sgemm_small_kernel_b0_tt_skylakex.c DGEMMKERNEL = dgemm_kernel_16x2_skylakex.c DTRMMKERNEL = dgemm_kernel_16x2_skylakex.c diff --git a/kernel/x86_64/sgemm_small_kernel_b0_tt_skylakex.c b/kernel/x86_64/sgemm_small_kernel_b0_tt_skylakex.c new file mode 100644 index 000000000..27d9e0afd --- /dev/null +++ b/kernel/x86_64/sgemm_small_kernel_b0_tt_skylakex.c @@ -0,0 +1,3 @@ +#define B0 1 +#define TT 1 +#include "./sgemm_small_kernel_tt_skylakex.c" diff --git a/kernel/x86_64/sgemm_small_kernel_permit_skylakex.c b/kernel/x86_64/sgemm_small_kernel_permit_skylakex.c index 159ae10b5..cbf2374bd 100644 --- a/kernel/x86_64/sgemm_small_kernel_permit_skylakex.c +++ b/kernel/x86_64/sgemm_small_kernel_permit_skylakex.c @@ -35,8 +35,11 @@ int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alph // tuning for A transpose if (transa) { if (transb) { - return 0; // TT kernel not support yet - } else { // TN kernel + /* TT kernel perform not good when: + * 1. K is too small. + */ + if (K < 4) return 0; + } else { /* TN kernel perform not good when: * 1. C matrix is too big * 2. K is too small diff --git a/kernel/x86_64/sgemm_small_kernel_tt_skylakex.c b/kernel/x86_64/sgemm_small_kernel_tt_skylakex.c new file mode 100644 index 000000000..8da560ef7 --- /dev/null +++ b/kernel/x86_64/sgemm_small_kernel_tt_skylakex.c @@ -0,0 +1,414 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include + +#define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() +#define BROADCAST_LOAD_A_512(M, N) __m512 Aval##M = _mm512_broadcastss_ps(_mm_load_ss(&A[k + lda * (i+M)])) +#define LOAD_B_512(M,N) __m512 Bval##N = _mm512_loadu_ps(&B[ldb * k + j + (N*16)]) +#define MASK_LOAD_B_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[ldb * k + j + (N*16)]) +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) + +#if defined(B0) +#define STORE_8xy(v, N, x, y) _mm256_storeu_ps(&C[(j + N*16 + x + y*8)*ldc + i], v) +#define STORE_4xy(v, N, x, y) _mm_mask_storeu_ps(&C[(j + N*16 + x + y*4)*ldc + i], mask8, v) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4); +#else +#define STORE_8xy(v, N, x, y) \ + asm("vfmadd231ps (%1), %2, %0": "+v"(v): "r"(&C[(j + N*16 + x + y*8)*ldc + i]), "v"(beta_256)); \ + _mm256_storeu_ps(&C[(j + N*16 + x + y*8)*ldc + i], v) +#define STORE_4xy(v, N, x, y) \ + asm("vfmadd231ps (%1), %2, %0": "+v"(v): "r"(&C[(j + N*16 + x + y*4)*ldc + i]), "v"(beta_128)); \ + _mm_mask_storeu_ps(&C[(j + N*16 + x + y*4)*ldc + i], mask8, v) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + __m512 tmp##M##N = _mm512_i32gather_ps(vindex_n, &C[(j + N*16)*ldc + i + M], 4); \ + result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \ + _mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + __m512 tmp##M##N = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, vindex_n, &C[(j + N*16)*ldc + i + M], 4); \ + result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \ + _mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4); +#endif + +#define REORDER_8x16(r0, r1, r2, r3, r4, r5, r6, r7) \ + __m512 t0, t1, t2, t3, t4, t5, t6, t7, v; \ + t0 = _mm512_unpacklo_ps(r0, r1); \ + t1 = _mm512_unpackhi_ps(r0, r1); \ + t2 = _mm512_unpacklo_ps(r2, r3); \ + t3 = _mm512_unpackhi_ps(r2, r3); \ + t4 = _mm512_unpacklo_ps(r4, r5); \ + t5 = _mm512_unpackhi_ps(r4, r5); \ + t6 = _mm512_unpacklo_ps(r6, r7); \ + t7 = _mm512_unpackhi_ps(r6, r7); \ + v = _mm512_shuffle_ps(t0, t2, 0x4E); \ + r0 = _mm512_mask_blend_ps(kc, t0, v); \ + r1 = _mm512_mask_blend_ps(k3, t2, v); \ + v = _mm512_shuffle_ps(t1, t3, 0x4E); \ + r2 = _mm512_mask_blend_ps(kc, t1, v); \ + r3 = _mm512_mask_blend_ps(k3, t3, v); \ + v = _mm512_shuffle_ps(t4, t6, 0x4E); \ + r4 = _mm512_mask_blend_ps(kc, t4, v); \ + r5 = _mm512_mask_blend_ps(k3, t6, v); \ + v = _mm512_shuffle_ps(t5, t7, 0x4E); \ + r6 = _mm512_mask_blend_ps(kc, t5, v); \ + r7 = _mm512_mask_blend_ps(k3, t7, v); \ + t0 = _mm512_permutex2var_ps(r0, idx_lo, r4); \ + t1 = _mm512_permutex2var_ps(r1, idx_lo, r5); \ + t2 = _mm512_permutex2var_ps(r2, idx_lo, r6); \ + t3 = _mm512_permutex2var_ps(r3, idx_lo, r7); \ + t4 = _mm512_permutex2var_ps(r0, idx_hi, r4); \ + t5 = _mm512_permutex2var_ps(r1, idx_hi, r5); \ + t6 = _mm512_permutex2var_ps(r2, idx_hi, r6); \ + t7 = _mm512_permutex2var_ps(r3, idx_hi, r7); \ + t0 = _mm512_mul_ps(t0, alpha_512); \ + t1 = _mm512_mul_ps(t1, alpha_512); \ + t2 = _mm512_mul_ps(t2, alpha_512); \ + t3 = _mm512_mul_ps(t3, alpha_512); \ + t4 = _mm512_mul_ps(t4, alpha_512); \ + t5 = _mm512_mul_ps(t5, alpha_512); \ + t6 = _mm512_mul_ps(t6, alpha_512); \ + t7 = _mm512_mul_ps(t7, alpha_512); + +#define SAVE_8(N, x, y) {\ + __m256 v8 = _mm512_extractf32x8_ps(t##x, y); \ + STORE_8xy(v8, N, x, y); \ +} + +#define REORDER_STORE_8x16(N) {\ + REORDER_8x16(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \ + SAVE_8(N, 0, 0); SAVE_8(N, 1, 0); SAVE_8(N, 2, 0); SAVE_8(N, 3, 0); SAVE_8(N, 4, 0); SAVE_8(N, 5, 0); SAVE_8(N, 6, 0); SAVE_8(N, 7, 0); \ + SAVE_8(N, 0, 1); SAVE_8(N, 1, 1); SAVE_8(N, 2, 1); SAVE_8(N, 3, 1); SAVE_8(N, 4, 1); SAVE_8(N, 5, 1); SAVE_8(N, 6, 1); SAVE_8(N, 7, 1); \ +} + +#define MASK_SAVE_8() \ + switch (nn) { \ + case 16: SAVE_8(0, 7, 1); \ + case 15: SAVE_8(0, 6, 1); \ + case 14: SAVE_8(0, 5, 1); \ + case 13: SAVE_8(0, 4, 1); \ + case 12: SAVE_8(0, 3, 1); \ + case 11: SAVE_8(0, 2, 1); \ + case 10: SAVE_8(0, 1, 1); \ + case 9: SAVE_8(0, 0, 1); \ + case 8: SAVE_8(0, 7, 0); \ + case 7: SAVE_8(0, 6, 0); \ + case 6: SAVE_8(0, 5, 0); \ + case 5: SAVE_8(0, 4, 0); \ + case 4: SAVE_8(0, 3, 0); \ + case 3: SAVE_8(0, 2, 0); \ + case 2: SAVE_8(0, 1, 0); \ + case 1: SAVE_8(0, 0, 0); \ + } + +#define MASK_REORDER_STORE_8x16(N) {\ + REORDER_8x16(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \ + MASK_SAVE_8(); \ +} + +#define REORDER_4x16(r0, r1, r2, r3) \ + __m512 t0, t1, t2, t3, v; \ + t0 = _mm512_unpacklo_ps(r0, r1); \ + t1 = _mm512_unpackhi_ps(r0, r1); \ + t2 = _mm512_unpacklo_ps(r2, r3); \ + t3 = _mm512_unpackhi_ps(r2, r3); \ + v = _mm512_shuffle_ps(t0, t2, 0x4E); \ + r0 = _mm512_mask_blend_ps(kc, t0, v); \ + r1 = _mm512_mask_blend_ps(k3, t2, v); \ + v = _mm512_shuffle_ps(t1, t3, 0x4E); \ + r2 = _mm512_mask_blend_ps(kc, t1, v); \ + r3 = _mm512_mask_blend_ps(k3, t3, v); \ + t0 = _mm512_mul_ps(r0, alpha_512); \ + t1 = _mm512_mul_ps(r1, alpha_512); \ + t2 = _mm512_mul_ps(r2, alpha_512); \ + t3 = _mm512_mul_ps(r3, alpha_512); + +#define SAVE_4(N, x, y) {\ + __m128 v4 = _mm512_extractf32x4_ps(t##x, y); \ + STORE_4xy(v4, N, x, y); \ +} + +#define REORDER_STORE_4x16(N) {\ + REORDER_4x16(result0##N, result1##N, result2##N, result3##N); \ + SAVE_4(N, 0, 0); SAVE_4(N, 1, 0); SAVE_4(N, 2, 0); SAVE_4(N, 3, 0); \ + SAVE_4(N, 0, 1); SAVE_4(N, 1, 1); SAVE_4(N, 2, 1); SAVE_4(N, 3, 1); \ + SAVE_4(N, 0, 2); SAVE_4(N, 1, 2); SAVE_4(N, 2, 2); SAVE_4(N, 3, 2); \ + SAVE_4(N, 0, 3); SAVE_4(N, 1, 3); SAVE_4(N, 2, 3); SAVE_4(N, 3, 3); \ +} + +#define MASK_SAVE_4() \ + switch (nn) { \ + case 16: SAVE_4(0, 3, 3); \ + case 15: SAVE_4(0, 2, 3); \ + case 14: SAVE_4(0, 1, 3); \ + case 13: SAVE_4(0, 0, 3); \ + case 12: SAVE_4(0, 3, 2); \ + case 11: SAVE_4(0, 2, 2); \ + case 10: SAVE_4(0, 1, 2); \ + case 9: SAVE_4(0, 0, 2); \ + case 8: SAVE_4(0, 3, 1); \ + case 7: SAVE_4(0, 2, 1); \ + case 6: SAVE_4(0, 1, 1); \ + case 5: SAVE_4(0, 0, 1); \ + case 4: SAVE_4(0, 3, 0); \ + case 3: SAVE_4(0, 2, 0); \ + case 2: SAVE_4(0, 1, 0); \ + case 1: SAVE_4(0, 0, 0); \ + } + +#define MASK_REORDER_STORE_4x16(N) {\ + REORDER_4x16(result0##N, result1##N, result2##N, result3##N); \ + MASK_SAVE_4(); \ +} + + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m8 = M & ~7; + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n64 = N & ~63; + BLASLONG n32 = N & ~31; + + __m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha)); +#if !defined(B0) + __m256 beta_256 = _mm256_broadcastss_ps(_mm_load_ss(&beta)); + __m128 beta_128 = _mm_broadcastss_ps(_mm_load_ss(&beta)); +#endif + int permute_table[] = { + 0x0, 0x1, 0x2, 0x3, 0x10, 0x11, 0x12, 0x13, 0x8, 0x9, 0xa, 0xb, 0x18, 0x19, 0x1a, 0x1b, + 0x4, 0x5, 0x6, 0x7, 0x14, 0x15, 0x16, 0x17, 0xc, 0xd, 0xe, 0xf, 0x1c, 0x1d, 0x1e, 0x1f, + }; + __m512i idx_lo = _mm512_loadu_epi32(permute_table); + __m512i idx_hi = _mm512_loadu_epi32(permute_table + 16); + __mmask16 kc = 0xcccc; + __mmask16 k3 = 0x3333; + __mmask8 mask8 = 0xff; // force use AVX128 instead of SSE + + for (i = 0; i < m8; i += 8) { + for (j = 0; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0); + + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(4, 1); DECLARE_RESULT_512(5, 1); DECLARE_RESULT_512(6, 1); DECLARE_RESULT_512(7, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(4, 1); MATMUL_512(5, 1); MATMUL_512(6, 1); MATMUL_512(7, 1); + } + REORDER_STORE_8x16(0); + REORDER_STORE_8x16(1); + } + __mmask16 mask = 0xffff; + int nn = 16; + for (; j < N; j += 16) { + if (N - j < 16) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0); + } + MASK_REORDER_STORE_8x16(0); + } + } + for (; i < m4; i += 4) { + for (j = 0; j < n64; j += 64) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + REORDER_STORE_4x16(0); + REORDER_STORE_4x16(1); + REORDER_STORE_4x16(2); + REORDER_STORE_4x16(3); + } + for (; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + REORDER_STORE_4x16(0); + REORDER_STORE_4x16(1); + } + __mmask16 mask = 0xffff; + int nn = 16; + for (; j < N; j += 16) { + if (N - j < 16) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + MASK_REORDER_STORE_4x16(0); + } + } + if (i < M) { + int index_n[16]; + for (int ii = 0; ii < 16; ii++) { + index_n[ii] = ii * ldc; + } + __m512i vindex_n = _mm512_loadu_epi32(index_n); +#if !defined(B0) + __m512 beta_512 = _mm512_broadcastss_ps(_mm_load_ss(&beta)); +#endif + for (; i < m2; i += 2) { + for (j = 0; j < n64; j += 64) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); + SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); + } + for (; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + } + __mmask16 mask = 0xffff; + int nn = 16; + for (; j < N; j += 16) { + if (N - j < 16) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); + } + } + for (; i < M; i += 1) { + for (j = 0; j < n64; j += 64) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + SCATTER_STORE_512(0, 2); + SCATTER_STORE_512(0, 3); + } + for (; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + } + __mmask16 mask = 0xffff; + int nn = 16; + for (; j < N; j += 16) { + if (N - j < 16) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_SCATTER_STORE_512(0, 0); + } + } + } + return 0; +} From 91ec21202bd8ae81f15dae79e004b2f00d20e559 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Tue, 1 Jun 2021 11:31:50 +0000 Subject: [PATCH 23/37] Small Matrix: skylakex: add dgemm nn kernel --- kernel/x86_64/KERNEL.SKYLAKEX | 2 + .../dgemm_small_kernel_b0_nn_skylakex.c | 2 + .../x86_64/dgemm_small_kernel_nn_skylakex.c | 590 ++++++++++++++++++ 3 files changed, 594 insertions(+) create mode 100644 kernel/x86_64/dgemm_small_kernel_b0_nn_skylakex.c create mode 100644 kernel/x86_64/dgemm_small_kernel_nn_skylakex.c diff --git a/kernel/x86_64/KERNEL.SKYLAKEX b/kernel/x86_64/KERNEL.SKYLAKEX index 0f58a4d46..a3c6f0556 100644 --- a/kernel/x86_64/KERNEL.SKYLAKEX +++ b/kernel/x86_64/KERNEL.SKYLAKEX @@ -27,6 +27,8 @@ DGEMMITCOPY = dgemm_tcopy_16_skylakex.c DGEMMONCOPY = ../generic/gemm_ncopy_2.c DGEMMOTCOPY = ../generic/gemm_tcopy_2.c DTRSMKERNEL_RN = ../generic/trsm_kernel_RN.c +DGEMM_SMALL_K_NN = dgemm_small_kernel_nn_skylakex.c +DGEMM_SMALL_K_B0_NN = dgemm_small_kernel_b0_nn_skylakex.c SGEMM_BETA = sgemm_beta_skylakex.c DGEMM_BETA = dgemm_beta_skylakex.c diff --git a/kernel/x86_64/dgemm_small_kernel_b0_nn_skylakex.c b/kernel/x86_64/dgemm_small_kernel_b0_nn_skylakex.c new file mode 100644 index 000000000..a58738a25 --- /dev/null +++ b/kernel/x86_64/dgemm_small_kernel_b0_nn_skylakex.c @@ -0,0 +1,2 @@ +#define B0 1 +#include "./dgemm_small_kernel_nn_skylakex.c" diff --git a/kernel/x86_64/dgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/dgemm_small_kernel_nn_skylakex.c new file mode 100644 index 000000000..8ffb899c8 --- /dev/null +++ b/kernel/x86_64/dgemm_small_kernel_nn_skylakex.c @@ -0,0 +1,590 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include +#include + +#define DECLARE_RESULT_512(M, N) __m512d result##M##N = _mm512_setzero_pd() +#define LOAD_A_512(M, N) __m512d Aval##M = _mm512_loadu_pd(&A[lda * k + i + (M*8)]) +#define MASK_LOAD_A_512(M, N) __m512d Aval##M = _mm512_maskz_loadu_pd(mask, &A[lda * k + i + (M*8)]) +#define BROADCAST_LOAD_B_512(M, N) __m512d Bval##N = _mm512_broadcastsd_pd(_mm_load_pd1(&B[k + ldb * (j+N)])) +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_pd(Aval##M, Bval##N, result##M##N) +#if defined(B0) +#define STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N) +#define MASK_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N) +#else +#define STORE_512(M, N) \ + result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + asm("vfmadd231pd (%1), %2, %0": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512)); \ + _mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N) +#define MASK_STORE_512(M, N) \ + result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "k"(mask)); \ + _mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N) +#endif + +#define LOAD_KA_512(M, N) __m512d Aval##M = _mm512_loadu_pd(&mbuf[(mi + M)*K + k]); +#define LOAD_KB_512(M, N) __m512d Bval##N = _mm512_loadu_pd(&B[(j + N)*ldb + k]) +#define MASK_LOAD_KA_512(M, N) __m512d Aval##M = _mm512_maskz_loadu_pd(mask, &mbuf[(mi + M)*K + k]) +#define MASK_LOAD_KB_512(M, N) __m512d Bval##N = _mm512_maskz_loadu_pd(mask, &B[(j + N)*ldb + k]) +#define REDUCE_4(rr0, rr1, rr2, rr3) \ + __m512d r0, r1, r2, r3, t0, t1, t2, t3;\ + r0 = _mm512_unpacklo_pd(rr0, rr1); r1 = _mm512_unpackhi_pd(rr0, rr1); \ + r2 = _mm512_unpacklo_pd(rr2, rr3); r3 = _mm512_unpackhi_pd(rr2, rr3); \ + t0 = _mm512_permutex2var_pd(r0, idx_lo, r2); t1 = _mm512_permutex2var_pd(r1, idx_lo, r3); \ + t2 = _mm512_permutex2var_pd(r0, idx_hi, r2); t3 = _mm512_permutex2var_pd(r1, idx_hi, r3); \ + r0 = _mm512_add_pd(t0, t1); r1 = _mm512_add_pd(t2, t3); t0 = _mm512_add_pd(r0, r1); \ + __m256d s0, s1; \ + s0 = _mm512_extractf64x4_pd(t0, 0); s1 = _mm512_extractf64x4_pd(t0, 1); \ + s0 = _mm256_add_pd(s0, s1); s0 = _mm256_mul_pd(alpha_256, s0); +#define REDUCE_M4(N) REDUCE_4(result0##N, result1##N, result2##N, result3##N) +#define REDUCE_N4(M) REDUCE_4(result##M##0, result##M##1, result##M##2, result##M##3) +#if defined(B0) +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N); +#define STORE_REDUCE_M4(N) {\ + REDUCE_M4(N) \ + _mm256_storeu_pd(&C[(j + N)*ldc + i], s0); \ +} +#define STORE_REDUCE_N4(M) {\ + REDUCE_N4(M) \ + _mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8); \ +} +#else +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N) + beta * C[(j+N)*ldc + i + M]; +#define STORE_REDUCE_M4(N) {\ + REDUCE_M4(N) \ + asm("vfmadd231pd (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_256)); \ + _mm256_storeu_pd(&C[(j + N)*ldc + i], s0); \ +} +#define STORE_REDUCE_N4(M) {\ + REDUCE_N4(M) \ + s1 = _mm256_i64gather_pd(&C[j*ldc + i + M], vindex_n, 8); \ + s0 = _mm256_fmadd_pd(s1, beta_256, s0); \ + _mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8); \ +} +#endif + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m32 = M & ~31; + BLASLONG m16 = M & ~15; + BLASLONG m8 = M & ~7; + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n6 = N - (N % 6); + BLASLONG n4 = N & ~3; + BLASLONG n2 = N & ~1; + + + __m512d alpha_512 = _mm512_broadcastsd_pd(_mm_load_pd1(&alpha)); +#if !defined(B0) + __m512d beta_512 = _mm512_broadcastsd_pd(_mm_load_pd1(&beta)); +#endif + + for (i = 0; i < m32; i += 32) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); + STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + } + } + for (; i < m16; i += 16) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); + DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + MATMUL_512(0, 4); MATMUL_512(1, 4); + MATMUL_512(0, 5); MATMUL_512(1, 5); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + STORE_512(0, 4); STORE_512(1, 4); + STORE_512(0, 5); STORE_512(1, 5); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_512(0, 0); STORE_512(1, 0); + } + } + for (; i < m8; i += 8) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + STORE_512(0, 4); + STORE_512(0, 5); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_512(0, 0); + STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + STORE_512(0, 0); + } + } + int mm = M - i; + if (!mm) return 0; + if (mm > 4 || K < 16) { + register __mmask8 mask asm("k1") = (1UL << mm) - 1; + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + MASK_STORE_512(0, 2); + MASK_STORE_512(0, 3); + MASK_STORE_512(0, 4); + MASK_STORE_512(0, 5); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_STORE_512(0, 0); + } + } else { + /* M => [1, 4] + * + * This kernel use dot-like style to calc a value - C(x, y): + * C(x, y) = A(x, 0)*B(0, y) + A(x, 1)*B(1, y) +....+ A(x, K)*B(K, y) + * + * Alloc a buf to copy rest of A as row major, + * so memory access from 0 to K is continuous for both A & B. + * + * Loading to zmm and FMA 8 of k at one loop, + * finally reduce_add zmm to a single float result in C(x, y). + * + * Note: performance is bad when K is small. + */ + FLOAT *mbuf = (FLOAT *) malloc(sizeof(FLOAT)*mm*K); + __mmask8 mask = (1UL << mm) - 1; + BLASLONG k8 = K & ~7; + BLASLONG k4 = K & ~3; + for (k = 0; k < k4; k += 4) { + __m256d r0, r1, r2, r3; + __m256d t0, t1, t2, t3; + r0 = _mm256_maskz_loadu_pd(mask, &A[i + lda*(0 + k)]); + r1 = _mm256_maskz_loadu_pd(mask, &A[i + lda*(1 + k)]); + r2 = _mm256_maskz_loadu_pd(mask, &A[i + lda*(2 + k)]); + r3 = _mm256_maskz_loadu_pd(mask, &A[i + lda*(3 + k)]); + + t0 = _mm256_unpacklo_pd(r0, r1); + t1 = _mm256_unpackhi_pd(r0, r1); + t2 = _mm256_unpacklo_pd(r2, r3); + t3 = _mm256_unpackhi_pd(r2, r3); + + r0 = _mm256_permute2f128_pd(t0, t2, 0x20); + r1 = _mm256_permute2f128_pd(t1, t3, 0x20); + r2 = _mm256_permute2f128_pd(t0, t2, 0x31); + r3 = _mm256_permute2f128_pd(t1, t3, 0x31); + + switch (mm) { + case 4: _mm256_storeu_pd(&mbuf[k + 3*K], r3); + case 3: _mm256_storeu_pd(&mbuf[k + 2*K], r2); + case 2: _mm256_storeu_pd(&mbuf[k + 1*K], r1); + case 1: _mm256_storeu_pd(&mbuf[k + 0*K], r0); + } + } + for (; k < K; k++) { + for (int ii = 0; ii < mm; ii++) { + mbuf[k + ii*K] = A[i + lda*k + ii]; + } + } + int mi = 0; + __m256d alpha_256 = _mm256_broadcast_sd(&alpha); +#if !defined(B0) + __m256d beta_256 = _mm256_broadcast_sd(&beta); +#endif + __m256i vindex_n = _mm256_set_epi64x(ldc*3, ldc*2, ldc*1, 0); + long long permute_table[] = { + 0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8, + 2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8, + }; + __m512i idx_lo = _mm512_loadu_epi32(permute_table); + __m512i idx_hi = _mm512_loadu_epi32(permute_table + 8); + for (; i < m4; i += 4, mi += 4) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); STORE_REDUCE_M4(2); STORE_REDUCE_M4(3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_REDUCE_M4(0); + } + + } + for (; i < m2; i += 2, mi += 2) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + STORE_REDUCE_N4(0); STORE_REDUCE_N4(1); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + } + } + for (; i < M; i += 1, mi += 1) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + STORE_REDUCE_N4(0); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_REDUCE(0, 0); + STORE_REDUCE(0, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + STORE_REDUCE(0, 0); + } + } + free(mbuf); + } + return 0; +} From f57fc932ac39c394e8f89bf7b6df3f1bddd315fd Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Tue, 1 Jun 2021 14:23:56 +0000 Subject: [PATCH 24/37] Small Matrix: skylakex: add dgemm nt kernel --- kernel/x86_64/KERNEL.SKYLAKEX | 2 + .../dgemm_small_kernel_b0_nt_skylakex.c | 2 + .../x86_64/dgemm_small_kernel_nt_skylakex.c | 535 ++++++++++++++++++ 3 files changed, 539 insertions(+) create mode 100644 kernel/x86_64/dgemm_small_kernel_b0_nt_skylakex.c create mode 100644 kernel/x86_64/dgemm_small_kernel_nt_skylakex.c diff --git a/kernel/x86_64/KERNEL.SKYLAKEX b/kernel/x86_64/KERNEL.SKYLAKEX index a3c6f0556..db1e6cbff 100644 --- a/kernel/x86_64/KERNEL.SKYLAKEX +++ b/kernel/x86_64/KERNEL.SKYLAKEX @@ -29,6 +29,8 @@ DGEMMOTCOPY = ../generic/gemm_tcopy_2.c DTRSMKERNEL_RN = ../generic/trsm_kernel_RN.c DGEMM_SMALL_K_NN = dgemm_small_kernel_nn_skylakex.c DGEMM_SMALL_K_B0_NN = dgemm_small_kernel_b0_nn_skylakex.c +DGEMM_SMALL_K_NT = dgemm_small_kernel_nt_skylakex.c +DGEMM_SMALL_K_B0_NT = dgemm_small_kernel_b0_nt_skylakex.c SGEMM_BETA = sgemm_beta_skylakex.c DGEMM_BETA = dgemm_beta_skylakex.c diff --git a/kernel/x86_64/dgemm_small_kernel_b0_nt_skylakex.c b/kernel/x86_64/dgemm_small_kernel_b0_nt_skylakex.c new file mode 100644 index 000000000..eafe2ce49 --- /dev/null +++ b/kernel/x86_64/dgemm_small_kernel_b0_nt_skylakex.c @@ -0,0 +1,2 @@ +#define B0 1 +#include "./dgemm_small_kernel_nt_skylakex.c" diff --git a/kernel/x86_64/dgemm_small_kernel_nt_skylakex.c b/kernel/x86_64/dgemm_small_kernel_nt_skylakex.c new file mode 100644 index 000000000..0a95a68e2 --- /dev/null +++ b/kernel/x86_64/dgemm_small_kernel_nt_skylakex.c @@ -0,0 +1,535 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include +#include + +#define DECLARE_RESULT_512(M, N) __m512d result##M##N = _mm512_setzero_pd() +#define LOAD_A_512(M, N) __m512d Aval##M = _mm512_loadu_pd(&A[lda * k + i + (M*8)]) +#define MASK_LOAD_A_512(M, N) __m512d Aval##M = _mm512_maskz_loadu_pd(mask, &A[lda * k + i + (M*8)]) +#define BROADCAST_LOAD_B_512(M, N) __m512d Bval##N = _mm512_broadcastsd_pd(_mm_load_sd(&B[ldb * k + j + N])) +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_pd(Aval##M, Bval##N, result##M##N) + +#define BROADCAST_LOAD_A_512(M, N) __m512d Aval##M = _mm512_broadcastsd_pd(_mm_load_sd(&A[lda * k + i + M])) +#define LOAD_B_512(M, N) __m512d Bval##N = _mm512_loadu_pd(&B[ldb * k + j + (N*8)]) +#define MASK_LOAD_B_512(M, N) __m512d Bval##N = _mm512_maskz_loadu_pd(mask, &B[ldb * k + j + (N*8)]) +#if defined(B0) +#define STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N) +#define MASK_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_i64scatter_pd(&C[(j + N*8)*ldc + i + M], vindex_n, result##M##N, 8); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_mask_i64scatter_pd(&C[(j + N*8)*ldc + i + M], mask, vindex_n, result##M##N, 8) +#else +#define STORE_512(M, N) \ + result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + asm("vfmadd231pd (%1), %2, %0": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512)); \ + _mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N) +#define MASK_STORE_512(M, N) \ + result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "k"(mask)); \ + _mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + __m512d tmp##M##N = _mm512_i64gather_pd(vindex_n, &C[(j + N*8)*ldc + i + M], 8); \ + result##M##N = _mm512_fmadd_pd(tmp##M##N, beta_512, result##M##N); \ + _mm512_i64scatter_pd(&C[(j + N*8)*ldc + i + M], vindex_n, result##M##N, 8); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + __m512d tmp##M##N = _mm512_mask_i64gather_pd(_mm512_setzero_pd(), mask, vindex_n, &C[(j + N*8)*ldc + i + M], 8); \ + result##M##N = _mm512_fmadd_pd(tmp##M##N, beta_512, result##M##N); \ + _mm512_mask_i64scatter_pd(&C[(j + N*8)*ldc + i + M], mask, vindex_n, result##M##N, 8); +#endif + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m32 = M & ~31; + BLASLONG m16 = M & ~15; + BLASLONG m8 = M & ~7; + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n32 = N & ~31; + BLASLONG n16 = N & ~15; + BLASLONG n8 = N & ~7; + BLASLONG n6 = N - (N % 6); + BLASLONG n4 = N & ~3; + BLASLONG n2 = N & ~1; + + + __m512d alpha_512 = _mm512_broadcastsd_pd(_mm_load_sd(&alpha)); +#if !defined(B0) + __m512d beta_512 = _mm512_broadcastsd_pd(_mm_load_sd(&beta)); +#endif + + for (i = 0; i < m32; i += 32) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); DECLARE_RESULT_512(2, 4); DECLARE_RESULT_512(3, 4); + DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); DECLARE_RESULT_512(2, 5); DECLARE_RESULT_512(3, 5); + + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + MATMUL_512(0, 4); MATMUL_512(1, 4); MATMUL_512(2, 4); MATMUL_512(3, 4); + MATMUL_512(0, 5); MATMUL_512(1, 5); MATMUL_512(2, 5); MATMUL_512(3, 5); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); + STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); + STORE_512(0, 4); STORE_512(1, 4); STORE_512(2, 4); STORE_512(3, 4); + STORE_512(0, 5); STORE_512(1, 5); STORE_512(2, 5); STORE_512(3, 5); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + } + } + for (; i < m16; i += 16) { + for (j = 0; j < n8; j += 8) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); + DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); + DECLARE_RESULT_512(0, 6); DECLARE_RESULT_512(1, 6); + DECLARE_RESULT_512(0, 7); DECLARE_RESULT_512(1, 7); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + MATMUL_512(0, 4); MATMUL_512(1, 4); + MATMUL_512(0, 5); MATMUL_512(1, 5); + MATMUL_512(0, 6); MATMUL_512(1, 6); + MATMUL_512(0, 7); MATMUL_512(1, 7); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + STORE_512(0, 4); STORE_512(1, 4); + STORE_512(0, 5); STORE_512(1, 5); + STORE_512(0, 6); STORE_512(1, 6); + STORE_512(0, 7); STORE_512(1, 7); + } + for (;j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_512(0, 0); STORE_512(1, 0); + } + } + for (; i < m8; i += 8) { + for (j = 0; j < n8; j += 8) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + DECLARE_RESULT_512(0, 6); + DECLARE_RESULT_512(0, 7); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + MATMUL_512(0, 6); + MATMUL_512(0, 7); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + STORE_512(0, 4); + STORE_512(0, 5); + STORE_512(0, 6); + STORE_512(0, 7); + } + for (; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + } + + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_512(0, 0); + STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + STORE_512(0, 0); + } + } + int mm = M - i; + if (mm >= 6) { + register __mmask16 mask asm("k1") = (1UL << mm) - 1; + for (j = 0; j < n8; j += 8) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + DECLARE_RESULT_512(0, 6); + DECLARE_RESULT_512(0, 7); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + MATMUL_512(0, 6); + MATMUL_512(0, 7); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + MASK_STORE_512(0, 2); + MASK_STORE_512(0, 3); + MASK_STORE_512(0, 4); + MASK_STORE_512(0, 5); + MASK_STORE_512(0, 6); + MASK_STORE_512(0, 7); + } + for (; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + MASK_STORE_512(0, 2); + MASK_STORE_512(0, 3); + } + + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_STORE_512(0, 0); + } + } else if (mm > 0) { + long long index_n[8]; + for (int ii = 0; ii < 8; ii++) { + index_n[ii] = ii * ldc; + } + __m512i vindex_n = _mm512_loadu_epi64(index_n); + for (; i < m4; i += 4) { + for (j = 0; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + LOAD_B_512(x, 2); + LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); SCATTER_STORE_512(2, 0); SCATTER_STORE_512(3, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); SCATTER_STORE_512(2, 1); SCATTER_STORE_512(3, 1); + SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); SCATTER_STORE_512(2, 2); SCATTER_STORE_512(3, 2); + SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); SCATTER_STORE_512(2, 3); SCATTER_STORE_512(3, 3); + } + for (; j < n16; j += 16) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); SCATTER_STORE_512(2, 0); SCATTER_STORE_512(3, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); SCATTER_STORE_512(2, 1); SCATTER_STORE_512(3, 1); + } + __mmask8 mask = 0xff; + for (; j < N; j += 8) { + int remains = N - j; + if (remains < 8) mask = (1UL << remains) - 1; + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); MASK_SCATTER_STORE_512(2, 0); MASK_SCATTER_STORE_512(3, 0); + } + } + for (; i < m2; i += 2) { + for (j = 0; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + LOAD_B_512(x, 2); + LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); + SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); + } + for (; j < n16; j += 16) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + } + __mmask8 mask = 0xff; + for (; j < N; j += 8) { + int remains = N - j; + if (remains < 8) mask = (1UL << remains) - 1; + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); + } + } + for (; i < M; i += 1) { + for (j = 0; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + LOAD_B_512(x, 2); + LOAD_B_512(x, 3); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + SCATTER_STORE_512(0, 2); + SCATTER_STORE_512(0, 3); + } + for (; j < n16; j += 16) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + } + __mmask8 mask = 0xff; + for (; j < N; j += 8) { + int remains = N - j; + if (remains < 8) mask = (1UL << remains) - 1; + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_SCATTER_STORE_512(0, 0); + } + } + } + return 0; +} From 323d7da4f7c21b0a285af1527a47799c4adf69f4 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 2 Jun 2021 11:45:44 +0000 Subject: [PATCH 25/37] Small Matrix: skylakex: add dgemm tt kernel --- kernel/x86_64/KERNEL.SKYLAKEX | 2 + .../dgemm_small_kernel_b0_tt_skylakex.c | 2 + .../x86_64/dgemm_small_kernel_tt_skylakex.c | 392 ++++++++++++++++++ 3 files changed, 396 insertions(+) create mode 100644 kernel/x86_64/dgemm_small_kernel_b0_tt_skylakex.c create mode 100644 kernel/x86_64/dgemm_small_kernel_tt_skylakex.c diff --git a/kernel/x86_64/KERNEL.SKYLAKEX b/kernel/x86_64/KERNEL.SKYLAKEX index db1e6cbff..3e84e794e 100644 --- a/kernel/x86_64/KERNEL.SKYLAKEX +++ b/kernel/x86_64/KERNEL.SKYLAKEX @@ -31,6 +31,8 @@ DGEMM_SMALL_K_NN = dgemm_small_kernel_nn_skylakex.c DGEMM_SMALL_K_B0_NN = dgemm_small_kernel_b0_nn_skylakex.c DGEMM_SMALL_K_NT = dgemm_small_kernel_nt_skylakex.c DGEMM_SMALL_K_B0_NT = dgemm_small_kernel_b0_nt_skylakex.c +DGEMM_SMALL_K_TT = dgemm_small_kernel_tt_skylakex.c +DGEMM_SMALL_K_B0_TT = dgemm_small_kernel_b0_tt_skylakex.c SGEMM_BETA = sgemm_beta_skylakex.c DGEMM_BETA = dgemm_beta_skylakex.c diff --git a/kernel/x86_64/dgemm_small_kernel_b0_tt_skylakex.c b/kernel/x86_64/dgemm_small_kernel_b0_tt_skylakex.c new file mode 100644 index 000000000..93fab1836 --- /dev/null +++ b/kernel/x86_64/dgemm_small_kernel_b0_tt_skylakex.c @@ -0,0 +1,2 @@ +#define B0 1 +#include "./dgemm_small_kernel_tt_skylakex.c" diff --git a/kernel/x86_64/dgemm_small_kernel_tt_skylakex.c b/kernel/x86_64/dgemm_small_kernel_tt_skylakex.c new file mode 100644 index 000000000..8ff79d2c8 --- /dev/null +++ b/kernel/x86_64/dgemm_small_kernel_tt_skylakex.c @@ -0,0 +1,392 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include + +#define DECLARE_RESULT_512(M, N) __m512d result##M##N = _mm512_setzero_pd() +#define BROADCAST_LOAD_A_512(M, N) __m512d Aval##M = _mm512_broadcastsd_pd(_mm_load_sd(&A[k + lda * (i+M)])) +#define LOAD_B_512(M,N) __m512d Bval##N = _mm512_loadu_pd(&B[ldb * k + j + (N*8)]) +#define MASK_LOAD_B_512(M, N) __m512d Bval##N = _mm512_maskz_loadu_pd(mask, &B[ldb * k + j + (N*8)]) +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_pd(Aval##M, Bval##N, result##M##N) + +#if defined(B0) +#define STORE_8xy(v, N, x, y) _mm512_storeu_pd(&C[(j + N*8 + x + y*8)*ldc + i], v) +#define STORE_4xy(v, N, x, y) _mm256_storeu_pd(&C[(j + N*8 + x + y*4)*ldc + i], v) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_i64scatter_pd(&C[(j + N*8)*ldc + i + M], vindex_n, result##M##N, 8); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_mask_i64scatter_pd(&C[(j + N*8)*ldc + i + M], mask, vindex_n, result##M##N, 8); +#else +#define STORE_8xy(v, N, x, y) \ + asm("vfmadd231pd (%1), %2, %0": "+v"(v): "r"(&C[(j + N*8 + x + y*8)*ldc + i]), "v"(beta_512)); \ + _mm512_storeu_pd(&C[(j + N*8 + x + y*8)*ldc + i], v) +#define STORE_4xy(v, N, x, y) \ + asm("vfmadd231pd (%1), %2, %0": "+v"(v): "r"(&C[(j + N*8 + x + y*4)*ldc + i]), "v"(beta_256)); \ + _mm256_storeu_pd(&C[(j + N*8 + x + y*4)*ldc + i], v) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + __m512d tmp##M##N = _mm512_i64gather_pd(vindex_n, &C[(j + N*8)*ldc + i + M], 8); \ + result##M##N = _mm512_fmadd_pd(tmp##M##N, beta_512, result##M##N); \ + _mm512_i64scatter_pd(&C[(j + N*8)*ldc + i + M], vindex_n, result##M##N, 8); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + __m512d tmp##M##N = _mm512_mask_i64gather_pd(_mm512_setzero_pd(), mask, vindex_n, &C[(j + N*8)*ldc + i + M], 8); \ + result##M##N = _mm512_fmadd_pd(tmp##M##N, beta_512, result##M##N); \ + _mm512_mask_i64scatter_pd(&C[(j + N*8)*ldc + i + M], mask, vindex_n, result##M##N, 8); +#endif + +#define REORDER_8x8(r0, r1, r2, r3, r4, r5, r6, r7) \ + __m512d t0, t1, t2, t3, t4, t5, t6, t7; \ + t0 = _mm512_unpacklo_pd(r0, r1); \ + t1 = _mm512_unpackhi_pd(r0, r1); \ + t2 = _mm512_unpacklo_pd(r2, r3); \ + t3 = _mm512_unpackhi_pd(r2, r3); \ + t4 = _mm512_unpacklo_pd(r4, r5); \ + t5 = _mm512_unpackhi_pd(r4, r5); \ + t6 = _mm512_unpacklo_pd(r6, r7); \ + t7 = _mm512_unpackhi_pd(r6, r7); \ + r0 = _mm512_shuffle_f64x2(t0, t2, 0x88); \ + r1 = _mm512_shuffle_f64x2(t1, t3, 0x88); \ + r2 = _mm512_shuffle_f64x2(t0, t2, 0xdd); \ + r3 = _mm512_shuffle_f64x2(t1, t3, 0xdd); \ + r4 = _mm512_shuffle_f64x2(t4, t6, 0x88); \ + r5 = _mm512_shuffle_f64x2(t5, t7, 0x88); \ + r6 = _mm512_shuffle_f64x2(t4, t6, 0xdd); \ + r7 = _mm512_shuffle_f64x2(t5, t7, 0xdd); \ + t0 = _mm512_permutex2var_pd(r0, idx_lo, r4); \ + t1 = _mm512_permutex2var_pd(r1, idx_lo, r5); \ + t2 = _mm512_permutex2var_pd(r2, idx_lo, r6); \ + t3 = _mm512_permutex2var_pd(r3, idx_lo, r7); \ + t4 = _mm512_permutex2var_pd(r0, idx_hi, r4); \ + t5 = _mm512_permutex2var_pd(r1, idx_hi, r5); \ + t6 = _mm512_permutex2var_pd(r2, idx_hi, r6); \ + t7 = _mm512_permutex2var_pd(r3, idx_hi, r7); \ + t0 = _mm512_mul_pd(t0, alpha_512); \ + t1 = _mm512_mul_pd(t1, alpha_512); \ + t2 = _mm512_mul_pd(t2, alpha_512); \ + t3 = _mm512_mul_pd(t3, alpha_512); \ + t4 = _mm512_mul_pd(t4, alpha_512); \ + t5 = _mm512_mul_pd(t5, alpha_512); \ + t6 = _mm512_mul_pd(t6, alpha_512); \ + t7 = _mm512_mul_pd(t7, alpha_512); + +#define SAVE_8(N, x) {\ + STORE_8xy(t##x, N, x, 0); \ +} + +#define REORDER_STORE_8x8(N) {\ + REORDER_8x8(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \ + SAVE_8(N, 0); SAVE_8(N, 1); SAVE_8(N, 2); SAVE_8(N, 3); SAVE_8(N, 4); SAVE_8(N, 5); SAVE_8(N, 6); SAVE_8(N, 7); \ +} + +#define MASK_SAVE_8() \ + switch (nn) { \ + case 8: SAVE_8(0, 7); \ + case 7: SAVE_8(0, 6); \ + case 6: SAVE_8(0, 5); \ + case 5: SAVE_8(0, 4); \ + case 4: SAVE_8(0, 3); \ + case 3: SAVE_8(0, 2); \ + case 2: SAVE_8(0, 1); \ + case 1: SAVE_8(0, 0); \ + } + +#define MASK_REORDER_STORE_8x8(N) {\ + REORDER_8x8(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \ + MASK_SAVE_8(); \ +} + +#define REORDER_4x8(r0, r1, r2, r3) \ + __m512d t0, t1, t2, t3; \ + t0 = _mm512_unpacklo_pd(r0, r1); \ + t1 = _mm512_unpackhi_pd(r0, r1); \ + t2 = _mm512_unpacklo_pd(r2, r3); \ + t3 = _mm512_unpackhi_pd(r2, r3); \ + r0 = _mm512_permutex2var_pd(t0, idx_lo, t2); \ + r1 = _mm512_permutex2var_pd(t1, idx_lo, t3); \ + r2 = _mm512_permutex2var_pd(t0, idx_hi, t2); \ + r3 = _mm512_permutex2var_pd(t1, idx_hi, t3); \ + t0 = _mm512_mul_pd(r0, alpha_512); \ + t1 = _mm512_mul_pd(r1, alpha_512); \ + t2 = _mm512_mul_pd(r2, alpha_512); \ + t3 = _mm512_mul_pd(r3, alpha_512); + +#define SAVE_4(N, x, y) {\ + __m256d v4 = _mm512_extractf64x4_pd(t##x, y); \ + STORE_4xy(v4, N, x, y); \ +} + +#define REORDER_STORE_4x8(N) {\ + REORDER_4x8(result0##N, result1##N, result2##N, result3##N); \ + SAVE_4(N, 0, 0); SAVE_4(N, 1, 0); SAVE_4(N, 2, 0); SAVE_4(N, 3, 0); \ + SAVE_4(N, 0, 1); SAVE_4(N, 1, 1); SAVE_4(N, 2, 1); SAVE_4(N, 3, 1); \ +} + +#define MASK_SAVE_4() \ + switch (nn) { \ + case 8: SAVE_4(0, 3, 1); \ + case 7: SAVE_4(0, 2, 1); \ + case 6: SAVE_4(0, 1, 1); \ + case 5: SAVE_4(0, 0, 1); \ + case 4: SAVE_4(0, 3, 0); \ + case 3: SAVE_4(0, 2, 0); \ + case 2: SAVE_4(0, 1, 0); \ + case 1: SAVE_4(0, 0, 0); \ + } + +#define MASK_REORDER_STORE_4x8(N) {\ + REORDER_4x8(result0##N, result1##N, result2##N, result3##N); \ + MASK_SAVE_4(); \ +} + + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m8 = M & ~7; + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n32 = N & ~31; + BLASLONG n16 = N & ~15; + + __m512d alpha_512 = _mm512_broadcastsd_pd(_mm_load_sd(&alpha)); +#if !defined(B0) + __m512d beta_512 = _mm512_broadcastsd_pd(_mm_load_sd(&beta)); + __m256d beta_256 = _mm256_broadcastsd_pd(_mm_load_sd(&beta)); +#endif + long long permute_table[] = { + 0, 1, 4, 5, 0|8, 1|8, 4|8, 5|8, + 2, 3, 6, 7, 2|8, 3|8, 6|8, 7|8, + }; + __m512i idx_lo = _mm512_loadu_epi64(permute_table); + __m512i idx_hi = _mm512_loadu_epi64(permute_table + 8); + + for (i = 0; i < m8; i += 8) { + for (j = 0; j < n16; j += 16) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0); + + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(4, 1); DECLARE_RESULT_512(5, 1); DECLARE_RESULT_512(6, 1); DECLARE_RESULT_512(7, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(4, 1); MATMUL_512(5, 1); MATMUL_512(6, 1); MATMUL_512(7, 1); + } + REORDER_STORE_8x8(0); + REORDER_STORE_8x8(1); + } + __mmask8 mask = 0xff; + int nn = 8; + for (; j < N; j += 8) { + if (N - j < 8) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0); + } + MASK_REORDER_STORE_8x8(0); + } + } + for (; i < m4; i += 4) { + long long permute_table2[] = { + 0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8, + 2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8, + }; + idx_lo = _mm512_loadu_epi64(permute_table2); + idx_hi = _mm512_loadu_epi64(permute_table2 + 8); + + for (j = 0; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + REORDER_STORE_4x8(0); + REORDER_STORE_4x8(1); + REORDER_STORE_4x8(2); + REORDER_STORE_4x8(3); + } + for (; j < n16; j += 16) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + REORDER_STORE_4x8(0); + REORDER_STORE_4x8(1); + } + __mmask8 mask = 0xff; + int nn = 8; + for (; j < N; j += 8) { + if (N - j < 8) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + MASK_REORDER_STORE_4x8(0); + } + } + if (i < M) { + long long index_n[8]; + for (int ii = 0; ii < 8; ii++) { + index_n[ii] = ii * ldc; + } + __m512i vindex_n = _mm512_loadu_epi64(index_n); +#if !defined(B0) + __m512d beta_512 = _mm512_broadcastsd_pd(_mm_load_sd(&beta)); +#endif + for (; i < m2; i += 2) { + for (j = 0; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); + SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); + } + for (; j < n16; j += 16) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + } + __mmask8 mask = 0xff; + int nn = 8; + for (; j < N; j += 8) { + if (N - j < 8) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); + } + } + for (; i < M; i += 1) { + for (j = 0; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + SCATTER_STORE_512(0, 2); + SCATTER_STORE_512(0, 3); + } + for (; j < n16; j += 16) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + } + __mmask8 mask = 0xff; + int nn = 8; + for (; j < N; j += 8) { + if (N - j < 8) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_SCATTER_STORE_512(0, 0); + } + } + } + return 0; +} From 3e79f6d89abe60b75a4a504670a676472b2d0918 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 2 Jun 2021 13:56:40 +0000 Subject: [PATCH 26/37] Small Matrix: skylakex: add dgemm tn kernel --- kernel/x86_64/KERNEL.SKYLAKEX | 2 + .../dgemm_small_kernel_b0_tn_skylakex.c | 2 + .../x86_64/dgemm_small_kernel_tn_skylakex.c | 322 ++++++++++++++++++ 3 files changed, 326 insertions(+) create mode 100644 kernel/x86_64/dgemm_small_kernel_b0_tn_skylakex.c create mode 100644 kernel/x86_64/dgemm_small_kernel_tn_skylakex.c diff --git a/kernel/x86_64/KERNEL.SKYLAKEX b/kernel/x86_64/KERNEL.SKYLAKEX index 3e84e794e..c1d8f8e89 100644 --- a/kernel/x86_64/KERNEL.SKYLAKEX +++ b/kernel/x86_64/KERNEL.SKYLAKEX @@ -31,6 +31,8 @@ DGEMM_SMALL_K_NN = dgemm_small_kernel_nn_skylakex.c DGEMM_SMALL_K_B0_NN = dgemm_small_kernel_b0_nn_skylakex.c DGEMM_SMALL_K_NT = dgemm_small_kernel_nt_skylakex.c DGEMM_SMALL_K_B0_NT = dgemm_small_kernel_b0_nt_skylakex.c +DGEMM_SMALL_K_TN = dgemm_small_kernel_tn_skylakex.c +DGEMM_SMALL_K_B0_TN = dgemm_small_kernel_b0_tn_skylakex.c DGEMM_SMALL_K_TT = dgemm_small_kernel_tt_skylakex.c DGEMM_SMALL_K_B0_TT = dgemm_small_kernel_b0_tt_skylakex.c diff --git a/kernel/x86_64/dgemm_small_kernel_b0_tn_skylakex.c b/kernel/x86_64/dgemm_small_kernel_b0_tn_skylakex.c new file mode 100644 index 000000000..1dfa0aaf1 --- /dev/null +++ b/kernel/x86_64/dgemm_small_kernel_b0_tn_skylakex.c @@ -0,0 +1,2 @@ +#define B0 1 +#include "./dgemm_small_kernel_tn_skylakex.c" diff --git a/kernel/x86_64/dgemm_small_kernel_tn_skylakex.c b/kernel/x86_64/dgemm_small_kernel_tn_skylakex.c new file mode 100644 index 000000000..0881f35b2 --- /dev/null +++ b/kernel/x86_64/dgemm_small_kernel_tn_skylakex.c @@ -0,0 +1,322 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include +#include + +#define DECLARE_RESULT_512(M, N) __m512d result##M##N = _mm512_setzero_pd() +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_pd(Aval##M, Bval##N, result##M##N) + +#define LOAD_KA_512(M, N) __m512d Aval##M = _mm512_loadu_pd(&A[(i + M)*lda + k]); +#define LOAD_KB_512(M, N) __m512d Bval##N = _mm512_loadu_pd(&B[(j + N)*ldb + k]) +#define MASK_LOAD_KA_512(M, N) __m512d Aval##M = _mm512_maskz_loadu_pd(mask, &A[(i + M)*lda + k]) +#define MASK_LOAD_KB_512(M, N) __m512d Bval##N = _mm512_maskz_loadu_pd(mask, &B[(j + N)*ldb + k]) + +#define REDUCE_4(rr0, rr1, rr2, rr3) \ + __m512d r0, r1, r2, r3, t0, t1, t2, t3;\ + r0 = _mm512_unpacklo_pd(rr0, rr1); r1 = _mm512_unpackhi_pd(rr0, rr1); \ + r2 = _mm512_unpacklo_pd(rr2, rr3); r3 = _mm512_unpackhi_pd(rr2, rr3); \ + t0 = _mm512_permutex2var_pd(r0, idx_lo, r2); t1 = _mm512_permutex2var_pd(r1, idx_lo, r3); \ + t2 = _mm512_permutex2var_pd(r0, idx_hi, r2); t3 = _mm512_permutex2var_pd(r1, idx_hi, r3); \ + r0 = _mm512_add_pd(t0, t1); r1 = _mm512_add_pd(t2, t3); t0 = _mm512_add_pd(r0, r1); \ + __m256d s0, s1; \ + s0 = _mm512_extractf64x4_pd(t0, 0); s1 = _mm512_extractf64x4_pd(t0, 1); \ + s0 = _mm256_add_pd(s0, s1); s0 = _mm256_mul_pd(alpha_256, s0); + +#define REDUCE_M4(N) REDUCE_4(result0##N, result1##N, result2##N, result3##N) +#define REDUCE_N4(M) REDUCE_4(result##M##0, result##M##1, result##M##2, result##M##3) + +#if defined(B0) +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N) +#define STORE_M4(N, s0) _mm256_storeu_pd(&C[(j + N)*ldc + i], s0); +#define STORE_N4(M, s0) _mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8); +#else +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N) + beta * C[(j+N)*ldc + i + M] +#define STORE_M4(N, s0) \ + asm("vfmadd231pd (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_256)); \ + _mm256_storeu_pd(&C[(j + N)*ldc + i], s0); + +#define STORE_N4(M, s0) \ + s0 = _mm256_fmadd_pd(_mm256_i64gather_pd(&C[j*ldc + i + M], vindex_n, 8), beta_256, s0); \ + _mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8); +#endif +#define STORE_REDUCE_M4(N) {\ + REDUCE_M4(N) \ + STORE_M4(N, s0) \ +} +#define STORE_REDUCE_N4(M) {\ + REDUCE_N4(M) \ + STORE_N4(M, s0) \ +} + + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n4 = N & ~3; + BLASLONG n2 = N & ~1; + + BLASLONG k8 = K & ~7; + + __mmask8 mask; + + __m256i vindex_n = _mm256_set_epi64x(ldc*3, ldc*2, ldc, 0); + __m256d alpha_256 = _mm256_broadcast_sd(&alpha); +#if !defined(B0) + __m256d beta_256 = _mm256_broadcast_sd(&beta); +#endif + + long long permute_table[] = { + 0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8, + 2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8, + }; + __m512i idx_lo = _mm512_loadu_epi64(permute_table); + __m512i idx_hi = _mm512_loadu_epi64(permute_table + 8); + + for (i = 0; i < m4; i += 4) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); STORE_REDUCE_M4(2); STORE_REDUCE_M4(3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_REDUCE_M4(0); + } + + } + for (; i < m2; i += 2) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + STORE_REDUCE_N4(0); STORE_REDUCE_N4(1); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + } + } + for (; i < M; i += 1) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + STORE_REDUCE_N4(0); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_REDUCE(0, 0); + STORE_REDUCE(0, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + STORE_REDUCE(0, 0); + } + } + return 0; +} From 8592c21af4d6328068b87f402a6801b30e2aebec Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 2 Jun 2021 13:57:39 +0000 Subject: [PATCH 27/37] Small Matrix: skylakex: dgemm nn: fix typo in idx load --- kernel/x86_64/dgemm_small_kernel_nn_skylakex.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kernel/x86_64/dgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/dgemm_small_kernel_nn_skylakex.c index 8ffb899c8..ff2a04beb 100644 --- a/kernel/x86_64/dgemm_small_kernel_nn_skylakex.c +++ b/kernel/x86_64/dgemm_small_kernel_nn_skylakex.c @@ -372,8 +372,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp 0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8, 2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8, }; - __m512i idx_lo = _mm512_loadu_epi32(permute_table); - __m512i idx_hi = _mm512_loadu_epi32(permute_table + 8); + __m512i idx_lo = _mm512_loadu_epi64(permute_table); + __m512i idx_hi = _mm512_loadu_epi64(permute_table + 8); for (; i < m4; i += 4, mi += 4) { for (j = 0; j < n4; j += 4) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); From fa777f5517d4b43acfda8b8a58649af94c1e40b4 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 2 Jun 2021 14:55:54 +0000 Subject: [PATCH 28/37] Small Matrix: skylakex: add DGEMM_SMALL_M_PERMIT and tune for TN kernel --- kernel/x86_64/KERNEL.SKYLAKEX | 1 + .../dgemm_small_kernel_permit_skylakex.c | 44 +++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 kernel/x86_64/dgemm_small_kernel_permit_skylakex.c diff --git a/kernel/x86_64/KERNEL.SKYLAKEX b/kernel/x86_64/KERNEL.SKYLAKEX index c1d8f8e89..eb0cbaf98 100644 --- a/kernel/x86_64/KERNEL.SKYLAKEX +++ b/kernel/x86_64/KERNEL.SKYLAKEX @@ -27,6 +27,7 @@ DGEMMITCOPY = dgemm_tcopy_16_skylakex.c DGEMMONCOPY = ../generic/gemm_ncopy_2.c DGEMMOTCOPY = ../generic/gemm_tcopy_2.c DTRSMKERNEL_RN = ../generic/trsm_kernel_RN.c +DGEMM_SMALL_M_PERMIT = dgemm_small_kernel_permit_skylakex.c DGEMM_SMALL_K_NN = dgemm_small_kernel_nn_skylakex.c DGEMM_SMALL_K_B0_NN = dgemm_small_kernel_b0_nn_skylakex.c DGEMM_SMALL_K_NT = dgemm_small_kernel_nt_skylakex.c diff --git a/kernel/x86_64/dgemm_small_kernel_permit_skylakex.c b/kernel/x86_64/dgemm_small_kernel_permit_skylakex.c new file mode 100644 index 000000000..9cca08e71 --- /dev/null +++ b/kernel/x86_64/dgemm_small_kernel_permit_skylakex.c @@ -0,0 +1,44 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta) +{ + double MNK = (double) M * (double) N * (double) K; + if (MNK > 100.0*100.0*100.0) // disable for big size matrix + return 0; + if (transa && !transb) { + /* TN kernel perform not good when: + * 1. C matrix is too big + * 2. K is too small + */ + if (M * N > 1200 || K < 32) + return 0; + } + return 1; +} From 6b58bca18b427a0c149d25542a5eb7c5ada6a19f Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Tue, 15 Jun 2021 16:09:51 +0000 Subject: [PATCH 29/37] Small Matrix: disable low performance default kernel --- kernel/generic/gemm_small_matrix_permit.c | 3 +++ kernel/generic/zgemm_small_matrix_permit.c | 3 +++ 2 files changed, 6 insertions(+) diff --git a/kernel/generic/gemm_small_matrix_permit.c b/kernel/generic/gemm_small_matrix_permit.c index 6e1ab1fc1..1ae6d2520 100644 --- a/kernel/generic/gemm_small_matrix_permit.c +++ b/kernel/generic/gemm_small_matrix_permit.c @@ -29,9 +29,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta) { + return 0; +/* double MNK = (double) M * (double) N * (double) K; if (MNK <= 100.0*100.0*100.0) return 1; else return 0; +*/ } diff --git a/kernel/generic/zgemm_small_matrix_permit.c b/kernel/generic/zgemm_small_matrix_permit.c index 288937256..940ff5dc8 100644 --- a/kernel/generic/zgemm_small_matrix_permit.c +++ b/kernel/generic/zgemm_small_matrix_permit.c @@ -29,9 +29,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha0, FLOAT alpha1, FLOAT beta0, FLOAT beta1) { + return 0; +/* double MNK = (double) M * (double) N * (double) K; if (MNK <= 100.0*100.0*100.0) return 1; else return 0; +*/ } From 478d1086c11f28903395bd13050dbca62aec81ef Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 4 Aug 2021 03:12:41 +0000 Subject: [PATCH 30/37] Small Matrix: support DYNAMIC_ARCH build --- common_c.h | 83 +++++++++++++++-------------- common_d.h | 23 ++++---- common_param.h | 119 ++++++++++++++++++++++++++++++++++++++++++ common_s.h | 23 ++++---- common_z.h | 83 +++++++++++++++-------------- interface/gemm.c | 50 ++++++++++-------- kernel/setparam-ref.c | 37 +++++++++++++ 7 files changed, 295 insertions(+), 123 deletions(-) diff --git a/common_c.h b/common_c.h index dc273eef0..6cff610bb 100644 --- a/common_c.h +++ b/common_c.h @@ -234,46 +234,6 @@ #define CGEMM_SMALL_MATRIX_PERMIT cgemm_small_matrix_permit -#define CGEMM_SMALL_KERNEL_NN cgemm_small_kernel_nn -#define CGEMM_SMALL_KERNEL_NT cgemm_small_kernel_nt -#define CGEMM_SMALL_KERNEL_NR cgemm_small_kernel_nr -#define CGEMM_SMALL_KERNEL_NC cgemm_small_kernel_nc - -#define CGEMM_SMALL_KERNEL_TN cgemm_small_kernel_tn -#define CGEMM_SMALL_KERNEL_TT cgemm_small_kernel_tt -#define CGEMM_SMALL_KERNEL_TR cgemm_small_kernel_tr -#define CGEMM_SMALL_KERNEL_TC cgemm_small_kernel_tc - -#define CGEMM_SMALL_KERNEL_RN cgemm_small_kernel_rn -#define CGEMM_SMALL_KERNEL_RT cgemm_small_kernel_rt -#define CGEMM_SMALL_KERNEL_RR cgemm_small_kernel_rr -#define CGEMM_SMALL_KERNEL_RC cgemm_small_kernel_rc - -#define CGEMM_SMALL_KERNEL_CN cgemm_small_kernel_cn -#define CGEMM_SMALL_KERNEL_CT cgemm_small_kernel_ct -#define CGEMM_SMALL_KERNEL_CR cgemm_small_kernel_cr -#define CGEMM_SMALL_KERNEL_CC cgemm_small_kernel_cc - -#define CGEMM_SMALL_KERNEL_B0_NN cgemm_small_kernel_b0_nn -#define CGEMM_SMALL_KERNEL_B0_NT cgemm_small_kernel_b0_nt -#define CGEMM_SMALL_KERNEL_B0_NR cgemm_small_kernel_b0_nr -#define CGEMM_SMALL_KERNEL_B0_NC cgemm_small_kernel_b0_nc - -#define CGEMM_SMALL_KERNEL_B0_TN cgemm_small_kernel_b0_tn -#define CGEMM_SMALL_KERNEL_B0_TT cgemm_small_kernel_b0_tt -#define CGEMM_SMALL_KERNEL_B0_TR cgemm_small_kernel_b0_tr -#define CGEMM_SMALL_KERNEL_B0_TC cgemm_small_kernel_b0_tc - -#define CGEMM_SMALL_KERNEL_B0_RN cgemm_small_kernel_b0_rn -#define CGEMM_SMALL_KERNEL_B0_RT cgemm_small_kernel_b0_rt -#define CGEMM_SMALL_KERNEL_B0_RR cgemm_small_kernel_b0_rr -#define CGEMM_SMALL_KERNEL_B0_RC cgemm_small_kernel_b0_rc - -#define CGEMM_SMALL_KERNEL_B0_CN cgemm_small_kernel_b0_cn -#define CGEMM_SMALL_KERNEL_B0_CT cgemm_small_kernel_b0_ct -#define CGEMM_SMALL_KERNEL_B0_CR cgemm_small_kernel_b0_cr -#define CGEMM_SMALL_KERNEL_B0_CC cgemm_small_kernel_b0_cc - #else #define CAMAX_K gotoblas -> camax_k @@ -468,8 +428,51 @@ #define CGEADD_K gotoblas -> cgeadd_k +#define CGEMM_SMALL_MATRIX_PERMIT gotoblas -> cgemm_small_matrix_permit + #endif +#define CGEMM_SMALL_KERNEL_NN FUNC_OFFSET(cgemm_small_kernel_nn) +#define CGEMM_SMALL_KERNEL_NT FUNC_OFFSET(cgemm_small_kernel_nt) +#define CGEMM_SMALL_KERNEL_NR FUNC_OFFSET(cgemm_small_kernel_nr) +#define CGEMM_SMALL_KERNEL_NC FUNC_OFFSET(cgemm_small_kernel_nc) + +#define CGEMM_SMALL_KERNEL_TN FUNC_OFFSET(cgemm_small_kernel_tn) +#define CGEMM_SMALL_KERNEL_TT FUNC_OFFSET(cgemm_small_kernel_tt) +#define CGEMM_SMALL_KERNEL_TR FUNC_OFFSET(cgemm_small_kernel_tr) +#define CGEMM_SMALL_KERNEL_TC FUNC_OFFSET(cgemm_small_kernel_tc) + +#define CGEMM_SMALL_KERNEL_RN FUNC_OFFSET(cgemm_small_kernel_rn) +#define CGEMM_SMALL_KERNEL_RT FUNC_OFFSET(cgemm_small_kernel_rt) +#define CGEMM_SMALL_KERNEL_RR FUNC_OFFSET(cgemm_small_kernel_rr) +#define CGEMM_SMALL_KERNEL_RC FUNC_OFFSET(cgemm_small_kernel_rc) + +#define CGEMM_SMALL_KERNEL_CN FUNC_OFFSET(cgemm_small_kernel_cn) +#define CGEMM_SMALL_KERNEL_CT FUNC_OFFSET(cgemm_small_kernel_ct) +#define CGEMM_SMALL_KERNEL_CR FUNC_OFFSET(cgemm_small_kernel_cr) +#define CGEMM_SMALL_KERNEL_CC FUNC_OFFSET(cgemm_small_kernel_cc) + +#define CGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(cgemm_small_kernel_b0_nn) +#define CGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(cgemm_small_kernel_b0_nt) +#define CGEMM_SMALL_KERNEL_B0_NR FUNC_OFFSET(cgemm_small_kernel_b0_nr) +#define CGEMM_SMALL_KERNEL_B0_NC FUNC_OFFSET(cgemm_small_kernel_b0_nc) + +#define CGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(cgemm_small_kernel_b0_tn) +#define CGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(cgemm_small_kernel_b0_tt) +#define CGEMM_SMALL_KERNEL_B0_TR FUNC_OFFSET(cgemm_small_kernel_b0_tr) +#define CGEMM_SMALL_KERNEL_B0_TC FUNC_OFFSET(cgemm_small_kernel_b0_tc) + +#define CGEMM_SMALL_KERNEL_B0_RN FUNC_OFFSET(cgemm_small_kernel_b0_rn) +#define CGEMM_SMALL_KERNEL_B0_RT FUNC_OFFSET(cgemm_small_kernel_b0_rt) +#define CGEMM_SMALL_KERNEL_B0_RR FUNC_OFFSET(cgemm_small_kernel_b0_rr) +#define CGEMM_SMALL_KERNEL_B0_RC FUNC_OFFSET(cgemm_small_kernel_b0_rc) + +#define CGEMM_SMALL_KERNEL_B0_CN FUNC_OFFSET(cgemm_small_kernel_b0_cn) +#define CGEMM_SMALL_KERNEL_B0_CT FUNC_OFFSET(cgemm_small_kernel_b0_ct) +#define CGEMM_SMALL_KERNEL_B0_CR FUNC_OFFSET(cgemm_small_kernel_b0_cr) +#define CGEMM_SMALL_KERNEL_B0_CC FUNC_OFFSET(cgemm_small_kernel_b0_cc) + + #define CGEMM_NN cgemm_nn #define CGEMM_CN cgemm_cn #define CGEMM_TN cgemm_tn diff --git a/common_d.h b/common_d.h index bb85f1232..6f4bb2ded 100644 --- a/common_d.h +++ b/common_d.h @@ -159,16 +159,6 @@ #define DGEMM_SMALL_MATRIX_PERMIT dgemm_small_matrix_permit -#define DGEMM_SMALL_KERNEL_NN dgemm_small_kernel_nn -#define DGEMM_SMALL_KERNEL_NT dgemm_small_kernel_nt -#define DGEMM_SMALL_KERNEL_TN dgemm_small_kernel_tn -#define DGEMM_SMALL_KERNEL_TT dgemm_small_kernel_tt - -#define DGEMM_SMALL_KERNEL_B0_NN dgemm_small_kernel_b0_nn -#define DGEMM_SMALL_KERNEL_B0_NT dgemm_small_kernel_b0_nt -#define DGEMM_SMALL_KERNEL_B0_TN dgemm_small_kernel_b0_tn -#define DGEMM_SMALL_KERNEL_B0_TT dgemm_small_kernel_b0_tt - #else #define DAMAX_K gotoblas -> damax_k @@ -293,8 +283,21 @@ #define DGEADD_K gotoblas -> dgeadd_k +#define DGEMM_SMALL_MATRIX_PERMIT gotoblas -> dgemm_small_matrix_permit + #endif +#define DGEMM_SMALL_KERNEL_NN FUNC_OFFSET(dgemm_small_kernel_nn) +#define DGEMM_SMALL_KERNEL_NT FUNC_OFFSET(dgemm_small_kernel_nt) +#define DGEMM_SMALL_KERNEL_TN FUNC_OFFSET(dgemm_small_kernel_tn) +#define DGEMM_SMALL_KERNEL_TT FUNC_OFFSET(dgemm_small_kernel_tt) + +#define DGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(dgemm_small_kernel_b0_nn) +#define DGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(dgemm_small_kernel_b0_nt) +#define DGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(dgemm_small_kernel_b0_tn) +#define DGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(dgemm_small_kernel_b0_tt) + + #define DGEMM_NN dgemm_nn #define DGEMM_CN dgemm_tn #define DGEMM_TN dgemm_tn diff --git a/common_param.h b/common_param.h index 3e3ae06f8..7e8bea4fe 100644 --- a/common_param.h +++ b/common_param.h @@ -207,6 +207,20 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); int (*sgemm_otcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); #endif #ifdef BUILD_SINGLE +#ifdef SMALL_MATRIX_OPT + int (*sgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float beta); + + int (*sgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + int (*sgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + int (*sgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + int (*sgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + + int (*sgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*sgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*sgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*sgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +#endif + int (*strsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); int (*strsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); int (*strsm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); @@ -314,6 +328,19 @@ BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG); int (*dgemm_otcopy )(BLASLONG, BLASLONG, double *, BLASLONG, double *); #endif #ifdef BUILD_DOUBLE +#ifdef SMALL_MATRIX_OPT + int (*dgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, double alpha, double beta); + + int (*dgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); + int (*dgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); + int (*dgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); + int (*dgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); + + int (*dgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*dgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*dgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*dgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +#endif int (*dtrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG, BLASLONG); int (*dtrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG, BLASLONG); int (*dtrsm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG, BLASLONG); @@ -513,6 +540,50 @@ BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG); int (*cgemm_oncopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*cgemm_otcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); +#ifdef SMALL_MATRIX_OPT + int (*cgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha0, float alpha1, float beta0, float beta1); + + int (*cgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_nr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_nc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + + int (*cgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_tr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_tc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + + int (*cgemm_small_kernel_rn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_rt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_rr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_rc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + + int (*cgemm_small_kernel_cn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_ct )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_cr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_cc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + + int (*cgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_nr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_nc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + + int (*cgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_tr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_tc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + + int (*cgemm_small_kernel_b0_rn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_rt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_rr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_rc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + + int (*cgemm_small_kernel_b0_cn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_ct )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_cr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_cc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +#endif + int (*ctrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG, BLASLONG); int (*ctrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG, BLASLONG); int (*ctrsm_kernel_LR)(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG, BLASLONG); @@ -679,6 +750,50 @@ BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG); int (*zgemm_oncopy )(BLASLONG, BLASLONG, double *, BLASLONG, double *); int (*zgemm_otcopy )(BLASLONG, BLASLONG, double *, BLASLONG, double *); +#ifdef SMALL_MATRIX_OPT + int (*zgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, double alpha0, double alpha1, double beta0, double beta1); + + int (*zgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_nr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_nc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + + int (*zgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_tr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_tc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + + int (*zgemm_small_kernel_rn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_rt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_rr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_rc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + + int (*zgemm_small_kernel_cn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_ct )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_cr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_cc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + + int (*zgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_nr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_nc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + + int (*zgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_tr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_tc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + + int (*zgemm_small_kernel_b0_rn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_rt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_rr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_rc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + + int (*zgemm_small_kernel_b0_cn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_ct )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_cr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_cc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +#endif + int (*ztrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG, BLASLONG); int (*ztrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG, BLASLONG); int (*ztrsm_kernel_LR)(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG, BLASLONG); @@ -1069,6 +1184,8 @@ BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG); extern gotoblas_t *gotoblas; +#define FUNC_OFFSET(func) (size_t)(&((gotoblas_t *)NULL)->func) + #define DTB_ENTRIES gotoblas -> dtb_entries #define GEMM_OFFSET_A gotoblas -> offsetA #define GEMM_OFFSET_B gotoblas -> offsetB @@ -1174,6 +1291,8 @@ extern gotoblas_t *gotoblas; #else +#define FUNC_OFFSET(func) (size_t)(func) + #define DTB_ENTRIES DTB_DEFAULT_ENTRIES #define GEMM_OFFSET_A GEMM_DEFAULT_OFFSET_A diff --git a/common_s.h b/common_s.h index 5851014cf..fdd80b62f 100644 --- a/common_s.h +++ b/common_s.h @@ -166,16 +166,6 @@ #define SGEMM_SMALL_MATRIX_PERMIT sgemm_small_matrix_permit -#define SGEMM_SMALL_KERNEL_NN sgemm_small_kernel_nn -#define SGEMM_SMALL_KERNEL_NT sgemm_small_kernel_nt -#define SGEMM_SMALL_KERNEL_TN sgemm_small_kernel_tn -#define SGEMM_SMALL_KERNEL_TT sgemm_small_kernel_tt - -#define SGEMM_SMALL_KERNEL_B0_NN sgemm_small_kernel_b0_nn -#define SGEMM_SMALL_KERNEL_B0_NT sgemm_small_kernel_b0_nt -#define SGEMM_SMALL_KERNEL_B0_TN sgemm_small_kernel_b0_tn -#define SGEMM_SMALL_KERNEL_B0_TT sgemm_small_kernel_b0_tt - #else #define SAMAX_K gotoblas -> samax_k @@ -311,8 +301,21 @@ #define SGEADD_K gotoblas -> sgeadd_k +#define SGEMM_SMALL_MATRIX_PERMIT gotoblas -> sgemm_small_matrix_permit + #endif +#define SGEMM_SMALL_KERNEL_NN FUNC_OFFSET(sgemm_small_kernel_nn) +#define SGEMM_SMALL_KERNEL_NT FUNC_OFFSET(sgemm_small_kernel_nt) +#define SGEMM_SMALL_KERNEL_TN FUNC_OFFSET(sgemm_small_kernel_tn) +#define SGEMM_SMALL_KERNEL_TT FUNC_OFFSET(sgemm_small_kernel_tt) + +#define SGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(sgemm_small_kernel_b0_nn) +#define SGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(sgemm_small_kernel_b0_nt) +#define SGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(sgemm_small_kernel_b0_tn) +#define SGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(sgemm_small_kernel_b0_tt) + + #define SGEMM_NN sgemm_nn #define SGEMM_CN sgemm_tn #define SGEMM_TN sgemm_tn diff --git a/common_z.h b/common_z.h index 6088260a1..c12d71b39 100644 --- a/common_z.h +++ b/common_z.h @@ -234,46 +234,6 @@ #define ZGEMM_SMALL_MATRIX_PERMIT zgemm_small_matrix_permit -#define ZGEMM_SMALL_KERNEL_NN zgemm_small_kernel_nn -#define ZGEMM_SMALL_KERNEL_NT zgemm_small_kernel_nt -#define ZGEMM_SMALL_KERNEL_NR zgemm_small_kernel_nr -#define ZGEMM_SMALL_KERNEL_NC zgemm_small_kernel_nc - -#define ZGEMM_SMALL_KERNEL_TN zgemm_small_kernel_tn -#define ZGEMM_SMALL_KERNEL_TT zgemm_small_kernel_tt -#define ZGEMM_SMALL_KERNEL_TR zgemm_small_kernel_tr -#define ZGEMM_SMALL_KERNEL_TC zgemm_small_kernel_tc - -#define ZGEMM_SMALL_KERNEL_RN zgemm_small_kernel_rn -#define ZGEMM_SMALL_KERNEL_RT zgemm_small_kernel_rt -#define ZGEMM_SMALL_KERNEL_RR zgemm_small_kernel_rr -#define ZGEMM_SMALL_KERNEL_RC zgemm_small_kernel_rc - -#define ZGEMM_SMALL_KERNEL_CN zgemm_small_kernel_cn -#define ZGEMM_SMALL_KERNEL_CT zgemm_small_kernel_ct -#define ZGEMM_SMALL_KERNEL_CR zgemm_small_kernel_cr -#define ZGEMM_SMALL_KERNEL_CC zgemm_small_kernel_cc - -#define ZGEMM_SMALL_KERNEL_B0_NN zgemm_small_kernel_b0_nn -#define ZGEMM_SMALL_KERNEL_B0_NT zgemm_small_kernel_b0_nt -#define ZGEMM_SMALL_KERNEL_B0_NR zgemm_small_kernel_b0_nr -#define ZGEMM_SMALL_KERNEL_B0_NC zgemm_small_kernel_b0_nc - -#define ZGEMM_SMALL_KERNEL_B0_TN zgemm_small_kernel_b0_tn -#define ZGEMM_SMALL_KERNEL_B0_TT zgemm_small_kernel_b0_tt -#define ZGEMM_SMALL_KERNEL_B0_TR zgemm_small_kernel_b0_tr -#define ZGEMM_SMALL_KERNEL_B0_TC zgemm_small_kernel_b0_tc - -#define ZGEMM_SMALL_KERNEL_B0_RN zgemm_small_kernel_b0_rn -#define ZGEMM_SMALL_KERNEL_B0_RT zgemm_small_kernel_b0_rt -#define ZGEMM_SMALL_KERNEL_B0_RR zgemm_small_kernel_b0_rr -#define ZGEMM_SMALL_KERNEL_B0_RC zgemm_small_kernel_b0_rc - -#define ZGEMM_SMALL_KERNEL_B0_CN zgemm_small_kernel_b0_cn -#define ZGEMM_SMALL_KERNEL_B0_CT zgemm_small_kernel_b0_ct -#define ZGEMM_SMALL_KERNEL_B0_CR zgemm_small_kernel_b0_cr -#define ZGEMM_SMALL_KERNEL_B0_CC zgemm_small_kernel_b0_cc - #else #define ZAMAX_K gotoblas -> zamax_k @@ -468,8 +428,51 @@ #define ZGEADD_K gotoblas -> zgeadd_k +#define ZGEMM_SMALL_MATRIX_PERMIT gotoblas -> zgemm_small_matrix_permit + #endif +#define ZGEMM_SMALL_KERNEL_NN FUNC_OFFSET(zgemm_small_kernel_nn) +#define ZGEMM_SMALL_KERNEL_NT FUNC_OFFSET(zgemm_small_kernel_nt) +#define ZGEMM_SMALL_KERNEL_NR FUNC_OFFSET(zgemm_small_kernel_nr) +#define ZGEMM_SMALL_KERNEL_NC FUNC_OFFSET(zgemm_small_kernel_nc) + +#define ZGEMM_SMALL_KERNEL_TN FUNC_OFFSET(zgemm_small_kernel_tn) +#define ZGEMM_SMALL_KERNEL_TT FUNC_OFFSET(zgemm_small_kernel_tt) +#define ZGEMM_SMALL_KERNEL_TR FUNC_OFFSET(zgemm_small_kernel_tr) +#define ZGEMM_SMALL_KERNEL_TC FUNC_OFFSET(zgemm_small_kernel_tc) + +#define ZGEMM_SMALL_KERNEL_RN FUNC_OFFSET(zgemm_small_kernel_rn) +#define ZGEMM_SMALL_KERNEL_RT FUNC_OFFSET(zgemm_small_kernel_rt) +#define ZGEMM_SMALL_KERNEL_RR FUNC_OFFSET(zgemm_small_kernel_rr) +#define ZGEMM_SMALL_KERNEL_RC FUNC_OFFSET(zgemm_small_kernel_rc) + +#define ZGEMM_SMALL_KERNEL_CN FUNC_OFFSET(zgemm_small_kernel_cn) +#define ZGEMM_SMALL_KERNEL_CT FUNC_OFFSET(zgemm_small_kernel_ct) +#define ZGEMM_SMALL_KERNEL_CR FUNC_OFFSET(zgemm_small_kernel_cr) +#define ZGEMM_SMALL_KERNEL_CC FUNC_OFFSET(zgemm_small_kernel_cc) + +#define ZGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(zgemm_small_kernel_b0_nn) +#define ZGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(zgemm_small_kernel_b0_nt) +#define ZGEMM_SMALL_KERNEL_B0_NR FUNC_OFFSET(zgemm_small_kernel_b0_nr) +#define ZGEMM_SMALL_KERNEL_B0_NC FUNC_OFFSET(zgemm_small_kernel_b0_nc) + +#define ZGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(zgemm_small_kernel_b0_tn) +#define ZGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(zgemm_small_kernel_b0_tt) +#define ZGEMM_SMALL_KERNEL_B0_TR FUNC_OFFSET(zgemm_small_kernel_b0_tr) +#define ZGEMM_SMALL_KERNEL_B0_TC FUNC_OFFSET(zgemm_small_kernel_b0_tc) + +#define ZGEMM_SMALL_KERNEL_B0_RN FUNC_OFFSET(zgemm_small_kernel_b0_rn) +#define ZGEMM_SMALL_KERNEL_B0_RT FUNC_OFFSET(zgemm_small_kernel_b0_rt) +#define ZGEMM_SMALL_KERNEL_B0_RR FUNC_OFFSET(zgemm_small_kernel_b0_rr) +#define ZGEMM_SMALL_KERNEL_B0_RC FUNC_OFFSET(zgemm_small_kernel_b0_rc) + +#define ZGEMM_SMALL_KERNEL_B0_CN FUNC_OFFSET(zgemm_small_kernel_b0_cn) +#define ZGEMM_SMALL_KERNEL_B0_CT FUNC_OFFSET(zgemm_small_kernel_b0_ct) +#define ZGEMM_SMALL_KERNEL_B0_CR FUNC_OFFSET(zgemm_small_kernel_b0_cr) +#define ZGEMM_SMALL_KERNEL_B0_CC FUNC_OFFSET(zgemm_small_kernel_b0_cc) + + #define ZGEMM_NN zgemm_nn #define ZGEMM_CN zgemm_cn #define ZGEMM_TN zgemm_tn diff --git a/interface/gemm.c b/interface/gemm.c index ad8780668..f4b9f1537 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -106,25 +106,34 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, B }; #ifdef SMALL_MATRIX_OPT +#ifndef DYNAMIC_ARCH +#define SMALL_KERNEL_ADDR(table, idx) ((void *)(table[idx])) +#else +#define SMALL_KERNEL_ADDR(table, idx) ((void *)(*(uintptr_t *)((char *)gotoblas + (size_t)(table[idx])))) +#endif + #ifndef COMPLEX -static int (*gemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG) = { +static size_t gemm_small_kernel[] = { #ifndef GEMM3M - GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, NULL, NULL, - GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, NULL, NULL, + GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, 0, 0, + GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, 0, 0, #endif }; -static int (*gemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = { + +static size_t gemm_small_kernel_b0[] = { #ifndef GEMM3M - GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, NULL, NULL, - GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, NULL, NULL, + GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, 0, 0, + GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, 0, 0, #endif }; +#define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx)) +#define GEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel, (idx)) #else -static int (*zgemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG) = { +static size_t zgemm_small_kernel[] = { #ifndef GEMM3M GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN, GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT, @@ -133,7 +142,7 @@ static int (*zgemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLO #endif }; -static int (*zgemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = { +static size_t zgemm_small_kernel_b0[] = { #ifndef GEMM3M GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN, GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT, @@ -141,6 +150,9 @@ static int (*zgemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLA GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC, #endif }; + +#define ZGEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel, (idx)) +#define ZGEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel_b0, (idx)) #endif #endif @@ -163,7 +175,7 @@ void NAME(char *TRANSA, char *TRANSB, IFLOAT *buffer; IFLOAT *sa, *sb; -#if defined (SMP) || defined(SMALL_MATRIX_OPT) +#ifdef SMP double MNK; #if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY) #ifndef COMPLEX @@ -287,11 +299,8 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS XFLOAT *buffer; XFLOAT *sa, *sb; -#if defined (SMP) || defined(SMALL_MATRIX_OPT) - double MNK; -#endif - #ifdef SMP + double MNK; #if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY) #ifndef COMPLEX #ifdef XDOUBLE @@ -459,32 +468,27 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS FUNCTION_PROFILE_START(); -#if defined(SMP) || defined(SMALL_MATRIX_OPT) - MNK = (double) args.m * (double) args.n * (double) args.k; -#endif - #ifdef SMALL_MATRIX_OPT #if !defined(COMPLEX) if(GEMM_SMALL_MATRIX_PERMIT(transa, transb, args.m, args.n, args.k, *(FLOAT *)(args.alpha), *(FLOAT *)(args.beta))){ if(*(FLOAT *)(args.beta) == 0.0){ - (gemm_small_kernel_b0[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, args.c, args.ldc); + (GEMM_SMALL_KERNEL_B0((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, args.c, args.ldc); }else{ - (gemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); + (GEMM_SMALL_KERNEL((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); } return; } #else if(GEMM_SMALL_MATRIX_PERMIT(transa, transb, args.m, args.n, args.k, alpha[0], alpha[1], beta[0], beta[1])){ if(beta[0] == 0.0 && beta[1] == 0.0){ - (zgemm_small_kernel_b0[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, args.c, args.ldc); + (ZGEMM_SMALL_KERNEL_B0((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, args.c, args.ldc); }else{ - (zgemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, beta[0], beta[1], args.c, args.ldc); + (ZGEMM_SMALL_KERNEL((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, beta[0], beta[1], args.c, args.ldc); } return; } #endif #endif - buffer = (XFLOAT *)blas_memory_alloc(0); @@ -497,7 +501,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS mode |= (transb << BLAS_TRANSB_SHIFT); #endif - + MNK = (double) args.m * (double) args.n * (double) args.k; if ( MNK <= (SMP_THRESHOLD_MIN * (double) GEMM_MULTITHREAD_THRESHOLD) ) args.nthreads = 1; else diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index 1e846a61c..f303d0dc6 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -171,6 +171,14 @@ gotoblas_t TABLE_NAME = { sgemm_oncopyTS, sgemm_otcopyTS, #endif +#if BUILD_SINGLE == 1 +#ifdef SMALL_MATRIX_OPT + sgemm_small_matrix_permitTS, + sgemm_small_kernel_nnTS, sgemm_small_kernel_ntTS, sgemm_small_kernel_tnTS, sgemm_small_kernel_ttTS, + sgemm_small_kernel_b0_nnTS, sgemm_small_kernel_b0_ntTS, sgemm_small_kernel_b0_tnTS, sgemm_small_kernel_b0_ttTS, +#endif +#endif + #if (BUILD_SINGLE==1) || (BUILD_DOUBLE==1) strsm_kernel_LNTS, strsm_kernel_LTTS, strsm_kernel_RNTS, strsm_kernel_RTTS, #if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N @@ -257,6 +265,11 @@ gotoblas_t TABLE_NAME = { #endif #if (BUILD_DOUBLE==1) +#ifdef SMALL_MATRIX_OPT + dgemm_small_matrix_permitTS, + dgemm_small_kernel_nnTS, dgemm_small_kernel_ntTS, dgemm_small_kernel_tnTS, dgemm_small_kernel_ttTS, + dgemm_small_kernel_b0_nnTS, dgemm_small_kernel_b0_ntTS, dgemm_small_kernel_b0_tnTS, dgemm_small_kernel_b0_ttTS, +#endif dtrsm_kernel_LNTS, dtrsm_kernel_LTTS, dtrsm_kernel_RNTS, dtrsm_kernel_RTTS, #if DGEMM_DEFAULT_UNROLL_M != DGEMM_DEFAULT_UNROLL_N dtrsm_iunucopyTS, dtrsm_iunncopyTS, dtrsm_iutucopyTS, dtrsm_iutncopyTS, @@ -389,6 +402,18 @@ gotoblas_t TABLE_NAME = { #endif cgemm_oncopyTS, cgemm_otcopyTS, +#ifdef SMALL_MATRIX_OPT + cgemm_small_matrix_permitTS, + cgemm_small_kernel_nnTS, cgemm_small_kernel_ntTS, cgemm_small_kernel_nrTS, cgemm_small_kernel_ncTS, + cgemm_small_kernel_tnTS, cgemm_small_kernel_ttTS, cgemm_small_kernel_trTS, cgemm_small_kernel_tcTS, + cgemm_small_kernel_rnTS, cgemm_small_kernel_rtTS, cgemm_small_kernel_rrTS, cgemm_small_kernel_rcTS, + cgemm_small_kernel_cnTS, cgemm_small_kernel_ctTS, cgemm_small_kernel_crTS, cgemm_small_kernel_ccTS, + cgemm_small_kernel_b0_nnTS, cgemm_small_kernel_b0_ntTS, cgemm_small_kernel_b0_nrTS, cgemm_small_kernel_b0_ncTS, + cgemm_small_kernel_b0_tnTS, cgemm_small_kernel_b0_ttTS, cgemm_small_kernel_b0_trTS, cgemm_small_kernel_b0_tcTS, + cgemm_small_kernel_b0_rnTS, cgemm_small_kernel_b0_rtTS, cgemm_small_kernel_b0_rrTS, cgemm_small_kernel_b0_rcTS, + cgemm_small_kernel_b0_cnTS, cgemm_small_kernel_b0_ctTS, cgemm_small_kernel_b0_crTS, cgemm_small_kernel_b0_ccTS, +#endif + ctrsm_kernel_LNTS, ctrsm_kernel_LTTS, ctrsm_kernel_LRTS, ctrsm_kernel_LCTS, ctrsm_kernel_RNTS, ctrsm_kernel_RTTS, ctrsm_kernel_RRTS, ctrsm_kernel_RCTS, @@ -533,6 +558,18 @@ gotoblas_t TABLE_NAME = { #endif zgemm_oncopyTS, zgemm_otcopyTS, +#ifdef SMALL_MATRIX_OPT + zgemm_small_matrix_permitTS, + zgemm_small_kernel_nnTS, zgemm_small_kernel_ntTS, zgemm_small_kernel_nrTS, zgemm_small_kernel_ncTS, + zgemm_small_kernel_tnTS, zgemm_small_kernel_ttTS, zgemm_small_kernel_trTS, zgemm_small_kernel_tcTS, + zgemm_small_kernel_rnTS, zgemm_small_kernel_rtTS, zgemm_small_kernel_rrTS, zgemm_small_kernel_rcTS, + zgemm_small_kernel_cnTS, zgemm_small_kernel_ctTS, zgemm_small_kernel_crTS, zgemm_small_kernel_ccTS, + zgemm_small_kernel_b0_nnTS, zgemm_small_kernel_b0_ntTS, zgemm_small_kernel_b0_nrTS, zgemm_small_kernel_b0_ncTS, + zgemm_small_kernel_b0_tnTS, zgemm_small_kernel_b0_ttTS, zgemm_small_kernel_b0_trTS, zgemm_small_kernel_b0_tcTS, + zgemm_small_kernel_b0_rnTS, zgemm_small_kernel_b0_rtTS, zgemm_small_kernel_b0_rrTS, zgemm_small_kernel_b0_rcTS, + zgemm_small_kernel_b0_cnTS, zgemm_small_kernel_b0_ctTS, zgemm_small_kernel_b0_crTS, zgemm_small_kernel_b0_ccTS, +#endif + ztrsm_kernel_LNTS, ztrsm_kernel_LTTS, ztrsm_kernel_LRTS, ztrsm_kernel_LCTS, ztrsm_kernel_RNTS, ztrsm_kernel_RTTS, ztrsm_kernel_RRTS, ztrsm_kernel_RCTS, From fee5abd84bf01aba7a2223f7264fcc7da66d1b20 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 4 Aug 2021 08:50:15 +0000 Subject: [PATCH 31/37] Small Matrix: support cmake build --- cmake/system.cmake | 4 ++ kernel/CMakeLists.txt | 110 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/cmake/system.cmake b/cmake/system.cmake index f8bd6678e..e51dc1fdc 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -258,6 +258,10 @@ if (NEED_PIC) endif() endif () +if (SMALL_MATRIX_OPT) + set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT") +endif () + if (DYNAMIC_ARCH) if (X86 OR X86_64 OR ARM64 OR PPC) set(CCOMMON_OPT "${CCOMMON_OPT} -DDYNAMIC_ARCH") diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index f0793bdef..769a73b91 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -458,7 +458,117 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("${KERNELDIR}/${${float_char}TRSMKERNEL_RN}" "UPPER;RN;TRSMKERNEL" "trsm_kernel_RN" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}TRSMKERNEL_RT}" "RT;TRSMKERNEL" "trsm_kernel_RT" false "" "" false ${float_type}) + if (NOT DEFINED ${float_char}GEMM_SMALL_M_PERMIT) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_M_PERMIT ../generic/zgemm_small_matrix_permit.c) + else () + set(${float_char}GEMM_SMALL_M_PERMIT ../generic/gemm_small_matrix_permit.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_NN) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_NN ../generic/zgemm_small_matrix_kernel_nn.c) + else () + set(${float_char}GEMM_SMALL_K_NN ../generic/gemm_small_matrix_kernel_nn.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_NT) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_NT ../generic/zgemm_small_matrix_kernel_nt.c) + else () + set(${float_char}GEMM_SMALL_K_NT ../generic/gemm_small_matrix_kernel_nt.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_TN) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_TN ../generic/zgemm_small_matrix_kernel_tn.c) + else () + set(${float_char}GEMM_SMALL_K_TN ../generic/gemm_small_matrix_kernel_tn.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_TT) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_TT ../generic/zgemm_small_matrix_kernel_tt.c) + else () + set(${float_char}GEMM_SMALL_K_TT ../generic/gemm_small_matrix_kernel_tt.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_NN) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_B0_NN ../generic/zgemm_small_matrix_kernel_b0_nn.c) + else () + set(${float_char}GEMM_SMALL_K_B0_NN ../generic/gemm_small_matrix_kernel_b0_nn.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_NT) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_B0_NT ../generic/zgemm_small_matrix_kernel_b0_nt.c) + else () + set(${float_char}GEMM_SMALL_K_B0_NT ../generic/gemm_small_matrix_kernel_b0_nt.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_TN) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_B0_TN ../generic/zgemm_small_matrix_kernel_b0_tn.c) + else () + set(${float_char}GEMM_SMALL_K_B0_TN ../generic/gemm_small_matrix_kernel_b0_tn.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_TT) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_B0_TT ../generic/zgemm_small_matrix_kernel_b0_tt.c) + else () + set(${float_char}GEMM_SMALL_K_B0_TT ../generic/gemm_small_matrix_kernel_b0_tt.c) + endif () + endif () + if (SMALL_MATRIX_OPT) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_M_PERMIT}" "" "gemm_small_matrix_permit" false "" "" false ${float_type}) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "NN" "gemm_small_kernel_nn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "NR" "gemm_small_kernel_nr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "RN" "gemm_small_kernel_rn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "RR" "gemm_small_kernel_rr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "NT" "gemm_small_kernel_nt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "NC" "gemm_small_kernel_nc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "RT" "gemm_small_kernel_rt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "RC" "gemm_small_kernel_rc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "TN" "gemm_small_kernel_tn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "TR" "gemm_small_kernel_tr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "CN" "gemm_small_kernel_cn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "CR" "gemm_small_kernel_cr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "TT" "gemm_small_kernel_tt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "TC" "gemm_small_kernel_tc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "CT" "gemm_small_kernel_ct" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "CC" "gemm_small_kernel_cc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "NN" "gemm_small_kernel_b0_nn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "NR" "gemm_small_kernel_b0_nr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "RN" "gemm_small_kernel_b0_rn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "RR" "gemm_small_kernel_b0_rr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "NT" "gemm_small_kernel_b0_nt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "NC" "gemm_small_kernel_b0_nc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "RT" "gemm_small_kernel_b0_rt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "RC" "gemm_small_kernel_b0_rc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "TN" "gemm_small_kernel_b0_tn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "TR" "gemm_small_kernel_b0_tr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "CN" "gemm_small_kernel_b0_cn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "CR" "gemm_small_kernel_b0_cr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "TT" "gemm_small_kernel_b0_tt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "TC" "gemm_small_kernel_b0_tc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "CT" "gemm_small_kernel_b0_ct" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "CC" "gemm_small_kernel_b0_cc" false "" "" false ${float_type}) + + else () + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "" "gemm_small_kernel_nn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "" "gemm_small_kernel_nt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "" "gemm_small_kernel_tn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "" "gemm_small_kernel_tt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "" "gemm_small_kernel_b0_nn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "" "gemm_small_kernel_b0_nt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "" "gemm_small_kernel_b0_tn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "" "gemm_small_kernel_b0_tt" false "" "" false ${float_type}) + endif () + endif () if (NOT DEFINED ${float_char}OMATCOPY_CN) if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") From aa50185647ba6966dcdb731372af2ecd5ae3b1d4 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 5 Aug 2021 02:45:53 +0000 Subject: [PATCH 32/37] Small Matrix: better handle with GEMM3M marco --- interface/gemm.c | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/interface/gemm.c b/interface/gemm.c index f4b9f1537..775f654c3 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -105,6 +105,7 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, B #endif }; +#ifndef GEMM3M #ifdef SMALL_MATRIX_OPT #ifndef DYNAMIC_ARCH #define SMALL_KERNEL_ADDR(table, idx) ((void *)(table[idx])) @@ -115,18 +116,14 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, B #ifndef COMPLEX static size_t gemm_small_kernel[] = { -#ifndef GEMM3M GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, 0, 0, GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, 0, 0, -#endif }; static size_t gemm_small_kernel_b0[] = { -#ifndef GEMM3M GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, 0, 0, GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, 0, 0, -#endif }; #define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx)) @@ -134,27 +131,24 @@ static size_t gemm_small_kernel_b0[] = { #else static size_t zgemm_small_kernel[] = { -#ifndef GEMM3M GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN, GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT, GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR, GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC, -#endif }; static size_t zgemm_small_kernel_b0[] = { -#ifndef GEMM3M GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN, GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT, GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR, GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC, -#endif }; #define ZGEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel, (idx)) #define ZGEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel_b0, (idx)) #endif #endif +#endif #ifndef CBLAS @@ -468,6 +462,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS FUNCTION_PROFILE_START(); +#ifndef GEMM3M #ifdef SMALL_MATRIX_OPT #if !defined(COMPLEX) if(GEMM_SMALL_MATRIX_PERMIT(transa, transb, args.m, args.n, args.k, *(FLOAT *)(args.alpha), *(FLOAT *)(args.beta))){ @@ -488,6 +483,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS return; } #endif +#endif #endif buffer = (XFLOAT *)blas_memory_alloc(0); From 76ea8db4da1a651bb4de744162de1ecfc6762e7c Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 5 Aug 2021 02:57:58 +0000 Subject: [PATCH 33/37] Small Matrix: enable by default for x86_64 arch If no customized GEMM_SMALL_M_PERMIT kernel defined, it will just by pass to normal path. --- Makefile.system | 3 +++ cmake/system.cmake | 3 +++ 2 files changed, 6 insertions(+) diff --git a/Makefile.system b/Makefile.system index 20d8d2f2a..20db80d07 100644 --- a/Makefile.system +++ b/Makefile.system @@ -245,6 +245,9 @@ ONLY_CBLAS = 0 endif #For small matrix optimization +ifeq ($(ARCH), x86_64) +SMALL_MATRIX_OPT = 1 +endif ifeq ($(SMALL_MATRIX_OPT), 1) CCOMMON_OPT += -DSMALL_MATRIX_OPT endif diff --git a/cmake/system.cmake b/cmake/system.cmake index e51dc1fdc..7d2672998 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -258,6 +258,9 @@ if (NEED_PIC) endif() endif () +if (X86_64) + set(SMALL_MATRIX_OPT TRUE) +endif () if (SMALL_MATRIX_OPT) set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT") endif () From 44d0032f3b8e9794d51b7807b3fb53905a2e9f1c Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 5 Aug 2021 04:43:47 +0000 Subject: [PATCH 34/37] Small Matrix: skylakex: fix build error in old compiler --- kernel/x86_64/dgemm_small_kernel_nn_skylakex.c | 4 ++-- kernel/x86_64/dgemm_small_kernel_nt_skylakex.c | 2 +- kernel/x86_64/dgemm_small_kernel_tn_skylakex.c | 4 ++-- kernel/x86_64/dgemm_small_kernel_tt_skylakex.c | 10 +++++----- kernel/x86_64/sgemm_small_kernel_nt_skylakex.c | 2 +- kernel/x86_64/sgemm_small_kernel_tt_skylakex.c | 6 +++--- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/kernel/x86_64/dgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/dgemm_small_kernel_nn_skylakex.c index ff2a04beb..d9b380fff 100644 --- a/kernel/x86_64/dgemm_small_kernel_nn_skylakex.c +++ b/kernel/x86_64/dgemm_small_kernel_nn_skylakex.c @@ -372,8 +372,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp 0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8, 2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8, }; - __m512i idx_lo = _mm512_loadu_epi64(permute_table); - __m512i idx_hi = _mm512_loadu_epi64(permute_table + 8); + __m512i idx_lo = _mm512_loadu_si512(permute_table); + __m512i idx_hi = _mm512_loadu_si512(permute_table + 8); for (; i < m4; i += 4, mi += 4) { for (j = 0; j < n4; j += 4) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); diff --git a/kernel/x86_64/dgemm_small_kernel_nt_skylakex.c b/kernel/x86_64/dgemm_small_kernel_nt_skylakex.c index 0a95a68e2..e757197ba 100644 --- a/kernel/x86_64/dgemm_small_kernel_nt_skylakex.c +++ b/kernel/x86_64/dgemm_small_kernel_nt_skylakex.c @@ -385,7 +385,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp for (int ii = 0; ii < 8; ii++) { index_n[ii] = ii * ldc; } - __m512i vindex_n = _mm512_loadu_epi64(index_n); + __m512i vindex_n = _mm512_loadu_si512(index_n); for (; i < m4; i += 4) { for (j = 0; j < n32; j += 32) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); diff --git a/kernel/x86_64/dgemm_small_kernel_tn_skylakex.c b/kernel/x86_64/dgemm_small_kernel_tn_skylakex.c index 0881f35b2..18c797283 100644 --- a/kernel/x86_64/dgemm_small_kernel_tn_skylakex.c +++ b/kernel/x86_64/dgemm_small_kernel_tn_skylakex.c @@ -105,8 +105,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp 0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8, 2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8, }; - __m512i idx_lo = _mm512_loadu_epi64(permute_table); - __m512i idx_hi = _mm512_loadu_epi64(permute_table + 8); + __m512i idx_lo = _mm512_loadu_si512(permute_table); + __m512i idx_hi = _mm512_loadu_si512(permute_table + 8); for (i = 0; i < m4; i += 4) { for (j = 0; j < n4; j += 4) { diff --git a/kernel/x86_64/dgemm_small_kernel_tt_skylakex.c b/kernel/x86_64/dgemm_small_kernel_tt_skylakex.c index 8ff79d2c8..00f42aa76 100644 --- a/kernel/x86_64/dgemm_small_kernel_tt_skylakex.c +++ b/kernel/x86_64/dgemm_small_kernel_tt_skylakex.c @@ -189,8 +189,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp 0, 1, 4, 5, 0|8, 1|8, 4|8, 5|8, 2, 3, 6, 7, 2|8, 3|8, 6|8, 7|8, }; - __m512i idx_lo = _mm512_loadu_epi64(permute_table); - __m512i idx_hi = _mm512_loadu_epi64(permute_table + 8); + __m512i idx_lo = _mm512_loadu_si512(permute_table); + __m512i idx_hi = _mm512_loadu_si512(permute_table + 8); for (i = 0; i < m8; i += 8) { for (j = 0; j < n16; j += 16) { @@ -235,8 +235,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp 0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8, 2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8, }; - idx_lo = _mm512_loadu_epi64(permute_table2); - idx_hi = _mm512_loadu_epi64(permute_table2 + 8); + idx_lo = _mm512_loadu_si512(permute_table2); + idx_hi = _mm512_loadu_si512(permute_table2 + 8); for (j = 0; j < n32; j += 32) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); @@ -289,7 +289,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp for (int ii = 0; ii < 8; ii++) { index_n[ii] = ii * ldc; } - __m512i vindex_n = _mm512_loadu_epi64(index_n); + __m512i vindex_n = _mm512_loadu_si512(index_n); #if !defined(B0) __m512d beta_512 = _mm512_broadcastsd_pd(_mm_load_sd(&beta)); #endif diff --git a/kernel/x86_64/sgemm_small_kernel_nt_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nt_skylakex.c index f293bf9f9..a7d87f8c4 100644 --- a/kernel/x86_64/sgemm_small_kernel_nt_skylakex.c +++ b/kernel/x86_64/sgemm_small_kernel_nt_skylakex.c @@ -385,7 +385,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp for (int ii = 0; ii < 16; ii++) { index_n[ii] = ii * ldc; } - __m512i vindex_n = _mm512_loadu_epi32(index_n); + __m512i vindex_n = _mm512_loadu_si512(index_n); for (; i < m4; i += 4) { for (j = 0; j < n64; j += 64) { DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); diff --git a/kernel/x86_64/sgemm_small_kernel_tt_skylakex.c b/kernel/x86_64/sgemm_small_kernel_tt_skylakex.c index 8da560ef7..023f58746 100644 --- a/kernel/x86_64/sgemm_small_kernel_tt_skylakex.c +++ b/kernel/x86_64/sgemm_small_kernel_tt_skylakex.c @@ -215,8 +215,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp 0x0, 0x1, 0x2, 0x3, 0x10, 0x11, 0x12, 0x13, 0x8, 0x9, 0xa, 0xb, 0x18, 0x19, 0x1a, 0x1b, 0x4, 0x5, 0x6, 0x7, 0x14, 0x15, 0x16, 0x17, 0xc, 0xd, 0xe, 0xf, 0x1c, 0x1d, 0x1e, 0x1f, }; - __m512i idx_lo = _mm512_loadu_epi32(permute_table); - __m512i idx_hi = _mm512_loadu_epi32(permute_table + 16); + __m512i idx_lo = _mm512_loadu_si512(permute_table); + __m512i idx_hi = _mm512_loadu_si512(permute_table + 16); __mmask16 kc = 0xcccc; __mmask16 k3 = 0x3333; __mmask8 mask8 = 0xff; // force use AVX128 instead of SSE @@ -311,7 +311,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp for (int ii = 0; ii < 16; ii++) { index_n[ii] = ii * ldc; } - __m512i vindex_n = _mm512_loadu_epi32(index_n); + __m512i vindex_n = _mm512_loadu_si512(index_n); #if !defined(B0) __m512 beta_512 = _mm512_broadcastss_ps(_mm_load_ss(&beta)); #endif From c17d6dacb23f0862f6f0318c55c097c361132663 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 5 Aug 2021 05:46:13 +0000 Subject: [PATCH 35/37] Small Matrix: skip compile in unimplemented data type --- interface/gemm.c | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/interface/gemm.c b/interface/gemm.c index 775f654c3..3497d8651 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -105,8 +105,13 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, B #endif }; -#ifndef GEMM3M -#ifdef SMALL_MATRIX_OPT +#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE) && !defined(BFLOAT16) +#define USE_SMALL_MATRIX_OPT 1 +#else +#define USE_SMALL_MATRIX_OPT 0 +#endif + +#if USE_SMALL_MATRIX_OPT #ifndef DYNAMIC_ARCH #define SMALL_KERNEL_ADDR(table, idx) ((void *)(table[idx])) #else @@ -148,7 +153,6 @@ static size_t zgemm_small_kernel_b0[] = { #define ZGEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel_b0, (idx)) #endif #endif -#endif #ifndef CBLAS @@ -462,8 +466,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS FUNCTION_PROFILE_START(); -#ifndef GEMM3M -#ifdef SMALL_MATRIX_OPT +#if USE_SMALL_MATRIX_OPT #if !defined(COMPLEX) if(GEMM_SMALL_MATRIX_PERMIT(transa, transb, args.m, args.n, args.k, *(FLOAT *)(args.alpha), *(FLOAT *)(args.beta))){ if(*(FLOAT *)(args.beta) == 0.0){ @@ -483,7 +486,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS return; } #endif -#endif #endif buffer = (XFLOAT *)blas_memory_alloc(0); From 989e6bbdd39fe3d49789b803c4fd6b20a3a673e5 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Fri, 13 Aug 2021 03:17:38 +0000 Subject: [PATCH 36/37] Small Matrix: reduce generic kernel source files --- kernel/CMakeLists.txt | 56 ++++----- kernel/Makefile.L3 | 112 +++++++++--------- .../generic/gemm_small_matrix_kernel_b0_nn.c | 49 -------- .../generic/gemm_small_matrix_kernel_b0_nt.c | 49 -------- .../generic/gemm_small_matrix_kernel_b0_tn.c | 49 -------- .../generic/gemm_small_matrix_kernel_b0_tt.c | 49 -------- kernel/generic/gemm_small_matrix_kernel_nn.c | 11 +- kernel/generic/gemm_small_matrix_kernel_nt.c | 9 +- kernel/generic/gemm_small_matrix_kernel_tn.c | 8 ++ kernel/generic/gemm_small_matrix_kernel_tt.c | 8 ++ .../generic/zgemm_small_matrix_kernel_b0_nn.c | 74 ------------ .../generic/zgemm_small_matrix_kernel_b0_nt.c | 77 ------------ .../generic/zgemm_small_matrix_kernel_b0_tn.c | 77 ------------ .../generic/zgemm_small_matrix_kernel_b0_tt.c | 77 ------------ kernel/generic/zgemm_small_matrix_kernel_nn.c | 11 ++ kernel/generic/zgemm_small_matrix_kernel_nt.c | 11 ++ kernel/generic/zgemm_small_matrix_kernel_tn.c | 11 ++ kernel/generic/zgemm_small_matrix_kernel_tt.c | 11 ++ 18 files changed, 161 insertions(+), 588 deletions(-) delete mode 100644 kernel/generic/gemm_small_matrix_kernel_b0_nn.c delete mode 100644 kernel/generic/gemm_small_matrix_kernel_b0_nt.c delete mode 100644 kernel/generic/gemm_small_matrix_kernel_b0_tn.c delete mode 100644 kernel/generic/gemm_small_matrix_kernel_b0_tt.c delete mode 100644 kernel/generic/zgemm_small_matrix_kernel_b0_nn.c delete mode 100644 kernel/generic/zgemm_small_matrix_kernel_b0_nt.c delete mode 100644 kernel/generic/zgemm_small_matrix_kernel_b0_tn.c delete mode 100644 kernel/generic/zgemm_small_matrix_kernel_b0_tt.c diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index 769a73b91..d8a230436 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -495,30 +495,30 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) endif () if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_NN) if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") - set(${float_char}GEMM_SMALL_K_B0_NN ../generic/zgemm_small_matrix_kernel_b0_nn.c) + set(${float_char}GEMM_SMALL_K_B0_NN ../generic/zgemm_small_matrix_kernel_nn.c) else () - set(${float_char}GEMM_SMALL_K_B0_NN ../generic/gemm_small_matrix_kernel_b0_nn.c) + set(${float_char}GEMM_SMALL_K_B0_NN ../generic/gemm_small_matrix_kernel_nn.c) endif () endif () if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_NT) if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") - set(${float_char}GEMM_SMALL_K_B0_NT ../generic/zgemm_small_matrix_kernel_b0_nt.c) + set(${float_char}GEMM_SMALL_K_B0_NT ../generic/zgemm_small_matrix_kernel_nt.c) else () - set(${float_char}GEMM_SMALL_K_B0_NT ../generic/gemm_small_matrix_kernel_b0_nt.c) + set(${float_char}GEMM_SMALL_K_B0_NT ../generic/gemm_small_matrix_kernel_nt.c) endif () endif () if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_TN) if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") - set(${float_char}GEMM_SMALL_K_B0_TN ../generic/zgemm_small_matrix_kernel_b0_tn.c) + set(${float_char}GEMM_SMALL_K_B0_TN ../generic/zgemm_small_matrix_kernel_tn.c) else () - set(${float_char}GEMM_SMALL_K_B0_TN ../generic/gemm_small_matrix_kernel_b0_tn.c) + set(${float_char}GEMM_SMALL_K_B0_TN ../generic/gemm_small_matrix_kernel_tn.c) endif () endif () if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_TT) if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") - set(${float_char}GEMM_SMALL_K_B0_TT ../generic/zgemm_small_matrix_kernel_b0_tt.c) + set(${float_char}GEMM_SMALL_K_B0_TT ../generic/zgemm_small_matrix_kernel_tt.c) else () - set(${float_char}GEMM_SMALL_K_B0_TT ../generic/gemm_small_matrix_kernel_b0_tt.c) + set(${float_char}GEMM_SMALL_K_B0_TT ../generic/gemm_small_matrix_kernel_tt.c) endif () endif () @@ -541,32 +541,32 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "TC" "gemm_small_kernel_tc" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "CT" "gemm_small_kernel_ct" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "CC" "gemm_small_kernel_cc" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "NN" "gemm_small_kernel_b0_nn" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "NR" "gemm_small_kernel_b0_nr" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "RN" "gemm_small_kernel_b0_rn" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "RR" "gemm_small_kernel_b0_rr" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "NT" "gemm_small_kernel_b0_nt" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "NC" "gemm_small_kernel_b0_nc" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "RT" "gemm_small_kernel_b0_rt" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "RC" "gemm_small_kernel_b0_rc" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "TN" "gemm_small_kernel_b0_tn" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "TR" "gemm_small_kernel_b0_tr" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "CN" "gemm_small_kernel_b0_cn" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "CR" "gemm_small_kernel_b0_cr" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "TT" "gemm_small_kernel_b0_tt" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "TC" "gemm_small_kernel_b0_tc" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "CT" "gemm_small_kernel_b0_ct" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "CC" "gemm_small_kernel_b0_cc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "NN;B0" "gemm_small_kernel_b0_nn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "NR;B0" "gemm_small_kernel_b0_nr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "RN;B0" "gemm_small_kernel_b0_rn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "RR;B0" "gemm_small_kernel_b0_rr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "NT;B0" "gemm_small_kernel_b0_nt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "NC;B0" "gemm_small_kernel_b0_nc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "RT;B0" "gemm_small_kernel_b0_rt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "RC;B0" "gemm_small_kernel_b0_rc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "TN;B0" "gemm_small_kernel_b0_tn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "TR;B0" "gemm_small_kernel_b0_tr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "CN;B0" "gemm_small_kernel_b0_cn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "CR;B0" "gemm_small_kernel_b0_cr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "TT;B0" "gemm_small_kernel_b0_tt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "TC;B0" "gemm_small_kernel_b0_tc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "CT;B0" "gemm_small_kernel_b0_ct" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "CC;B0" "gemm_small_kernel_b0_cc" false "" "" false ${float_type}) else () GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "" "gemm_small_kernel_nn" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "" "gemm_small_kernel_nt" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "" "gemm_small_kernel_tn" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "" "gemm_small_kernel_tt" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "" "gemm_small_kernel_b0_nn" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "" "gemm_small_kernel_b0_nt" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "" "gemm_small_kernel_b0_tn" false "" "" false ${float_type}) - GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "" "gemm_small_kernel_b0_tt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "B0" "gemm_small_kernel_b0_nn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "B0" "gemm_small_kernel_b0_nt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "B0" "gemm_small_kernel_b0_tn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "B0" "gemm_small_kernel_b0_tt" false "" "" false ${float_type}) endif () endif () diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index f977793a0..ef11e391c 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -4334,32 +4334,32 @@ $(KDIR)dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_ $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ ifndef DGEMM_SMALL_K_B0_NN -DGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_b0_nn.c +DGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_nn.c endif ifndef DGEMM_SMALL_K_B0_NT -DGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_b0_nt.c +DGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_nt.c endif ifndef DGEMM_SMALL_K_B0_TN -DGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_b0_tn.c +DGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_tn.c endif ifndef DGEMM_SMALL_K_B0_TT -DGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_b0_tt.c +DGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_tt.c endif $(KDIR)dgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_NN) - $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX -DB0 $< -o $@ $(KDIR)dgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_NT) - $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX -DB0 $< -o $@ $(KDIR)dgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_TN) - $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX -DB0 $< -o $@ $(KDIR)dgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_TT) - $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX -DB0 $< -o $@ ifndef SGEMM_SMALL_M_PERMIT SGEMM_SMALL_M_PERMIT = ../generic/gemm_small_matrix_permit.c @@ -4397,32 +4397,32 @@ $(KDIR)sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_ $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ ifndef SGEMM_SMALL_K_B0_NN -SGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_b0_nn.c +SGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_nn.c endif ifndef SGEMM_SMALL_K_B0_NT -SGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_b0_nt.c +SGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_nt.c endif ifndef SGEMM_SMALL_K_B0_TN -SGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_b0_tn.c +SGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_tn.c endif ifndef SGEMM_SMALL_K_B0_TT -SGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_b0_tt.c +SGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_tt.c endif $(KDIR)sgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_NN) - $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DB0 $< -o $@ $(KDIR)sgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_NT) - $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DB0 $< -o $@ $(KDIR)sgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_TN) - $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DB0 $< -o $@ $(KDIR)sgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_TT) - $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DB0 $< -o $@ ifndef CGEMM_SMALL_M_PERMIT CGEMM_SMALL_M_PERMIT = ../generic/zgemm_small_matrix_permit.c @@ -4496,68 +4496,68 @@ $(KDIR)cgemm_small_kernel_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_ $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCC $< -o $@ ifndef CGEMM_SMALL_K_B0_NN -CGEMM_SMALL_K_B0_NN = ../generic/zgemm_small_matrix_kernel_b0_nn.c +CGEMM_SMALL_K_B0_NN = ../generic/zgemm_small_matrix_kernel_nn.c endif ifndef CGEMM_SMALL_K_B0_NT -CGEMM_SMALL_K_B0_NT = ../generic/zgemm_small_matrix_kernel_b0_nt.c +CGEMM_SMALL_K_B0_NT = ../generic/zgemm_small_matrix_kernel_nt.c endif ifndef CGEMM_SMALL_K_B0_TN -CGEMM_SMALL_K_B0_TN = ../generic/zgemm_small_matrix_kernel_b0_tn.c +CGEMM_SMALL_K_B0_TN = ../generic/zgemm_small_matrix_kernel_tn.c endif ifndef CGEMM_SMALL_K_B0_TT -CGEMM_SMALL_K_B0_TT = ../generic/zgemm_small_matrix_kernel_b0_tt.c +CGEMM_SMALL_K_B0_TT = ../generic/zgemm_small_matrix_kernel_tt.c endif $(KDIR)cgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNN $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNN -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_nr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNR $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNR -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_rn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRN $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRN -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_rr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRR $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRR -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNT $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNT -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_nc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNC $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNC -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_rt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRT $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRT -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_rc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRC $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRC -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTN $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTN -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_tr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTR $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTR -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_cn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCN $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCN -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCR $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCR -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTT $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTT -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_tc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTC $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTC -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCT $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCT -DB0 $< -o $@ $(KDIR)cgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) - $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCC $< -o $@ + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCC -DB0 $< -o $@ ifndef ZGEMM_SMALL_M_PERMIT ZGEMM_SMALL_M_PERMIT = ../generic/zgemm_small_matrix_permit.c @@ -4632,65 +4632,65 @@ $(KDIR)zgemm_small_kernel_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_ $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCC $< -o $@ ifndef ZGEMM_SMALL_K_B0_NN -ZGEMM_SMALL_K_B0_NN = ../generic/zgemm_small_matrix_kernel_b0_nn.c +ZGEMM_SMALL_K_B0_NN = ../generic/zgemm_small_matrix_kernel_nn.c endif ifndef ZGEMM_SMALL_K_B0_NT -ZGEMM_SMALL_K_B0_NT = ../generic/zgemm_small_matrix_kernel_b0_nt.c +ZGEMM_SMALL_K_B0_NT = ../generic/zgemm_small_matrix_kernel_nt.c endif ifndef ZGEMM_SMALL_K_B0_TN -ZGEMM_SMALL_K_B0_TN = ../generic/zgemm_small_matrix_kernel_b0_tn.c +ZGEMM_SMALL_K_B0_TN = ../generic/zgemm_small_matrix_kernel_tn.c endif ifndef ZGEMM_SMALL_K_B0_TT -ZGEMM_SMALL_K_B0_TT = ../generic/zgemm_small_matrix_kernel_b0_tt.c +ZGEMM_SMALL_K_B0_TT = ../generic/zgemm_small_matrix_kernel_tt.c endif $(KDIR)zgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNN $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNN -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_nr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNR $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNR -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_rn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRN $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRN -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_rr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRR $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRR -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNT $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNT -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_nc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNC $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNC -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_rt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRT $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRT -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_rc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRC $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRC -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTN $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTN -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_tr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTR $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTR -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_cn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCN $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCN -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCR $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCR -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTT $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTT -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_tc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTC $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTC -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCT $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCT -DB0 $< -o $@ $(KDIR)zgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) - $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCC $< -o $@ + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCC -DB0 $< -o $@ diff --git a/kernel/generic/gemm_small_matrix_kernel_b0_nn.c b/kernel/generic/gemm_small_matrix_kernel_b0_nn.c deleted file mode 100644 index 3be918017..000000000 --- a/kernel/generic/gemm_small_matrix_kernel_b0_nn.c +++ /dev/null @@ -1,49 +0,0 @@ -/*************************************************************************** -Copyright (c) 2020, The OpenBLAS Project -All rights reserved. -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: -1. Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. -2. Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in -the documentation and/or other materials provided with the -distribution. -3. Neither the name of the OpenBLAS project nor the names of -its contributors may be used to endorse or promote products -derived from this software without specific prior written permission. -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*****************************************************************************/ - -#include "common.h" - -int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb,FLOAT * C, BLASLONG ldc) -{ - //naive implemtation - //Column major - - BLASLONG i,j,k; - FLOAT result=0.0; - - for(i=0; i Date: Fri, 13 Aug 2021 03:28:44 +0000 Subject: [PATCH 37/37] Small Matrix: skylakex: remove unnecessary b0 source files --- kernel/x86_64/KERNEL.SKYLAKEX | 16 ++++++++-------- .../x86_64/dgemm_small_kernel_b0_nn_skylakex.c | 2 -- .../x86_64/dgemm_small_kernel_b0_nt_skylakex.c | 2 -- .../x86_64/dgemm_small_kernel_b0_tn_skylakex.c | 2 -- .../x86_64/dgemm_small_kernel_b0_tt_skylakex.c | 2 -- .../x86_64/sgemm_small_kernel_b0_nn_skylakex.c | 2 -- .../x86_64/sgemm_small_kernel_b0_nt_skylakex.c | 2 -- .../x86_64/sgemm_small_kernel_b0_tn_skylakex.c | 2 -- .../x86_64/sgemm_small_kernel_b0_tt_skylakex.c | 3 --- 9 files changed, 8 insertions(+), 25 deletions(-) delete mode 100644 kernel/x86_64/dgemm_small_kernel_b0_nn_skylakex.c delete mode 100644 kernel/x86_64/dgemm_small_kernel_b0_nt_skylakex.c delete mode 100644 kernel/x86_64/dgemm_small_kernel_b0_tn_skylakex.c delete mode 100644 kernel/x86_64/dgemm_small_kernel_b0_tt_skylakex.c delete mode 100644 kernel/x86_64/sgemm_small_kernel_b0_nn_skylakex.c delete mode 100644 kernel/x86_64/sgemm_small_kernel_b0_nt_skylakex.c delete mode 100644 kernel/x86_64/sgemm_small_kernel_b0_tn_skylakex.c delete mode 100644 kernel/x86_64/sgemm_small_kernel_b0_tt_skylakex.c diff --git a/kernel/x86_64/KERNEL.SKYLAKEX b/kernel/x86_64/KERNEL.SKYLAKEX index eb0cbaf98..6b4961bc2 100644 --- a/kernel/x86_64/KERNEL.SKYLAKEX +++ b/kernel/x86_64/KERNEL.SKYLAKEX @@ -12,13 +12,13 @@ STRSMKERNEL_RN = ../generic/trsm_kernel_RN.c STRSMKERNEL_RT = ../generic/trsm_kernel_RT.c SGEMM_SMALL_M_PERMIT = sgemm_small_kernel_permit_skylakex.c SGEMM_SMALL_K_NN = sgemm_small_kernel_nn_skylakex.c -SGEMM_SMALL_K_B0_NN = sgemm_small_kernel_b0_nn_skylakex.c +SGEMM_SMALL_K_B0_NN = sgemm_small_kernel_nn_skylakex.c SGEMM_SMALL_K_NT = sgemm_small_kernel_nt_skylakex.c -SGEMM_SMALL_K_B0_NT = sgemm_small_kernel_b0_nt_skylakex.c +SGEMM_SMALL_K_B0_NT = sgemm_small_kernel_nt_skylakex.c SGEMM_SMALL_K_TN = sgemm_small_kernel_tn_skylakex.c -SGEMM_SMALL_K_B0_TN = sgemm_small_kernel_b0_tn_skylakex.c +SGEMM_SMALL_K_B0_TN = sgemm_small_kernel_tn_skylakex.c SGEMM_SMALL_K_TT = sgemm_small_kernel_tt_skylakex.c -SGEMM_SMALL_K_B0_TT = sgemm_small_kernel_b0_tt_skylakex.c +SGEMM_SMALL_K_B0_TT = sgemm_small_kernel_tt_skylakex.c DGEMMKERNEL = dgemm_kernel_16x2_skylakex.c DTRMMKERNEL = dgemm_kernel_16x2_skylakex.c @@ -29,13 +29,13 @@ DGEMMOTCOPY = ../generic/gemm_tcopy_2.c DTRSMKERNEL_RN = ../generic/trsm_kernel_RN.c DGEMM_SMALL_M_PERMIT = dgemm_small_kernel_permit_skylakex.c DGEMM_SMALL_K_NN = dgemm_small_kernel_nn_skylakex.c -DGEMM_SMALL_K_B0_NN = dgemm_small_kernel_b0_nn_skylakex.c +DGEMM_SMALL_K_B0_NN = dgemm_small_kernel_nn_skylakex.c DGEMM_SMALL_K_NT = dgemm_small_kernel_nt_skylakex.c -DGEMM_SMALL_K_B0_NT = dgemm_small_kernel_b0_nt_skylakex.c +DGEMM_SMALL_K_B0_NT = dgemm_small_kernel_nt_skylakex.c DGEMM_SMALL_K_TN = dgemm_small_kernel_tn_skylakex.c -DGEMM_SMALL_K_B0_TN = dgemm_small_kernel_b0_tn_skylakex.c +DGEMM_SMALL_K_B0_TN = dgemm_small_kernel_tn_skylakex.c DGEMM_SMALL_K_TT = dgemm_small_kernel_tt_skylakex.c -DGEMM_SMALL_K_B0_TT = dgemm_small_kernel_b0_tt_skylakex.c +DGEMM_SMALL_K_B0_TT = dgemm_small_kernel_tt_skylakex.c SGEMM_BETA = sgemm_beta_skylakex.c DGEMM_BETA = dgemm_beta_skylakex.c diff --git a/kernel/x86_64/dgemm_small_kernel_b0_nn_skylakex.c b/kernel/x86_64/dgemm_small_kernel_b0_nn_skylakex.c deleted file mode 100644 index a58738a25..000000000 --- a/kernel/x86_64/dgemm_small_kernel_b0_nn_skylakex.c +++ /dev/null @@ -1,2 +0,0 @@ -#define B0 1 -#include "./dgemm_small_kernel_nn_skylakex.c" diff --git a/kernel/x86_64/dgemm_small_kernel_b0_nt_skylakex.c b/kernel/x86_64/dgemm_small_kernel_b0_nt_skylakex.c deleted file mode 100644 index eafe2ce49..000000000 --- a/kernel/x86_64/dgemm_small_kernel_b0_nt_skylakex.c +++ /dev/null @@ -1,2 +0,0 @@ -#define B0 1 -#include "./dgemm_small_kernel_nt_skylakex.c" diff --git a/kernel/x86_64/dgemm_small_kernel_b0_tn_skylakex.c b/kernel/x86_64/dgemm_small_kernel_b0_tn_skylakex.c deleted file mode 100644 index 1dfa0aaf1..000000000 --- a/kernel/x86_64/dgemm_small_kernel_b0_tn_skylakex.c +++ /dev/null @@ -1,2 +0,0 @@ -#define B0 1 -#include "./dgemm_small_kernel_tn_skylakex.c" diff --git a/kernel/x86_64/dgemm_small_kernel_b0_tt_skylakex.c b/kernel/x86_64/dgemm_small_kernel_b0_tt_skylakex.c deleted file mode 100644 index 93fab1836..000000000 --- a/kernel/x86_64/dgemm_small_kernel_b0_tt_skylakex.c +++ /dev/null @@ -1,2 +0,0 @@ -#define B0 1 -#include "./dgemm_small_kernel_tt_skylakex.c" diff --git a/kernel/x86_64/sgemm_small_kernel_b0_nn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_b0_nn_skylakex.c deleted file mode 100644 index 704e964b8..000000000 --- a/kernel/x86_64/sgemm_small_kernel_b0_nn_skylakex.c +++ /dev/null @@ -1,2 +0,0 @@ -#define B0 1 -#include "./sgemm_small_kernel_nn_skylakex.c" diff --git a/kernel/x86_64/sgemm_small_kernel_b0_nt_skylakex.c b/kernel/x86_64/sgemm_small_kernel_b0_nt_skylakex.c deleted file mode 100644 index 6d7934be1..000000000 --- a/kernel/x86_64/sgemm_small_kernel_b0_nt_skylakex.c +++ /dev/null @@ -1,2 +0,0 @@ -#define B0 1 -#include "./sgemm_small_kernel_nt_skylakex.c" diff --git a/kernel/x86_64/sgemm_small_kernel_b0_tn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_b0_tn_skylakex.c deleted file mode 100644 index 0f9745b72..000000000 --- a/kernel/x86_64/sgemm_small_kernel_b0_tn_skylakex.c +++ /dev/null @@ -1,2 +0,0 @@ -#define B0 1 -#include "./sgemm_small_kernel_tn_skylakex.c" diff --git a/kernel/x86_64/sgemm_small_kernel_b0_tt_skylakex.c b/kernel/x86_64/sgemm_small_kernel_b0_tt_skylakex.c deleted file mode 100644 index 27d9e0afd..000000000 --- a/kernel/x86_64/sgemm_small_kernel_b0_tt_skylakex.c +++ /dev/null @@ -1,3 +0,0 @@ -#define B0 1 -#define TT 1 -#include "./sgemm_small_kernel_tt_skylakex.c"