Improve input argument checks of gemmt

* Fix return value for invalid info
* Add missing checks for ldA, ldB
* Use reference-LAPACK like checks (ie ld=0,nrows=0 is invalid)
This commit is contained in:
Angelika Schwarz 2023-04-27 17:59:15 +02:00
parent df88536d1c
commit 899c3a6f6a
1 changed files with 52 additions and 21 deletions

View File

@ -77,6 +77,7 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
blasint info; blasint info;
char transA, transB, Uplo; char transA, transB, Uplo;
blasint nrowa, nrowb;
IFLOAT *buffer; IFLOAT *buffer;
IFLOAT *aa, *bb; IFLOAT *aa, *bb;
FLOAT *cc; FLOAT *cc;
@ -155,22 +156,31 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
if (Uplo == 'L') if (Uplo == 'L')
uplo = 1; uplo = 1;
nrowa = m;
if (transa) nrowa = k;
nrowb = k;
if (transb) nrowb = m;
info = 0; info = 0;
if (uplo < 0) if (ldc < MAX(1, m))
info = 14;
if (ldc < m)
info = 13; info = 13;
if (ldb < MAX(1, nrowa))
info = 10;
if (lda < MAX(1, nrowb))
info = 8;
if (k < 0) if (k < 0)
info = 5; info = 5;
if (m < 0) if (m < 0)
info = 3; info = 4;
if (transb < 0) if (transb < 0)
info = 2; info = 3;
if (transa < 0) if (transa < 0)
info = 2;
if (uplo < 0)
info = 1; info = 1;
if (info) { if (info != 0) {
BLASFUNC(xerbla) (ERROR_NAME, &info, sizeof(ERROR_NAME)); BLASFUNC(xerbla) (ERROR_NAME, &info, sizeof(ERROR_NAME));
return; return;
} }
@ -205,11 +215,14 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
PRINT_DEBUG_CNAME; PRINT_DEBUG_CNAME;
uplo = -1;
transa = -1; transa = -1;
transb = -1; transb = -1;
info = 0; info = 0;
if (order == CblasColMajor) { if (order == CblasColMajor) {
if (Uplo == CblasUpper) uplo = 0;
if (Uplo == CblasLower) uplo = 1;
if (TransA == CblasNoTrans) if (TransA == CblasNoTrans)
transa = 0; transa = 0;
@ -249,15 +262,27 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
info = -1; info = -1;
if (ldc < m) blasint nrowa, nrowb;
nrowa = m;
if (transa) nrowa = k;
nrowb = k;
if (transb) nrowb = m;
if (ldc < MAX(1, m))
info = 13; info = 13;
if (ldb < MAX(1, nrowb))
info = 10;
if (lda < MAX(1, nrowa))
info = 8;
if (k < 0) if (k < 0)
info = 5; info = 5;
if (m < 0) if (m < 0)
info = 3; info = 4;
if (transb < 0) if (transb < 0)
info = 2; info = 3;
if (transa < 0) if (transa < 0)
info = 2;
if (uplo < 0)
info = 1; info = 1;
} }
@ -269,6 +294,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
lda = LDB; lda = LDB;
ldb = LDA; ldb = LDA;
if (Uplo == CblasUpper) uplo = 0;
if (Uplo == CblasLower) uplo = 1;
if (TransB == CblasNoTrans) if (TransB == CblasNoTrans)
transa = 0; transa = 0;
if (TransB == CblasTrans) if (TransB == CblasTrans)
@ -302,27 +330,30 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
info = -1; info = -1;
if (ldc < m) blasint ncola, ncolb;
ncola = k;
if (transa) ncola = m;
ncolb = m;
if (transb) ncolb = k;
if (ldc < MAX(1,m))
info = 13; info = 13;
if (ldb < MAX(1, ncolb))
info = 10;
if (lda < MAX(1, ncola))
info = 8;
if (k < 0) if (k < 0)
info = 5; info = 5;
if (m < 0) if (m < 0)
info = 3; info = 4;
if (transb < 0) if (transb < 0)
info = 2; info = 3;
if (transa < 0) if (transa < 0)
info = 2;
if (uplo < 0)
info = 1; info = 1;
} }
uplo = -1;
if (Uplo == CblasUpper)
uplo = 0;
if (Uplo == CblasLower)
uplo = 1;
if (uplo < 0)
info = 14;
if (info >= 0) { if (info >= 0) {
BLASFUNC(xerbla) (ERROR_NAME, &info, sizeof(ERROR_NAME)); BLASFUNC(xerbla) (ERROR_NAME, &info, sizeof(ERROR_NAME));
return; return;