fix a bug of trmm
This commit is contained in:
parent
0f112077e6
commit
ebe64a3c03
|
@ -122,6 +122,10 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
|
||||||
if (min_l > GEMM_Q) min_l = GEMM_Q;
|
if (min_l > GEMM_Q) min_l = GEMM_Q;
|
||||||
min_i = min_l;
|
min_i = min_l;
|
||||||
if (min_i > GEMM_P) min_i = GEMM_P;
|
if (min_i > GEMM_P) min_i = GEMM_P;
|
||||||
|
if( min_i > GEMM_UNROLL_M){
|
||||||
|
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
START_RPCC();
|
START_RPCC();
|
||||||
|
|
||||||
|
@ -135,14 +139,10 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
|
||||||
|
|
||||||
for(jjs = js; jjs < js + min_j; jjs += min_jj){
|
for(jjs = js; jjs < js + min_j; jjs += min_jj){
|
||||||
min_jj = min_j + js - jjs;
|
min_jj = min_j + js - jjs;
|
||||||
#if defined(SKYLAKEX) || defined(COOPERLAKE)
|
if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
|
||||||
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
|
|
||||||
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
|
|
||||||
#else
|
|
||||||
if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
|
|
||||||
else
|
else
|
||||||
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
|
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
|
||||||
#endif
|
|
||||||
START_RPCC();
|
START_RPCC();
|
||||||
|
|
||||||
GEMM_ONCOPY(min_l, min_jj, b + (jjs * ldb) * COMPSIZE, ldb, sb + min_l * (jjs - js) * COMPSIZE);
|
GEMM_ONCOPY(min_l, min_jj, b + (jjs * ldb) * COMPSIZE, ldb, sb + min_l * (jjs - js) * COMPSIZE);
|
||||||
|
@ -161,9 +161,13 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
for(is = min_i; is < min_l; is += GEMM_P){
|
for(is = min_i; is < min_l; is += min_i){
|
||||||
min_i = min_l - is;
|
min_i = min_l - is;
|
||||||
if (min_i > GEMM_P) min_i = GEMM_P;
|
if (min_i > GEMM_P) min_i = GEMM_P;
|
||||||
|
if( min_i > GEMM_UNROLL_M){
|
||||||
|
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
START_RPCC();
|
START_RPCC();
|
||||||
|
|
||||||
|
@ -192,6 +196,11 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
|
||||||
if (min_l > GEMM_Q) min_l = GEMM_Q;
|
if (min_l > GEMM_Q) min_l = GEMM_Q;
|
||||||
min_i = ls;
|
min_i = ls;
|
||||||
if (min_i > GEMM_P) min_i = GEMM_P;
|
if (min_i > GEMM_P) min_i = GEMM_P;
|
||||||
|
if( min_i > GEMM_UNROLL_M){
|
||||||
|
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
START_RPCC();
|
START_RPCC();
|
||||||
|
|
||||||
|
@ -205,14 +214,10 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
|
||||||
|
|
||||||
for(jjs = js; jjs < js + min_j; jjs += min_jj){
|
for(jjs = js; jjs < js + min_j; jjs += min_jj){
|
||||||
min_jj = min_j + js - jjs;
|
min_jj = min_j + js - jjs;
|
||||||
#if defined(SKYLAKEX) || defined(COOPERLAKE)
|
if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
|
||||||
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
|
|
||||||
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
|
|
||||||
#else
|
|
||||||
if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
|
|
||||||
else
|
else
|
||||||
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
|
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
|
||||||
#endif
|
|
||||||
START_RPCC();
|
START_RPCC();
|
||||||
|
|
||||||
GEMM_ONCOPY(min_l, min_jj, b + (ls + jjs * ldb) * COMPSIZE, ldb, sb + min_l * (jjs - js) * COMPSIZE);
|
GEMM_ONCOPY(min_l, min_jj, b + (ls + jjs * ldb) * COMPSIZE, ldb, sb + min_l * (jjs - js) * COMPSIZE);
|
||||||
|
@ -231,9 +236,12 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
|
||||||
STOP_RPCC(gemmcost);
|
STOP_RPCC(gemmcost);
|
||||||
}
|
}
|
||||||
|
|
||||||
for(is = min_i; is < ls; is += GEMM_P){
|
for(is = min_i; is < ls; is += min_i){
|
||||||
min_i = ls - is;
|
min_i = ls - is;
|
||||||
if (min_i > GEMM_P) min_i = GEMM_P;
|
if (min_i > GEMM_P) min_i = GEMM_P;
|
||||||
|
if( min_i > GEMM_UNROLL_M){
|
||||||
|
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
|
||||||
|
}
|
||||||
|
|
||||||
START_RPCC();
|
START_RPCC();
|
||||||
|
|
||||||
|
@ -256,9 +264,12 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
|
||||||
STOP_RPCC(gemmcost);
|
STOP_RPCC(gemmcost);
|
||||||
}
|
}
|
||||||
|
|
||||||
for(is = ls; is < ls + min_l; is += GEMM_P){
|
for(is = ls; is < ls + min_l; is += min_i){
|
||||||
min_i = ls + min_l - is;
|
min_i = ls + min_l - is;
|
||||||
if (min_i > GEMM_P) min_i = GEMM_P;
|
if (min_i > GEMM_P) min_i = GEMM_P;
|
||||||
|
if( min_i > GEMM_UNROLL_M){
|
||||||
|
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
|
||||||
|
}
|
||||||
|
|
||||||
START_RPCC();
|
START_RPCC();
|
||||||
|
|
||||||
|
@ -287,6 +298,10 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
|
||||||
if (min_l > GEMM_Q) min_l = GEMM_Q;
|
if (min_l > GEMM_Q) min_l = GEMM_Q;
|
||||||
min_i = min_l;
|
min_i = min_l;
|
||||||
if (min_i > GEMM_P) min_i = GEMM_P;
|
if (min_i > GEMM_P) min_i = GEMM_P;
|
||||||
|
if (min_i > GEMM_UNROLL_M){
|
||||||
|
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
START_RPCC();
|
START_RPCC();
|
||||||
|
|
||||||
|
@ -300,14 +315,10 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
|
||||||
|
|
||||||
for(jjs = js; jjs < js + min_j; jjs += min_jj){
|
for(jjs = js; jjs < js + min_j; jjs += min_jj){
|
||||||
min_jj = min_j + js - jjs;
|
min_jj = min_j + js - jjs;
|
||||||
#if defined(SKYLAKEX) || defined(COOPERLAKE)
|
if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
|
||||||
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
|
|
||||||
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
|
|
||||||
#else
|
|
||||||
if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
|
|
||||||
else
|
else
|
||||||
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
|
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
|
||||||
#endif
|
|
||||||
START_RPCC();
|
START_RPCC();
|
||||||
|
|
||||||
GEMM_ONCOPY(min_l, min_jj, b + (m - min_l + jjs * ldb) * COMPSIZE, ldb,
|
GEMM_ONCOPY(min_l, min_jj, b + (m - min_l + jjs * ldb) * COMPSIZE, ldb,
|
||||||
|
@ -327,9 +338,13 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
|
||||||
STOP_RPCC(trmmcost);
|
STOP_RPCC(trmmcost);
|
||||||
}
|
}
|
||||||
|
|
||||||
for(is = m - min_l + min_i; is < m; is += GEMM_P){
|
for(is = m - min_l + min_i; is < m; is += min_i){
|
||||||
min_i = m - is;
|
min_i = m - is;
|
||||||
if (min_i > GEMM_P) min_i = GEMM_P;
|
if (min_i > GEMM_P) min_i = GEMM_P;
|
||||||
|
if (min_i > GEMM_UNROLL_M){
|
||||||
|
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
START_RPCC();
|
START_RPCC();
|
||||||
|
|
||||||
|
@ -357,6 +372,10 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
|
||||||
if (min_l > GEMM_Q) min_l = GEMM_Q;
|
if (min_l > GEMM_Q) min_l = GEMM_Q;
|
||||||
min_i = min_l;
|
min_i = min_l;
|
||||||
if (min_i > GEMM_P) min_i = GEMM_P;
|
if (min_i > GEMM_P) min_i = GEMM_P;
|
||||||
|
if (min_i > GEMM_UNROLL_M){
|
||||||
|
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
START_RPCC();
|
START_RPCC();
|
||||||
|
|
||||||
|
@ -370,14 +389,10 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
|
||||||
|
|
||||||
for(jjs = js; jjs < js + min_j; jjs += min_jj){
|
for(jjs = js; jjs < js + min_j; jjs += min_jj){
|
||||||
min_jj = min_j + js - jjs;
|
min_jj = min_j + js - jjs;
|
||||||
#if defined(SKYLAKEX) || defined(COOPERLAKE)
|
if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
|
||||||
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */
|
|
||||||
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
|
|
||||||
#else
|
|
||||||
if (min_jj >= GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3;
|
|
||||||
else
|
else
|
||||||
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
|
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
|
||||||
#endif
|
|
||||||
START_RPCC();
|
START_RPCC();
|
||||||
|
|
||||||
GEMM_ONCOPY(min_l, min_jj, b + (ls - min_l + jjs * ldb) * COMPSIZE, ldb,
|
GEMM_ONCOPY(min_l, min_jj, b + (ls - min_l + jjs * ldb) * COMPSIZE, ldb,
|
||||||
|
@ -397,9 +412,13 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
|
||||||
STOP_RPCC(trmmcost);
|
STOP_RPCC(trmmcost);
|
||||||
}
|
}
|
||||||
|
|
||||||
for(is = ls - min_l + min_i; is < ls; is += GEMM_P){
|
for(is = ls - min_l + min_i; is < ls; is += min_i){
|
||||||
min_i = ls - is;
|
min_i = ls - is;
|
||||||
if (min_i > GEMM_P) min_i = GEMM_P;
|
if (min_i > GEMM_P) min_i = GEMM_P;
|
||||||
|
if( min_i > GEMM_UNROLL_M){
|
||||||
|
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
START_RPCC();
|
START_RPCC();
|
||||||
|
|
||||||
|
@ -423,9 +442,12 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
for(is = ls; is < m; is += GEMM_P){
|
for(is = ls; is < m; is += min_i){
|
||||||
min_i = m - is;
|
min_i = m - is;
|
||||||
if (min_i > GEMM_P) min_i = GEMM_P;
|
if (min_i > GEMM_P) min_i = GEMM_P;
|
||||||
|
if( min_i > GEMM_UNROLL_M){
|
||||||
|
min_i = (min_i / GEMM_UNROLL_M) * GEMM_UNROLL_M;
|
||||||
|
}
|
||||||
|
|
||||||
START_RPCC();
|
START_RPCC();
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue