Complete implementation of GEMV forwarding

This commit is contained in:
Chris Sidebottom 2024-07-23 20:42:39 +00:00
parent 3db5dbc88e
commit 28b5334f22
1 changed files with 47 additions and 42 deletions

View File

@ -1,4 +1,5 @@
/*********************************************************************/ /*********************************************************************/
/* Copyright 2024 The OpenBLAS Project */
/* Copyright 2009, 2010 The University of Texas at Austin. */ /* Copyright 2009, 2010 The University of Texas at Austin. */
/* All rights reserved. */ /* All rights reserved. */
/* */ /* */
@ -63,13 +64,10 @@
#ifndef GEMM3M #ifndef GEMM3M
#ifdef XDOUBLE #ifdef XDOUBLE
#define ERROR_NAME "XGEMM " #define ERROR_NAME "XGEMM "
#define GEMV BLASFUNC(xgemv)
#elif defined(DOUBLE) #elif defined(DOUBLE)
#define ERROR_NAME "ZGEMM " #define ERROR_NAME "ZGEMM "
#define GEMV BLASFUNC(zgemv)
#else #else
#define ERROR_NAME "CGEMM " #define ERROR_NAME "CGEMM "
#define GEMV BLASFUNC(cgemv)
#endif #endif
#else #else
#ifdef XDOUBLE #ifdef XDOUBLE
@ -492,42 +490,54 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
} }
#endif #endif
#endif // defined(__linux__) && defined(__x86_64__) && defined(BFLOAT16) #endif // defined(__linux__) && defined(__x86_64__) && defined(BFLOAT16)
// fprintf(stderr,"G E M M interface m n k %d %d %d\n",args.m,args.n,args.k);
if ((args.m == 0) || (args.n == 0)) return; if ((args.m == 0) || (args.n == 0)) return;
#if 1 #if !defined(GEMM3M) && !defined(COMPLEX)
#ifndef GEMM3M // Check if we can convert GEMM -> GEMV
if (args.m == 1) { if (args.k != 0) {
char *NT=(char*)malloc(2*sizeof(char));
if (transb&1)strcpy(NT,"T");
else NT="N";
// fprintf(stderr,"G E M V\n");
GEMV(NT, &args.n ,&args.k, args.alpha, args.b, &args.ldb, args.a, &args.m, args.beta, args.c, &args.m);
//SUBROUTINE SGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
return;
} else {
if (args.n == 1) { if (args.n == 1) {
#ifndef CBLAS blasint inc_x = 1;
char *NT=(char*)malloc(2*sizeof(char)); blasint inc_y = 1;
strcpy(NT,"N"); // These were passed in as blasint, but the struct translates them to blaslong
#else blasint m = args.m;
char *NT=(char*)malloc(2*sizeof(char)); blasint n = args.k;
if (transb&1)strcpy(NT,"T"); blasint lda = args.lda;
else strcpy(NT,"N"); // Create new transpose parameters
#endif char NT = 'N';
// fprintf(stderr,"G E M V ! ! ! lda=%d ldb=%d ldc=%d\n",args.lda,args.ldb,args.ldc); if (transa & 1) {
GEMV(NT, &args.m ,&args.k, args.alpha, args.a, &args.lda, args.b, &args.n, args.beta, args.c, &args.n); NT = 'T';
//SUBROUTINE SGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY) m = args.k;
n = args.m;
}
if (transb & 1) {
inc_x = args.ldb;
}
GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
return;
}
if (args.m == 1) {
blasint inc_x = args.lda;
blasint inc_y = args.ldc;
// These were passed in as blasint, but the struct translates them to blaslong
blasint m = args.k;
blasint n = args.n;
blasint ldb = args.ldb;
// Create new transpose parameters
char NT = 'T';
if (transa & 1) {
inc_x = 1;
}
if (transb & 1) {
NT = 'N';
m = args.n;
n = args.k;
}
GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
return; return;
} }
} }
#endif #endif
#endif
#if 0
fprintf(stderr, "m = %4d n = %d k = %d lda = %4d ldb = %4d ldc = %4d\n",
args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
#endif
IDEBUG_START; IDEBUG_START;
@ -557,15 +567,10 @@ return;
buffer = (XFLOAT *)blas_memory_alloc(0); buffer = (XFLOAT *)blas_memory_alloc(0);
//For Loongson servers, like the 3C5000 (featuring 16 cores), applying an //For target LOONGSON3R5, applying an offset to the buffer is essential
//offset to the buffer is essential for minimizing cache conflicts and optimizing performance. //for minimizing cache conflicts and optimizing performance.
#if defined(LOONGSON3R5) && !defined(NO_AFFINITY) #if defined(ARCH_LOONGARCH64) && !defined(NO_AFFINITY)
char model_name[128];
get_cpu_model(model_name);
if ((strstr(model_name, "3C5000") != NULL) || (strstr(model_name, "3D5000") != NULL))
sa = (XFLOAT *)((BLASLONG)buffer + (WhereAmI() & 0xf) * GEMM_OFFSET_A); sa = (XFLOAT *)((BLASLONG)buffer + (WhereAmI() & 0xf) * GEMM_OFFSET_A);
else
sa = (XFLOAT *)((BLASLONG)buffer + GEMM_OFFSET_A);
#else #else
sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A); sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A);
#endif #endif