From 103637887e35062aae8e795805c13ac9dfb6ca31 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Wed, 29 May 2024 15:46:10 +0200 Subject: [PATCH] add cblas_?gemm_batch --- cblas.h | 15 +++++++++++++++ common_level3.h | 7 ++++++- common_macro.h | 11 +++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/cblas.h b/cblas.h index ab8b7dcde..097b4303d 100644 --- a/cblas.h +++ b/cblas.h @@ -416,6 +416,18 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta, double *c, OPENBLAS_CONST blasint cldc); +void cblas_sgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, + OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST float ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST float ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); + +void cblas_dgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, + OPENBLAS_CONST double * alpha_array, OPENBLAS_CONST double ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST double ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST double * beta_array, double ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); + +void cblas_cgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, + OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); + +void cblas_zgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, + OPENBLAS_CONST void * alpha_array, OPENBLAS_CONST void ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST void ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST void * beta_array, void ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); + /*** BFLOAT16 and INT8 extensions ***/ /* convert float array to BFLOAT16 array by rounding */ void cblas_sbstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout); @@ -431,6 +443,9 @@ void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc); +void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, + OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); + #ifdef __cplusplus } #endif /* __cplusplus */ diff --git a/common_level3.h b/common_level3.h index 5080ada10..d370c1f96 100644 --- a/common_level3.h +++ b/common_level3.h @@ -1937,8 +1937,13 @@ int zimatcopy_k_rtc(BLASLONG, BLASLONG, double, double, double *, BLASLONG); int sgeadd_k(BLASLONG, BLASLONG, float, float*, BLASLONG, float, float *, BLASLONG); int dgeadd_k(BLASLONG, BLASLONG, double, double*, BLASLONG, double, double *, BLASLONG); int cgeadd_k(BLASLONG, BLASLONG, float, float, float*, BLASLONG, float, float, float *, BLASLONG); -int zgeadd_k(BLASLONG, BLASLONG, double,double, double*, BLASLONG, double, double, double *, BLASLONG); +int zgeadd_k(BLASLONG, BLASLONG, double,double, double*, BLASLONG, double, double, double *, BLASLONG); +int sgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); +int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); +int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); +int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); +int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); #ifdef __CUDACC__ } diff --git a/common_macro.h b/common_macro.h index 3226d0f11..a924651de 100644 --- a/common_macro.h +++ b/common_macro.h @@ -2655,9 +2655,20 @@ typedef struct { BLASLONG prea, preb, prec, pred; #endif + + //for gemm_batch + void * routine; + int routine_mode; + } blas_arg_t; #endif +#ifdef SMALL_MATRIX_OPT +#define BLAS_SMALL_OPT 0x10000U +#define BLAS_SMALL_B0_OPT 0x30000U +#endif + + #ifdef XDOUBLE #define TRSV_NUU qtrsv_NUU