Merge pull request #4006 from martin-frbg/issue4005
Fix ?GEMMT implementation
This commit is contained in:
		
						commit
						73e6fcb925
					
				|  | @ -35,29 +35,26 @@ | ||||||
| #include <stdio.h> | #include <stdio.h> | ||||||
| #include <stdlib.h> | #include <stdlib.h> | ||||||
| #include "common.h" | #include "common.h" | ||||||
| #ifdef FUNCTION_PROFILE |  | ||||||
| #include "functable.h" |  | ||||||
| #endif |  | ||||||
| 
 | 
 | ||||||
| #ifndef COMPLEX | #ifndef COMPLEX | ||||||
| #define SMP_THRESHOLD_MIN 65536.0 | #define SMP_THRESHOLD_MIN 65536.0 | ||||||
| #ifdef XDOUBLE | #ifdef XDOUBLE | ||||||
| #define ERROR_NAME "QGEMT " | #define ERROR_NAME "QGEMMT " | ||||||
| #elif defined(DOUBLE) | #elif defined(DOUBLE) | ||||||
| #define ERROR_NAME "DGEMT " | #define ERROR_NAME "DGEMMT " | ||||||
| #elif defined(BFLOAT16) | #elif defined(BFLOAT16) | ||||||
| #define ERROR_NAME "SBGEMT " | #define ERROR_NAME "SBGEMMT " | ||||||
| #else | #else | ||||||
| #define ERROR_NAME "SGEMT " | #define ERROR_NAME "SGEMMT " | ||||||
| #endif | #endif | ||||||
| #else | #else | ||||||
| #define SMP_THRESHOLD_MIN 8192.0 | #define SMP_THRESHOLD_MIN 8192.0 | ||||||
| #ifdef XDOUBLE | #ifdef XDOUBLE | ||||||
| #define ERROR_NAME "XGEMT " | #define ERROR_NAME "XGEMMT " | ||||||
| #elif defined(DOUBLE) | #elif defined(DOUBLE) | ||||||
| #define ERROR_NAME "ZGEMT " | #define ERROR_NAME "ZGEMMT " | ||||||
| #else | #else | ||||||
| #define ERROR_NAME "CGEMT " | #define ERROR_NAME "CGEMMT " | ||||||
| #endif | #endif | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
|  | @ -68,13 +65,13 @@ | ||||||
| #ifndef CBLAS | #ifndef CBLAS | ||||||
| 
 | 
 | ||||||
| void NAME(char *UPLO, char *TRANSA, char *TRANSB, | void NAME(char *UPLO, char *TRANSA, char *TRANSB, | ||||||
| 	  blasint * M, blasint * N, blasint * K, | 	  blasint * M, blasint * K, | ||||||
| 	  FLOAT * Alpha, | 	  FLOAT * Alpha, | ||||||
| 	  IFLOAT * a, blasint * ldA, | 	  IFLOAT * a, blasint * ldA, | ||||||
| 	  IFLOAT * b, blasint * ldB, FLOAT * Beta, FLOAT * c, blasint * ldC) | 	  IFLOAT * b, blasint * ldB, FLOAT * Beta, FLOAT * c, blasint * ldC) | ||||||
| { | { | ||||||
| 
 | 
 | ||||||
| 	blasint m, n, k; | 	blasint m, k; | ||||||
| 	blasint lda, ldb, ldc; | 	blasint lda, ldb, ldc; | ||||||
| 	int transa, transb, uplo; | 	int transa, transb, uplo; | ||||||
| 	blasint info; | 	blasint info; | ||||||
|  | @ -92,7 +89,6 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB, | ||||||
| 	PRINT_DEBUG_NAME; | 	PRINT_DEBUG_NAME; | ||||||
| 
 | 
 | ||||||
| 	m = *M; | 	m = *M; | ||||||
| 	n = *N; |  | ||||||
| 	k = *K; | 	k = *K; | ||||||
| 
 | 
 | ||||||
| #if defined(COMPLEX) | #if defined(COMPLEX) | ||||||
|  | @ -167,8 +163,6 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB, | ||||||
| 		info = 13; | 		info = 13; | ||||||
| 	if (k < 0) | 	if (k < 0) | ||||||
| 		info = 5; | 		info = 5; | ||||||
| 	if (n < 0) |  | ||||||
| 		info = 4; |  | ||||||
| 	if (m < 0) | 	if (m < 0) | ||||||
| 		info = 3; | 		info = 3; | ||||||
| 	if (transb < 0) | 	if (transb < 0) | ||||||
|  | @ -184,7 +178,7 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB, | ||||||
| 
 | 
 | ||||||
| void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, | void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, | ||||||
| 	   enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, blasint M, | 	   enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, blasint M, | ||||||
| 	   blasint N, blasint k, | 	   blasint k, | ||||||
| #ifndef COMPLEX | #ifndef COMPLEX | ||||||
| 	   FLOAT alpha, | 	   FLOAT alpha, | ||||||
| 	   IFLOAT * A, blasint LDA, | 	   IFLOAT * A, blasint LDA, | ||||||
|  | @ -205,7 +199,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, | ||||||
| 
 | 
 | ||||||
| 	int transa, transb, uplo; | 	int transa, transb, uplo; | ||||||
| 	blasint info; | 	blasint info; | ||||||
| 	blasint m, n, lda, ldb; | 	blasint m, lda, ldb; | ||||||
| 	FLOAT *a, *b; | 	FLOAT *a, *b; | ||||||
| 	XFLOAT *buffer; | 	XFLOAT *buffer; | ||||||
| 
 | 
 | ||||||
|  | @ -248,9 +242,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, | ||||||
| 			transb = 3; | 			transb = 3; | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
| 		m = M; |  | ||||||
| 		n = N; |  | ||||||
| 
 |  | ||||||
| 		a = (void *)A; | 		a = (void *)A; | ||||||
| 		b = (void *)B; | 		b = (void *)B; | ||||||
| 		lda = LDA; | 		lda = LDA; | ||||||
|  | @ -262,8 +253,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, | ||||||
| 			info = 13; | 			info = 13; | ||||||
| 		if (k < 0) | 		if (k < 0) | ||||||
| 			info = 5; | 			info = 5; | ||||||
| 		if (n < 0) |  | ||||||
| 			info = 4; |  | ||||||
| 		if (m < 0) | 		if (m < 0) | ||||||
| 			info = 3; | 			info = 3; | ||||||
| 		if (transb < 0) | 		if (transb < 0) | ||||||
|  | @ -273,8 +262,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if (order == CblasRowMajor) { | 	if (order == CblasRowMajor) { | ||||||
| 		m = N; |  | ||||||
| 		n = M; |  | ||||||
| 
 | 
 | ||||||
| 		a = (void *)B; | 		a = (void *)B; | ||||||
| 		b = (void *)A; | 		b = (void *)A; | ||||||
|  | @ -319,8 +306,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, | ||||||
| 			info = 13; | 			info = 13; | ||||||
| 		if (k < 0) | 		if (k < 0) | ||||||
| 			info = 5; | 			info = 5; | ||||||
| 		if (n < 0) |  | ||||||
| 			info = 4; |  | ||||||
| 		if (m < 0) | 		if (m < 0) | ||||||
| 			info = 3; | 			info = 3; | ||||||
| 		if (transb < 0) | 		if (transb < 0) | ||||||
|  | @ -407,37 +392,35 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, | ||||||
| 
 | 
 | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
| 	if ((m == 0) || (n == 0)) | 	if ((m == 0) ) | ||||||
| 		return; | 		return; | ||||||
| 
 | 
 | ||||||
| 	IDEBUG_START; | 	IDEBUG_START; | ||||||
| 
 | 
 | ||||||
| 	FUNCTION_PROFILE_START(); |  | ||||||
| 
 |  | ||||||
| 	const blasint incb = (transb == 0) ? 1 : ldb; | 	const blasint incb = (transb == 0) ? 1 : ldb; | ||||||
| 
 | 
 | ||||||
| 	if (uplo == 1) { | 	if (uplo == 1) { | ||||||
| 		for (i = 0; i < n; i++) { | 		for (i = 0; i < m; i++) { | ||||||
| 			j = n - i; | 			j = m - i; | ||||||
| 
 | 
 | ||||||
| 			l = j; | 			l = j; | ||||||
| #if defined(COMPLEX) | #if defined(COMPLEX) | ||||||
| 			aa = a + i * 2; | 			aa = a + i * 2; | ||||||
| 			bb = b + i * ldb * 2; | 			bb = b + i * ldb * 2; | ||||||
| 			if (transa) { | 			if (transa) { | ||||||
| 				l = k; |  | ||||||
| 				aa = a + lda * i * 2; | 				aa = a + lda * i * 2; | ||||||
| 				bb = b + i * 2; |  | ||||||
| 			} | 			} | ||||||
|  | 			if (transb) | ||||||
|  | 				bb = b + i * 2; | ||||||
| 			cc = c + i * 2 * ldc + i * 2; | 			cc = c + i * 2 * ldc + i * 2; | ||||||
| #else | #else | ||||||
| 			aa = a + i; | 			aa = a + i; | ||||||
| 			bb = b + i * ldb; | 			bb = b + i * ldb; | ||||||
| 			if (transa) { | 			if (transa) { | ||||||
| 				l = k; |  | ||||||
| 				aa = a + lda * i; | 				aa = a + lda * i; | ||||||
| 				bb = b + i; |  | ||||||
| 			} | 			} | ||||||
|  | 			if (transb) | ||||||
|  | 				bb = b + i; | ||||||
| 			cc = c + i * ldc + i; | 			cc = c + i * ldc + i; | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
|  | @ -458,8 +441,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, | ||||||
| 
 | 
 | ||||||
| 			IDEBUG_START; | 			IDEBUG_START; | ||||||
| 
 | 
 | ||||||
| 			FUNCTION_PROFILE_START(); |  | ||||||
| 
 |  | ||||||
| 			buffer_size = j + k + 128 / sizeof(FLOAT); | 			buffer_size = j + k + 128 / sizeof(FLOAT); | ||||||
| #ifdef WINDOWS_ABI | #ifdef WINDOWS_ABI | ||||||
| 			buffer_size += 160 / sizeof(FLOAT); | 			buffer_size += 160 / sizeof(FLOAT); | ||||||
|  | @ -479,20 +460,34 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
| #if defined(COMPLEX) | #if defined(COMPLEX) | ||||||
|  | 				if (!transa) | ||||||
| 				(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i, | 				(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i, | ||||||
| 						     aa, lda, bb, incb, cc, 1, | 						     aa, lda, bb, incb, cc, 1, | ||||||
| 						     buffer); | 						     buffer); | ||||||
|  | 				else | ||||||
|  | 				(gemv[(int)transa]) (k, j, 0, alpha_r, alpha_i, | ||||||
|  | 						     aa, lda, bb, incb, cc, 1, | ||||||
|  | 						     buffer); | ||||||
| #else | #else | ||||||
|  | 				if (!transa) | ||||||
| 				(gemv[(int)transa]) (j, k, 0, alpha, aa, lda, | 				(gemv[(int)transa]) (j, k, 0, alpha, aa, lda, | ||||||
| 						     bb, incb, cc, 1, buffer); | 						     bb, incb, cc, 1, buffer); | ||||||
|  | 				else | ||||||
|  | 				(gemv[(int)transa]) (k, j, 0, alpha, aa, lda, | ||||||
|  | 						     bb, incb, cc, 1, buffer); | ||||||
| #endif | #endif | ||||||
| #ifdef SMP | #ifdef SMP | ||||||
| 			} else { | 			} else { | ||||||
| 
 | 				if (!transa) | ||||||
| 				(gemv_thread[(int)transa]) (j, k, alpha, aa, | 				(gemv_thread[(int)transa]) (j, k, alpha, aa, | ||||||
| 							    lda, bb, incb, cc, | 							    lda, bb, incb, cc, | ||||||
| 							    1, buffer, | 							    1, buffer, | ||||||
| 							    nthreads); | 							    nthreads); | ||||||
|  | 				else | ||||||
|  | 				(gemv_thread[(int)transa]) (k, j, alpha, aa, | ||||||
|  | 							    lda, bb, incb, cc, | ||||||
|  | 							    1, buffer, | ||||||
|  | 							    nthreads); | ||||||
| 
 | 
 | ||||||
| 			} | 			} | ||||||
| #endif | #endif | ||||||
|  | @ -501,21 +496,19 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 
 | 
 | ||||||
| 		for (i = 0; i < n; i++) { | 		for (i = 0; i < m; i++) { | ||||||
| 			j = i + 1; | 			j = i + 1; | ||||||
| 
 | 
 | ||||||
| 			l = j; | 			l = j; | ||||||
| #if defined COMPLEX | #if defined COMPLEX | ||||||
| 			bb = b + i * ldb * 2; | 			bb = b + i * ldb * 2; | ||||||
| 			if (transa) { | 			if (transb) { | ||||||
| 				l = k; |  | ||||||
| 				bb = b + i * 2; | 				bb = b + i * 2; | ||||||
| 			} | 			} | ||||||
| 			cc = c + i * 2 * ldc; | 			cc = c + i * 2 * ldc; | ||||||
| #else | #else | ||||||
| 			bb = b + i * ldb; | 			bb = b + i * ldb; | ||||||
| 			if (transa) { | 			if (transb) { | ||||||
| 				l = k; |  | ||||||
| 				bb = b + i; | 				bb = b + i; | ||||||
| 			} | 			} | ||||||
| 			cc = c + i * ldc; | 			cc = c + i * ldc; | ||||||
|  | @ -537,8 +530,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, | ||||||
| #endif | #endif | ||||||
| 			IDEBUG_START; | 			IDEBUG_START; | ||||||
| 
 | 
 | ||||||
| 			FUNCTION_PROFILE_START(); |  | ||||||
| 
 |  | ||||||
| 			buffer_size = j + k + 128 / sizeof(FLOAT); | 			buffer_size = j + k + 128 / sizeof(FLOAT); | ||||||
| #ifdef WINDOWS_ABI | #ifdef WINDOWS_ABI | ||||||
| 			buffer_size += 160 / sizeof(FLOAT); | 			buffer_size += 160 / sizeof(FLOAT); | ||||||
|  | @ -558,30 +549,39 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
| #if defined(COMPLEX) | #if defined(COMPLEX) | ||||||
|  | 				if (!transa) | ||||||
| 				(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i, | 				(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i, | ||||||
| 						     a, lda, bb, incb, cc, 1, | 						     a, lda, bb, incb, cc, 1, | ||||||
| 						     buffer); | 						     buffer); | ||||||
|  | 				else | ||||||
|  | 				(gemv[(int)transa]) (k, j, 0, alpha_r, alpha_i, | ||||||
|  | 						     a, lda, bb, incb, cc, 1, | ||||||
|  | 						     buffer); | ||||||
| #else | #else | ||||||
|  | 				if (!transa) | ||||||
| 				(gemv[(int)transa]) (j, k, 0, alpha, a, lda, bb, | 				(gemv[(int)transa]) (j, k, 0, alpha, a, lda, bb, | ||||||
| 						     incb, cc, 1, buffer); | 						     incb, cc, 1, buffer); | ||||||
|  | 				else | ||||||
|  | 				(gemv[(int)transa]) (k, j, 0, alpha, a, lda, bb, | ||||||
|  | 						     incb, cc, 1, buffer); | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
| #ifdef SMP | #ifdef SMP | ||||||
| 			} else { | 			} else { | ||||||
| 
 | 				if (!transa) | ||||||
| 				(gemv_thread[(int)transa]) (j, k, alpha, a, lda, | 				(gemv_thread[(int)transa]) (j, k, alpha, a, lda, | ||||||
| 							    bb, incb, cc, 1, | 							    bb, incb, cc, 1, | ||||||
| 							    buffer, nthreads); | 							    buffer, nthreads); | ||||||
| 
 | 				else | ||||||
|  | 				(gemv_thread[(int)transa]) (k, j, alpha, a, lda, | ||||||
|  | 							    bb, incb, cc, 1, | ||||||
|  | 							    buffer, nthreads); | ||||||
| 			} | 			} | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
| 			STACK_FREE(buffer); | 			STACK_FREE(buffer); | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	FUNCTION_PROFILE_END(COMPSIZE * COMPSIZE, |  | ||||||
| 			     args.m * args.k + args.k * args.n + |  | ||||||
| 			     args.m * args.n, 2 * args.m * args.n * args.k); |  | ||||||
| 
 | 
 | ||||||
| 	IDEBUG_END; | 	IDEBUG_END; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue