From 833bd0f8ffd1b7921ca8e98196e40183158f8cb7 Mon Sep 17 00:00:00 2001 From: wjc404 <52632443+wjc404@users.noreply.github.com> Date: Wed, 5 Feb 2020 10:09:41 +0800 Subject: [PATCH] Update trmm_L.c --- driver/level3/trmm_L.c | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/driver/level3/trmm_L.c b/driver/level3/trmm_L.c index 8a81d31a0..9117090b5 100644 --- a/driver/level3/trmm_L.c +++ b/driver/level3/trmm_L.c @@ -135,10 +135,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; +#ifdef SKYLAKEX + /* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */ + if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N; +#else if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; - +#endif START_RPCC(); GEMM_ONCOPY(min_l, min_jj, b + (jjs * ldb) * COMPSIZE, ldb, sb + min_l * (jjs - js) * COMPSIZE); @@ -201,10 +205,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; +#ifdef SKYLAKEX + /* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */ + if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N; +#else if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; - +#endif START_RPCC(); GEMM_ONCOPY(min_l, min_jj, b + (ls + jjs * ldb) * COMPSIZE, ldb, sb + min_l * (jjs - js) * COMPSIZE); @@ -292,10 +300,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; +#ifdef SKYLAKEX + /* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */ + if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N; +#else if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; - +#endif START_RPCC(); GEMM_ONCOPY(min_l, min_jj, b + (m - min_l + jjs * ldb) * COMPSIZE, ldb, @@ -358,10 +370,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; +#ifdef SKYLAKEX + /* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve the best performance */ + if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N; +#else if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; else if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; - +#endif START_RPCC(); GEMM_ONCOPY(min_l, min_jj, b + (ls - min_l + jjs * ldb) * COMPSIZE, ldb,