108 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			C
		
	
	
	
			
		
		
	
	
			108 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			C
		
	
	
	
| #include "relapack.h"
 | |
| 
 | |
| static void RELAPACK_strtri_rec(const char *, const char *, const blasint *,
 | |
|     float *, const blasint *, blasint *);
 | |
| 
 | |
| 
 | |
| /** CTRTRI computes the inverse of a real upper or lower triangular matrix A.
 | |
|  *
 | |
|  * This routine is functionally equivalent to LAPACK's strtri.
 | |
|  * For details on its interface, see
 | |
|  * http://www.netlib.org/lapack/explore-html/de/d76/strtri_8f.html
 | |
|  * */
 | |
| void RELAPACK_strtri(
 | |
|     const char *uplo, const char *diag, const blasint *n,
 | |
|     float *A, const blasint *ldA,
 | |
|     blasint *info
 | |
| ) {
 | |
| 
 | |
|     // Check arguments
 | |
|     const blasint lower = LAPACK(lsame)(uplo, "L");
 | |
|     const blasint upper = LAPACK(lsame)(uplo, "U");
 | |
|     const blasint nounit = LAPACK(lsame)(diag, "N");
 | |
|     const blasint unit = LAPACK(lsame)(diag, "U");
 | |
|     *info = 0;
 | |
|     if (!lower && !upper)
 | |
|         *info = -1;
 | |
|     else if (!nounit && !unit)
 | |
|         *info = -2;
 | |
|     else if (*n < 0)
 | |
|         *info = -3;
 | |
|     else if (*ldA < MAX(1, *n))
 | |
|         *info = -5;
 | |
|     if (*info) {
 | |
|         const blasint minfo = -*info;
 | |
|         LAPACK(xerbla)("STRTRI", &minfo, strlen("STRTRI"));
 | |
|         return;
 | |
|     }
 | |
| 
 | |
|     // Clean char * arguments
 | |
|     const char cleanuplo = lower  ? 'L' : 'U';
 | |
|     const char cleandiag = nounit ? 'N' : 'U';
 | |
| 
 | |
|     // check for singularity
 | |
|     if (nounit) {
 | |
|         blasint i;
 | |
|         for (i = 0; i < *n; i++)
 | |
|             if (A[i + *ldA * i] == 0) {
 | |
|                 *info = i;
 | |
|                 return;
 | |
|             }
 | |
|     }
 | |
| 
 | |
|     // Recursive kernel
 | |
|     RELAPACK_strtri_rec(&cleanuplo, &cleandiag, n, A, ldA, info);
 | |
| }
 | |
| 
 | |
| 
 | |
| /** strtri's recursive compute kernel */
 | |
| static void RELAPACK_strtri_rec(
 | |
|     const char *uplo, const char *diag, const blasint *n,
 | |
|     float *A, const blasint *ldA,
 | |
|     blasint *info
 | |
| ){
 | |
| 
 | |
|     if (*n <= MAX(CROSSOVER_STRTRI, 1)) {
 | |
|         // Unblocked
 | |
|         LAPACK(strti2)(uplo, diag, n, A, ldA, info);
 | |
|         return;
 | |
|     }
 | |
| 
 | |
|     // Constants
 | |
|     const float ONE[]  = { 1. };
 | |
|     const float MONE[] = { -1. };
 | |
| 
 | |
|     // Splitting
 | |
|     const blasint n1 = SREC_SPLIT(*n);
 | |
|     const blasint n2 = *n - n1;
 | |
| 
 | |
|     // A_TL A_TR
 | |
|     // A_BL A_BR
 | |
|     float *const A_TL = A;
 | |
|     float *const A_TR = A + *ldA * n1;
 | |
|     float *const A_BL = A             + n1;
 | |
|     float *const A_BR = A + *ldA * n1 + n1;
 | |
| 
 | |
|     // recursion(A_TL)
 | |
|     RELAPACK_strtri_rec(uplo, diag, &n1, A_TL, ldA, info);
 | |
|     if (*info)
 | |
|         return;
 | |
| 
 | |
|     if (*uplo == 'L') {
 | |
|         // A_BL = - A_BL * A_TL
 | |
|         BLAS(strmm)("R", "L", "N", diag, &n2, &n1, MONE, A_TL, ldA, A_BL, ldA);
 | |
|         // A_BL = A_BR \ A_BL
 | |
|         BLAS(strsm)("L", "L", "N", diag, &n2, &n1, ONE, A_BR, ldA, A_BL, ldA);
 | |
|     } else {
 | |
|         // A_TR = - A_TL * A_TR
 | |
|         BLAS(strmm)("L", "U", "N", diag, &n1, &n2, MONE, A_TL, ldA, A_TR, ldA);
 | |
|         // A_TR = A_TR / A_BR
 | |
|         BLAS(strsm)("R", "U", "N", diag, &n1, &n2, ONE, A_BR, ldA, A_TR, ldA);
 | |
|     }
 | |
| 
 | |
|     // recursion(A_BR)
 | |
|     RELAPACK_strtri_rec(uplo, diag, &n2, A_BR, ldA, info);
 | |
|     if (*info)
 | |
|         *info += n1;
 | |
| }
 |