Tweak SVE dot kernel

This changes the SVE dot kernel to only predicate when necessary as well
as streamlining the assembly a bit. The benchmarks seem to indicate this
can improve performance by ~33%.
This commit is contained in:
Chris Sidebottom 2023-12-15 12:50:48 +00:00
parent d2f1594bca
commit 7a4fef4f60
1 changed files with 74 additions and 26 deletions

View File

@ -1,4 +1,5 @@
/*************************************************************************** /***************************************************************************
Copyright (c) 2023, The OpenBLAS Project
Copyright (c) 2022, Arm Ltd Copyright (c) 2022, Arm Ltd
All rights reserved. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
@ -30,37 +31,84 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <arm_sve.h> #include <arm_sve.h>
#ifdef DOUBLE #ifdef DOUBLE
#define SVE_TYPE svfloat64_t #define DTYPE "d"
#define SVE_ZERO svdup_f64(0.0) #define WIDTH "d"
#define SVE_WHILELT svwhilelt_b64 #define SHIFT "3"
#define SVE_ALL svptrue_b64()
#define SVE_WIDTH svcntd()
#else #else
#define SVE_TYPE svfloat32_t #define DTYPE "s"
#define SVE_ZERO svdup_f32(0.0) #define WIDTH "w"
#define SVE_WHILELT svwhilelt_b32 #define SHIFT "2"
#define SVE_ALL svptrue_b32()
#define SVE_WIDTH svcntw()
#endif #endif
static FLOAT dot_kernel_sve(BLASLONG n, FLOAT *x, FLOAT *y) { #define COUNT \
SVE_TYPE acc_a = SVE_ZERO; " cnt"WIDTH" x9 \n"
SVE_TYPE acc_b = SVE_ZERO; #define SETUP_TRUE \
" ptrue p0."DTYPE" \n"
#define OFFSET_INPUTS \
" add x12, %[X_], x9, lsl #"SHIFT" \n" \
" add x13, %[Y_], x9, lsl #"SHIFT" \n"
#define TAIL_WHILE \
" whilelo p1."DTYPE", x8, x0 \n"
#define UPDATE(pg, x,y,out) \
" ld1"WIDTH" { z2."DTYPE" }, "pg"/z, ["x", x8, lsl #"SHIFT"] \n" \
" ld1"WIDTH" { z3."DTYPE" }, "pg"/z, ["y", x8, lsl #"SHIFT"] \n" \
" fmla "out"."DTYPE", "pg"/m, z2."DTYPE", z3."DTYPE" \n"
#define SUM_VECTOR(v) \
" faddv "DTYPE""v", p0, z"v"."DTYPE" \n"
#define RET \
" fadd %"DTYPE"[RET_], "DTYPE"1, "DTYPE"0 \n"
BLASLONG sve_width = SVE_WIDTH; #define DOT_KERNEL \
COUNT \
" mov z1.d, #0 \n" \
" mov z0.d, #0 \n" \
" mov x8, #0 \n" \
" movi d1, #0x0 \n" \
SETUP_TRUE \
" neg x10, x9, lsl #1 \n" \
" ands x11, x10, x0 \n" \
" b.eq .Lskip_2x \n" \
OFFSET_INPUTS \
".Lvector_2x: \n" \
UPDATE("p0", "%[X_]", "%[Y_]", "z1") \
UPDATE("p0", "x12", "x13", "z0") \
" sub x8, x8, x10 \n" \
" cmp x8, x11 \n" \
" b.lo .Lvector_2x \n" \
SUM_VECTOR("1") \
".Lskip_2x: \n" \
" neg x10, x9 \n" \
" and x10, x10, x0 \n" \
" cmp x8, x10 \n" \
" b.hs .Ltail \n" \
".Lvector_1x: \n" \
UPDATE("p0", "%[X_]", "%[Y_]", "z0") \
" add x8, x8, x9 \n" \
" cmp x8, x10 \n" \
" b.lo .Lvector_1x \n" \
".Ltail: \n" \
" cmp x10, x0 \n" \
" b.eq .Lend \n" \
TAIL_WHILE \
UPDATE("p1", "%[X_]", "%[Y_]", "z0") \
".Lend: \n" \
SUM_VECTOR("0") \
RET
for (BLASLONG i = 0; i < n; i += sve_width * 2) { static
svbool_t pg_a = SVE_WHILELT((uint64_t)i, (uint64_t)n); FLOAT
svbool_t pg_b = SVE_WHILELT((uint64_t)(i + sve_width), (uint64_t)n); dot_kernel_sve(BLASLONG n, FLOAT* x, FLOAT* y)
{
FLOAT ret;
SVE_TYPE x_vec_a = svld1(pg_a, &x[i]); asm(DOT_KERNEL
SVE_TYPE y_vec_a = svld1(pg_a, &y[i]); :
SVE_TYPE x_vec_b = svld1(pg_b, &x[i + sve_width]); [RET_] "=&w" (ret)
SVE_TYPE y_vec_b = svld1(pg_b, &y[i + sve_width]); :
[N_] "r" (n),
[X_] "r" (x),
[Y_] "r" (y)
:);
acc_a = svmla_m(pg_a, acc_a, x_vec_a, y_vec_a); return ret;
acc_b = svmla_m(pg_b, acc_b, x_vec_b, y_vec_b);
}
return svaddv(SVE_ALL, acc_a) + svaddv(SVE_ALL, acc_b);
} }