Common code.

This commit is contained in:
Chip Kerchner 2024-09-25 16:31:36 -05:00
parent 8ab6245771
commit fb287d17fc
3 changed files with 262 additions and 121 deletions

View File

@ -33,6 +33,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define USE_MERGE_MMA #define USE_MERGE_MMA
#endif #endif
FORCEINLINE void vec_load_pair2(vec_bf16 *in0, vec_bf16 *in)
{
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0));
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2));
}
FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp) FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
{ {
vec_bf16 in0 = (vec_bf16)vec_load_vec(in); vec_bf16 in0 = (vec_bf16)vec_load_vec(in);
@ -40,6 +46,28 @@ FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 in
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp); __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp);
} }
FORCEINLINE void vec_load_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp)
{
vec_bf16 in01 = (vec_bf16)vec_load_vec(in0);
vec_bf16 in11 = (vec_bf16)vec_load_vec(in1);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
}
FORCEINLINE void vec_load_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp)
{
vec_bf16 in01 = (vec_bf16)vec_load_vec(in0);
vec_bf16 in11 = (vec_bf16)vec_load_vec(in1);
vec_bf16 in21 = (vec_bf16)vec_load_vec(in2);
vec_bf16 in31 = (vec_bf16)vec_load_vec(in3);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp);
}
FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
{ {
vec_bf16 in0[2]; vec_bf16 in0[2];
@ -50,6 +78,94 @@ FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]); __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]);
} }
FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
{
vec_bf16 in01[2], in11[2];
vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0);
vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
}
FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp)
{
vec_bf16 in01[2], in11[2], in21[2], in31[2];
vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0);
vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1);
vec_load_pair((vec_f32 *)in21, (vec_f32 *)in2);
vec_load_pair((vec_f32 *)in31, (vec_f32 *)in3);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]);
}
FORCEINLINE void vec_load_mult4_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
{
vec_bf16 in0[4];
vec_load_pair2(in0, in);
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[0]);
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[2], (vec_uc8)inp[2]);
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[3], (vec_uc8)inp[3]);
}
FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
{
vec_bf16 in01[4], in11[4];
vec_load_pair2(in01, in0);
vec_load_pair2(in11, in1);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]);
}
FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp)
{
vec_bf16 in01[4], in11[4], in21[4], in31[4];
vec_load_pair2(in01, in0);
vec_load_pair2(in11, in1);
vec_load_pair2(in21, in2);
vec_load_pair2(in31, in3);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[2], (vec_uc8)inp[2]);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[2], (vec_uc8)inp[2]);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[3], (vec_uc8)inp[3]);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[3], (vec_uc8)inp[3]);
}
FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n) FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
{ {
vec_bf16 in0 = vec_loadN(in, n); vec_bf16 in0 = vec_loadN(in, n);
@ -57,6 +173,28 @@ FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 i
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp); __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp);
} }
FORCEINLINE void vec_loadN_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp, BLASLONG n)
{
vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n);
vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
}
FORCEINLINE void vec_loadN_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp, BLASLONG n)
{
vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n);
vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n);
vec_bf16 in21 = (vec_bf16)vec_loadN(in2, n);
vec_bf16 in31 = (vec_bf16)vec_loadN(in3, n);
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp);
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp);
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp);
}
FORCEINLINE void vec_mult1_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp) FORCEINLINE void vec_mult1_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
{ {
vec_bf16 in00 = vec_mergeh(in0, in0); vec_bf16 in00 = vec_mergeh(in0, in0);

View File

@ -119,12 +119,12 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
if (n > 4) { if (n > 4) {
vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0[ 0], n); vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0[ 0], n);
BLASLONG n3 = n & 3; n &= 3;
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n);
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n);
} else if (n) { } else if (n) {
vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0[ 0], n); vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0[ 0], n);
@ -213,12 +213,12 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
if (n > 4) { if (n > 4) {
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
BLASLONG n3 = n & 3; n &= 3;
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n);
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n);
} else if (n) { } else if (n) {
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
@ -318,12 +318,12 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n);
BLASLONG n3 = n & 3; n &= 3;
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n);
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n);
} else if (n) { } else if (n) {
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n);
@ -445,12 +445,12 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n); vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n);
vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n); vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n);
BLASLONG n3 = n & 3; n &= 3;
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n);
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n);
} else if (n) { } else if (n) {
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n);

View File

@ -49,7 +49,7 @@ static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
vec_bf16 *va0, *v_x; vec_bf16 *va0, *v_x;
__vector_quad temp0; __vector_quad temp0;
vec_f32 temp00[4]; vec_f32 temp00[4];
vec_bf16 inp[2]; vec_bf16 inp[4];
__builtin_mma_xxsetaccz(&temp0); __builtin_mma_xxsetaccz(&temp0);
@ -59,10 +59,18 @@ static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
BLASLONG n8 = n / 8; BLASLONG n8 = n / 8;
BLASLONG i = 0; BLASLONG i = 0;
for (; i + 2 <= n8; i += 2) { for (; i + 4 <= n8; i += 4) {
vec_load_pair2(inp, &v_x[i]);
vec_load_mult4_mma(&temp0, &va0[i + 0], inp);
}
if (n8 & 2) {
vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]);
vec_load_mult2_mma(&temp0, &va0[i + 0], inp); vec_load_mult2_mma(&temp0, &va0[i + 0], inp);
i += 2;
} }
if (n8 & 1) { if (n8 & 1) {
@ -89,12 +97,12 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
{ {
IFLOAT *a0, *a1; IFLOAT *a0, *a1;
vec_bf16 *va0, *va1, *v_x; vec_bf16 *va0, *va1, *v_x;
__vector_quad temp0, temp1; __vector_quad temp0[2];
vec_f32 temp00[4], temp01[4]; vec_f32 temp00[4*2];
vec_bf16 inp[2]; vec_bf16 inp[4];
__builtin_mma_xxsetaccz(&temp0); __builtin_mma_xxsetaccz(&temp0[0]);
__builtin_mma_xxsetaccz(&temp1); __builtin_mma_xxsetaccz(&temp0[1]);
a0 = ap; a0 = ap;
a1 = ap + lda; a1 = ap + lda;
@ -104,18 +112,24 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
BLASLONG n8 = n / 8; BLASLONG n8 = n / 8;
BLASLONG i = 0; BLASLONG i = 0;
for (; i + 2 <= n8; i += 2) { for (; i + 4 <= n8; i += 4) {
vec_load_pair2(inp, &v_x[i]);
vec_load_mult42_mma(&temp0[0], &va0[i + 0], &va1[i + 0], inp);
}
if (n8 & 2) {
vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]);
vec_load_mult2_mma(&temp0, &va0[i + 0], inp); vec_load_mult22_mma(&temp0[0], &va0[i + 0], &va1[i + 0], inp);
vec_load_mult2_mma(&temp1, &va1[i + 0], inp);
i += 2;
} }
if (n8 & 1) { if (n8 & 1) {
inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); inp[0] = (vec_bf16)vec_load_vec(&v_x[i]);
vec_load_mult_mma(&temp0, &va0[i], inp[0]); vec_load_mult12a_mma(&temp0[0], &va0[i], &va1[i], inp[0]);
vec_load_mult_mma(&temp1, &va1[i], inp[0]);
i++; i++;
} }
@ -124,29 +138,28 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
if (n) { if (n) {
inp[0] = vec_loadN(&v_x[i], n); inp[0] = vec_loadN(&v_x[i], n);
vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n); vec_loadN_mult12a_mma(&temp0[0], &va0[i], &va1[i], inp[0], n);
vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n);
} }
__builtin_mma_disassemble_acc((void*)temp00, &temp0); __builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]);
__builtin_mma_disassemble_acc((void*)temp01, &temp1); __builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]);
y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]); y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]);
y[1] = (alpha * (temp01[0][0] + temp01[1][1] + temp01[2][2] + temp01[3][3])) + (beta * y[1]); y[1] = (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3])) + (beta * y[1]);
} }
static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
{ {
IFLOAT *a0, *a1, *a2, *a3; IFLOAT *a0, *a1, *a2, *a3;
vec_bf16 *va0, *va1, *va2, *va3, *v_x; vec_bf16 *va0, *va1, *va2, *va3, *v_x;
__vector_quad temp0, temp1, temp2, temp3; __vector_quad temp0[4];
vec_f32 temp00[4], temp01[4], temp02[4], temp03[4]; vec_f32 temp00[4*4];
vec_bf16 inp[2]; vec_bf16 inp[4];
__builtin_mma_xxsetaccz(&temp0); __builtin_mma_xxsetaccz(&temp0[0]);
__builtin_mma_xxsetaccz(&temp1); __builtin_mma_xxsetaccz(&temp0[1]);
__builtin_mma_xxsetaccz(&temp2); __builtin_mma_xxsetaccz(&temp0[2]);
__builtin_mma_xxsetaccz(&temp3); __builtin_mma_xxsetaccz(&temp0[3]);
a0 = ap; a0 = ap;
a1 = ap + lda; a1 = ap + lda;
@ -160,22 +173,24 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
BLASLONG n8 = n / 8; BLASLONG n8 = n / 8;
BLASLONG i = 0; BLASLONG i = 0;
for (; i + 2 <= n8; i += 2) { for (; i + 4 <= n8; i += 4) {
vec_load_pair2(inp, &v_x[i]);
vec_load_mult44_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp);
}
if (n8 & 2) {
vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]);
vec_load_mult2_mma(&temp0, &va0[i + 0], inp); vec_load_mult24_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp);
vec_load_mult2_mma(&temp1, &va1[i + 0], inp);
vec_load_mult2_mma(&temp2, &va2[i + 0], inp); i += 2;
vec_load_mult2_mma(&temp3, &va3[i + 0], inp);
} }
if (n8 & 1) { if (n8 & 1) {
inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); inp[0] = (vec_bf16)vec_load_vec(&v_x[i]);
vec_load_mult_mma(&temp0, &va0[i], inp[0]); vec_load_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0]);
vec_load_mult_mma(&temp1, &va1[i], inp[0]);
vec_load_mult_mma(&temp2, &va2[i], inp[0]);
vec_load_mult_mma(&temp3, &va3[i], inp[0]);
i++; i++;
} }
@ -184,30 +199,27 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
if (n) { if (n) {
inp[0] = vec_loadN(&v_x[i], n); inp[0] = vec_loadN(&v_x[i], n);
vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n); vec_loadN_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0], n);
vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n);
vec_loadN_mult_mma(&temp2, &va2[i], inp[0], n);
vec_loadN_mult_mma(&temp3, &va3[i], inp[0], n);
} }
__builtin_mma_disassemble_acc((void*)temp00, &temp0); __builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]);
__builtin_mma_disassemble_acc((void*)temp01, &temp1); __builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]);
__builtin_mma_disassemble_acc((void*)temp02, &temp2); __builtin_mma_disassemble_acc((void*)(temp00 + 8), &temp0[2]);
__builtin_mma_disassemble_acc((void*)temp03, &temp3); __builtin_mma_disassemble_acc((void*)(temp00 + 12), &temp0[3]);
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7; vec_f32 t0, t1, t2, t3, t4, t5, t6, t7;
vec_f32 a = { alpha, alpha, alpha, alpha }; vec_f32 a = { alpha, alpha, alpha, alpha };
vec_f32 b = { beta, beta, beta, beta }; vec_f32 b = { beta, beta, beta, beta };
vec_f32 *v_y = (vec_f32 *) y; vec_f32 *v_y = (vec_f32 *) y;
t0 = vec_mergeh(temp00[0], temp01[0]); t0 = vec_mergeh(temp00[ 0], temp00[ 4]);
t1 = vec_mergeh(temp02[0], temp03[0]); t1 = vec_mergeh(temp00[ 8], temp00[12]);
t2 = vec_mergeo(temp00[1], temp01[1]); t2 = vec_mergeo(temp00[ 1], temp00[ 5]);
t3 = vec_mergeo(temp02[1], temp03[1]); t3 = vec_mergeo(temp00[ 9], temp00[13]);
t4 = vec_mergel(temp00[2], temp01[2]); t4 = vec_mergel(temp00[ 2], temp00[ 6]);
t5 = vec_mergel(temp02[2], temp03[2]); t5 = vec_mergel(temp00[10], temp00[14]);
t6 = vec_mergeo(temp00[3], temp01[3]); t6 = vec_mergeo(temp00[ 3], temp00[ 7]);
t7 = vec_mergeo(temp02[3], temp03[3]); t7 = vec_mergeo(temp00[11], temp00[15]);
t0 = vec_xxpermdi(t0, t1, 0); t0 = vec_xxpermdi(t0, t1, 0);
t2 = vec_xxpermdi(t2, t3, 0); t2 = vec_xxpermdi(t2, t3, 0);
t4 = vec_xxpermdi(t4, t5, 0); t4 = vec_xxpermdi(t4, t5, 0);
@ -223,18 +235,18 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
{ {
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x; vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x;
__vector_quad temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7; __vector_quad temp0[8];
vec_f32 temp00[4], temp01[4], temp02[4], temp03[4], temp04[4], temp05[4], temp06[4], temp07[4]; vec_f32 temp00[4*8];
vec_bf16 inp[2]; vec_bf16 inp[4];
__builtin_mma_xxsetaccz(&temp0); __builtin_mma_xxsetaccz(&temp0[0]);
__builtin_mma_xxsetaccz(&temp1); __builtin_mma_xxsetaccz(&temp0[1]);
__builtin_mma_xxsetaccz(&temp2); __builtin_mma_xxsetaccz(&temp0[2]);
__builtin_mma_xxsetaccz(&temp3); __builtin_mma_xxsetaccz(&temp0[3]);
__builtin_mma_xxsetaccz(&temp4); __builtin_mma_xxsetaccz(&temp0[4]);
__builtin_mma_xxsetaccz(&temp5); __builtin_mma_xxsetaccz(&temp0[5]);
__builtin_mma_xxsetaccz(&temp6); __builtin_mma_xxsetaccz(&temp0[6]);
__builtin_mma_xxsetaccz(&temp7); __builtin_mma_xxsetaccz(&temp0[7]);
a0 = ap; a0 = ap;
a1 = ap + lda; a1 = ap + lda;
@ -256,30 +268,27 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
BLASLONG n8 = n / 8; BLASLONG n8 = n / 8;
BLASLONG i = 0; BLASLONG i = 0;
for (; i + 2 <= n8; i += 2) { for (; i + 4 <= n8; i += 4) {
vec_load_pair2(inp, &v_x[i]);
vec_load_mult44_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp);
vec_load_mult44_mma(&temp0[4], &va4[i + 0], &va5[i + 0], &va6[i + 0], &va7[i + 0], inp);
}
if (n8 & 2) {
vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]);
vec_load_mult2_mma(&temp0, &va0[i + 0], inp); vec_load_mult24_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp);
vec_load_mult2_mma(&temp1, &va1[i + 0], inp); vec_load_mult24_mma(&temp0[4], &va4[i + 0], &va5[i + 0], &va6[i + 0], &va7[i + 0], inp);
vec_load_mult2_mma(&temp2, &va2[i + 0], inp);
vec_load_mult2_mma(&temp3, &va3[i + 0], inp); i += 2;
vec_load_mult2_mma(&temp4, &va4[i + 0], inp);
vec_load_mult2_mma(&temp5, &va5[i + 0], inp);
vec_load_mult2_mma(&temp6, &va6[i + 0], inp);
vec_load_mult2_mma(&temp7, &va7[i + 0], inp);
} }
if (n8 & 1) { if (n8 & 1) {
inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); inp[0] = (vec_bf16)vec_load_vec(&v_x[i]);
vec_load_mult_mma(&temp0, &va0[i], inp[0]); vec_load_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0]);
vec_load_mult_mma(&temp1, &va1[i], inp[0]); vec_load_mult14_mma(&temp0[4], &va4[i], &va5[i], &va6[i], &va7[i], inp[0]);
vec_load_mult_mma(&temp2, &va2[i], inp[0]);
vec_load_mult_mma(&temp3, &va3[i], inp[0]);
vec_load_mult_mma(&temp4, &va4[i], inp[0]);
vec_load_mult_mma(&temp5, &va5[i], inp[0]);
vec_load_mult_mma(&temp6, &va6[i], inp[0]);
vec_load_mult_mma(&temp7, &va7[i], inp[0]);
i++; i++;
} }
@ -288,38 +297,32 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
if (n) { if (n) {
inp[0] = vec_loadN(&v_x[i], n); inp[0] = vec_loadN(&v_x[i], n);
vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n); vec_loadN_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0], n);
vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n); vec_loadN_mult14_mma(&temp0[4], &va4[i], &va5[i], &va6[i], &va7[i], inp[0], n);
vec_loadN_mult_mma(&temp2, &va2[i], inp[0], n);
vec_loadN_mult_mma(&temp3, &va3[i], inp[0], n);
vec_loadN_mult_mma(&temp4, &va4[i], inp[0], n);
vec_loadN_mult_mma(&temp5, &va5[i], inp[0], n);
vec_loadN_mult_mma(&temp6, &va6[i], inp[0], n);
vec_loadN_mult_mma(&temp7, &va7[i], inp[0], n);
} }
__builtin_mma_disassemble_acc((void*)temp00, &temp0); __builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]);
__builtin_mma_disassemble_acc((void*)temp01, &temp1); __builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]);
__builtin_mma_disassemble_acc((void*)temp02, &temp2); __builtin_mma_disassemble_acc((void*)(temp00 + 8), &temp0[2]);
__builtin_mma_disassemble_acc((void*)temp03, &temp3); __builtin_mma_disassemble_acc((void*)(temp00 + 12), &temp0[3]);
__builtin_mma_disassemble_acc((void*)temp04, &temp4); __builtin_mma_disassemble_acc((void*)(temp00 + 16), &temp0[4]);
__builtin_mma_disassemble_acc((void*)temp05, &temp5); __builtin_mma_disassemble_acc((void*)(temp00 + 20), &temp0[5]);
__builtin_mma_disassemble_acc((void*)temp06, &temp6); __builtin_mma_disassemble_acc((void*)(temp00 + 24), &temp0[6]);
__builtin_mma_disassemble_acc((void*)temp07, &temp7); __builtin_mma_disassemble_acc((void*)(temp00 + 28), &temp0[7]);
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17; vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17;
vec_f32 a = { alpha, alpha, alpha, alpha }; vec_f32 a = { alpha, alpha, alpha, alpha };
vec_f32 b = { beta, beta, beta, beta }; vec_f32 b = { beta, beta, beta, beta };
vec_f32 *v_y = (vec_f32 *) y; vec_f32 *v_y = (vec_f32 *) y;
t0 = vec_mergeh(temp00[0], temp01[0]); t0 = vec_mergeh(temp00[ 0], temp00[ 4]);
t1 = vec_mergeh(temp02[0], temp03[0]); t1 = vec_mergeh(temp00[ 8], temp00[12]);
t2 = vec_mergeo(temp00[1], temp01[1]); t2 = vec_mergeo(temp00[ 1], temp00[ 5]);
t3 = vec_mergeo(temp02[1], temp03[1]); t3 = vec_mergeo(temp00[ 9], temp00[13]);
t4 = vec_mergel(temp00[2], temp01[2]); t4 = vec_mergel(temp00[ 2], temp00[ 6]);
t5 = vec_mergel(temp02[2], temp03[2]); t5 = vec_mergel(temp00[10], temp00[14]);
t6 = vec_mergeo(temp00[3], temp01[3]); t6 = vec_mergeo(temp00[ 3], temp00[ 7]);
t7 = vec_mergeo(temp02[3], temp03[3]); t7 = vec_mergeo(temp00[11], temp00[15]);
t0 = vec_xxpermdi(t0, t1, 0); t0 = vec_xxpermdi(t0, t1, 0);
t2 = vec_xxpermdi(t2, t3, 0); t2 = vec_xxpermdi(t2, t3, 0);
t4 = vec_xxpermdi(t4, t5, 0); t4 = vec_xxpermdi(t4, t5, 0);
@ -327,14 +330,14 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
t0 += t2 + t4 + t6; t0 += t2 + t4 + t6;
t10 = vec_mergeh(temp04[0], temp05[0]); t10 = vec_mergeh(temp00[16], temp00[20]);
t11 = vec_mergeh(temp06[0], temp07[0]); t11 = vec_mergeh(temp00[24], temp00[28]);
t12 = vec_mergeo(temp04[1], temp05[1]); t12 = vec_mergeo(temp00[17], temp00[21]);
t13 = vec_mergeo(temp06[1], temp07[1]); t13 = vec_mergeo(temp00[25], temp00[29]);
t14 = vec_mergel(temp04[2], temp05[2]); t14 = vec_mergel(temp00[18], temp00[22]);
t15 = vec_mergel(temp06[2], temp07[2]); t15 = vec_mergel(temp00[26], temp00[30]);
t16 = vec_mergeo(temp04[3], temp05[3]); t16 = vec_mergeo(temp00[19], temp00[23]);
t17 = vec_mergeo(temp06[3], temp07[3]); t17 = vec_mergeo(temp00[27], temp00[31]);
t10 = vec_xxpermdi(t10, t11, 0); t10 = vec_xxpermdi(t10, t11, 0);
t12 = vec_xxpermdi(t12, t13, 0); t12 = vec_xxpermdi(t12, t13, 0);
t14 = vec_xxpermdi(t14, t15, 0); t14 = vec_xxpermdi(t14, t15, 0);