optimized sgemv_t_4 kernel for very small sizes

This commit is contained in:
wernsaar 2014-09-06 19:41:57 +02:00
parent 3a7ab47ee9
commit c8eaf3ae2d
1 changed files with 84 additions and 14 deletions

View File

@ -423,14 +423,37 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
FLOAT *aj = a_ptr; FLOAT *aj = a_ptr;
y_ptr = y; y_ptr = y;
for ( j=0; j<n; j++ )
if ( lda == 3 && inc_y == 1 )
{ {
*y_ptr += *aj * xtemp0 + *(aj+1) * xtemp1 + *(aj+2) *xtemp2;
y_ptr += inc_y; for ( j=0; j< ( n & -4) ; j+=4 )
aj += lda; {
y_ptr[j] += aj[0] * xtemp0 + aj[1] * xtemp1 + aj[2] * xtemp2;
y_ptr[j+1] += aj[3] * xtemp0 + aj[4] * xtemp1 + aj[5] * xtemp2;
y_ptr[j+2] += aj[6] * xtemp0 + aj[7] * xtemp1 + aj[8] * xtemp2;
y_ptr[j+3] += aj[9] * xtemp0 + aj[10] * xtemp1 + aj[11] * xtemp2;
aj += 12;
}
for ( ; j<n; j++ )
{
y_ptr[j] += aj[0] * xtemp0 + aj[1] * xtemp1 + aj[2] * xtemp2;
aj += 3;
}
}
else
{
for ( j=0; j<n; j++ )
{
*y_ptr += *aj * xtemp0 + *(aj+1) * xtemp1 + *(aj+2) * xtemp2;
y_ptr += inc_y;
aj += lda;
}
} }
return(0); return(0);
} }
if ( m3 == 2 ) if ( m3 == 2 )
@ -441,11 +464,38 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
FLOAT *aj = a_ptr; FLOAT *aj = a_ptr;
y_ptr = y; y_ptr = y;
for ( j=0; j<n; j++ )
if ( lda == 2 && inc_y == 1 )
{ {
*y_ptr += *aj * xtemp0 + *(aj+1) * xtemp1 ;
y_ptr += inc_y; for ( j=0; j< ( n & -4) ; j+=4 )
aj += lda; {
y_ptr[j] += aj[0] * xtemp0 + aj[1] * xtemp1 ;
y_ptr[j+1] += aj[2] * xtemp0 + aj[3] * xtemp1 ;
y_ptr[j+2] += aj[4] * xtemp0 + aj[5] * xtemp1 ;
y_ptr[j+3] += aj[6] * xtemp0 + aj[7] * xtemp1 ;
aj += 8;
}
for ( ; j<n; j++ )
{
y_ptr[j] += aj[0] * xtemp0 + aj[1] * xtemp1 ;
aj += 2;
}
}
else
{
for ( j=0; j<n; j++ )
{
*y_ptr += *aj * xtemp0 + *(aj+1) * xtemp1 ;
y_ptr += inc_y;
aj += lda;
}
} }
return(0); return(0);
@ -454,13 +504,33 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
FLOAT xtemp = *x_ptr * alpha; FLOAT xtemp = *x_ptr * alpha;
FLOAT *aj = a_ptr; FLOAT *aj = a_ptr;
y_ptr = y; y_ptr = y;
for ( j=0; j<n; j++ ) if ( lda == 1 && inc_y == 1 )
{ {
*y_ptr += *aj * xtemp; for ( j=0; j< ( n & -4) ; j+=4 )
y_ptr += inc_y; {
aj += lda; y_ptr[j] += aj[j] * xtemp;
y_ptr[j+1] += aj[j+1] * xtemp;
y_ptr[j+2] += aj[j+2] * xtemp;
y_ptr[j+3] += aj[j+3] * xtemp;
}
for ( ; j<n ; j++ )
{
y_ptr[j] += aj[j] * xtemp;
}
} }
else
{
for ( j=0; j<n; j++ )
{
*y_ptr += *aj * xtemp;
y_ptr += inc_y;
aj += lda;
}
}
return(0); return(0);
} }