From c4c591ac5afc10b5619d1c58b10d5095dc82a2ff Mon Sep 17 00:00:00 2001 From: Qiyu8 Date: Tue, 10 Nov 2020 16:16:38 +0800 Subject: [PATCH] fix sum optimize issues --- kernel/arm/sum.c | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/kernel/arm/sum.c b/kernel/arm/sum.c index 63584b95c..a486a1868 100644 --- a/kernel/arm/sum.c +++ b/kernel/arm/sum.c @@ -42,24 +42,27 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x) n *= inc_x; if (inc_x == 1) { -#if V_SIMD +#if V_SIMD && (!defined(DOUBLE) || (defined(DOUBLE) && V_SIMD_F64 && V_SIMD > 128)) #ifdef DOUBLE const int vstep = v_nlanes_f64; - const int unrollx2 = n & (-vstep * 2); + const int unrollx4 = n & (-vstep * 4); const int unrollx = n & -vstep; v_f64 vsum0 = v_zero_f64(); v_f64 vsum1 = v_zero_f64(); - while (i < unrollx2) - { - vsum0 = v_add_f64(vsum0, v_loadu_f64(x)); - vsum1 = v_add_f64(vsum1, v_loadu_f64(x + vstep)); - i += vstep * 2; - } - vsum0 = v_add_f64(vsum0, vsum1); - while (i < unrollx) + v_f64 vsum2 = v_zero_f64(); + v_f64 vsum3 = v_zero_f64(); + for (; i < unrollx4; i += vstep * 4) + { + vsum0 = v_add_f64(vsum0, v_loadu_f64(x + i)); + vsum1 = v_add_f64(vsum1, v_loadu_f64(x + i + vstep)); + vsum2 = v_add_f64(vsum2, v_loadu_f64(x + i + vstep * 2)); + vsum3 = v_add_f64(vsum3, v_loadu_f64(x + i + vstep * 3)); + } + vsum0 = v_add_f64( + v_add_f64(vsum0, vsum1), v_add_f64(vsum2, vsum3)); + for (; i < unrollx; i += vstep) { vsum0 = v_add_f64(vsum0, v_loadu_f64(x + i)); - i += vstep; } sumf = v_sum_f64(vsum0); #else @@ -70,20 +73,18 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x) v_f32 vsum1 = v_zero_f32(); v_f32 vsum2 = v_zero_f32(); v_f32 vsum3 = v_zero_f32(); - while (i < unrollx4) + for (; i < unrollx4; i += vstep * 4) { - vsum0 = v_add_f32(vsum0, v_loadu_f32(x)); - vsum1 = v_add_f32(vsum1, v_loadu_f32(x + vstep)); - vsum2 = v_add_f32(vsum2, v_loadu_f32(x + vstep * 2)); - vsum3 = v_add_f32(vsum3, v_loadu_f32(x + vstep * 3)); - i += vstep * 4; + vsum0 = v_add_f32(vsum0, v_loadu_f32(x + i)); + vsum1 = v_add_f32(vsum1, v_loadu_f32(x + i + vstep)); + vsum2 = v_add_f32(vsum2, v_loadu_f32(x + i + vstep * 2)); + vsum3 = v_add_f32(vsum3, v_loadu_f32(x + i + vstep * 3)); } vsum0 = v_add_f32( v_add_f32(vsum0, vsum1), v_add_f32(vsum2, vsum3)); - while (i < unrollx) + for (; i < unrollx; i += vstep) { vsum0 = v_add_f32(vsum0, v_loadu_f32(x + i)); - i += vstep; } sumf = v_sum_f32(vsum0); #endif