diff --git a/common_d.h b/common_d.h index dad304a5f..f5d7935fa 100644 --- a/common_d.h +++ b/common_d.h @@ -163,6 +163,11 @@ #define DGEMM_SMALL_KERNEL_TN dgemm_small_kernel_tn #define DGEMM_SMALL_KERNEL_TT dgemm_small_kernel_tt +#define DGEMM_SMALL_KERNEL_A1B0_NN dgemm_small_kernel_a1b0_nn +#define DGEMM_SMALL_KERNEL_A1B0_NT dgemm_small_kernel_a1b0_nt +#define DGEMM_SMALL_KERNEL_A1B0_TN dgemm_small_kernel_a1b0_tn +#define DGEMM_SMALL_KERNEL_A1B0_TT dgemm_small_kernel_a1b0_tt + #else #define DAMAX_K gotoblas -> damax_k diff --git a/common_level3.h b/common_level3.h index 751592b67..31d514cd5 100644 --- a/common_level3.h +++ b/common_level3.h @@ -525,6 +525,17 @@ int dgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLO int dgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); int dgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); int dgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); + +int sgemm_small_kernel_a1b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sgemm_small_kernel_a1b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sgemm_small_kernel_a1b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sgemm_small_kernel_a1b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + +int dgemm_small_kernel_a1b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int dgemm_small_kernel_a1b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int dgemm_small_kernel_a1b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int dgemm_small_kernel_a1b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + #endif int cgemm_kernel_n(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG); diff --git a/common_macro.h b/common_macro.h index eb2abcdc0..2f7263023 100644 --- a/common_macro.h +++ b/common_macro.h @@ -648,6 +648,10 @@ #define GEMM_SMALL_KERNEL_NT DGEMM_SMALL_KERNEL_NT #define GEMM_SMALL_KERNEL_TN DGEMM_SMALL_KERNEL_TN #define GEMM_SMALL_KERNEL_TT DGEMM_SMALL_KERNEL_TT +#define GEMM_SMALL_KERNEL_A1B0_NN DGEMM_SMALL_KERNEL_A1B0_NN +#define GEMM_SMALL_KERNEL_A1B0_NT DGEMM_SMALL_KERNEL_A1B0_NT +#define GEMM_SMALL_KERNEL_A1B0_TN DGEMM_SMALL_KERNEL_A1B0_TN +#define GEMM_SMALL_KERNEL_A1B0_TT DGEMM_SMALL_KERNEL_A1B0_TT #elif defined(BFLOAT16) @@ -941,6 +945,11 @@ #define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN #define GEMM_SMALL_KERNEL_TT SGEMM_SMALL_KERNEL_TT +#define GEMM_SMALL_KERNEL_A1B0_NN SGEMM_SMALL_KERNEL_A1B0_NN +#define GEMM_SMALL_KERNEL_A1B0_NT SGEMM_SMALL_KERNEL_A1B0_NT +#define GEMM_SMALL_KERNEL_A1B0_TN SGEMM_SMALL_KERNEL_A1B0_TN +#define GEMM_SMALL_KERNEL_A1B0_TT SGEMM_SMALL_KERNEL_A1B0_TT + #endif #else @@ -1252,6 +1261,11 @@ #define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN #define GEMM_SMALL_KERNEL_TT SGEMM_SMALL_KERNEL_TT +#define GEMM_SMALL_KERNEL_A1B0_NN SGEMM_SMALL_KERNEL_A1B0_NN +#define GEMM_SMALL_KERNEL_A1B0_NT SGEMM_SMALL_KERNEL_A1B0_NT +#define GEMM_SMALL_KERNEL_A1B0_TN SGEMM_SMALL_KERNEL_A1B0_TN +#define GEMM_SMALL_KERNEL_A1B0_TT SGEMM_SMALL_KERNEL_A1B0_TT + #endif #else #ifdef XDOUBLE diff --git a/common_s.h b/common_s.h index 6ad98ba8b..440b78723 100644 --- a/common_s.h +++ b/common_s.h @@ -169,6 +169,11 @@ #define SGEMM_SMALL_KERNEL_TN sgemm_small_kernel_tn #define SGEMM_SMALL_KERNEL_TT sgemm_small_kernel_tt +#define SGEMM_SMALL_KERNEL_A1B0_NN sgemm_small_kernel_a1b0_nn +#define SGEMM_SMALL_KERNEL_A1B0_NT sgemm_small_kernel_a1b0_nt +#define SGEMM_SMALL_KERNEL_A1B0_TN sgemm_small_kernel_a1b0_tn +#define SGEMM_SMALL_KERNEL_A1B0_TT sgemm_small_kernel_a1b0_tt + #else #define SAMAX_K gotoblas -> samax_k diff --git a/interface/gemm.c b/interface/gemm.c index d2fb42ff7..da602f7a9 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -115,6 +115,15 @@ static int (*gemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLON #endif #endif }; + +static int (*gemm_small_kernel_a1b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = { +#ifndef GEMM3M +#ifndef COMPLEX + GEMM_SMALL_KERNEL_A1B0_NN, GEMM_SMALL_KERNEL_A1B0_TN, NULL, NULL, + GEMM_SMALL_KERNEL_A1B0_NT, GEMM_SMALL_KERNEL_A1B0_TT, NULL, NULL, +#endif +#endif +}; #endif #ifndef CBLAS @@ -435,8 +444,13 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS #if !defined(COMPLEX) //need to tune small matrices cases. if(MNK <= 100.0*100.0*100.0){ - (gemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, - args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); + + if(*(FLOAT *)(args.alpha) == 1.0 && *(FLOAT *)(args.beta) == 0.0){ + (gemm_small_kernel_a1b0[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda,args.b, args.ldb, args.c, args.ldc); + }else{ + (gemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); + } + return; } #endif diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 88e5eb2d6..448d22e4e 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -452,11 +452,15 @@ ifeq ($(SMALL_MATRIX_OPT), 1) SBLASOBJS += \ sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ - sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) + sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ + sgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) \ + sgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) DBLASOBJS += \ dgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ - dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) + dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ + dgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) \ + dgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) endif @@ -4282,6 +4286,34 @@ $(KDIR)dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_ $(KDIR)dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_TT) $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ +ifndef DGEMM_SAMLL_K_A1B0_NN +DGEMM_SAMLL_K_A1B0_NN = ../generic/gemm_small_matrix_kernel_a1b0_nn.c +endif + +ifndef DGEMM_SAMLL_K_A1B0_NT +DGEMM_SAMLL_K_A1B0_NT = ../generic/gemm_small_matrix_kernel_a1b0_nt.c +endif + +ifndef DGEMM_SAMLL_K_A1B0_TN +DGEMM_SAMLL_K_A1B0_TN = ../generic/gemm_small_matrix_kernel_a1b0_tn.c +endif + +ifndef DGEMM_SAMLL_K_A1B0_TT +DGEMM_SAMLL_K_A1B0_TT = ../generic/gemm_small_matrix_kernel_a1b0_tt.c +endif + +$(KDIR)dgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_A1B0_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)dgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_A1B0_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)dgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_A1B0_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)dgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SAMLL_K_A1B0_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + ifndef SGEMM_SAMLL_K_NN SGEMM_SAMLL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c @@ -4310,3 +4342,31 @@ $(KDIR)sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_ $(KDIR)sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_TT) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +ifndef SGEMM_SAMLL_K_A1B0_NN +SGEMM_SAMLL_K_A1B0_NN = ../generic/gemm_small_matrix_kernel_a1b0_nn.c +endif + +ifndef SGEMM_SAMLL_K_A1B0_NT +SGEMM_SAMLL_K_A1B0_NT = ../generic/gemm_small_matrix_kernel_a1b0_nt.c +endif + +ifndef SGEMM_SAMLL_K_A1B0_TN +SGEMM_SAMLL_K_A1B0_TN = ../generic/gemm_small_matrix_kernel_a1b0_tn.c +endif + +ifndef SGEMM_SAMLL_K_A1B0_TT +SGEMM_SAMLL_K_A1B0_TT = ../generic/gemm_small_matrix_kernel_a1b0_tt.c +endif + +$(KDIR)sgemm_small_kernel_a1b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_A1B0_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sgemm_small_kernel_a1b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_A1B0_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sgemm_small_kernel_a1b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_A1B0_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sgemm_small_kernel_a1b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SAMLL_K_A1B0_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ diff --git a/kernel/generic/gemm_small_matrix_kernel_a1b0_nn.c b/kernel/generic/gemm_small_matrix_kernel_a1b0_nn.c new file mode 100644 index 000000000..8e3417027 --- /dev/null +++ b/kernel/generic/gemm_small_matrix_kernel_a1b0_nn.c @@ -0,0 +1,49 @@ +/*************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT * B, BLASLONG ldb,FLOAT * C, BLASLONG ldc) +{ + //naive implemtation + //Column major + + BLASLONG i,j,k; + FLOAT result=0.0; + + for(i=0; i