Add test for shgemm
This patch has Makefile changes to add test for shgemm which compares sgemm and shgemm result.
This commit is contained in:
parent
d394d4e677
commit
564b0d39ef
|
@ -70,6 +70,7 @@ test/SBLAT2.SUMM
|
|||
test/SBLAT3.SUMM
|
||||
test/ZBLAT2.SUMM
|
||||
test/ZBLAT3.SUMM
|
||||
test/SHBLAT3.SUMM
|
||||
test/cblat1
|
||||
test/cblat2
|
||||
test/cblat3
|
||||
|
@ -79,6 +80,7 @@ test/dblat3
|
|||
test/sblat1
|
||||
test/sblat2
|
||||
test/sblat3
|
||||
test/test_shgemm
|
||||
test/zblat1
|
||||
test/zblat2
|
||||
test/zblat3
|
||||
|
|
|
@ -64,9 +64,17 @@ endif
|
|||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
level3 : test_shgemm sblat3 dblat3 cblat3 zblat3
|
||||
else
|
||||
level3 : sblat3 dblat3 cblat3 zblat3
|
||||
endif
|
||||
ifndef CROSS
|
||||
rm -f ?BLAT3.SUMM
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_shgemm > SHBLAT3.SUMM
|
||||
@$(GREP) -q FATAL SHBLAT3.SUMM && cat SHBLAT3.SUMM || exit 0
|
||||
endif
|
||||
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./sblat3 < ./sblat3.dat
|
||||
@$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0
|
||||
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./dblat3 < ./dblat3.dat
|
||||
|
@ -78,6 +86,10 @@ ifndef CROSS
|
|||
ifdef SMP
|
||||
rm -f ?BLAT3.SUMM
|
||||
ifeq ($(USE_OPENMP), 1)
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
OMP_NUM_THREADS=2 ./test_shgemm > SHBLAT3.SUMM
|
||||
@$(GREP) -q FATAL SHBLAT3.SUMM && cat SHBLAT3.SUMM || exit 0
|
||||
endif
|
||||
OMP_NUM_THREADS=2 ./sblat3 < ./sblat3.dat
|
||||
@$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0
|
||||
OMP_NUM_THREADS=2 ./dblat3 < ./dblat3.dat
|
||||
|
@ -87,6 +99,10 @@ ifeq ($(USE_OPENMP), 1)
|
|||
OMP_NUM_THREADS=2 ./zblat3 < ./zblat3.dat
|
||||
@$(GREP) -q FATAL ZBLAT3.SUMM && cat ZBLAT3.SUMM || exit 0
|
||||
else
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
OPENBLAS_NUM_THREADS=2 ./test_shgemm > SHBLAT3.SUMM
|
||||
@$(GREP) -q FATAL SHBLAT3.SUMM && cat SHBLAT3.SUMM || exit 0
|
||||
endif
|
||||
OPENBLAS_NUM_THREADS=2 ./sblat3 < ./sblat3.dat
|
||||
@$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0
|
||||
OPENBLAS_NUM_THREADS=2 ./dblat3 < ./dblat3.dat
|
||||
|
@ -165,6 +181,11 @@ zblat2 : zblat2.$(SUFFIX) ../$(LIBNAME)
|
|||
sblat3 : sblat3.$(SUFFIX) ../$(LIBNAME)
|
||||
$(FC) $(FLDFLAGS) -o sblat3 sblat3.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
|
||||
|
||||
ifeq ($(BUILD_HALF),1)
|
||||
test_shgemm : compare_sgemm_shgemm.c ../$(LIBNAME)
|
||||
$(FC) $(FLDFLAGS) -o test_shgemm compare_sgemm_shgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
|
||||
endif
|
||||
|
||||
dblat3 : dblat3.$(SUFFIX) ../$(LIBNAME)
|
||||
$(FC) $(FLDFLAGS) -o dblat3 dblat3.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
|
||||
|
||||
|
@ -187,7 +208,7 @@ clean:
|
|||
@rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \
|
||||
sblat1 dblat1 cblat1 zblat1 \
|
||||
sblat2 dblat2 cblat2 zblat2 \
|
||||
sblat3 dblat3 cblat3 zblat3 \
|
||||
test_shgemm sblat3 dblat3 cblat3 zblat3 \
|
||||
sblat1p dblat1p cblat1p zblat1p \
|
||||
sblat2p dblat2p cblat2p zblat2p \
|
||||
sblat3p dblat3p cblat3p zblat3p \
|
||||
|
|
|
@ -26,7 +26,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
*****************************************************************************/
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include "common.h"
|
||||
#include "../common.h"
|
||||
#define SGEMM BLASFUNC(sgemm)
|
||||
#define SHGEMM BLASFUNC(shgemm)
|
||||
typedef union
|
||||
|
@ -52,7 +52,7 @@ main (int argc, char *argv[])
|
|||
int m, n, k;
|
||||
int i, j, l;
|
||||
int ret = 0;
|
||||
int loop = 20;
|
||||
int loop = 100;
|
||||
char transA = 'N', transB = 'N';
|
||||
float alpha = 1.0, beta = 0.0;
|
||||
char transa = 'N';
|
||||
|
@ -71,8 +71,8 @@ main (int argc, char *argv[])
|
|||
{
|
||||
for (int i = 0; i < m; i++)
|
||||
{
|
||||
A[j * k + i] = j * 9.0;
|
||||
B[j * k + i] = i * 2.0;
|
||||
A[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5;
|
||||
B[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5;
|
||||
C[j * k + i] = 0;
|
||||
AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16;
|
||||
BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16;
|
||||
|
@ -85,11 +85,12 @@ main (int argc, char *argv[])
|
|||
&m, BB, &k, &beta, CC, &m);
|
||||
|
||||
for (i = 0; i < n; i++)
|
||||
for (j = 0; j < m; j++)
|
||||
for (l = 0; l < k; l++)
|
||||
if (CC[i * m + j] != C[i * m + j])
|
||||
ret++;
|
||||
for (j = 0; j < m; j++)
|
||||
for (l = 0; l < k; l++)
|
||||
if (fabs(CC[i * m + j]-C[i * m + j]) > 1.0)
|
||||
ret++;
|
||||
}
|
||||
fprintf (stderr, "Return code: %d\n", ret);
|
||||
if (ret != 0)
|
||||
fprintf (stderr, "FATAL ERROR SHGEMM - Return code: %d\n", ret);
|
||||
return ret;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue