141 lines
3.8 KiB
C
141 lines
3.8 KiB
C
#include "common.h"
|
|
|
|
|
|
/* for debugging/unit tests
|
|
* this is a drop-in replacement for zgemm/cgemm/ztrmm/ctrmm kernels that supports arbitrary combinations of unroll values
|
|
*/
|
|
|
|
#ifdef TRMMKERNEL
|
|
#if defined(LEFT) != defined(TRANSA)
|
|
#define BACKWARDS
|
|
#endif
|
|
#endif
|
|
|
|
#ifdef DOUBLE
|
|
|
|
#define UNROLL_M ZGEMM_DEFAULT_UNROLL_M
|
|
#define UNROLL_N ZGEMM_DEFAULT_UNROLL_N
|
|
|
|
#else
|
|
|
|
#define UNROLL_M CGEMM_DEFAULT_UNROLL_M
|
|
#define UNROLL_N CGEMM_DEFAULT_UNROLL_N
|
|
|
|
#endif
|
|
|
|
int CNAME(BLASLONG M,BLASLONG N,BLASLONG K,FLOAT alphar,FLOAT alphai,FLOAT* A,FLOAT* B,FLOAT* C,BLASLONG ldc
|
|
#ifdef TRMMKERNEL
|
|
,BLASLONG offset
|
|
#endif
|
|
)
|
|
{
|
|
FLOAT res[UNROLL_M*UNROLL_N*2];
|
|
|
|
#if defined(NN) || defined(NT) || defined(TN) || defined(TT)
|
|
FLOAT sign[4] = { 1, -1, 1, 1};
|
|
#endif
|
|
#if defined(NR) || defined(NC) || defined(TR) || defined(TC)
|
|
FLOAT sign[4] = { 1, 1, 1, -1};
|
|
#endif
|
|
#if defined(RN) || defined(RT) || defined(CN) || defined(CT)
|
|
FLOAT sign[4] = { 1, 1, -1, 1};
|
|
#endif
|
|
#if defined(RR) || defined(RC) || defined(CR) || defined(CC)
|
|
FLOAT sign[4] = { 1, -1, -1, -1};
|
|
#endif
|
|
|
|
BLASLONG n_packing = UNROLL_N;
|
|
BLASLONG n_top = 0;
|
|
|
|
while(n_top < N)
|
|
{
|
|
while( n_top+n_packing > N )
|
|
n_packing >>= 1;
|
|
|
|
BLASLONG m_packing = UNROLL_M;
|
|
BLASLONG m_top = 0;
|
|
while (m_top < M)
|
|
{
|
|
while( m_top+m_packing > M )
|
|
m_packing >>= 1;
|
|
|
|
BLASLONG ai = K*m_top*2;
|
|
BLASLONG bi = K*n_top*2;
|
|
|
|
BLASLONG pass_K = K;
|
|
|
|
|
|
#ifdef TRMMKERNEL
|
|
#ifdef LEFT
|
|
BLASLONG off = offset + m_top;
|
|
#else
|
|
BLASLONG off = -offset + n_top;
|
|
#endif
|
|
#ifdef BACKWARDS
|
|
ai += off * m_packing*2;
|
|
bi += off * n_packing*2;
|
|
pass_K -= off;
|
|
#else
|
|
#ifdef LEFT
|
|
pass_K = off + m_packing;
|
|
#else
|
|
pass_K = off + n_packing;
|
|
#endif
|
|
#endif
|
|
#endif
|
|
|
|
memset( res, 0, UNROLL_M*UNROLL_N*2*sizeof(FLOAT) );
|
|
|
|
for (BLASLONG k=0; k<pass_K; k+=1)
|
|
{
|
|
for( BLASLONG ki = 0; ki < n_packing; ++ki )
|
|
{
|
|
FLOAT B0 = B[bi+ki*2+0];
|
|
FLOAT B1 = B[bi+ki*2+1];
|
|
|
|
for( BLASLONG kj = 0; kj < m_packing; ++kj )
|
|
{
|
|
FLOAT A0 = A[ai+kj*2+0];
|
|
FLOAT A1 = A[ai+kj*2+1];
|
|
|
|
res[(ki*UNROLL_M+kj)*2+0] += sign[0]*A0*B0 +sign[1]*A1*B1;
|
|
res[(ki*UNROLL_M+kj)*2+1] += sign[2]*A1*B0 +sign[3]*A0*B1;
|
|
}
|
|
}
|
|
|
|
ai += m_packing*2;
|
|
bi += n_packing*2;
|
|
}
|
|
|
|
BLASLONG cofs = ldc * n_top + m_top;
|
|
for( BLASLONG ki = 0; ki < n_packing; ++ki )
|
|
{
|
|
for( BLASLONG kj = 0; kj < m_packing; ++kj )
|
|
{
|
|
#ifdef TRMMKERNEL
|
|
FLOAT Cr = 0;
|
|
FLOAT Ci = 0;
|
|
#else
|
|
FLOAT Cr = C[(cofs+ki*ldc+kj)*2+0];
|
|
FLOAT Ci = C[(cofs+ki*ldc+kj)*2+1];
|
|
#endif
|
|
|
|
Cr += res[(ki*UNROLL_M+kj)*2+0]*alphar;
|
|
Cr += -res[(ki*UNROLL_M+kj)*2+1]*alphai;
|
|
Ci += res[(ki*UNROLL_M+kj)*2+1]*alphar;
|
|
Ci += res[(ki*UNROLL_M+kj)*2+0]*alphai;
|
|
|
|
C[(cofs+ki*ldc+kj)*2+0] = Cr;
|
|
C[(cofs+ki*ldc+kj)*2+1] = Ci;
|
|
}
|
|
}
|
|
|
|
m_top += m_packing;
|
|
}
|
|
|
|
n_top += n_packing;
|
|
}
|
|
|
|
return 0;
|
|
}
|