Fixed handling of complex conjugate matrices and error codes for complex cases

This commit is contained in:
Martin Kroeker 2024-01-26 14:25:38 +01:00 committed by GitHub
parent d6a5174e9c
commit 41515e6e7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 82 additions and 32 deletions

View File

@ -78,6 +78,9 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
char transA, transB, Uplo; char transA, transB, Uplo;
blasint nrowa, nrowb; blasint nrowa, nrowb;
#if defined(COMPLEX)
blasint ncolb;
#endif
IFLOAT *buffer; IFLOAT *buffer;
IFLOAT *aa, *bb; IFLOAT *aa, *bb;
FLOAT *cc; FLOAT *cc;
@ -157,17 +160,25 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
uplo = 1; uplo = 1;
nrowa = m; nrowa = m;
if (transa) nrowa = k; if (transa & 1) nrowa = k;
nrowb = k; nrowb = k;
if (transb) nrowb = m; #if defined(COMPLEX)
ncolb = m;
#endif
if (transb & 1) {
nrowb = m;
#if defined(COMPLEX)
ncolb = k;
#endif
}
info = 0; info = 0;
if (ldc < MAX(1, m)) if (ldc < MAX(1, m))
info = 13; info = 13;
if (ldb < MAX(1, nrowa)) if (ldb < MAX(1, nrowb))
info = 10; info = 10;
if (lda < MAX(1, nrowb)) if (lda < MAX(1, nrowa))
info = 8; info = 8;
if (k < 0) if (k < 0)
info = 5; info = 5;
@ -211,6 +222,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
blasint info; blasint info;
blasint lda, ldb; blasint lda, ldb;
FLOAT *a, *b; FLOAT *a, *b;
#if defined(COMPLEX)
blasint nrowb, ncolb;
#endif
XFLOAT *buffer; XFLOAT *buffer;
PRINT_DEBUG_CNAME; PRINT_DEBUG_CNAME;
@ -262,11 +276,22 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
info = -1; info = -1;
blasint nrowa, nrowb; blasint nrowa;
#if !defined(COMPLEX)
blasint nrowb;
#endif
nrowa = m; nrowa = m;
if (transa) nrowa = k; if (transa & 1) nrowa = k;
nrowb = k; nrowb = k;
if (transb) nrowb = m; #if defined(COMPLEX)
ncolb = m;
#endif
if (transb & 1) {
nrowb = m;
#if defined(COMPLEX)
ncolb = k;
#endif
}
if (ldc < MAX(1, m)) if (ldc < MAX(1, m))
info = 13; info = 13;
@ -330,26 +355,38 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
info = -1; info = -1;
blasint ncola, ncolb; blasint ncola;
ncola = k; #if !defined(COMPLEX)
if (transa) ncola = m; blasint ncolb;
#endif
ncola = m;
if (transa & 1) ncola = k;
ncolb = k;
#if defined(COMPLEX)
nrowb = m;
#endif
if (transb & 1) {
#if defined(COMPLEX)
nrowb = k;
#endif
ncolb = m; ncolb = m;
if (transb) ncolb = k; }
if (ldc < MAX(1,m)) if (ldc < MAX(1,m))
info = 13; info = 13;
if (ldb < MAX(1, ncolb)) if (ldb < MAX(1, ncolb))
info = 10;
if (lda < MAX(1, ncola))
info = 8; info = 8;
if (lda < MAX(1, ncola))
info = 10;
if (k < 0) if (k < 0)
info = 5; info = 5;
if (m < 0) if (m < 0)
info = 4; info = 4;
if (transb < 0) if (transb < 0)
info = 3;
if (transa < 0)
info = 2; info = 2;
if (transa < 0)
info = 3;
if (uplo < 0) if (uplo < 0)
info = 1; info = 1;
} }
@ -428,7 +465,20 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
IDEBUG_START; IDEBUG_START;
const blasint incb = (transb == 0) ? 1 : ldb; #if defined(COMPLEX)
if (transb > 1){
#ifndef CBLAS
IMATCOPY_K_CNC(nrowb, ncolb, (FLOAT)(1.0), (FLOAT)(0.0), b, ldb);
#else
if (order == CblasColMajor)
IMATCOPY_K_CNC(nrowb, ncolb, (FLOAT)(1.0), (FLOAT)(0.0), b, ldb);
if (order == CblasRowMajor)
IMATCOPY_K_RNC(nrowb, ncolb, (FLOAT)(1.0), (FLOAT)(0.0), b, ldb);
#endif
}
#endif
const blasint incb = ((transb & 1) == 0) ? 1 : ldb;
if (uplo == 1) { if (uplo == 1) {
for (i = 0; i < m; i++) { for (i = 0; i < m; i++) {
@ -438,19 +488,19 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#if defined(COMPLEX) #if defined(COMPLEX)
aa = a + i * 2; aa = a + i * 2;
bb = b + i * ldb * 2; bb = b + i * ldb * 2;
if (transa) { if (transa & 1) {
aa = a + lda * i * 2; aa = a + lda * i * 2;
} }
if (transb) if (transb & 1)
bb = b + i * 2; bb = b + i * 2;
cc = c + i * 2 * ldc + i * 2; cc = c + i * 2 * ldc + i * 2;
#else #else
aa = a + i; aa = a + i;
bb = b + i * ldb; bb = b + i * ldb;
if (transa) { if (transa & 1) {
aa = a + lda * i; aa = a + lda * i;
} }
if (transb) if (transb & 1)
bb = b + i; bb = b + i;
cc = c + i * ldc + i; cc = c + i * ldc + i;
#endif #endif
@ -461,7 +511,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
NULL, 0); NULL, 0);
if (alpha_r == ZERO && alpha_i == ZERO) if (alpha_r == ZERO && alpha_i == ZERO)
return; continue;
#else #else
if (beta != ONE) if (beta != ONE)
SCAL_K(l, 0, 0, beta, cc, 1, NULL, 0, NULL, 0); SCAL_K(l, 0, 0, beta, cc, 1, NULL, 0, NULL, 0);
@ -491,7 +541,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#endif #endif
#if defined(COMPLEX) #if defined(COMPLEX)
if (!transa) if (!(transa & 1))
(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i, (gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i,
aa, lda, bb, incb, cc, 1, aa, lda, bb, incb, cc, 1,
buffer); buffer);
@ -500,7 +550,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
aa, lda, bb, incb, cc, 1, aa, lda, bb, incb, cc, 1,
buffer); buffer);
#else #else
if (!transa) if (!(transa & 1))
(gemv[(int)transa]) (j, k, 0, alpha, aa, lda, (gemv[(int)transa]) (j, k, 0, alpha, aa, lda,
bb, incb, cc, 1, buffer); bb, incb, cc, 1, buffer);
else else
@ -509,7 +559,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#endif #endif
#ifdef SMP #ifdef SMP
} else { } else {
if (!transa) if (!(transa & 1))
(gemv_thread[(int)transa]) (j, k, alpha, aa, (gemv_thread[(int)transa]) (j, k, alpha, aa,
lda, bb, incb, cc, lda, bb, incb, cc,
1, buffer, 1, buffer,
@ -533,13 +583,13 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
l = j; l = j;
#if defined COMPLEX #if defined COMPLEX
bb = b + i * ldb * 2; bb = b + i * ldb * 2;
if (transb) { if (transb & 1) {
bb = b + i * 2; bb = b + i * 2;
} }
cc = c + i * 2 * ldc; cc = c + i * 2 * ldc;
#else #else
bb = b + i * ldb; bb = b + i * ldb;
if (transb) { if (transb & 1) {
bb = b + i; bb = b + i;
} }
cc = c + i * ldc; cc = c + i * ldc;
@ -551,7 +601,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
NULL, 0); NULL, 0);
if (alpha_r == ZERO && alpha_i == ZERO) if (alpha_r == ZERO && alpha_i == ZERO)
return; continue;
#else #else
if (beta != ONE) if (beta != ONE)
SCAL_K(l, 0, 0, beta, cc, 1, NULL, 0, NULL, 0); SCAL_K(l, 0, 0, beta, cc, 1, NULL, 0, NULL, 0);
@ -580,7 +630,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#endif #endif
#if defined(COMPLEX) #if defined(COMPLEX)
if (!transa) if (!(transa & 1))
(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i, (gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i,
a, lda, bb, incb, cc, 1, a, lda, bb, incb, cc, 1,
buffer); buffer);
@ -589,7 +639,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
a, lda, bb, incb, cc, 1, a, lda, bb, incb, cc, 1,
buffer); buffer);
#else #else
if (!transa) if (!(transa & 1))
(gemv[(int)transa]) (j, k, 0, alpha, a, lda, bb, (gemv[(int)transa]) (j, k, 0, alpha, a, lda, bb,
incb, cc, 1, buffer); incb, cc, 1, buffer);
else else
@ -599,7 +649,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#ifdef SMP #ifdef SMP
} else { } else {
if (!transa) if (!(transa & 1))
(gemv_thread[(int)transa]) (j, k, alpha, a, lda, (gemv_thread[(int)transa]) (j, k, alpha, a, lda,
bb, incb, cc, 1, bb, incb, cc, 1,
buffer, nthreads); buffer, nthreads);