Special case beta is one.

This commit is contained in:
Chip Kerchner 2024-09-06 14:48:48 -05:00
parent 76227e2948
commit 8541b25e1d
2 changed files with 12 additions and 0 deletions

View File

@ -252,6 +252,11 @@ FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_s
{ {
if (beta == 0) { if (beta == 0) {
memset(dest, 0, sizeof(FLOAT) * n); memset(dest, 0, sizeof(FLOAT) * n);
} else if (beta == 1) {
for (BLASLONG i = 0; i < n; i++) {
*dest++ = *src;
src += inc_src;
}
} else { } else {
for (BLASLONG i = 0; i < n; i++) { for (BLASLONG i = 0; i < n; i++) {
*dest++ = *src * beta; *dest++ = *src * beta;
@ -267,6 +272,11 @@ FORCEINLINE void copy_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, F
*dest = *src++; *dest = *src++;
dest += inc_src; dest += inc_src;
} }
} else if (beta == 1) {
for (BLASLONG i = 0; i < n; i++) {
*dest += *src++;
dest += inc_src;
}
} else { } else {
for (BLASLONG i = 0; i < n; i++) { for (BLASLONG i = 0; i < n; i++) {
*dest = *src++ + (beta * *dest); *dest = *src++ + (beta * *dest);

View File

@ -31,6 +31,8 @@ static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vecto
{ {
if (beta == 0) { if (beta == 0) {
memset(output_vector, 0, sizeof(FLOAT) * n); memset(output_vector, 0, sizeof(FLOAT) * n);
} else if ((output_vector != input_vector) && (beta == 1)) {
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
} else { } else {
vec_f32 b = { beta, beta, beta, beta }; vec_f32 b = { beta, beta, beta, beta };