Merge pull request #1846 from fenrus75/threadsize
gemm/dgemm: add a way for an arch kernel to specify preferred sizes
This commit is contained in:
commit
f1c02273cb
|
@ -48,6 +48,10 @@
|
||||||
#define SWITCH_RATIO 2
|
#define SWITCH_RATIO 2
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifndef GEMM_PREFERED_SIZE
|
||||||
|
#define GEMM_PREFERED_SIZE 1
|
||||||
|
#endif
|
||||||
|
|
||||||
//The array of job_t may overflow the stack.
|
//The array of job_t may overflow the stack.
|
||||||
//Instead, use malloc to alloc job_t.
|
//Instead, use malloc to alloc job_t.
|
||||||
#if MAX_CPU_NUMBER > BLAS3_MEM_ALLOC_THRESHOLD
|
#if MAX_CPU_NUMBER > BLAS3_MEM_ALLOC_THRESHOLD
|
||||||
|
@ -510,6 +514,16 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int round_up(int remainder, int width, int multiple)
|
||||||
|
{
|
||||||
|
if (multiple > remainder || width <= multiple)
|
||||||
|
return width;
|
||||||
|
width = (width + multiple - 1) / multiple;
|
||||||
|
width = width * multiple;
|
||||||
|
return width;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
|
static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
|
||||||
*range_n, FLOAT *sa, FLOAT *sb,
|
*range_n, FLOAT *sa, FLOAT *sb,
|
||||||
BLASLONG nthreads_m, BLASLONG nthreads_n) {
|
BLASLONG nthreads_m, BLASLONG nthreads_n) {
|
||||||
|
@ -601,9 +615,14 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
|
||||||
num_parts = 0;
|
num_parts = 0;
|
||||||
while (m > 0){
|
while (m > 0){
|
||||||
width = blas_quickdivide(m + nthreads_m - num_parts - 1, nthreads_m - num_parts);
|
width = blas_quickdivide(m + nthreads_m - num_parts - 1, nthreads_m - num_parts);
|
||||||
|
|
||||||
|
width = round_up(m, width, GEMM_PREFERED_SIZE);
|
||||||
|
|
||||||
m -= width;
|
m -= width;
|
||||||
|
|
||||||
if (m < 0) width = width + m;
|
if (m < 0) width = width + m;
|
||||||
range_M[num_parts + 1] = range_M[num_parts] + width;
|
range_M[num_parts + 1] = range_M[num_parts] + width;
|
||||||
|
|
||||||
num_parts ++;
|
num_parts ++;
|
||||||
}
|
}
|
||||||
for (i = num_parts; i < MAX_CPU_NUMBER; i++) {
|
for (i = num_parts; i < MAX_CPU_NUMBER; i++) {
|
||||||
|
@ -645,9 +664,12 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
|
||||||
if (width < SWITCH_RATIO) {
|
if (width < SWITCH_RATIO) {
|
||||||
width = SWITCH_RATIO;
|
width = SWITCH_RATIO;
|
||||||
}
|
}
|
||||||
|
width = round_up(n, width, GEMM_PREFERED_SIZE);
|
||||||
|
|
||||||
n -= width;
|
n -= width;
|
||||||
if (n < 0) width = width + n;
|
if (n < 0) width = width + n;
|
||||||
range_N[num_parts + 1] = range_N[num_parts] + width;
|
range_N[num_parts + 1] = range_N[num_parts] + width;
|
||||||
|
|
||||||
num_parts ++;
|
num_parts ++;
|
||||||
}
|
}
|
||||||
for (j = num_parts; j < MAX_CPU_NUMBER; j++) {
|
for (j = num_parts; j < MAX_CPU_NUMBER; j++) {
|
||||||
|
|
|
@ -55,6 +55,8 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta,
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (m == 0 || n == 0)
|
||||||
|
return 0;
|
||||||
|
|
||||||
c_offset = c;
|
c_offset = c;
|
||||||
|
|
||||||
|
@ -69,7 +71,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta,
|
||||||
|
|
||||||
i = m;
|
i = m;
|
||||||
|
|
||||||
while (i > 32) {
|
while (i >= 32) {
|
||||||
_mm512_storeu_pd(c_offset1, z_zero);
|
_mm512_storeu_pd(c_offset1, z_zero);
|
||||||
_mm512_storeu_pd(c_offset1 + 8, z_zero);
|
_mm512_storeu_pd(c_offset1 + 8, z_zero);
|
||||||
_mm512_storeu_pd(c_offset1 + 16, z_zero);
|
_mm512_storeu_pd(c_offset1 + 16, z_zero);
|
||||||
|
@ -77,7 +79,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta,
|
||||||
c_offset1 += 32;
|
c_offset1 += 32;
|
||||||
i -= 32;
|
i -= 32;
|
||||||
}
|
}
|
||||||
while (i > 8) {
|
while (i >= 8) {
|
||||||
_mm512_storeu_pd(c_offset1, z_zero);
|
_mm512_storeu_pd(c_offset1, z_zero);
|
||||||
c_offset1 += 8;
|
c_offset1 += 8;
|
||||||
i -= 8;
|
i -= 8;
|
||||||
|
|
|
@ -55,6 +55,8 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta,
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (n == 0 || m == 0)
|
||||||
|
return;
|
||||||
|
|
||||||
c_offset = c;
|
c_offset = c;
|
||||||
|
|
||||||
|
@ -71,13 +73,13 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta,
|
||||||
|
|
||||||
i = m;
|
i = m;
|
||||||
|
|
||||||
while (i > 32) {
|
while (i >= 32) {
|
||||||
_mm512_storeu_ps(c_offset1, z_zero);
|
_mm512_storeu_ps(c_offset1, z_zero);
|
||||||
_mm512_storeu_ps(c_offset1 + 16, z_zero);
|
_mm512_storeu_ps(c_offset1 + 16, z_zero);
|
||||||
c_offset1 += 32;
|
c_offset1 += 32;
|
||||||
i -= 32;
|
i -= 32;
|
||||||
}
|
}
|
||||||
while (i > 8) {
|
while (i >= 8) {
|
||||||
_mm256_storeu_ps(c_offset1, y_zero);
|
_mm256_storeu_ps(c_offset1, y_zero);
|
||||||
c_offset1 += 8;
|
c_offset1 += 8;
|
||||||
i -= 8;
|
i -= 8;
|
||||||
|
|
Loading…
Reference in New Issue