From 8e5a1083bbfca6d93e3d35c1490311cbda675761 Mon Sep 17 00:00:00 2001 From: Zhang Xianyi Date: Fri, 8 May 2015 05:33:17 +0800 Subject: [PATCH] Refs #532. Improve gemv paralel with small m and large n case. Splite the matrix and reduction. --- driver/level2/gemv_thread.c | 84 ++++++++++++++++++++++++++++++++++++- 1 file changed, 83 insertions(+), 1 deletion(-) diff --git a/driver/level2/gemv_thread.c b/driver/level2/gemv_thread.c index ddd475367..061454848 100644 --- a/driver/level2/gemv_thread.c +++ b/driver/level2/gemv_thread.c @@ -62,6 +62,11 @@ #endif #endif +#ifndef TRANSA +#define Y_DUMMY_NUM 1024 +static FLOAT y_dummy[Y_DUMMY_NUM]; +#endif + static int gemv_kernel(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *dummy1, FLOAT *buffer, BLASLONG pos){ FLOAT *a, *x, *y; @@ -99,10 +104,15 @@ static int gemv_kernel(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, F a += n_from * lda * COMPSIZE; #ifdef TRANSA y += n_from * incy * COMPSIZE; +#else + //for split matrix row (n) direction and vector x of gemv_n + x += n_from * incx * COMPSIZE; + //store partial result for every thread + y += (m_to - m_from) * 1 * COMPSIZE * pos; #endif } - // fprintf(stderr, "M_From = %d M_To = %d N_From = %d N_To = %d\n", m_from, m_to, n_from, n_to); + //fprintf(stderr, "M_From = %d M_To = %d N_From = %d N_To = %d POS=%d\n", m_from, m_to, n_from, n_to, pos); GEMV(m_to - m_from, n_to - n_from, 0, *((FLOAT *)args -> alpha + 0), @@ -126,6 +136,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *alpha, FLOAT *a, BLASLONG lda, FLOAT *x BLASLONG width, i, num_cpu; +#ifndef TRANSA + int split_x=0; +#endif + #ifdef SMP #ifndef COMPLEX #ifdef XDOUBLE @@ -198,6 +212,58 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *alpha, FLOAT *a, BLASLONG lda, FLOAT *x i -= width; } +#ifndef TRANSA + //try to split matrix on row direction and x. + //Then, reduction. + if (num_cpu < nthreads) { + + //too small to split or bigger than the y_dummy buffer. + double MN = (double) m * (double) n; + if ( MN <= (24.0 * 24.0 * (double) (GEMM_MULTITHREAD_THRESHOLD*GEMM_MULTITHREAD_THRESHOLD)) + || m*COMPSIZE*nthreads > Y_DUMMY_NUM) + goto Outer; + + num_cpu = 0; + range[0] = 0; + + memset(y_dummy, 0, sizeof(FLOAT) * m * COMPSIZE * nthreads); + + args.ldc = 1; + args.c = (void *)y_dummy; + + //split on row (n) and x + i=n; + split_x=1; + while (i > 0){ + + width = blas_quickdivide(i + nthreads - num_cpu - 1, nthreads - num_cpu); + if (width < 4) width = 4; + if (i < width) width = i; + + range[num_cpu + 1] = range[num_cpu] + width; + + queue[num_cpu].mode = mode; + queue[num_cpu].routine = gemv_kernel; + queue[num_cpu].args = &args; + + queue[num_cpu].position = num_cpu; + + queue[num_cpu].range_m = NULL; + queue[num_cpu].range_n = &range[num_cpu]; + + queue[num_cpu].sa = NULL; + queue[num_cpu].sb = NULL; + queue[num_cpu].next = &queue[num_cpu + 1]; + + num_cpu ++; + i -= width; + } + + } + + Outer: +#endif + if (num_cpu) { queue[0].sa = NULL; queue[0].sb = buffer; @@ -206,5 +272,21 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *alpha, FLOAT *a, BLASLONG lda, FLOAT *x exec_blas(num_cpu, queue); } +#ifndef TRANSA + if(split_x==1){ + //reduction + for(i=0; i