From 67bf4b6998d9ca2ec51311655004dc9bb32ea21c Mon Sep 17 00:00:00 2001 From: Mark Ryan Date: Fri, 12 Jul 2024 11:16:48 +0000 Subject: [PATCH] Fix axpby_rvv kernels for cases where inc_y = 0 The following openblas_utest tests fail when the RISCV64_ZVL128B is enabled. TEST 89/103 axpby:zaxpby_inc_0 [FAIL] TEST 92/103 axpby:caxpby_inc_0 [FAIL] TEST 95/103 axpby:daxpby_inc_0 [FAIL] TEST 98/103 axpby:saxpby_inc_0 [FAIL] The issue is that the vectorized kernels do not work when inc_y == 0. This patch updates the kernels to fall back to the scalar algorithms when inc_y == 0, fixing the failing tests. Signed-off-by: Mark Ryan --- kernel/riscv64/axpby_rvv.c | 12 +++++ kernel/riscv64/zaxpby_rvv.c | 93 +++++++++++++++++++++++-------------- 2 files changed, 70 insertions(+), 35 deletions(-) diff --git a/kernel/riscv64/axpby_rvv.c b/kernel/riscv64/axpby_rvv.c index d7fb86eab..27abc0ff4 100644 --- a/kernel/riscv64/axpby_rvv.c +++ b/kernel/riscv64/axpby_rvv.c @@ -114,6 +114,11 @@ int CNAME(BLASLONG n, FLOAT alpha, FLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT * vy = VFMULVF_FLOAT(vy, beta, vl); VSEV_FLOAT (y, vy, vl); } + } else if (inc_y == 0) { + FLOAT vf = y[0]; + for (; n > 0; n--) + vf *= beta; + y[0] = vf; } else { BLASLONG stride_y = inc_y * sizeof(FLOAT); for (size_t vl; n > 0; n -= vl, y += vl*inc_y) { @@ -134,6 +139,13 @@ int CNAME(BLASLONG n, FLOAT alpha, FLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT * vy = VFMACCVF_FLOAT(vy, alpha, vx, vl); VSEV_FLOAT (y, vy, vl); } + } else if (inc_y == 0) { + FLOAT vf = y[0]; + for (; n > 0; n--) { + vf = (vf * beta) + (x[0] * alpha); + x += inc_x; + } + y[0] = vf; } else if (1 == inc_x) { BLASLONG stride_y = inc_y * sizeof(FLOAT); for (size_t vl; n > 0; n -= vl, x += vl, y += vl*inc_y) { diff --git a/kernel/riscv64/zaxpby_rvv.c b/kernel/riscv64/zaxpby_rvv.c index 66e38c1e4..9bf5bdd5b 100644 --- a/kernel/riscv64/zaxpby_rvv.c +++ b/kernel/riscv64/zaxpby_rvv.c @@ -79,8 +79,10 @@ int CNAME(BLASLONG n, FLOAT alpha_r, FLOAT alpha_i, FLOAT *x, BLASLONG inc_x, FL BLASLONG stride_x = inc_x2 * sizeof(FLOAT); BLASLONG stride_y = inc_y2 * sizeof(FLOAT); + BLASLONG ix; FLOAT_V_T vx0, vx1, vy0, vy1; FLOAT_VX2_T vxx2, vyx2; + FLOAT temp; if ( beta_r == 0.0 && beta_i == 0.0) { @@ -125,53 +127,74 @@ int CNAME(BLASLONG n, FLOAT alpha_r, FLOAT alpha_i, FLOAT *x, BLASLONG inc_x, FL if ( alpha_r == 0.0 && alpha_i == 0.0 ) { - for (size_t vl; n > 0; n -= vl, y += vl*inc_y2) - { - vl = VSETVL(n); + if ( inc_y == 0 ) { + for (; n > 0; n--) + { + temp = (beta_r * y[0] - beta_i * y[1]); + y[1] = (beta_r * y[1] + beta_i * y[0]); + y[0] = temp; + } + } else { + for (size_t vl; n > 0; n -= vl, y += vl*inc_y2) + { + vl = VSETVL(n); - vyx2 = VLSSEG_FLOAT(y, stride_y, vl); - vy0 = VGET_VX2(vyx2, 0); - vy1 = VGET_VX2(vyx2, 1); - - v0 = VFMULVF_FLOAT(vy1, beta_i, vl); - v0 = VFMSACVF_FLOAT(v0, beta_r, vy0, vl); + vyx2 = VLSSEG_FLOAT(y, stride_y, vl); + vy0 = VGET_VX2(vyx2, 0); + vy1 = VGET_VX2(vyx2, 1); - v1 = VFMULVF_FLOAT(vy1, beta_r, vl); - v1 = VFMACCVF_FLOAT(v1, beta_i, vy0, vl); + v0 = VFMULVF_FLOAT(vy1, beta_i, vl); + v0 = VFMSACVF_FLOAT(v0, beta_r, vy0, vl); - v_x2 = VSET_VX2(v_x2, 0, v0); - v_x2 = VSET_VX2(v_x2, 1, v1); - VSSSEG_FLOAT(y, stride_y, v_x2, vl); + v1 = VFMULVF_FLOAT(vy1, beta_r, vl); + v1 = VFMACCVF_FLOAT(v1, beta_i, vy0, vl); + + v_x2 = VSET_VX2(v_x2, 0, v0); + v_x2 = VSET_VX2(v_x2, 1, v1); + VSSSEG_FLOAT(y, stride_y, v_x2, vl); + } } } else { - for (size_t vl; n > 0; n -= vl, x += vl*inc_x2, y += vl*inc_y2) - { - vl = VSETVL(n); + if ( inc_y == 0 ) { + ix = 0; + for (; n > 0; n--) { + temp = (alpha_r * x[ix] - alpha_i * x[ix+1] ) + + (beta_r * y[0] - beta_i * y[1]); + y[1] = (alpha_r * x[ix+1] + alpha_i * x[ix]) + + (beta_r * y[1] + beta_i * y[0]); + y[0] = temp; + ix += inc_x2; + } + } else { + for (size_t vl; n > 0; n -= vl, x += vl*inc_x2, y += vl*inc_y2) + { + vl = VSETVL(n); - vxx2 = VLSSEG_FLOAT(x, stride_x, vl); - vyx2 = VLSSEG_FLOAT(y, stride_y, vl); + vxx2 = VLSSEG_FLOAT(x, stride_x, vl); + vyx2 = VLSSEG_FLOAT(y, stride_y, vl); - vx0 = VGET_VX2(vxx2, 0); - vx1 = VGET_VX2(vxx2, 1); - vy0 = VGET_VX2(vyx2, 0); - vy1 = VGET_VX2(vyx2, 1); + vx0 = VGET_VX2(vxx2, 0); + vx1 = VGET_VX2(vxx2, 1); + vy0 = VGET_VX2(vyx2, 0); + vy1 = VGET_VX2(vyx2, 1); - v0 = VFMULVF_FLOAT(vx0, alpha_r, vl); - v0 = VFNMSACVF_FLOAT(v0, alpha_i, vx1, vl); - v0 = VFMACCVF_FLOAT(v0, beta_r, vy0, vl); - v0 = VFNMSACVF_FLOAT(v0, beta_i, vy1, vl); - - v1 = VFMULVF_FLOAT(vx1, alpha_r, vl); - v1 = VFMACCVF_FLOAT(v1, alpha_i, vx0, vl); - v1 = VFMACCVF_FLOAT(v1, beta_r, vy1, vl); - v1 = VFMACCVF_FLOAT(v1, beta_i, vy0, vl); + v0 = VFMULVF_FLOAT(vx0, alpha_r, vl); + v0 = VFNMSACVF_FLOAT(v0, alpha_i, vx1, vl); + v0 = VFMACCVF_FLOAT(v0, beta_r, vy0, vl); + v0 = VFNMSACVF_FLOAT(v0, beta_i, vy1, vl); - v_x2 = VSET_VX2(v_x2, 0, v0); - v_x2 = VSET_VX2(v_x2, 1, v1); + v1 = VFMULVF_FLOAT(vx1, alpha_r, vl); + v1 = VFMACCVF_FLOAT(v1, alpha_i, vx0, vl); + v1 = VFMACCVF_FLOAT(v1, beta_r, vy1, vl); + v1 = VFMACCVF_FLOAT(v1, beta_i, vy0, vl); - VSSSEG_FLOAT(y, stride_y, v_x2, vl); + v_x2 = VSET_VX2(v_x2, 0, v0); + v_x2 = VSET_VX2(v_x2, 1, v1); + + VSSSEG_FLOAT(y, stride_y, v_x2, vl); + } } } }