Files
OpenBLAS/relapack/src/ssygst.c
Martin Kroeker 798c448b0c Add support for INTERFACE64 and fix XERBLA calls
1. Replaced all instances of "int" with "blasint"
2. Added string length as "hidden" third parameter in calls to fortran XERBLA
2019-04-27 19:06:00 +02:00

213 lines
8.0 KiB
C

#include "relapack.h"
#if XSYGST_ALLOW_MALLOC
#include "stdlib.h"
#endif
static void RELAPACK_ssygst_rec(const blasint *, const char *, const blasint *,
float *, const blasint *, const float *, const blasint *,
float *, const blasint *, blasint *);
/** SSYGST reduces a real symmetric-definite generalized eigenproblem to standard form.
*
* This routine is functionally equivalent to LAPACK's ssygst.
* For details on its interface, see
* http://www.netlib.org/lapack/explore-html/d8/d78/ssygst_8f.html
* */
void RELAPACK_ssygst(
const blasint *itype, const char *uplo, const blasint *n,
float *A, const blasint *ldA, const float *B, const blasint *ldB,
blasint *info
) {
// Check arguments
const blasint lower = LAPACK(lsame)(uplo, "L");
const blasint upper = LAPACK(lsame)(uplo, "U");
*info = 0;
if (*itype < 1 || *itype > 3)
*info = -1;
else if (!lower && !upper)
*info = -2;
else if (*n < 0)
*info = -3;
else if (*ldA < MAX(1, *n))
*info = -5;
else if (*ldB < MAX(1, *n))
*info = -7;
if (*info) {
const blasint minfo = -*info;
LAPACK(xerbla)("SSYGST", &minfo, strlen("SSYGST"));
return;
}
// Clean char * arguments
const char cleanuplo = lower ? 'L' : 'U';
// Allocate work space
float *Work = NULL;
blasint lWork = 0;
#if XSYGST_ALLOW_MALLOC
const blasint n1 = SREC_SPLIT(*n);
lWork = n1 * (*n - n1);
Work = malloc(lWork * sizeof(float));
if (!Work)
lWork = 0;
#endif
// Recursive kernel
RELAPACK_ssygst_rec(itype, &cleanuplo, n, A, ldA, B, ldB, Work, &lWork, info);
// Free work space
#if XSYGST_ALLOW_MALLOC
if (Work)
free(Work);
#endif
}
/** ssygst's recursive compute kernel */
static void RELAPACK_ssygst_rec(
const blasint *itype, const char *uplo, const blasint *n,
float *A, const blasint *ldA, const float *B, const blasint *ldB,
float *Work, const blasint *lWork, blasint *info
) {
if (*n <= MAX(CROSSOVER_SSYGST, 1)) {
// Unblocked
LAPACK(ssygs2)(itype, uplo, n, A, ldA, B, ldB, info);
return;
}
// Constants
const float ZERO[] = { 0. };
const float ONE[] = { 1. };
const float MONE[] = { -1. };
const float HALF[] = { .5 };
const float MHALF[] = { -.5 };
const blasint iONE[] = { 1 };
// Loop iterator
blasint i;
// Splitting
const blasint n1 = SREC_SPLIT(*n);
const blasint n2 = *n - n1;
// A_TL A_TR
// A_BL A_BR
float *const A_TL = A;
float *const A_TR = A + *ldA * n1;
float *const A_BL = A + n1;
float *const A_BR = A + *ldA * n1 + n1;
// B_TL B_TR
// B_BL B_BR
const float *const B_TL = B;
const float *const B_TR = B + *ldB * n1;
const float *const B_BL = B + n1;
const float *const B_BR = B + *ldB * n1 + n1;
// recursion(A_TL, B_TL)
RELAPACK_ssygst_rec(itype, uplo, &n1, A_TL, ldA, B_TL, ldB, Work, lWork, info);
if (*itype == 1)
if (*uplo == 'L') {
// A_BL = A_BL / B_TL'
BLAS(strsm)("R", "L", "T", "N", &n2, &n1, ONE, B_TL, ldB, A_BL, ldA);
if (*lWork > n2 * n1) {
// T = -1/2 * B_BL * A_TL
BLAS(ssymm)("R", "L", &n2, &n1, MHALF, A_TL, ldA, B_BL, ldB, ZERO, Work, &n2);
// A_BL = A_BL + T
for (i = 0; i < n1; i++)
BLAS(saxpy)(&n2, ONE, Work + n2 * i, iONE, A_BL + *ldA * i, iONE);
} else
// A_BL = A_BL - 1/2 B_BL * A_TL
BLAS(ssymm)("R", "L", &n2, &n1, MHALF, A_TL, ldA, B_BL, ldB, ONE, A_BL, ldA);
// A_BR = A_BR - A_BL * B_BL' - B_BL * A_BL'
BLAS(ssyr2k)("L", "N", &n2, &n1, MONE, A_BL, ldA, B_BL, ldB, ONE, A_BR, ldA);
if (*lWork > n2 * n1)
// A_BL = A_BL + T
for (i = 0; i < n1; i++)
BLAS(saxpy)(&n2, ONE, Work + n2 * i, iONE, A_BL + *ldA * i, iONE);
else
// A_BL = A_BL - 1/2 B_BL * A_TL
BLAS(ssymm)("R", "L", &n2, &n1, MHALF, A_TL, ldA, B_BL, ldB, ONE, A_BL, ldA);
// A_BL = B_BR \ A_BL
BLAS(strsm)("L", "L", "N", "N", &n2, &n1, ONE, B_BR, ldB, A_BL, ldA);
} else {
// A_TR = B_TL' \ A_TR
BLAS(strsm)("L", "U", "T", "N", &n1, &n2, ONE, B_TL, ldB, A_TR, ldA);
if (*lWork > n2 * n1) {
// T = -1/2 * A_TL * B_TR
BLAS(ssymm)("L", "U", &n1, &n2, MHALF, A_TL, ldA, B_TR, ldB, ZERO, Work, &n1);
// A_TR = A_BL + T
for (i = 0; i < n2; i++)
BLAS(saxpy)(&n1, ONE, Work + n1 * i, iONE, A_TR + *ldA * i, iONE);
} else
// A_TR = A_TR - 1/2 A_TL * B_TR
BLAS(ssymm)("L", "U", &n1, &n2, MHALF, A_TL, ldA, B_TR, ldB, ONE, A_TR, ldA);
// A_BR = A_BR - A_TR' * B_TR - B_TR' * A_TR
BLAS(ssyr2k)("U", "T", &n2, &n1, MONE, A_TR, ldA, B_TR, ldB, ONE, A_BR, ldA);
if (*lWork > n2 * n1)
// A_TR = A_BL + T
for (i = 0; i < n2; i++)
BLAS(saxpy)(&n1, ONE, Work + n1 * i, iONE, A_TR + *ldA * i, iONE);
else
// A_TR = A_TR - 1/2 A_TL * B_TR
BLAS(ssymm)("L", "U", &n1, &n2, MHALF, A_TL, ldA, B_TR, ldB, ONE, A_TR, ldA);
// A_TR = A_TR / B_BR
BLAS(strsm)("R", "U", "N", "N", &n1, &n2, ONE, B_BR, ldB, A_TR, ldA);
}
else
if (*uplo == 'L') {
// A_BL = A_BL * B_TL
BLAS(strmm)("R", "L", "N", "N", &n2, &n1, ONE, B_TL, ldB, A_BL, ldA);
if (*lWork > n2 * n1) {
// T = 1/2 * A_BR * B_BL
BLAS(ssymm)("L", "L", &n2, &n1, HALF, A_BR, ldA, B_BL, ldB, ZERO, Work, &n2);
// A_BL = A_BL + T
for (i = 0; i < n1; i++)
BLAS(saxpy)(&n2, ONE, Work + n2 * i, iONE, A_BL + *ldA * i, iONE);
} else
// A_BL = A_BL + 1/2 A_BR * B_BL
BLAS(ssymm)("L", "L", &n2, &n1, HALF, A_BR, ldA, B_BL, ldB, ONE, A_BL, ldA);
// A_TL = A_TL + A_BL' * B_BL + B_BL' * A_BL
BLAS(ssyr2k)("L", "T", &n1, &n2, ONE, A_BL, ldA, B_BL, ldB, ONE, A_TL, ldA);
if (*lWork > n2 * n1)
// A_BL = A_BL + T
for (i = 0; i < n1; i++)
BLAS(saxpy)(&n2, ONE, Work + n2 * i, iONE, A_BL + *ldA * i, iONE);
else
// A_BL = A_BL + 1/2 A_BR * B_BL
BLAS(ssymm)("L", "L", &n2, &n1, HALF, A_BR, ldA, B_BL, ldB, ONE, A_BL, ldA);
// A_BL = B_BR * A_BL
BLAS(strmm)("L", "L", "T", "N", &n2, &n1, ONE, B_BR, ldB, A_BL, ldA);
} else {
// A_TR = B_TL * A_TR
BLAS(strmm)("L", "U", "N", "N", &n1, &n2, ONE, B_TL, ldB, A_TR, ldA);
if (*lWork > n2 * n1) {
// T = 1/2 * B_TR * A_BR
BLAS(ssymm)("R", "U", &n1, &n2, HALF, A_BR, ldA, B_TR, ldB, ZERO, Work, &n1);
// A_TR = A_TR + T
for (i = 0; i < n2; i++)
BLAS(saxpy)(&n1, ONE, Work + n1 * i, iONE, A_TR + *ldA * i, iONE);
} else
// A_TR = A_TR + 1/2 B_TR A_BR
BLAS(ssymm)("R", "U", &n1, &n2, HALF, A_BR, ldA, B_TR, ldB, ONE, A_TR, ldA);
// A_TL = A_TL + A_TR * B_TR' + B_TR * A_TR'
BLAS(ssyr2k)("U", "N", &n1, &n2, ONE, A_TR, ldA, B_TR, ldB, ONE, A_TL, ldA);
if (*lWork > n2 * n1)
// A_TR = A_TR + T
for (i = 0; i < n2; i++)
BLAS(saxpy)(&n1, ONE, Work + n1 * i, iONE, A_TR + *ldA * i, iONE);
else
// A_TR = A_TR + 1/2 B_TR * A_BR
BLAS(ssymm)("R", "U", &n1, &n2, HALF, A_BR, ldA, B_TR, ldB, ONE, A_TR, ldA);
// A_TR = A_TR * B_BR
BLAS(strmm)("R", "U", "T", "N", &n1, &n2, ONE, B_BR, ldB, A_TR, ldA);
}
// recursion(A_BR, B_BR)
RELAPACK_ssygst_rec(itype, uplo, &n2, A_BR, ldA, B_BR, ldB, Work, lWork, info);
}