optimized cgemv_n_4.c

This commit is contained in:
wernsaar 2014-09-10 19:26:14 +02:00
parent be95700b30
commit f98e1244c4
1 changed files with 113 additions and 25 deletions

View File

@ -30,7 +30,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "common.h"
#if defined(HASWELL)
#include "cgemv_n_microk_haswell-2.c"
#include "cgemv_n_microk_haswell-4.c"
#endif
@ -73,6 +73,41 @@ static void cgemv_kernel_4x4(BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y)
#endif
#ifndef HAVE_KERNEL_4x2
static void cgemv_kernel_4x2(BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y)
{
BLASLONG i;
FLOAT *a0,*a1;
a0 = ap[0];
a1 = ap[1];
for ( i=0; i< 2*n; i+=2 )
{
#if ( !defined(CONJ) && !defined(XCONJ) ) || ( defined(CONJ) && defined(XCONJ) )
y[i] += a0[i]*x[0] - a0[i+1] * x[1];
y[i+1] += a0[i]*x[1] + a0[i+1] * x[0];
y[i] += a1[i]*x[2] - a1[i+1] * x[3];
y[i+1] += a1[i]*x[3] + a1[i+1] * x[2];
#else
y[i] += a0[i]*x[0] + a0[i+1] * x[1];
y[i+1] += a0[i]*x[1] - a0[i+1] * x[0];
y[i] += a1[i]*x[2] + a1[i+1] * x[3];
y[i+1] += a1[i]*x[3] - a1[i+1] * x[2];
#endif
}
}
#endif
#ifndef HAVE_KERNEL_4x1
static void cgemv_kernel_4x1(BLASLONG n, FLOAT *ap, FLOAT *x, FLOAT *y)
{
BLASLONG i;
@ -93,21 +128,18 @@ static void cgemv_kernel_4x1(BLASLONG n, FLOAT *ap, FLOAT *x, FLOAT *y)
}
static void zero_y(BLASLONG n, FLOAT *dest)
{
BLASLONG i;
for ( i=0; i<2*n; i++ )
{
*dest = 0.0;
dest++;
}
}
#endif
static void add_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest,FLOAT alpha_r, FLOAT alpha_i) __attribute__ ((noinline));
static void add_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest,FLOAT alpha_r, FLOAT alpha_i)
{
BLASLONG i;
if ( inc_dest != 2 )
{
FLOAT temp_r;
FLOAT temp_i;
for ( i=0; i<n; i++ )
@ -126,6 +158,53 @@ static void add_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest,FLOAT a
src+=2;
dest += inc_dest;
}
return;
}
FLOAT temp_r0;
FLOAT temp_i0;
FLOAT temp_r1;
FLOAT temp_i1;
FLOAT temp_r2;
FLOAT temp_i2;
FLOAT temp_r3;
FLOAT temp_i3;
for ( i=0; i<n; i+=4 )
{
#if !defined(XCONJ)
temp_r0 = alpha_r * src[0] - alpha_i * src[1];
temp_i0 = alpha_r * src[1] + alpha_i * src[0];
temp_r1 = alpha_r * src[2] - alpha_i * src[3];
temp_i1 = alpha_r * src[3] + alpha_i * src[2];
temp_r2 = alpha_r * src[4] - alpha_i * src[5];
temp_i2 = alpha_r * src[5] + alpha_i * src[4];
temp_r3 = alpha_r * src[6] - alpha_i * src[7];
temp_i3 = alpha_r * src[7] + alpha_i * src[6];
#else
temp_r0 = alpha_r * src[0] + alpha_i * src[1];
temp_i0 = -alpha_r * src[1] + alpha_i * src[0];
temp_r1 = alpha_r * src[2] + alpha_i * src[3];
temp_i1 = -alpha_r * src[3] + alpha_i * src[2];
temp_r2 = alpha_r * src[4] + alpha_i * src[5];
temp_i2 = -alpha_r * src[5] + alpha_i * src[4];
temp_r3 = alpha_r * src[6] + alpha_i * src[7];
temp_i3 = -alpha_r * src[7] + alpha_i * src[6];
#endif
dest[0] += temp_r0;
dest[1] += temp_i0;
dest[2] += temp_r1;
dest[3] += temp_i1;
dest[4] += temp_r2;
dest[5] += temp_i2;
dest[6] += temp_r3;
dest[7] += temp_i3;
src += 8;
dest += 8;
}
return;
}
int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha_r,FLOAT alpha_i, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *buffer)
@ -186,7 +265,8 @@ printf("%s %d %d %.16f %.16f %d %d %d\n","zgemv_n",m,n,alpha_r,alpha_i,lda,inc_x
ap[2] = ap[1] + lda;
ap[3] = ap[2] + lda;
x_ptr = x;
zero_y(NB,ybuffer);
//zero_y(NB,ybuffer);
memset(ybuffer,0,NB*8);
if ( inc_x == 2 )
{
@ -202,7 +282,15 @@ printf("%s %d %d %.16f %.16f %d %d %d\n","zgemv_n",m,n,alpha_r,alpha_i,lda,inc_x
x_ptr += 8;
}
for( i = 0; i < n2 ; i++)
if ( n2 & 2 )
{
cgemv_kernel_4x2(NB,ap,x_ptr,ybuffer);
x_ptr += 4;
a_ptr += 2 * lda;
}
if ( n2 & 1 )
{
cgemv_kernel_4x1(NB,a_ptr,x_ptr,ybuffer);
x_ptr += 2;