From a5b9a0d358b51dbb918636e2412934a7ffdcbd1c Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Wed, 24 Jul 2024 17:26:46 +0000 Subject: [PATCH] Unroll SVE gemv further and use on all SVE cores Unrolling the new SVE kernels looks to be good enough for all of the SVE cores I tested on. --- kernel/arm64/KERNEL.A64FX | 5 - kernel/arm64/KERNEL.ARMV8SVE | 8 +- kernel/arm64/gemv_n_sve.c | 218 ++++++++++++++++++++++++++++++----- kernel/arm64/gemv_t_sve.c | 147 ++++++++++++++++++----- 4 files changed, 315 insertions(+), 63 deletions(-) diff --git a/kernel/arm64/KERNEL.A64FX b/kernel/arm64/KERNEL.A64FX index 4abc84040..bc5999097 100644 --- a/kernel/arm64/KERNEL.A64FX +++ b/kernel/arm64/KERNEL.A64FX @@ -1,6 +1 @@ include $(KERNELDIR)/KERNEL.ARMV8SVE - -SGEMVNKERNEL = gemv_n_sve.c -DGEMVNKERNEL = gemv_n_sve.c -SGEMVTKERNEL = gemv_t_sve.c -DGEMVTKERNEL = gemv_t_sve.c diff --git a/kernel/arm64/KERNEL.ARMV8SVE b/kernel/arm64/KERNEL.ARMV8SVE index eeb4844bf..ebdb5794c 100644 --- a/kernel/arm64/KERNEL.ARMV8SVE +++ b/kernel/arm64/KERNEL.ARMV8SVE @@ -74,13 +74,13 @@ DSCALKERNEL = scal.S CSCALKERNEL = zscal.S ZSCALKERNEL = zscal.S -SGEMVNKERNEL = gemv_n.S -DGEMVNKERNEL = gemv_n.S +SGEMVNKERNEL = gemv_n_sve.c +DGEMVNKERNEL = gemv_n_sve.c CGEMVNKERNEL = zgemv_n.S ZGEMVNKERNEL = zgemv_n.S -SGEMVTKERNEL = gemv_t.S -DGEMVTKERNEL = gemv_t.S +SGEMVTKERNEL = gemv_t_sve.c +DGEMVTKERNEL = gemv_t_sve.c CGEMVTKERNEL = zgemv_t.S ZGEMVTKERNEL = zgemv_t.S diff --git a/kernel/arm64/gemv_n_sve.c b/kernel/arm64/gemv_n_sve.c index d3aa57ae3..f7f780e60 100644 --- a/kernel/arm64/gemv_n_sve.c +++ b/kernel/arm64/gemv_n_sve.c @@ -47,45 +47,209 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define SV_DUP svdup_f32 #endif -int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *buffer) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *buffer) { - BLASLONG i; - BLASLONG ix,iy; - BLASLONG j; - FLOAT *a_ptr; - FLOAT temp; + const uint64_t v_size = SV_COUNT(); + const uint64_t v_size2 = v_size * 2; + const svbool_t pg_true = SV_TRUE(); +#ifndef DOUBLE + const BLASLONG n8 = N & -8; +#endif + const BLASLONG n4 = N & -4; +#ifdef DOUBLE + const BLASLONG n2 = N & -2; +#endif + const BLASLONG v_m1 = M & -v_size; + const BLASLONG v_m2 = M & -v_size2; - ix = 0; - a_ptr = a; + BLASLONG ix = 0; if (inc_y == 1) { - uint64_t sve_size = SV_COUNT(); - for (j = 0; j < n; j++) { - SV_TYPE temp_vec = SV_DUP(alpha * x[ix]); - i = 0; - svbool_t pg = SV_WHILE(i, m); - while (svptest_any(SV_TRUE(), pg)) { - SV_TYPE a_vec = svld1(pg, a_ptr + i); - SV_TYPE y_vec = svld1(pg, y + i); - y_vec = svmla_x(pg, y_vec, temp_vec, a_vec); - svst1(pg, y + i, y_vec); - i += sve_size; - pg = SV_WHILE(i, m); + BLASLONG j = 0; + if (inc_x == 1) { +#ifndef DOUBLE + for (; j < n8; j += 8) { + SV_TYPE temp_vec1 = svmul_x(pg_true, svld1rq(pg_true, &x[ix]), alpha); + SV_TYPE temp_vec2 = svmul_x(pg_true, svld1rq(pg_true, &x[ix + 4]), alpha); + + BLASLONG i = 0; + for (; i < v_m1; i += v_size) { + SV_TYPE a_vec1 = svld1(pg_true, a + i); + SV_TYPE a_vec2 = svld1(pg_true, a + i + lda); + SV_TYPE a_vec3 = svld1(pg_true, a + i + lda * 2); + SV_TYPE a_vec4 = svld1(pg_true, a + i + lda * 3); + SV_TYPE a_vec5 = svld1(pg_true, a + i + lda * 4); + SV_TYPE a_vec6 = svld1(pg_true, a + i + lda * 5); + SV_TYPE a_vec7 = svld1(pg_true, a + i + lda * 6); + SV_TYPE a_vec8 = svld1(pg_true, a + i + lda * 7); + SV_TYPE y_vec = svld1(pg_true, y + i); + y_vec = svmla_lane(y_vec, a_vec1, temp_vec1, 0); + y_vec = svmla_lane(y_vec, a_vec2, temp_vec1, 1); + y_vec = svmla_lane(y_vec, a_vec3, temp_vec1, 2); + y_vec = svmla_lane(y_vec, a_vec4, temp_vec1, 3); + y_vec = svmla_lane(y_vec, a_vec5, temp_vec2, 0); + y_vec = svmla_lane(y_vec, a_vec6, temp_vec2, 1); + y_vec = svmla_lane(y_vec, a_vec7, temp_vec2, 2); + y_vec = svmla_lane(y_vec, a_vec8, temp_vec2, 3); + svst1(pg_true, y + i, y_vec); + } + + for (; i < M; i += v_size) { + svbool_t pg = SV_WHILE(i, M); + SV_TYPE a_vec1 = svld1(pg, a + i); + SV_TYPE a_vec2 = svld1(pg, a + i + lda); + SV_TYPE a_vec3 = svld1(pg, a + i + lda * 2); + SV_TYPE a_vec4 = svld1(pg, a + i + lda * 3); + SV_TYPE a_vec5 = svld1(pg, a + i + lda * 4); + SV_TYPE a_vec6 = svld1(pg, a + i + lda * 5); + SV_TYPE a_vec7 = svld1(pg, a + i + lda * 6); + SV_TYPE a_vec8 = svld1(pg, a + i + lda * 7); + + SV_TYPE y_vec = svld1(pg, y + i); + y_vec = svmla_lane(y_vec, a_vec1, temp_vec1, 0); + y_vec = svmla_lane(y_vec, a_vec2, temp_vec1, 1); + y_vec = svmla_lane(y_vec, a_vec3, temp_vec1, 2); + y_vec = svmla_lane(y_vec, a_vec4, temp_vec1, 3); + y_vec = svmla_lane(y_vec, a_vec5, temp_vec2, 0); + y_vec = svmla_lane(y_vec, a_vec6, temp_vec2, 1); + y_vec = svmla_lane(y_vec, a_vec7, temp_vec2, 2); + y_vec = svmla_lane(y_vec, a_vec8, temp_vec2, 3); + svst1(pg, y + i, y_vec); + } + + a += lda * 8; + ix += 8; } - a_ptr += lda; + for (; j < n4; j += 4) { + SV_TYPE temp_vec1 = svmul_x(pg_true, svld1rq(pg_true, &x[ix]), alpha); + + BLASLONG i = 0; + for (; i < v_m1; i += v_size) { + SV_TYPE a_vec1 = svld1(pg_true, a + i); + SV_TYPE a_vec2 = svld1(pg_true, a + i + lda); + SV_TYPE a_vec3 = svld1(pg_true, a + i + lda * 2); + SV_TYPE a_vec4 = svld1(pg_true, a + i + lda * 3); + SV_TYPE y_vec = svld1(pg_true, y + i); + y_vec = svmla_lane(y_vec, a_vec1, temp_vec1, 0); + y_vec = svmla_lane(y_vec, a_vec2, temp_vec1, 1); + y_vec = svmla_lane(y_vec, a_vec3, temp_vec1, 2); + y_vec = svmla_lane(y_vec, a_vec4, temp_vec1, 3); + svst1(pg_true, y + i, y_vec); + } + + for (; i < M; i += v_size) { + svbool_t pg = SV_WHILE(i, M); + SV_TYPE a_vec1 = svld1(pg, a + i); + SV_TYPE a_vec2 = svld1(pg, a + i + lda); + SV_TYPE a_vec3 = svld1(pg, a + i + lda * 2); + SV_TYPE a_vec4 = svld1(pg, a + i + lda * 3); + + SV_TYPE y_vec = svld1(pg, y + i); + y_vec = svmla_lane(y_vec, a_vec1, temp_vec1, 0); + y_vec = svmla_lane(y_vec, a_vec2, temp_vec1, 1); + y_vec = svmla_lane(y_vec, a_vec3, temp_vec1, 2); + y_vec = svmla_lane(y_vec, a_vec4, temp_vec1, 3); + svst1(pg, y + i, y_vec); + } + + a += lda * 4; + ix += 4; + } +#else + for (; j < n4; j += 4) { + SV_TYPE temp_vec1 = svmul_x(pg_true, svld1rq(pg_true, &x[ix]), alpha); + SV_TYPE temp_vec2 = svmul_x(pg_true, svld1rq(pg_true, &x[ix + 2]), alpha); + + BLASLONG i = 0; + for (; i < v_m1; i += v_size) { + SV_TYPE a_vec1 = svld1(pg_true, a + i); + SV_TYPE a_vec2 = svld1(pg_true, a + i + lda); + SV_TYPE a_vec3 = svld1(pg_true, a + i + lda * 2); + SV_TYPE a_vec4 = svld1(pg_true, a + i + lda * 3); + SV_TYPE y_vec = svld1(pg_true, y + i); + y_vec = svmla_lane(y_vec, a_vec1, temp_vec1, 0); + y_vec = svmla_lane(y_vec, a_vec2, temp_vec1, 1); + y_vec = svmla_lane(y_vec, a_vec3, temp_vec2, 0); + y_vec = svmla_lane(y_vec, a_vec4, temp_vec2, 1); + svst1(pg_true, y + i, y_vec); + } + for (; i < M; i += v_size) { + svbool_t pg = SV_WHILE(i, M); + SV_TYPE a_vec1 = svld1(pg, a + i); + SV_TYPE a_vec2 = svld1(pg, a + i + lda); + SV_TYPE a_vec3 = svld1(pg, a + i + lda * 2); + SV_TYPE a_vec4 = svld1(pg, a + i + lda * 3); + SV_TYPE y_vec = svld1(pg, y + i); + y_vec = svmla_lane(y_vec, a_vec1, temp_vec1, 0); + y_vec = svmla_lane(y_vec, a_vec2, temp_vec1, 1); + y_vec = svmla_lane(y_vec, a_vec3, temp_vec2, 0); + y_vec = svmla_lane(y_vec, a_vec4, temp_vec2, 1); + svst1(pg, y + i, y_vec); + } + + a += lda * 4; + ix += 4; + } + for (; j < n2; j += 2) { + SV_TYPE temp_vec1 = svmul_x(pg_true, svld1rq(pg_true, &x[ix]), alpha); + + BLASLONG i = 0; + for (; i < v_m1; i += v_size) { + SV_TYPE a_vec1 = svld1(pg_true, a + i); + SV_TYPE a_vec2 = svld1(pg_true, a + i + lda); + SV_TYPE y_vec = svld1(pg_true, y + i); + y_vec = svmla_lane(y_vec, a_vec1, temp_vec1, 0); + y_vec = svmla_lane(y_vec, a_vec2, temp_vec1, 1); + svst1(pg_true, y + i, y_vec); + } + for (; i < M; i += v_size) { + svbool_t pg = SV_WHILE(i, M); + SV_TYPE a_vec1 = svld1(pg, a + i); + SV_TYPE a_vec2 = svld1(pg, a + i + lda); + SV_TYPE y_vec = svld1(pg, y + i); + y_vec = svmla_lane(y_vec, a_vec1, temp_vec1, 0); + y_vec = svmla_lane(y_vec, a_vec2, temp_vec1, 1); + svst1(pg, y + i, y_vec); + } + + a += lda * 2; + ix += 2; + } +#endif + } + + for (; j < N; j++) { + SV_TYPE temp_vec1 = SV_DUP(alpha * x[ix]); + SV_TYPE temp_vec2 = temp_vec1; + + BLASLONG i = 0; + for (; i < v_m1; i += v_size) { + SV_TYPE a_vec = svld1(pg_true, a + i); + SV_TYPE y_vec = svld1(pg_true, y + i); + y_vec = svmla_x(pg_true, y_vec, temp_vec1, a_vec); + svst1(pg_true, y + i, y_vec); + } + for (; i < M; i += v_size) { + svbool_t pg = SV_WHILE(i, M); + SV_TYPE a_vec = svld1(pg, a + i); + SV_TYPE y_vec = svld1(pg, y + i); + y_vec = svmla_x(pg, y_vec, temp_vec1, a_vec); + svst1(pg, y + i, y_vec); + } + a += lda; ix += inc_x; } return(0); } - for (j = 0; j < n; j++) { - temp = alpha * x[ix]; - iy = 0; - for (i = 0; i < m; i++) { - y[iy] += temp * a_ptr[i]; + for (BLASLONG j = 0; j < N; j++) { + FLOAT temp = alpha * x[ix]; + BLASLONG iy = 0; + for (BLASLONG i = 0; i < M; i++) { + y[iy] += temp * a[i]; iy += inc_y; } - a_ptr += lda; + a += lda; ix += inc_x; } return (0); diff --git a/kernel/arm64/gemv_t_sve.c b/kernel/arm64/gemv_t_sve.c index bff08b257..96d667bf9 100644 --- a/kernel/arm64/gemv_t_sve.c +++ b/kernel/arm64/gemv_t_sve.c @@ -47,48 +47,141 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define SV_DUP svdup_f32 #endif -int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *buffer) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *buffer) { - BLASLONG i; - BLASLONG ix,iy; - BLASLONG j; - FLOAT *a_ptr; - FLOAT temp; + const uint64_t v_size = SV_COUNT(); + const uint64_t v_size2 = v_size * 2; + const svbool_t pg_true = SV_TRUE(); + const BLASLONG n4 = N & -4; + const BLASLONG n2 = N & -2; + const BLASLONG v_m1 = M & -v_size; + const BLASLONG v_m2 = M & -v_size2; - iy = 0; - a_ptr = a; + BLASLONG iy = 0; if (inc_x == 1) { - uint64_t sve_size = SV_COUNT(); - for (j = 0; j < n; j++) { - SV_TYPE temp_vec = SV_DUP(0.0); - i = 0; - svbool_t pg = SV_WHILE(i, m); - while (svptest_any(SV_TRUE(), pg)) { - SV_TYPE a_vec = svld1(pg, a_ptr + i); - SV_TYPE x_vec = svld1(pg, x + i); - temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec); - i += sve_size; - pg = SV_WHILE(i, m); + BLASLONG j = 0; + + for (; j < n4; j += 4) { + SV_TYPE temp_vec1 = SV_DUP(0.0); + SV_TYPE temp_vec2 = SV_DUP(0.0); + SV_TYPE temp_vec3 = SV_DUP(0.0); + SV_TYPE temp_vec4 = SV_DUP(0.0); + BLASLONG i = 0; + + for (; i < v_m1; i += v_size) { + SV_TYPE a_vec1 = svld1(pg_true, a + i); + SV_TYPE a_vec2 = svld1(pg_true, a + i + lda); + SV_TYPE a_vec3 = svld1(pg_true, a + i + lda * 2); + SV_TYPE a_vec4 = svld1(pg_true, a + i + lda * 3); + SV_TYPE x_vec = svld1(pg_true, x + i); + temp_vec1 = svmla_x(pg_true, temp_vec1, a_vec1, x_vec); + temp_vec2 = svmla_x(pg_true, temp_vec2, a_vec2, x_vec); + temp_vec3 = svmla_x(pg_true, temp_vec3, a_vec3, x_vec); + temp_vec4 = svmla_x(pg_true, temp_vec4, a_vec4, x_vec); } - temp = svaddv(SV_TRUE(), temp_vec); + + for (; i < M; i += v_size) { + svbool_t pg = SV_WHILE(i, M); + SV_TYPE a_vec1 = svld1(pg, a + i); + SV_TYPE a_vec2 = svld1(pg, a + i + lda); + SV_TYPE a_vec3 = svld1(pg, a + i + lda * 2); + SV_TYPE a_vec4 = svld1(pg, a + i + lda * 3); + SV_TYPE x_vec = svld1(pg, x + i); + temp_vec1 = svmla_x(pg, temp_vec1, a_vec1, x_vec); + temp_vec2 = svmla_x(pg, temp_vec2, a_vec2, x_vec); + temp_vec3 = svmla_x(pg, temp_vec3, a_vec3, x_vec); + temp_vec4 = svmla_x(pg, temp_vec4, a_vec4, x_vec); + } + + FLOAT temp1 = svaddv(pg_true, temp_vec1); + FLOAT temp2 = svaddv(pg_true, temp_vec2); + FLOAT temp3 = svaddv(pg_true, temp_vec3); + FLOAT temp4 = svaddv(pg_true, temp_vec4); + + y[iy] += alpha * temp1; + y[iy + inc_y] += alpha * temp2; + y[iy + inc_y * 2] += alpha * temp3; + y[iy + inc_y * 3] += alpha * temp4; + + iy += inc_y * 4; + a += lda * 4; + } + + for (; j < n2; j += 2) { + SV_TYPE temp_vec1 = SV_DUP(0.0); + SV_TYPE temp_vec2 = SV_DUP(0.0); + BLASLONG i = 0; + for (; i < v_m1; i += v_size) { + SV_TYPE a_vec1 = svld1(pg_true, a + i); + SV_TYPE a_vec2 = svld1(pg_true, a + lda + i); + SV_TYPE x_vec = svld1(pg_true, x + i); + temp_vec1 = svmla_x(pg_true, temp_vec1, a_vec1, x_vec); + temp_vec2 = svmla_x(pg_true, temp_vec2, a_vec2, x_vec); + } + for (; i < M; i += v_size) { + svbool_t pg = SV_WHILE(i, M); + SV_TYPE a_vec1 = svld1(pg, a + i); + SV_TYPE a_vec2 = svld1(pg, a + lda + i); + SV_TYPE x_vec = svld1(pg, x + i); + temp_vec1 = svmla_x(pg_true, temp_vec1, a_vec1, x_vec); + temp_vec2 = svmla_x(pg_true, temp_vec2, a_vec2, x_vec); + } + FLOAT temp1 = svaddv(pg_true, temp_vec1); + y[iy] += alpha * temp1; + FLOAT temp2 = svaddv(pg_true, temp_vec2); + y[iy + inc_y] += alpha * temp2; + iy += inc_y + inc_y; + a += lda + lda; + } + + for (; j < N; j++) { + SV_TYPE temp_vec = SV_DUP(0.0); + BLASLONG i = 0; + for (; i < v_m2; i += v_size2) { + SV_TYPE a_vec1 = svld1(pg_true, a + i); + SV_TYPE a_vec2 = svld1(pg_true, a + i + v_size); + SV_TYPE x_vec1 = svld1(pg_true, x + i); + SV_TYPE x_vec2 = svld1(pg_true, x + i + v_size); + temp_vec = svadd_x(pg_true, temp_vec, + svadd_x(pg_true, + svmul_x(pg_true, a_vec1, x_vec1), + svmul_x(pg_true, a_vec2, x_vec2) + ) + ); + } + for (; i < v_m1; i += v_size) { + SV_TYPE a_vec = svld1(pg_true, a + i); + SV_TYPE x_vec = svld1(pg_true, x + i); + temp_vec = svmla_x(pg_true, temp_vec, a_vec, x_vec); + } + for (; i < M; i += v_size) { + svbool_t pg = SV_WHILE(i, M); + SV_TYPE a_vec = svld1(pg, a + i); + SV_TYPE x_vec = svld1(pg, x + i); + temp_vec = svmla_x(pg, temp_vec, a_vec, x_vec); + } + FLOAT temp = svaddv(pg_true, temp_vec); y[iy] += alpha * temp; iy += inc_y; - a_ptr += lda; + a += lda; } return(0); } - for (j = 0; j < n; j++) { - temp = 0.0; - ix = 0; - for (i = 0; i < m; i++) { - temp += a_ptr[i] * x[ix]; + BLASLONG j = 0; + for (j = 0; j < N; j++) { + FLOAT temp = 0.0; + BLASLONG ix = 0; + BLASLONG i = 0; + for (i = 0; i < M; i++) { + temp += a[i] * x[ix]; ix += inc_x; } y[iy] += alpha * temp; iy += inc_y; - a_ptr += lda; + a += lda; } + return (0); }