diff --git a/.gitignore b/.gitignore index 6803a919e..bca79f043 100644 --- a/.gitignore +++ b/.gitignore @@ -70,6 +70,7 @@ test/SBLAT2.SUMM test/SBLAT3.SUMM test/ZBLAT2.SUMM test/ZBLAT3.SUMM +test/SHBLAT3.SUMM test/cblat1 test/cblat2 test/cblat3 @@ -79,6 +80,7 @@ test/dblat3 test/sblat1 test/sblat2 test/sblat3 +test/test_shgemm test/zblat1 test/zblat2 test/zblat3 diff --git a/test/Makefile b/test/Makefile index 7a873b7e5..45f9821ec 100644 --- a/test/Makefile +++ b/test/Makefile @@ -64,9 +64,17 @@ endif endif endif +ifeq ($(BUILD_HALF),1) +level3 : test_shgemm sblat3 dblat3 cblat3 zblat3 +else level3 : sblat3 dblat3 cblat3 zblat3 +endif ifndef CROSS rm -f ?BLAT3.SUMM +ifeq ($(BUILD_HALF),1) + OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_shgemm > SHBLAT3.SUMM + @$(GREP) -q FATAL SHBLAT3.SUMM && cat SHBLAT3.SUMM || exit 0 +endif OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./sblat3 < ./sblat3.dat @$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0 OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./dblat3 < ./dblat3.dat @@ -78,6 +86,10 @@ ifndef CROSS ifdef SMP rm -f ?BLAT3.SUMM ifeq ($(USE_OPENMP), 1) +ifeq ($(BUILD_HALF),1) + OMP_NUM_THREADS=2 ./test_shgemm > SHBLAT3.SUMM + @$(GREP) -q FATAL SHBLAT3.SUMM && cat SHBLAT3.SUMM || exit 0 +endif OMP_NUM_THREADS=2 ./sblat3 < ./sblat3.dat @$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0 OMP_NUM_THREADS=2 ./dblat3 < ./dblat3.dat @@ -87,6 +99,10 @@ ifeq ($(USE_OPENMP), 1) OMP_NUM_THREADS=2 ./zblat3 < ./zblat3.dat @$(GREP) -q FATAL ZBLAT3.SUMM && cat ZBLAT3.SUMM || exit 0 else +ifeq ($(BUILD_HALF),1) + OPENBLAS_NUM_THREADS=2 ./test_shgemm > SHBLAT3.SUMM + @$(GREP) -q FATAL SHBLAT3.SUMM && cat SHBLAT3.SUMM || exit 0 +endif OPENBLAS_NUM_THREADS=2 ./sblat3 < ./sblat3.dat @$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0 OPENBLAS_NUM_THREADS=2 ./dblat3 < ./dblat3.dat @@ -165,6 +181,11 @@ zblat2 : zblat2.$(SUFFIX) ../$(LIBNAME) sblat3 : sblat3.$(SUFFIX) ../$(LIBNAME) $(FC) $(FLDFLAGS) -o sblat3 sblat3.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) +ifeq ($(BUILD_HALF),1) +test_shgemm : compare_sgemm_shgemm.c ../$(LIBNAME) + $(FC) $(FLDFLAGS) -o test_shgemm compare_sgemm_shgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) +endif + dblat3 : dblat3.$(SUFFIX) ../$(LIBNAME) $(FC) $(FLDFLAGS) -o dblat3 dblat3.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) @@ -187,7 +208,7 @@ clean: @rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \ sblat1 dblat1 cblat1 zblat1 \ sblat2 dblat2 cblat2 zblat2 \ - sblat3 dblat3 cblat3 zblat3 \ + test_shgemm sblat3 dblat3 cblat3 zblat3 \ sblat1p dblat1p cblat1p zblat1p \ sblat2p dblat2p cblat2p zblat2p \ sblat3p dblat3p cblat3p zblat3p \ diff --git a/test/compare_sgemm_shgemm.c b/test/compare_sgemm_shgemm.c index 978972b24..d5bd84b91 100644 --- a/test/compare_sgemm_shgemm.c +++ b/test/compare_sgemm_shgemm.c @@ -26,7 +26,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ #include #include -#include "common.h" +#include "../common.h" #define SGEMM BLASFUNC(sgemm) #define SHGEMM BLASFUNC(shgemm) typedef union @@ -52,7 +52,7 @@ main (int argc, char *argv[]) int m, n, k; int i, j, l; int ret = 0; - int loop = 20; + int loop = 100; char transA = 'N', transB = 'N'; float alpha = 1.0, beta = 0.0; char transa = 'N'; @@ -71,8 +71,8 @@ main (int argc, char *argv[]) { for (int i = 0; i < m; i++) { - A[j * k + i] = j * 9.0; - B[j * k + i] = i * 2.0; + A[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5; + B[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5; C[j * k + i] = 0; AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16; BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16; @@ -85,11 +85,12 @@ main (int argc, char *argv[]) &m, BB, &k, &beta, CC, &m); for (i = 0; i < n; i++) - for (j = 0; j < m; j++) - for (l = 0; l < k; l++) - if (CC[i * m + j] != C[i * m + j]) - ret++; + for (j = 0; j < m; j++) + for (l = 0; l < k; l++) + if (fabs(CC[i * m + j]-C[i * m + j]) > 1.0) + ret++; } - fprintf (stderr, "Return code: %d\n", ret); + if (ret != 0) + fprintf (stderr, "FATAL ERROR SHGEMM - Return code: %d\n", ret); return ret; }