fix sve dtrsm kernels
This commit is contained in:
parent
8071e179f1
commit
aaa2b1a861
|
@ -182,8 +182,8 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,
|
||||||
|
|
||||||
i = m % sve_size;
|
i = m % sve_size;
|
||||||
if (i) {
|
if (i) {
|
||||||
aa = a + ((m & ~(i - 1)) - i) * k * COMPSIZE;
|
aa = a + (m - i) * k * COMPSIZE;
|
||||||
cc = c + ((m & ~(i - 1)) - i) * COMPSIZE;
|
cc = c + (m - i) * COMPSIZE;
|
||||||
|
|
||||||
if (k - kk > 0) {
|
if (k - kk > 0) {
|
||||||
GEMM_KERNEL(i, GEMM_UNROLL_N, k - kk, dm1,
|
GEMM_KERNEL(i, GEMM_UNROLL_N, k - kk, dm1,
|
||||||
|
@ -205,10 +205,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int mod = i;
|
||||||
i = sve_size;
|
i = sve_size;
|
||||||
if (i <= m) {
|
if (i <= m) {
|
||||||
aa = a + ((m & ~(sve_size - 1)) - sve_size) * k * COMPSIZE;
|
aa = a + (m - mod - sve_size) * k * COMPSIZE;
|
||||||
cc = c + ((m & ~(sve_size - 1)) - sve_size) * COMPSIZE;
|
cc = c + (m - mod - sve_size) * COMPSIZE;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
if (k - kk > 0) {
|
if (k - kk > 0) {
|
||||||
|
@ -217,7 +218,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,
|
||||||
ZERO,
|
ZERO,
|
||||||
#endif
|
#endif
|
||||||
aa + sve_size * kk * COMPSIZE,
|
aa + sve_size * kk * COMPSIZE,
|
||||||
b + sve_size * kk * COMPSIZE,
|
b + GEMM_UNROLL_N * kk * COMPSIZE,
|
||||||
cc,
|
cc,
|
||||||
ldc);
|
ldc);
|
||||||
}
|
}
|
||||||
|
@ -251,8 +252,8 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,
|
||||||
|
|
||||||
i = m % sve_size;
|
i = m % sve_size;
|
||||||
if (i) {
|
if (i) {
|
||||||
aa = a + ((m & ~(i - 1)) - i) * k * COMPSIZE;
|
aa = a + (m - i) * k * COMPSIZE;
|
||||||
cc = c + ((m & ~(i - 1)) - i) * COMPSIZE;
|
cc = c + (m - i) * COMPSIZE;
|
||||||
|
|
||||||
if (k - kk > 0) {
|
if (k - kk > 0) {
|
||||||
GEMM_KERNEL(i, j, k - kk, dm1,
|
GEMM_KERNEL(i, j, k - kk, dm1,
|
||||||
|
@ -273,10 +274,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int mod = i;
|
||||||
i = sve_size;
|
i = sve_size;
|
||||||
if (i <= m) {
|
if (i <= m) {
|
||||||
aa = a + ((m & ~(sve_size - 1)) - sve_size) * k * COMPSIZE;
|
aa = a + (m - mod - sve_size) * k * COMPSIZE;
|
||||||
cc = c + ((m & ~(sve_size - 1)) - sve_size) * COMPSIZE;
|
cc = c + (m - mod - sve_size) * COMPSIZE;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
if (k - kk > 0) {
|
if (k - kk > 0) {
|
||||||
|
|
|
@ -257,7 +257,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,
|
||||||
i += sve_size;
|
i += sve_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
i = sve_size % m;
|
i = m % sve_size;
|
||||||
if (i) {
|
if (i) {
|
||||||
if (kk > 0) {
|
if (kk > 0) {
|
||||||
GEMM_KERNEL(i, j, kk, dm1,
|
GEMM_KERNEL(i, j, kk, dm1,
|
||||||
|
|
|
@ -258,23 +258,23 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,
|
||||||
if (i <= m) {
|
if (i <= m) {
|
||||||
do {
|
do {
|
||||||
if (k - kk > 0) {
|
if (k - kk > 0) {
|
||||||
GEMM_KERNEL(GEMM_UNROLL_M, GEMM_UNROLL_N, k - kk, dm1,
|
GEMM_KERNEL(sve_size, GEMM_UNROLL_N, k - kk, dm1,
|
||||||
#ifdef COMPLEX
|
#ifdef COMPLEX
|
||||||
ZERO,
|
ZERO,
|
||||||
#endif
|
#endif
|
||||||
aa + GEMM_UNROLL_M * kk * COMPSIZE,
|
aa + sve_size * kk * COMPSIZE,
|
||||||
b + GEMM_UNROLL_N * kk * COMPSIZE,
|
b + GEMM_UNROLL_N * kk * COMPSIZE,
|
||||||
cc,
|
cc,
|
||||||
ldc);
|
ldc);
|
||||||
}
|
}
|
||||||
|
|
||||||
solve(GEMM_UNROLL_M, GEMM_UNROLL_N,
|
solve(sve_size, GEMM_UNROLL_N,
|
||||||
aa + (kk - GEMM_UNROLL_N) * GEMM_UNROLL_M * COMPSIZE,
|
aa + (kk - GEMM_UNROLL_N) * sve_size * COMPSIZE,
|
||||||
b + (kk - GEMM_UNROLL_N) * GEMM_UNROLL_N * COMPSIZE,
|
b + (kk - GEMM_UNROLL_N) * GEMM_UNROLL_N * COMPSIZE,
|
||||||
cc, ldc);
|
cc, ldc);
|
||||||
|
|
||||||
aa += GEMM_UNROLL_M * k * COMPSIZE;
|
aa += sve_size * k * COMPSIZE;
|
||||||
cc += GEMM_UNROLL_M * COMPSIZE;
|
cc += sve_size * COMPSIZE;
|
||||||
i += sve_size;
|
i += sve_size;
|
||||||
} while (i <= m);
|
} while (i <= m);
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,17 +48,18 @@
|
||||||
|
|
||||||
int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){
|
int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){
|
||||||
|
|
||||||
BLASLONG i, ii, j, jj;
|
BLASLONG i, ii, jj;
|
||||||
|
|
||||||
FLOAT *ao;
|
FLOAT *ao;
|
||||||
|
|
||||||
jj = offset;
|
jj = offset;
|
||||||
int js = 0;
|
|
||||||
#ifdef DOUBLE
|
#ifdef DOUBLE
|
||||||
|
int64_t js = 0;
|
||||||
svint64_t index = svindex_s64(0LL, lda);
|
svint64_t index = svindex_s64(0LL, lda);
|
||||||
svbool_t pn = svwhilelt_b64(js, n);
|
svbool_t pn = svwhilelt_b64(js, n);
|
||||||
int n_active = svcntp_b64(svptrue_b64(), pn);
|
int n_active = svcntp_b64(svptrue_b64(), pn);
|
||||||
#else
|
#else
|
||||||
|
int32_t js = 0;
|
||||||
svint32_t index = svindex_s32(0, lda);
|
svint32_t index = svindex_s32(0, lda);
|
||||||
svbool_t pn = svwhilelt_b32(js, n);
|
svbool_t pn = svwhilelt_b32(js, n);
|
||||||
int n_active = svcntp_b32(svptrue_b32(), pn);
|
int n_active = svcntp_b32(svptrue_b32(), pn);
|
||||||
|
@ -74,25 +75,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT
|
||||||
if (ii == jj) {
|
if (ii == jj) {
|
||||||
for (int j = 0; j < n_active; j++) {
|
for (int j = 0; j < n_active; j++) {
|
||||||
for (int k = 0; k < j; k++) {
|
for (int k = 0; k < j; k++) {
|
||||||
*(b + j * n_active + k) = *(a + k * lda + j);
|
*(b + j * n_active + k) = *(ao + k * lda + j);
|
||||||
}
|
}
|
||||||
*(b + j * n_active + j) = INV(*(a + j * lda + j));
|
*(b + j * n_active + j) = INV(*(ao + j * lda + j));
|
||||||
}
|
}
|
||||||
}
|
ao += n_active;
|
||||||
|
|
||||||
if (ii > jj) {
|
|
||||||
for (int j = 0; j < n_active; j++) {
|
|
||||||
svfloat64_t aj_vec = svld1_gather_index(pn, ao, index);
|
|
||||||
svst1(pn, b, aj_vec);
|
|
||||||
ao++;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
b += n_active * n_active;
|
b += n_active * n_active;
|
||||||
|
|
||||||
i += n_active;
|
i += n_active;
|
||||||
ii += n_active;
|
ii += n_active;
|
||||||
|
} else {
|
||||||
|
if (ii > jj) {
|
||||||
|
svfloat64_t aj_vec = svld1_gather_index(pn, ao, index);
|
||||||
|
svst1(pn, b, aj_vec);
|
||||||
|
}
|
||||||
|
ao++;
|
||||||
|
b += n_active;
|
||||||
|
i++;
|
||||||
|
ii++;
|
||||||
|
}
|
||||||
} while (i < m);
|
} while (i < m);
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -48,18 +48,17 @@
|
||||||
|
|
||||||
int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){
|
int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){
|
||||||
|
|
||||||
BLASLONG i, ii, j, jj;
|
BLASLONG i, ii, jj;
|
||||||
|
|
||||||
FLOAT *ao;
|
FLOAT *ao;
|
||||||
|
|
||||||
jj = offset;
|
jj = offset;
|
||||||
int js = 0;
|
|
||||||
#ifdef DOUBLE
|
#ifdef DOUBLE
|
||||||
svint64_t index = svindex_s64(0LL, lda);
|
int64_t js = 0;
|
||||||
svbool_t pn = svwhilelt_b64(js, n);
|
svbool_t pn = svwhilelt_b64(js, n);
|
||||||
int n_active = svcntp_b64(svptrue_b64(), pn);
|
int n_active = svcntp_b64(svptrue_b64(), pn);
|
||||||
#else
|
#else
|
||||||
svint32_t index = svindex_s32(0, lda);
|
int32_t js = 0;
|
||||||
svbool_t pn = svwhilelt_b32(js, n);
|
svbool_t pn = svwhilelt_b32(js, n);
|
||||||
int n_active = svcntp_b32(svptrue_b32(), pn);
|
int n_active = svcntp_b32(svptrue_b32(), pn);
|
||||||
#endif
|
#endif
|
||||||
|
@ -73,26 +72,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT
|
||||||
|
|
||||||
if (ii == jj) {
|
if (ii == jj) {
|
||||||
for (int j = 0; j < n_active; j++) {
|
for (int j = 0; j < n_active; j++) {
|
||||||
*(b + j * n_active + j) = INV(*(a + j * lda + j));
|
*(b + j * n_active + j) = INV(*(ao + j * lda + j));
|
||||||
for (int k = j+1; k < n_active; k++) {
|
for (int k = j+1; k < n_active; k++) {
|
||||||
*(b + j * n_active + k) = *(a + j * lda + k);
|
*(b + j * n_active + k) = *(ao + j * lda + k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (ii < jj) {
|
|
||||||
for (int j = 0; j < n_active; j++) {
|
|
||||||
svfloat64_t aj_vec = svld1(pn, ao);
|
|
||||||
svst1(pn, b, aj_vec);
|
|
||||||
ao += lda;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
b += n_active * n_active;
|
b += n_active * n_active;
|
||||||
|
ao += lda * n_active;
|
||||||
i += n_active;
|
i += n_active;
|
||||||
ii += n_active;
|
ii += n_active;
|
||||||
|
} else {
|
||||||
|
if (ii < jj) {
|
||||||
|
svfloat64_t aj_vec = svld1(pn, ao);
|
||||||
|
svst1(pn, b, aj_vec);
|
||||||
|
}
|
||||||
|
ao += lda;
|
||||||
|
b += n_active;
|
||||||
|
i ++;
|
||||||
|
ii ++;
|
||||||
|
}
|
||||||
} while (i < m);
|
} while (i < m);
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -48,17 +48,18 @@
|
||||||
|
|
||||||
int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){
|
int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){
|
||||||
|
|
||||||
BLASLONG i, ii, j, jj;
|
BLASLONG i, ii, jj;
|
||||||
|
|
||||||
FLOAT *ao;
|
FLOAT *ao;
|
||||||
|
|
||||||
jj = offset;
|
jj = offset;
|
||||||
int js = 0;
|
|
||||||
#ifdef DOUBLE
|
#ifdef DOUBLE
|
||||||
|
int64_t js = 0;
|
||||||
svint64_t index = svindex_s64(0LL, lda);
|
svint64_t index = svindex_s64(0LL, lda);
|
||||||
svbool_t pn = svwhilelt_b64(js, n);
|
svbool_t pn = svwhilelt_b64(js, n);
|
||||||
int n_active = svcntp_b64(svptrue_b64(), pn);
|
int n_active = svcntp_b64(svptrue_b64(), pn);
|
||||||
#else
|
#else
|
||||||
|
int32_t js = 0;
|
||||||
svint32_t index = svindex_s32(0, lda);
|
svint32_t index = svindex_s32(0, lda);
|
||||||
svbool_t pn = svwhilelt_b32(js, n);
|
svbool_t pn = svwhilelt_b32(js, n);
|
||||||
int n_active = svcntp_b32(svptrue_b32(), pn);
|
int n_active = svcntp_b32(svptrue_b32(), pn);
|
||||||
|
@ -73,25 +74,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT
|
||||||
|
|
||||||
if (ii == jj) {
|
if (ii == jj) {
|
||||||
for (int j = 0; j < n_active; j++) {
|
for (int j = 0; j < n_active; j++) {
|
||||||
*(b + j * n_active + j) = INV(*(a + j * lda + j));
|
*(b + j * n_active + j) = INV(*(ao + j * lda + j));
|
||||||
for (int k = j+1; k < n_active; k++) {
|
for (int k = j+1; k < n_active; k++) {
|
||||||
*(b + j * n_active + k) = *(a + k * lda + j);
|
*(b + j * n_active + k) = *(ao + k * lda + j);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
ao += n_active;
|
||||||
|
|
||||||
if (ii < jj) {
|
|
||||||
for (int j = 0; j < n_active; j++) {
|
|
||||||
svfloat64_t aj_vec = svld1_gather_index(pn, ao, index);
|
|
||||||
svst1(pn, b, aj_vec);
|
|
||||||
ao++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
b += n_active * n_active;
|
b += n_active * n_active;
|
||||||
|
|
||||||
i += n_active;
|
i += n_active;
|
||||||
ii += n_active;
|
ii += n_active;
|
||||||
|
} else {
|
||||||
|
if (ii < jj) {
|
||||||
|
svfloat64_t aj_vec = svld1_gather_index(pn, ao, index);
|
||||||
|
svst1(pn, b, aj_vec);
|
||||||
|
}
|
||||||
|
ao++;
|
||||||
|
b += n_active;
|
||||||
|
i++;
|
||||||
|
ii++;
|
||||||
|
}
|
||||||
} while (i < m);
|
} while (i < m);
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -48,18 +48,17 @@
|
||||||
|
|
||||||
int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){
|
int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){
|
||||||
|
|
||||||
BLASLONG i, ii, j, jj;
|
BLASLONG i, ii, jj;
|
||||||
|
|
||||||
FLOAT *ao;
|
FLOAT *ao;
|
||||||
|
|
||||||
jj = offset;
|
jj = offset;
|
||||||
int js = 0;
|
|
||||||
#ifdef DOUBLE
|
#ifdef DOUBLE
|
||||||
svint64_t index = svindex_s64(0LL, lda);
|
int64_t js = 0;
|
||||||
svbool_t pn = svwhilelt_b64(js, n);
|
svbool_t pn = svwhilelt_b64(js, n);
|
||||||
int n_active = svcntp_b64(svptrue_b64(), pn);
|
int n_active = svcntp_b64(svptrue_b64(), pn);
|
||||||
#else
|
#else
|
||||||
svint32_t index = svindex_s32(0, lda);
|
int32_t js = 0;
|
||||||
svbool_t pn = svwhilelt_b32(js, n);
|
svbool_t pn = svwhilelt_b32(js, n);
|
||||||
int n_active = svcntp_b32(svptrue_b32(), pn);
|
int n_active = svcntp_b32(svptrue_b32(), pn);
|
||||||
#endif
|
#endif
|
||||||
|
@ -74,25 +73,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT
|
||||||
if (ii == jj) {
|
if (ii == jj) {
|
||||||
for (int j = 0; j < n_active; j++) {
|
for (int j = 0; j < n_active; j++) {
|
||||||
for (int k = 0; k < j; k++) {
|
for (int k = 0; k < j; k++) {
|
||||||
*(b + j * n_active + k) = *(a + j * lda + k);
|
*(b + j * n_active + k) = *(ao + j * lda + k);
|
||||||
}
|
}
|
||||||
*(b + j * n_active + j) = INV(*(a + j * lda + j));
|
*(b + j * n_active + j) = INV(*(ao + j * lda + j));
|
||||||
}
|
}
|
||||||
}
|
ao += lda * n_active;
|
||||||
|
|
||||||
if (ii > jj) {
|
|
||||||
for (int j = 0; j < n_active; j++) {
|
|
||||||
svfloat64_t aj_vec = svld1(pn, ao);
|
|
||||||
svst1(pn, b, aj_vec);
|
|
||||||
ao += lda;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
b += n_active * n_active;
|
b += n_active * n_active;
|
||||||
|
|
||||||
i += n_active;
|
i += n_active;
|
||||||
ii += n_active;
|
ii += n_active;
|
||||||
|
} else {
|
||||||
|
if (ii > jj) {
|
||||||
|
svfloat64_t aj_vec = svld1(pn, ao);
|
||||||
|
svst1(pn, b, aj_vec);
|
||||||
|
}
|
||||||
|
ao += lda;
|
||||||
|
b += n_active;
|
||||||
|
i ++;
|
||||||
|
ii ++;
|
||||||
|
}
|
||||||
} while (i < m);
|
} while (i < m);
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue