Common MMA code.

This commit is contained in:
Chip Kerchner 2024-09-26 09:28:56 -05:00
parent fb287d17fc
commit eb6f3a05ef
1 changed files with 40 additions and 54 deletions

View File

@ -48,22 +48,20 @@ FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 in
FORCEINLINE void vec_load_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp) FORCEINLINE void vec_load_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp)
{ {
vec_bf16 in01 = (vec_bf16)vec_load_vec(in0);
vec_bf16 in11 = (vec_bf16)vec_load_vec(in1); vec_bf16 in11 = (vec_bf16)vec_load_vec(in1);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); vec_load_mult_mma(out, in0, inp);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
} }
FORCEINLINE void vec_load_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp) FORCEINLINE void vec_load_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp)
{ {
vec_bf16 in01 = (vec_bf16)vec_load_vec(in0);
vec_bf16 in11 = (vec_bf16)vec_load_vec(in1);
vec_bf16 in21 = (vec_bf16)vec_load_vec(in2); vec_bf16 in21 = (vec_bf16)vec_load_vec(in2);
vec_bf16 in31 = (vec_bf16)vec_load_vec(in3); vec_bf16 in31 = (vec_bf16)vec_load_vec(in3);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); vec_load_mult12a_mma(out, in0, in1, inp);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp); __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp); __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp);
} }
@ -78,6 +76,12 @@ FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]); __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]);
} }
FORCEINLINE void vec_mult2d_mma(__vector_quad *out, vec_bf16 *in01, vec_bf16 *in11, vec_bf16 *inp)
{
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
}
FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
{ {
vec_bf16 in01[2], in11[2]; vec_bf16 in01[2], in11[2];
@ -85,10 +89,8 @@ FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16
vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0); vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0);
vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1); vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); vec_mult2d_mma(out, in01 + 0, in11 + 0, inp + 0);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); vec_mult2d_mma(out, in01 + 1, in11 + 1, inp + 1);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
} }
FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp) FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp)
@ -100,26 +102,22 @@ FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16
vec_load_pair((vec_f32 *)in21, (vec_f32 *)in2); vec_load_pair((vec_f32 *)in21, (vec_f32 *)in2);
vec_load_pair((vec_f32 *)in31, (vec_f32 *)in3); vec_load_pair((vec_f32 *)in31, (vec_f32 *)in3);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); vec_mult2d_mma(out + 0, in01 + 0, in11 + 0, inp + 0);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); vec_mult2d_mma(out + 2, in21 + 0, in31 + 0, inp + 0);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]); vec_mult2d_mma(out + 0, in01 + 1, in11 + 1, inp + 1);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]); vec_mult2d_mma(out + 2, in21 + 1, in31 + 1, inp + 1);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]);
} }
FORCEINLINE void vec_load_mult4_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) FORCEINLINE void vec_load_mult4_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
{ {
vec_bf16 in0[4]; vec_bf16 in0[2];
vec_load_pair2(in0, in); vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 2));
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[0]); vec_load_mult2_mma(out, in + 0, inp + 0);
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[2], (vec_uc8)inp[2]); __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[2]);
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[3], (vec_uc8)inp[3]); __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[3]);
} }
FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
@ -129,14 +127,16 @@ FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16
vec_load_pair2(in01, in0); vec_load_pair2(in01, in0);
vec_load_pair2(in11, in1); vec_load_pair2(in11, in1);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); vec_mult2d_mma(out, in01 + 0, in11 + 0, inp + 0);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); vec_mult2d_mma(out, in01 + 1, in11 + 1, inp + 1);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); vec_mult2d_mma(out, in01 + 2, in11 + 2, inp + 2);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); vec_mult2d_mma(out, in01 + 3, in11 + 3, inp + 3);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]); }
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]); FORCEINLINE void vec_mult4d_mma(__vector_quad *out, vec_bf16 *in01, vec_bf16 *in11, vec_bf16 *in21, vec_bf16 *in31, vec_bf16 *inp)
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]); {
vec_mult2d_mma(out + 0, in01, in11, inp);
vec_mult2d_mma(out + 2, in21, in31, inp);
} }
FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp) FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp)
@ -148,22 +148,10 @@ FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16
vec_load_pair2(in21, in2); vec_load_pair2(in21, in2);
vec_load_pair2(in31, in3); vec_load_pair2(in31, in3);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); vec_mult4d_mma(out, in01 + 0, in11 + 0, in21 + 0, in31 + 0, inp + 0);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); vec_mult4d_mma(out, in01 + 1, in11 + 1, in21 + 1, in31 + 1, inp + 1);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]); vec_mult4d_mma(out, in01 + 2, in11 + 2, in21 + 2, in31 + 2, inp + 2);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]); vec_mult4d_mma(out, in01 + 3, in11 + 3, in21 + 3, in31 + 3, inp + 3);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[2], (vec_uc8)inp[2]);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[2], (vec_uc8)inp[2]);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[3], (vec_uc8)inp[3]);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[3], (vec_uc8)inp[3]);
} }
FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n) FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
@ -175,22 +163,20 @@ FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 i
FORCEINLINE void vec_loadN_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp, BLASLONG n) FORCEINLINE void vec_loadN_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp, BLASLONG n)
{ {
vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n);
vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n); vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); vec_loadN_mult_mma(out, in0, inp, n);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
} }
FORCEINLINE void vec_loadN_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp, BLASLONG n) FORCEINLINE void vec_loadN_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp, BLASLONG n)
{ {
vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n);
vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n);
vec_bf16 in21 = (vec_bf16)vec_loadN(in2, n); vec_bf16 in21 = (vec_bf16)vec_loadN(in2, n);
vec_bf16 in31 = (vec_bf16)vec_loadN(in3, n); vec_bf16 in31 = (vec_bf16)vec_loadN(in3, n);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); vec_loadN_mult12a_mma(out, in0, in1, inp, n);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp); __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp); __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp);
} }