From 28915eed726404bd14ed2828d45fe5293c55603e Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Fri, 5 Jun 2020 10:05:34 +0200 Subject: [PATCH] Cosmetic fixes for non-C99 compilers --- test/compare_sgemm_shgemm.c | 65 +++++++++---------------------------- 1 file changed, 16 insertions(+), 49 deletions(-) diff --git a/test/compare_sgemm_shgemm.c b/test/compare_sgemm_shgemm.c index 7e254f844..d37ae6851 100644 --- a/test/compare_sgemm_shgemm.c +++ b/test/compare_sgemm_shgemm.c @@ -46,83 +46,50 @@ typedef union } bits; } bfloat16_bits; -typedef union -{ - float v; - struct - { - uint32_t m:23; - uint32_t e:8; - uint32_t s:1; - } bits; -} float32_bits; - -float -float16to32 (bfloat16_bits f16) -{ - float32_bits f32; - f32.bits.s = f16.bits.s; - f32.bits.e = f16.bits.e; - f32.bits.m = (uint32_t) f16.bits.m << 16; - return f32.v; -} - int main (int argc, char *argv[]) { int m, n, k; int i, j, l; + int x; int ret = 0; int loop = 100; char transA = 'N', transB = 'N'; float alpha = 1.0, beta = 0.0; + char transa = 'N'; + char transb = 'N'; - for (int x = 0; x <= loop; x++) + for (x = 0; x <= loop; x++) { m = k = n = x; float A[m * k]; float B[k * n]; float C[m * n]; bfloat16_bits AA[m * k], BB[k * n]; - float DD[m * n], CC[m * n]; + float CC[m * n]; - for (int j = 0; j < m; j++) + for (j = 0; j < m; j++) { - for (int i = 0; i < m; i++) + for (i = 0; i < m; i++) { - A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + A[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5; + B[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5; C[j * k + i] = 0; AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16; BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16; CC[j * k + i] = 0; - DD[j * k + i] = 0; } } SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, - &m, B, &k, &beta, C, &m); + &m, B, &k, &beta, C, &m); SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA, - &m, BB, &k, &beta, CC, &m); + &m, BB, &k, &beta, CC, &m); + for (i = 0; i < n; i++) - for (j = 0; j < m; j++) - for (l = 0; l < k; l++) - if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) - ret++; - if (transA == 'N' && transB == 'N') - { - for (i = 0; i < n; i++) - for (j = 0; j < m; j++) - for (l = 0; l < k; l++) - { - DD[i * m + j] += - float16to32 (AA[l * m + j]) * float16to32 (BB[l + k * i]); - } - for (i = 0; i < n; i++) - for (j = 0; j < m; j++) - for (l = 0; l < k; l++) - if (CC[i * m + j] != DD[i * m + j]) - ret++; - } + for (j = 0; j < m; j++) + for (l = 0; l < k; l++) + if (fabs(CC[i * m + j]-C[i * m + j]) > 1.0) + ret++; } if (ret != 0) fprintf (stderr, "FATAL ERROR SHGEMM - Return code: %d\n", ret);