From 28b5334f22dc972613493bb8cd207657c002e240 Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Tue, 23 Jul 2024 20:42:39 +0000 Subject: [PATCH] Complete implementation of GEMV forwarding --- interface/gemm.c | 89 +++++++++++++++++++++++++----------------------- 1 file changed, 47 insertions(+), 42 deletions(-) diff --git a/interface/gemm.c b/interface/gemm.c index c3118672f..c8a20d119 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -1,4 +1,5 @@ /*********************************************************************/ +/* Copyright 2024 The OpenBLAS Project */ /* Copyright 2009, 2010 The University of Texas at Austin. */ /* All rights reserved. */ /* */ @@ -63,13 +64,10 @@ #ifndef GEMM3M #ifdef XDOUBLE #define ERROR_NAME "XGEMM " -#define GEMV BLASFUNC(xgemv) #elif defined(DOUBLE) #define ERROR_NAME "ZGEMM " -#define GEMV BLASFUNC(zgemv) #else #define ERROR_NAME "CGEMM " -#define GEMV BLASFUNC(cgemv) #endif #else #ifdef XDOUBLE @@ -492,42 +490,54 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS } #endif #endif // defined(__linux__) && defined(__x86_64__) && defined(BFLOAT16) - // fprintf(stderr,"G E M M interface m n k %d %d %d\n",args.m,args.n,args.k); if ((args.m == 0) || (args.n == 0)) return; -#if 1 -#ifndef GEMM3M - if (args.m == 1) { - char *NT=(char*)malloc(2*sizeof(char)); - if (transb&1)strcpy(NT,"T"); - else NT="N"; -// fprintf(stderr,"G E M V\n"); - GEMV(NT, &args.n ,&args.k, args.alpha, args.b, &args.ldb, args.a, &args.m, args.beta, args.c, &args.m); -//SUBROUTINE SGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY) -return; - } else { - if (args.n == 1) { -#ifndef CBLAS - char *NT=(char*)malloc(2*sizeof(char)); - strcpy(NT,"N"); -#else - char *NT=(char*)malloc(2*sizeof(char)); - if (transb&1)strcpy(NT,"T"); - else strcpy(NT,"N"); -#endif -// fprintf(stderr,"G E M V ! ! ! lda=%d ldb=%d ldc=%d\n",args.lda,args.ldb,args.ldc); - GEMV(NT, &args.m ,&args.k, args.alpha, args.a, &args.lda, args.b, &args.n, args.beta, args.c, &args.n); -//SUBROUTINE SGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY) - return; +#if !defined(GEMM3M) && !defined(COMPLEX) + // Check if we can convert GEMM -> GEMV + if (args.k != 0) { + if (args.n == 1) { + blasint inc_x = 1; + blasint inc_y = 1; + // These were passed in as blasint, but the struct translates them to blaslong + blasint m = args.m; + blasint n = args.k; + blasint lda = args.lda; + // Create new transpose parameters + char NT = 'N'; + if (transa & 1) { + NT = 'T'; + m = args.k; + n = args.m; + } + if (transb & 1) { + inc_x = args.ldb; + } + GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y); + return; + } + if (args.m == 1) { + blasint inc_x = args.lda; + blasint inc_y = args.ldc; + // These were passed in as blasint, but the struct translates them to blaslong + blasint m = args.k; + blasint n = args.n; + blasint ldb = args.ldb; + // Create new transpose parameters + char NT = 'T'; + if (transa & 1) { + inc_x = 1; + } + if (transb & 1) { + NT = 'N'; + m = args.n; + n = args.k; + } + GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y); + return; } } #endif -#endif -#if 0 - fprintf(stderr, "m = %4d n = %d k = %d lda = %4d ldb = %4d ldc = %4d\n", - args.m, args.n, args.k, args.lda, args.ldb, args.ldc); -#endif IDEBUG_START; @@ -557,15 +567,10 @@ return; buffer = (XFLOAT *)blas_memory_alloc(0); -//For Loongson servers, like the 3C5000 (featuring 16 cores), applying an -//offset to the buffer is essential for minimizing cache conflicts and optimizing performance. -#if defined(LOONGSON3R5) && !defined(NO_AFFINITY) - char model_name[128]; - get_cpu_model(model_name); - if ((strstr(model_name, "3C5000") != NULL) || (strstr(model_name, "3D5000") != NULL)) - sa = (XFLOAT *)((BLASLONG)buffer + (WhereAmI() & 0xf) * GEMM_OFFSET_A); - else - sa = (XFLOAT *)((BLASLONG)buffer + GEMM_OFFSET_A); +//For target LOONGSON3R5, applying an offset to the buffer is essential +//for minimizing cache conflicts and optimizing performance. +#if defined(ARCH_LOONGARCH64) && !defined(NO_AFFINITY) + sa = (XFLOAT *)((BLASLONG)buffer + (WhereAmI() & 0xf) * GEMM_OFFSET_A); #else sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A); #endif