Add accumulators to AArch64 GEMV Kernels

This helps to reduce values going missing as we accumulate.
This commit is contained in:
Chris Sidebottom 2024-07-31 13:07:35 +01:00
parent b26424c6a2
commit ba2e989c67
3 changed files with 74 additions and 36 deletions

View File

@ -1 +1,4 @@
include $(KERNELDIR)/KERNEL.ARMV8SVE
SGEMVTKERNEL = gemv_t_sve.c
DGEMVTKERNEL = gemv_t_sve.c

View File

@ -1,5 +1,5 @@
/*******************************************************************************
Copyright (c) 2015, The OpenBLAS Project
Copyright (c) 2015, 2024 The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@ -170,39 +170,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
.macro KERNEL_F32_FINALIZE
#if !defined(DOUBLE)
fadd v1.4s, v1.4s, v2.4s
// F8 only has 2 accumulators
// so add into those pairs
fadd v1.4s, v1.4s, v3.4s
fadd v1.4s, v1.4s, v4.4s
#else
fadd v1.2d, v1.2d, v2.2d
fadd v1.2d, v1.2d, v3.2d
fadd v1.2d, v1.2d, v4.2d
fadd v2.4s, v2.4s, v4.4s
#endif
.endm
.macro KERNEL_F4
.macro KERNEL_F8
#if !defined(DOUBLE)
ld1 {v2.4s}, [A_PTR], #16
ld1 {v3.4s}, [X_PTR], #16
fmla v1.4s, v2.4s, v3.4s
#else
ld1 {v2.2d}, [A_PTR], #16
ld1 {v3.2d}, [X_PTR], #16
fmla v1.2d, v2.2d, v3.2d
ld1 {v4.2d}, [A_PTR], #16
ld1 {v5.2d}, [X_PTR], #16
fmla v1.2d, v4.2d, v5.2d
ld1 {v13.4s, v14.4s}, [A_PTR], #32
ld1 {v17.4s, v18.4s}, [X_PTR], #32
fmla v1.4s, v13.4s, v17.4s
fmla v2.4s, v14.4s, v18.4s
#else
ld1 {v13.2d, v14.2d, v15.2d, v16.2d}, [A_PTR], #64
ld1 {v17.2d, v18.2d, v19.2d, v20.2d}, [X_PTR], #64
fmla v1.2d, v13.2d, v17.2d
fmla v2.2d, v14.2d, v18.2d
fmla v3.2d, v15.2d, v19.2d
fmla v4.2d, v16.2d, v20.2d
#endif
.endm
.macro KERNEL_F4_FINALIZE
.macro KERNEL_F8_FINALIZE
#if !defined(DOUBLE)
ext v2.16b, v1.16b, v1.16b, #8
// Take the top two elements of v1 and
// put them into the first two lanes of v3
ext v3.16b, v1.16b, v1.16b, #8
fadd v1.2s, v1.2s, v3.2s
ext v4.16b, v2.16b, v2.16b, #8
fadd v2.2s, v2.2s, v4.2s
// Final pair
fadd v1.2s, v1.2s, v2.2s
faddp TEMP, v1.2s
#else
faddp TEMP, v1.2d
faddp TEMP1, v2.2d
faddp TEMP2, v3.2d
faddp TEMP3, v4.2d
fadd TEMP, TEMP, TEMP1
fadd TEMP2, TEMP2, TEMP3
fadd TEMP, TEMP, TEMP2
#endif
.endm
@ -258,7 +267,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
asr I, M, #5
cmp I, xzr
beq .Lgemv_t_kernel_F4
beq .Lgemv_t_kernel_F8
.Lgemv_t_kernel_F320:
@ -269,24 +278,24 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
KERNEL_F32_FINALIZE
.Lgemv_t_kernel_F4:
.Lgemv_t_kernel_F8:
ands I, M, #31
asr I, I, #2
asr I, I, #3
cmp I, xzr
beq .Lgemv_t_kernel_F1
.Lgemv_t_kernel_F40:
.Lgemv_t_kernel_F80:
KERNEL_F4
KERNEL_F8
subs I, I, #1
bne .Lgemv_t_kernel_F40
bne .Lgemv_t_kernel_F80
.Lgemv_t_kernel_F1:
KERNEL_F4_FINALIZE
KERNEL_F8_FINALIZE
ands I, M, #3
ands I, M, #7
ble .Lgemv_t_kernel_F_END
.Lgemv_t_kernel_F10:

View File

@ -59,20 +59,46 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
a_ptr = a;
if (inc_x == 1) {
svbool_t pg_true = SV_TRUE();
uint64_t sve_size = SV_COUNT();
uint64_t sve_size2 = sve_size * 2;
BLASLONG m1 = m & -sve_size;
BLASLONG m2 = m & -sve_size2;
for (j = 0; j < n; j++) {
BLASLONG i = 0;
SV_TYPE temp_vec_v2_0 = SV_DUP(0.0);
SV_TYPE temp_vec_v2_1 = SV_DUP(0.0);
for (; i < m2; i += sve_size2) {
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
SV_TYPE x_vec0 = svld1(pg_true, x + i);
SV_TYPE a_vec1 = svld1(pg_true, a_ptr + i + sve_size);
SV_TYPE x_vec1 = svld1(pg_true, x + i + sve_size);
temp_vec_v2_0 = svmla_m(pg_true, temp_vec_v2_0, a_vec0, x_vec0);
temp_vec_v2_1 = svmla_m(pg_true, temp_vec_v2_1, a_vec1, x_vec1);
}
SV_TYPE temp_vec_v1 = SV_DUP(0.0);
for (; i < m1; i += sve_size) {
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
SV_TYPE x_vec0 = svld1(pg_true, x + i);
temp_vec_v1 = svmla_m(pg_true, temp_vec_v1, a_vec0, x_vec0);
}
SV_TYPE temp_vec = SV_DUP(0.0);
i = 0;
svbool_t pg = SV_WHILE(i, m);
while (svptest_any(SV_TRUE(), pg)) {
for (; i < m; i += sve_size) {
svbool_t pg = SV_WHILE(i, m);
SV_TYPE a_vec = svld1(pg, a_ptr + i);
SV_TYPE x_vec = svld1(pg, x + i);
temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec);
i += sve_size;
pg = SV_WHILE(i, m);
}
temp = svaddv(SV_TRUE(), temp_vec);
y[iy] += alpha * temp;
y[iy] += alpha * (
(svaddv(SV_TRUE(), temp_vec_v2_0) + svaddv(SV_TRUE(), temp_vec)) +
(svaddv(SV_TRUE(), temp_vec_v2_1) + svaddv(SV_TRUE(), temp_vec_v1))
);
iy += inc_y;
a_ptr += lda;
}