Merge pull request #4006 from martin-frbg/issue4005
Fix ?GEMMT implementation
This commit is contained in:
commit
73e6fcb925
|
@ -35,29 +35,26 @@
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#ifdef FUNCTION_PROFILE
|
|
||||||
#include "functable.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifndef COMPLEX
|
#ifndef COMPLEX
|
||||||
#define SMP_THRESHOLD_MIN 65536.0
|
#define SMP_THRESHOLD_MIN 65536.0
|
||||||
#ifdef XDOUBLE
|
#ifdef XDOUBLE
|
||||||
#define ERROR_NAME "QGEMT "
|
#define ERROR_NAME "QGEMMT "
|
||||||
#elif defined(DOUBLE)
|
#elif defined(DOUBLE)
|
||||||
#define ERROR_NAME "DGEMT "
|
#define ERROR_NAME "DGEMMT "
|
||||||
#elif defined(BFLOAT16)
|
#elif defined(BFLOAT16)
|
||||||
#define ERROR_NAME "SBGEMT "
|
#define ERROR_NAME "SBGEMMT "
|
||||||
#else
|
#else
|
||||||
#define ERROR_NAME "SGEMT "
|
#define ERROR_NAME "SGEMMT "
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
#define SMP_THRESHOLD_MIN 8192.0
|
#define SMP_THRESHOLD_MIN 8192.0
|
||||||
#ifdef XDOUBLE
|
#ifdef XDOUBLE
|
||||||
#define ERROR_NAME "XGEMT "
|
#define ERROR_NAME "XGEMMT "
|
||||||
#elif defined(DOUBLE)
|
#elif defined(DOUBLE)
|
||||||
#define ERROR_NAME "ZGEMT "
|
#define ERROR_NAME "ZGEMMT "
|
||||||
#else
|
#else
|
||||||
#define ERROR_NAME "CGEMT "
|
#define ERROR_NAME "CGEMMT "
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -68,13 +65,13 @@
|
||||||
#ifndef CBLAS
|
#ifndef CBLAS
|
||||||
|
|
||||||
void NAME(char *UPLO, char *TRANSA, char *TRANSB,
|
void NAME(char *UPLO, char *TRANSA, char *TRANSB,
|
||||||
blasint * M, blasint * N, blasint * K,
|
blasint * M, blasint * K,
|
||||||
FLOAT * Alpha,
|
FLOAT * Alpha,
|
||||||
IFLOAT * a, blasint * ldA,
|
IFLOAT * a, blasint * ldA,
|
||||||
IFLOAT * b, blasint * ldB, FLOAT * Beta, FLOAT * c, blasint * ldC)
|
IFLOAT * b, blasint * ldB, FLOAT * Beta, FLOAT * c, blasint * ldC)
|
||||||
{
|
{
|
||||||
|
|
||||||
blasint m, n, k;
|
blasint m, k;
|
||||||
blasint lda, ldb, ldc;
|
blasint lda, ldb, ldc;
|
||||||
int transa, transb, uplo;
|
int transa, transb, uplo;
|
||||||
blasint info;
|
blasint info;
|
||||||
|
@ -92,7 +89,6 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
|
||||||
PRINT_DEBUG_NAME;
|
PRINT_DEBUG_NAME;
|
||||||
|
|
||||||
m = *M;
|
m = *M;
|
||||||
n = *N;
|
|
||||||
k = *K;
|
k = *K;
|
||||||
|
|
||||||
#if defined(COMPLEX)
|
#if defined(COMPLEX)
|
||||||
|
@ -167,8 +163,6 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
|
||||||
info = 13;
|
info = 13;
|
||||||
if (k < 0)
|
if (k < 0)
|
||||||
info = 5;
|
info = 5;
|
||||||
if (n < 0)
|
|
||||||
info = 4;
|
|
||||||
if (m < 0)
|
if (m < 0)
|
||||||
info = 3;
|
info = 3;
|
||||||
if (transb < 0)
|
if (transb < 0)
|
||||||
|
@ -184,7 +178,7 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
|
||||||
|
|
||||||
void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
||||||
enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, blasint M,
|
enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, blasint M,
|
||||||
blasint N, blasint k,
|
blasint k,
|
||||||
#ifndef COMPLEX
|
#ifndef COMPLEX
|
||||||
FLOAT alpha,
|
FLOAT alpha,
|
||||||
IFLOAT * A, blasint LDA,
|
IFLOAT * A, blasint LDA,
|
||||||
|
@ -205,7 +199,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
||||||
|
|
||||||
int transa, transb, uplo;
|
int transa, transb, uplo;
|
||||||
blasint info;
|
blasint info;
|
||||||
blasint m, n, lda, ldb;
|
blasint m, lda, ldb;
|
||||||
FLOAT *a, *b;
|
FLOAT *a, *b;
|
||||||
XFLOAT *buffer;
|
XFLOAT *buffer;
|
||||||
|
|
||||||
|
@ -248,9 +242,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
||||||
transb = 3;
|
transb = 3;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
m = M;
|
|
||||||
n = N;
|
|
||||||
|
|
||||||
a = (void *)A;
|
a = (void *)A;
|
||||||
b = (void *)B;
|
b = (void *)B;
|
||||||
lda = LDA;
|
lda = LDA;
|
||||||
|
@ -262,8 +253,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
||||||
info = 13;
|
info = 13;
|
||||||
if (k < 0)
|
if (k < 0)
|
||||||
info = 5;
|
info = 5;
|
||||||
if (n < 0)
|
|
||||||
info = 4;
|
|
||||||
if (m < 0)
|
if (m < 0)
|
||||||
info = 3;
|
info = 3;
|
||||||
if (transb < 0)
|
if (transb < 0)
|
||||||
|
@ -273,8 +262,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (order == CblasRowMajor) {
|
if (order == CblasRowMajor) {
|
||||||
m = N;
|
|
||||||
n = M;
|
|
||||||
|
|
||||||
a = (void *)B;
|
a = (void *)B;
|
||||||
b = (void *)A;
|
b = (void *)A;
|
||||||
|
@ -319,8 +306,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
||||||
info = 13;
|
info = 13;
|
||||||
if (k < 0)
|
if (k < 0)
|
||||||
info = 5;
|
info = 5;
|
||||||
if (n < 0)
|
|
||||||
info = 4;
|
|
||||||
if (m < 0)
|
if (m < 0)
|
||||||
info = 3;
|
info = 3;
|
||||||
if (transb < 0)
|
if (transb < 0)
|
||||||
|
@ -407,37 +392,35 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if ((m == 0) || (n == 0))
|
if ((m == 0) )
|
||||||
return;
|
return;
|
||||||
|
|
||||||
IDEBUG_START;
|
IDEBUG_START;
|
||||||
|
|
||||||
FUNCTION_PROFILE_START();
|
|
||||||
|
|
||||||
const blasint incb = (transb == 0) ? 1 : ldb;
|
const blasint incb = (transb == 0) ? 1 : ldb;
|
||||||
|
|
||||||
if (uplo == 1) {
|
if (uplo == 1) {
|
||||||
for (i = 0; i < n; i++) {
|
for (i = 0; i < m; i++) {
|
||||||
j = n - i;
|
j = m - i;
|
||||||
|
|
||||||
l = j;
|
l = j;
|
||||||
#if defined(COMPLEX)
|
#if defined(COMPLEX)
|
||||||
aa = a + i * 2;
|
aa = a + i * 2;
|
||||||
bb = b + i * ldb * 2;
|
bb = b + i * ldb * 2;
|
||||||
if (transa) {
|
if (transa) {
|
||||||
l = k;
|
|
||||||
aa = a + lda * i * 2;
|
aa = a + lda * i * 2;
|
||||||
bb = b + i * 2;
|
|
||||||
}
|
}
|
||||||
|
if (transb)
|
||||||
|
bb = b + i * 2;
|
||||||
cc = c + i * 2 * ldc + i * 2;
|
cc = c + i * 2 * ldc + i * 2;
|
||||||
#else
|
#else
|
||||||
aa = a + i;
|
aa = a + i;
|
||||||
bb = b + i * ldb;
|
bb = b + i * ldb;
|
||||||
if (transa) {
|
if (transa) {
|
||||||
l = k;
|
|
||||||
aa = a + lda * i;
|
aa = a + lda * i;
|
||||||
bb = b + i;
|
|
||||||
}
|
}
|
||||||
|
if (transb)
|
||||||
|
bb = b + i;
|
||||||
cc = c + i * ldc + i;
|
cc = c + i * ldc + i;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -458,8 +441,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
||||||
|
|
||||||
IDEBUG_START;
|
IDEBUG_START;
|
||||||
|
|
||||||
FUNCTION_PROFILE_START();
|
|
||||||
|
|
||||||
buffer_size = j + k + 128 / sizeof(FLOAT);
|
buffer_size = j + k + 128 / sizeof(FLOAT);
|
||||||
#ifdef WINDOWS_ABI
|
#ifdef WINDOWS_ABI
|
||||||
buffer_size += 160 / sizeof(FLOAT);
|
buffer_size += 160 / sizeof(FLOAT);
|
||||||
|
@ -479,20 +460,34 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(COMPLEX)
|
#if defined(COMPLEX)
|
||||||
|
if (!transa)
|
||||||
(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i,
|
(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i,
|
||||||
aa, lda, bb, incb, cc, 1,
|
aa, lda, bb, incb, cc, 1,
|
||||||
buffer);
|
buffer);
|
||||||
|
else
|
||||||
|
(gemv[(int)transa]) (k, j, 0, alpha_r, alpha_i,
|
||||||
|
aa, lda, bb, incb, cc, 1,
|
||||||
|
buffer);
|
||||||
#else
|
#else
|
||||||
|
if (!transa)
|
||||||
(gemv[(int)transa]) (j, k, 0, alpha, aa, lda,
|
(gemv[(int)transa]) (j, k, 0, alpha, aa, lda,
|
||||||
bb, incb, cc, 1, buffer);
|
bb, incb, cc, 1, buffer);
|
||||||
|
else
|
||||||
|
(gemv[(int)transa]) (k, j, 0, alpha, aa, lda,
|
||||||
|
bb, incb, cc, 1, buffer);
|
||||||
#endif
|
#endif
|
||||||
#ifdef SMP
|
#ifdef SMP
|
||||||
} else {
|
} else {
|
||||||
|
if (!transa)
|
||||||
(gemv_thread[(int)transa]) (j, k, alpha, aa,
|
(gemv_thread[(int)transa]) (j, k, alpha, aa,
|
||||||
lda, bb, incb, cc,
|
lda, bb, incb, cc,
|
||||||
1, buffer,
|
1, buffer,
|
||||||
nthreads);
|
nthreads);
|
||||||
|
else
|
||||||
|
(gemv_thread[(int)transa]) (k, j, alpha, aa,
|
||||||
|
lda, bb, incb, cc,
|
||||||
|
1, buffer,
|
||||||
|
nthreads);
|
||||||
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -501,21 +496,19 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
for (i = 0; i < n; i++) {
|
for (i = 0; i < m; i++) {
|
||||||
j = i + 1;
|
j = i + 1;
|
||||||
|
|
||||||
l = j;
|
l = j;
|
||||||
#if defined COMPLEX
|
#if defined COMPLEX
|
||||||
bb = b + i * ldb * 2;
|
bb = b + i * ldb * 2;
|
||||||
if (transa) {
|
if (transb) {
|
||||||
l = k;
|
|
||||||
bb = b + i * 2;
|
bb = b + i * 2;
|
||||||
}
|
}
|
||||||
cc = c + i * 2 * ldc;
|
cc = c + i * 2 * ldc;
|
||||||
#else
|
#else
|
||||||
bb = b + i * ldb;
|
bb = b + i * ldb;
|
||||||
if (transa) {
|
if (transb) {
|
||||||
l = k;
|
|
||||||
bb = b + i;
|
bb = b + i;
|
||||||
}
|
}
|
||||||
cc = c + i * ldc;
|
cc = c + i * ldc;
|
||||||
|
@ -537,8 +530,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
||||||
#endif
|
#endif
|
||||||
IDEBUG_START;
|
IDEBUG_START;
|
||||||
|
|
||||||
FUNCTION_PROFILE_START();
|
|
||||||
|
|
||||||
buffer_size = j + k + 128 / sizeof(FLOAT);
|
buffer_size = j + k + 128 / sizeof(FLOAT);
|
||||||
#ifdef WINDOWS_ABI
|
#ifdef WINDOWS_ABI
|
||||||
buffer_size += 160 / sizeof(FLOAT);
|
buffer_size += 160 / sizeof(FLOAT);
|
||||||
|
@ -558,30 +549,39 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(COMPLEX)
|
#if defined(COMPLEX)
|
||||||
|
if (!transa)
|
||||||
(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i,
|
(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i,
|
||||||
a, lda, bb, incb, cc, 1,
|
a, lda, bb, incb, cc, 1,
|
||||||
buffer);
|
buffer);
|
||||||
|
else
|
||||||
|
(gemv[(int)transa]) (k, j, 0, alpha_r, alpha_i,
|
||||||
|
a, lda, bb, incb, cc, 1,
|
||||||
|
buffer);
|
||||||
#else
|
#else
|
||||||
|
if (!transa)
|
||||||
(gemv[(int)transa]) (j, k, 0, alpha, a, lda, bb,
|
(gemv[(int)transa]) (j, k, 0, alpha, a, lda, bb,
|
||||||
incb, cc, 1, buffer);
|
incb, cc, 1, buffer);
|
||||||
|
else
|
||||||
|
(gemv[(int)transa]) (k, j, 0, alpha, a, lda, bb,
|
||||||
|
incb, cc, 1, buffer);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef SMP
|
#ifdef SMP
|
||||||
} else {
|
} else {
|
||||||
|
if (!transa)
|
||||||
(gemv_thread[(int)transa]) (j, k, alpha, a, lda,
|
(gemv_thread[(int)transa]) (j, k, alpha, a, lda,
|
||||||
bb, incb, cc, 1,
|
bb, incb, cc, 1,
|
||||||
buffer, nthreads);
|
buffer, nthreads);
|
||||||
|
else
|
||||||
|
(gemv_thread[(int)transa]) (k, j, alpha, a, lda,
|
||||||
|
bb, incb, cc, 1,
|
||||||
|
buffer, nthreads);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
STACK_FREE(buffer);
|
STACK_FREE(buffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
FUNCTION_PROFILE_END(COMPSIZE * COMPSIZE,
|
|
||||||
args.m * args.k + args.k * args.n +
|
|
||||||
args.m * args.n, 2 * args.m * args.n * args.k);
|
|
||||||
|
|
||||||
IDEBUG_END;
|
IDEBUG_END;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue