diff --git a/driver/level3/level3_thread.c b/driver/level3/level3_thread.c index fec873e51..8ab4ef699 100644 --- a/driver/level3/level3_thread.c +++ b/driver/level3/level3_thread.c @@ -684,8 +684,6 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO BLASLONG m = args -> m; BLASLONG n = args -> n; BLASLONG nthreads = args -> nthreads; - BLASLONG divN, divT; - int mode; if (nthreads == 1) { GEMM_LOCAL(args, range_m, range_n, sa, sb, 0); @@ -706,66 +704,21 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO n = n_to - n_from; } - if ((m < nthreads * SWITCH_RATIO) || (n < nthreads * SWITCH_RATIO)) { + if ((m < 2 * SWITCH_RATIO) || (n < 2 * SWITCH_RATIO)) { GEMM_LOCAL(args, range_m, range_n, sa, sb, 0); return 0; } - divT = nthreads; - divN = 1; - -#if 0 - while ((GEMM_P * divT > m * SWITCH_RATIO) && (divT > 1)) { - do { - divT --; - divN = 1; - while (divT * divN < nthreads) divN ++; - } while ((divT * divN != nthreads) && (divT > 1)); + if (m < nthreads * SWITCH_RATIO) { + nthreads = blas_quickdivide(m, SWITCH_RATIO); } -#endif - - // fprintf(stderr, "divN = %4ld divT = %4ld\n", divN, divT); - - args -> nthreads = divT; - - if (divN == 1){ - - gemm_driver(args, range_m, range_n, sa, sb, 0); - } else { -#ifndef COMPLEX -#ifdef XDOUBLE - mode = BLAS_XDOUBLE | BLAS_REAL; -#elif defined(DOUBLE) - mode = BLAS_DOUBLE | BLAS_REAL; -#else - mode = BLAS_SINGLE | BLAS_REAL; -#endif -#else -#ifdef XDOUBLE - mode = BLAS_XDOUBLE | BLAS_COMPLEX; -#elif defined(DOUBLE) - mode = BLAS_DOUBLE | BLAS_COMPLEX; -#else - mode = BLAS_SINGLE | BLAS_COMPLEX; -#endif -#endif - -#if defined(TN) || defined(TT) || defined(TR) || defined(TC) || \ - defined(CN) || defined(CT) || defined(CR) || defined(CC) - mode |= (BLAS_TRANSA_T); -#endif -#if defined(NT) || defined(TT) || defined(RT) || defined(CT) || \ - defined(NC) || defined(TC) || defined(RC) || defined(CC) - mode |= (BLAS_TRANSB_T); -#endif - -#ifdef OS_WINDOWS - gemm_thread_n(mode, args, range_m, range_n, GEMM_LOCAL, sa, sb, divN); -#else - gemm_thread_n(mode, args, range_m, range_n, gemm_driver, sa, sb, divN); -#endif - + if (n < nthreads * SWITCH_RATIO) { + nthreads = blas_quickdivide(n, SWITCH_RATIO); } + args -> nthreads = nthreads; + + gemm_driver(args, range_m, range_n, sa, sb, 0); + return 0; }