diff --git a/Makefile.system b/Makefile.system index 2c5ca9690..8c030842a 100644 --- a/Makefile.system +++ b/Makefile.system @@ -282,6 +282,7 @@ GEMM_GEMV_FORWARD = 1 endif ifeq ($(ARCH), power) GEMM_GEMV_FORWARD = 1 +GEMM_GEMV_FORWARD_BF16 = 1 endif ifeq ($(SMALL_MATRIX_OPT), 1) diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c index 156eadce7..47de837cc 100644 --- a/kernel/power/sbgemv_common.c +++ b/kernel/power/sbgemv_common.c @@ -122,7 +122,10 @@ FORCEINLINE void copy_x(BLASLONG n, IFLOAT *src, IFLOAT *dest, BLASLONG inc_src) FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta) { if (beta == 0) { - memset(dest, 0, sizeof(FLOAT) * n); + for (BLASLONG i = 0; i < n; i++) { + *dest++ = (FLOAT)0; + src += inc_src; + } } else if (beta == 1) { for (BLASLONG i = 0; i < n; i++) { *dest++ = *src; @@ -163,4 +166,64 @@ FORCEINLINE void move_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest) dest += inc_dest; } } + +static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta) +{ + if (beta == 0) { + memset(output_vector, 0, sizeof(FLOAT) * n); + } else if (beta == 1) { + if (output_vector != input_vector) { + memcpy(output_vector, input_vector, sizeof(FLOAT) * n); + } + } else { + vec_f32 b = { beta, beta, beta, beta }; + + vec_f32 *in = (vec_f32 *)input_vector; + vec_f32 *out = (vec_f32 *)output_vector; + + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 v_inp0[2]; + + for (; i + 4 <= n8; i += 4) { + vec_f32 v_inp1[2], v_inp2[2], v_inp3[2]; + vec_load_pair(v_inp0, &in[(i * 2) + 0]); + vec_load_pair(v_inp1, &in[(i * 2) + 2]); + vec_load_pair(v_inp2, &in[(i * 2) + 4]); + vec_load_pair(v_inp3, &in[(i * 2) + 6]); + v_inp0[0] *= b; + v_inp0[1] *= b; + v_inp1[0] *= b; + v_inp1[1] *= b; + v_inp2[0] *= b; + v_inp2[1] *= b; + v_inp3[0] *= b; + v_inp3[1] *= b; + vec_store_pair(&out[(i * 2) + 0], v_inp0); + vec_store_pair(&out[(i * 2) + 2], v_inp1); + vec_store_pair(&out[(i * 2) + 4], v_inp2); + vec_store_pair(&out[(i * 2) + 6], v_inp3); + } + + for (; i < n8; i++) { + vec_load_pair(v_inp0, &in[(i * 2) + 0]); + v_inp0[0] *= b; + v_inp0[1] *= b; + vec_store_pair(&out[(i * 2) + 0], v_inp0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3); + v_inp0[0] *= b; + v_inp0[1] *= b; + vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3); + } else if (n) { + v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n); + v_inp0[0] *= b; + vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n); + } + } +} #endif diff --git a/kernel/power/sbgemv_n.c b/kernel/power/sbgemv_n.c index eab0b4e33..e6f7f587e 100644 --- a/kernel/power/sbgemv_n.c +++ b/kernel/power/sbgemv_n.c @@ -27,65 +27,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef SBGEMV_N_COMMON_C #define SBGEMV_N_COMMON_C -static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta) -{ - if (beta == 0) { - memset(output_vector, 0, sizeof(FLOAT) * n); - } else if (beta == 1) { - if (output_vector != input_vector) { - memcpy(output_vector, input_vector, sizeof(FLOAT) * n); - } - } else { - vec_f32 b = { beta, beta, beta, beta }; - - vec_f32 *in = (vec_f32 *)input_vector; - vec_f32 *out = (vec_f32 *)output_vector; - - BLASLONG n8 = n / 8; - BLASLONG i = 0; - vec_f32 v_inp0[2]; - - for (; i + 4 <= n8; i += 4) { - vec_f32 v_inp1[2], v_inp2[2], v_inp3[2]; - vec_load_pair(v_inp0, &in[(i * 2) + 0]); - vec_load_pair(v_inp1, &in[(i * 2) + 2]); - vec_load_pair(v_inp2, &in[(i * 2) + 4]); - vec_load_pair(v_inp3, &in[(i * 2) + 6]); - v_inp0[0] *= b; - v_inp0[1] *= b; - v_inp1[0] *= b; - v_inp1[1] *= b; - v_inp2[0] *= b; - v_inp2[1] *= b; - v_inp3[0] *= b; - v_inp3[1] *= b; - vec_store_pair(&out[(i * 2) + 0], v_inp0); - vec_store_pair(&out[(i * 2) + 2], v_inp1); - vec_store_pair(&out[(i * 2) + 4], v_inp2); - vec_store_pair(&out[(i * 2) + 6], v_inp3); - } - - for (; i < n8; i++) { - vec_load_pair(v_inp0, &in[(i * 2) + 0]); - v_inp0[0] *= b; - v_inp0[1] *= b; - vec_store_pair(&out[(i * 2) + 0], v_inp0); - } - - n &= 7; - if (n > 4) { - BLASLONG n3 = n & 3; - vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3); - v_inp0[0] *= b; - v_inp0[1] *= b; - vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3); - } else if (n) { - v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n); - v_inp0[0] *= b; - vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n); - } - } -} #if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_N_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_N_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_N_VSX)) #define USE_N_8 diff --git a/kernel/power/sbgemv_t.c b/kernel/power/sbgemv_t.c index c6fdb6b1a..594b1fc57 100644 --- a/kernel/power/sbgemv_t.c +++ b/kernel/power/sbgemv_t.c @@ -41,6 +41,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * if ((m < 1) || (n < 1)) return 0; + if (inc_y == 1) { + BF16GEMV_N_beta(n, y, y, beta); + } + xbuffer = buffer; BLASLONG lda4 = lda << 2; @@ -58,18 +62,21 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * } a_ptr = a; + a += NB; y_ptr = y; if (inc_x != 1) { copy_x(NB, x, xbuffer, inc_x); + x += NB * inc_x; } else { xbuffer = x; + x += NB; } if (inc_y == 1) { #ifdef USE_T_8 for (BLASLONG j = 0; j + 8 <= n; j += 8) { - BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha); y_ptr += 8; a_ptr += lda8; } @@ -77,23 +84,23 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * #else for (BLASLONG j = 0; j + 4 <= n; j += 4) { #endif - BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha); y_ptr += 4; a_ptr += lda4; } if (n & 2) { - BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha); y_ptr += 2; a_ptr += (lda * 2); } if (n & 1) { - BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha); } } else { #ifdef USE_T_8 for (BLASLONG j = 0; j + 8 <= n; j += 8) { memset(ybuffer, 0, sizeof(FLOAT) * 8); - BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha); copy_y(8, ybuffer, y_ptr, inc_y, beta); y_ptr += 8 * inc_y; a_ptr += lda8; @@ -103,28 +110,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * for (BLASLONG j = 0; j + 4 <= n; j += 4) { #endif memset(ybuffer, 0, sizeof(FLOAT) * 4); - BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha); copy_y(4, ybuffer, y_ptr, inc_y, beta); y_ptr += 4 * inc_y; a_ptr += lda4; } if (n & 2) { memset(ybuffer, 0, sizeof(FLOAT) * 4); - BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha); copy_y(2, ybuffer, y_ptr, inc_y, beta); y_ptr += 2 * inc_y; a_ptr += (lda * 2); } if (n & 1) { memset(ybuffer, 0, sizeof(FLOAT) * 4); - BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha); copy_y(1, ybuffer, y_ptr, inc_y, beta); } + beta = (FLOAT)1; } - - a += NB; - x += NB * inc_x; - beta = (FLOAT)1; } return 0; diff --git a/kernel/power/sbgemv_t_power10.c b/kernel/power/sbgemv_t_power10.c index d2f6087f0..40c166354 100644 --- a/kernel/power/sbgemv_t_power10.c +++ b/kernel/power/sbgemv_t_power10.c @@ -43,7 +43,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define USE_BFGEMV_8_T_MMA -static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0; vec_bf16 *va0, *v_x; @@ -90,10 +90,10 @@ static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL __builtin_mma_disassemble_acc((void*)temp00, &temp0); - y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]); + y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])); } -static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0, *a1; vec_bf16 *va0, *va1, *v_x; @@ -142,11 +142,11 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_reduce_2(temp00, &temp0[0]); - y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]); - y[1] = (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3])) + (beta * y[1]); + y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])); + y[1] += (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3])); } -static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0, *a1, *a2, *a3; vec_bf16 *va0, *va1, *va2, *va3, *v_x; @@ -201,7 +201,6 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_f32 t0, t1, t2, t3, t4, t5, t6, t7; vec_f32 a = { alpha, alpha, alpha, alpha }; - vec_f32 b = { beta, beta, beta, beta }; vec_f32 *v_y = (vec_f32 *) y; t0 = vec_mergeh(temp00[ 0], temp00[ 4]); @@ -219,11 +218,11 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL t0 += t2 + t4 + t6; - v_y[0] = (a * t0) + (b * v_y[0]); + v_y[0] += (a * t0); } #ifdef USE_BFGEMV_8_T_MMA -static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x; @@ -291,7 +290,6 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17; vec_f32 a = { alpha, alpha, alpha, alpha }; - vec_f32 b = { beta, beta, beta, beta }; vec_f32 *v_y = (vec_f32 *) y; t0 = vec_mergeh(temp00[ 0], temp00[ 4]); @@ -326,8 +324,8 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_f32 inp2[2]; vec_load_pair(inp2, v_y); - inp2[0] = (a * t0) + (b * inp2[0]); - inp2[1] = (a * t10) + (b * inp2[1]); + inp2[0] += (a * t0); + inp2[1] += (a * t10); vec_store_pair(v_y, inp2); } #endif diff --git a/kernel/power/sbgemv_t_vsx.c b/kernel/power/sbgemv_t_vsx.c index 9d5e6d997..e72d2f31e 100644 --- a/kernel/power/sbgemv_t_vsx.c +++ b/kernel/power/sbgemv_t_vsx.c @@ -40,7 +40,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define USE_BFGEMV_8_T_VSX -static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0; vec_bf16 *va0, *v_x; @@ -71,10 +71,10 @@ static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); } - y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]); + y[0] += (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])); } -static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0, *a1; vec_bf16 *va0, *va1, *v_x; @@ -111,11 +111,11 @@ static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero); } - y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]); - y[1] = (alpha * (temp1[0] + temp1[1] + temp1[2] + temp1[3])) + (beta * y[1]); + y[0] += (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])); + y[1] += (alpha * (temp1[0] + temp1[1] + temp1[2] + temp1[3])); } -static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0, *a1, *a2, *a3; vec_bf16 *va0, *va1, *va2, *va3, *v_x; @@ -166,7 +166,6 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_f32 t0, t1, t2, t3; vec_f32 a = { alpha, alpha, alpha, alpha }; - vec_f32 b = { beta, beta, beta, beta }; vec_f32 *v_y = (vec_f32 *) y; t0 = vec_mergeh(temp0, temp2); @@ -179,11 +178,11 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp3 = vec_mergel(t1, t3); temp0 += temp1 + temp2 + temp3; - v_y[0] = (a * temp0) + (b * v_y[0]); + v_y[0] += (a * temp0); } #ifdef USE_BFGEMV_8_T_VSX -static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x; @@ -259,7 +258,6 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_f32 t0, t1, t2, t3, t10, t11, t12, t13; vec_f32 a = { alpha, alpha, alpha, alpha }; - vec_f32 b = { beta, beta, beta, beta }; vec_f32 *v_y = (vec_f32 *) y; t0 = vec_mergeh(temp0, temp2); @@ -283,8 +281,8 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp4 += temp5 + temp6 + temp7; vec_load_pair(inp, v_y); - inp[0] = (a * temp0) + (b * inp[0]); - inp[1] = (a * temp4) + (b * inp[1]); + inp[0] += (a * temp0); + inp[1] += (a * temp4); vec_store_pair(v_y, inp); } #endif