From 899c3a6f6a801fef080962221755eac8d543d5df Mon Sep 17 00:00:00 2001 From: Angelika Schwarz <17718454+angsch@users.noreply.github.com> Date: Thu, 27 Apr 2023 17:59:15 +0200 Subject: [PATCH] Improve input argument checks of gemmt * Fix return value for invalid info * Add missing checks for ldA, ldB * Use reference-LAPACK like checks (ie ld=0,nrows=0 is invalid) --- interface/gemmt.c | 73 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 52 insertions(+), 21 deletions(-) diff --git a/interface/gemmt.c b/interface/gemmt.c index cebc7918d..046432670 100644 --- a/interface/gemmt.c +++ b/interface/gemmt.c @@ -77,6 +77,7 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB, blasint info; char transA, transB, Uplo; + blasint nrowa, nrowb; IFLOAT *buffer; IFLOAT *aa, *bb; FLOAT *cc; @@ -155,22 +156,31 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB, if (Uplo == 'L') uplo = 1; + nrowa = m; + if (transa) nrowa = k; + nrowb = k; + if (transb) nrowb = m; + info = 0; - if (uplo < 0) - info = 14; - if (ldc < m) + if (ldc < MAX(1, m)) info = 13; + if (ldb < MAX(1, nrowa)) + info = 10; + if (lda < MAX(1, nrowb)) + info = 8; if (k < 0) info = 5; if (m < 0) - info = 3; + info = 4; if (transb < 0) - info = 2; + info = 3; if (transa < 0) + info = 2; + if (uplo < 0) info = 1; - if (info) { + if (info != 0) { BLASFUNC(xerbla) (ERROR_NAME, &info, sizeof(ERROR_NAME)); return; } @@ -205,11 +215,14 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, PRINT_DEBUG_CNAME; + uplo = -1; transa = -1; transb = -1; info = 0; if (order == CblasColMajor) { + if (Uplo == CblasUpper) uplo = 0; + if (Uplo == CblasLower) uplo = 1; if (TransA == CblasNoTrans) transa = 0; @@ -249,15 +262,27 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, info = -1; - if (ldc < m) + blasint nrowa, nrowb; + nrowa = m; + if (transa) nrowa = k; + nrowb = k; + if (transb) nrowb = m; + + if (ldc < MAX(1, m)) info = 13; + if (ldb < MAX(1, nrowb)) + info = 10; + if (lda < MAX(1, nrowa)) + info = 8; if (k < 0) info = 5; if (m < 0) - info = 3; + info = 4; if (transb < 0) - info = 2; + info = 3; if (transa < 0) + info = 2; + if (uplo < 0) info = 1; } @@ -269,6 +294,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, lda = LDB; ldb = LDA; + if (Uplo == CblasUpper) uplo = 0; + if (Uplo == CblasLower) uplo = 1; + if (TransB == CblasNoTrans) transa = 0; if (TransB == CblasTrans) @@ -302,27 +330,30 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, info = -1; - if (ldc < m) + blasint ncola, ncolb; + ncola = k; + if (transa) ncola = m; + ncolb = m; + if (transb) ncolb = k; + + if (ldc < MAX(1,m)) info = 13; + if (ldb < MAX(1, ncolb)) + info = 10; + if (lda < MAX(1, ncola)) + info = 8; if (k < 0) info = 5; if (m < 0) - info = 3; + info = 4; if (transb < 0) - info = 2; + info = 3; if (transa < 0) + info = 2; + if (uplo < 0) info = 1; - } - uplo = -1; - if (Uplo == CblasUpper) - uplo = 0; - if (Uplo == CblasLower) - uplo = 1; - if (uplo < 0) - info = 14; - if (info >= 0) { BLASFUNC(xerbla) (ERROR_NAME, &info, sizeof(ERROR_NAME)); return;