Almost final code for MMA.

This commit is contained in:
Chip Kerchner 2024-09-24 16:30:01 -05:00
parent 05aa63e738
commit df19375560
2 changed files with 79 additions and 53 deletions

View File

@ -152,6 +152,14 @@ FORCEINLINE void vec_reduce84_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_a
vec_reduce44_mma(&out[0], &temp[0], v_alpha, vy0 + 0); vec_reduce44_mma(&out[0], &temp[0], v_alpha, vy0 + 0);
vec_reduce44_mma(&out[1], &temp[4], v_alpha, vy0 + 1); vec_reduce44_mma(&out[1], &temp[4], v_alpha, vy0 + 1);
} }
FORCEINLINE void vec_reduce88_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
{
vec_reduce44_mma(&out[0], &temp[ 0], v_alpha, vy0 + 0);
vec_reduce44_mma(&out[1], &temp[ 4], v_alpha, vy0 + 1);
vec_reduce44_mma(&out[2], &temp[ 8], v_alpha, vy0 + 8);
vec_reduce44_mma(&out[3], &temp[12], v_alpha, vy0 + 9);
}
#endif #endif
FORCEINLINE void vec_mult11a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp) FORCEINLINE void vec_mult11a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
@ -341,6 +349,32 @@ FORCEINLINE void vec_load_mult284b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf
vec_mult44b_mma(out, in0 + 0, in1 + 0, inp + 0); vec_mult44b_mma(out, in0 + 0, in1 + 0, inp + 0);
vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2); vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2);
} }
FORCEINLINE void vec_load_mult288a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp)
{
vec_bf16 in0[8], in1[8];
vec_load4_mma(in0 + 0, in1 + 0, ina + 0, inb + 0);
vec_load4_mma(in0 + 4, in1 + 4, ina + 4, inb + 4);
vec_mult44a_mma(out + 0, in0 + 0, in1 + 0, inp + 0);
vec_mult44a_mma(out + 2, in0 + 4, in1 + 4, inp + 0);
vec_mult44b_mma(out + 0, in0 + 2, in1 + 2, inp + 2);
vec_mult44b_mma(out + 2, in0 + 6, in1 + 6, inp + 2);
}
FORCEINLINE void vec_load_mult288b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp)
{
vec_bf16 in0[8], in1[8];
vec_load4_mma(in0 + 0, in1 + 0, ina + 0, inb + 0);
vec_load4_mma(in0 + 4, in1 + 4, ina + 4, inb + 4);
vec_mult44b_mma(out + 0, in0 + 0, in1 + 0, inp + 0);
vec_mult44b_mma(out + 2, in0 + 4, in1 + 4, inp + 0);
vec_mult44b_mma(out + 0, in0 + 2, in1 + 2, inp + 2);
vec_mult44b_mma(out + 2, in0 + 6, in1 + 6, inp + 2);
}
#endif #endif
FORCEINLINE void vec_loadN_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n) FORCEINLINE void vec_loadN_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
@ -381,49 +415,54 @@ FORCEINLINE void vec_store8_pair(vec_f32 *v_y, vec_f32 *vy0)
} }
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#define VEC_SHIFT(data, shift) vec_sld(data, data, 16 - shift) #define VEC_SHIFT(data, shift) vec_sldw(data, data, 4 - shift)
#define MASK_0 0xf000
#define MASK_1 0x0f00
#define MASK_2 0x00f0
#define MASK_3 0x000f
#else #else
#define VEC_SHIFT(data, shift) vec_sld(data, data, shift) #define VEC_SHIFT(data, shift) vec_sldw(data, data, shift)
#define MASK_0 0x000f
#define MASK_1 0x00f0
#define MASK_2 0x0f00
#define MASK_3 0xf000
#endif #endif
typedef __vector unsigned int vec_ui32; FORCEINLINE void vec_make_mult1(vec_bf16 *v_x0, const bool mask)
static vec_ui32 mask_0 = { 0xffffffff, 0x00000000, 0x00000000, 0x00000000 };
static vec_ui32 mask_1 = { 0x00000000, 0xffffffff, 0x00000000, 0x00000000 };
static vec_ui32 mask_2 = { 0x00000000, 0x00000000, 0xffffffff, 0x00000000 };
static vec_ui32 mask_3 = { 0x00000000, 0x00000000, 0x00000000, 0xffffffff };
FORCEINLINE void vec_make_mult1(vec_bf16 *v_x0)
{ {
v_x0[ 0] = vec_and(v_x0[0], (vec_bf16)mask_0); if (mask) {
v_x0[ 0] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_0));
}
v_x0[ 1] = VEC_SHIFT(v_x0[ 0], 4); v_x0[ 1] = VEC_SHIFT(v_x0[ 0], 1);
v_x0[ 2] = VEC_SHIFT(v_x0[ 0], 8); v_x0[ 2] = VEC_SHIFT(v_x0[ 0], 2);
v_x0[ 3] = VEC_SHIFT(v_x0[ 0], 12); v_x0[ 3] = VEC_SHIFT(v_x0[ 0], 3);
} }
FORCEINLINE void vec_make_mult2(vec_bf16 *v_x0) FORCEINLINE void vec_make_mult2(vec_bf16 *v_x0)
{ {
v_x0[ 5] = vec_and(v_x0[0], (vec_bf16)mask_1); v_x0[ 5] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_1));
vec_make_mult1(v_x0); vec_make_mult1(v_x0, true);
v_x0[ 4] = VEC_SHIFT(v_x0[ 5], 12); v_x0[ 4] = VEC_SHIFT(v_x0[ 5], 3);
v_x0[ 6] = VEC_SHIFT(v_x0[ 5], 4); v_x0[ 6] = VEC_SHIFT(v_x0[ 5], 1);
v_x0[ 7] = VEC_SHIFT(v_x0[ 5], 8); v_x0[ 7] = VEC_SHIFT(v_x0[ 5], 2);
} }
FORCEINLINE void vec_make_mult4(vec_bf16 *v_x0) FORCEINLINE void vec_make_mult4(vec_bf16 *v_x0)
{ {
v_x0[10] = vec_and(v_x0[0], (vec_bf16)mask_2); v_x0[10] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_2));
v_x0[15] = vec_and(v_x0[0], (vec_bf16)mask_3); v_x0[15] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_3));
vec_make_mult2(v_x0); vec_make_mult2(v_x0);
v_x0[ 8] = VEC_SHIFT(v_x0[10], 8); v_x0[ 8] = VEC_SHIFT(v_x0[10], 2);
v_x0[ 9] = VEC_SHIFT(v_x0[10], 12); v_x0[ 9] = VEC_SHIFT(v_x0[10], 3);
v_x0[11] = VEC_SHIFT(v_x0[10], 4); v_x0[11] = VEC_SHIFT(v_x0[10], 1);
v_x0[12] = VEC_SHIFT(v_x0[15], 4); v_x0[12] = VEC_SHIFT(v_x0[15], 1);
v_x0[13] = VEC_SHIFT(v_x0[15], 8); v_x0[13] = VEC_SHIFT(v_x0[15], 2);
v_x0[14] = VEC_SHIFT(v_x0[15], 12); v_x0[14] = VEC_SHIFT(v_x0[15], 3);
} }
#endif #endif

View File

@ -28,9 +28,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef SBGEMV_N_MMA_C #ifndef SBGEMV_N_MMA_C
#define SBGEMV_N_MMA_C #define SBGEMV_N_MMA_C
#if !defined(_AIX) || defined(__clang__)
#define USE_BFGEMV_N_MMA #define USE_BFGEMV_N_MMA
#endif
#ifdef USE_BFGEMV_N_MMA #ifdef USE_BFGEMV_N_MMA
#include "sbgemv_common_power10.c" #include "sbgemv_common_power10.c"
@ -67,7 +65,7 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
v_x0[0] = vec_loadN(x_bf, 1); v_x0[0] = vec_loadN(x_bf, 1);
vec_f32 vy0[2*4*2]; vec_f32 vy0[2*4*2];
vec_make_mult1(v_x0); vec_make_mult1(v_x0, false);
for (; i + 8 <= n8; i += 8) { for (; i + 8 <= n8; i += 8) {
vec_load8_pair(vy0, &v_y[(i * 2) + 0]); vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
@ -75,8 +73,7 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]); vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]);
vec_load_mult184_mma(&temp[2], &va0[i + 4], &v_x0[ 0]); vec_load_mult184_mma(&temp[2], &va0[i + 4], &v_x0[ 0]);
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8);
vec_store8_pair(&v_y[(i * 2) + 0], vy0); vec_store8_pair(&v_y[(i * 2) + 0], vy0);
} }
@ -163,16 +160,14 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
vec_f32 vy0[2*4*2]; vec_f32 vy0[2*4*2];
v_x0[0] = vec_loadN(x_bf, 2); v_x0[0] = vec_loadN(x_bf, 2);
vec_make_mult1(v_x0); vec_make_mult1(v_x0, false);
for (; i + 8 <= n8; i += 8) { for (; i + 8 <= n8; i += 8) {
vec_load8_pair(vy0, &v_y[(i * 2) + 0]); vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]);
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8);
vec_store8_pair(&v_y[(i * 2) + 0], vy0); vec_store8_pair(&v_y[(i * 2) + 0], vy0);
} }
@ -268,13 +263,10 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
for (; i + 8 <= n8; i += 8) { for (; i + 8 <= n8; i += 8) {
vec_load8_pair(vy0, &v_y[(i * 2) + 0]); vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); vec_load_mult288b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]);
vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]);
vec_load_mult284b_mma(&temp[2], &va2[i + 4], &va3[i + 4], &v_x0[ 4]);
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8);
vec_store8_pair(&v_y[(i * 2) + 0], vy0); vec_store8_pair(&v_y[(i * 2) + 0], vy0);
} }
@ -386,17 +378,12 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
for (; i + 8 <= n8; i += 8) { for (; i + 8 <= n8; i += 8) {
vec_load8_pair(vy0, &v_y[(i * 2) + 0]); vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); vec_load_mult288b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]);
vec_load_mult284b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]); vec_load_mult288b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]);
vec_load_mult284b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]); vec_load_mult288b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]);
vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]);
vec_load_mult284b_mma(&temp[2], &va2[i + 4], &va3[i + 4], &v_x0[ 4]);
vec_load_mult284b_mma(&temp[2], &vb0[i + 4], &vb1[i + 4], &v_x0[ 8]);
vec_load_mult284b_mma(&temp[2], &vb2[i + 4], &vb3[i + 4], &v_x0[12]);
vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8);
vec_store8_pair(&v_y[(i * 2) + 0], vy0); vec_store8_pair(&v_y[(i * 2) + 0], vy0);
} }