From c4e5ba1bfe8c7c4e263d5c14f4034e657347b591 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Wed, 2 Aug 2017 00:37:58 +0200 Subject: [PATCH] Make sure that range_n of last thread never exceeds the actual data size when splitting the workload --- driver/level2/gbmv_thread.c | 2 ++ driver/level2/sbmv_thread.c | 3 +++ driver/level2/spmv_thread.c | 2 ++ driver/level2/symv_thread.c | 4 +++- driver/level2/tbmv_thread.c | 3 +++ driver/level2/tpmv_thread.c | 4 +++- driver/level2/trmv_thread.c | 4 +++- 7 files changed, 19 insertions(+), 3 deletions(-) diff --git a/driver/level2/gbmv_thread.c b/driver/level2/gbmv_thread.c index e86b565f8..9d374676e 100644 --- a/driver/level2/gbmv_thread.c +++ b/driver/level2/gbmv_thread.c @@ -230,8 +230,10 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG ku, BLASLONG kl, FLOAT *alpha, FLOAT #ifndef TRANSA range_m[num_cpu] = num_cpu * ((m + 15) & ~15); + if (range_m[num_cpu] > m) range_m[num_cpu] = m; #else range_m[num_cpu] = num_cpu * ((n + 15) & ~15); + if (range_m[num_cpu] > n) range_m[num_cpu] = n; #endif queue[num_cpu].mode = mode; diff --git a/driver/level2/sbmv_thread.c b/driver/level2/sbmv_thread.c index 5718c0ec9..ce841ee0e 100644 --- a/driver/level2/sbmv_thread.c +++ b/driver/level2/sbmv_thread.c @@ -246,6 +246,7 @@ int CNAME(BLASLONG n, BLASLONG k, FLOAT *alpha, FLOAT *a, BLASLONG lda, FLOAT *x range_m[MAX_CPU_NUMBER - num_cpu - 1] = range_m[MAX_CPU_NUMBER - num_cpu] - width; range_n[num_cpu] = num_cpu * (((n + 15) & ~15) + 16); + if (range_n[num_cpu] > n) range_n[num_cpu] = n; queue[num_cpu].mode = mode; queue[num_cpu].routine = sbmv_kernel; @@ -285,6 +286,7 @@ int CNAME(BLASLONG n, BLASLONG k, FLOAT *alpha, FLOAT *a, BLASLONG lda, FLOAT *x range_m[num_cpu + 1] = range_m[num_cpu] + width; range_n[num_cpu] = num_cpu * (((n + 15) & ~15) + 16); + if (range_n[num_cpu] > n) range_n[num_cpu] = n; queue[num_cpu].mode = mode; queue[num_cpu].routine = sbmv_kernel; @@ -316,6 +318,7 @@ int CNAME(BLASLONG n, BLASLONG k, FLOAT *alpha, FLOAT *a, BLASLONG lda, FLOAT *x range_m[num_cpu + 1] = range_m[num_cpu] + width; range_n[num_cpu] = num_cpu * ((n + 15) & ~15); + if (range_n[num_cpu] > n) range_n[num_cpu] = n; queue[num_cpu].mode = mode; queue[num_cpu].routine = sbmv_kernel; diff --git a/driver/level2/spmv_thread.c b/driver/level2/spmv_thread.c index 035300841..0b4087430 100644 --- a/driver/level2/spmv_thread.c +++ b/driver/level2/spmv_thread.c @@ -246,6 +246,7 @@ int CNAME(BLASLONG m, FLOAT *alpha, FLOAT *a, FLOAT *x, BLASLONG incx, FLOAT *y, range_m[MAX_CPU_NUMBER - num_cpu - 1] = range_m[MAX_CPU_NUMBER - num_cpu] - width; range_n[num_cpu] = num_cpu * (((m + 15) & ~15) + 16); + if (range_n[num_cpu] > m) range_n[num_cpu] = m; queue[num_cpu].mode = mode; queue[num_cpu].routine = spmv_kernel; @@ -285,6 +286,7 @@ int CNAME(BLASLONG m, FLOAT *alpha, FLOAT *a, FLOAT *x, BLASLONG incx, FLOAT *y, range_m[num_cpu + 1] = range_m[num_cpu] + width; range_n[num_cpu] = num_cpu * (((m + 15) & ~15) + 16); + if (range_n[num_cpu] > m) range_n[num_cpu] = m; queue[num_cpu].mode = mode; queue[num_cpu].routine = spmv_kernel; diff --git a/driver/level2/symv_thread.c b/driver/level2/symv_thread.c index 6580178f1..8d4cd249c 100644 --- a/driver/level2/symv_thread.c +++ b/driver/level2/symv_thread.c @@ -177,7 +177,8 @@ int CNAME(BLASLONG m, FLOAT *alpha, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG i range_m[num_cpu + 1] = range_m[num_cpu] + width; range_n[num_cpu] = num_cpu * (((m + 15) & ~15) + 16); - + if (range_n[num_cpu] > m) range_n[num_cpu] = m; + queue[MAX_CPU_NUMBER - num_cpu - 1].mode = mode; queue[MAX_CPU_NUMBER - num_cpu - 1].routine = symv_kernel; queue[MAX_CPU_NUMBER - num_cpu - 1].args = &args; @@ -225,6 +226,7 @@ int CNAME(BLASLONG m, FLOAT *alpha, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG i range_m[num_cpu + 1] = range_m[num_cpu] + width; range_n[num_cpu] = num_cpu * (((m + 15) & ~15) + 16); + if (range_n[num_cpu] > m) range_n[num_cpu] = m; queue[num_cpu].mode = mode; queue[num_cpu].routine = symv_kernel; diff --git a/driver/level2/tbmv_thread.c b/driver/level2/tbmv_thread.c index 226a922e9..aaf4958e2 100644 --- a/driver/level2/tbmv_thread.c +++ b/driver/level2/tbmv_thread.c @@ -288,6 +288,7 @@ int CNAME(BLASLONG n, BLASLONG k, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc range_m[MAX_CPU_NUMBER - num_cpu - 1] = range_m[MAX_CPU_NUMBER - num_cpu] - width; range_n[num_cpu] = num_cpu * (((n + 15) & ~15) + 16); + if (range_n[num_cpu] > n) range_n[num_cpu] = n; queue[num_cpu].mode = mode; queue[num_cpu].routine = trmv_kernel; @@ -327,6 +328,7 @@ int CNAME(BLASLONG n, BLASLONG k, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc range_m[num_cpu + 1] = range_m[num_cpu] + width; range_n[num_cpu] = num_cpu * (((n + 15) & ~15) + 16); + if (range_n[num_cpu] > n) range_n[num_cpu] = n; queue[num_cpu].mode = mode; queue[num_cpu].routine = trmv_kernel; @@ -356,6 +358,7 @@ int CNAME(BLASLONG n, BLASLONG k, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc range_m[num_cpu + 1] = range_m[num_cpu] + width; range_n[num_cpu] = num_cpu * (((n + 15) & ~15) + 16); + if (range_n[num_cpu] > n) range_n[num_cpu] = n; queue[num_cpu].mode = mode; queue[num_cpu].routine = trmv_kernel; diff --git a/driver/level2/tpmv_thread.c b/driver/level2/tpmv_thread.c index c91b52775..79438ba29 100644 --- a/driver/level2/tpmv_thread.c +++ b/driver/level2/tpmv_thread.c @@ -307,7 +307,8 @@ int CNAME(BLASLONG m, FLOAT *a, FLOAT *x, BLASLONG incx, FLOAT *buffer, int nthr range_m[MAX_CPU_NUMBER - num_cpu - 1] = range_m[MAX_CPU_NUMBER - num_cpu] - width; range_n[num_cpu] = num_cpu * (((m + 15) & ~15) + 16); - + if (range_n[num_cpu] > m) range_n[num_cpu] = m; + queue[num_cpu].mode = mode; queue[num_cpu].routine = tpmv_kernel; queue[num_cpu].args = &args; @@ -346,6 +347,7 @@ int CNAME(BLASLONG m, FLOAT *a, FLOAT *x, BLASLONG incx, FLOAT *buffer, int nthr range_m[num_cpu + 1] = range_m[num_cpu] + width; range_n[num_cpu] = num_cpu * (((m + 15) & ~15) + 16); + if (range_n[num_cpu] > m) range_n[num_cpu] = m; queue[num_cpu].mode = mode; queue[num_cpu].routine = tpmv_kernel; diff --git a/driver/level2/trmv_thread.c b/driver/level2/trmv_thread.c index 0a155366c..8b931a0e8 100644 --- a/driver/level2/trmv_thread.c +++ b/driver/level2/trmv_thread.c @@ -346,7 +346,8 @@ int CNAME(BLASLONG m, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG incx, FLOAT *bu range_m[MAX_CPU_NUMBER - num_cpu - 1] = range_m[MAX_CPU_NUMBER - num_cpu] - width; range_n[num_cpu] = num_cpu * (((m + 15) & ~15) + 16); - + if (range_n[num_cpu] > m) range_n[num_cpu] = m; + queue[num_cpu].mode = mode; queue[num_cpu].routine = trmv_kernel; queue[num_cpu].args = &args; @@ -385,6 +386,7 @@ int CNAME(BLASLONG m, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG incx, FLOAT *bu range_m[num_cpu + 1] = range_m[num_cpu] + width; range_n[num_cpu] = num_cpu * (((m + 15) & ~15) + 16); + if (range_n[num_cpu] > m) range_n[num_cpu] = m; queue[num_cpu].mode = mode; queue[num_cpu].routine = trmv_kernel;