Add accumulators to AArch64 GEMV Kernels
This helps to reduce values going missing as we accumulate.
This commit is contained in:
parent
b26424c6a2
commit
ba2e989c67
|
@ -1 +1,4 @@
|
||||||
include $(KERNELDIR)/KERNEL.ARMV8SVE
|
include $(KERNELDIR)/KERNEL.ARMV8SVE
|
||||||
|
|
||||||
|
SGEMVTKERNEL = gemv_t_sve.c
|
||||||
|
DGEMVTKERNEL = gemv_t_sve.c
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
Copyright (c) 2015, The OpenBLAS Project
|
Copyright (c) 2015, 2024 The OpenBLAS Project
|
||||||
All rights reserved.
|
All rights reserved.
|
||||||
Redistribution and use in source and binary forms, with or without
|
Redistribution and use in source and binary forms, with or without
|
||||||
modification, are permitted provided that the following conditions are
|
modification, are permitted provided that the following conditions are
|
||||||
|
@ -170,39 +170,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
.macro KERNEL_F32_FINALIZE
|
.macro KERNEL_F32_FINALIZE
|
||||||
#if !defined(DOUBLE)
|
#if !defined(DOUBLE)
|
||||||
fadd v1.4s, v1.4s, v2.4s
|
// F8 only has 2 accumulators
|
||||||
|
// so add into those pairs
|
||||||
fadd v1.4s, v1.4s, v3.4s
|
fadd v1.4s, v1.4s, v3.4s
|
||||||
fadd v1.4s, v1.4s, v4.4s
|
fadd v2.4s, v2.4s, v4.4s
|
||||||
#else
|
|
||||||
fadd v1.2d, v1.2d, v2.2d
|
|
||||||
fadd v1.2d, v1.2d, v3.2d
|
|
||||||
fadd v1.2d, v1.2d, v4.2d
|
|
||||||
#endif
|
#endif
|
||||||
.endm
|
.endm
|
||||||
|
|
||||||
.macro KERNEL_F4
|
.macro KERNEL_F8
|
||||||
#if !defined(DOUBLE)
|
#if !defined(DOUBLE)
|
||||||
ld1 {v2.4s}, [A_PTR], #16
|
ld1 {v13.4s, v14.4s}, [A_PTR], #32
|
||||||
ld1 {v3.4s}, [X_PTR], #16
|
ld1 {v17.4s, v18.4s}, [X_PTR], #32
|
||||||
fmla v1.4s, v2.4s, v3.4s
|
fmla v1.4s, v13.4s, v17.4s
|
||||||
#else
|
fmla v2.4s, v14.4s, v18.4s
|
||||||
ld1 {v2.2d}, [A_PTR], #16
|
#else
|
||||||
ld1 {v3.2d}, [X_PTR], #16
|
ld1 {v13.2d, v14.2d, v15.2d, v16.2d}, [A_PTR], #64
|
||||||
fmla v1.2d, v2.2d, v3.2d
|
ld1 {v17.2d, v18.2d, v19.2d, v20.2d}, [X_PTR], #64
|
||||||
|
fmla v1.2d, v13.2d, v17.2d
|
||||||
ld1 {v4.2d}, [A_PTR], #16
|
fmla v2.2d, v14.2d, v18.2d
|
||||||
ld1 {v5.2d}, [X_PTR], #16
|
fmla v3.2d, v15.2d, v19.2d
|
||||||
fmla v1.2d, v4.2d, v5.2d
|
fmla v4.2d, v16.2d, v20.2d
|
||||||
#endif
|
#endif
|
||||||
.endm
|
.endm
|
||||||
|
|
||||||
.macro KERNEL_F4_FINALIZE
|
.macro KERNEL_F8_FINALIZE
|
||||||
#if !defined(DOUBLE)
|
#if !defined(DOUBLE)
|
||||||
ext v2.16b, v1.16b, v1.16b, #8
|
// Take the top two elements of v1 and
|
||||||
|
// put them into the first two lanes of v3
|
||||||
|
ext v3.16b, v1.16b, v1.16b, #8
|
||||||
|
fadd v1.2s, v1.2s, v3.2s
|
||||||
|
ext v4.16b, v2.16b, v2.16b, #8
|
||||||
|
fadd v2.2s, v2.2s, v4.2s
|
||||||
|
// Final pair
|
||||||
fadd v1.2s, v1.2s, v2.2s
|
fadd v1.2s, v1.2s, v2.2s
|
||||||
faddp TEMP, v1.2s
|
faddp TEMP, v1.2s
|
||||||
#else
|
#else
|
||||||
faddp TEMP, v1.2d
|
faddp TEMP, v1.2d
|
||||||
|
faddp TEMP1, v2.2d
|
||||||
|
faddp TEMP2, v3.2d
|
||||||
|
faddp TEMP3, v4.2d
|
||||||
|
fadd TEMP, TEMP, TEMP1
|
||||||
|
fadd TEMP2, TEMP2, TEMP3
|
||||||
|
fadd TEMP, TEMP, TEMP2
|
||||||
#endif
|
#endif
|
||||||
.endm
|
.endm
|
||||||
|
|
||||||
|
@ -258,7 +267,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
asr I, M, #5
|
asr I, M, #5
|
||||||
cmp I, xzr
|
cmp I, xzr
|
||||||
beq .Lgemv_t_kernel_F4
|
beq .Lgemv_t_kernel_F8
|
||||||
|
|
||||||
.Lgemv_t_kernel_F320:
|
.Lgemv_t_kernel_F320:
|
||||||
|
|
||||||
|
@ -269,24 +278,24 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
KERNEL_F32_FINALIZE
|
KERNEL_F32_FINALIZE
|
||||||
|
|
||||||
.Lgemv_t_kernel_F4:
|
.Lgemv_t_kernel_F8:
|
||||||
ands I, M, #31
|
ands I, M, #31
|
||||||
asr I, I, #2
|
asr I, I, #3
|
||||||
cmp I, xzr
|
cmp I, xzr
|
||||||
beq .Lgemv_t_kernel_F1
|
beq .Lgemv_t_kernel_F1
|
||||||
|
|
||||||
.Lgemv_t_kernel_F40:
|
.Lgemv_t_kernel_F80:
|
||||||
|
|
||||||
KERNEL_F4
|
KERNEL_F8
|
||||||
|
|
||||||
subs I, I, #1
|
subs I, I, #1
|
||||||
bne .Lgemv_t_kernel_F40
|
bne .Lgemv_t_kernel_F80
|
||||||
|
|
||||||
.Lgemv_t_kernel_F1:
|
.Lgemv_t_kernel_F1:
|
||||||
|
|
||||||
KERNEL_F4_FINALIZE
|
KERNEL_F8_FINALIZE
|
||||||
|
|
||||||
ands I, M, #3
|
ands I, M, #7
|
||||||
ble .Lgemv_t_kernel_F_END
|
ble .Lgemv_t_kernel_F_END
|
||||||
|
|
||||||
.Lgemv_t_kernel_F10:
|
.Lgemv_t_kernel_F10:
|
||||||
|
|
|
@ -59,20 +59,46 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
|
||||||
a_ptr = a;
|
a_ptr = a;
|
||||||
|
|
||||||
if (inc_x == 1) {
|
if (inc_x == 1) {
|
||||||
|
svbool_t pg_true = SV_TRUE();
|
||||||
uint64_t sve_size = SV_COUNT();
|
uint64_t sve_size = SV_COUNT();
|
||||||
|
uint64_t sve_size2 = sve_size * 2;
|
||||||
|
BLASLONG m1 = m & -sve_size;
|
||||||
|
BLASLONG m2 = m & -sve_size2;
|
||||||
|
|
||||||
for (j = 0; j < n; j++) {
|
for (j = 0; j < n; j++) {
|
||||||
|
BLASLONG i = 0;
|
||||||
|
|
||||||
|
SV_TYPE temp_vec_v2_0 = SV_DUP(0.0);
|
||||||
|
SV_TYPE temp_vec_v2_1 = SV_DUP(0.0);
|
||||||
|
for (; i < m2; i += sve_size2) {
|
||||||
|
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
|
||||||
|
SV_TYPE x_vec0 = svld1(pg_true, x + i);
|
||||||
|
SV_TYPE a_vec1 = svld1(pg_true, a_ptr + i + sve_size);
|
||||||
|
SV_TYPE x_vec1 = svld1(pg_true, x + i + sve_size);
|
||||||
|
temp_vec_v2_0 = svmla_m(pg_true, temp_vec_v2_0, a_vec0, x_vec0);
|
||||||
|
temp_vec_v2_1 = svmla_m(pg_true, temp_vec_v2_1, a_vec1, x_vec1);
|
||||||
|
}
|
||||||
|
|
||||||
|
SV_TYPE temp_vec_v1 = SV_DUP(0.0);
|
||||||
|
for (; i < m1; i += sve_size) {
|
||||||
|
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
|
||||||
|
SV_TYPE x_vec0 = svld1(pg_true, x + i);
|
||||||
|
temp_vec_v1 = svmla_m(pg_true, temp_vec_v1, a_vec0, x_vec0);
|
||||||
|
}
|
||||||
|
|
||||||
SV_TYPE temp_vec = SV_DUP(0.0);
|
SV_TYPE temp_vec = SV_DUP(0.0);
|
||||||
i = 0;
|
for (; i < m; i += sve_size) {
|
||||||
svbool_t pg = SV_WHILE(i, m);
|
svbool_t pg = SV_WHILE(i, m);
|
||||||
while (svptest_any(SV_TRUE(), pg)) {
|
|
||||||
SV_TYPE a_vec = svld1(pg, a_ptr + i);
|
SV_TYPE a_vec = svld1(pg, a_ptr + i);
|
||||||
SV_TYPE x_vec = svld1(pg, x + i);
|
SV_TYPE x_vec = svld1(pg, x + i);
|
||||||
temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec);
|
temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec);
|
||||||
i += sve_size;
|
|
||||||
pg = SV_WHILE(i, m);
|
|
||||||
}
|
}
|
||||||
temp = svaddv(SV_TRUE(), temp_vec);
|
|
||||||
y[iy] += alpha * temp;
|
y[iy] += alpha * (
|
||||||
|
(svaddv(SV_TRUE(), temp_vec_v2_0) + svaddv(SV_TRUE(), temp_vec)) +
|
||||||
|
(svaddv(SV_TRUE(), temp_vec_v2_1) + svaddv(SV_TRUE(), temp_vec_v1))
|
||||||
|
);
|
||||||
|
|
||||||
iy += inc_y;
|
iy += inc_y;
|
||||||
a_ptr += lda;
|
a_ptr += lda;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue