More MMA BF16 GEMV code.
This commit is contained in:
parent
c9ce37d527
commit
05aa63e738
|
@ -29,6 +29,10 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
#define SBGEMV_COMMON_MMA_C
|
||||
#include "sbgemv_common.c"
|
||||
|
||||
#if defined(_AIX) || defined(__clang__)
|
||||
#define USE_MERGE_MMA
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in0 = (vec_bf16)vec_load_vec(in);
|
||||
|
@ -69,11 +73,13 @@ FORCEINLINE void vec_mult2_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
|
|||
__builtin_mma_xvbf16ger2(&out[1], (vec_uc8)inp, (vec_uc8)in01);
|
||||
}
|
||||
|
||||
#ifndef USE_MERGE_MMA
|
||||
FORCEINLINE void vec_mult4_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 inp)
|
||||
{
|
||||
vec_mult2_mma(out + 0, in0[0], inp);
|
||||
vec_mult2_mma(out + 2, in0[1], inp);
|
||||
}
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_loadN_mult11_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
|
@ -96,6 +102,7 @@ FORCEINLINE void vec_load_mult12_mma(__vector_quad *out, vec_bf16 *in, vec_bf16
|
|||
vec_mult2_mma(out, in0, inp);
|
||||
}
|
||||
|
||||
#ifndef USE_MERGE_MMA
|
||||
FORCEINLINE void vec_load_mult18_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in0[4];
|
||||
|
@ -106,6 +113,7 @@ FORCEINLINE void vec_load_mult18_mma(__vector_quad *out, vec_bf16 *in, vec_bf16
|
|||
vec_mult4_mma(&out[0], in0 + 0, inp);
|
||||
vec_mult4_mma(&out[4], in0 + 2, inp);
|
||||
}
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_reduce1_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
|
||||
{
|
||||
|
@ -120,6 +128,7 @@ FORCEINLINE void vec_reduce2_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_al
|
|||
vec_reduce1_mma(&out[1], &temp[4], v_alpha, &vy0[1]);
|
||||
}
|
||||
|
||||
#ifndef USE_MERGE_MMA
|
||||
FORCEINLINE void vec_reduce8_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
|
||||
{
|
||||
vec_reduce2_mma(&out[0], &temp[0], v_alpha, vy0 + 0);
|
||||
|
@ -127,6 +136,23 @@ FORCEINLINE void vec_reduce8_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_al
|
|||
vec_reduce2_mma(&out[4], &temp[16], v_alpha, vy0 + 4);
|
||||
vec_reduce2_mma(&out[6], &temp[24], v_alpha, vy0 + 6);
|
||||
}
|
||||
#else
|
||||
FORCEINLINE void vec_reduce44_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
|
||||
{
|
||||
__builtin_mma_disassemble_acc((void*)temp, &out[0]);
|
||||
|
||||
vy0[0] += (temp[0] * v_alpha);
|
||||
vy0[2] += (temp[1] * v_alpha);
|
||||
vy0[4] += (temp[2] * v_alpha);
|
||||
vy0[6] += (temp[3] * v_alpha);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_reduce84_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
|
||||
{
|
||||
vec_reduce44_mma(&out[0], &temp[0], v_alpha, vy0 + 0);
|
||||
vec_reduce44_mma(&out[1], &temp[4], v_alpha, vy0 + 1);
|
||||
}
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_mult11a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
|
||||
{
|
||||
|
@ -166,18 +192,25 @@ FORCEINLINE void vec_load_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf1
|
|||
vec_mult2a_mma(out, in0, in1, inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult28a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
|
||||
FORCEINLINE void vec_load4_mma(vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *ina, vec_bf16 *inb)
|
||||
{
|
||||
vec_bf16 in0[4], in1[4];
|
||||
|
||||
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(ina + 0));
|
||||
vec_load_pair((vec_f32 *)(in1 + 0), (vec_f32 *)(inb + 0));
|
||||
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(ina + 2));
|
||||
vec_load_pair((vec_f32 *)(in1 + 2), (vec_f32 *)(inb + 2));
|
||||
}
|
||||
|
||||
#ifndef USE_MERGE_MMA
|
||||
FORCEINLINE void vec_load_mult28a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in0[4], in1[4];
|
||||
|
||||
vec_load4_mma(in0, in1, ina, inb);
|
||||
|
||||
vec_mult4a_mma(&out[0], in0 + 0, in1 + 0, inp);
|
||||
vec_mult4a_mma(&out[4], in0 + 2, in1 + 2, inp);
|
||||
}
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_loadN_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
|
@ -209,6 +242,48 @@ FORCEINLINE void vec_mult4b_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1
|
|||
vec_mult2b_mma(out + 2, in0[1], in1[1], inp);
|
||||
}
|
||||
|
||||
#ifdef USE_MERGE_MMA
|
||||
FORCEINLINE void vec_mult1c_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in00 = vec_mergeh(in0, in0);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)inp, (vec_uc8)in00);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult2c_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in01 = vec_mergel(in0, in0);
|
||||
|
||||
vec_mult1c_mma(&out[0], in0, inp);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(&out[1], (vec_uc8)inp, (vec_uc8)in01);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult44_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
|
||||
{
|
||||
vec_mult2_mma(out, in[0], inp[0]);
|
||||
vec_mult2c_mma(out, in[1], inp[1]);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult44c_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
|
||||
{
|
||||
vec_mult2c_mma(out, in[0], inp[0]);
|
||||
vec_mult2c_mma(out, in[1], inp[1]);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult44a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
|
||||
{
|
||||
vec_mult2a_mma(out, in0[0], in1[0], inp[0]);
|
||||
vec_mult2b_mma(out, in0[1], in1[1], inp[1]);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult44b_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
|
||||
{
|
||||
vec_mult2b_mma(out, in0[0], in1[0], inp[0]);
|
||||
vec_mult2b_mma(out, in0[1], in1[1], inp[1]);
|
||||
}
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_loadN_mult11b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
vec_bf16 in0 = vec_loadN(ina, n);
|
||||
|
@ -225,18 +300,48 @@ FORCEINLINE void vec_load_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf1
|
|||
vec_mult2b_mma(out, in0, in1, inp);
|
||||
}
|
||||
|
||||
#ifndef USE_MERGE_MMA
|
||||
FORCEINLINE void vec_load_mult28b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in0[4], in1[4];
|
||||
|
||||
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(ina + 0));
|
||||
vec_load_pair((vec_f32 *)(in1 + 0), (vec_f32 *)(inb + 0));
|
||||
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(ina + 2));
|
||||
vec_load_pair((vec_f32 *)(in1 + 2), (vec_f32 *)(inb + 2));
|
||||
vec_load4_mma(in0, in1, ina, inb);
|
||||
|
||||
vec_mult4b_mma(&out[0], in0 + 0, in1 + 0, inp);
|
||||
vec_mult4b_mma(&out[4], in0 + 2, in1 + 2, inp);
|
||||
}
|
||||
#else
|
||||
FORCEINLINE void vec_load_mult184_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in0[4];
|
||||
|
||||
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0));
|
||||
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2));
|
||||
|
||||
vec_mult44_mma(out, in0 + 0, inp + 0);
|
||||
vec_mult44c_mma(out, in0 + 2, inp + 2);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult284a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in0[4], in1[4];
|
||||
|
||||
vec_load4_mma(in0, in1, ina, inb);
|
||||
|
||||
vec_mult44a_mma(out, in0 + 0, in1 + 0, inp + 0);
|
||||
vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult284b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in0[4], in1[4];
|
||||
|
||||
vec_load4_mma(in0, in1, ina, inb);
|
||||
|
||||
vec_mult44b_mma(out, in0 + 0, in1 + 0, inp + 0);
|
||||
vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2);
|
||||
}
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_loadN_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
|
@ -262,4 +367,64 @@ FORCEINLINE void vec_store4_pair(vec_f32 *v_y, vec_f32 *vy0)
|
|||
vec_store_pair(v_y + 6, vy0 + 6);
|
||||
}
|
||||
|
||||
#ifdef USE_MERGE_MMA
|
||||
FORCEINLINE void vec_load8_pair(vec_f32 *vy0, vec_f32 *v_y)
|
||||
{
|
||||
vec_load4_pair(vy0 + 0, v_y + 0);
|
||||
vec_load4_pair(vy0 + 8, v_y + 8);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_store8_pair(vec_f32 *v_y, vec_f32 *vy0)
|
||||
{
|
||||
vec_store4_pair(v_y + 0, vy0 + 0);
|
||||
vec_store4_pair(v_y + 8, vy0 + 8);
|
||||
}
|
||||
|
||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
#define VEC_SHIFT(data, shift) vec_sld(data, data, 16 - shift)
|
||||
#else
|
||||
#define VEC_SHIFT(data, shift) vec_sld(data, data, shift)
|
||||
#endif
|
||||
|
||||
typedef __vector unsigned int vec_ui32;
|
||||
|
||||
static vec_ui32 mask_0 = { 0xffffffff, 0x00000000, 0x00000000, 0x00000000 };
|
||||
static vec_ui32 mask_1 = { 0x00000000, 0xffffffff, 0x00000000, 0x00000000 };
|
||||
static vec_ui32 mask_2 = { 0x00000000, 0x00000000, 0xffffffff, 0x00000000 };
|
||||
static vec_ui32 mask_3 = { 0x00000000, 0x00000000, 0x00000000, 0xffffffff };
|
||||
|
||||
FORCEINLINE void vec_make_mult1(vec_bf16 *v_x0)
|
||||
{
|
||||
v_x0[ 0] = vec_and(v_x0[0], (vec_bf16)mask_0);
|
||||
|
||||
v_x0[ 1] = VEC_SHIFT(v_x0[ 0], 4);
|
||||
v_x0[ 2] = VEC_SHIFT(v_x0[ 0], 8);
|
||||
v_x0[ 3] = VEC_SHIFT(v_x0[ 0], 12);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_make_mult2(vec_bf16 *v_x0)
|
||||
{
|
||||
v_x0[ 5] = vec_and(v_x0[0], (vec_bf16)mask_1);
|
||||
vec_make_mult1(v_x0);
|
||||
|
||||
v_x0[ 4] = VEC_SHIFT(v_x0[ 5], 12);
|
||||
v_x0[ 6] = VEC_SHIFT(v_x0[ 5], 4);
|
||||
v_x0[ 7] = VEC_SHIFT(v_x0[ 5], 8);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_make_mult4(vec_bf16 *v_x0)
|
||||
{
|
||||
v_x0[10] = vec_and(v_x0[0], (vec_bf16)mask_2);
|
||||
v_x0[15] = vec_and(v_x0[0], (vec_bf16)mask_3);
|
||||
vec_make_mult2(v_x0);
|
||||
|
||||
v_x0[ 8] = VEC_SHIFT(v_x0[10], 8);
|
||||
v_x0[ 9] = VEC_SHIFT(v_x0[10], 12);
|
||||
v_x0[11] = VEC_SHIFT(v_x0[10], 4);
|
||||
v_x0[12] = VEC_SHIFT(v_x0[15], 4);
|
||||
v_x0[13] = VEC_SHIFT(v_x0[15], 8);
|
||||
v_x0[14] = VEC_SHIFT(v_x0[15], 12);
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
|
|
@ -28,7 +28,9 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
#ifndef SBGEMV_N_MMA_C
|
||||
#define SBGEMV_N_MMA_C
|
||||
|
||||
#if !defined(_AIX) || defined(__clang__)
|
||||
#define USE_BFGEMV_N_MMA
|
||||
#endif
|
||||
|
||||
#ifdef USE_BFGEMV_N_MMA
|
||||
#include "sbgemv_common_power10.c"
|
||||
|
@ -47,7 +49,7 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
|||
{
|
||||
IFLOAT *a0;
|
||||
__vector_quad temp[2*4];
|
||||
vec_f32 temp0[8*4], vy0[2*4];
|
||||
vec_f32 temp0[8*4];
|
||||
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||
|
||||
a0 = ap[0];
|
||||
|
@ -55,26 +57,61 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
|||
vec_bf16 *va0 = (vec_bf16 *)a0;
|
||||
|
||||
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||
vec_bf16 v_x0 = vec_loadN(x_bf, 1);
|
||||
|
||||
vec_f32 *v_y = (vec_f32 *)y;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
#ifdef USE_MERGE_MMA
|
||||
vec_bf16 v_x0[4];
|
||||
v_x0[0] = vec_loadN(x_bf, 1);
|
||||
vec_f32 vy0[2*4*2];
|
||||
|
||||
vec_make_mult1(v_x0);
|
||||
|
||||
for (; i + 8 <= n8; i += 8) {
|
||||
vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]);
|
||||
vec_load_mult184_mma(&temp[2], &va0[i + 4], &v_x0[ 0]);
|
||||
|
||||
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
|
||||
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8);
|
||||
|
||||
vec_store8_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
if (n8 & 4) {
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]);
|
||||
|
||||
vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
|
||||
i += 4;
|
||||
}
|
||||
#else
|
||||
vec_bf16 v_x0[1];
|
||||
v_x0[0] = vec_loadN(x_bf, 1);
|
||||
vec_f32 vy0[2*4];
|
||||
|
||||
for (; i + 4 <= n8; i += 4) {
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult18_mma(&temp[0], &va0[i + 0], v_x0);
|
||||
vec_load_mult18_mma(&temp[0], &va0[i + 0], v_x0[ 0]);
|
||||
|
||||
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult12_mma(&temp[0], &va0[i], v_x0);
|
||||
vec_load_mult12_mma(&temp[0], &va0[i], v_x0[ 0]);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
|
@ -86,7 +123,7 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
|||
BLASLONG n3 = n & 3;
|
||||
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
|
||||
vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0, n);
|
||||
vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0[ 0], n);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
|
@ -94,7 +131,7 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
|||
} else if (n) {
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0, n);
|
||||
vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0[ 0], n);
|
||||
|
||||
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
|
@ -106,7 +143,7 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
|||
{
|
||||
IFLOAT *a0, *a1;
|
||||
__vector_quad temp[2*4];
|
||||
vec_f32 temp0[8*4], vy0[2*4];
|
||||
vec_f32 temp0[8*4];
|
||||
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||
|
||||
a0 = ap[0];
|
||||
|
@ -116,26 +153,61 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
|||
vec_bf16 *va1 = (vec_bf16 *)a1;
|
||||
|
||||
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||
vec_bf16 v_x0 = vec_loadN(x_bf, 2);
|
||||
|
||||
vec_f32 *v_y = (vec_f32 *)y;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
#ifdef USE_MERGE_MMA
|
||||
vec_bf16 v_x0[4];
|
||||
vec_f32 vy0[2*4*2];
|
||||
v_x0[0] = vec_loadN(x_bf, 2);
|
||||
|
||||
vec_make_mult1(v_x0);
|
||||
|
||||
for (; i + 8 <= n8; i += 8) {
|
||||
vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
|
||||
vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]);
|
||||
|
||||
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
|
||||
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8);
|
||||
|
||||
vec_store8_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
if (n8 & 4) {
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
|
||||
|
||||
vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
|
||||
i += 4;
|
||||
}
|
||||
#else
|
||||
vec_bf16 v_x0[1];
|
||||
vec_f32 vy0[2*4];
|
||||
v_x0[0] = vec_loadN(x_bf, 2);
|
||||
|
||||
for (; i + 4 <= n8; i += 4) {
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0);
|
||||
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]);
|
||||
|
||||
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0);
|
||||
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
|
@ -147,7 +219,7 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
|||
BLASLONG n3 = n & 3;
|
||||
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
|
||||
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0, n);
|
||||
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
|
@ -155,7 +227,7 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
|||
} else if (n) {
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0, n);
|
||||
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
|
||||
|
||||
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
|
@ -167,7 +239,7 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
|||
{
|
||||
IFLOAT *a0, *a1, *a2, *a3;
|
||||
__vector_quad temp[2*4];
|
||||
vec_f32 temp0[8*4], vy0[2*4];
|
||||
vec_f32 temp0[8*4];
|
||||
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||
|
||||
a0 = ap[0];
|
||||
|
@ -181,30 +253,68 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
|||
vec_bf16 *va3 = (vec_bf16 *)a3;
|
||||
|
||||
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||
vec_bf16 v_x00 = vec_loadN(x_bf, 4);
|
||||
|
||||
vec_bf16 v_x01 = (vec_bf16)vec_splat((vec_f32)v_x00, 1);
|
||||
|
||||
vec_f32 *v_y = (vec_f32 *)y;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
#ifdef USE_MERGE_MMA
|
||||
vec_bf16 v_x0[8];
|
||||
vec_f32 vy0[2*4*2];
|
||||
v_x0[0] = vec_loadN(x_bf, 4);
|
||||
|
||||
vec_make_mult2(v_x0);
|
||||
|
||||
for (; i + 8 <= n8; i += 8) {
|
||||
vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
|
||||
vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]);
|
||||
vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]);
|
||||
vec_load_mult284b_mma(&temp[2], &va2[i + 4], &va3[i + 4], &v_x0[ 4]);
|
||||
|
||||
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
|
||||
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8);
|
||||
|
||||
vec_store8_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
if (n8 & 4) {
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
|
||||
vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]);
|
||||
|
||||
vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
|
||||
i += 4;
|
||||
}
|
||||
#else
|
||||
vec_bf16 v_x0[5];
|
||||
vec_f32 vy0[2*4];
|
||||
v_x0[0] = vec_loadN(x_bf, 4);
|
||||
|
||||
v_x0[ 4] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 1);
|
||||
|
||||
for (; i + 4 <= n8; i += 4) {
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x00);
|
||||
vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x01);
|
||||
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]);
|
||||
vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x0[ 4]);
|
||||
|
||||
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00);
|
||||
vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01);
|
||||
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]);
|
||||
vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4]);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
|
@ -216,8 +326,8 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
|||
BLASLONG n3 = n & 3;
|
||||
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
|
||||
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00, n);
|
||||
vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01, n);
|
||||
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
|
||||
vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
|
@ -225,8 +335,8 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
|||
} else if (n) {
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x00, n);
|
||||
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x01, n);
|
||||
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
|
||||
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n);
|
||||
|
||||
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
|
@ -239,7 +349,7 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
|
|||
{
|
||||
IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3;
|
||||
__vector_quad temp[2*4];
|
||||
vec_f32 temp0[8*4], vy0[2*4];
|
||||
vec_f32 temp0[8*4];
|
||||
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||
|
||||
a0 = ap[0];
|
||||
|
@ -261,36 +371,80 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
|
|||
vec_bf16 *vb3 = (vec_bf16 *)b3;
|
||||
|
||||
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||
vec_bf16 v_x00 = (vec_bf16)vec_load_vec(x_bf);
|
||||
|
||||
vec_bf16 v_x01 = (vec_bf16)vec_splat((vec_f32)v_x00, 1);
|
||||
vec_bf16 v_x02 = (vec_bf16)vec_splat((vec_f32)v_x00, 2);
|
||||
vec_bf16 v_x03 = (vec_bf16)vec_splat((vec_f32)v_x00, 3);
|
||||
|
||||
vec_f32 *v_y = (vec_f32 *)y;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
#ifdef USE_MERGE_MMA
|
||||
vec_bf16 v_x0[16];
|
||||
vec_f32 vy0[2*4*2];
|
||||
v_x0[0] = (vec_bf16)vec_load_vec(x_bf);
|
||||
|
||||
vec_make_mult4(v_x0);
|
||||
|
||||
for (; i + 8 <= n8; i += 8) {
|
||||
vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
|
||||
vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]);
|
||||
vec_load_mult284b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]);
|
||||
vec_load_mult284b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]);
|
||||
vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]);
|
||||
vec_load_mult284b_mma(&temp[2], &va2[i + 4], &va3[i + 4], &v_x0[ 4]);
|
||||
vec_load_mult284b_mma(&temp[2], &vb0[i + 4], &vb1[i + 4], &v_x0[ 8]);
|
||||
vec_load_mult284b_mma(&temp[2], &vb2[i + 4], &vb3[i + 4], &v_x0[12]);
|
||||
|
||||
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
|
||||
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8);
|
||||
|
||||
vec_store8_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
if (n8 & 4) {
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
|
||||
vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]);
|
||||
vec_load_mult284b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]);
|
||||
vec_load_mult284b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]);
|
||||
|
||||
vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
|
||||
i += 4;
|
||||
}
|
||||
#else
|
||||
vec_bf16 v_x0[13];
|
||||
vec_f32 vy0[2*4];
|
||||
v_x0[0] = (vec_bf16)vec_load_vec(x_bf);
|
||||
|
||||
v_x0[ 4] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 1);
|
||||
v_x0[ 8] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 2);
|
||||
v_x0[12] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 3);
|
||||
|
||||
for (; i + 4 <= n8; i += 4) {
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x00);
|
||||
vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x01);
|
||||
vec_load_mult28b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], v_x02);
|
||||
vec_load_mult28b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], v_x03);
|
||||
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]);
|
||||
vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x0[ 4]);
|
||||
vec_load_mult28b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], v_x0[ 8]);
|
||||
vec_load_mult28b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], v_x0[12]);
|
||||
|
||||
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00);
|
||||
vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01);
|
||||
vec_load_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x02);
|
||||
vec_load_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x03);
|
||||
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]);
|
||||
vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4]);
|
||||
vec_load_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8]);
|
||||
vec_load_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12]);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
|
@ -302,10 +456,10 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
|
|||
BLASLONG n3 = n & 3;
|
||||
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
|
||||
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00, n);
|
||||
vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01, n);
|
||||
vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x02, n);
|
||||
vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x03, n);
|
||||
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
|
||||
vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n);
|
||||
vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n);
|
||||
vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
|
@ -313,10 +467,10 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
|
|||
} else if (n) {
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x00, n);
|
||||
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x01, n);
|
||||
vec_loadN_mult11b_mma(&temp[0], &vb0[i], &vb1[i], v_x02, n);
|
||||
vec_loadN_mult11b_mma(&temp[0], &vb2[i], &vb3[i], v_x03, n);
|
||||
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
|
||||
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n);
|
||||
vec_loadN_mult11b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n);
|
||||
vec_loadN_mult11b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n);
|
||||
|
||||
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
|
|
Loading…
Reference in New Issue