Change malloc zero to return one byte and update the SBGEMM test to again use sizes of zero.
This commit is contained in:
parent
b1802f4dc8
commit
868aa857bc
|
@ -85,6 +85,14 @@ float16to32 (bfloat16_bits f16)
|
|||
|
||||
#define SBGEMM_LARGEST 256
|
||||
|
||||
void *malloc_safe(size_t size)
|
||||
{
|
||||
if (size == 0)
|
||||
return malloc(1);
|
||||
else
|
||||
return malloc(size);
|
||||
}
|
||||
|
||||
int
|
||||
main (int argc, char *argv[])
|
||||
{
|
||||
|
@ -96,17 +104,17 @@ main (int argc, char *argv[])
|
|||
char transA = 'N', transB = 'N';
|
||||
float alpha = 1.0, beta = 0.0;
|
||||
|
||||
for (x = 1; x <= loop; x++)
|
||||
for (x = 0; x <= loop; x++)
|
||||
{
|
||||
if ((x > 100) && (x != SBGEMM_LARGEST)) continue;
|
||||
m = k = n = x;
|
||||
float *A = (float *)malloc(m * k * sizeof(FLOAT));
|
||||
float *B = (float *)malloc(k * n * sizeof(FLOAT));
|
||||
float *C = (float *)malloc(m * n * sizeof(FLOAT));
|
||||
bfloat16_bits *AA = (bfloat16_bits *)malloc(m * k * sizeof(bfloat16_bits));
|
||||
bfloat16_bits *BB = (bfloat16_bits *)malloc(k * n * sizeof(bfloat16_bits));
|
||||
float *DD = (float *)malloc(m * n * sizeof(FLOAT));
|
||||
float *CC = (float *)malloc(m * n * sizeof(FLOAT));
|
||||
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT));
|
||||
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT));
|
||||
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT));
|
||||
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(m * k * sizeof(bfloat16_bits));
|
||||
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits));
|
||||
float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT));
|
||||
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT));
|
||||
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
|
||||
(DD == NULL) || (CC == NULL))
|
||||
return 1;
|
||||
|
@ -195,15 +203,15 @@ main (int argc, char *argv[])
|
|||
}
|
||||
|
||||
k = 1;
|
||||
for (x = 1; x <= loop; x++)
|
||||
for (x = 0; x <= loop; x++)
|
||||
{
|
||||
float *A = (float *)malloc(x * x * sizeof(FLOAT));
|
||||
float *B = (float *)malloc(x * sizeof(FLOAT));
|
||||
float *C = (float *)malloc(x * sizeof(FLOAT));
|
||||
bfloat16_bits *AA = (bfloat16_bits *)malloc(x * x * sizeof(bfloat16_bits));
|
||||
bfloat16_bits *BB = (bfloat16_bits *)malloc(x * sizeof(bfloat16_bits));
|
||||
float *DD = (float *)malloc(x * sizeof(FLOAT));
|
||||
float *CC = (float *)malloc(x * sizeof(FLOAT));
|
||||
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT));
|
||||
float *B = (float *)malloc_safe(x * sizeof(FLOAT));
|
||||
float *C = (float *)malloc_safe(x * sizeof(FLOAT));
|
||||
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits));
|
||||
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits));
|
||||
float *DD = (float *)malloc_safe(x * sizeof(FLOAT));
|
||||
float *CC = (float *)malloc_safe(x * sizeof(FLOAT));
|
||||
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
|
||||
(DD == NULL) || (CC == NULL))
|
||||
return 1;
|
||||
|
|
Loading…
Reference in New Issue