Minor improvement and turn off BF16 GEMV forwarding by default.
This commit is contained in:
parent
8541b25e1d
commit
39fd29f1de
|
@ -282,7 +282,6 @@ GEMM_GEMV_FORWARD = 1
|
|||
endif
|
||||
ifeq ($(ARCH), power)
|
||||
GEMM_GEMV_FORWARD = 1
|
||||
GEMM_GEMV_FORWARD_BF16 = 1
|
||||
endif
|
||||
|
||||
ifeq ($(SMALL_MATRIX_OPT), 1)
|
||||
|
|
|
@ -31,8 +31,10 @@ static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vecto
|
|||
{
|
||||
if (beta == 0) {
|
||||
memset(output_vector, 0, sizeof(FLOAT) * n);
|
||||
} else if ((output_vector != input_vector) && (beta == 1)) {
|
||||
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
|
||||
} else if (beta == 1) {
|
||||
if (output_vector != input_vector) {
|
||||
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
|
||||
}
|
||||
} else {
|
||||
vec_f32 b = { beta, beta, beta, beta };
|
||||
|
||||
|
|
|
@ -205,15 +205,14 @@ main (int argc, char *argv[])
|
|||
for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one.
|
||||
for (x = 1; x <= loop; x++)
|
||||
{
|
||||
m = l + 1;
|
||||
k = (x == 0) ? 0 : m;
|
||||
k = (x == 0) ? 0 : l + 1;
|
||||
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT));
|
||||
float *B = (float *)malloc_safe(x * sizeof(FLOAT) * m);
|
||||
float *C = (float *)malloc_safe(x * sizeof(FLOAT) * m);
|
||||
float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l);
|
||||
float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l);
|
||||
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits));
|
||||
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) * m);
|
||||
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) << l);
|
||||
float *DD = (float *)malloc_safe(x * sizeof(FLOAT));
|
||||
float *CC = (float *)malloc_safe(x * sizeof(FLOAT) * m);
|
||||
float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l);
|
||||
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
|
||||
(DD == NULL) || (CC == NULL))
|
||||
return 1;
|
||||
|
@ -228,9 +227,9 @@ main (int argc, char *argv[])
|
|||
sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one);
|
||||
AA[j * x + i].v = atmp;
|
||||
}
|
||||
B[j*m] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
sbstobf16_(&one, &B[j*m], &one, &btmp, &one);
|
||||
BB[j*m].v = btmp;
|
||||
B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
sbstobf16_(&one, &B[j << l], &one, &btmp, &one);
|
||||
BB[j << l].v = btmp;
|
||||
}
|
||||
for (y = 0; y < 2; y++)
|
||||
{
|
||||
|
@ -240,9 +239,9 @@ main (int argc, char *argv[])
|
|||
transA = 'T';
|
||||
}
|
||||
|
||||
memset(CC, 0, x * m * sizeof(FLOAT));
|
||||
memset(CC, 0, x * sizeof(FLOAT) << l);
|
||||
memset(DD, 0, x * sizeof(FLOAT));
|
||||
memset(C, 0, x * m * sizeof(FLOAT));
|
||||
memset(C, 0, x * sizeof(FLOAT) << l);
|
||||
|
||||
SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k);
|
||||
SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k);
|
||||
|
@ -250,15 +249,15 @@ main (int argc, char *argv[])
|
|||
for (j = 0; j < x; j++)
|
||||
for (i = 0; i < x; i++)
|
||||
if (transA == 'N') {
|
||||
DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j*m]);
|
||||
DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]);
|
||||
} else if (transA == 'T') {
|
||||
DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i*m]);
|
||||
DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]);
|
||||
}
|
||||
|
||||
for (j = 0; j < x; j++) {
|
||||
if (fabs (CC[j*m] - C[j*m]) > 1.0)
|
||||
if (fabs (CC[j << l] - C[j << l]) > 1.0)
|
||||
ret++;
|
||||
if (fabs (CC[j*m] - DD[j]) > 1.0)
|
||||
if (fabs (CC[j << l] - DD[j]) > 1.0)
|
||||
ret++;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue