diff --git a/driver/level3/level3_thread.c b/driver/level3/level3_thread.c index 22a12d465..77ceac6e8 100644 --- a/driver/level3/level3_thread.c +++ b/driver/level3/level3_thread.c @@ -525,7 +525,7 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG BLASLONG range_M_buffer[MAX_CPU_NUMBER + 2]; BLASLONG range_N_buffer[MAX_CPU_NUMBER + 2]; BLASLONG *range_M, *range_N; - BLASLONG num_cpu_m, num_cpu_n; + BLASLONG num_parts; BLASLONG nthreads = args -> nthreads; @@ -596,16 +596,16 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG } /* Partition m into nthreads_m regions */ - num_cpu_m = 0; + num_parts = 0; while (m > 0){ - width = blas_quickdivide(m + nthreads_m - num_cpu_m - 1, nthreads_m - num_cpu_m); + width = blas_quickdivide(m + nthreads_m - num_parts - 1, nthreads_m - num_parts); m -= width; if (m < 0) width = width + m; - range_M[num_cpu_m + 1] = range_M[num_cpu_m] + width; - num_cpu_m ++; + range_M[num_parts + 1] = range_M[num_parts] + width; + num_parts ++; } - for (i = num_cpu_m; i < MAX_CPU_NUMBER; i++) { - range_M[i + 1] = range_M[num_cpu_m]; + for (i = num_parts; i < MAX_CPU_NUMBER; i++) { + range_M[i + 1] = range_M[num_parts]; } /* Initialize parameters for parallel execution */ @@ -637,16 +637,19 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG /* Partition (a step of) n into nthreads regions */ range_N[0] = js; - num_cpu_n = 0; + num_parts = 0; while (n > 0){ - width = blas_quickdivide(n + nthreads - num_cpu_n - 1, nthreads - num_cpu_n); + width = blas_quickdivide(n + nthreads - num_parts - 1, nthreads - num_parts); + if (width < SWITCH_RATIO) { + width = SWITCH_RATIO; + } n -= width; if (n < 0) width = width + n; - range_N[num_cpu_n + 1] = range_N[num_cpu_n] + width; - num_cpu_n ++; + range_N[num_parts + 1] = range_N[num_parts] + width; + num_parts ++; } - for (j = num_cpu_n; j < MAX_CPU_NUMBER; j++) { - range_N[j + 1] = range_N[num_cpu_n]; + for (j = num_parts; j < MAX_CPU_NUMBER; j++) { + range_N[j + 1] = range_N[num_parts]; } /* Clear synchronization flags */ @@ -683,7 +686,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO n = range_n[1] - range_n[0]; } - /* CPU partitions in m should have at least SWITCH_RATIO rows */ + /* Partitions in m should have at least SWITCH_RATIO rows */ if (m < 2 * SWITCH_RATIO) { nthreads_m = 1; } else { @@ -693,11 +696,11 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO } } - /* At most one CPU partition in n should have less than nthreads_m columns */ - if (n < nthreads_m) { + /* Partitions in n should have at most SWITCH_RATIO * nthreads_m columns */ + if (n < SWITCH_RATIO * nthreads_m) { nthreads_n = 1; } else { - nthreads_n = blas_quickdivide(n + nthreads_m - 1, nthreads_m); + nthreads_n = (n + SWITCH_RATIO * nthreads_m - 1) / (SWITCH_RATIO * nthreads_m); if (nthreads_m * nthreads_n > args -> nthreads) { nthreads_n = blas_quickdivide(args -> nthreads, nthreads_m); }