Minor improvement and turn off BF16 GEMV forwarding by default.
This commit is contained in:
parent
8541b25e1d
commit
39fd29f1de
|
@ -282,7 +282,6 @@ GEMM_GEMV_FORWARD = 1
|
||||||
endif
|
endif
|
||||||
ifeq ($(ARCH), power)
|
ifeq ($(ARCH), power)
|
||||||
GEMM_GEMV_FORWARD = 1
|
GEMM_GEMV_FORWARD = 1
|
||||||
GEMM_GEMV_FORWARD_BF16 = 1
|
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ifeq ($(SMALL_MATRIX_OPT), 1)
|
ifeq ($(SMALL_MATRIX_OPT), 1)
|
||||||
|
|
|
@ -31,8 +31,10 @@ static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vecto
|
||||||
{
|
{
|
||||||
if (beta == 0) {
|
if (beta == 0) {
|
||||||
memset(output_vector, 0, sizeof(FLOAT) * n);
|
memset(output_vector, 0, sizeof(FLOAT) * n);
|
||||||
} else if ((output_vector != input_vector) && (beta == 1)) {
|
} else if (beta == 1) {
|
||||||
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
|
if (output_vector != input_vector) {
|
||||||
|
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
vec_f32 b = { beta, beta, beta, beta };
|
vec_f32 b = { beta, beta, beta, beta };
|
||||||
|
|
||||||
|
|
|
@ -205,15 +205,14 @@ main (int argc, char *argv[])
|
||||||
for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one.
|
for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one.
|
||||||
for (x = 1; x <= loop; x++)
|
for (x = 1; x <= loop; x++)
|
||||||
{
|
{
|
||||||
m = l + 1;
|
k = (x == 0) ? 0 : l + 1;
|
||||||
k = (x == 0) ? 0 : m;
|
|
||||||
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT));
|
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT));
|
||||||
float *B = (float *)malloc_safe(x * sizeof(FLOAT) * m);
|
float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l);
|
||||||
float *C = (float *)malloc_safe(x * sizeof(FLOAT) * m);
|
float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l);
|
||||||
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits));
|
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits));
|
||||||
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) * m);
|
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) << l);
|
||||||
float *DD = (float *)malloc_safe(x * sizeof(FLOAT));
|
float *DD = (float *)malloc_safe(x * sizeof(FLOAT));
|
||||||
float *CC = (float *)malloc_safe(x * sizeof(FLOAT) * m);
|
float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l);
|
||||||
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
|
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
|
||||||
(DD == NULL) || (CC == NULL))
|
(DD == NULL) || (CC == NULL))
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -228,9 +227,9 @@ main (int argc, char *argv[])
|
||||||
sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one);
|
sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one);
|
||||||
AA[j * x + i].v = atmp;
|
AA[j * x + i].v = atmp;
|
||||||
}
|
}
|
||||||
B[j*m] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||||
sbstobf16_(&one, &B[j*m], &one, &btmp, &one);
|
sbstobf16_(&one, &B[j << l], &one, &btmp, &one);
|
||||||
BB[j*m].v = btmp;
|
BB[j << l].v = btmp;
|
||||||
}
|
}
|
||||||
for (y = 0; y < 2; y++)
|
for (y = 0; y < 2; y++)
|
||||||
{
|
{
|
||||||
|
@ -240,9 +239,9 @@ main (int argc, char *argv[])
|
||||||
transA = 'T';
|
transA = 'T';
|
||||||
}
|
}
|
||||||
|
|
||||||
memset(CC, 0, x * m * sizeof(FLOAT));
|
memset(CC, 0, x * sizeof(FLOAT) << l);
|
||||||
memset(DD, 0, x * sizeof(FLOAT));
|
memset(DD, 0, x * sizeof(FLOAT));
|
||||||
memset(C, 0, x * m * sizeof(FLOAT));
|
memset(C, 0, x * sizeof(FLOAT) << l);
|
||||||
|
|
||||||
SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k);
|
SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k);
|
||||||
SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k);
|
SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k);
|
||||||
|
@ -250,15 +249,15 @@ main (int argc, char *argv[])
|
||||||
for (j = 0; j < x; j++)
|
for (j = 0; j < x; j++)
|
||||||
for (i = 0; i < x; i++)
|
for (i = 0; i < x; i++)
|
||||||
if (transA == 'N') {
|
if (transA == 'N') {
|
||||||
DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j*m]);
|
DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]);
|
||||||
} else if (transA == 'T') {
|
} else if (transA == 'T') {
|
||||||
DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i*m]);
|
DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (j = 0; j < x; j++) {
|
for (j = 0; j < x; j++) {
|
||||||
if (fabs (CC[j*m] - C[j*m]) > 1.0)
|
if (fabs (CC[j << l] - C[j << l]) > 1.0)
|
||||||
ret++;
|
ret++;
|
||||||
if (fabs (CC[j*m] - DD[j]) > 1.0)
|
if (fabs (CC[j << l] - DD[j]) > 1.0)
|
||||||
ret++;
|
ret++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue