diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c index 4afa8bf93..bc74233ab 100644 --- a/test/compare_sgemm_sbgemm.c +++ b/test/compare_sgemm_sbgemm.c @@ -81,16 +81,6 @@ float16to32 (bfloat16_bits f16) return f32.v; } -float -float32to16 (float32_bits f32) -{ - bfloat16_bits f16; - f16.bits.s = f32.bits.s; - f16.bits.e = f32.bits.e; - f16.bits.m = (uint32_t) f32.bits.m >> 16; - return f32.v; -} - int main (int argc, char *argv[]) { @@ -110,6 +100,8 @@ main (int argc, char *argv[]) float C[m * n]; bfloat16_bits AA[m * k], BB[k * n]; float DD[m * n], CC[m * n]; + bfloat16 atmp,btmp; + blasint one=1; for (j = 0; j < m; j++) { @@ -118,8 +110,10 @@ main (int argc, char *argv[]) A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; C[j * k + i] = 0; - AA[j * k + i].v = float32to16( A[j * k + i] ); - BB[j * k + i].v = float32to16( B[j * k + i] ); + sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one); + sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one); + AA[j * k + i].v = atmp; + BB[j * k + i].v = btmp; CC[j * k + i] = 0; DD[j * k + i] = 0; }