diff --git a/benchmark/gemm.c b/benchmark/gemm.c index 9d661e648..809813c92 100644 --- a/benchmark/gemm.c +++ b/benchmark/gemm.c @@ -121,13 +121,15 @@ static void *huge_malloc(BLASLONG size){ int main(int argc, char *argv[]){ FLOAT *a, *b, *c; - FLOAT alpha[] = {1.0, 1.0}; + FLOAT alpha[] = {1.0, 0.0}; FLOAT beta [] = {0.0, 0.0}; - char trans='N'; - blasint m, n, i, j; + char transa = 'N'; + char transb = 'N'; + blasint m, n, k, i, j, lda, ldb, ldc; int loops = 1; - int has_param_n=0; - int l; + int has_param_m = 0; + int has_param_n = 0; + int has_param_k = 0; char *p; int from = 1; @@ -135,86 +137,108 @@ int main(int argc, char *argv[]){ int step = 1; struct timeval start, stop; - double time1,timeg; + double time1, timeg; argc--;argv++; - if (argc > 0) { from = atol(*argv); argc--; argv++;} - if (argc > 0) { to = MAX(atol(*argv), from); argc--; argv++;} - if (argc > 0) { step = atol(*argv); argc--; argv++;} + if (argc > 0) { from = atol(*argv); argc--; argv++; } + if (argc > 0) { to = MAX(atol(*argv), from); argc--; argv++; } + if (argc > 0) { step = atol(*argv); argc--; argv++; } - if ((p = getenv("OPENBLAS_TRANS"))) trans=*p; - - fprintf(stderr, "From : %3d To : %3d Step=%d : Trans=%c\n", from, to, step, trans); - - if (( a = (FLOAT *)malloc(sizeof(FLOAT) * to * to * COMPSIZE)) == NULL){ - fprintf(stderr,"Out of Memory!!\n");exit(1); + if ((p = getenv("OPENBLAS_TRANS"))) { + transa=*p; + transb=*p; } - - if (( b = (FLOAT *)malloc(sizeof(FLOAT) * to * to * COMPSIZE)) == NULL){ - fprintf(stderr,"Out of Memory!!\n");exit(1); + if ((p = getenv("OPENBLAS_TRANSA"))) { + transa=*p; } - - if (( c = (FLOAT *)malloc(sizeof(FLOAT) * to * to * COMPSIZE)) == NULL){ - fprintf(stderr,"Out of Memory!!\n");exit(1); + if ((p = getenv("OPENBLAS_TRANSB"))) { + transb=*p; } + TOUPPER(transa); + TOUPPER(transb); + + fprintf(stderr, "From : %3d To : %3d Step=%d : Transa=%c : Transb=%c\n", from, to, step, transa, transb); p = getenv("OPENBLAS_LOOPS"); - if ( p != NULL ) - loops = atoi(p); + if ( p != NULL ) { + loops = atoi(p); + } + if ((p = getenv("OPENBLAS_PARAM_M"))) { + m = atoi(p); + has_param_m=1; + } else { + m = to; + } if ((p = getenv("OPENBLAS_PARAM_N"))) { - n = atoi(p); - has_param_n=1; + n = atoi(p); + has_param_n=1; + } else { + n = to; + } + if ((p = getenv("OPENBLAS_PARAM_K"))) { + k = atoi(p); + has_param_k=1; + } else { + k = to; + } + + if (( a = (FLOAT *)malloc(sizeof(FLOAT) * m * k * COMPSIZE)) == NULL) { + fprintf(stderr,"Out of Memory!!\n");exit(1); + } + if (( b = (FLOAT *)malloc(sizeof(FLOAT) * k * n * COMPSIZE)) == NULL) { + fprintf(stderr,"Out of Memory!!\n");exit(1); + } + if (( c = (FLOAT *)malloc(sizeof(FLOAT) * m * n * COMPSIZE)) == NULL) { + fprintf(stderr,"Out of Memory!!\n");exit(1); } #ifdef linux srandom(getpid()); #endif + + for (i = 0; i < m * k * COMPSIZE; i++) { + a[i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; + } + for (i = 0; i < k * n * COMPSIZE; i++) { + b[i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; + } + for (i = 0; i < m * n * COMPSIZE; i++) { + c[i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; + } - for(j = 0; j < to; j++){ - for(i = 0; i < to * COMPSIZE; i++){ - a[i + j * to * COMPSIZE] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; - b[i + j * to * COMPSIZE] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; - c[i + j * to * COMPSIZE] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; - } - } - - - - fprintf(stderr, " SIZE Flops Time\n"); - - for(m = from; m <= to; m += step) - { + fprintf(stderr, " SIZE Flops Time\n"); + for (i = from; i <= to; i += step) { + timeg=0; - if ( has_param_n == 1 && n <= m ) - n=n; - else - n=m; + if (!has_param_m) { m = i; } + if (!has_param_n) { n = i; } + if (!has_param_k) { k = i; } + if (transa == 'N') { lda = m; } + else { lda = k; } + if (transb == 'N') { ldb = k; } + else { ldb = n; } + ldc = m; - - fprintf(stderr, " %6dx%d : ", (int)m, (int)n); + fprintf(stderr, " M=%4d, N=%4d, K=%4d : ", (int)m, (int)n, (int)k); gettimeofday( &start, (struct timezone *)0); - for (l=0; l m; BLASLONG n = args -> n; BLASLONG nthreads = args -> nthreads; - BLASLONG divN, divT; - int mode; if (nthreads == 1) { GEMM_LOCAL(args, range_m, range_n, sa, sb, 0); @@ -706,66 +704,21 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO n = n_to - n_from; } - if ((m < nthreads * SWITCH_RATIO) || (n < nthreads * SWITCH_RATIO)) { + if ((m < 2 * SWITCH_RATIO) || (n < 2 * SWITCH_RATIO)) { GEMM_LOCAL(args, range_m, range_n, sa, sb, 0); return 0; } - divT = nthreads; - divN = 1; - -#if 0 - while ((GEMM_P * divT > m * SWITCH_RATIO) && (divT > 1)) { - do { - divT --; - divN = 1; - while (divT * divN < nthreads) divN ++; - } while ((divT * divN != nthreads) && (divT > 1)); + if (m < nthreads * SWITCH_RATIO) { + nthreads = blas_quickdivide(m, SWITCH_RATIO); } -#endif - - // fprintf(stderr, "divN = %4ld divT = %4ld\n", divN, divT); - - args -> nthreads = divT; - - if (divN == 1){ - - gemm_driver(args, range_m, range_n, sa, sb, 0); - } else { -#ifndef COMPLEX -#ifdef XDOUBLE - mode = BLAS_XDOUBLE | BLAS_REAL; -#elif defined(DOUBLE) - mode = BLAS_DOUBLE | BLAS_REAL; -#else - mode = BLAS_SINGLE | BLAS_REAL; -#endif -#else -#ifdef XDOUBLE - mode = BLAS_XDOUBLE | BLAS_COMPLEX; -#elif defined(DOUBLE) - mode = BLAS_DOUBLE | BLAS_COMPLEX; -#else - mode = BLAS_SINGLE | BLAS_COMPLEX; -#endif -#endif - -#if defined(TN) || defined(TT) || defined(TR) || defined(TC) || \ - defined(CN) || defined(CT) || defined(CR) || defined(CC) - mode |= (BLAS_TRANSA_T); -#endif -#if defined(NT) || defined(TT) || defined(RT) || defined(CT) || \ - defined(NC) || defined(TC) || defined(RC) || defined(CC) - mode |= (BLAS_TRANSB_T); -#endif - -#ifdef OS_WINDOWS - gemm_thread_n(mode, args, range_m, range_n, GEMM_LOCAL, sa, sb, divN); -#else - gemm_thread_n(mode, args, range_m, range_n, gemm_driver, sa, sb, divN); -#endif - + if (n < nthreads * SWITCH_RATIO) { + nthreads = blas_quickdivide(n, SWITCH_RATIO); } + args -> nthreads = nthreads; + + gemm_driver(args, range_m, range_n, sa, sb, 0); + return 0; }