Remove beta from optimized functions.

This commit is contained in:
Chip Kerchner 2024-10-03 13:27:33 -05:00
parent 7cc00f68c9
commit 7ec3c16d82
6 changed files with 101 additions and 96 deletions

View File

@ -282,6 +282,7 @@ GEMM_GEMV_FORWARD = 1
endif
ifeq ($(ARCH), power)
GEMM_GEMV_FORWARD = 1
GEMM_GEMV_FORWARD_BF16 = 1
endif
ifeq ($(SMALL_MATRIX_OPT), 1)

View File

@ -122,7 +122,10 @@ FORCEINLINE void copy_x(BLASLONG n, IFLOAT *src, IFLOAT *dest, BLASLONG inc_src)
FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta)
{
if (beta == 0) {
memset(dest, 0, sizeof(FLOAT) * n);
for (BLASLONG i = 0; i < n; i++) {
*dest++ = (FLOAT)0;
src += inc_src;
}
} else if (beta == 1) {
for (BLASLONG i = 0; i < n; i++) {
*dest++ = *src;
@ -163,4 +166,64 @@ FORCEINLINE void move_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest)
dest += inc_dest;
}
}
static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta)
{
if (beta == 0) {
memset(output_vector, 0, sizeof(FLOAT) * n);
} else if (beta == 1) {
if (output_vector != input_vector) {
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
}
} else {
vec_f32 b = { beta, beta, beta, beta };
vec_f32 *in = (vec_f32 *)input_vector;
vec_f32 *out = (vec_f32 *)output_vector;
BLASLONG n8 = n / 8;
BLASLONG i = 0;
vec_f32 v_inp0[2];
for (; i + 4 <= n8; i += 4) {
vec_f32 v_inp1[2], v_inp2[2], v_inp3[2];
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
vec_load_pair(v_inp1, &in[(i * 2) + 2]);
vec_load_pair(v_inp2, &in[(i * 2) + 4]);
vec_load_pair(v_inp3, &in[(i * 2) + 6]);
v_inp0[0] *= b;
v_inp0[1] *= b;
v_inp1[0] *= b;
v_inp1[1] *= b;
v_inp2[0] *= b;
v_inp2[1] *= b;
v_inp3[0] *= b;
v_inp3[1] *= b;
vec_store_pair(&out[(i * 2) + 0], v_inp0);
vec_store_pair(&out[(i * 2) + 2], v_inp1);
vec_store_pair(&out[(i * 2) + 4], v_inp2);
vec_store_pair(&out[(i * 2) + 6], v_inp3);
}
for (; i < n8; i++) {
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
v_inp0[0] *= b;
v_inp0[1] *= b;
vec_store_pair(&out[(i * 2) + 0], v_inp0);
}
n &= 7;
if (n > 4) {
BLASLONG n3 = n & 3;
vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3);
v_inp0[0] *= b;
v_inp0[1] *= b;
vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3);
} else if (n) {
v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n);
v_inp0[0] *= b;
vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n);
}
}
}
#endif

View File

@ -27,65 +27,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef SBGEMV_N_COMMON_C
#define SBGEMV_N_COMMON_C
static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta)
{
if (beta == 0) {
memset(output_vector, 0, sizeof(FLOAT) * n);
} else if (beta == 1) {
if (output_vector != input_vector) {
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
}
} else {
vec_f32 b = { beta, beta, beta, beta };
vec_f32 *in = (vec_f32 *)input_vector;
vec_f32 *out = (vec_f32 *)output_vector;
BLASLONG n8 = n / 8;
BLASLONG i = 0;
vec_f32 v_inp0[2];
for (; i + 4 <= n8; i += 4) {
vec_f32 v_inp1[2], v_inp2[2], v_inp3[2];
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
vec_load_pair(v_inp1, &in[(i * 2) + 2]);
vec_load_pair(v_inp2, &in[(i * 2) + 4]);
vec_load_pair(v_inp3, &in[(i * 2) + 6]);
v_inp0[0] *= b;
v_inp0[1] *= b;
v_inp1[0] *= b;
v_inp1[1] *= b;
v_inp2[0] *= b;
v_inp2[1] *= b;
v_inp3[0] *= b;
v_inp3[1] *= b;
vec_store_pair(&out[(i * 2) + 0], v_inp0);
vec_store_pair(&out[(i * 2) + 2], v_inp1);
vec_store_pair(&out[(i * 2) + 4], v_inp2);
vec_store_pair(&out[(i * 2) + 6], v_inp3);
}
for (; i < n8; i++) {
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
v_inp0[0] *= b;
v_inp0[1] *= b;
vec_store_pair(&out[(i * 2) + 0], v_inp0);
}
n &= 7;
if (n > 4) {
BLASLONG n3 = n & 3;
vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3);
v_inp0[0] *= b;
v_inp0[1] *= b;
vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3);
} else if (n) {
v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n);
v_inp0[0] *= b;
vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n);
}
}
}
#if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_N_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_N_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_N_VSX))
#define USE_N_8

View File

@ -41,6 +41,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
if ((m < 1) || (n < 1)) return 0;
if (inc_y == 1) {
BF16GEMV_N_beta(n, y, y, beta);
}
xbuffer = buffer;
BLASLONG lda4 = lda << 2;
@ -58,18 +62,21 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
}
a_ptr = a;
a += NB;
y_ptr = y;
if (inc_x != 1) {
copy_x(NB, x, xbuffer, inc_x);
x += NB * inc_x;
} else {
xbuffer = x;
x += NB;
}
if (inc_y == 1) {
#ifdef USE_T_8
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
y_ptr += 8;
a_ptr += lda8;
}
@ -77,23 +84,23 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
#else
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
#endif
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
y_ptr += 4;
a_ptr += lda4;
}
if (n & 2) {
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
y_ptr += 2;
a_ptr += (lda * 2);
}
if (n & 1) {
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
}
} else {
#ifdef USE_T_8
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
memset(ybuffer, 0, sizeof(FLOAT) * 8);
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
copy_y(8, ybuffer, y_ptr, inc_y, beta);
y_ptr += 8 * inc_y;
a_ptr += lda8;
@ -103,29 +110,26 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
#endif
memset(ybuffer, 0, sizeof(FLOAT) * 4);
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
copy_y(4, ybuffer, y_ptr, inc_y, beta);
y_ptr += 4 * inc_y;
a_ptr += lda4;
}
if (n & 2) {
memset(ybuffer, 0, sizeof(FLOAT) * 4);
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
copy_y(2, ybuffer, y_ptr, inc_y, beta);
y_ptr += 2 * inc_y;
a_ptr += (lda * 2);
}
if (n & 1) {
memset(ybuffer, 0, sizeof(FLOAT) * 4);
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
copy_y(1, ybuffer, y_ptr, inc_y, beta);
}
}
a += NB;
x += NB * inc_x;
beta = (FLOAT)1;
}
}
return 0;
}

View File

@ -43,7 +43,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define USE_BFGEMV_8_T_MMA
static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{
IFLOAT *a0;
vec_bf16 *va0, *v_x;
@ -90,10 +90,10 @@ static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
__builtin_mma_disassemble_acc((void*)temp00, &temp0);
y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]);
y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3]));
}
static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{
IFLOAT *a0, *a1;
vec_bf16 *va0, *va1, *v_x;
@ -142,11 +142,11 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
vec_reduce_2(temp00, &temp0[0]);
y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]);
y[1] = (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3])) + (beta * y[1]);
y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3]));
y[1] += (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3]));
}
static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{
IFLOAT *a0, *a1, *a2, *a3;
vec_bf16 *va0, *va1, *va2, *va3, *v_x;
@ -201,7 +201,6 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7;
vec_f32 a = { alpha, alpha, alpha, alpha };
vec_f32 b = { beta, beta, beta, beta };
vec_f32 *v_y = (vec_f32 *) y;
t0 = vec_mergeh(temp00[ 0], temp00[ 4]);
@ -219,11 +218,11 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
t0 += t2 + t4 + t6;
v_y[0] = (a * t0) + (b * v_y[0]);
v_y[0] += (a * t0);
}
#ifdef USE_BFGEMV_8_T_MMA
static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x;
@ -291,7 +290,6 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17;
vec_f32 a = { alpha, alpha, alpha, alpha };
vec_f32 b = { beta, beta, beta, beta };
vec_f32 *v_y = (vec_f32 *) y;
t0 = vec_mergeh(temp00[ 0], temp00[ 4]);
@ -326,8 +324,8 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
vec_f32 inp2[2];
vec_load_pair(inp2, v_y);
inp2[0] = (a * t0) + (b * inp2[0]);
inp2[1] = (a * t10) + (b * inp2[1]);
inp2[0] += (a * t0);
inp2[1] += (a * t10);
vec_store_pair(v_y, inp2);
}
#endif

View File

@ -40,7 +40,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define USE_BFGEMV_8_T_VSX
static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{
IFLOAT *a0;
vec_bf16 *va0, *v_x;
@ -71,10 +71,10 @@ static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero);
}
y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]);
y[0] += (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3]));
}
static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{
IFLOAT *a0, *a1;
vec_bf16 *va0, *va1, *v_x;
@ -111,11 +111,11 @@ static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero);
}
y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]);
y[1] = (alpha * (temp1[0] + temp1[1] + temp1[2] + temp1[3])) + (beta * y[1]);
y[0] += (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3]));
y[1] += (alpha * (temp1[0] + temp1[1] + temp1[2] + temp1[3]));
}
static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{
IFLOAT *a0, *a1, *a2, *a3;
vec_bf16 *va0, *va1, *va2, *va3, *v_x;
@ -166,7 +166,6 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
vec_f32 t0, t1, t2, t3;
vec_f32 a = { alpha, alpha, alpha, alpha };
vec_f32 b = { beta, beta, beta, beta };
vec_f32 *v_y = (vec_f32 *) y;
t0 = vec_mergeh(temp0, temp2);
@ -179,11 +178,11 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
temp3 = vec_mergel(t1, t3);
temp0 += temp1 + temp2 + temp3;
v_y[0] = (a * temp0) + (b * v_y[0]);
v_y[0] += (a * temp0);
}
#ifdef USE_BFGEMV_8_T_VSX
static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x;
@ -259,7 +258,6 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
vec_f32 t0, t1, t2, t3, t10, t11, t12, t13;
vec_f32 a = { alpha, alpha, alpha, alpha };
vec_f32 b = { beta, beta, beta, beta };
vec_f32 *v_y = (vec_f32 *) y;
t0 = vec_mergeh(temp0, temp2);
@ -283,8 +281,8 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
temp4 += temp5 + temp6 + temp7;
vec_load_pair(inp, v_y);
inp[0] = (a * temp0) + (b * inp[0]);
inp[1] = (a * temp4) + (b * inp[1]);
inp[0] += (a * temp0);
inp[1] += (a * temp4);
vec_store_pair(v_y, inp);
}
#endif