From fb287d17fc1a5f53920f0b8c29ba476b258950bc Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Wed, 25 Sep 2024 16:31:36 -0500 Subject: [PATCH] Common code. --- kernel/power/sbgemv_common_power10.c | 138 +++++++++++++++++ kernel/power/sbgemv_n_power10.c | 24 +-- kernel/power/sbgemv_t_power10.c | 221 ++++++++++++++------------- 3 files changed, 262 insertions(+), 121 deletions(-) diff --git a/kernel/power/sbgemv_common_power10.c b/kernel/power/sbgemv_common_power10.c index d24a98418..638e2655c 100644 --- a/kernel/power/sbgemv_common_power10.c +++ b/kernel/power/sbgemv_common_power10.c @@ -33,6 +33,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define USE_MERGE_MMA #endif +FORCEINLINE void vec_load_pair2(vec_bf16 *in0, vec_bf16 *in) +{ + vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0)); + vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2)); +} + FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp) { vec_bf16 in0 = (vec_bf16)vec_load_vec(in); @@ -40,6 +46,28 @@ FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 in __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp); } +FORCEINLINE void vec_load_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp) +{ + vec_bf16 in01 = (vec_bf16)vec_load_vec(in0); + vec_bf16 in11 = (vec_bf16)vec_load_vec(in1); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); +} + +FORCEINLINE void vec_load_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp) +{ + vec_bf16 in01 = (vec_bf16)vec_load_vec(in0); + vec_bf16 in11 = (vec_bf16)vec_load_vec(in1); + vec_bf16 in21 = (vec_bf16)vec_load_vec(in2); + vec_bf16 in31 = (vec_bf16)vec_load_vec(in3); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp); +} + FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) { vec_bf16 in0[2]; @@ -50,6 +78,94 @@ FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 * __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]); } +FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) +{ + vec_bf16 in01[2], in11[2]; + + vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0); + vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); +} + +FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp) +{ + vec_bf16 in01[2], in11[2], in21[2], in31[2]; + + vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0); + vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1); + vec_load_pair((vec_f32 *)in21, (vec_f32 *)in2); + vec_load_pair((vec_f32 *)in31, (vec_f32 *)in3); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]); +} + +FORCEINLINE void vec_load_mult4_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) +{ + vec_bf16 in0[4]; + + vec_load_pair2(in0, in); + + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[2], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[3], (vec_uc8)inp[3]); +} + +FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) +{ + vec_bf16 in01[4], in11[4]; + + vec_load_pair2(in01, in0); + vec_load_pair2(in11, in1); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]); +} + +FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp) +{ + vec_bf16 in01[4], in11[4], in21[4], in31[4]; + + vec_load_pair2(in01, in0); + vec_load_pair2(in11, in1); + vec_load_pair2(in21, in2); + vec_load_pair2(in31, in3); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[2], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[2], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[3], (vec_uc8)inp[3]); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[3], (vec_uc8)inp[3]); +} + FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n) { vec_bf16 in0 = vec_loadN(in, n); @@ -57,6 +173,28 @@ FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 i __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp); } +FORCEINLINE void vec_loadN_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n); + vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); +} + +FORCEINLINE void vec_loadN_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n); + vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n); + vec_bf16 in21 = (vec_bf16)vec_loadN(in2, n); + vec_bf16 in31 = (vec_bf16)vec_loadN(in3, n); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp); +} + FORCEINLINE void vec_mult1_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp) { vec_bf16 in00 = vec_mergeh(in0, in0); diff --git a/kernel/power/sbgemv_n_power10.c b/kernel/power/sbgemv_n_power10.c index f33a246a9..b1dcb2fcc 100644 --- a/kernel/power/sbgemv_n_power10.c +++ b/kernel/power/sbgemv_n_power10.c @@ -119,12 +119,12 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA if (n > 4) { vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0[ 0], n); - BLASLONG n3 = n & 3; - vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + n &= 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); - vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n); } else if (n) { vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0[ 0], n); @@ -213,12 +213,12 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA if (n > 4) { vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); - BLASLONG n3 = n & 3; - vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + n &= 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); - vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n); } else if (n) { vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); @@ -318,12 +318,12 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); - BLASLONG n3 = n & 3; - vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + n &= 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); - vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n); } else if (n) { vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); @@ -445,12 +445,12 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n); vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n); - BLASLONG n3 = n & 3; - vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + n &= 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); - vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n); } else if (n) { vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); diff --git a/kernel/power/sbgemv_t_power10.c b/kernel/power/sbgemv_t_power10.c index 810287e89..9a5c54f12 100644 --- a/kernel/power/sbgemv_t_power10.c +++ b/kernel/power/sbgemv_t_power10.c @@ -49,7 +49,7 @@ static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_bf16 *va0, *v_x; __vector_quad temp0; vec_f32 temp00[4]; - vec_bf16 inp[2]; + vec_bf16 inp[4]; __builtin_mma_xxsetaccz(&temp0); @@ -59,10 +59,18 @@ static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL BLASLONG n8 = n / 8; BLASLONG i = 0; - for (; i + 2 <= n8; i += 2) { + for (; i + 4 <= n8; i += 4) { + vec_load_pair2(inp, &v_x[i]); + + vec_load_mult4_mma(&temp0, &va0[i + 0], inp); + } + + if (n8 & 2) { vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); vec_load_mult2_mma(&temp0, &va0[i + 0], inp); + + i += 2; } if (n8 & 1) { @@ -89,12 +97,12 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL { IFLOAT *a0, *a1; vec_bf16 *va0, *va1, *v_x; - __vector_quad temp0, temp1; - vec_f32 temp00[4], temp01[4]; - vec_bf16 inp[2]; + __vector_quad temp0[2]; + vec_f32 temp00[4*2]; + vec_bf16 inp[4]; - __builtin_mma_xxsetaccz(&temp0); - __builtin_mma_xxsetaccz(&temp1); + __builtin_mma_xxsetaccz(&temp0[0]); + __builtin_mma_xxsetaccz(&temp0[1]); a0 = ap; a1 = ap + lda; @@ -104,18 +112,24 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL BLASLONG n8 = n / 8; BLASLONG i = 0; - for (; i + 2 <= n8; i += 2) { + for (; i + 4 <= n8; i += 4) { + vec_load_pair2(inp, &v_x[i]); + + vec_load_mult42_mma(&temp0[0], &va0[i + 0], &va1[i + 0], inp); + } + + if (n8 & 2) { vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); - vec_load_mult2_mma(&temp0, &va0[i + 0], inp); - vec_load_mult2_mma(&temp1, &va1[i + 0], inp); + vec_load_mult22_mma(&temp0[0], &va0[i + 0], &va1[i + 0], inp); + + i += 2; } if (n8 & 1) { inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); - vec_load_mult_mma(&temp0, &va0[i], inp[0]); - vec_load_mult_mma(&temp1, &va1[i], inp[0]); + vec_load_mult12a_mma(&temp0[0], &va0[i], &va1[i], inp[0]); i++; } @@ -124,29 +138,28 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL if (n) { inp[0] = vec_loadN(&v_x[i], n); - vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n); - vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n); + vec_loadN_mult12a_mma(&temp0[0], &va0[i], &va1[i], inp[0], n); } - __builtin_mma_disassemble_acc((void*)temp00, &temp0); - __builtin_mma_disassemble_acc((void*)temp01, &temp1); + __builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]); + __builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]); y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]); - y[1] = (alpha * (temp01[0][0] + temp01[1][1] + temp01[2][2] + temp01[3][3])) + (beta * y[1]); + y[1] = (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3])) + (beta * y[1]); } static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) { IFLOAT *a0, *a1, *a2, *a3; vec_bf16 *va0, *va1, *va2, *va3, *v_x; - __vector_quad temp0, temp1, temp2, temp3; - vec_f32 temp00[4], temp01[4], temp02[4], temp03[4]; - vec_bf16 inp[2]; + __vector_quad temp0[4]; + vec_f32 temp00[4*4]; + vec_bf16 inp[4]; - __builtin_mma_xxsetaccz(&temp0); - __builtin_mma_xxsetaccz(&temp1); - __builtin_mma_xxsetaccz(&temp2); - __builtin_mma_xxsetaccz(&temp3); + __builtin_mma_xxsetaccz(&temp0[0]); + __builtin_mma_xxsetaccz(&temp0[1]); + __builtin_mma_xxsetaccz(&temp0[2]); + __builtin_mma_xxsetaccz(&temp0[3]); a0 = ap; a1 = ap + lda; @@ -160,22 +173,24 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL BLASLONG n8 = n / 8; BLASLONG i = 0; - for (; i + 2 <= n8; i += 2) { + for (; i + 4 <= n8; i += 4) { + vec_load_pair2(inp, &v_x[i]); + + vec_load_mult44_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp); + } + + if (n8 & 2) { vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); - vec_load_mult2_mma(&temp0, &va0[i + 0], inp); - vec_load_mult2_mma(&temp1, &va1[i + 0], inp); - vec_load_mult2_mma(&temp2, &va2[i + 0], inp); - vec_load_mult2_mma(&temp3, &va3[i + 0], inp); + vec_load_mult24_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp); + + i += 2; } if (n8 & 1) { inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); - vec_load_mult_mma(&temp0, &va0[i], inp[0]); - vec_load_mult_mma(&temp1, &va1[i], inp[0]); - vec_load_mult_mma(&temp2, &va2[i], inp[0]); - vec_load_mult_mma(&temp3, &va3[i], inp[0]); + vec_load_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0]); i++; } @@ -184,30 +199,27 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL if (n) { inp[0] = vec_loadN(&v_x[i], n); - vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n); - vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n); - vec_loadN_mult_mma(&temp2, &va2[i], inp[0], n); - vec_loadN_mult_mma(&temp3, &va3[i], inp[0], n); + vec_loadN_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0], n); } - __builtin_mma_disassemble_acc((void*)temp00, &temp0); - __builtin_mma_disassemble_acc((void*)temp01, &temp1); - __builtin_mma_disassemble_acc((void*)temp02, &temp2); - __builtin_mma_disassemble_acc((void*)temp03, &temp3); + __builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]); + __builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]); + __builtin_mma_disassemble_acc((void*)(temp00 + 8), &temp0[2]); + __builtin_mma_disassemble_acc((void*)(temp00 + 12), &temp0[3]); vec_f32 t0, t1, t2, t3, t4, t5, t6, t7; vec_f32 a = { alpha, alpha, alpha, alpha }; vec_f32 b = { beta, beta, beta, beta }; vec_f32 *v_y = (vec_f32 *) y; - t0 = vec_mergeh(temp00[0], temp01[0]); - t1 = vec_mergeh(temp02[0], temp03[0]); - t2 = vec_mergeo(temp00[1], temp01[1]); - t3 = vec_mergeo(temp02[1], temp03[1]); - t4 = vec_mergel(temp00[2], temp01[2]); - t5 = vec_mergel(temp02[2], temp03[2]); - t6 = vec_mergeo(temp00[3], temp01[3]); - t7 = vec_mergeo(temp02[3], temp03[3]); + t0 = vec_mergeh(temp00[ 0], temp00[ 4]); + t1 = vec_mergeh(temp00[ 8], temp00[12]); + t2 = vec_mergeo(temp00[ 1], temp00[ 5]); + t3 = vec_mergeo(temp00[ 9], temp00[13]); + t4 = vec_mergel(temp00[ 2], temp00[ 6]); + t5 = vec_mergel(temp00[10], temp00[14]); + t6 = vec_mergeo(temp00[ 3], temp00[ 7]); + t7 = vec_mergeo(temp00[11], temp00[15]); t0 = vec_xxpermdi(t0, t1, 0); t2 = vec_xxpermdi(t2, t3, 0); t4 = vec_xxpermdi(t4, t5, 0); @@ -223,18 +235,18 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL { IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x; - __vector_quad temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7; - vec_f32 temp00[4], temp01[4], temp02[4], temp03[4], temp04[4], temp05[4], temp06[4], temp07[4]; - vec_bf16 inp[2]; + __vector_quad temp0[8]; + vec_f32 temp00[4*8]; + vec_bf16 inp[4]; - __builtin_mma_xxsetaccz(&temp0); - __builtin_mma_xxsetaccz(&temp1); - __builtin_mma_xxsetaccz(&temp2); - __builtin_mma_xxsetaccz(&temp3); - __builtin_mma_xxsetaccz(&temp4); - __builtin_mma_xxsetaccz(&temp5); - __builtin_mma_xxsetaccz(&temp6); - __builtin_mma_xxsetaccz(&temp7); + __builtin_mma_xxsetaccz(&temp0[0]); + __builtin_mma_xxsetaccz(&temp0[1]); + __builtin_mma_xxsetaccz(&temp0[2]); + __builtin_mma_xxsetaccz(&temp0[3]); + __builtin_mma_xxsetaccz(&temp0[4]); + __builtin_mma_xxsetaccz(&temp0[5]); + __builtin_mma_xxsetaccz(&temp0[6]); + __builtin_mma_xxsetaccz(&temp0[7]); a0 = ap; a1 = ap + lda; @@ -256,30 +268,27 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL BLASLONG n8 = n / 8; BLASLONG i = 0; - for (; i + 2 <= n8; i += 2) { + for (; i + 4 <= n8; i += 4) { + vec_load_pair2(inp, &v_x[i]); + + vec_load_mult44_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp); + vec_load_mult44_mma(&temp0[4], &va4[i + 0], &va5[i + 0], &va6[i + 0], &va7[i + 0], inp); + } + + if (n8 & 2) { vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); - vec_load_mult2_mma(&temp0, &va0[i + 0], inp); - vec_load_mult2_mma(&temp1, &va1[i + 0], inp); - vec_load_mult2_mma(&temp2, &va2[i + 0], inp); - vec_load_mult2_mma(&temp3, &va3[i + 0], inp); - vec_load_mult2_mma(&temp4, &va4[i + 0], inp); - vec_load_mult2_mma(&temp5, &va5[i + 0], inp); - vec_load_mult2_mma(&temp6, &va6[i + 0], inp); - vec_load_mult2_mma(&temp7, &va7[i + 0], inp); + vec_load_mult24_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp); + vec_load_mult24_mma(&temp0[4], &va4[i + 0], &va5[i + 0], &va6[i + 0], &va7[i + 0], inp); + + i += 2; } if (n8 & 1) { inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); - vec_load_mult_mma(&temp0, &va0[i], inp[0]); - vec_load_mult_mma(&temp1, &va1[i], inp[0]); - vec_load_mult_mma(&temp2, &va2[i], inp[0]); - vec_load_mult_mma(&temp3, &va3[i], inp[0]); - vec_load_mult_mma(&temp4, &va4[i], inp[0]); - vec_load_mult_mma(&temp5, &va5[i], inp[0]); - vec_load_mult_mma(&temp6, &va6[i], inp[0]); - vec_load_mult_mma(&temp7, &va7[i], inp[0]); + vec_load_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0]); + vec_load_mult14_mma(&temp0[4], &va4[i], &va5[i], &va6[i], &va7[i], inp[0]); i++; } @@ -288,38 +297,32 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL if (n) { inp[0] = vec_loadN(&v_x[i], n); - vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n); - vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n); - vec_loadN_mult_mma(&temp2, &va2[i], inp[0], n); - vec_loadN_mult_mma(&temp3, &va3[i], inp[0], n); - vec_loadN_mult_mma(&temp4, &va4[i], inp[0], n); - vec_loadN_mult_mma(&temp5, &va5[i], inp[0], n); - vec_loadN_mult_mma(&temp6, &va6[i], inp[0], n); - vec_loadN_mult_mma(&temp7, &va7[i], inp[0], n); + vec_loadN_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0], n); + vec_loadN_mult14_mma(&temp0[4], &va4[i], &va5[i], &va6[i], &va7[i], inp[0], n); } - __builtin_mma_disassemble_acc((void*)temp00, &temp0); - __builtin_mma_disassemble_acc((void*)temp01, &temp1); - __builtin_mma_disassemble_acc((void*)temp02, &temp2); - __builtin_mma_disassemble_acc((void*)temp03, &temp3); - __builtin_mma_disassemble_acc((void*)temp04, &temp4); - __builtin_mma_disassemble_acc((void*)temp05, &temp5); - __builtin_mma_disassemble_acc((void*)temp06, &temp6); - __builtin_mma_disassemble_acc((void*)temp07, &temp7); + __builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]); + __builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]); + __builtin_mma_disassemble_acc((void*)(temp00 + 8), &temp0[2]); + __builtin_mma_disassemble_acc((void*)(temp00 + 12), &temp0[3]); + __builtin_mma_disassemble_acc((void*)(temp00 + 16), &temp0[4]); + __builtin_mma_disassemble_acc((void*)(temp00 + 20), &temp0[5]); + __builtin_mma_disassemble_acc((void*)(temp00 + 24), &temp0[6]); + __builtin_mma_disassemble_acc((void*)(temp00 + 28), &temp0[7]); vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17; vec_f32 a = { alpha, alpha, alpha, alpha }; vec_f32 b = { beta, beta, beta, beta }; vec_f32 *v_y = (vec_f32 *) y; - t0 = vec_mergeh(temp00[0], temp01[0]); - t1 = vec_mergeh(temp02[0], temp03[0]); - t2 = vec_mergeo(temp00[1], temp01[1]); - t3 = vec_mergeo(temp02[1], temp03[1]); - t4 = vec_mergel(temp00[2], temp01[2]); - t5 = vec_mergel(temp02[2], temp03[2]); - t6 = vec_mergeo(temp00[3], temp01[3]); - t7 = vec_mergeo(temp02[3], temp03[3]); + t0 = vec_mergeh(temp00[ 0], temp00[ 4]); + t1 = vec_mergeh(temp00[ 8], temp00[12]); + t2 = vec_mergeo(temp00[ 1], temp00[ 5]); + t3 = vec_mergeo(temp00[ 9], temp00[13]); + t4 = vec_mergel(temp00[ 2], temp00[ 6]); + t5 = vec_mergel(temp00[10], temp00[14]); + t6 = vec_mergeo(temp00[ 3], temp00[ 7]); + t7 = vec_mergeo(temp00[11], temp00[15]); t0 = vec_xxpermdi(t0, t1, 0); t2 = vec_xxpermdi(t2, t3, 0); t4 = vec_xxpermdi(t4, t5, 0); @@ -327,14 +330,14 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL t0 += t2 + t4 + t6; - t10 = vec_mergeh(temp04[0], temp05[0]); - t11 = vec_mergeh(temp06[0], temp07[0]); - t12 = vec_mergeo(temp04[1], temp05[1]); - t13 = vec_mergeo(temp06[1], temp07[1]); - t14 = vec_mergel(temp04[2], temp05[2]); - t15 = vec_mergel(temp06[2], temp07[2]); - t16 = vec_mergeo(temp04[3], temp05[3]); - t17 = vec_mergeo(temp06[3], temp07[3]); + t10 = vec_mergeh(temp00[16], temp00[20]); + t11 = vec_mergeh(temp00[24], temp00[28]); + t12 = vec_mergeo(temp00[17], temp00[21]); + t13 = vec_mergeo(temp00[25], temp00[29]); + t14 = vec_mergel(temp00[18], temp00[22]); + t15 = vec_mergel(temp00[26], temp00[30]); + t16 = vec_mergeo(temp00[19], temp00[23]); + t17 = vec_mergeo(temp00[27], temp00[31]); t10 = vec_xxpermdi(t10, t11, 0); t12 = vec_xxpermdi(t12, t13, 0); t14 = vec_xxpermdi(t14, t15, 0);