Remove beta from optimized functions.

This commit is contained in:
Chip Kerchner 2024-10-03 13:27:33 -05:00
parent 7cc00f68c9
commit 7ec3c16d82
6 changed files with 101 additions and 96 deletions

View File

@ -282,6 +282,7 @@ GEMM_GEMV_FORWARD = 1
endif endif
ifeq ($(ARCH), power) ifeq ($(ARCH), power)
GEMM_GEMV_FORWARD = 1 GEMM_GEMV_FORWARD = 1
GEMM_GEMV_FORWARD_BF16 = 1
endif endif
ifeq ($(SMALL_MATRIX_OPT), 1) ifeq ($(SMALL_MATRIX_OPT), 1)

View File

@ -122,7 +122,10 @@ FORCEINLINE void copy_x(BLASLONG n, IFLOAT *src, IFLOAT *dest, BLASLONG inc_src)
FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta) FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta)
{ {
if (beta == 0) { if (beta == 0) {
memset(dest, 0, sizeof(FLOAT) * n); for (BLASLONG i = 0; i < n; i++) {
*dest++ = (FLOAT)0;
src += inc_src;
}
} else if (beta == 1) { } else if (beta == 1) {
for (BLASLONG i = 0; i < n; i++) { for (BLASLONG i = 0; i < n; i++) {
*dest++ = *src; *dest++ = *src;
@ -163,4 +166,64 @@ FORCEINLINE void move_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest)
dest += inc_dest; dest += inc_dest;
} }
} }
static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta)
{
if (beta == 0) {
memset(output_vector, 0, sizeof(FLOAT) * n);
} else if (beta == 1) {
if (output_vector != input_vector) {
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
}
} else {
vec_f32 b = { beta, beta, beta, beta };
vec_f32 *in = (vec_f32 *)input_vector;
vec_f32 *out = (vec_f32 *)output_vector;
BLASLONG n8 = n / 8;
BLASLONG i = 0;
vec_f32 v_inp0[2];
for (; i + 4 <= n8; i += 4) {
vec_f32 v_inp1[2], v_inp2[2], v_inp3[2];
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
vec_load_pair(v_inp1, &in[(i * 2) + 2]);
vec_load_pair(v_inp2, &in[(i * 2) + 4]);
vec_load_pair(v_inp3, &in[(i * 2) + 6]);
v_inp0[0] *= b;
v_inp0[1] *= b;
v_inp1[0] *= b;
v_inp1[1] *= b;
v_inp2[0] *= b;
v_inp2[1] *= b;
v_inp3[0] *= b;
v_inp3[1] *= b;
vec_store_pair(&out[(i * 2) + 0], v_inp0);
vec_store_pair(&out[(i * 2) + 2], v_inp1);
vec_store_pair(&out[(i * 2) + 4], v_inp2);
vec_store_pair(&out[(i * 2) + 6], v_inp3);
}
for (; i < n8; i++) {
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
v_inp0[0] *= b;
v_inp0[1] *= b;
vec_store_pair(&out[(i * 2) + 0], v_inp0);
}
n &= 7;
if (n > 4) {
BLASLONG n3 = n & 3;
vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3);
v_inp0[0] *= b;
v_inp0[1] *= b;
vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3);
} else if (n) {
v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n);
v_inp0[0] *= b;
vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n);
}
}
}
#endif #endif

View File

@ -27,65 +27,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef SBGEMV_N_COMMON_C #ifndef SBGEMV_N_COMMON_C
#define SBGEMV_N_COMMON_C #define SBGEMV_N_COMMON_C
static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta)
{
if (beta == 0) {
memset(output_vector, 0, sizeof(FLOAT) * n);
} else if (beta == 1) {
if (output_vector != input_vector) {
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
}
} else {
vec_f32 b = { beta, beta, beta, beta };
vec_f32 *in = (vec_f32 *)input_vector;
vec_f32 *out = (vec_f32 *)output_vector;
BLASLONG n8 = n / 8;
BLASLONG i = 0;
vec_f32 v_inp0[2];
for (; i + 4 <= n8; i += 4) {
vec_f32 v_inp1[2], v_inp2[2], v_inp3[2];
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
vec_load_pair(v_inp1, &in[(i * 2) + 2]);
vec_load_pair(v_inp2, &in[(i * 2) + 4]);
vec_load_pair(v_inp3, &in[(i * 2) + 6]);
v_inp0[0] *= b;
v_inp0[1] *= b;
v_inp1[0] *= b;
v_inp1[1] *= b;
v_inp2[0] *= b;
v_inp2[1] *= b;
v_inp3[0] *= b;
v_inp3[1] *= b;
vec_store_pair(&out[(i * 2) + 0], v_inp0);
vec_store_pair(&out[(i * 2) + 2], v_inp1);
vec_store_pair(&out[(i * 2) + 4], v_inp2);
vec_store_pair(&out[(i * 2) + 6], v_inp3);
}
for (; i < n8; i++) {
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
v_inp0[0] *= b;
v_inp0[1] *= b;
vec_store_pair(&out[(i * 2) + 0], v_inp0);
}
n &= 7;
if (n > 4) {
BLASLONG n3 = n & 3;
vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3);
v_inp0[0] *= b;
v_inp0[1] *= b;
vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3);
} else if (n) {
v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n);
v_inp0[0] *= b;
vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n);
}
}
}
#if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_N_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_N_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_N_VSX)) #if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_N_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_N_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_N_VSX))
#define USE_N_8 #define USE_N_8

View File

@ -41,6 +41,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
if ((m < 1) || (n < 1)) return 0; if ((m < 1) || (n < 1)) return 0;
if (inc_y == 1) {
BF16GEMV_N_beta(n, y, y, beta);
}
xbuffer = buffer; xbuffer = buffer;
BLASLONG lda4 = lda << 2; BLASLONG lda4 = lda << 2;
@ -58,18 +62,21 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
} }
a_ptr = a; a_ptr = a;
a += NB;
y_ptr = y; y_ptr = y;
if (inc_x != 1) { if (inc_x != 1) {
copy_x(NB, x, xbuffer, inc_x); copy_x(NB, x, xbuffer, inc_x);
x += NB * inc_x;
} else { } else {
xbuffer = x; xbuffer = x;
x += NB;
} }
if (inc_y == 1) { if (inc_y == 1) {
#ifdef USE_T_8 #ifdef USE_T_8
for (BLASLONG j = 0; j + 8 <= n; j += 8) { for (BLASLONG j = 0; j + 8 <= n; j += 8) {
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
y_ptr += 8; y_ptr += 8;
a_ptr += lda8; a_ptr += lda8;
} }
@ -77,23 +84,23 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
#else #else
for (BLASLONG j = 0; j + 4 <= n; j += 4) { for (BLASLONG j = 0; j + 4 <= n; j += 4) {
#endif #endif
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
y_ptr += 4; y_ptr += 4;
a_ptr += lda4; a_ptr += lda4;
} }
if (n & 2) { if (n & 2) {
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
y_ptr += 2; y_ptr += 2;
a_ptr += (lda * 2); a_ptr += (lda * 2);
} }
if (n & 1) { if (n & 1) {
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
} }
} else { } else {
#ifdef USE_T_8 #ifdef USE_T_8
for (BLASLONG j = 0; j + 8 <= n; j += 8) { for (BLASLONG j = 0; j + 8 <= n; j += 8) {
memset(ybuffer, 0, sizeof(FLOAT) * 8); memset(ybuffer, 0, sizeof(FLOAT) * 8);
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
copy_y(8, ybuffer, y_ptr, inc_y, beta); copy_y(8, ybuffer, y_ptr, inc_y, beta);
y_ptr += 8 * inc_y; y_ptr += 8 * inc_y;
a_ptr += lda8; a_ptr += lda8;
@ -103,28 +110,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
for (BLASLONG j = 0; j + 4 <= n; j += 4) { for (BLASLONG j = 0; j + 4 <= n; j += 4) {
#endif #endif
memset(ybuffer, 0, sizeof(FLOAT) * 4); memset(ybuffer, 0, sizeof(FLOAT) * 4);
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
copy_y(4, ybuffer, y_ptr, inc_y, beta); copy_y(4, ybuffer, y_ptr, inc_y, beta);
y_ptr += 4 * inc_y; y_ptr += 4 * inc_y;
a_ptr += lda4; a_ptr += lda4;
} }
if (n & 2) { if (n & 2) {
memset(ybuffer, 0, sizeof(FLOAT) * 4); memset(ybuffer, 0, sizeof(FLOAT) * 4);
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
copy_y(2, ybuffer, y_ptr, inc_y, beta); copy_y(2, ybuffer, y_ptr, inc_y, beta);
y_ptr += 2 * inc_y; y_ptr += 2 * inc_y;
a_ptr += (lda * 2); a_ptr += (lda * 2);
} }
if (n & 1) { if (n & 1) {
memset(ybuffer, 0, sizeof(FLOAT) * 4); memset(ybuffer, 0, sizeof(FLOAT) * 4);
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
copy_y(1, ybuffer, y_ptr, inc_y, beta); copy_y(1, ybuffer, y_ptr, inc_y, beta);
} }
beta = (FLOAT)1;
} }
a += NB;
x += NB * inc_x;
beta = (FLOAT)1;
} }
return 0; return 0;

View File

@ -43,7 +43,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define USE_BFGEMV_8_T_MMA #define USE_BFGEMV_8_T_MMA
static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{ {
IFLOAT *a0; IFLOAT *a0;
vec_bf16 *va0, *v_x; vec_bf16 *va0, *v_x;
@ -90,10 +90,10 @@ static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
__builtin_mma_disassemble_acc((void*)temp00, &temp0); __builtin_mma_disassemble_acc((void*)temp00, &temp0);
y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]); y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3]));
} }
static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{ {
IFLOAT *a0, *a1; IFLOAT *a0, *a1;
vec_bf16 *va0, *va1, *v_x; vec_bf16 *va0, *va1, *v_x;
@ -142,11 +142,11 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
vec_reduce_2(temp00, &temp0[0]); vec_reduce_2(temp00, &temp0[0]);
y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]); y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3]));
y[1] = (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3])) + (beta * y[1]); y[1] += (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3]));
} }
static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{ {
IFLOAT *a0, *a1, *a2, *a3; IFLOAT *a0, *a1, *a2, *a3;
vec_bf16 *va0, *va1, *va2, *va3, *v_x; vec_bf16 *va0, *va1, *va2, *va3, *v_x;
@ -201,7 +201,6 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7; vec_f32 t0, t1, t2, t3, t4, t5, t6, t7;
vec_f32 a = { alpha, alpha, alpha, alpha }; vec_f32 a = { alpha, alpha, alpha, alpha };
vec_f32 b = { beta, beta, beta, beta };
vec_f32 *v_y = (vec_f32 *) y; vec_f32 *v_y = (vec_f32 *) y;
t0 = vec_mergeh(temp00[ 0], temp00[ 4]); t0 = vec_mergeh(temp00[ 0], temp00[ 4]);
@ -219,11 +218,11 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
t0 += t2 + t4 + t6; t0 += t2 + t4 + t6;
v_y[0] = (a * t0) + (b * v_y[0]); v_y[0] += (a * t0);
} }
#ifdef USE_BFGEMV_8_T_MMA #ifdef USE_BFGEMV_8_T_MMA
static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{ {
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x; vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x;
@ -291,7 +290,6 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17; vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17;
vec_f32 a = { alpha, alpha, alpha, alpha }; vec_f32 a = { alpha, alpha, alpha, alpha };
vec_f32 b = { beta, beta, beta, beta };
vec_f32 *v_y = (vec_f32 *) y; vec_f32 *v_y = (vec_f32 *) y;
t0 = vec_mergeh(temp00[ 0], temp00[ 4]); t0 = vec_mergeh(temp00[ 0], temp00[ 4]);
@ -326,8 +324,8 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
vec_f32 inp2[2]; vec_f32 inp2[2];
vec_load_pair(inp2, v_y); vec_load_pair(inp2, v_y);
inp2[0] = (a * t0) + (b * inp2[0]); inp2[0] += (a * t0);
inp2[1] = (a * t10) + (b * inp2[1]); inp2[1] += (a * t10);
vec_store_pair(v_y, inp2); vec_store_pair(v_y, inp2);
} }
#endif #endif

View File

@ -40,7 +40,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define USE_BFGEMV_8_T_VSX #define USE_BFGEMV_8_T_VSX
static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{ {
IFLOAT *a0; IFLOAT *a0;
vec_bf16 *va0, *v_x; vec_bf16 *va0, *v_x;
@ -71,10 +71,10 @@ static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero);
} }
y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]); y[0] += (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3]));
} }
static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{ {
IFLOAT *a0, *a1; IFLOAT *a0, *a1;
vec_bf16 *va0, *va1, *v_x; vec_bf16 *va0, *va1, *v_x;
@ -111,11 +111,11 @@ static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero); temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero);
} }
y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]); y[0] += (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3]));
y[1] = (alpha * (temp1[0] + temp1[1] + temp1[2] + temp1[3])) + (beta * y[1]); y[1] += (alpha * (temp1[0] + temp1[1] + temp1[2] + temp1[3]));
} }
static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{ {
IFLOAT *a0, *a1, *a2, *a3; IFLOAT *a0, *a1, *a2, *a3;
vec_bf16 *va0, *va1, *va2, *va3, *v_x; vec_bf16 *va0, *va1, *va2, *va3, *v_x;
@ -166,7 +166,6 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
vec_f32 t0, t1, t2, t3; vec_f32 t0, t1, t2, t3;
vec_f32 a = { alpha, alpha, alpha, alpha }; vec_f32 a = { alpha, alpha, alpha, alpha };
vec_f32 b = { beta, beta, beta, beta };
vec_f32 *v_y = (vec_f32 *) y; vec_f32 *v_y = (vec_f32 *) y;
t0 = vec_mergeh(temp0, temp2); t0 = vec_mergeh(temp0, temp2);
@ -179,11 +178,11 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
temp3 = vec_mergel(t1, t3); temp3 = vec_mergel(t1, t3);
temp0 += temp1 + temp2 + temp3; temp0 += temp1 + temp2 + temp3;
v_y[0] = (a * temp0) + (b * v_y[0]); v_y[0] += (a * temp0);
} }
#ifdef USE_BFGEMV_8_T_VSX #ifdef USE_BFGEMV_8_T_VSX
static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
{ {
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x; vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x;
@ -259,7 +258,6 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
vec_f32 t0, t1, t2, t3, t10, t11, t12, t13; vec_f32 t0, t1, t2, t3, t10, t11, t12, t13;
vec_f32 a = { alpha, alpha, alpha, alpha }; vec_f32 a = { alpha, alpha, alpha, alpha };
vec_f32 b = { beta, beta, beta, beta };
vec_f32 *v_y = (vec_f32 *) y; vec_f32 *v_y = (vec_f32 *) y;
t0 = vec_mergeh(temp0, temp2); t0 = vec_mergeh(temp0, temp2);
@ -283,8 +281,8 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
temp4 += temp5 + temp6 + temp7; temp4 += temp5 + temp6 + temp7;
vec_load_pair(inp, v_y); vec_load_pair(inp, v_y);
inp[0] = (a * temp0) + (b * inp[0]); inp[0] += (a * temp0);
inp[1] = (a * temp4) + (b * inp[1]); inp[1] += (a * temp4);
vec_store_pair(v_y, inp); vec_store_pair(v_y, inp);
} }
#endif #endif