Common MMA code.
This commit is contained in:
parent
fb287d17fc
commit
eb6f3a05ef
|
@ -48,22 +48,20 @@ FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 in
|
|||
|
||||
FORCEINLINE void vec_load_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in01 = (vec_bf16)vec_load_vec(in0);
|
||||
vec_bf16 in11 = (vec_bf16)vec_load_vec(in1);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
|
||||
vec_load_mult_mma(out, in0, inp);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in01 = (vec_bf16)vec_load_vec(in0);
|
||||
vec_bf16 in11 = (vec_bf16)vec_load_vec(in1);
|
||||
vec_bf16 in21 = (vec_bf16)vec_load_vec(in2);
|
||||
vec_bf16 in31 = (vec_bf16)vec_load_vec(in3);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
|
||||
vec_load_mult12a_mma(out, in0, in1, inp);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp);
|
||||
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp);
|
||||
}
|
||||
|
@ -78,6 +76,12 @@ FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *
|
|||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult2d_mma(__vector_quad *out, vec_bf16 *in01, vec_bf16 *in11, vec_bf16 *inp)
|
||||
{
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in01[2], in11[2];
|
||||
|
@ -85,10 +89,8 @@ FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16
|
|||
vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0);
|
||||
vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
|
||||
vec_mult2d_mma(out, in01 + 0, in11 + 0, inp + 0);
|
||||
vec_mult2d_mma(out, in01 + 1, in11 + 1, inp + 1);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp)
|
||||
|
@ -100,26 +102,22 @@ FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16
|
|||
vec_load_pair((vec_f32 *)in21, (vec_f32 *)in2);
|
||||
vec_load_pair((vec_f32 *)in31, (vec_f32 *)in3);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]);
|
||||
vec_mult2d_mma(out + 0, in01 + 0, in11 + 0, inp + 0);
|
||||
vec_mult2d_mma(out + 2, in21 + 0, in31 + 0, inp + 0);
|
||||
vec_mult2d_mma(out + 0, in01 + 1, in11 + 1, inp + 1);
|
||||
vec_mult2d_mma(out + 2, in21 + 1, in31 + 1, inp + 1);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult4_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in0[4];
|
||||
vec_bf16 in0[2];
|
||||
|
||||
vec_load_pair2(in0, in);
|
||||
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 2));
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]);
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[2], (vec_uc8)inp[2]);
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[3], (vec_uc8)inp[3]);
|
||||
vec_load_mult2_mma(out, in + 0, inp + 0);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[2]);
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[3]);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
|
||||
|
@ -129,14 +127,16 @@ FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16
|
|||
vec_load_pair2(in01, in0);
|
||||
vec_load_pair2(in11, in1);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]);
|
||||
vec_mult2d_mma(out, in01 + 0, in11 + 0, inp + 0);
|
||||
vec_mult2d_mma(out, in01 + 1, in11 + 1, inp + 1);
|
||||
vec_mult2d_mma(out, in01 + 2, in11 + 2, inp + 2);
|
||||
vec_mult2d_mma(out, in01 + 3, in11 + 3, inp + 3);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult4d_mma(__vector_quad *out, vec_bf16 *in01, vec_bf16 *in11, vec_bf16 *in21, vec_bf16 *in31, vec_bf16 *inp)
|
||||
{
|
||||
vec_mult2d_mma(out + 0, in01, in11, inp);
|
||||
vec_mult2d_mma(out + 2, in21, in31, inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp)
|
||||
|
@ -148,22 +148,10 @@ FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16
|
|||
vec_load_pair2(in21, in2);
|
||||
vec_load_pair2(in31, in3);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[2], (vec_uc8)inp[2]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[2], (vec_uc8)inp[2]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[3], (vec_uc8)inp[3]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[3], (vec_uc8)inp[3]);
|
||||
vec_mult4d_mma(out, in01 + 0, in11 + 0, in21 + 0, in31 + 0, inp + 0);
|
||||
vec_mult4d_mma(out, in01 + 1, in11 + 1, in21 + 1, in31 + 1, inp + 1);
|
||||
vec_mult4d_mma(out, in01 + 2, in11 + 2, in21 + 2, in31 + 2, inp + 2);
|
||||
vec_mult4d_mma(out, in01 + 3, in11 + 3, in21 + 3, in31 + 3, inp + 3);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
|
||||
|
@ -175,22 +163,20 @@ FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 i
|
|||
|
||||
FORCEINLINE void vec_loadN_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n);
|
||||
vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
|
||||
vec_loadN_mult_mma(out, in0, inp, n);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_loadN_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n);
|
||||
vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n);
|
||||
vec_bf16 in21 = (vec_bf16)vec_loadN(in2, n);
|
||||
vec_bf16 in31 = (vec_bf16)vec_loadN(in3, n);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
|
||||
vec_loadN_mult12a_mma(out, in0, in1, inp, n);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp);
|
||||
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue