Merge pull request #4875 from ChipKerchner/addGEMVtoBF16Test
Add GEMV to SBGEMx vs SGEMx testing
This commit is contained in:
commit
4944148e66
|
@ -29,6 +29,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
#include "../common.h"
|
||||
#define SGEMM BLASFUNC(sgemm)
|
||||
#define SBGEMM BLASFUNC(sbgemm)
|
||||
#define SGEMV BLASFUNC(sgemv)
|
||||
#define SBGEMV BLASFUNC(sbgemv)
|
||||
typedef union
|
||||
{
|
||||
unsigned short v;
|
||||
|
@ -187,7 +189,79 @@ main (int argc, char *argv[])
|
|||
free(CC);
|
||||
}
|
||||
|
||||
if (ret != 0)
|
||||
if (ret != 0) {
|
||||
fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
k = 1;
|
||||
for (x = 1; x <= loop; x++)
|
||||
{
|
||||
float *A = (float *)malloc(x * x * sizeof(FLOAT));
|
||||
float *B = (float *)malloc(x * sizeof(FLOAT));
|
||||
float *C = (float *)malloc(x * sizeof(FLOAT));
|
||||
bfloat16_bits *AA = (bfloat16_bits *)malloc(x * x * sizeof(bfloat16_bits));
|
||||
bfloat16_bits *BB = (bfloat16_bits *)malloc(x * sizeof(bfloat16_bits));
|
||||
float *DD = (float *)malloc(x * sizeof(FLOAT));
|
||||
float *CC = (float *)malloc(x * sizeof(FLOAT));
|
||||
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
|
||||
(DD == NULL) || (CC == NULL))
|
||||
return 1;
|
||||
bfloat16 atmp, btmp;
|
||||
blasint one = 1;
|
||||
|
||||
for (j = 0; j < x; j++)
|
||||
{
|
||||
for (i = 0; i < x; i++)
|
||||
{
|
||||
A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one);
|
||||
AA[j * x + i].v = atmp;
|
||||
}
|
||||
B[j] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
sbstobf16_(&one, &B[j], &one, &btmp, &one);
|
||||
BB[j].v = btmp;
|
||||
}
|
||||
for (y = 0; y < 2; y++)
|
||||
{
|
||||
if (y == 0) {
|
||||
transA = 'N';
|
||||
} else {
|
||||
transA = 'T';
|
||||
}
|
||||
|
||||
memset(CC, 0, x * sizeof(FLOAT));
|
||||
memset(DD, 0, x * sizeof(FLOAT));
|
||||
memset(C, 0, x * sizeof(FLOAT));
|
||||
|
||||
SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k);
|
||||
SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k);
|
||||
|
||||
for (j = 0; j < x; j++)
|
||||
for (i = 0; i < x; i++)
|
||||
if (transA == 'N') {
|
||||
DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j]);
|
||||
} else if (transA == 'T') {
|
||||
DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i]);
|
||||
}
|
||||
|
||||
for (j = 0; j < x; j++) {
|
||||
if (fabs (CC[j] - C[j]) > 1.0)
|
||||
ret++;
|
||||
if (fabs (CC[j] - DD[j]) > 1.0)
|
||||
ret++;
|
||||
}
|
||||
}
|
||||
free(A);
|
||||
free(B);
|
||||
free(C);
|
||||
free(AA);
|
||||
free(BB);
|
||||
free(DD);
|
||||
free(CC);
|
||||
}
|
||||
|
||||
if (ret != 0)
|
||||
fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret);
|
||||
return ret;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue