Merge pull request #4720 from martin-frbg/issue3039

Resurrect and complete cblas_?gemm_batch
This commit is contained in:
Martin Kroeker 2024-06-01 00:34:32 +02:00 committed by GitHub
commit 56bd57ca99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 620 additions and 16 deletions

View File

@ -133,7 +133,7 @@ jobs:
mkdir build mkdir build
cd build cd build
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
cmake -G "Ninja" -DCMAKE_C_COMPILER=cl -DCMAKE_Fortran_COMPILER=flang -DC_LAPACK=1 -DCMAKE_MT=mt -DCMAKE_BUILD_TYPE=Release -DMSVC_STATIC_CRT=ON .. cmake -G "Ninja" -DCMAKE_C_COMPILER=cl -DCMAKE_Fortran_COMPILER=flang-new -DC_LAPACK=1 -DCMAKE_MT=mt -DCMAKE_BUILD_TYPE=Release -DMSVC_STATIC_CRT=ON ..
cmake --build . --config Release cmake --build . --config Release
ctest ctest

15
cblas.h
View File

@ -416,6 +416,18 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint
void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta, void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta,
double *c, OPENBLAS_CONST blasint cldc); double *c, OPENBLAS_CONST blasint cldc);
void cblas_sgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST float ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST float ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);
void cblas_dgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
OPENBLAS_CONST double * alpha_array, OPENBLAS_CONST double ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST double ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST double * beta_array, double ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);
void cblas_cgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);
void cblas_zgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);
/*** BFLOAT16 and INT8 extensions ***/ /*** BFLOAT16 and INT8 extensions ***/
/* convert float array to BFLOAT16 array by rounding */ /* convert float array to BFLOAT16 array by rounding */
void cblas_sbstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout); void cblas_sbstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout);
@ -431,6 +443,9 @@ void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum
void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc); OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);
void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif /* __cplusplus */ #endif /* __cplusplus */

View File

@ -1937,8 +1937,13 @@ int zimatcopy_k_rtc(BLASLONG, BLASLONG, double, double, double *, BLASLONG);
int sgeadd_k(BLASLONG, BLASLONG, float, float*, BLASLONG, float, float *, BLASLONG); int sgeadd_k(BLASLONG, BLASLONG, float, float*, BLASLONG, float, float *, BLASLONG);
int dgeadd_k(BLASLONG, BLASLONG, double, double*, BLASLONG, double, double *, BLASLONG); int dgeadd_k(BLASLONG, BLASLONG, double, double*, BLASLONG, double, double *, BLASLONG);
int cgeadd_k(BLASLONG, BLASLONG, float, float, float*, BLASLONG, float, float, float *, BLASLONG); int cgeadd_k(BLASLONG, BLASLONG, float, float, float*, BLASLONG, float, float, float *, BLASLONG);
int zgeadd_k(BLASLONG, BLASLONG, double,double, double*, BLASLONG, double, double, double *, BLASLONG); int zgeadd_k(BLASLONG, BLASLONG, double,double, double*, BLASLONG, double, double, double *, BLASLONG);
int sgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
#ifdef __CUDACC__ #ifdef __CUDACC__
} }

View File

@ -2655,9 +2655,20 @@ typedef struct {
BLASLONG prea, preb, prec, pred; BLASLONG prea, preb, prec, pred;
#endif #endif
//for gemm_batch
void * routine;
int routine_mode;
} blas_arg_t; } blas_arg_t;
#endif #endif
#ifdef SMALL_MATRIX_OPT
#define BLAS_SMALL_OPT 0x10000U
#define BLAS_SMALL_B0_OPT 0x30000U
#endif
#ifdef XDOUBLE #ifdef XDOUBLE
#define TRSV_NUU qtrsv_NUU #define TRSV_NUU qtrsv_NUU

View File

@ -68,6 +68,8 @@ if (USE_THREAD)
endif () endif ()
foreach (float_type ${FLOAT_TYPES}) foreach (float_type ${FLOAT_TYPES})
GenerateNamedObjects("gemm_batch_thread.c" "" "gemm_batch_thread" 0 "" "" false ${float_type})
if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX") if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX")
GenerateCombinationObjects("zherk_kernel.c" "LOWER;CONJ" "U;N" "HERK" 2 "herk_kernel" false ${float_type}) GenerateCombinationObjects("zherk_kernel.c" "LOWER;CONJ" "U;N" "HERK" 2 "herk_kernel" false ${float_type})
# TRANS needs to be set/unset when CONJ is set/unset, so can't use it as a combination # TRANS needs to be set/unset when CONJ is set/unset, so can't use it as a combination

View File

@ -37,7 +37,7 @@ SBLASOBJS += \
ssyrk_UN.$(SUFFIX) ssyrk_UT.$(SUFFIX) ssyrk_LN.$(SUFFIX) ssyrk_LT.$(SUFFIX) \ ssyrk_UN.$(SUFFIX) ssyrk_UT.$(SUFFIX) ssyrk_LN.$(SUFFIX) ssyrk_LT.$(SUFFIX) \
ssyr2k_UN.$(SUFFIX) ssyr2k_UT.$(SUFFIX) ssyr2k_LN.$(SUFFIX) ssyr2k_LT.$(SUFFIX) \ ssyr2k_UN.$(SUFFIX) ssyr2k_UT.$(SUFFIX) ssyr2k_LN.$(SUFFIX) ssyr2k_LT.$(SUFFIX) \
ssyrk_kernel_U.$(SUFFIX) ssyrk_kernel_L.$(SUFFIX) \ ssyrk_kernel_U.$(SUFFIX) ssyrk_kernel_L.$(SUFFIX) \
ssyr2k_kernel_U.$(SUFFIX) ssyr2k_kernel_L.$(SUFFIX) ssyr2k_kernel_U.$(SUFFIX) ssyr2k_kernel_L.$(SUFFIX) sgemm_batch_thread.$(SUFFIX)
DBLASOBJS += \ DBLASOBJS += \
dgemm_nn.$(SUFFIX) dgemm_nt.$(SUFFIX) dgemm_tn.$(SUFFIX) dgemm_tt.$(SUFFIX) \ dgemm_nn.$(SUFFIX) dgemm_nt.$(SUFFIX) dgemm_tn.$(SUFFIX) dgemm_tt.$(SUFFIX) \
@ -53,7 +53,7 @@ DBLASOBJS += \
dsyrk_UN.$(SUFFIX) dsyrk_UT.$(SUFFIX) dsyrk_LN.$(SUFFIX) dsyrk_LT.$(SUFFIX) \ dsyrk_UN.$(SUFFIX) dsyrk_UT.$(SUFFIX) dsyrk_LN.$(SUFFIX) dsyrk_LT.$(SUFFIX) \
dsyr2k_UN.$(SUFFIX) dsyr2k_UT.$(SUFFIX) dsyr2k_LN.$(SUFFIX) dsyr2k_LT.$(SUFFIX) \ dsyr2k_UN.$(SUFFIX) dsyr2k_UT.$(SUFFIX) dsyr2k_LN.$(SUFFIX) dsyr2k_LT.$(SUFFIX) \
dsyrk_kernel_U.$(SUFFIX) dsyrk_kernel_L.$(SUFFIX) \ dsyrk_kernel_U.$(SUFFIX) dsyrk_kernel_L.$(SUFFIX) \
dsyr2k_kernel_U.$(SUFFIX) dsyr2k_kernel_L.$(SUFFIX) dsyr2k_kernel_U.$(SUFFIX) dsyr2k_kernel_L.$(SUFFIX) dgemm_batch_thread.$(SUFFIX)
QBLASOBJS += \ QBLASOBJS += \
qgemm_nn.$(SUFFIX) qgemm_nt.$(SUFFIX) qgemm_tn.$(SUFFIX) qgemm_tt.$(SUFFIX) \ qgemm_nn.$(SUFFIX) qgemm_nt.$(SUFFIX) qgemm_tn.$(SUFFIX) qgemm_tt.$(SUFFIX) \
@ -103,7 +103,7 @@ CBLASOBJS += \
cherk_kernel_LN.$(SUFFIX) cherk_kernel_LC.$(SUFFIX) \ cherk_kernel_LN.$(SUFFIX) cherk_kernel_LC.$(SUFFIX) \
csyr2k_kernel_U.$(SUFFIX) csyr2k_kernel_L.$(SUFFIX) \ csyr2k_kernel_U.$(SUFFIX) csyr2k_kernel_L.$(SUFFIX) \
cher2k_kernel_UN.$(SUFFIX) cher2k_kernel_UC.$(SUFFIX) \ cher2k_kernel_UN.$(SUFFIX) cher2k_kernel_UC.$(SUFFIX) \
cher2k_kernel_LN.$(SUFFIX) cher2k_kernel_LC.$(SUFFIX) cher2k_kernel_LN.$(SUFFIX) cher2k_kernel_LC.$(SUFFIX) cgemm_batch_thread.$(SUFFIX)
ZBLASOBJS += \ ZBLASOBJS += \
zgemm_nn.$(SUFFIX) zgemm_cn.$(SUFFIX) zgemm_tn.$(SUFFIX) zgemm_nc.$(SUFFIX) \ zgemm_nn.$(SUFFIX) zgemm_cn.$(SUFFIX) zgemm_tn.$(SUFFIX) zgemm_nc.$(SUFFIX) \
@ -137,7 +137,7 @@ ZBLASOBJS += \
zherk_kernel_LN.$(SUFFIX) zherk_kernel_LC.$(SUFFIX) \ zherk_kernel_LN.$(SUFFIX) zherk_kernel_LC.$(SUFFIX) \
zsyr2k_kernel_U.$(SUFFIX) zsyr2k_kernel_L.$(SUFFIX) \ zsyr2k_kernel_U.$(SUFFIX) zsyr2k_kernel_L.$(SUFFIX) \
zher2k_kernel_UN.$(SUFFIX) zher2k_kernel_UC.$(SUFFIX) \ zher2k_kernel_UN.$(SUFFIX) zher2k_kernel_UC.$(SUFFIX) \
zher2k_kernel_LN.$(SUFFIX) zher2k_kernel_LC.$(SUFFIX) zher2k_kernel_LN.$(SUFFIX) zher2k_kernel_LC.$(SUFFIX) zgemm_batch_thread.$(SUFFIX)
XBLASOBJS += \ XBLASOBJS += \
@ -2942,6 +2942,21 @@ gemm_thread_variable.$(PSUFFIX) : gemm_thread_variable.c ../../common.h
beta_thread.$(PSUFFIX) : beta_thread.c ../../common.h beta_thread.$(PSUFFIX) : beta_thread.c ../../common.h
$(CC) -c $(PFLAGS) $< -o $(@F) $(CC) -c $(PFLAGS) $< -o $(@F)
sbgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h
$(CC) -c $(CFLAGS) $< -o $(@F)
sgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h
$(CC) -c $(CFLAGS) $< -o $(@F)
dgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h
$(CC) -c $(CFLAGS) $< -o $(@F)
cgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h
$(CC) -c $(CFLAGS) $< -o $(@F)
zgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h
$(CC) -c $(CFLAGS) $< -o $(@F)
sbgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h sbgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)

View File

@ -0,0 +1,156 @@
/*****************************************************************************
Copyright (c) 2020, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
**********************************************************************************/
#include "common.h"
void openblas_warning(int verbose, const char * msg);
#ifdef SMALL_MATRIX_OPT
static int inner_small_matrix_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){
int routine_mode;
#ifndef COMPLEX
int (*gemm_small_kernel)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG);
int (*gemm_small_kernel_b0)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG);
#else
int (*zgemm_small_kernel)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG);
int (*zgemm_small_kernel_b0)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG);
FLOAT alpha[2], beta[2];
#endif
routine_mode=args->routine_mode;
if((routine_mode & BLAS_SMALL_B0_OPT) == BLAS_SMALL_B0_OPT){
#ifndef COMPLEX
gemm_small_kernel_b0=args->routine;
gemm_small_kernel_b0(args->m, args->n, args->k, args->a, args->lda, *(FLOAT *)(args->alpha), args->b, args->ldb, args->c, args->ldc);
#else
zgemm_small_kernel_b0=args->routine;
alpha[0] = *((FLOAT *)args -> alpha + 0);
alpha[1] = *((FLOAT *)args -> alpha + 1);
zgemm_small_kernel_b0(args->m, args->n, args->k, args->a, args->lda, alpha[0], alpha[1], args->b, args->ldb, args->c, args->ldc);
#endif
return(0);
}else if(routine_mode & BLAS_SMALL_OPT){
#ifndef COMPLEX
gemm_small_kernel=args->routine;
gemm_small_kernel(args->m, args->n, args->k, args->a, args->lda, *(FLOAT *)(args->alpha), args->b, args->ldb, *(FLOAT *)(args->beta), args->c, args->ldc);
#else
zgemm_small_kernel=args->routine;
alpha[0] = *((FLOAT *)args -> alpha + 0);
alpha[1] = *((FLOAT *)args -> alpha + 1);
beta[0] = *((FLOAT *)args -> beta + 0);
beta[1] = *((FLOAT *)args -> beta + 1);
zgemm_small_kernel(args->m, args->n, args->k, args->a, args->lda, alpha[0], alpha[1], args->b, args->ldb, beta[0], beta[1], args->c, args->ldc);
#endif
return(0);
}
return(1);
}
#endif
int CNAME(blas_arg_t * args_array, BLASLONG nums){
XFLOAT *buffer;
XFLOAT *sa, *sb;
int nthreads=1;
int (*routine)(blas_arg_t *, void *, void *, XFLOAT *, XFLOAT *, BLASLONG);
int i=0, /*j,*/ current_nums;
#ifdef SMP
blas_queue_t * queue=NULL;
#endif
if(nums <=0 ) return 0;
buffer = (XFLOAT *)blas_memory_alloc(0);
sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A);
sb = (XFLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
#ifdef SMP
nthreads=num_cpu_avail(3);
if(nthreads==1){
#endif
//single thread
for(i=0; i<nums; i++){
routine=args_array[i].routine;
#ifdef SMALL_MATRIX_OPT
if(args_array[i].routine_mode & BLAS_SMALL_OPT){
inner_small_matrix_thread(&args_array[i], NULL, NULL, NULL, NULL, 0);
}else{
#endif
routine(&args_array[i], NULL, NULL, sa, sb, 0);
#ifdef SMALL_MATRIX_OPT
}
#endif
}
#ifdef SMP
} else {
//multi thread
queue=(blas_queue_t *)malloc((nums+1) * sizeof(blas_queue_t));
if(queue == NULL){
openblas_warning(0, "memory alloc failed!\n");
return(1);
}
for(i=0; i<nums; i++){
queue[i].args=&args_array[i];
queue[i].range_m=NULL;
queue[i].range_n=NULL;
queue[i].sa=NULL;
queue[i].sb=NULL;
queue[i].next=&queue[i+1];
queue[i].mode=args_array[i].routine_mode;
queue[i].routine=args_array[i].routine;
#ifdef SMALL_MATRIX_OPT
if((args_array[i].routine_mode & BLAS_SMALL_B0_OPT) || (args_array[i].routine_mode & BLAS_SMALL_OPT)){
queue[i].routine=inner_small_matrix_thread;
}
#endif
}
for(i=0; i<nums; i+=nthreads){
current_nums=((nums-i)>nthreads)? nthreads: (nums-i);
queue[i].sa=sa;
queue[i].sb=sb;
queue[i+current_nums-1].next=NULL;
exec_blas(current_nums, &queue[i]);
}
free(queue);
}
#endif
blas_memory_free(buffer);
return 0;
}

View File

@ -60,7 +60,7 @@ cblasobjsc="
cblas_ctbsv cblas_ctpmv cblas_ctpsv cblas_ctrmm cblas_ctrmv cblas_ctrsm cblas_ctrsv cblas_ctbsv cblas_ctpmv cblas_ctpsv cblas_ctrmm cblas_ctrmv cblas_ctrsm cblas_ctrsv
cblas_scnrm2 cblas_scasum cblas_cgemmt cblas_scnrm2 cblas_scasum cblas_cgemmt
cblas_icamax cblas_icamin cblas_icmin cblas_icmax cblas_scsum cblas_cimatcopy cblas_comatcopy cblas_icamax cblas_icamin cblas_icmin cblas_icmax cblas_scsum cblas_cimatcopy cblas_comatcopy
cblas_caxpyc cblas_crotg cblas_csrot cblas_scamax cblas_scamin cblas_caxpyc cblas_crotg cblas_csrot cblas_scamax cblas_scamin cblas_cgemm_batch
" "
cblasobjsd=" cblasobjsd="
cblas_dasum cblas_daxpy cblas_dcopy cblas_ddot cblas_dasum cblas_daxpy cblas_dcopy cblas_ddot
@ -70,7 +70,7 @@ cblasobjsd="
cblas_dsyr2k cblas_dsyr cblas_dsyrk cblas_dtbmv cblas_dtbsv cblas_dtpmv cblas_dtpsv cblas_dsyr2k cblas_dsyr cblas_dsyrk cblas_dtbmv cblas_dtbsv cblas_dtpmv cblas_dtpsv
cblas_dtrmm cblas_dtrmv cblas_dtrsm cblas_dtrsv cblas_daxpby cblas_dgeadd cblas_dgemmt cblas_dtrmm cblas_dtrmv cblas_dtrsm cblas_dtrsv cblas_daxpby cblas_dgeadd cblas_dgemmt
cblas_idamax cblas_idamin cblas_idmin cblas_idmax cblas_dsum cblas_dimatcopy cblas_domatcopy cblas_idamax cblas_idamin cblas_idmin cblas_idmax cblas_dsum cblas_dimatcopy cblas_domatcopy
cblas_damax cblas_damin cblas_damax cblas_damin cblas_dgemm_batch
" "
cblasobjss=" cblasobjss="
@ -82,7 +82,7 @@ cblasobjss="
cblas_stbmv cblas_stbsv cblas_stpmv cblas_stpsv cblas_strmm cblas_strmv cblas_strsm cblas_stbmv cblas_stbsv cblas_stpmv cblas_stpsv cblas_strmm cblas_strmv cblas_strsm
cblas_strsv cblas_sgeadd cblas_sgemmt cblas_strsv cblas_sgeadd cblas_sgemmt
cblas_isamax cblas_isamin cblas_ismin cblas_ismax cblas_ssum cblas_simatcopy cblas_somatcopy cblas_isamax cblas_isamin cblas_ismin cblas_ismax cblas_ssum cblas_simatcopy cblas_somatcopy
cblas_samax cblas_samin cblas_samax cblas_samin cblas_sgemm_batch
" "
cblasobjsz=" cblasobjsz="
@ -94,12 +94,12 @@ cblasobjsz="
cblas_ztrsv cblas_cdotc_sub cblas_cdotu_sub cblas_zdotc_sub cblas_zdotu_sub cblas_ztrsv cblas_cdotc_sub cblas_cdotu_sub cblas_zdotc_sub cblas_zdotu_sub
cblas_zaxpby cblas_zgeadd cblas_zgemmt cblas_zaxpby cblas_zgeadd cblas_zgemmt
cblas_izamax cblas_izamin cblas_izmin cblas_izmax cblas_dzsum cblas_zimatcopy cblas_zomatcopy cblas_izamax cblas_izamin cblas_izmin cblas_izmax cblas_dzsum cblas_zimatcopy cblas_zomatcopy
cblas_zaxpyc cblas_zdrot cblas_zrotg cblas_dzamax cblas_dzamin cblas_zaxpyc cblas_zdrot cblas_zrotg cblas_dzamax cblas_dzamin cblas_zgemm_batch
" "
cblasobjs="cblas_xerbla" cblasobjs="cblas_xerbla"
bfcblasobjs="cblas_sbgemm cblas_sbgemv cblas_sbdot cblas_sbstobf16 cblas_sbdtobf16 cblas_sbf16tos cblas_dbf16tod" bfcblasobjs="cblas_sbgemm cblas_sbgemv cblas_sbdot cblas_sbstobf16 cblas_sbdtobf16 cblas_sbf16tos cblas_dbf16tod cblas_sbgemm_batch"
exblasobjs=" exblasobjs="
qamax qamin qasum qaxpy qcabs1 qcopy qdot qgbmv qgemm qamax qamin qasum qaxpy qcabs1 qcopy qdot qgbmv qgemm

View File

@ -97,6 +97,9 @@ foreach (CBLAS_FLAG ${CBLAS_FLAGS})
#sdsdot, dsdot #sdsdot, dsdot
if (BUILD_SINGLE OR BUILD_DOUBLE) if (BUILD_SINGLE OR BUILD_DOUBLE)
GenerateNamedObjects("sdsdot.c" "" "sdsdot" ${CBLAS_FLAG} "" "" true "SINGLE") GenerateNamedObjects("sdsdot.c" "" "sdsdot" ${CBLAS_FLAG} "" "" true "SINGLE")
if(CBLAS_FLAG EQUAL 1)
GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" ${CBLAS_FLAG} "" "" false)
endif ()
endif () endif ()
if (BUILD_DOUBLE) if (BUILD_DOUBLE)
GenerateNamedObjects("dsdot.c" "" "dsdot" ${CBLAS_FLAG} "" "" true "SINGLE") GenerateNamedObjects("dsdot.c" "" "dsdot" ${CBLAS_FLAG} "" "" true "SINGLE")
@ -125,6 +128,9 @@ if (BUILD_BFLOAT16)
GenerateNamedObjects("tobf16.c" "DOUBLE_PREC" "sbdtobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("tobf16.c" "DOUBLE_PREC" "sbdtobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16")
GenerateNamedObjects("bf16to.c" "SINGLE_PREC" "sbf16tos" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("bf16to.c" "SINGLE_PREC" "sbf16tos" ${CBLAS_FLAG} "" "" true "BFLOAT16")
GenerateNamedObjects("bf16to.c" "DOUBLE_PREC" "dbf16tod" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("bf16to.c" "DOUBLE_PREC" "dbf16tod" ${CBLAS_FLAG} "" "" true "BFLOAT16")
if(CBLAS_FLAG EQUAL 1)
GenerateNamedObjects("gemm_batch.c" "" "sbgemm_batch" ${CBLAS_FLAG} "" "" true "BFLOAT16")
endif ()
endif () endif ()
# complex-specific sources # complex-specific sources
@ -154,6 +160,9 @@ foreach (float_type ${FLOAT_TYPES})
GenerateNamedObjects("max.c" "USE_ABS" "scamax" ${CBLAS_FLAG} "" "" true "COMPLEX") GenerateNamedObjects("max.c" "USE_ABS" "scamax" ${CBLAS_FLAG} "" "" true "COMPLEX")
GenerateNamedObjects("asum.c" "" "scasum" ${CBLAS_FLAG} "" "" true "COMPLEX") GenerateNamedObjects("asum.c" "" "scasum" ${CBLAS_FLAG} "" "" true "COMPLEX")
GenerateNamedObjects("sum.c" "" "scsum" ${CBLAS_FLAG} "" "" true "COMPLEX") GenerateNamedObjects("sum.c" "" "scsum" ${CBLAS_FLAG} "" "" true "COMPLEX")
if(CBLAS_FLAG EQUAL 1)
GenerateNamedObjects("gemm_batch.c" "" "cgemm_batch" ${CBLAS_FLAG} "" "" true "COMPLEX")
endif ()
endif () endif ()
if (${float_type} STREQUAL "ZCOMPLEX") if (${float_type} STREQUAL "ZCOMPLEX")
GenerateNamedObjects("zscal.c" "SSCAL" "dscal" ${CBLAS_FLAG} "" "" false "ZCOMPLEX") GenerateNamedObjects("zscal.c" "SSCAL" "dscal" ${CBLAS_FLAG} "" "" false "ZCOMPLEX")
@ -163,6 +172,9 @@ foreach (float_type ${FLOAT_TYPES})
GenerateNamedObjects("max.c" "USE_ABS" "dzamax" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") GenerateNamedObjects("max.c" "USE_ABS" "dzamax" ${CBLAS_FLAG} "" "" true "ZCOMPLEX")
GenerateNamedObjects("asum.c" "" "dzasum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") GenerateNamedObjects("asum.c" "" "dzasum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX")
GenerateNamedObjects("sum.c" "" "dzsum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") GenerateNamedObjects("sum.c" "" "dzsum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX")
if(CBLAS_FLAG EQUAL 1)
GenerateNamedObjects("gemm_batch.c" "" "zgemm_batch" ${CBLAS_FLAG} "" "" true "ZCOMPLEX")
endif ()
endif () endif ()
endforeach () endforeach ()
@ -212,6 +224,7 @@ if ( BUILD_COMPLEX AND NOT BUILD_SINGLE)
GenerateNamedObjects("nrm2.c" "" "nrm2" 0 "" "" false "SINGLE") GenerateNamedObjects("nrm2.c" "" "nrm2" 0 "" "" false "SINGLE")
GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "SINGLE") GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "SINGLE")
GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "SINGLE") GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "SINGLE")
GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 1 "" "" false "SINGLE")
GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "SINGLE") GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "SINGLE")
GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "SINGLE") GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "SINGLE")
GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "SINGLE") GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "SINGLE")
@ -225,6 +238,7 @@ if ( BUILD_COMPLEX16 AND NOT BUILD_DOUBLE)
GenerateNamedObjects("nrm2.c" "" "nrm2" 0 "" "" false "DOUBLE") GenerateNamedObjects("nrm2.c" "" "nrm2" 0 "" "" false "DOUBLE")
GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "DOUBLE") GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "DOUBLE")
GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "DOUBLE") GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "DOUBLE")
GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 1 "" "" false "DOUBLE")
GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "DOUBLE") GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "DOUBLE")
GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "DOUBLE") GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "DOUBLE")
GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "DOUBLE") GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "DOUBLE")

View File

@ -282,12 +282,12 @@ CSBLAS2OBJS = \
CSBLAS3OBJS = \ CSBLAS3OBJS = \
cblas_sgemm.$(SUFFIX) cblas_ssymm.$(SUFFIX) cblas_strmm.$(SUFFIX) cblas_strsm.$(SUFFIX) \ cblas_sgemm.$(SUFFIX) cblas_ssymm.$(SUFFIX) cblas_strmm.$(SUFFIX) cblas_strsm.$(SUFFIX) \
cblas_ssyrk.$(SUFFIX) cblas_ssyr2k.$(SUFFIX) cblas_somatcopy.$(SUFFIX) cblas_simatcopy.$(SUFFIX)\ cblas_ssyrk.$(SUFFIX) cblas_ssyr2k.$(SUFFIX) cblas_somatcopy.$(SUFFIX) cblas_simatcopy.$(SUFFIX)\
cblas_sgeadd.$(SUFFIX) cblas_sgemmt.$(SUFFIX) cblas_sgeadd.$(SUFFIX) cblas_sgemmt.$(SUFFIX) cblas_sgemm_batch.$(SUFFIX)
ifeq ($(BUILD_BFLOAT16),1) ifeq ($(BUILD_BFLOAT16),1)
CSBBLAS1OBJS = cblas_sbdot.$(SUFFIX) CSBBLAS1OBJS = cblas_sbdot.$(SUFFIX)
CSBBLAS2OBJS = cblas_sbgemv.$(SUFFIX) CSBBLAS2OBJS = cblas_sbgemv.$(SUFFIX)
CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) cblas_sbgemm_batch.$(SUFFIX)
CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX)
endif endif
@ -308,7 +308,7 @@ CDBLAS2OBJS = \
CDBLAS3OBJS += \ CDBLAS3OBJS += \
cblas_dgemm.$(SUFFIX) cblas_dsymm.$(SUFFIX) cblas_dtrmm.$(SUFFIX) cblas_dtrsm.$(SUFFIX) \ cblas_dgemm.$(SUFFIX) cblas_dsymm.$(SUFFIX) cblas_dtrmm.$(SUFFIX) cblas_dtrsm.$(SUFFIX) \
cblas_dsyrk.$(SUFFIX) cblas_dsyr2k.$(SUFFIX) cblas_domatcopy.$(SUFFIX) cblas_dimatcopy.$(SUFFIX) \ cblas_dsyrk.$(SUFFIX) cblas_dsyr2k.$(SUFFIX) cblas_domatcopy.$(SUFFIX) cblas_dimatcopy.$(SUFFIX) \
cblas_dgeadd.$(SUFFIX) cblas_dgemmt.$(SUFFIX) cblas_dgeadd.$(SUFFIX) cblas_dgemmt.$(SUFFIX) cblas_dgemm_batch.$(SUFFIX)
CCBLAS1OBJS = \ CCBLAS1OBJS = \
cblas_icamax.$(SUFFIX) cblas_icamin.$(SUFFIX) cblas_scasum.$(SUFFIX) cblas_caxpy.$(SUFFIX) \ cblas_icamax.$(SUFFIX) cblas_icamin.$(SUFFIX) cblas_scasum.$(SUFFIX) cblas_caxpy.$(SUFFIX) \
@ -333,7 +333,7 @@ CCBLAS3OBJS = \
cblas_csyrk.$(SUFFIX) cblas_csyr2k.$(SUFFIX) \ cblas_csyrk.$(SUFFIX) cblas_csyr2k.$(SUFFIX) \
cblas_chemm.$(SUFFIX) cblas_cherk.$(SUFFIX) cblas_cher2k.$(SUFFIX) \ cblas_chemm.$(SUFFIX) cblas_cherk.$(SUFFIX) cblas_cher2k.$(SUFFIX) \
cblas_comatcopy.$(SUFFIX) cblas_cimatcopy.$(SUFFIX)\ cblas_comatcopy.$(SUFFIX) cblas_cimatcopy.$(SUFFIX)\
cblas_cgeadd.$(SUFFIX) cblas_cgemmt.$(SUFFIX) cblas_cgeadd.$(SUFFIX) cblas_cgemmt.$(SUFFIX) cblas_cgemm_batch.$(SUFFIX)
CXERBLAOBJ = \ CXERBLAOBJ = \
cblas_xerbla.$(SUFFIX) cblas_xerbla.$(SUFFIX)
@ -364,7 +364,7 @@ CZBLAS3OBJS = \
cblas_zsyrk.$(SUFFIX) cblas_zsyr2k.$(SUFFIX) \ cblas_zsyrk.$(SUFFIX) cblas_zsyr2k.$(SUFFIX) \
cblas_zhemm.$(SUFFIX) cblas_zherk.$(SUFFIX) cblas_zher2k.$(SUFFIX)\ cblas_zhemm.$(SUFFIX) cblas_zherk.$(SUFFIX) cblas_zher2k.$(SUFFIX)\
cblas_zomatcopy.$(SUFFIX) cblas_zimatcopy.$(SUFFIX) \ cblas_zomatcopy.$(SUFFIX) cblas_zimatcopy.$(SUFFIX) \
cblas_zgeadd.$(SUFFIX) cblas_zgemmt.$(SUFFIX) cblas_zgeadd.$(SUFFIX) cblas_zgemmt.$(SUFFIX) cblas_zgemm_batch.$(SUFFIX)
ifeq ($(SUPPORT_GEMM3M), 1) ifeq ($(SUPPORT_GEMM3M), 1)
@ -2419,3 +2419,17 @@ cblas_zgeadd.$(SUFFIX) cblas_zgeadd.$(PSUFFIX) : zgeadd.c
cblas_xerbla.$(SUFFIX) cblas_xerbla.$(PSUFFIX) : xerbla.c cblas_xerbla.$(SUFFIX) cblas_xerbla.$(PSUFFIX) : xerbla.c
$(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F)
cblas_sbgemm_batch.$(SUFFIX) cblas_sbgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
$(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F)
cblas_sgemm_batch.$(SUFFIX) cblas_sgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
$(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F)
cblas_dgemm_batch.$(SUFFIX) cblas_dgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
$(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F)
cblas_cgemm_batch.$(SUFFIX) cblas_cgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
$(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F)
cblas_zgemm_batch.$(SUFFIX) cblas_zgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
$(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F)

372
interface/gemm_batch.c Normal file
View File

@ -0,0 +1,372 @@
/*****************************************************************************
Copyright (c) 2020, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
**********************************************************************************/
#include <stdio.h>
#include <stdlib.h>
#include "common.h"
void openblas_warning(int verbose, const char * msg);
#ifndef COMPLEX
#ifdef XDOUBLE
#define ERROR_NAME "QGEMM_BATCH "
#elif defined(DOUBLE)
#define ERROR_NAME "DGEMM_BATCH "
#define GEMM_BATCH_THREAD dgemm_batch_thread
#else
#define ERROR_NAME "SGEMM_BATCH "
#define GEMM_BATCH_THREAD sgemm_batch_thread
#endif
#else
#ifdef XDOUBLE
#define ERROR_NAME "XGEMM_BATCH "
#elif defined(DOUBLE)
#define ERROR_NAME "ZGEMM_BATCH "
#define GEMM_BATCH_THREAD zgemm_batch_thread
#else
#define ERROR_NAME "CGEMM_BATCH "
#define GEMM_BATCH_THREAD cgemm_batch_thread
#endif
#endif
static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = {
GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR,
GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC,
};
#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE)
#define USE_SMALL_MATRIX_OPT 1
#else
#define USE_SMALL_MATRIX_OPT 0
#endif
#if USE_SMALL_MATRIX_OPT
#ifndef DYNAMIC_ARCH
#define SMALL_KERNEL_ADDR(table, idx) ((void *)(table[idx]))
#else
#define SMALL_KERNEL_ADDR(table, idx) ((void *)(*(uintptr_t *)((char *)gotoblas + (size_t)(table[idx]))))
#endif
#ifndef COMPLEX
static size_t gemm_small_kernel[] = {
GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, 0, 0,
GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, 0, 0,
};
static size_t gemm_small_kernel_b0[] = {
GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, 0, 0,
GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, 0, 0,
};
#define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx))
#define GEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel, (idx))
#else
static size_t zgemm_small_kernel[] = {
GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN,
GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT,
GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR,
GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC,
};
static size_t zgemm_small_kernel_b0[] = {
GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN,
GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT,
GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR,
GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC,
};
#define ZGEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel, (idx))
#define ZGEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel_b0, (idx))
#endif
#endif
void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CBLAS_TRANSPOSE * transb_array,
blasint * m_array, blasint * n_array, blasint * k_array,
#ifndef COMPLEX
FLOAT * alpha_array,
FLOAT ** a_array, blasint * lda_array,
FLOAT ** b_array, blasint * ldb_array,
FLOAT * beta_array,
FLOAT ** c_array, blasint * ldc_array, blasint group_count, blasint * group_size) {
#else
void * valpha_array,
void ** va_array, blasint * lda_array,
void ** vb_array, blasint * ldb_array,
void * vbeta_array,
void ** vc_array, blasint * ldc_array, blasint group_count, blasint * group_size) {
FLOAT * alpha_array=(FLOAT *)valpha_array;
FLOAT * beta_array=(FLOAT *)vbeta_array;
FLOAT ** a_array=(FLOAT**)va_array;
FLOAT ** b_array=(FLOAT**)vb_array;
FLOAT ** c_array=(FLOAT**)vc_array;
#endif
blas_arg_t * args_array=NULL;
int mode=0, group_mode=0;
blasint total_num=0;
blasint i=0, j=0, matrix_idx=0, count=0;
int group_transa, group_transb;
BLASLONG group_nrowa, group_nrowb;
blasint info;
void * group_alpha, * group_beta;
BLASLONG group_m, group_n, group_k;
BLASLONG group_lda, group_ldb, group_ldc;
void * group_routine=NULL;
#ifdef SMALL_MATRIX_OPT
void * group_small_matrix_opt_routine=NULL;
#endif
#if defined (SMP) || defined(SMALL_MATRIX_OPT)
double MNK;
#endif
PRINT_DEBUG_CNAME;
for(i=0; i<group_count; i++){
total_num+=group_size[i];
}
args_array=(blas_arg_t *)malloc(total_num * sizeof(blas_arg_t));
if(args_array == NULL){
openblas_warning(0, "memory alloc failed!\n");
return;
}
#ifdef SMP
#ifndef COMPLEX
#ifdef XDOUBLE
mode = BLAS_XDOUBLE | BLAS_REAL;
#elif defined(DOUBLE)
mode = BLAS_DOUBLE | BLAS_REAL;
#else
mode = BLAS_SINGLE | BLAS_REAL;
#endif
#else
#ifdef XDOUBLE
mode = BLAS_XDOUBLE | BLAS_COMPLEX;
#elif defined(DOUBLE)
mode = BLAS_DOUBLE | BLAS_COMPLEX;
#else
mode = BLAS_SINGLE | BLAS_COMPLEX;
#endif
#endif
#endif
for(i=0; i<group_count; matrix_idx+=group_size[i], i++){
group_alpha = (void *)&alpha_array[i * COMPSIZE];
group_beta = (void *)&beta_array[i * COMPSIZE];
group_m = group_n = group_k = 0;
group_lda = group_ldb = group_ldc = 0;
group_transa = -1;
group_transb = -1;
info = 0;
if (order == CblasColMajor) {
group_m = m_array[i];
group_n = n_array[i];
group_k = k_array[i];
group_lda = lda_array[i];
group_ldb = ldb_array[i];
group_ldc = ldc_array[i];
if (transa_array[i] == CblasNoTrans) group_transa = 0;
if (transa_array[i] == CblasTrans) group_transa = 1;
#ifndef COMPLEX
if (transa_array[i] == CblasConjNoTrans) group_transa = 0;
if (transa_array[i] == CblasConjTrans) group_transa = 1;
#else
if (transa_array[i] == CblasConjNoTrans) group_transa = 2;
if (transa_array[i] == CblasConjTrans) group_transa = 3;
#endif
if (transb_array[i] == CblasNoTrans) group_transb = 0;
if (transb_array[i] == CblasTrans) group_transb = 1;
#ifndef COMPLEX
if (transb_array[i] == CblasConjNoTrans) group_transb = 0;
if (transb_array[i] == CblasConjTrans) group_transb = 1;
#else
if (transb_array[i] == CblasConjNoTrans) group_transb = 2;
if (transb_array[i] == CblasConjTrans) group_transb = 3;
#endif
group_nrowa = group_m;
if (group_transa & 1) group_nrowa = group_k;
group_nrowb = group_k;
if (group_transb & 1) group_nrowb = group_n;
info=-1;
if (group_ldc < group_m) info = 13;
if (group_ldb < group_nrowb) info = 10;
if (group_lda < group_nrowa) info = 8;
if (group_k < 0) info = 5;
if (group_n < 0) info = 4;
if (group_m < 0) info = 3;
if (group_transb < 0) info = 2;
if (group_transa < 0) info = 1;
}else if (order == CblasRowMajor) {
group_m = n_array[i];
group_n = m_array[i];
group_k = k_array[i];
group_lda = ldb_array[i];
group_ldb = lda_array[i];
group_ldc = ldc_array[i];
if (transb_array[i] == CblasNoTrans) group_transa = 0;
if (transb_array[i] == CblasTrans) group_transa = 1;
#ifndef COMPLEX
if (transb_array[i] == CblasConjNoTrans) group_transa = 0;
if (transb_array[i] == CblasConjTrans) group_transa = 1;
#else
if (transb_array[i] == CblasConjNoTrans) group_transa = 2;
if (transb_array[i] == CblasConjTrans) group_transa = 3;
#endif
if (transa_array[i] == CblasNoTrans) group_transb = 0;
if (transa_array[i] == CblasTrans) group_transb = 1;
#ifndef COMPLEX
if (transa_array[i] == CblasConjNoTrans) group_transb = 0;
if (transa_array[i] == CblasConjTrans) group_transb = 1;
#else
if (transa_array[i] == CblasConjNoTrans) group_transb = 2;
if (transa_array[i] == CblasConjTrans) group_transb = 3;
#endif
group_nrowa = group_m;
if (group_transa & 1) group_nrowa = group_k;
group_nrowb = group_k;
if (group_transb & 1) group_nrowb = group_n;
info=-1;
if (group_ldc < group_m) info = 13;
if (group_ldb < group_nrowb) info = 10;
if (group_lda < group_nrowa) info = 8;
if (group_k < 0) info = 5;
if (group_n < 0) info = 4;
if (group_m < 0) info = 3;
if (group_transb < 0) info = 2;
if (group_transa < 0) info = 1;
}
if (info >= 0) {
BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
free(args_array);
return;
}
if (group_m == 0 || group_n == 0) continue;
group_mode=mode;
#if defined(SMP) || defined(SMALL_MATRIX_OPT)
MNK = (double) group_m * (double) group_n * (double) group_k;
#endif
#ifdef SMALL_MATRIX_OPT
if (MNK <= 100.0*100.0*100.0){
group_routine=NULL;
#if !defined(COMPLEX)
if(*(FLOAT *)(group_beta) == 0.0){
group_mode=mode | BLAS_SMALL_B0_OPT;
group_small_matrix_opt_routine=(void *)(gemm_small_kernel_b0[(group_transb<<2)|group_transa]);
}else{
group_mode=mode | BLAS_SMALL_OPT;
group_small_matrix_opt_routine=(void *)(gemm_small_kernel[(group_transb<<2)|group_transa]);
}
#else
if(((FLOAT *)(group_beta))[0] == 0.0 && ((FLOAT *)(group_beta))[1] == 0.0){
group_mode=mode | BLAS_SMALL_B0_OPT;
group_small_matrix_opt_routine=(void *)(zgemm_small_kernel_b0[(group_transb<<2)|group_transa]);
}else{
group_mode=mode | BLAS_SMALL_OPT;
group_small_matrix_opt_routine=(void *)(zgemm_small_kernel[(group_transb<<2)|group_transa]);
}
#endif
}else{
#endif
group_routine=(void*)(gemm[(group_transb<<2)|group_transa]);
#ifdef SMALL_MATRIX_OPT
}
#endif
for(j=0; j<group_size[i]; j++){
args_array[count].m=group_m;
args_array[count].n=group_n;
args_array[count].k=group_k;
args_array[count].lda=group_lda;
args_array[count].ldb=group_ldb;
args_array[count].ldc=group_ldc;
args_array[count].alpha=group_alpha;
args_array[count].beta=group_beta;
if (order == CblasColMajor) {
args_array[count].a=(a_array[matrix_idx+j]);
args_array[count].b=(b_array[matrix_idx+j]);
}else if(order == CblasRowMajor){
args_array[count].a=(b_array[matrix_idx+j]);
args_array[count].b=(a_array[matrix_idx+j]);
}
args_array[count].c=(c_array[matrix_idx+j]);
args_array[count].routine_mode=group_mode;
args_array[count].routine=group_routine;
#ifdef SMALL_MATRIX_OPT
if (!group_routine)
args_array[count].routine=group_small_matrix_opt_routine;
#endif
count++;
}
}
if(count>0){
GEMM_BATCH_THREAD(args_array,count);
}
free(args_array);
}