diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c index 395317441..cd508a0cf 100644 --- a/test/compare_sgemm_sbgemm.c +++ b/test/compare_sgemm_sbgemm.c @@ -29,6 +29,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "../common.h" #define SGEMM BLASFUNC(sgemm) #define SBGEMM BLASFUNC(sbgemm) +#define SGEMV BLASFUNC(sgemv) +#define SBGEMV BLASFUNC(sbgemv) typedef union { unsigned short v; @@ -187,7 +189,79 @@ main (int argc, char *argv[]) free(CC); } - if (ret != 0) + if (ret != 0) { fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret); + return ret; + } + + k = 1; + for (x = 1; x <= loop; x++) + { + float *A = (float *)malloc(x * x * sizeof(FLOAT)); + float *B = (float *)malloc(x * sizeof(FLOAT)); + float *C = (float *)malloc(x * sizeof(FLOAT)); + bfloat16_bits *AA = (bfloat16_bits *)malloc(x * x * sizeof(bfloat16_bits)); + bfloat16_bits *BB = (bfloat16_bits *)malloc(x * sizeof(bfloat16_bits)); + float *DD = (float *)malloc(x * sizeof(FLOAT)); + float *CC = (float *)malloc(x * sizeof(FLOAT)); + if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || + (DD == NULL) || (CC == NULL)) + return 1; + bfloat16 atmp, btmp; + blasint one = 1; + + for (j = 0; j < x; j++) + { + for (i = 0; i < x; i++) + { + A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one); + AA[j * x + i].v = atmp; + } + B[j] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + sbstobf16_(&one, &B[j], &one, &btmp, &one); + BB[j].v = btmp; + } + for (y = 0; y < 2; y++) + { + if (y == 0) { + transA = 'N'; + } else { + transA = 'T'; + } + + memset(CC, 0, x * sizeof(FLOAT)); + memset(DD, 0, x * sizeof(FLOAT)); + memset(C, 0, x * sizeof(FLOAT)); + + SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k); + SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k); + + for (j = 0; j < x; j++) + for (i = 0; i < x; i++) + if (transA == 'N') { + DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j]); + } else if (transA == 'T') { + DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i]); + } + + for (j = 0; j < x; j++) { + if (fabs (CC[j] - C[j]) > 1.0) + ret++; + if (fabs (CC[j] - DD[j]) > 1.0) + ret++; + } + } + free(A); + free(B); + free(C); + free(AA); + free(BB); + free(DD); + free(CC); + } + + if (ret != 0) + fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret); return ret; }