Merge pull request #2591 from RajalakshmiSR/testhalf

Add test for shgemm
This commit is contained in:
Martin Kroeker 2020-05-01 09:59:39 +02:00 committed by GitHub
commit 596f5df9e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 10 deletions

2
.gitignore vendored
View File

@ -70,6 +70,7 @@ test/SBLAT2.SUMM
test/SBLAT3.SUMM test/SBLAT3.SUMM
test/ZBLAT2.SUMM test/ZBLAT2.SUMM
test/ZBLAT3.SUMM test/ZBLAT3.SUMM
test/SHBLAT3.SUMM
test/cblat1 test/cblat1
test/cblat2 test/cblat2
test/cblat3 test/cblat3
@ -79,6 +80,7 @@ test/dblat3
test/sblat1 test/sblat1
test/sblat2 test/sblat2
test/sblat3 test/sblat3
test/test_shgemm
test/zblat1 test/zblat1
test/zblat2 test/zblat2
test/zblat3 test/zblat3

View File

@ -64,9 +64,17 @@ endif
endif endif
endif endif
ifeq ($(BUILD_HALF),1)
level3 : test_shgemm sblat3 dblat3 cblat3 zblat3
else
level3 : sblat3 dblat3 cblat3 zblat3 level3 : sblat3 dblat3 cblat3 zblat3
endif
ifndef CROSS ifndef CROSS
rm -f ?BLAT3.SUMM rm -f ?BLAT3.SUMM
ifeq ($(BUILD_HALF),1)
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_shgemm > SHBLAT3.SUMM
@$(GREP) -q FATAL SHBLAT3.SUMM && cat SHBLAT3.SUMM || exit 0
endif
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./sblat3 < ./sblat3.dat OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./sblat3 < ./sblat3.dat
@$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0 @$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./dblat3 < ./dblat3.dat OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./dblat3 < ./dblat3.dat
@ -78,6 +86,10 @@ ifndef CROSS
ifdef SMP ifdef SMP
rm -f ?BLAT3.SUMM rm -f ?BLAT3.SUMM
ifeq ($(USE_OPENMP), 1) ifeq ($(USE_OPENMP), 1)
ifeq ($(BUILD_HALF),1)
OMP_NUM_THREADS=2 ./test_shgemm > SHBLAT3.SUMM
@$(GREP) -q FATAL SHBLAT3.SUMM && cat SHBLAT3.SUMM || exit 0
endif
OMP_NUM_THREADS=2 ./sblat3 < ./sblat3.dat OMP_NUM_THREADS=2 ./sblat3 < ./sblat3.dat
@$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0 @$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0
OMP_NUM_THREADS=2 ./dblat3 < ./dblat3.dat OMP_NUM_THREADS=2 ./dblat3 < ./dblat3.dat
@ -87,6 +99,10 @@ ifeq ($(USE_OPENMP), 1)
OMP_NUM_THREADS=2 ./zblat3 < ./zblat3.dat OMP_NUM_THREADS=2 ./zblat3 < ./zblat3.dat
@$(GREP) -q FATAL ZBLAT3.SUMM && cat ZBLAT3.SUMM || exit 0 @$(GREP) -q FATAL ZBLAT3.SUMM && cat ZBLAT3.SUMM || exit 0
else else
ifeq ($(BUILD_HALF),1)
OPENBLAS_NUM_THREADS=2 ./test_shgemm > SHBLAT3.SUMM
@$(GREP) -q FATAL SHBLAT3.SUMM && cat SHBLAT3.SUMM || exit 0
endif
OPENBLAS_NUM_THREADS=2 ./sblat3 < ./sblat3.dat OPENBLAS_NUM_THREADS=2 ./sblat3 < ./sblat3.dat
@$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0 @$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0
OPENBLAS_NUM_THREADS=2 ./dblat3 < ./dblat3.dat OPENBLAS_NUM_THREADS=2 ./dblat3 < ./dblat3.dat
@ -165,6 +181,11 @@ zblat2 : zblat2.$(SUFFIX) ../$(LIBNAME)
sblat3 : sblat3.$(SUFFIX) ../$(LIBNAME) sblat3 : sblat3.$(SUFFIX) ../$(LIBNAME)
$(FC) $(FLDFLAGS) -o sblat3 sblat3.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) $(FC) $(FLDFLAGS) -o sblat3 sblat3.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
ifeq ($(BUILD_HALF),1)
test_shgemm : compare_sgemm_shgemm.c ../$(LIBNAME)
$(FC) $(FLDFLAGS) -o test_shgemm compare_sgemm_shgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
endif
dblat3 : dblat3.$(SUFFIX) ../$(LIBNAME) dblat3 : dblat3.$(SUFFIX) ../$(LIBNAME)
$(FC) $(FLDFLAGS) -o dblat3 dblat3.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) $(FC) $(FLDFLAGS) -o dblat3 dblat3.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
@ -187,7 +208,7 @@ clean:
@rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \ @rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \
sblat1 dblat1 cblat1 zblat1 \ sblat1 dblat1 cblat1 zblat1 \
sblat2 dblat2 cblat2 zblat2 \ sblat2 dblat2 cblat2 zblat2 \
sblat3 dblat3 cblat3 zblat3 \ test_shgemm sblat3 dblat3 cblat3 zblat3 \
sblat1p dblat1p cblat1p zblat1p \ sblat1p dblat1p cblat1p zblat1p \
sblat2p dblat2p cblat2p zblat2p \ sblat2p dblat2p cblat2p zblat2p \
sblat3p dblat3p cblat3p zblat3p \ sblat3p dblat3p cblat3p zblat3p \

View File

@ -26,7 +26,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
#include <stdio.h> #include <stdio.h>
#include <stdint.h> #include <stdint.h>
#include "common.h" #include "../common.h"
#define SGEMM BLASFUNC(sgemm) #define SGEMM BLASFUNC(sgemm)
#define SHGEMM BLASFUNC(shgemm) #define SHGEMM BLASFUNC(shgemm)
typedef union typedef union
@ -52,7 +52,7 @@ main (int argc, char *argv[])
int m, n, k; int m, n, k;
int i, j, l; int i, j, l;
int ret = 0; int ret = 0;
int loop = 20; int loop = 100;
char transA = 'N', transB = 'N'; char transA = 'N', transB = 'N';
float alpha = 1.0, beta = 0.0; float alpha = 1.0, beta = 0.0;
char transa = 'N'; char transa = 'N';
@ -71,8 +71,8 @@ main (int argc, char *argv[])
{ {
for (int i = 0; i < m; i++) for (int i = 0; i < m; i++)
{ {
A[j * k + i] = j * 9.0; A[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5;
B[j * k + i] = i * 2.0; B[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5;
C[j * k + i] = 0; C[j * k + i] = 0;
AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16; AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16;
BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16; BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16;
@ -87,9 +87,10 @@ main (int argc, char *argv[])
for (i = 0; i < n; i++) for (i = 0; i < n; i++)
for (j = 0; j < m; j++) for (j = 0; j < m; j++)
for (l = 0; l < k; l++) for (l = 0; l < k; l++)
if (CC[i * m + j] != C[i * m + j]) if (fabs(CC[i * m + j]-C[i * m + j]) > 1.0)
ret++; ret++;
} }
fprintf (stderr, "Return code: %d\n", ret); if (ret != 0)
fprintf (stderr, "FATAL ERROR SHGEMM - Return code: %d\n", ret);
return ret; return ret;
} }