diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c index bc74233ab..de589458b 100644 --- a/test/compare_sgemm_sbgemm.c +++ b/test/compare_sgemm_sbgemm.c @@ -86,14 +86,26 @@ main (int argc, char *argv[]) { blasint m, n, k; int i, j, l; - blasint x; + blasint x, y; int ret = 0; int loop = 100; char transA = 'N', transB = 'N'; float alpha = 1.0, beta = 0.0; for (x = 0; x <= loop; x++) + { + for (y = 0; y < 4; y++) { + if ((y == 0) || (y == 2)) { + transA = 'N'; + } else { + transA = 'T'; + } + if ((y == 0) || (y == 1)) { + transB = 'N'; + } else { + transB = 'T'; + } m = k = n = x; float A[m * k]; float B[k * n]; @@ -104,43 +116,55 @@ main (int argc, char *argv[]) blasint one=1; for (j = 0; j < m; j++) - { - for (i = 0; i < m; i++) - { - A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - C[j * k + i] = 0; - sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one); - sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one); - AA[j * k + i].v = atmp; - BB[j * k + i].v = btmp; - CC[j * k + i] = 0; - DD[j * k + i] = 0; - } - } + { + for (i = 0; i < m; i++) + { + A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + C[j * k + i] = 0; + sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one); + sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one); + AA[j * k + i].v = atmp; + BB[j * k + i].v = btmp; + CC[j * k + i] = 0; + DD[j * k + i] = 0; + } + } SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, - &m, B, &k, &beta, C, &m); + &m, B, &k, &beta, C, &m); SBGEMM (&transA, &transB, &m, &n, &k, &alpha, (bfloat16*) AA, - &m, (bfloat16*)BB, &k, &beta, CC, &m); + &m, (bfloat16*)BB, &k, &beta, CC, &m); for (i = 0; i < n; i++) - for (j = 0; j < m; j++) - if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) - ret++; - if (transA == 'N' && transB == 'N') - { - for (i = 0; i < n; i++) - for (j = 0; j < m; j++) - for (l = 0; l < k; l++) - { - DD[i * m + j] += - float16to32 (AA[l * m + j]) * float16to32 (BB[l + k * i]); - } - for (i = 0; i < n; i++) - for (j = 0; j < m; j++) - if (CC[i * m + j] != DD[i * m + j]) - ret++; - } + for (j = 0; j < m; j++) + if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) + ret++; + for (i = 0; i < n; i++) + for (j = 0; j < m; j++) + for (l = 0; l < k; l++) + if (transA == 'N' && transB == 'N') + { + DD[i * m + j] += + float16to32 (AA[l * m + j]) * float16to32 (BB[l + k * i]); + } else if (transA == 'T' && transB == 'N') + { + DD[i * m + j] += + float16to32 (AA[k * j + l]) * float16to32 (BB[l + k * i]); + } else if (transA == 'N' && transB == 'T') + { + DD[i * m + j] += + float16to32 (AA[l * m + j]) * float16to32 (BB[i + l * n]); + } else if (transA == 'T' && transB == 'T') + { + DD[i * m + j] += + float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]); + } + for (i = 0; i < n; i++) + for (j = 0; j < m; j++) + if (CC[i * m + j] != DD[i * m + j]) + ret++; } + } + if (ret != 0) fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret); return ret;