diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c index b11ab59de..1893eba51 100644 --- a/kernel/power/sbgemv_common.c +++ b/kernel/power/sbgemv_common.c @@ -138,6 +138,12 @@ FORCEINLINE vec_f32 vec_loadN_f32(void *src, BLASLONG n) return (vec_f32)vec_loadN(src, n * (sizeof(FLOAT) / sizeof(IFLOAT))); } +FORCEINLINE void vec_loadN2_f32(vec_f32 *data, vec_f32 *src, BLASLONG n) +{ + data[0] = src[0]; + data[1] = vec_loadN_f32(&src[1], n); +} + FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n) { FLOAT *dst2 = (FLOAT *)(dst); @@ -160,6 +166,12 @@ FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n) #endif } +FORCEINLINE void vec_storeN2_f32(vec_f32 *data, vec_f32 *dst, BLASLONG n) +{ + dst[0] = data[0]; + vec_storeN_f32(data[1], &dst[1], n); +} + FORCEINLINE vec_f32 vec_mult(vec_f32 *inp, vec_bf16 in0, vec_bf16 zero) { vec_f32 v_in00 = BF16_HI(in0, zero); diff --git a/kernel/power/sbgemv_n.c b/kernel/power/sbgemv_n.c index fa7df858f..05c02a006 100644 --- a/kernel/power/sbgemv_n.c +++ b/kernel/power/sbgemv_n.c @@ -75,12 +75,10 @@ static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vecto n &= 7; if (n > 4) { BLASLONG n3 = n & 3; - v_inp0[0] = in[(i * 2) + 0]; - v_inp0[1] = vec_loadN_f32(&in[(i * 2) + 1], n3); + vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3); v_inp0[0] *= b; v_inp0[1] *= b; - out[(i * 2) + 0] = v_inp0[0]; - vec_storeN_f32(v_inp0[1], &out[(i * 2) + 1], n3); + vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3); } else if (n) { v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n); v_inp0[0] *= b; diff --git a/kernel/power/sbgemv_n_vsx.c b/kernel/power/sbgemv_n_vsx.c index ddbf908b3..45570950e 100644 --- a/kernel/power/sbgemv_n_vsx.c +++ b/kernel/power/sbgemv_n_vsx.c @@ -64,13 +64,11 @@ static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA n &= 7; if (n > 4) { BLASLONG n3 = n & 3; - vy0[0] = v_y[(i * 2) + 0]; - vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); - v_y[(i * 2) + 0] = vy0[0]; - vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); @@ -116,14 +114,12 @@ static void BF16GEMV_N_VSX_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA n &= 7; if (n > 4) { BLASLONG n3 = n & 3; - vy0[0] = v_y[(i * 2) + 0]; - vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); - v_y[(i * 2) + 0] = vy0[0]; - vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); @@ -178,16 +174,14 @@ static void BF16GEMV_N_VSX_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA n &= 7; if (n > 4) { BLASLONG n3 = n & 3; - vy0[0] = v_y[(i * 2) + 0]; - vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); vec_loadN_mult2(v_x2, &va2[i], n, zero, vy0); vec_loadN_mult2(v_x3, &va3[i], n, zero, vy0); - v_y[(i * 2) + 0] = vy0[0]; - vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); @@ -263,8 +257,7 @@ static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS n &= 7; if (n > 4) { BLASLONG n3 = n & 3; - vy0[0] = v_y[(i * 2) + 0]; - vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); @@ -275,8 +268,7 @@ static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS vec_loadN_mult2(v_x6, &vb2[i], n, zero, vy0); vec_loadN_mult2(v_x7, &vb3[i], n, zero, vy0); - v_y[(i * 2) + 0] = vy0[0]; - vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);