Merge pull request #4802 from markdryan/markdryan/rvv_axpby_incy0

Fix axpby_rvv kernels for cases where inc_y = 0
This commit is contained in:
Martin Kroeker 2024-07-16 14:22:38 +02:00 committed by GitHub
commit b1aa2e1768
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 70 additions and 35 deletions

View File

@ -114,6 +114,11 @@ int CNAME(BLASLONG n, FLOAT alpha, FLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *
vy = VFMULVF_FLOAT(vy, beta, vl);
VSEV_FLOAT (y, vy, vl);
}
} else if (inc_y == 0) {
FLOAT vf = y[0];
for (; n > 0; n--)
vf *= beta;
y[0] = vf;
} else {
BLASLONG stride_y = inc_y * sizeof(FLOAT);
for (size_t vl; n > 0; n -= vl, y += vl*inc_y) {
@ -134,6 +139,13 @@ int CNAME(BLASLONG n, FLOAT alpha, FLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *
vy = VFMACCVF_FLOAT(vy, alpha, vx, vl);
VSEV_FLOAT (y, vy, vl);
}
} else if (inc_y == 0) {
FLOAT vf = y[0];
for (; n > 0; n--) {
vf = (vf * beta) + (x[0] * alpha);
x += inc_x;
}
y[0] = vf;
} else if (1 == inc_x) {
BLASLONG stride_y = inc_y * sizeof(FLOAT);
for (size_t vl; n > 0; n -= vl, x += vl, y += vl*inc_y) {

View File

@ -79,8 +79,10 @@ int CNAME(BLASLONG n, FLOAT alpha_r, FLOAT alpha_i, FLOAT *x, BLASLONG inc_x, FL
BLASLONG stride_x = inc_x2 * sizeof(FLOAT);
BLASLONG stride_y = inc_y2 * sizeof(FLOAT);
BLASLONG ix;
FLOAT_V_T vx0, vx1, vy0, vy1;
FLOAT_VX2_T vxx2, vyx2;
FLOAT temp;
if ( beta_r == 0.0 && beta_i == 0.0)
{
@ -125,53 +127,74 @@ int CNAME(BLASLONG n, FLOAT alpha_r, FLOAT alpha_i, FLOAT *x, BLASLONG inc_x, FL
if ( alpha_r == 0.0 && alpha_i == 0.0 )
{
for (size_t vl; n > 0; n -= vl, y += vl*inc_y2)
{
vl = VSETVL(n);
if ( inc_y == 0 ) {
for (; n > 0; n--)
{
temp = (beta_r * y[0] - beta_i * y[1]);
y[1] = (beta_r * y[1] + beta_i * y[0]);
y[0] = temp;
}
} else {
for (size_t vl; n > 0; n -= vl, y += vl*inc_y2)
{
vl = VSETVL(n);
vyx2 = VLSSEG_FLOAT(y, stride_y, vl);
vy0 = VGET_VX2(vyx2, 0);
vy1 = VGET_VX2(vyx2, 1);
v0 = VFMULVF_FLOAT(vy1, beta_i, vl);
v0 = VFMSACVF_FLOAT(v0, beta_r, vy0, vl);
vyx2 = VLSSEG_FLOAT(y, stride_y, vl);
vy0 = VGET_VX2(vyx2, 0);
vy1 = VGET_VX2(vyx2, 1);
v1 = VFMULVF_FLOAT(vy1, beta_r, vl);
v1 = VFMACCVF_FLOAT(v1, beta_i, vy0, vl);
v0 = VFMULVF_FLOAT(vy1, beta_i, vl);
v0 = VFMSACVF_FLOAT(v0, beta_r, vy0, vl);
v_x2 = VSET_VX2(v_x2, 0, v0);
v_x2 = VSET_VX2(v_x2, 1, v1);
VSSSEG_FLOAT(y, stride_y, v_x2, vl);
v1 = VFMULVF_FLOAT(vy1, beta_r, vl);
v1 = VFMACCVF_FLOAT(v1, beta_i, vy0, vl);
v_x2 = VSET_VX2(v_x2, 0, v0);
v_x2 = VSET_VX2(v_x2, 1, v1);
VSSSEG_FLOAT(y, stride_y, v_x2, vl);
}
}
}
else
{
for (size_t vl; n > 0; n -= vl, x += vl*inc_x2, y += vl*inc_y2)
{
vl = VSETVL(n);
if ( inc_y == 0 ) {
ix = 0;
for (; n > 0; n--) {
temp = (alpha_r * x[ix] - alpha_i * x[ix+1] ) +
(beta_r * y[0] - beta_i * y[1]);
y[1] = (alpha_r * x[ix+1] + alpha_i * x[ix]) +
(beta_r * y[1] + beta_i * y[0]);
y[0] = temp;
ix += inc_x2;
}
} else {
for (size_t vl; n > 0; n -= vl, x += vl*inc_x2, y += vl*inc_y2)
{
vl = VSETVL(n);
vxx2 = VLSSEG_FLOAT(x, stride_x, vl);
vyx2 = VLSSEG_FLOAT(y, stride_y, vl);
vxx2 = VLSSEG_FLOAT(x, stride_x, vl);
vyx2 = VLSSEG_FLOAT(y, stride_y, vl);
vx0 = VGET_VX2(vxx2, 0);
vx1 = VGET_VX2(vxx2, 1);
vy0 = VGET_VX2(vyx2, 0);
vy1 = VGET_VX2(vyx2, 1);
vx0 = VGET_VX2(vxx2, 0);
vx1 = VGET_VX2(vxx2, 1);
vy0 = VGET_VX2(vyx2, 0);
vy1 = VGET_VX2(vyx2, 1);
v0 = VFMULVF_FLOAT(vx0, alpha_r, vl);
v0 = VFNMSACVF_FLOAT(v0, alpha_i, vx1, vl);
v0 = VFMACCVF_FLOAT(v0, beta_r, vy0, vl);
v0 = VFNMSACVF_FLOAT(v0, beta_i, vy1, vl);
v1 = VFMULVF_FLOAT(vx1, alpha_r, vl);
v1 = VFMACCVF_FLOAT(v1, alpha_i, vx0, vl);
v1 = VFMACCVF_FLOAT(v1, beta_r, vy1, vl);
v1 = VFMACCVF_FLOAT(v1, beta_i, vy0, vl);
v0 = VFMULVF_FLOAT(vx0, alpha_r, vl);
v0 = VFNMSACVF_FLOAT(v0, alpha_i, vx1, vl);
v0 = VFMACCVF_FLOAT(v0, beta_r, vy0, vl);
v0 = VFNMSACVF_FLOAT(v0, beta_i, vy1, vl);
v_x2 = VSET_VX2(v_x2, 0, v0);
v_x2 = VSET_VX2(v_x2, 1, v1);
v1 = VFMULVF_FLOAT(vx1, alpha_r, vl);
v1 = VFMACCVF_FLOAT(v1, alpha_i, vx0, vl);
v1 = VFMACCVF_FLOAT(v1, beta_r, vy1, vl);
v1 = VFMACCVF_FLOAT(v1, beta_i, vy0, vl);
VSSSEG_FLOAT(y, stride_y, v_x2, vl);
v_x2 = VSET_VX2(v_x2, 0, v0);
v_x2 = VSET_VX2(v_x2, 1, v1);
VSSSEG_FLOAT(y, stride_y, v_x2, vl);
}
}
}
}