diff --git a/kernel/arm64/KERNEL.NEOVERSEV1 b/kernel/arm64/KERNEL.NEOVERSEV1 index bc5999097..53d157a0a 100644 --- a/kernel/arm64/KERNEL.NEOVERSEV1 +++ b/kernel/arm64/KERNEL.NEOVERSEV1 @@ -1 +1,4 @@ include $(KERNELDIR)/KERNEL.ARMV8SVE + +SGEMVTKERNEL = gemv_t_sve.c +DGEMVTKERNEL = gemv_t_sve.c diff --git a/kernel/arm64/gemv_t.S b/kernel/arm64/gemv_t.S index b04367ab3..a98eef49b 100644 --- a/kernel/arm64/gemv_t.S +++ b/kernel/arm64/gemv_t.S @@ -1,5 +1,5 @@ /******************************************************************************* -Copyright (c) 2015, The OpenBLAS Project +Copyright (c) 2015, 2024 The OpenBLAS Project All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -170,39 +170,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. .macro KERNEL_F32_FINALIZE #if !defined(DOUBLE) - fadd v1.4s, v1.4s, v2.4s + // F8 only has 2 accumulators + // so add into those pairs fadd v1.4s, v1.4s, v3.4s - fadd v1.4s, v1.4s, v4.4s -#else - fadd v1.2d, v1.2d, v2.2d - fadd v1.2d, v1.2d, v3.2d - fadd v1.2d, v1.2d, v4.2d + fadd v2.4s, v2.4s, v4.4s #endif .endm -.macro KERNEL_F4 +.macro KERNEL_F8 #if !defined(DOUBLE) - ld1 {v2.4s}, [A_PTR], #16 - ld1 {v3.4s}, [X_PTR], #16 - fmla v1.4s, v2.4s, v3.4s -#else - ld1 {v2.2d}, [A_PTR], #16 - ld1 {v3.2d}, [X_PTR], #16 - fmla v1.2d, v2.2d, v3.2d - - ld1 {v4.2d}, [A_PTR], #16 - ld1 {v5.2d}, [X_PTR], #16 - fmla v1.2d, v4.2d, v5.2d + ld1 {v13.4s, v14.4s}, [A_PTR], #32 + ld1 {v17.4s, v18.4s}, [X_PTR], #32 + fmla v1.4s, v13.4s, v17.4s + fmla v2.4s, v14.4s, v18.4s +#else + ld1 {v13.2d, v14.2d, v15.2d, v16.2d}, [A_PTR], #64 + ld1 {v17.2d, v18.2d, v19.2d, v20.2d}, [X_PTR], #64 + fmla v1.2d, v13.2d, v17.2d + fmla v2.2d, v14.2d, v18.2d + fmla v3.2d, v15.2d, v19.2d + fmla v4.2d, v16.2d, v20.2d #endif .endm -.macro KERNEL_F4_FINALIZE +.macro KERNEL_F8_FINALIZE #if !defined(DOUBLE) - ext v2.16b, v1.16b, v1.16b, #8 + // Take the top two elements of v1 and + // put them into the first two lanes of v3 + ext v3.16b, v1.16b, v1.16b, #8 + fadd v1.2s, v1.2s, v3.2s + ext v4.16b, v2.16b, v2.16b, #8 + fadd v2.2s, v2.2s, v4.2s + // Final pair fadd v1.2s, v1.2s, v2.2s faddp TEMP, v1.2s #else faddp TEMP, v1.2d + faddp TEMP1, v2.2d + faddp TEMP2, v3.2d + faddp TEMP3, v4.2d + fadd TEMP, TEMP, TEMP1 + fadd TEMP2, TEMP2, TEMP3 + fadd TEMP, TEMP, TEMP2 #endif .endm @@ -258,7 +267,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. asr I, M, #5 cmp I, xzr - beq .Lgemv_t_kernel_F4 + beq .Lgemv_t_kernel_F8 .Lgemv_t_kernel_F320: @@ -269,24 +278,24 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. KERNEL_F32_FINALIZE -.Lgemv_t_kernel_F4: +.Lgemv_t_kernel_F8: ands I, M, #31 - asr I, I, #2 + asr I, I, #3 cmp I, xzr beq .Lgemv_t_kernel_F1 -.Lgemv_t_kernel_F40: +.Lgemv_t_kernel_F80: - KERNEL_F4 + KERNEL_F8 subs I, I, #1 - bne .Lgemv_t_kernel_F40 + bne .Lgemv_t_kernel_F80 .Lgemv_t_kernel_F1: - KERNEL_F4_FINALIZE + KERNEL_F8_FINALIZE - ands I, M, #3 + ands I, M, #7 ble .Lgemv_t_kernel_F_END .Lgemv_t_kernel_F10: diff --git a/kernel/arm64/gemv_t_sve.c b/kernel/arm64/gemv_t_sve.c index ab700a374..183d9c3d1 100644 --- a/kernel/arm64/gemv_t_sve.c +++ b/kernel/arm64/gemv_t_sve.c @@ -59,20 +59,46 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO a_ptr = a; if (inc_x == 1) { + svbool_t pg_true = SV_TRUE(); uint64_t sve_size = SV_COUNT(); + uint64_t sve_size2 = sve_size * 2; + BLASLONG m1 = m & -sve_size; + BLASLONG m2 = m & -sve_size2; + for (j = 0; j < n; j++) { + BLASLONG i = 0; + + SV_TYPE temp_vec_v2_0 = SV_DUP(0.0); + SV_TYPE temp_vec_v2_1 = SV_DUP(0.0); + for (; i < m2; i += sve_size2) { + SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i); + SV_TYPE x_vec0 = svld1(pg_true, x + i); + SV_TYPE a_vec1 = svld1(pg_true, a_ptr + i + sve_size); + SV_TYPE x_vec1 = svld1(pg_true, x + i + sve_size); + temp_vec_v2_0 = svmla_m(pg_true, temp_vec_v2_0, a_vec0, x_vec0); + temp_vec_v2_1 = svmla_m(pg_true, temp_vec_v2_1, a_vec1, x_vec1); + } + + SV_TYPE temp_vec_v1 = SV_DUP(0.0); + for (; i < m1; i += sve_size) { + SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i); + SV_TYPE x_vec0 = svld1(pg_true, x + i); + temp_vec_v1 = svmla_m(pg_true, temp_vec_v1, a_vec0, x_vec0); + } + SV_TYPE temp_vec = SV_DUP(0.0); - i = 0; - svbool_t pg = SV_WHILE(i, m); - while (svptest_any(SV_TRUE(), pg)) { + for (; i < m; i += sve_size) { + svbool_t pg = SV_WHILE(i, m); SV_TYPE a_vec = svld1(pg, a_ptr + i); SV_TYPE x_vec = svld1(pg, x + i); temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec); - i += sve_size; - pg = SV_WHILE(i, m); } - temp = svaddv(SV_TRUE(), temp_vec); - y[iy] += alpha * temp; + + y[iy] += alpha * ( + (svaddv(SV_TRUE(), temp_vec_v2_0) + svaddv(SV_TRUE(), temp_vec)) + + (svaddv(SV_TRUE(), temp_vec_v2_1) + svaddv(SV_TRUE(), temp_vec_v1)) + ); + iy += inc_y; a_ptr += lda; }