performance optimizations for sgemv_n

This commit is contained in:
wernsaar 2014-07-18 11:25:21 +02:00
parent 3c5732615d
commit c8a4a56177
2 changed files with 146 additions and 23 deletions

View File

@ -70,12 +70,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
n1 = n / 512 ; n1 = n / 512 ;
n2 = n % 512 ; n2 = n % 512 ;
m1 = m / 32; m1 = m / 64;
m2 = m % 32; m2 = m % 64;
x_ptr = x;
a_ptr = a;
y_ptr = y; y_ptr = y;
x_ptr = x;
for (j=0; j<n1; j++) for (j=0; j<n1; j++)
{ {
@ -85,12 +84,19 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
else else
copy_x(512,x_ptr,xbuffer,inc_x); copy_x(512,x_ptr,xbuffer,inc_x);
x_ptr += 512 * inc_x; a_ptr = a + j * 512 * lda;
a_ptr += j * 512;
y_ptr = y; y_ptr = y;
for(i = 0; i<m1; i++ ) for(i = 0; i<m1; i++ )
{
sgemv_kernel_64(512,alpha,a_ptr,lda,xbuffer,ybuffer);
add_y(64,ybuffer,y_ptr,inc_y);
y_ptr += 64 * inc_y;
a_ptr += 64;
}
if ( m2 & 32 )
{ {
sgemv_kernel_32(512,alpha,a_ptr,lda,xbuffer,ybuffer); sgemv_kernel_32(512,alpha,a_ptr,lda,xbuffer,ybuffer);
add_y(32,ybuffer,y_ptr,inc_y); add_y(32,ybuffer,y_ptr,inc_y);
@ -98,6 +104,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
a_ptr += 32; a_ptr += 32;
} }
if ( m2 & 16 ) if ( m2 & 16 )
{ {
sgemv_kernel_16(512,alpha,a_ptr,lda,xbuffer,ybuffer); sgemv_kernel_16(512,alpha,a_ptr,lda,xbuffer,ybuffer);
@ -131,6 +138,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
sgemv_kernel_1(512,alpha,a_ptr,lda,xbuffer,ybuffer); sgemv_kernel_1(512,alpha,a_ptr,lda,xbuffer,ybuffer);
add_y(1,ybuffer,y_ptr,inc_y); add_y(1,ybuffer,y_ptr,inc_y);
} }
x_ptr += 512 * inc_x;
} }
@ -142,9 +150,19 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
else else
copy_x(n2,x_ptr,xbuffer,inc_x); copy_x(n2,x_ptr,xbuffer,inc_x);
a_ptr = a + n1 * 512 * lda;
y_ptr = y; y_ptr = y;
for(i = 0; i<m1; i++ ) for(i = 0; i<m1; i++ )
{
sgemv_kernel_64(n2,alpha,a_ptr,lda,xbuffer,ybuffer);
add_y(64,ybuffer,y_ptr,inc_y);
y_ptr += 64 * inc_y;
a_ptr += 64;
}
if ( m2 & 32 )
{ {
sgemv_kernel_32(n2,alpha,a_ptr,lda,xbuffer,ybuffer); sgemv_kernel_32(n2,alpha,a_ptr,lda,xbuffer,ybuffer);
add_y(32,ybuffer,y_ptr,inc_y); add_y(32,ybuffer,y_ptr,inc_y);

View File

@ -25,12 +25,11 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
static void sgemv_kernel_64( long n, float alpha, float *a, long lda, float *x, float *y)
static void sgemv_kernel_32( long n, float alpha, float *a, long lda, float *x, float *y)
{ {
float *pre = a + lda*4*2; float *pre = a + lda*3;
__asm __volatile __asm __volatile
( (
@ -44,38 +43,56 @@ static void sgemv_kernel_32( long n, float alpha, float *a, long lda, float *x,
"prefetcht0 (%%r8)\n\t" // Prefetch "prefetcht0 (%%r8)\n\t" // Prefetch
"prefetcht0 64(%%r8)\n\t" // Prefetch "prefetcht0 64(%%r8)\n\t" // Prefetch
"vxorps %%ymm8 , %%ymm8 , %%ymm8 \n\t" // set to zero
"vxorps %%ymm9 , %%ymm9 , %%ymm9 \n\t" // set to zero
"vxorps %%ymm10, %%ymm10, %%ymm10\n\t" // set to zero
"vxorps %%ymm11, %%ymm11, %%ymm11\n\t" // set to zero
"vxorps %%ymm12, %%ymm12, %%ymm12\n\t" // set to zero "vxorps %%ymm12, %%ymm12, %%ymm12\n\t" // set to zero
"vxorps %%ymm13, %%ymm13, %%ymm13\n\t" // set to zero "vxorps %%ymm13, %%ymm13, %%ymm13\n\t" // set to zero
"vxorps %%ymm14, %%ymm14, %%ymm14\n\t" // set to zero "vxorps %%ymm14, %%ymm14, %%ymm14\n\t" // set to zero
"vxorps %%ymm15, %%ymm15, %%ymm15\n\t" // set to zero "vxorps %%ymm15, %%ymm15, %%ymm15\n\t" // set to zero
".align 16 \n\t"
".L01LOOP%=: \n\t" ".L01LOOP%=: \n\t"
"vbroadcastss (%%rdi), %%ymm0 \n\t" // load values of c "vbroadcastss (%%rdi), %%ymm0 \n\t" // load values of c
"addq $4 , %%rdi \n\t" // increment pointer of c "nop \n\t"
"leaq (%%r8 , %%rcx, 4), %%r8 \n\t" // add lda to pointer for prefetch "leaq (%%r8 , %%rcx, 4), %%r8 \n\t" // add lda to pointer for prefetch
"prefetcht0 (%%r8)\n\t" // Prefetch "prefetcht0 (%%r8)\n\t" // Prefetch
"vfmaddps %%ymm8 , 0*4(%%rsi), %%ymm0, %%ymm8 \n\t" // multiply a and c and add to temp
"prefetcht0 64(%%r8)\n\t" // Prefetch "prefetcht0 64(%%r8)\n\t" // Prefetch
"vfmaddps %%ymm9 , 8*4(%%rsi), %%ymm0, %%ymm9 \n\t" // multiply a and c and add to temp
"prefetcht0 128(%%r8)\n\t" // Prefetch
"vfmaddps %%ymm10, 16*4(%%rsi), %%ymm0, %%ymm10\n\t" // multiply a and c and add to temp
"vfmaddps %%ymm11, 24*4(%%rsi), %%ymm0, %%ymm11\n\t" // multiply a and c and add to temp
"prefetcht0 192(%%r8)\n\t" // Prefetch
"vfmaddps %%ymm12, 32*4(%%rsi), %%ymm0, %%ymm12\n\t" // multiply a and c and add to temp
"vfmaddps %%ymm13, 40*4(%%rsi), %%ymm0, %%ymm13\n\t" // multiply a and c and add to temp
"vfmaddps %%ymm14, 48*4(%%rsi), %%ymm0, %%ymm14\n\t" // multiply a and c and add to temp
"vfmaddps %%ymm15, 56*4(%%rsi), %%ymm0, %%ymm15\n\t" // multiply a and c and add to temp
"vfmaddps %%ymm12, 0*4(%%rsi), %%ymm0, %%ymm12\n\t" // multiply a and c and add to temp "addq $4 , %%rdi \n\t" // increment pointer of c
"vfmaddps %%ymm13, 8*4(%%rsi), %%ymm0, %%ymm13\n\t" // multiply a and c and add to temp
"vfmaddps %%ymm14, 16*4(%%rsi), %%ymm0, %%ymm14\n\t" // multiply a and c and add to temp
"vfmaddps %%ymm15, 24*4(%%rsi), %%ymm0, %%ymm15\n\t" // multiply a and c and add to temp
"leaq (%%rsi, %%rcx, 4), %%rsi \n\t" // add lda to pointer of a "leaq (%%rsi, %%rcx, 4), %%rsi \n\t" // add lda to pointer of a
"dec %%rax \n\t" // n = n -1 "dec %%rax \n\t" // n = n -1
"jnz .L01LOOP%= \n\t" "jnz .L01LOOP%= \n\t"
"vmulps %%ymm8 , %%ymm1, %%ymm8 \n\t" // scale by alpha
"vmulps %%ymm9 , %%ymm1, %%ymm9 \n\t" // scale by alpha
"vmulps %%ymm10, %%ymm1, %%ymm10\n\t" // scale by alpha
"vmulps %%ymm11, %%ymm1, %%ymm11\n\t" // scale by alpha
"vmulps %%ymm12, %%ymm1, %%ymm12\n\t" // scale by alpha "vmulps %%ymm12, %%ymm1, %%ymm12\n\t" // scale by alpha
"vmulps %%ymm13, %%ymm1, %%ymm13\n\t" // scale by alpha "vmulps %%ymm13, %%ymm1, %%ymm13\n\t" // scale by alpha
"vmulps %%ymm14, %%ymm1, %%ymm14\n\t" // scale by alpha "vmulps %%ymm14, %%ymm1, %%ymm14\n\t" // scale by alpha
"vmulps %%ymm15, %%ymm1, %%ymm15\n\t" // scale by alpha "vmulps %%ymm15, %%ymm1, %%ymm15\n\t" // scale by alpha
"vmovups %%ymm12, (%%rdx) \n\t" // store temp -> y "vmovups %%ymm8 , (%%rdx) \n\t" // store temp -> y
"vmovups %%ymm13, 8*4(%%rdx) \n\t" // store temp -> y "vmovups %%ymm9 , 8*4(%%rdx) \n\t" // store temp -> y
"vmovups %%ymm14, 16*4(%%rdx) \n\t" // store temp -> y "vmovups %%ymm10, 16*4(%%rdx) \n\t" // store temp -> y
"vmovups %%ymm15, 24*4(%%rdx) \n\t" // store temp -> y "vmovups %%ymm11, 24*4(%%rdx) \n\t" // store temp -> y
"vmovups %%ymm12, 32*4(%%rdx) \n\t" // store temp -> y
"vmovups %%ymm13, 40*4(%%rdx) \n\t" // store temp -> y
"vmovups %%ymm14, 48*4(%%rdx) \n\t" // store temp -> y
"vmovups %%ymm15, 56*4(%%rdx) \n\t" // store temp -> y
: :
: :
@ -88,6 +105,94 @@ static void sgemv_kernel_32( long n, float alpha, float *a, long lda, float *x,
"m" (pre) // 6 "m" (pre) // 6
: "rax", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", : "rax", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11",
"xmm0" , "xmm1", "xmm0" , "xmm1",
"xmm8", "xmm9", "xmm10", "xmm11",
"xmm12", "xmm13", "xmm14", "xmm15",
"memory"
);
}
static void sgemv_kernel_32( long n, float alpha, float *a, long lda, float *x, float *y)
{
float *pre = a + lda*3;
__asm __volatile
(
"movq %0, %%rax\n\t" // n -> rax
"vbroadcastss %1, %%xmm1\n\t" // alpha -> xmm1
"movq %2, %%rsi\n\t" // adress of a -> rsi
"movq %3, %%rcx\n\t" // value of lda > rcx
"movq %4, %%rdi\n\t" // adress of x -> rdi
"movq %5, %%rdx\n\t" // adress of y -> rdx
"movq %6, %%r8\n\t" // address for prefetch
"prefetcht0 (%%r8)\n\t" // Prefetch
"prefetcht0 64(%%r8)\n\t" // Prefetch
"vxorps %%xmm8 , %%xmm8 , %%xmm8 \n\t" // set to zero
"vxorps %%xmm9 , %%xmm9 , %%xmm9 \n\t" // set to zero
"vxorps %%xmm10, %%xmm10, %%xmm10\n\t" // set to zero
"vxorps %%xmm11, %%xmm11, %%xmm11\n\t" // set to zero
"vxorps %%xmm12, %%xmm12, %%xmm12\n\t" // set to zero
"vxorps %%xmm13, %%xmm13, %%xmm13\n\t" // set to zero
"vxorps %%xmm14, %%xmm14, %%xmm14\n\t" // set to zero
"vxorps %%xmm15, %%xmm15, %%xmm15\n\t" // set to zero
".align 16 \n\t"
".L01LOOP%=: \n\t"
"vbroadcastss (%%rdi), %%xmm0 \n\t" // load values of c
"nop \n\t"
"leaq (%%r8 , %%rcx, 4), %%r8 \n\t" // add lda to pointer for prefetch
"prefetcht0 (%%r8)\n\t" // Prefetch
"vfmaddps %%xmm8 , 0*4(%%rsi), %%xmm0, %%xmm8 \n\t" // multiply a and c and add to temp
"prefetcht0 64(%%r8)\n\t" // Prefetch
"vfmaddps %%xmm9 , 4*4(%%rsi), %%xmm0, %%xmm9 \n\t" // multiply a and c and add to temp
"vfmaddps %%xmm10, 8*4(%%rsi), %%xmm0, %%xmm10\n\t" // multiply a and c and add to temp
"vfmaddps %%xmm11, 12*4(%%rsi), %%xmm0, %%xmm11\n\t" // multiply a and c and add to temp
"vfmaddps %%xmm12, 16*4(%%rsi), %%xmm0, %%xmm12\n\t" // multiply a and c and add to temp
"vfmaddps %%xmm13, 20*4(%%rsi), %%xmm0, %%xmm13\n\t" // multiply a and c and add to temp
"vfmaddps %%xmm14, 24*4(%%rsi), %%xmm0, %%xmm14\n\t" // multiply a and c and add to temp
"vfmaddps %%xmm15, 28*4(%%rsi), %%xmm0, %%xmm15\n\t" // multiply a and c and add to temp
"addq $4 , %%rdi \n\t" // increment pointer of c
"leaq (%%rsi, %%rcx, 4), %%rsi \n\t" // add lda to pointer of a
"dec %%rax \n\t" // n = n -1
"jnz .L01LOOP%= \n\t"
"vmulps %%xmm8 , %%xmm1, %%xmm8 \n\t" // scale by alpha
"vmulps %%xmm9 , %%xmm1, %%xmm9 \n\t" // scale by alpha
"vmulps %%xmm10, %%xmm1, %%xmm10\n\t" // scale by alpha
"vmulps %%xmm11, %%xmm1, %%xmm11\n\t" // scale by alpha
"vmulps %%xmm12, %%xmm1, %%xmm12\n\t" // scale by alpha
"vmulps %%xmm13, %%xmm1, %%xmm13\n\t" // scale by alpha
"vmulps %%xmm14, %%xmm1, %%xmm14\n\t" // scale by alpha
"vmulps %%xmm15, %%xmm1, %%xmm15\n\t" // scale by alpha
"vmovups %%xmm8 , (%%rdx) \n\t" // store temp -> y
"vmovups %%xmm9 , 4*4(%%rdx) \n\t" // store temp -> y
"vmovups %%xmm10, 8*4(%%rdx) \n\t" // store temp -> y
"vmovups %%xmm11, 12*4(%%rdx) \n\t" // store temp -> y
"vmovups %%xmm12, 16*4(%%rdx) \n\t" // store temp -> y
"vmovups %%xmm13, 20*4(%%rdx) \n\t" // store temp -> y
"vmovups %%xmm14, 24*4(%%rdx) \n\t" // store temp -> y
"vmovups %%xmm15, 28*4(%%rdx) \n\t" // store temp -> y
:
:
"m" (n), // 0
"m" (alpha), // 1
"m" (a), // 2
"m" (lda), // 3
"m" (x), // 4
"m" (y), // 5
"m" (pre) // 6
: "rax", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11",
"xmm0" , "xmm1",
"xmm8", "xmm9", "xmm10", "xmm11",
"xmm12", "xmm13", "xmm14", "xmm15", "xmm12", "xmm13", "xmm14", "xmm15",
"memory" "memory"
); );
@ -97,7 +202,7 @@ static void sgemv_kernel_32( long n, float alpha, float *a, long lda, float *x,
static void sgemv_kernel_16( long n, float alpha, float *a, long lda, float *x, float *y) static void sgemv_kernel_16( long n, float alpha, float *a, long lda, float *x, float *y)
{ {
float *pre = a + lda*4*3; float *pre = a + lda*1;
__asm __volatile __asm __volatile
( (