Reduce number of data partitions in n.

This commit is contained in:
Tim Moon 2017-10-04 12:37:49 -07:00
parent 9de52b489a
commit 30486a356c
1 changed files with 20 additions and 17 deletions

View File

@ -525,7 +525,7 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
BLASLONG range_M_buffer[MAX_CPU_NUMBER + 2];
BLASLONG range_N_buffer[MAX_CPU_NUMBER + 2];
BLASLONG *range_M, *range_N;
BLASLONG num_cpu_m, num_cpu_n;
BLASLONG num_parts;
BLASLONG nthreads = args -> nthreads;
@ -596,16 +596,16 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
}
/* Partition m into nthreads_m regions */
num_cpu_m = 0;
num_parts = 0;
while (m > 0){
width = blas_quickdivide(m + nthreads_m - num_cpu_m - 1, nthreads_m - num_cpu_m);
width = blas_quickdivide(m + nthreads_m - num_parts - 1, nthreads_m - num_parts);
m -= width;
if (m < 0) width = width + m;
range_M[num_cpu_m + 1] = range_M[num_cpu_m] + width;
num_cpu_m ++;
range_M[num_parts + 1] = range_M[num_parts] + width;
num_parts ++;
}
for (i = num_cpu_m; i < MAX_CPU_NUMBER; i++) {
range_M[i + 1] = range_M[num_cpu_m];
for (i = num_parts; i < MAX_CPU_NUMBER; i++) {
range_M[i + 1] = range_M[num_parts];
}
/* Initialize parameters for parallel execution */
@ -637,16 +637,19 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
/* Partition (a step of) n into nthreads regions */
range_N[0] = js;
num_cpu_n = 0;
num_parts = 0;
while (n > 0){
width = blas_quickdivide(n + nthreads - num_cpu_n - 1, nthreads - num_cpu_n);
width = blas_quickdivide(n + nthreads - num_parts - 1, nthreads - num_parts);
if (width < SWITCH_RATIO) {
width = SWITCH_RATIO;
}
n -= width;
if (n < 0) width = width + n;
range_N[num_cpu_n + 1] = range_N[num_cpu_n] + width;
num_cpu_n ++;
range_N[num_parts + 1] = range_N[num_parts] + width;
num_parts ++;
}
for (j = num_cpu_n; j < MAX_CPU_NUMBER; j++) {
range_N[j + 1] = range_N[num_cpu_n];
for (j = num_parts; j < MAX_CPU_NUMBER; j++) {
range_N[j + 1] = range_N[num_parts];
}
/* Clear synchronization flags */
@ -683,7 +686,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
n = range_n[1] - range_n[0];
}
/* CPU partitions in m should have at least SWITCH_RATIO rows */
/* Partitions in m should have at least SWITCH_RATIO rows */
if (m < 2 * SWITCH_RATIO) {
nthreads_m = 1;
} else {
@ -693,11 +696,11 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
}
}
/* At most one CPU partition in n should have less than nthreads_m columns */
if (n < nthreads_m) {
/* Partitions in n should have at most SWITCH_RATIO * nthreads_m columns */
if (n < SWITCH_RATIO * nthreads_m) {
nthreads_n = 1;
} else {
nthreads_n = blas_quickdivide(n + nthreads_m - 1, nthreads_m);
nthreads_n = (n + SWITCH_RATIO * nthreads_m - 1) / (SWITCH_RATIO * nthreads_m);
if (nthreads_m * nthreads_n > args -> nthreads) {
nthreads_n = blas_quickdivide(args -> nthreads, nthreads_m);
}