Files
OpenBLAS/relapack/src/strtri.c
Martin Kroeker 798c448b0c Add support for INTERFACE64 and fix XERBLA calls
1. Replaced all instances of "int" with "blasint"
2. Added string length as "hidden" third parameter in calls to fortran XERBLA
2019-04-27 19:06:00 +02:00

108 lines
2.9 KiB
C

#include "relapack.h"
static void RELAPACK_strtri_rec(const char *, const char *, const blasint *,
float *, const blasint *, blasint *);
/** CTRTRI computes the inverse of a real upper or lower triangular matrix A.
*
* This routine is functionally equivalent to LAPACK's strtri.
* For details on its interface, see
* http://www.netlib.org/lapack/explore-html/de/d76/strtri_8f.html
* */
void RELAPACK_strtri(
const char *uplo, const char *diag, const blasint *n,
float *A, const blasint *ldA,
blasint *info
) {
// Check arguments
const blasint lower = LAPACK(lsame)(uplo, "L");
const blasint upper = LAPACK(lsame)(uplo, "U");
const blasint nounit = LAPACK(lsame)(diag, "N");
const blasint unit = LAPACK(lsame)(diag, "U");
*info = 0;
if (!lower && !upper)
*info = -1;
else if (!nounit && !unit)
*info = -2;
else if (*n < 0)
*info = -3;
else if (*ldA < MAX(1, *n))
*info = -5;
if (*info) {
const blasint minfo = -*info;
LAPACK(xerbla)("STRTRI", &minfo, strlen("STRTRI"));
return;
}
// Clean char * arguments
const char cleanuplo = lower ? 'L' : 'U';
const char cleandiag = nounit ? 'N' : 'U';
// check for singularity
if (nounit) {
blasint i;
for (i = 0; i < *n; i++)
if (A[i + *ldA * i] == 0) {
*info = i;
return;
}
}
// Recursive kernel
RELAPACK_strtri_rec(&cleanuplo, &cleandiag, n, A, ldA, info);
}
/** strtri's recursive compute kernel */
static void RELAPACK_strtri_rec(
const char *uplo, const char *diag, const blasint *n,
float *A, const blasint *ldA,
blasint *info
){
if (*n <= MAX(CROSSOVER_STRTRI, 1)) {
// Unblocked
LAPACK(strti2)(uplo, diag, n, A, ldA, info);
return;
}
// Constants
const float ONE[] = { 1. };
const float MONE[] = { -1. };
// Splitting
const blasint n1 = SREC_SPLIT(*n);
const blasint n2 = *n - n1;
// A_TL A_TR
// A_BL A_BR
float *const A_TL = A;
float *const A_TR = A + *ldA * n1;
float *const A_BL = A + n1;
float *const A_BR = A + *ldA * n1 + n1;
// recursion(A_TL)
RELAPACK_strtri_rec(uplo, diag, &n1, A_TL, ldA, info);
if (*info)
return;
if (*uplo == 'L') {
// A_BL = - A_BL * A_TL
BLAS(strmm)("R", "L", "N", diag, &n2, &n1, MONE, A_TL, ldA, A_BL, ldA);
// A_BL = A_BR \ A_BL
BLAS(strsm)("L", "L", "N", diag, &n2, &n1, ONE, A_BR, ldA, A_BL, ldA);
} else {
// A_TR = - A_TL * A_TR
BLAS(strmm)("L", "U", "N", diag, &n1, &n2, MONE, A_TL, ldA, A_TR, ldA);
// A_TR = A_TR / A_BR
BLAS(strsm)("R", "U", "N", diag, &n1, &n2, ONE, A_BR, ldA, A_TR, ldA);
}
// recursion(A_BR)
RELAPACK_strtri_rec(uplo, diag, &n2, A_BR, ldA, info);
if (*info)
*info += n1;
}