diff --git a/Makefile.system b/Makefile.system index 8c030842a..2c5ca9690 100644 --- a/Makefile.system +++ b/Makefile.system @@ -282,7 +282,6 @@ GEMM_GEMV_FORWARD = 1 endif ifeq ($(ARCH), power) GEMM_GEMV_FORWARD = 1 -GEMM_GEMV_FORWARD_BF16 = 1 endif ifeq ($(SMALL_MATRIX_OPT), 1) diff --git a/kernel/power/sbgemv_n.c b/kernel/power/sbgemv_n.c index db64915e0..fa7df858f 100644 --- a/kernel/power/sbgemv_n.c +++ b/kernel/power/sbgemv_n.c @@ -31,8 +31,10 @@ static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vecto { if (beta == 0) { memset(output_vector, 0, sizeof(FLOAT) * n); - } else if ((output_vector != input_vector) && (beta == 1)) { - memcpy(output_vector, input_vector, sizeof(FLOAT) * n); + } else if (beta == 1) { + if (output_vector != input_vector) { + memcpy(output_vector, input_vector, sizeof(FLOAT) * n); + } } else { vec_f32 b = { beta, beta, beta, beta }; diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c index a86c73d1c..05d9b33ab 100644 --- a/test/compare_sgemm_sbgemm.c +++ b/test/compare_sgemm_sbgemm.c @@ -205,15 +205,14 @@ main (int argc, char *argv[]) for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one. for (x = 1; x <= loop; x++) { - m = l + 1; - k = (x == 0) ? 0 : m; + k = (x == 0) ? 0 : l + 1; float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); - float *B = (float *)malloc_safe(x * sizeof(FLOAT) * m); - float *C = (float *)malloc_safe(x * sizeof(FLOAT) * m); + float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l); + float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l); bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits)); - bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) * m); + bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) << l); float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); - float *CC = (float *)malloc_safe(x * sizeof(FLOAT) * m); + float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l); if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || (DD == NULL) || (CC == NULL)) return 1; @@ -228,9 +227,9 @@ main (int argc, char *argv[]) sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one); AA[j * x + i].v = atmp; } - B[j*m] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - sbstobf16_(&one, &B[j*m], &one, &btmp, &one); - BB[j*m].v = btmp; + B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + sbstobf16_(&one, &B[j << l], &one, &btmp, &one); + BB[j << l].v = btmp; } for (y = 0; y < 2; y++) { @@ -240,9 +239,9 @@ main (int argc, char *argv[]) transA = 'T'; } - memset(CC, 0, x * m * sizeof(FLOAT)); + memset(CC, 0, x * sizeof(FLOAT) << l); memset(DD, 0, x * sizeof(FLOAT)); - memset(C, 0, x * m * sizeof(FLOAT)); + memset(C, 0, x * sizeof(FLOAT) << l); SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k); SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k); @@ -250,15 +249,15 @@ main (int argc, char *argv[]) for (j = 0; j < x; j++) for (i = 0; i < x; i++) if (transA == 'N') { - DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j*m]); + DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]); } else if (transA == 'T') { - DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i*m]); + DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]); } for (j = 0; j < x; j++) { - if (fabs (CC[j*m] - C[j*m]) > 1.0) + if (fabs (CC[j << l] - C[j << l]) > 1.0) ret++; - if (fabs (CC[j*m] - DD[j]) > 1.0) + if (fabs (CC[j << l] - DD[j]) > 1.0) ret++; } }