From 103637887e35062aae8e795805c13ac9dfb6ca31 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Wed, 29 May 2024 15:46:10 +0200 Subject: [PATCH 1/9] add cblas_?gemm_batch --- cblas.h | 15 +++++++++++++++ common_level3.h | 7 ++++++- common_macro.h | 11 +++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/cblas.h b/cblas.h index ab8b7dcde..097b4303d 100644 --- a/cblas.h +++ b/cblas.h @@ -416,6 +416,18 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta, double *c, OPENBLAS_CONST blasint cldc); +void cblas_sgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, + OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST float ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST float ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); + +void cblas_dgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, + OPENBLAS_CONST double * alpha_array, OPENBLAS_CONST double ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST double ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST double * beta_array, double ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); + +void cblas_cgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, + OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); + +void cblas_zgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, + OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); + /*** BFLOAT16 and INT8 extensions ***/ /* convert float array to BFLOAT16 array by rounding */ void cblas_sbstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout); @@ -431,6 +443,9 @@ void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc); +void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, + OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); + #ifdef __cplusplus } #endif /* __cplusplus */ diff --git a/common_level3.h b/common_level3.h index 5080ada10..d370c1f96 100644 --- a/common_level3.h +++ b/common_level3.h @@ -1937,8 +1937,13 @@ int zimatcopy_k_rtc(BLASLONG, BLASLONG, double, double, double *, BLASLONG); int sgeadd_k(BLASLONG, BLASLONG, float, float*, BLASLONG, float, float *, BLASLONG); int dgeadd_k(BLASLONG, BLASLONG, double, double*, BLASLONG, double, double *, BLASLONG); int cgeadd_k(BLASLONG, BLASLONG, float, float, float*, BLASLONG, float, float, float *, BLASLONG); -int zgeadd_k(BLASLONG, BLASLONG, double,double, double*, BLASLONG, double, double, double *, BLASLONG); +int zgeadd_k(BLASLONG, BLASLONG, double,double, double*, BLASLONG, double, double, double *, BLASLONG); +int sgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); +int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); +int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); +int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); +int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); #ifdef __CUDACC__ } diff --git a/common_macro.h b/common_macro.h index 3226d0f11..a924651de 100644 --- a/common_macro.h +++ b/common_macro.h @@ -2655,9 +2655,20 @@ typedef struct { BLASLONG prea, preb, prec, pred; #endif + + //for gemm_batch + void * routine; + int routine_mode; + } blas_arg_t; #endif +#ifdef SMALL_MATRIX_OPT +#define BLAS_SMALL_OPT 0x10000U +#define BLAS_SMALL_B0_OPT 0x30000U +#endif + + #ifdef XDOUBLE #define TRSV_NUU qtrsv_NUU From 89c7bbcba65cc4dc639bf5030910b307d2861df8 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Wed, 29 May 2024 15:47:02 +0200 Subject: [PATCH 2/9] add cblas_?gemm_batch --- interface/Makefile | 24 ++- interface/gemm_batch.c | 372 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 391 insertions(+), 5 deletions(-) create mode 100644 interface/gemm_batch.c diff --git a/interface/Makefile b/interface/Makefile index 048d679d6..97439d87f 100644 --- a/interface/Makefile +++ b/interface/Makefile @@ -282,12 +282,12 @@ CSBLAS2OBJS = \ CSBLAS3OBJS = \ cblas_sgemm.$(SUFFIX) cblas_ssymm.$(SUFFIX) cblas_strmm.$(SUFFIX) cblas_strsm.$(SUFFIX) \ cblas_ssyrk.$(SUFFIX) cblas_ssyr2k.$(SUFFIX) cblas_somatcopy.$(SUFFIX) cblas_simatcopy.$(SUFFIX)\ - cblas_sgeadd.$(SUFFIX) cblas_sgemmt.$(SUFFIX) + cblas_sgeadd.$(SUFFIX) cblas_sgemmt.$(SUFFIX) cblas_sgemm_batch.$(SUFFIX) ifeq ($(BUILD_BFLOAT16),1) CSBBLAS1OBJS = cblas_sbdot.$(SUFFIX) CSBBLAS2OBJS = cblas_sbgemv.$(SUFFIX) -CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) +CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) cblas_sbgemm_batch.$(SUFFIX) CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) endif @@ -308,7 +308,7 @@ CDBLAS2OBJS = \ CDBLAS3OBJS += \ cblas_dgemm.$(SUFFIX) cblas_dsymm.$(SUFFIX) cblas_dtrmm.$(SUFFIX) cblas_dtrsm.$(SUFFIX) \ cblas_dsyrk.$(SUFFIX) cblas_dsyr2k.$(SUFFIX) cblas_domatcopy.$(SUFFIX) cblas_dimatcopy.$(SUFFIX) \ - cblas_dgeadd.$(SUFFIX) cblas_dgemmt.$(SUFFIX) + cblas_dgeadd.$(SUFFIX) cblas_dgemmt.$(SUFFIX) cblas_dgemm_batch.$(SUFFIX) CCBLAS1OBJS = \ cblas_icamax.$(SUFFIX) cblas_icamin.$(SUFFIX) cblas_scasum.$(SUFFIX) cblas_caxpy.$(SUFFIX) \ @@ -333,7 +333,7 @@ CCBLAS3OBJS = \ cblas_csyrk.$(SUFFIX) cblas_csyr2k.$(SUFFIX) \ cblas_chemm.$(SUFFIX) cblas_cherk.$(SUFFIX) cblas_cher2k.$(SUFFIX) \ cblas_comatcopy.$(SUFFIX) cblas_cimatcopy.$(SUFFIX)\ - cblas_cgeadd.$(SUFFIX) cblas_cgemmt.$(SUFFIX) + cblas_cgeadd.$(SUFFIX) cblas_cgemmt.$(SUFFIX) cblas_cgemm_batch.$(SUFFIX) CXERBLAOBJ = \ cblas_xerbla.$(SUFFIX) @@ -364,7 +364,7 @@ CZBLAS3OBJS = \ cblas_zsyrk.$(SUFFIX) cblas_zsyr2k.$(SUFFIX) \ cblas_zhemm.$(SUFFIX) cblas_zherk.$(SUFFIX) cblas_zher2k.$(SUFFIX)\ cblas_zomatcopy.$(SUFFIX) cblas_zimatcopy.$(SUFFIX) \ - cblas_zgeadd.$(SUFFIX) cblas_zgemmt.$(SUFFIX) + cblas_zgeadd.$(SUFFIX) cblas_zgemmt.$(SUFFIX) cblas_zgemm_batch.$(SUFFIX) ifeq ($(SUPPORT_GEMM3M), 1) @@ -2419,3 +2419,17 @@ cblas_zgeadd.$(SUFFIX) cblas_zgeadd.$(PSUFFIX) : zgeadd.c cblas_xerbla.$(SUFFIX) cblas_xerbla.$(PSUFFIX) : xerbla.c $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) +cblas_sbgemm_batch.$(SUFFIX) cblas_sbgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h + $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) + +cblas_sgemm_batch.$(SUFFIX) cblas_sgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h + $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) + +cblas_dgemm_batch.$(SUFFIX) cblas_dgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h + $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) + +cblas_cgemm_batch.$(SUFFIX) cblas_cgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h + $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) + +cblas_zgemm_batch.$(SUFFIX) cblas_zgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h + $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) diff --git a/interface/gemm_batch.c b/interface/gemm_batch.c new file mode 100644 index 000000000..d674f9615 --- /dev/null +++ b/interface/gemm_batch.c @@ -0,0 +1,372 @@ +/***************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + 3. Neither the name of the OpenBLAS project nor the names of + its contributors may be used to endorse or promote products + derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +**********************************************************************************/ + +#include +#include +#include "common.h" + +void openblas_warning(int verbose, const char * msg); + +#ifndef COMPLEX +#ifdef XDOUBLE +#define ERROR_NAME "QGEMM_BATCH " +#elif defined(DOUBLE) +#define ERROR_NAME "DGEMM_BATCH " +#define GEMM_BATCH_THREAD dgemm_batch_thread +#else +#define ERROR_NAME "SGEMM_BATCH " +#define GEMM_BATCH_THREAD sgemm_batch_thread +#endif +#else +#ifdef XDOUBLE +#define ERROR_NAME "XGEMM_BATCH " +#elif defined(DOUBLE) +#define ERROR_NAME "ZGEMM_BATCH " +#define GEMM_BATCH_THREAD zgemm_batch_thread +#else +#define ERROR_NAME "CGEMM_BATCH " +#define GEMM_BATCH_THREAD cgemm_batch_thread +#endif +#endif +static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = { + GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN, + GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT, + GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR, + GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC, +}; + +#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE) +#define USE_SMALL_MATRIX_OPT 1 +#else +#define USE_SMALL_MATRIX_OPT 0 +#endif + +#if USE_SMALL_MATRIX_OPT +#ifndef DYNAMIC_ARCH +#define SMALL_KERNEL_ADDR(table, idx) ((void *)(table[idx])) +#else +#define SMALL_KERNEL_ADDR(table, idx) ((void *)(*(uintptr_t *)((char *)gotoblas + (size_t)(table[idx])))) +#endif + + +#ifndef COMPLEX +static size_t gemm_small_kernel[] = { + GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, 0, 0, + GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, 0, 0, +}; + + +static size_t gemm_small_kernel_b0[] = { + GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, 0, 0, + GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, 0, 0, +}; + +#define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx)) +#define GEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel, (idx)) +#else + +static size_t zgemm_small_kernel[] = { + GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN, + GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT, + GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR, + GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC, +}; + +static size_t zgemm_small_kernel_b0[] = { + GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN, + GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT, + GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR, + GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC, +}; + +#define ZGEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel, (idx)) +#define ZGEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel_b0, (idx)) +#endif +#endif + +void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CBLAS_TRANSPOSE * transb_array, + blasint * m_array, blasint * n_array, blasint * k_array, +#ifndef COMPLEX + FLOAT * alpha_array, + FLOAT ** a_array, blasint * lda_array, + FLOAT ** b_array, blasint * ldb_array, + FLOAT * beta_array, + FLOAT ** c_array, blasint * ldc_array, blasint group_count, blasint * group_size) { +#else + void * valpha_array, + void ** va_array, blasint * lda_array, + void ** vb_array, blasint * ldb_array, + void * vbeta_array, + void ** vc_array, blasint * ldc_array, blasint group_count, blasint * group_size) { + + FLOAT * alpha_array=(FLOAT *)valpha_array; + FLOAT * beta_array=(FLOAT *)vbeta_array; + FLOAT ** a_array=(FLOAT**)va_array; + FLOAT ** b_array=(FLOAT**)vb_array; + FLOAT ** c_array=(FLOAT**)vc_array; + +#endif + blas_arg_t * args_array=NULL; + + int mode=0, group_mode=0; + blasint total_num=0; + + blasint i=0, j=0, matrix_idx=0, count=0; + + int group_transa, group_transb; + BLASLONG group_nrowa, group_nrowb; + blasint info; + + void * group_alpha, * group_beta; + BLASLONG group_m, group_n, group_k; + BLASLONG group_lda, group_ldb, group_ldc; + void * group_routine=NULL; +#ifdef SMALL_MATRIX_OPT + void * group_small_matrix_opt_routine=NULL; +#endif + +#if defined (SMP) || defined(SMALL_MATRIX_OPT) + double MNK; +#endif + + PRINT_DEBUG_CNAME; + + for(i=0; i= 0) { + BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); + free(args_array); + return; + } + + if (group_m == 0 || group_n == 0) continue; + + group_mode=mode; + +#if defined(SMP) || defined(SMALL_MATRIX_OPT) + MNK = (double) group_m * (double) group_n * (double) group_k; +#endif + +#ifdef SMALL_MATRIX_OPT + if (MNK <= 100.0*100.0*100.0){ + group_routine=NULL; +#if !defined(COMPLEX) + if(*(FLOAT *)(group_beta) == 0.0){ + group_mode=mode | BLAS_SMALL_B0_OPT; + group_small_matrix_opt_routine=(void *)(gemm_small_kernel_b0[(group_transb<<2)|group_transa]); + }else{ + group_mode=mode | BLAS_SMALL_OPT; + group_small_matrix_opt_routine=(void *)(gemm_small_kernel[(group_transb<<2)|group_transa]); + } +#else + if(((FLOAT *)(group_beta))[0] == 0.0 && ((FLOAT *)(group_beta))[1] == 0.0){ + group_mode=mode | BLAS_SMALL_B0_OPT; + group_small_matrix_opt_routine=(void *)(zgemm_small_kernel_b0[(group_transb<<2)|group_transa]); + }else{ + group_mode=mode | BLAS_SMALL_OPT; + group_small_matrix_opt_routine=(void *)(zgemm_small_kernel[(group_transb<<2)|group_transa]); + } + +#endif + + }else{ +#endif + group_routine=(void*)(gemm[(group_transb<<2)|group_transa]); +#ifdef SMALL_MATRIX_OPT + } +#endif + + + for(j=0; j0){ + GEMM_BATCH_THREAD(args_array,count); + } + + free(args_array); +} From 833a8880c608cad1d6e0b515da41611612be71e9 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Wed, 29 May 2024 15:47:50 +0200 Subject: [PATCH 3/9] add cblas_?gemm_batch --- exports/gensymbol | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/exports/gensymbol b/exports/gensymbol index 226035842..28dd883f2 100755 --- a/exports/gensymbol +++ b/exports/gensymbol @@ -60,7 +60,7 @@ cblasobjsc=" cblas_ctbsv cblas_ctpmv cblas_ctpsv cblas_ctrmm cblas_ctrmv cblas_ctrsm cblas_ctrsv cblas_scnrm2 cblas_scasum cblas_cgemmt cblas_icamax cblas_icamin cblas_icmin cblas_icmax cblas_scsum cblas_cimatcopy cblas_comatcopy - cblas_caxpyc cblas_crotg cblas_csrot cblas_scamax cblas_scamin + cblas_caxpyc cblas_crotg cblas_csrot cblas_scamax cblas_scamin cblas_cgemm_batch " cblasobjsd=" cblas_dasum cblas_daxpy cblas_dcopy cblas_ddot @@ -70,7 +70,7 @@ cblasobjsd=" cblas_dsyr2k cblas_dsyr cblas_dsyrk cblas_dtbmv cblas_dtbsv cblas_dtpmv cblas_dtpsv cblas_dtrmm cblas_dtrmv cblas_dtrsm cblas_dtrsv cblas_daxpby cblas_dgeadd cblas_dgemmt cblas_idamax cblas_idamin cblas_idmin cblas_idmax cblas_dsum cblas_dimatcopy cblas_domatcopy - cblas_damax cblas_damin + cblas_damax cblas_damin cblas_dgemm_batch " cblasobjss=" @@ -82,7 +82,7 @@ cblasobjss=" cblas_stbmv cblas_stbsv cblas_stpmv cblas_stpsv cblas_strmm cblas_strmv cblas_strsm cblas_strsv cblas_sgeadd cblas_sgemmt cblas_isamax cblas_isamin cblas_ismin cblas_ismax cblas_ssum cblas_simatcopy cblas_somatcopy - cblas_samax cblas_samin + cblas_samax cblas_samin cblas_sgemm_batch " cblasobjsz=" @@ -94,12 +94,12 @@ cblasobjsz=" cblas_ztrsv cblas_cdotc_sub cblas_cdotu_sub cblas_zdotc_sub cblas_zdotu_sub cblas_zaxpby cblas_zgeadd cblas_zgemmt cblas_izamax cblas_izamin cblas_izmin cblas_izmax cblas_dzsum cblas_zimatcopy cblas_zomatcopy - cblas_zaxpyc cblas_zdrot cblas_zrotg cblas_dzamax cblas_dzamin + cblas_zaxpyc cblas_zdrot cblas_zrotg cblas_dzamax cblas_dzamin cblas_zgemm_batch " cblasobjs="cblas_xerbla" -bfcblasobjs="cblas_sbgemm cblas_sbgemv cblas_sbdot cblas_sbstobf16 cblas_sbdtobf16 cblas_sbf16tos cblas_dbf16tod" +bfcblasobjs="cblas_sbgemm cblas_sbgemv cblas_sbdot cblas_sbstobf16 cblas_sbdtobf16 cblas_sbf16tos cblas_dbf16tod cblas_sbgemm_batch" exblasobjs=" qamax qamin qasum qaxpy qcabs1 qcopy qdot qgbmv qgemm From d0794f88dcb83fd1a3a9fd317124e9ec984b6255 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Wed, 29 May 2024 15:49:20 +0200 Subject: [PATCH 4/9] add gemm_batch driver --- driver/level3/Makefile | 23 ++++- driver/level3/gemm_batch_thread.c | 156 ++++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+), 4 deletions(-) create mode 100644 driver/level3/gemm_batch_thread.c diff --git a/driver/level3/Makefile b/driver/level3/Makefile index b8465d4ed..c30483842 100644 --- a/driver/level3/Makefile +++ b/driver/level3/Makefile @@ -37,7 +37,7 @@ SBLASOBJS += \ ssyrk_UN.$(SUFFIX) ssyrk_UT.$(SUFFIX) ssyrk_LN.$(SUFFIX) ssyrk_LT.$(SUFFIX) \ ssyr2k_UN.$(SUFFIX) ssyr2k_UT.$(SUFFIX) ssyr2k_LN.$(SUFFIX) ssyr2k_LT.$(SUFFIX) \ ssyrk_kernel_U.$(SUFFIX) ssyrk_kernel_L.$(SUFFIX) \ - ssyr2k_kernel_U.$(SUFFIX) ssyr2k_kernel_L.$(SUFFIX) + ssyr2k_kernel_U.$(SUFFIX) ssyr2k_kernel_L.$(SUFFIX) sgemm_batch_thread.$(SUFFIX) DBLASOBJS += \ dgemm_nn.$(SUFFIX) dgemm_nt.$(SUFFIX) dgemm_tn.$(SUFFIX) dgemm_tt.$(SUFFIX) \ @@ -53,7 +53,7 @@ DBLASOBJS += \ dsyrk_UN.$(SUFFIX) dsyrk_UT.$(SUFFIX) dsyrk_LN.$(SUFFIX) dsyrk_LT.$(SUFFIX) \ dsyr2k_UN.$(SUFFIX) dsyr2k_UT.$(SUFFIX) dsyr2k_LN.$(SUFFIX) dsyr2k_LT.$(SUFFIX) \ dsyrk_kernel_U.$(SUFFIX) dsyrk_kernel_L.$(SUFFIX) \ - dsyr2k_kernel_U.$(SUFFIX) dsyr2k_kernel_L.$(SUFFIX) + dsyr2k_kernel_U.$(SUFFIX) dsyr2k_kernel_L.$(SUFFIX) dgemm_batch_thread.$(SUFFIX) QBLASOBJS += \ qgemm_nn.$(SUFFIX) qgemm_nt.$(SUFFIX) qgemm_tn.$(SUFFIX) qgemm_tt.$(SUFFIX) \ @@ -103,7 +103,7 @@ CBLASOBJS += \ cherk_kernel_LN.$(SUFFIX) cherk_kernel_LC.$(SUFFIX) \ csyr2k_kernel_U.$(SUFFIX) csyr2k_kernel_L.$(SUFFIX) \ cher2k_kernel_UN.$(SUFFIX) cher2k_kernel_UC.$(SUFFIX) \ - cher2k_kernel_LN.$(SUFFIX) cher2k_kernel_LC.$(SUFFIX) + cher2k_kernel_LN.$(SUFFIX) cher2k_kernel_LC.$(SUFFIX) cgemm_batch_thread.$(SUFFIX) ZBLASOBJS += \ zgemm_nn.$(SUFFIX) zgemm_cn.$(SUFFIX) zgemm_tn.$(SUFFIX) zgemm_nc.$(SUFFIX) \ @@ -137,7 +137,7 @@ ZBLASOBJS += \ zherk_kernel_LN.$(SUFFIX) zherk_kernel_LC.$(SUFFIX) \ zsyr2k_kernel_U.$(SUFFIX) zsyr2k_kernel_L.$(SUFFIX) \ zher2k_kernel_UN.$(SUFFIX) zher2k_kernel_UC.$(SUFFIX) \ - zher2k_kernel_LN.$(SUFFIX) zher2k_kernel_LC.$(SUFFIX) + zher2k_kernel_LN.$(SUFFIX) zher2k_kernel_LC.$(SUFFIX) zgemm_batch_thread.$(SUFFIX) XBLASOBJS += \ @@ -2942,6 +2942,21 @@ gemm_thread_variable.$(PSUFFIX) : gemm_thread_variable.c ../../common.h beta_thread.$(PSUFFIX) : beta_thread.c ../../common.h $(CC) -c $(PFLAGS) $< -o $(@F) +sbgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h + $(CC) -c $(CFLAGS) $< -o $(@F) + +sgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h + $(CC) -c $(CFLAGS) $< -o $(@F) + +dgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h + $(CC) -c $(CFLAGS) $< -o $(@F) + +cgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h + $(CC) -c $(CFLAGS) $< -o $(@F) + +zgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h + $(CC) -c $(CFLAGS) $< -o $(@F) + sbgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) diff --git a/driver/level3/gemm_batch_thread.c b/driver/level3/gemm_batch_thread.c new file mode 100644 index 000000000..45d6977ba --- /dev/null +++ b/driver/level3/gemm_batch_thread.c @@ -0,0 +1,156 @@ +/***************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + 3. Neither the name of the OpenBLAS project nor the names of + its contributors may be used to endorse or promote products + derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +**********************************************************************************/ + +#include "common.h" + +void openblas_warning(int verbose, const char * msg); + +#ifdef SMALL_MATRIX_OPT +static int inner_small_matrix_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){ + int routine_mode; +#ifndef COMPLEX + int (*gemm_small_kernel)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG); + int (*gemm_small_kernel_b0)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG); +#else + int (*zgemm_small_kernel)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG); + int (*zgemm_small_kernel_b0)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG); + FLOAT alpha[2], beta[2]; +#endif + routine_mode=args->routine_mode; + if((routine_mode & BLAS_SMALL_B0_OPT) == BLAS_SMALL_B0_OPT){ +#ifndef COMPLEX + gemm_small_kernel_b0=args->routine; + gemm_small_kernel_b0(args->m, args->n, args->k, args->a, args->lda, *(FLOAT *)(args->alpha), args->b, args->ldb, args->c, args->ldc); +#else + zgemm_small_kernel_b0=args->routine; + alpha[0] = *((FLOAT *)args -> alpha + 0); + alpha[1] = *((FLOAT *)args -> alpha + 1); + zgemm_small_kernel_b0(args->m, args->n, args->k, args->a, args->lda, alpha[0], alpha[1], args->b, args->ldb, args->c, args->ldc); +#endif + return(0); + }else if(routine_mode & BLAS_SMALL_OPT){ +#ifndef COMPLEX + gemm_small_kernel=args->routine; + gemm_small_kernel(args->m, args->n, args->k, args->a, args->lda, *(FLOAT *)(args->alpha), args->b, args->ldb, *(FLOAT *)(args->beta), args->c, args->ldc); +#else + zgemm_small_kernel=args->routine; + alpha[0] = *((FLOAT *)args -> alpha + 0); + alpha[1] = *((FLOAT *)args -> alpha + 1); + beta[0] = *((FLOAT *)args -> beta + 0); + beta[1] = *((FLOAT *)args -> beta + 1); + zgemm_small_kernel(args->m, args->n, args->k, args->a, args->lda, alpha[0], alpha[1], args->b, args->ldb, beta[0], beta[1], args->c, args->ldc); +#endif + return(0); + } + return(1); +} +#endif + +int CNAME(blas_arg_t * args_array, BLASLONG nums){ + XFLOAT *buffer; + XFLOAT *sa, *sb; + int nthreads=1; + int (*routine)(blas_arg_t *, void *, void *, XFLOAT *, XFLOAT *, BLASLONG); + int i=0, /*j,*/ current_nums; + +#ifdef SMP + blas_queue_t * queue=NULL; +#endif + + if(nums <=0 ) return 0; + + buffer = (XFLOAT *)blas_memory_alloc(0); + sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A); + sb = (XFLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); + +#ifdef SMP + nthreads=num_cpu_avail(3); + + if(nthreads==1){ + +#endif + //single thread + for(i=0; inthreads)? nthreads: (nums-i); + + queue[i].sa=sa; + queue[i].sb=sb; + queue[i+current_nums-1].next=NULL; + + exec_blas(current_nums, &queue[i]); + } + free(queue); + } +#endif + blas_memory_free(buffer); + return 0; +} From 362a06339677f1018690ab409f6a8cde664de347 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Wed, 29 May 2024 23:16:58 +0200 Subject: [PATCH 5/9] remove return value --- interface/gemm_batch.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/interface/gemm_batch.c b/interface/gemm_batch.c index d674f9615..846d8e0f4 100644 --- a/interface/gemm_batch.c +++ b/interface/gemm_batch.c @@ -169,7 +169,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB if(args_array == NULL){ openblas_warning(0, "memory alloc failed!\n"); - return(1); + return; } #ifdef SMP From ff6670cb83ebfb12c199576ba1d2f459b0c9d750 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Thu, 30 May 2024 18:26:02 +0200 Subject: [PATCH 6/9] don't generate non-cblas files for gemm_batch --- interface/CMakeLists.txt | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/interface/CMakeLists.txt b/interface/CMakeLists.txt index 55374674a..01e67de6f 100644 --- a/interface/CMakeLists.txt +++ b/interface/CMakeLists.txt @@ -97,6 +97,9 @@ foreach (CBLAS_FLAG ${CBLAS_FLAGS}) #sdsdot, dsdot if (BUILD_SINGLE OR BUILD_DOUBLE) GenerateNamedObjects("sdsdot.c" "" "sdsdot" ${CBLAS_FLAG} "" "" true "SINGLE") + if(CBLAS_FLAG EQUAL 1) + GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" ${CBLAS_FLAG} "" "" true) +endif () endif () if (BUILD_DOUBLE) GenerateNamedObjects("dsdot.c" "" "dsdot" ${CBLAS_FLAG} "" "" true "SINGLE") @@ -125,6 +128,9 @@ if (BUILD_BFLOAT16) GenerateNamedObjects("tobf16.c" "DOUBLE_PREC" "sbdtobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("bf16to.c" "SINGLE_PREC" "sbf16tos" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("bf16to.c" "DOUBLE_PREC" "dbf16tod" ${CBLAS_FLAG} "" "" true "BFLOAT16") + if(CBLAS_FLAG EQUAL 1) + GenerateNamedObjects("gemm_batch.c" "" "sbgemm_batch" ${CBLAS_FLAG} "" "" true "BFLOAT16") +endif () endif () # complex-specific sources @@ -154,6 +160,9 @@ foreach (float_type ${FLOAT_TYPES}) GenerateNamedObjects("max.c" "USE_ABS" "scamax" ${CBLAS_FLAG} "" "" true "COMPLEX") GenerateNamedObjects("asum.c" "" "scasum" ${CBLAS_FLAG} "" "" true "COMPLEX") GenerateNamedObjects("sum.c" "" "scsum" ${CBLAS_FLAG} "" "" true "COMPLEX") + if(CBLAS_FLAG EQUAL 1) + GenerateNamedObjects("gemm_batch.c" "" "cgemm_batch" ${CBLAS_FLAG} "" "" true "COMPLEX") + endif () endif () if (${float_type} STREQUAL "ZCOMPLEX") GenerateNamedObjects("zscal.c" "SSCAL" "dscal" ${CBLAS_FLAG} "" "" false "ZCOMPLEX") @@ -163,6 +172,9 @@ foreach (float_type ${FLOAT_TYPES}) GenerateNamedObjects("max.c" "USE_ABS" "dzamax" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") GenerateNamedObjects("asum.c" "" "dzasum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") GenerateNamedObjects("sum.c" "" "dzsum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") + if(CBLAS_FLAG EQUAL 1) + GenerateNamedObjects("gemm_batch.c" "" "zgemm_batch" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") + endif () endif () endforeach () @@ -212,6 +224,7 @@ if ( BUILD_COMPLEX AND NOT BUILD_SINGLE) GenerateNamedObjects("nrm2.c" "" "nrm2" 0 "" "" false "SINGLE") GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "SINGLE") GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "SINGLE") + GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 1 "" "" false "SINGLE") GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "SINGLE") GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "SINGLE") GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "SINGLE") @@ -225,6 +238,7 @@ if ( BUILD_COMPLEX16 AND NOT BUILD_DOUBLE) GenerateNamedObjects("nrm2.c" "" "nrm2" 0 "" "" false "DOUBLE") GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "DOUBLE") GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "DOUBLE") + GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 1 "" "" false "DOUBLE") GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "DOUBLE") GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "DOUBLE") GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "DOUBLE") From 0d007adb18234289562b8556b5a0963125adda5a Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Thu, 30 May 2024 23:30:16 +0200 Subject: [PATCH 7/9] fix clang_cl-flang job to use flang-new after the llvm update --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 3ae8615a7..7f82a461d 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -133,7 +133,7 @@ jobs: mkdir build cd build call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" - cmake -G "Ninja" -DCMAKE_C_COMPILER=cl -DCMAKE_Fortran_COMPILER=flang -DC_LAPACK=1 -DCMAKE_MT=mt -DCMAKE_BUILD_TYPE=Release -DMSVC_STATIC_CRT=ON .. + cmake -G "Ninja" -DCMAKE_C_COMPILER=cl -DCMAKE_Fortran_COMPILER=flang-new -DC_LAPACK=1 -DCMAKE_MT=mt -DCMAKE_BUILD_TYPE=Release -DMSVC_STATIC_CRT=ON .. cmake --build . --config Release ctest From 076766df4e3a147da3e57ffec3f2fcba106c78b1 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Fri, 31 May 2024 18:23:18 +0200 Subject: [PATCH 8/9] Update CMakeLists.txt --- interface/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/interface/CMakeLists.txt b/interface/CMakeLists.txt index 01e67de6f..05b893b8c 100644 --- a/interface/CMakeLists.txt +++ b/interface/CMakeLists.txt @@ -98,7 +98,7 @@ foreach (CBLAS_FLAG ${CBLAS_FLAGS}) if (BUILD_SINGLE OR BUILD_DOUBLE) GenerateNamedObjects("sdsdot.c" "" "sdsdot" ${CBLAS_FLAG} "" "" true "SINGLE") if(CBLAS_FLAG EQUAL 1) - GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" ${CBLAS_FLAG} "" "" true) + GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" ${CBLAS_FLAG} "" "" false) endif () endif () if (BUILD_DOUBLE) From db070a92238fd451ebb66561d72c2ce2e1524b58 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Fri, 31 May 2024 18:29:27 +0200 Subject: [PATCH 9/9] add gemm_batch drivers --- driver/level3/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/driver/level3/CMakeLists.txt b/driver/level3/CMakeLists.txt index 75b25d039..b1ec94c23 100644 --- a/driver/level3/CMakeLists.txt +++ b/driver/level3/CMakeLists.txt @@ -68,6 +68,8 @@ if (USE_THREAD) endif () foreach (float_type ${FLOAT_TYPES}) + GenerateNamedObjects("gemm_batch_thread.c" "" "gemm_batch_thread" 0 "" "" false ${float_type}) + if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX") GenerateCombinationObjects("zherk_kernel.c" "LOWER;CONJ" "U;N" "HERK" 2 "herk_kernel" false ${float_type}) # TRANS needs to be set/unset when CONJ is set/unset, so can't use it as a combination