From 1b10ff129a7cf0b072feacfbf1357eff44749d09 Mon Sep 17 00:00:00 2001 From: wernsaar Date: Fri, 25 Jul 2014 10:00:23 +0200 Subject: [PATCH] optimizations for trmm --- driver/level3/trmm_L.c | 16 ++++++++++++---- driver/level3/trmm_R.c | 24 ++++++++++++++++++------ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/driver/level3/trmm_L.c b/driver/level3/trmm_L.c index c0a822b51..8a81d31a0 100644 --- a/driver/level3/trmm_L.c +++ b/driver/level3/trmm_L.c @@ -135,7 +135,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; START_RPCC(); @@ -199,7 +201,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; START_RPCC(); @@ -288,7 +292,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; START_RPCC(); @@ -352,7 +358,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; START_RPCC(); diff --git a/driver/level3/trmm_R.c b/driver/level3/trmm_R.c index 6012386c8..bdd9370cd 100644 --- a/driver/level3/trmm_R.c +++ b/driver/level3/trmm_R.c @@ -119,7 +119,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = 0; jjs < ls - js; jjs += min_jj){ min_jj = ls - js - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #ifndef TRANSA GEMM_ONCOPY(min_l, min_jj, a + (ls + (js + jjs) * lda) * COMPSIZE, lda, sb + min_l * jjs * COMPSIZE); @@ -137,7 +139,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = 0; jjs < min_l; jjs += min_jj){ min_jj = min_l - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #ifndef TRANSA TRMM_OLNCOPY(min_l, min_jj, a, lda, ls, ls + jjs, sb + min_l * (ls - js + jjs) * COMPSIZE); @@ -188,7 +192,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #ifndef TRANSA GEMM_ONCOPY(min_l, min_jj, a + (ls + jjs * lda) * COMPSIZE, lda, sb + min_l * (jjs - js) * COMPSIZE); @@ -239,7 +245,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = 0; jjs < min_l; jjs += min_jj){ min_jj = min_l - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #ifndef TRANSA TRMM_OUNCOPY(min_l, min_jj, a, lda, ls, ls + jjs, sb + min_l * jjs * COMPSIZE); @@ -258,7 +266,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = 0; jjs < js - ls - min_l; jjs += min_jj){ min_jj = js - ls - min_l - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #ifndef TRANSA GEMM_ONCOPY(min_l, min_jj, a + (ls + (ls + min_l + jjs) * lda) * COMPSIZE, lda, @@ -313,7 +323,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #ifndef TRANSA GEMM_ONCOPY(min_l, min_jj, a + (ls + (jjs - min_j) * lda) * COMPSIZE, lda, sb + min_l * (jjs - js) * COMPSIZE);