diff --git a/cblas.h b/cblas.h index 4bc5588d8..3c938553d 100644 --- a/cblas.h +++ b/cblas.h @@ -382,6 +382,17 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta, double *c, OPENBLAS_CONST blasint cldc); +void cblas_sgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, + OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST float ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST float ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); + +void cblas_dgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, + OPENBLAS_CONST double * alpha_array, OPENBLAS_CONST double ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST double ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST double * beta_array, double ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); + +void cblas_cgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, + OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); + +void cblas_zgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, + OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); #ifdef __cplusplus } diff --git a/common_level3.h b/common_level3.h index 1528e0134..70f7df2c5 100644 --- a/common_level3.h +++ b/common_level3.h @@ -1919,6 +1919,10 @@ int dgeadd_k(BLASLONG, BLASLONG, double, double*, BLASLONG, double, double *, BL int cgeadd_k(BLASLONG, BLASLONG, float, float, float*, BLASLONG, float, float, float *, BLASLONG); int zgeadd_k(BLASLONG, BLASLONG, double,double, double*, BLASLONG, double, double, double *, BLASLONG); +int sgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); +int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); +int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); +int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); #ifdef __CUDACC__ } diff --git a/common_macro.h b/common_macro.h index 1762650bb..3f6d31809 100644 --- a/common_macro.h +++ b/common_macro.h @@ -2636,7 +2636,17 @@ typedef struct { BLASLONG prea, preb, prec, pred; #endif + //for gemm_batch + void * routine; + int routine_mode; + } blas_arg_t; + +#ifdef SMALL_MATRIX_OPT +#define BLAS_SMALL_OPT 0x10000U +#define BLAS_SMALL_B0_OPT 0x30000U +#endif + #endif #ifdef XDOUBLE diff --git a/driver/level3/Makefile b/driver/level3/Makefile index 09a62d9bf..141b66e82 100644 --- a/driver/level3/Makefile +++ b/driver/level3/Makefile @@ -37,7 +37,7 @@ SBLASOBJS += \ ssyrk_UN.$(SUFFIX) ssyrk_UT.$(SUFFIX) ssyrk_LN.$(SUFFIX) ssyrk_LT.$(SUFFIX) \ ssyr2k_UN.$(SUFFIX) ssyr2k_UT.$(SUFFIX) ssyr2k_LN.$(SUFFIX) ssyr2k_LT.$(SUFFIX) \ ssyrk_kernel_U.$(SUFFIX) ssyrk_kernel_L.$(SUFFIX) \ - ssyr2k_kernel_U.$(SUFFIX) ssyr2k_kernel_L.$(SUFFIX) + ssyr2k_kernel_U.$(SUFFIX) ssyr2k_kernel_L.$(SUFFIX) sgemm_batch_thread.$(SUFFIX) DBLASOBJS += \ dgemm_nn.$(SUFFIX) dgemm_nt.$(SUFFIX) dgemm_tn.$(SUFFIX) dgemm_tt.$(SUFFIX) \ @@ -53,7 +53,7 @@ DBLASOBJS += \ dsyrk_UN.$(SUFFIX) dsyrk_UT.$(SUFFIX) dsyrk_LN.$(SUFFIX) dsyrk_LT.$(SUFFIX) \ dsyr2k_UN.$(SUFFIX) dsyr2k_UT.$(SUFFIX) dsyr2k_LN.$(SUFFIX) dsyr2k_LT.$(SUFFIX) \ dsyrk_kernel_U.$(SUFFIX) dsyrk_kernel_L.$(SUFFIX) \ - dsyr2k_kernel_U.$(SUFFIX) dsyr2k_kernel_L.$(SUFFIX) + dsyr2k_kernel_U.$(SUFFIX) dsyr2k_kernel_L.$(SUFFIX) dgemm_batch_thread.$(SUFFIX) QBLASOBJS += \ qgemm_nn.$(SUFFIX) qgemm_nt.$(SUFFIX) qgemm_tn.$(SUFFIX) qgemm_tt.$(SUFFIX) \ @@ -103,7 +103,7 @@ CBLASOBJS += \ cherk_kernel_LN.$(SUFFIX) cherk_kernel_LC.$(SUFFIX) \ csyr2k_kernel_U.$(SUFFIX) csyr2k_kernel_L.$(SUFFIX) \ cher2k_kernel_UN.$(SUFFIX) cher2k_kernel_UC.$(SUFFIX) \ - cher2k_kernel_LN.$(SUFFIX) cher2k_kernel_LC.$(SUFFIX) + cher2k_kernel_LN.$(SUFFIX) cher2k_kernel_LC.$(SUFFIX) cgemm_batch_thread.$(SUFFIX) ZBLASOBJS += \ zgemm_nn.$(SUFFIX) zgemm_cn.$(SUFFIX) zgemm_tn.$(SUFFIX) zgemm_nc.$(SUFFIX) \ @@ -137,7 +137,7 @@ ZBLASOBJS += \ zherk_kernel_LN.$(SUFFIX) zherk_kernel_LC.$(SUFFIX) \ zsyr2k_kernel_U.$(SUFFIX) zsyr2k_kernel_L.$(SUFFIX) \ zher2k_kernel_UN.$(SUFFIX) zher2k_kernel_UC.$(SUFFIX) \ - zher2k_kernel_LN.$(SUFFIX) zher2k_kernel_LC.$(SUFFIX) + zher2k_kernel_LN.$(SUFFIX) zher2k_kernel_LC.$(SUFFIX) zgemm_batch_thread.$(SUFFIX) XBLASOBJS += \ @@ -2888,6 +2888,18 @@ gemm_thread_variable.$(PSUFFIX) : gemm_thread_variable.c ../../common.h beta_thread.$(PSUFFIX) : beta_thread.c ../../common.h $(CC) -c $(PFLAGS) $< -o $(@F) +sgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h + $(CC) -c $(CFLAGS) $< -o $(@F) + +dgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h + $(CC) -c $(CFLAGS) $< -o $(@F) + +cgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h + $(CC) -c $(CFLAGS) $< -o $(@F) + +zgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h + $(CC) -c $(CFLAGS) $< -o $(@F) + shgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) diff --git a/driver/level3/gemm_batch_thread.c b/driver/level3/gemm_batch_thread.c new file mode 100644 index 000000000..8e0b77c75 --- /dev/null +++ b/driver/level3/gemm_batch_thread.c @@ -0,0 +1,151 @@ +/***************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + 3. Neither the name of the OpenBLAS project nor the names of + its contributors may be used to endorse or promote products + derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +**********************************************************************************/ + +#include "common.h" + +void openblas_warning(int verbose, const char * msg); + +#ifdef SMALL_MATRIX_OPT +static int inner_small_matrix_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){ + int routine_mode; +#ifndef COMPLEX + int (*gemm_small_kernel)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG); + int (*gemm_small_kernel_b0)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG); +#else + int (*zgemm_small_kernel)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG); + int (*zgemm_small_kernel_b0)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG); + FLOAT alpha[2], beta[2]; +#endif + routine_mode=args.routine_mode; + if(routine_mode & BLAS_SMALL_B0_OPT){ +#ifndef COMPLEX + gemm_small_kernel_b0=args.routine; + gemm_small_kernel_b0(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, args.c, args.ldc); +#else + zgemm_small_kernel_b0=args.routine; + alpha[0]=(FLOAT *)(args.alpha)[0]; + alpha[1]=(FLOAT *)(args.alpha)[1]; + zgemm_small_kernel_b0(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, args.c, args.ldc); +#endif + }else if(routine_mode & BLAS_SMALL_OPT){ +#ifndef COMPLEX + gemm_small_kernel=args.routine; + gemm_small_kernel(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); +#else + zgemm_small_kernel=args.routine; + alpha[0]=(FLOAT *)(args.alpha)[0]; + alpha[1]=(FLOAT *)(args.alpha)[1]; + beta[0]=(FLOAT *)(args.beta)[0]; + beta[1]=(FLOAT *)(args.beta)[1]; + zgemm_small_kernel(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, beta[0], beta[1], args.c, args.ldc); +#endif + } +} +#endif + +int CNAME(blas_arg_t * args_array, BLASLONG nums){ + XFLOAT *buffer; + XFLOAT *sa, *sb; + int nthreads=1; + int (*routine)(blas_arg_t *, void *, void *, double *, double *, BLASLONG); + int i=0, j, current_nums; + +#ifdef SMP + blas_queue_t * queue=NULL; +#endif + + if(nums <=0 ) return 0; + + buffer = (XFLOAT *)blas_memory_alloc(0); + sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A); + sb = (XFLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); + +#ifdef SMP + nthreads=num_cpu_avail(3); + + if(nthreads==1){ +#endif + //single thread + for(i=0; inthreads)? nthreads: (nums-i); + + queue[i].sa=sa; + queue[i].sb=sb; + queue[i+current_nums-1].next=NULL; + + exec_blas(current_nums, &queue[i]); + } + free(queue); + } +#endif + blas_memory_free(buffer); + return 0; +} diff --git a/exports/gensymbol b/exports/gensymbol index 73b4be248..e8b992c51 100644 --- a/exports/gensymbol +++ b/exports/gensymbol @@ -81,6 +81,7 @@ cblas_ismin, cblas_idmin, cblas_icmin, cblas_izmin, cblas_ismax, cblas_idmax, cblas_icmax, cblas_izmax, cblas_ssum, cblas_dsum, cblas_scsum, cblas_dzsum, + cblas_sgemm_batch, cblas_dgemm_batch, cblas_cgemm_batch, cblas_zgemm_batch, cblas_xerbla ); diff --git a/interface/Makefile b/interface/Makefile index 2dbd60073..c10741865 100644 --- a/interface/Makefile +++ b/interface/Makefile @@ -278,7 +278,7 @@ CSBLAS2OBJS = \ CSBLAS3OBJS = \ cblas_sgemm.$(SUFFIX) cblas_ssymm.$(SUFFIX) cblas_strmm.$(SUFFIX) cblas_strsm.$(SUFFIX) \ cblas_ssyrk.$(SUFFIX) cblas_ssyr2k.$(SUFFIX) cblas_somatcopy.$(SUFFIX) cblas_simatcopy.$(SUFFIX)\ - cblas_sgeadd.$(SUFFIX) + cblas_sgeadd.$(SUFFIX) cblas_sgemm_batch.$(SUFFIX) ifeq ($(BUILD_HALF),1) CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX) @@ -300,7 +300,7 @@ CDBLAS2OBJS = \ CDBLAS3OBJS += \ cblas_dgemm.$(SUFFIX) cblas_dsymm.$(SUFFIX) cblas_dtrmm.$(SUFFIX) cblas_dtrsm.$(SUFFIX) \ cblas_dsyrk.$(SUFFIX) cblas_dsyr2k.$(SUFFIX) cblas_domatcopy.$(SUFFIX) cblas_dimatcopy.$(SUFFIX) \ - cblas_dgeadd.$(SUFFIX) + cblas_dgeadd.$(SUFFIX) cblas_dgemm_batch.$(SUFFIX) CCBLAS1OBJS = \ cblas_icamax.$(SUFFIX) cblas_icamin.$(SUFFIX) cblas_scasum.$(SUFFIX) cblas_caxpy.$(SUFFIX) \ @@ -325,7 +325,7 @@ CCBLAS3OBJS = \ cblas_csyrk.$(SUFFIX) cblas_csyr2k.$(SUFFIX) \ cblas_chemm.$(SUFFIX) cblas_cherk.$(SUFFIX) cblas_cher2k.$(SUFFIX) \ cblas_comatcopy.$(SUFFIX) cblas_cimatcopy.$(SUFFIX)\ - cblas_cgeadd.$(SUFFIX) cblas_xerbla.$(SUFFIX) + cblas_cgeadd.$(SUFFIX) cblas_xerbla.$(SUFFIX) cblas_cgemm_batch.$(SUFFIX) @@ -353,7 +353,7 @@ CZBLAS3OBJS = \ cblas_zsyrk.$(SUFFIX) cblas_zsyr2k.$(SUFFIX) \ cblas_zhemm.$(SUFFIX) cblas_zherk.$(SUFFIX) cblas_zher2k.$(SUFFIX)\ cblas_zomatcopy.$(SUFFIX) cblas_zimatcopy.$(SUFFIX) \ - cblas_zgeadd.$(SUFFIX) + cblas_zgeadd.$(SUFFIX) cblas_zgemm_batch.$(SUFFIX) ifeq ($(SUPPORT_GEMM3M), 1) @@ -2236,3 +2236,15 @@ cblas_zgeadd.$(SUFFIX) cblas_zgeadd.$(PSUFFIX) : zgeadd.c cblas_xerbla.$(SUFFIX) cblas_xerbla.$(PSUFFIX) : xerbla.c $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) + +cblas_sgemm_batch.$(SUFFIX) cblas_sgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h + $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) + +cblas_dgemm_batch.$(SUFFIX) cblas_dgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h + $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) + +cblas_cgemm_batch.$(SUFFIX) cblas_cgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h + $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) + +cblas_zgemm_batch.$(SUFFIX) cblas_zgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h + $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) diff --git a/interface/gemm_batch.c b/interface/gemm_batch.c new file mode 100644 index 000000000..048b5abbf --- /dev/null +++ b/interface/gemm_batch.c @@ -0,0 +1,358 @@ +/***************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + 3. Neither the name of the OpenBLAS project nor the names of + its contributors may be used to endorse or promote products + derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +**********************************************************************************/ + +#include +#include +#include "common.h" + +void openblas_warning(int verbose, const char * msg); + +#ifndef COMPLEX +#ifdef XDOUBLE +#define ERROR_NAME "QGEMM_BATCH " +#elif defined(DOUBLE) +#define ERROR_NAME "DGEMM_BATCH " +#define GEMM_BATCH_THREAD dgemm_batch_thread +#else +#define ERROR_NAME "SGEMM_BATCH " +#define GEMM_BATCH_THREAD sgemm_batch_thread +#endif +#else +#ifdef XDOUBLE +#define ERROR_NAME "XGEMM_BATCH " +#elif defined(DOUBLE) +#define ERROR_NAME "ZGEMM_BATCH " +#define GEMM_BATCH_THREAD zgemm_batch_thread +#else +#define ERROR_NAME "CGEMM_BATCH " +#define GEMM_BATCH_THREAD cgemm_batch_thread +#endif +#endif +static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = { + GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN, + GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT, + GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR, + GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC, +}; + +#ifdef SMALL_MATRIX_OPT + +#ifndef COMPLEX +static int (*gemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG) = { +#ifndef GEMM3M + GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, NULL, NULL, + GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, NULL, NULL, +#endif +}; + +static int (*gemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = { +#ifndef GEMM3M + GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, NULL, NULL, + GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, NULL, NULL, +#endif +}; + +#else + +static int (*zgemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG) = { +#ifndef GEMM3M + GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN, + GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT, + GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR, + GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC, +#endif +}; + +static int (*zgemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = { +#ifndef GEMM3M + GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN, + GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT, + GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR, + GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC, +#endif +}; +#endif +#endif + +void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CBLAS_TRANSPOSE * transb_array, + blasint * m_array, blasint * n_array, blasint * k_array, +#ifndef COMPLEX + FLOAT * alpha_array, + FLOAT ** a_array, blasint * lda_array, + FLOAT ** b_array, blasint * ldb_array, + FLOAT * beta_array, + FLOAT ** c_array, blasint * ldc_array, blasint group_count, blasint * group_size) { +#else + void * valpha_array, + void ** va_array, blasint * lda_array, + void ** vb_array, blasint * ldb_array, + void * vbeta_array, + void ** vc_array, blasint * ldc_array, blasint group_count, blasint * group_size) { + + FLOAT * alpha_array=(FLOAT *)valpha_array; + FLOAT * beta_array=(FLOAT *)vbeta_array; + FLOAT ** a_array=(FLOAT**)va_array; + FLOAT ** b_array=(FLOAT**)vb_array; + FLOAT ** c_array=(FLOAT**)vc_array; + +#endif + blas_arg_t * args_array=NULL; + + int mode=0, group_mode=0; + blasint total_num=0; + + blasint i=0, j=0, matrix_idx=0, count=0; + + int group_transa, group_transb; + BLASLONG group_nrowa, group_nrowb; + blasint info; + + void * group_alpha, * group_beta; + BLASLONG group_m, group_n, group_k; + BLASLONG group_lda, group_ldb, group_ldc; + void * group_routine=NULL; +#ifdef SMALL_MATRIX_OPT + void * group_small_matrix_opt_routine=NULL; +#endif + +#if defined (SMP) || defined(SMALL_MATRIX_OPT) + double MNK; +#endif + + PRINT_DEBUG_CNAME; + + for(i=0; i= 0) { + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + free(args_array); + return; + } + + if (group_m == 0 || group_n == 0) continue; + + group_mode=mode; + +#if defined(SMP) || defined(SMALL_MATRIX_OPT) + MNK = (double) group_m * (double) group_n * (double) group_k; +#endif + +#ifdef SMALL_MATRIX_OPT + if(MNK <= 100.0*100.0*100.0){ + group_routine=NULL; +#if !defined(COMPLEX) + if(*(FLOAT *)(group_beta) == 0.0){ + group_mode=mode | BLAS_SMALL_B0_OPT; + group_small_matrix_opt_routine=(void *)(gemm_small_kernel_b0[(group_transb<<2)|group_transa]); + }else{ + group_mode=mode | BLAS_SMALL_OPT; + group_small_matrix_opt_routine=(void *)(gemm_small_kernel[(group_transb<<2)|group_transa]); + } +#else + if(((FLOAT *)(group_beta))[0] == 0.0 && ((FLOAT *)(group_beta))[1] == 0.0){ + group_mode=mode | BLAS_SMALL_B0_OPT; + group_small_matrix_opt_routine=(void *)(zgemm_small_kernel_b0[(group_transb<<2)|group_transa]); + }else{ + group_mode=mode | BLAS_SMALL_OPT; + group_small_matrix_opt_routine=(void *)(zgemm_small_kernel[(group_transb<<2)|group_transa]); + } + +#endif + + }else{ +#endif + group_routine=(void*)(gemm[(group_transb<<2)|group_transa]); +#ifdef SMALL_MATRIX_OPT + } +#endif + + + for(j=0; j0){ + GEMM_BATCH_THREAD(args_array,count); + } + + free(args_array); +}