Special case beta is one.
This commit is contained in:
parent
76227e2948
commit
8541b25e1d
|
@ -252,6 +252,11 @@ FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_s
|
||||||
{
|
{
|
||||||
if (beta == 0) {
|
if (beta == 0) {
|
||||||
memset(dest, 0, sizeof(FLOAT) * n);
|
memset(dest, 0, sizeof(FLOAT) * n);
|
||||||
|
} else if (beta == 1) {
|
||||||
|
for (BLASLONG i = 0; i < n; i++) {
|
||||||
|
*dest++ = *src;
|
||||||
|
src += inc_src;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
for (BLASLONG i = 0; i < n; i++) {
|
for (BLASLONG i = 0; i < n; i++) {
|
||||||
*dest++ = *src * beta;
|
*dest++ = *src * beta;
|
||||||
|
@ -267,6 +272,11 @@ FORCEINLINE void copy_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, F
|
||||||
*dest = *src++;
|
*dest = *src++;
|
||||||
dest += inc_src;
|
dest += inc_src;
|
||||||
}
|
}
|
||||||
|
} else if (beta == 1) {
|
||||||
|
for (BLASLONG i = 0; i < n; i++) {
|
||||||
|
*dest += *src++;
|
||||||
|
dest += inc_src;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
for (BLASLONG i = 0; i < n; i++) {
|
for (BLASLONG i = 0; i < n; i++) {
|
||||||
*dest = *src++ + (beta * *dest);
|
*dest = *src++ + (beta * *dest);
|
||||||
|
|
|
@ -31,6 +31,8 @@ static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vecto
|
||||||
{
|
{
|
||||||
if (beta == 0) {
|
if (beta == 0) {
|
||||||
memset(output_vector, 0, sizeof(FLOAT) * n);
|
memset(output_vector, 0, sizeof(FLOAT) * n);
|
||||||
|
} else if ((output_vector != input_vector) && (beta == 1)) {
|
||||||
|
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
|
||||||
} else {
|
} else {
|
||||||
vec_f32 b = { beta, beta, beta, beta };
|
vec_f32 b = { beta, beta, beta, beta };
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue