More common code.
This commit is contained in:
parent
39fd29f1de
commit
2f142ee857
|
@ -138,6 +138,12 @@ FORCEINLINE vec_f32 vec_loadN_f32(void *src, BLASLONG n)
|
||||||
return (vec_f32)vec_loadN(src, n * (sizeof(FLOAT) / sizeof(IFLOAT)));
|
return (vec_f32)vec_loadN(src, n * (sizeof(FLOAT) / sizeof(IFLOAT)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_loadN2_f32(vec_f32 *data, vec_f32 *src, BLASLONG n)
|
||||||
|
{
|
||||||
|
data[0] = src[0];
|
||||||
|
data[1] = vec_loadN_f32(&src[1], n);
|
||||||
|
}
|
||||||
|
|
||||||
FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n)
|
FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n)
|
||||||
{
|
{
|
||||||
FLOAT *dst2 = (FLOAT *)(dst);
|
FLOAT *dst2 = (FLOAT *)(dst);
|
||||||
|
@ -160,6 +166,12 @@ FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n)
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_storeN2_f32(vec_f32 *data, vec_f32 *dst, BLASLONG n)
|
||||||
|
{
|
||||||
|
dst[0] = data[0];
|
||||||
|
vec_storeN_f32(data[1], &dst[1], n);
|
||||||
|
}
|
||||||
|
|
||||||
FORCEINLINE vec_f32 vec_mult(vec_f32 *inp, vec_bf16 in0, vec_bf16 zero)
|
FORCEINLINE vec_f32 vec_mult(vec_f32 *inp, vec_bf16 in0, vec_bf16 zero)
|
||||||
{
|
{
|
||||||
vec_f32 v_in00 = BF16_HI(in0, zero);
|
vec_f32 v_in00 = BF16_HI(in0, zero);
|
||||||
|
|
|
@ -75,12 +75,10 @@ static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vecto
|
||||||
n &= 7;
|
n &= 7;
|
||||||
if (n > 4) {
|
if (n > 4) {
|
||||||
BLASLONG n3 = n & 3;
|
BLASLONG n3 = n & 3;
|
||||||
v_inp0[0] = in[(i * 2) + 0];
|
vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3);
|
||||||
v_inp0[1] = vec_loadN_f32(&in[(i * 2) + 1], n3);
|
|
||||||
v_inp0[0] *= b;
|
v_inp0[0] *= b;
|
||||||
v_inp0[1] *= b;
|
v_inp0[1] *= b;
|
||||||
out[(i * 2) + 0] = v_inp0[0];
|
vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3);
|
||||||
vec_storeN_f32(v_inp0[1], &out[(i * 2) + 1], n3);
|
|
||||||
} else if (n) {
|
} else if (n) {
|
||||||
v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n);
|
v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n);
|
||||||
v_inp0[0] *= b;
|
v_inp0[0] *= b;
|
||||||
|
|
|
@ -64,13 +64,11 @@ static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
||||||
n &= 7;
|
n &= 7;
|
||||||
if (n > 4) {
|
if (n > 4) {
|
||||||
BLASLONG n3 = n & 3;
|
BLASLONG n3 = n & 3;
|
||||||
vy0[0] = v_y[(i * 2) + 0];
|
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3);
|
|
||||||
|
|
||||||
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
|
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
|
||||||
|
|
||||||
v_y[(i * 2) + 0] = vy0[0];
|
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3);
|
|
||||||
} else if (n) {
|
} else if (n) {
|
||||||
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||||
|
|
||||||
|
@ -116,14 +114,12 @@ static void BF16GEMV_N_VSX_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
||||||
n &= 7;
|
n &= 7;
|
||||||
if (n > 4) {
|
if (n > 4) {
|
||||||
BLASLONG n3 = n & 3;
|
BLASLONG n3 = n & 3;
|
||||||
vy0[0] = v_y[(i * 2) + 0];
|
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3);
|
|
||||||
|
|
||||||
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
|
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
|
||||||
vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0);
|
vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0);
|
||||||
|
|
||||||
v_y[(i * 2) + 0] = vy0[0];
|
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3);
|
|
||||||
} else if (n) {
|
} else if (n) {
|
||||||
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||||
|
|
||||||
|
@ -178,16 +174,14 @@ static void BF16GEMV_N_VSX_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
||||||
n &= 7;
|
n &= 7;
|
||||||
if (n > 4) {
|
if (n > 4) {
|
||||||
BLASLONG n3 = n & 3;
|
BLASLONG n3 = n & 3;
|
||||||
vy0[0] = v_y[(i * 2) + 0];
|
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3);
|
|
||||||
|
|
||||||
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
|
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
|
||||||
vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0);
|
vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0);
|
||||||
vec_loadN_mult2(v_x2, &va2[i], n, zero, vy0);
|
vec_loadN_mult2(v_x2, &va2[i], n, zero, vy0);
|
||||||
vec_loadN_mult2(v_x3, &va3[i], n, zero, vy0);
|
vec_loadN_mult2(v_x3, &va3[i], n, zero, vy0);
|
||||||
|
|
||||||
v_y[(i * 2) + 0] = vy0[0];
|
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3);
|
|
||||||
} else if (n) {
|
} else if (n) {
|
||||||
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||||
|
|
||||||
|
@ -263,8 +257,7 @@ static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
|
||||||
n &= 7;
|
n &= 7;
|
||||||
if (n > 4) {
|
if (n > 4) {
|
||||||
BLASLONG n3 = n & 3;
|
BLASLONG n3 = n & 3;
|
||||||
vy0[0] = v_y[(i * 2) + 0];
|
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3);
|
|
||||||
|
|
||||||
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
|
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
|
||||||
vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0);
|
vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0);
|
||||||
|
@ -275,8 +268,7 @@ static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
|
||||||
vec_loadN_mult2(v_x6, &vb2[i], n, zero, vy0);
|
vec_loadN_mult2(v_x6, &vb2[i], n, zero, vy0);
|
||||||
vec_loadN_mult2(v_x7, &vb3[i], n, zero, vy0);
|
vec_loadN_mult2(v_x7, &vb3[i], n, zero, vy0);
|
||||||
|
|
||||||
v_y[(i * 2) + 0] = vy0[0];
|
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3);
|
|
||||||
} else
|
} else
|
||||||
if (n) {
|
if (n) {
|
||||||
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||||
|
|
Loading…
Reference in New Issue