More common code.

This commit is contained in:
Chip Kerchner 2024-09-09 14:41:55 -05:00
parent 39fd29f1de
commit 2f142ee857
3 changed files with 22 additions and 20 deletions

View File

@ -138,6 +138,12 @@ FORCEINLINE vec_f32 vec_loadN_f32(void *src, BLASLONG n)
return (vec_f32)vec_loadN(src, n * (sizeof(FLOAT) / sizeof(IFLOAT)));
}
FORCEINLINE void vec_loadN2_f32(vec_f32 *data, vec_f32 *src, BLASLONG n)
{
data[0] = src[0];
data[1] = vec_loadN_f32(&src[1], n);
}
FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n)
{
FLOAT *dst2 = (FLOAT *)(dst);
@ -160,6 +166,12 @@ FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n)
#endif
}
FORCEINLINE void vec_storeN2_f32(vec_f32 *data, vec_f32 *dst, BLASLONG n)
{
dst[0] = data[0];
vec_storeN_f32(data[1], &dst[1], n);
}
FORCEINLINE vec_f32 vec_mult(vec_f32 *inp, vec_bf16 in0, vec_bf16 zero)
{
vec_f32 v_in00 = BF16_HI(in0, zero);

View File

@ -75,12 +75,10 @@ static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vecto
n &= 7;
if (n > 4) {
BLASLONG n3 = n & 3;
v_inp0[0] = in[(i * 2) + 0];
v_inp0[1] = vec_loadN_f32(&in[(i * 2) + 1], n3);
vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3);
v_inp0[0] *= b;
v_inp0[1] *= b;
out[(i * 2) + 0] = v_inp0[0];
vec_storeN_f32(v_inp0[1], &out[(i * 2) + 1], n3);
vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3);
} else if (n) {
v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n);
v_inp0[0] *= b;

View File

@ -64,13 +64,11 @@ static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
n &= 7;
if (n > 4) {
BLASLONG n3 = n & 3;
vy0[0] = v_y[(i * 2) + 0];
vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3);
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
v_y[(i * 2) + 0] = vy0[0];
vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3);
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
} else if (n) {
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
@ -116,14 +114,12 @@ static void BF16GEMV_N_VSX_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
n &= 7;
if (n > 4) {
BLASLONG n3 = n & 3;
vy0[0] = v_y[(i * 2) + 0];
vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3);
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0);
v_y[(i * 2) + 0] = vy0[0];
vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3);
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
} else if (n) {
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
@ -178,16 +174,14 @@ static void BF16GEMV_N_VSX_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
n &= 7;
if (n > 4) {
BLASLONG n3 = n & 3;
vy0[0] = v_y[(i * 2) + 0];
vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3);
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0);
vec_loadN_mult2(v_x2, &va2[i], n, zero, vy0);
vec_loadN_mult2(v_x3, &va3[i], n, zero, vy0);
v_y[(i * 2) + 0] = vy0[0];
vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3);
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
} else if (n) {
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
@ -263,8 +257,7 @@ static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
n &= 7;
if (n > 4) {
BLASLONG n3 = n & 3;
vy0[0] = v_y[(i * 2) + 0];
vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3);
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0);
@ -275,8 +268,7 @@ static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
vec_loadN_mult2(v_x6, &vb2[i], n, zero, vy0);
vec_loadN_mult2(v_x7, &vb3[i], n, zero, vy0);
v_y[(i * 2) + 0] = vy0[0];
vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3);
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
} else
if (n) {
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);