diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c index 4b546fb1f..395317441 100644 --- a/test/compare_sgemm_sbgemm.c +++ b/test/compare_sgemm_sbgemm.c @@ -113,13 +113,19 @@ main (int argc, char *argv[]) for (j = 0; j < m; j++) { - for (i = 0; i < n; i++) + for (i = 0; i < k; i++) { A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one); - sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one); AA[j * k + i].v = atmp; + } + } + for (j = 0; j < n; j++) + { + for (i = 0; i < k; i++) + { + B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one); BB[j * k + i].v = btmp; } } @@ -147,10 +153,7 @@ main (int argc, char *argv[]) for (i = 0; i < n; i++) for (j = 0; j < m; j++) - if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) - ret++; - for (i = 0; i < n; i++) - for (j = 0; j < m; j++) + { for (l = 0; l < k; l++) if (transA == 'N' && transB == 'N') { @@ -169,10 +172,11 @@ main (int argc, char *argv[]) DD[i * m + j] += float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]); } - for (i = 0; i < n; i++) - for (j = 0; j < m; j++) + if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) + ret++; if (fabs (CC[i * m + j] - DD[i * m + j]) > 1.0) ret++; + } } free(A); free(B);