diff --git a/interface/gemmt.c b/interface/gemmt.c index 3eed1dfe4..d35406411 100644 --- a/interface/gemmt.c +++ b/interface/gemmt.c @@ -35,29 +35,26 @@ #include #include #include "common.h" -#ifdef FUNCTION_PROFILE -#include "functable.h" -#endif #ifndef COMPLEX #define SMP_THRESHOLD_MIN 65536.0 #ifdef XDOUBLE -#define ERROR_NAME "QGEMT " +#define ERROR_NAME "QGEMMT " #elif defined(DOUBLE) -#define ERROR_NAME "DGEMT " +#define ERROR_NAME "DGEMMT " #elif defined(BFLOAT16) -#define ERROR_NAME "SBGEMT " +#define ERROR_NAME "SBGEMMT " #else -#define ERROR_NAME "SGEMT " +#define ERROR_NAME "SGEMMT " #endif #else #define SMP_THRESHOLD_MIN 8192.0 #ifdef XDOUBLE -#define ERROR_NAME "XGEMT " +#define ERROR_NAME "XGEMMT " #elif defined(DOUBLE) -#define ERROR_NAME "ZGEMT " +#define ERROR_NAME "ZGEMMT " #else -#define ERROR_NAME "CGEMT " +#define ERROR_NAME "CGEMMT " #endif #endif @@ -68,13 +65,13 @@ #ifndef CBLAS void NAME(char *UPLO, char *TRANSA, char *TRANSB, - blasint * M, blasint * N, blasint * K, + blasint * M, blasint * K, FLOAT * Alpha, IFLOAT * a, blasint * ldA, IFLOAT * b, blasint * ldB, FLOAT * Beta, FLOAT * c, blasint * ldC) { - blasint m, n, k; + blasint m, k; blasint lda, ldb, ldc; int transa, transb, uplo; blasint info; @@ -92,7 +89,6 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB, PRINT_DEBUG_NAME; m = *M; - n = *N; k = *K; #if defined(COMPLEX) @@ -167,8 +163,6 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB, info = 13; if (k < 0) info = 5; - if (n < 0) - info = 4; if (m < 0) info = 3; if (transb < 0) @@ -184,7 +178,7 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB, void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, blasint M, - blasint N, blasint k, + blasint k, #ifndef COMPLEX FLOAT alpha, IFLOAT * A, blasint LDA, @@ -205,7 +199,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, int transa, transb, uplo; blasint info; - blasint m, n, lda, ldb; + blasint m, lda, ldb; FLOAT *a, *b; XFLOAT *buffer; @@ -248,9 +242,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, transb = 3; #endif - m = M; - n = N; - a = (void *)A; b = (void *)B; lda = LDA; @@ -262,8 +253,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, info = 13; if (k < 0) info = 5; - if (n < 0) - info = 4; if (m < 0) info = 3; if (transb < 0) @@ -273,8 +262,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, } if (order == CblasRowMajor) { - m = N; - n = M; a = (void *)B; b = (void *)A; @@ -319,8 +306,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, info = 13; if (k < 0) info = 5; - if (n < 0) - info = 4; if (m < 0) info = 3; if (transb < 0) @@ -407,37 +392,35 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, #endif - if ((m == 0) || (n == 0)) + if ((m == 0) ) return; IDEBUG_START; - FUNCTION_PROFILE_START(); - const blasint incb = (transb == 0) ? 1 : ldb; if (uplo == 1) { - for (i = 0; i < n; i++) { - j = n - i; + for (i = 0; i < m; i++) { + j = m - i; l = j; #if defined(COMPLEX) aa = a + i * 2; bb = b + i * ldb * 2; if (transa) { - l = k; aa = a + lda * i * 2; - bb = b + i * 2; } + if (transb) + bb = b + i * 2; cc = c + i * 2 * ldc + i * 2; #else aa = a + i; bb = b + i * ldb; if (transa) { - l = k; aa = a + lda * i; - bb = b + i; } + if (transb) + bb = b + i; cc = c + i * ldc + i; #endif @@ -458,8 +441,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, IDEBUG_START; - FUNCTION_PROFILE_START(); - buffer_size = j + k + 128 / sizeof(FLOAT); #ifdef WINDOWS_ABI buffer_size += 160 / sizeof(FLOAT); @@ -479,20 +460,34 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, #endif #if defined(COMPLEX) + if (!transa) (gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i, aa, lda, bb, incb, cc, 1, buffer); + else + (gemv[(int)transa]) (k, j, 0, alpha_r, alpha_i, + aa, lda, bb, incb, cc, 1, + buffer); #else + if (!transa) (gemv[(int)transa]) (j, k, 0, alpha, aa, lda, bb, incb, cc, 1, buffer); + else + (gemv[(int)transa]) (k, j, 0, alpha, aa, lda, + bb, incb, cc, 1, buffer); #endif #ifdef SMP } else { - + if (!transa) (gemv_thread[(int)transa]) (j, k, alpha, aa, lda, bb, incb, cc, 1, buffer, nthreads); + else + (gemv_thread[(int)transa]) (k, j, alpha, aa, + lda, bb, incb, cc, + 1, buffer, + nthreads); } #endif @@ -501,21 +496,19 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, } } else { - for (i = 0; i < n; i++) { + for (i = 0; i < m; i++) { j = i + 1; l = j; #if defined COMPLEX bb = b + i * ldb * 2; - if (transa) { - l = k; + if (transb) { bb = b + i * 2; } cc = c + i * 2 * ldc; #else bb = b + i * ldb; - if (transa) { - l = k; + if (transb) { bb = b + i; } cc = c + i * ldc; @@ -537,8 +530,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, #endif IDEBUG_START; - FUNCTION_PROFILE_START(); - buffer_size = j + k + 128 / sizeof(FLOAT); #ifdef WINDOWS_ABI buffer_size += 160 / sizeof(FLOAT); @@ -558,30 +549,39 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, #endif #if defined(COMPLEX) + if (!transa) (gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i, a, lda, bb, incb, cc, 1, buffer); + else + (gemv[(int)transa]) (k, j, 0, alpha_r, alpha_i, + a, lda, bb, incb, cc, 1, + buffer); #else + if (!transa) (gemv[(int)transa]) (j, k, 0, alpha, a, lda, bb, incb, cc, 1, buffer); + else + (gemv[(int)transa]) (k, j, 0, alpha, a, lda, bb, + incb, cc, 1, buffer); #endif #ifdef SMP } else { - + if (!transa) (gemv_thread[(int)transa]) (j, k, alpha, a, lda, bb, incb, cc, 1, buffer, nthreads); - + else + (gemv_thread[(int)transa]) (k, j, alpha, a, lda, + bb, incb, cc, 1, + buffer, nthreads); } #endif STACK_FREE(buffer); } } - FUNCTION_PROFILE_END(COMPSIZE * COMPSIZE, - args.m * args.k + args.k * args.n + - args.m * args.n, 2 * args.m * args.n * args.k); IDEBUG_END;