The RVV kernel generation script uses the provided LMUL to increase the number of accumulator registers. Since the effect of the LMUL is to group together the vector registers into larger ones, it actually should be used as a multiplier in the calculation of vlenmax. At the moment, no matter what LMUL is provided, the generated kernels would only set the maximum number of vector elements equal to VLEN/SEW. Commit changes the use of LMUL to properly adjust vlenmax. Note that an increase in LMUL results in a decrease in the number of effective vector registers.
674 lines
30 KiB
Python
Executable File
674 lines
30 KiB
Python
Executable File
#!/usr/bin/python3
|
|
|
|
import sys, os
|
|
import contextlib
|
|
|
|
#-----------------------------------------------------------------------
|
|
def ERROR(*args, **kwargs):
|
|
print(*args, file=sys.stderr, **kwargs)
|
|
sys.exit(-1)
|
|
|
|
class Target(object):
|
|
def __init__( self, out, mappings, initial_level=0, tab_width=4 ):
|
|
self._level = initial_level
|
|
self._tab_width = tab_width
|
|
self._out = out
|
|
self._mappings = mappings
|
|
|
|
@contextlib.contextmanager
|
|
def map( self, **items ):
|
|
old_mappings = self._mappings
|
|
self._mappings = dict(old_mappings, **items)
|
|
yield self._mappings
|
|
self._mappings = old_mappings
|
|
|
|
@contextlib.contextmanager
|
|
def block( self, start=None, end=None, **args ):
|
|
with self.map(**args):
|
|
if start is not None:
|
|
self.write();
|
|
self.write(start)
|
|
self._level += 1
|
|
yield self._level
|
|
self._level -= 1
|
|
if end is not None:
|
|
self.write(end)
|
|
self.write()
|
|
|
|
def write( self, fmt=None, *args, **kwargs ):
|
|
if fmt is not None:
|
|
mappings = dict(self._mappings, **kwargs) if kwargs else self._mappings
|
|
self._out(self._indent_str() + fmt.format(*args, **mappings))
|
|
else:
|
|
self._out("")
|
|
|
|
def _indent_str( self ):
|
|
return ' ' * (self._level * self._tab_width)
|
|
|
|
#-----------------------------------------------------------------------
|
|
def generate_trmm_block( dest ):
|
|
dest.write("{index_type} pass_K = K;")
|
|
dest.write("#ifdef LEFT")
|
|
with dest.block():
|
|
dest.write("{index_type} off = offset + m_top;")
|
|
dest.write("#else")
|
|
with dest.block():
|
|
dest.write("{index_type} off = -offset + n_top;")
|
|
dest.write("#endif")
|
|
|
|
dest.write("#ifdef BACKWARDS")
|
|
with dest.block():
|
|
dest.write("ai += off*{M}{elt_size};")
|
|
dest.write("bi += off*{N}{elt_size};")
|
|
dest.write("pass_K -= off;")
|
|
dest.write("#else")
|
|
with dest.block():
|
|
dest.write("#ifdef LEFT")
|
|
with dest.block():
|
|
dest.write("pass_K = off + {M};")
|
|
dest.write("#else")
|
|
with dest.block():
|
|
dest.write("pass_K = off + {N};")
|
|
dest.write("#endif")
|
|
dest.write("#endif")
|
|
|
|
#-----------------------------------------------------------------------
|
|
def generate_gemm_kernel_inner_real( settings, dest, M, N, vlen, a_regs ):
|
|
TRMM = (settings['op'].value == 'trmm')
|
|
narrow_result = (settings['param_precision'].value != 'double') and settings['force_acc_double'].value
|
|
|
|
with dest.map(
|
|
M=M,
|
|
N=N,
|
|
):
|
|
dest.write("{index_type} ai=m_top*K{elt_size};")
|
|
dest.write("{index_type} bi=n_top*K{elt_size};")
|
|
if TRMM:
|
|
generate_trmm_block( dest )
|
|
|
|
for i in range(N):
|
|
dest.write("{param_scalar_t} B{i} = B[bi+{i}];", i=i)
|
|
dest.write("bi += {N};")
|
|
dest.write()
|
|
|
|
for i in range(a_regs):
|
|
dest.write("{param_vector_t} A{i} = {VLEV}( &A[ai+{i}*gvl], gvl );", i=i)
|
|
dest.write("ai += {M};")
|
|
dest.write()
|
|
|
|
for j in range(N):
|
|
for i in range(a_regs):
|
|
dest.write("{acc_vector_t} result{dest} = {VMUL_TO_ACC}( A{i}, B{j}, gvl);", dest=j*a_regs+i, i=i, j=j)
|
|
|
|
with dest.block("for({index_type} k=1; k<{Kend}; k++) {{", "}}", Kend=('pass_K' if TRMM else 'K')):
|
|
for i in range(N):
|
|
dest.write("B{i} = B[bi+{i}];", i=i )
|
|
dest.write("bi += {N};")
|
|
dest.write()
|
|
|
|
for i in range(a_regs):
|
|
dest.write("A{i} = {VLEV}( &A[ai+{i}*gvl], gvl );", i=i)
|
|
|
|
dest.write("ai += {M};")
|
|
dest.write()
|
|
|
|
|
|
for j in range(N):
|
|
for i in range(a_regs):
|
|
dest.write("result{dest} = {VMACC_TO_ACC}( result{dest}, B{j}, A{i}, gvl);", dest= j*a_regs+i, j=j, i=i )
|
|
|
|
dest.write()
|
|
dest.write("{index_type} ci=n_top*ldc+m_top;")
|
|
dest.write()
|
|
|
|
if narrow_result:
|
|
for j in range(N):
|
|
for i in range(a_regs):
|
|
dest.write("{param_vector_t} narrowed{idx} = {VFNCVT}( result{idx}, gvl );", idx=j*a_regs+i)
|
|
|
|
if not TRMM:
|
|
for j in range(N):
|
|
for i in range(a_regs):
|
|
idx = j*a_regs+i
|
|
increment = ' ci += ldc-gvl*{};'.format(a_regs-1) if (i == a_regs-1) else ' ci += gvl;'
|
|
if idx == N*a_regs-1:
|
|
increment = ''
|
|
dest.write("{param_vector_t} c{idx} = {VLEV}( &C[ci], gvl);{increment}", idx=idx, increment=increment)
|
|
|
|
if narrow_result:
|
|
for j in range(N):
|
|
for i in range(a_regs):
|
|
idx = j*a_regs+i
|
|
if TRMM:
|
|
dest.write("{param_vector_t} c{idx} = {VFMUL}( narrowed{idx}, alpha, gvl );", idx=idx)
|
|
else:
|
|
dest.write("c{idx} = {VFMACC}( c{idx}, alpha, narrowed{idx}, gvl );", idx=idx)
|
|
else:
|
|
for j in range(N):
|
|
for i in range(a_regs):
|
|
idx = j*a_regs+i
|
|
if TRMM:
|
|
dest.write("{param_vector_t} c{idx} = {VFMUL}( result{idx}, alpha, gvl );", idx=idx)
|
|
else:
|
|
dest.write("c{idx} = {VFMACC}( c{idx}, alpha, result{idx}, gvl );", idx=idx)
|
|
|
|
|
|
if not TRMM:
|
|
dest.write()
|
|
dest.write("ci=n_top*ldc+m_top;")
|
|
dest.write()
|
|
|
|
for j in range(N):
|
|
for i in range(a_regs):
|
|
idx = j*a_regs+i
|
|
increment = ' ci += ldc-gvl*{};'.format(a_regs-1) if (i == a_regs-1) else ' ci += gvl;'
|
|
if idx == N*a_regs-1:
|
|
increment = ''
|
|
dest.write("{VSEV}( &C[ci], c{idx}, gvl);{increment}", idx=idx, increment=increment)
|
|
|
|
|
|
#-----------------------------------------------------------------------
|
|
def generate_gemm_kernel_inner_complex( settings, dest, M, N, vlen, a_regs ):
|
|
TRMM = (settings['op'].value == 'trmm')
|
|
narrow_result = (settings['param_precision'].value != 'double') and settings['force_acc_double'].value
|
|
|
|
if narrow_result:
|
|
raise RuntimeError("wide accumulator not supported for generated complex kernels")
|
|
# we could, but we run out of registers really really fast
|
|
|
|
with dest.map(
|
|
M=M,
|
|
N=N,
|
|
):
|
|
dest.write("{index_type} ai=m_top*K*2;")
|
|
dest.write("{index_type} bi=n_top*K*2;")
|
|
if TRMM:
|
|
generate_trmm_block( dest )
|
|
|
|
for i in range(N):
|
|
dest.write("{param_scalar_t} B{i}r = B[bi+{i}*2+0];", i=i)
|
|
dest.write("{param_scalar_t} B{i}i = B[bi+{i}*2+1];", i=i)
|
|
dest.write("bi += {N}*2;")
|
|
dest.write()
|
|
|
|
for i in range(a_regs):
|
|
dest.write("{param_vector_t} A{i}r = {VLSEV}( &A[ai+{i}*gvl*2], sizeof(FLOAT)*2, gvl );", i=i)
|
|
dest.write("{param_vector_t} A{i}i = {VLSEV}( &A[ai+{i}*gvl*2+1], sizeof(FLOAT)*2, gvl );", i=i)
|
|
dest.write("ai += {M}*2;")
|
|
dest.write()
|
|
|
|
# for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
|
|
accumulation_regs = a_regs * N
|
|
dest.write("// {a_regs} vector regs to hold A array contents, {accumulation_regs} regs to hold values accumulated over k",
|
|
a_regs=a_regs*2, accumulation_regs=accumulation_regs*2
|
|
)
|
|
pass_regs = (accumulation_regs + a_regs)*2
|
|
tmp_regs = (32 // settings['LMUL_ACC'].value) - pass_regs
|
|
if tmp_regs < 2:
|
|
raise RuntimeError("Complex kernel would use too many registers!")
|
|
|
|
dest.write("// leaving {tmp_regs} vector registers for temporaries", tmp_regs=tmp_regs)
|
|
|
|
tmp_unroll_i = min(tmp_regs, a_regs)
|
|
tmp_unroll_j = N
|
|
while tmp_unroll_j > 1 and (tmp_regs/(tmp_unroll_i*2)) < tmp_unroll_j:
|
|
tmp_unroll_j = int(tmp_unroll_j / 2)
|
|
|
|
if tmp_unroll_i < a_regs or tmp_unroll_j < N:
|
|
dest.write("// performing {ops} operations between reuses of temporaries", ops=tmp_unroll_j*tmp_unroll_i)
|
|
|
|
for tj in range(0, N, tmp_unroll_j):
|
|
for ti in range(0, a_regs, tmp_unroll_i):
|
|
for j in range(tj, tj+tmp_unroll_j):
|
|
for i in range(ti, ti+tmp_unroll_i):
|
|
with dest.map(dest=j*a_regs+i, tmp=(i-ti)+tmp_unroll_i*(j-tj), i=i, j=j):
|
|
if ti == 0 and tj==0:
|
|
dest.write("{acc_vector_t} tmp{tmp}r = {VMUL_TO_ACC}( A{i}i, B{j}i, gvl);")
|
|
dest.write("{acc_vector_t} tmp{tmp}i = {VMUL_TO_ACC}( A{i}r, B{j}i, gvl);")
|
|
else:
|
|
dest.write("tmp{tmp}r = {VMUL_TO_ACC}( A{i}i, B{j}i, gvl);")
|
|
dest.write("tmp{tmp}i = {VMUL_TO_ACC}( A{i}r, B{j}i, gvl);")
|
|
for j in range(tj, tj+tmp_unroll_j):
|
|
for i in range(ti, ti+tmp_unroll_i):
|
|
with dest.map(dest=j*a_regs+i, tmp=(i-ti)+tmp_unroll_i*(j-tj), i=i, j=j):
|
|
dest.write("tmp{tmp}r = VFMACC_RR( tmp{tmp}r, B{j}r, A{i}r, gvl);")
|
|
dest.write("tmp{tmp}i = VFMACC_RI( tmp{tmp}i, B{j}r, A{i}i, gvl);")
|
|
|
|
for j in range(tj, tj+tmp_unroll_j):
|
|
for i in range(ti, ti+tmp_unroll_i):
|
|
with dest.map(dest=j*a_regs+i, tmp=(i-ti)+tmp_unroll_i*(j-tj), i=i, j=j):
|
|
dest.write("{acc_vector_t} ACC{dest}r = tmp{tmp}r;")
|
|
dest.write("{acc_vector_t} ACC{dest}i = tmp{tmp}i;")
|
|
|
|
with dest.block("for({index_type} k=1; k<{Kend}; k++) {{", "}}", Kend=('pass_K' if TRMM else 'K')):
|
|
for i in range(N):
|
|
dest.write("B{i}r = B[bi+{i}*2+0];", i=i)
|
|
dest.write("B{i}i = B[bi+{i}*2+1];", i=i)
|
|
dest.write("bi += {N}*2;")
|
|
dest.write()
|
|
|
|
for i in range(a_regs):
|
|
dest.write("A{i}r = {VLSEV}( &A[ai+{i}*gvl*2], sizeof(FLOAT)*2, gvl );", i=i)
|
|
dest.write("A{i}i = {VLSEV}( &A[ai+{i}*gvl*2+1], sizeof(FLOAT)*2, gvl );", i=i)
|
|
|
|
dest.write("ai += {M}*2;")
|
|
dest.write()
|
|
|
|
|
|
for tj in range(0, N, tmp_unroll_j):
|
|
for ti in range(0, a_regs, tmp_unroll_i):
|
|
# note the values in tmp{tmp}* are frequently of similar magnitude and opposite sign
|
|
# so accumulating them directly to ACC would lose precision when ACC is larger
|
|
|
|
for j in range(tj, tj+tmp_unroll_j):
|
|
for i in range(ti, ti+tmp_unroll_i):
|
|
with dest.map(dest=j*a_regs+i, tmp=(i-ti)+tmp_unroll_i*(j-tj), i=i, j=j):
|
|
dest.write("tmp{tmp}r = {VMUL_TO_ACC}( A{i}i, B{j}i, gvl);")
|
|
dest.write("tmp{tmp}i = {VMUL_TO_ACC}( A{i}r, B{j}i, gvl);")
|
|
for j in range(tj, tj+tmp_unroll_j):
|
|
for i in range(ti, ti+tmp_unroll_i):
|
|
with dest.map(dest=j*a_regs+i, tmp=(i-ti)+tmp_unroll_i*(j-tj), i=i, j=j):
|
|
dest.write("tmp{tmp}r = VFMACC_RR( tmp{tmp}r, B{j}r, A{i}r, gvl);")
|
|
dest.write("tmp{tmp}i = VFMACC_RI( tmp{tmp}i, B{j}r, A{i}i, gvl);")
|
|
for j in range(tj, tj+tmp_unroll_j):
|
|
for i in range(ti, ti+tmp_unroll_i):
|
|
with dest.map(dest=j*a_regs+i, tmp=(i-ti)+tmp_unroll_i*(j-tj), i=i, j=j):
|
|
dest.write("ACC{dest}r = {__riscv_}vfadd( ACC{dest}r, tmp{tmp}r, gvl);")
|
|
dest.write("ACC{dest}i = {__riscv_}vfadd( ACC{dest}i, tmp{tmp}i, gvl);")
|
|
|
|
dest.write()
|
|
dest.write("{index_type} ci=n_top*ldc+m_top;")
|
|
dest.write()
|
|
|
|
for j in range(N):
|
|
if TRMM:
|
|
for i in range(a_regs):
|
|
with dest.map(idx=j*a_regs+i):
|
|
dest.write("{param_vector_t} C{idx}r = {__riscv_}vfmul( ACC{idx}r, alphar, gvl );")
|
|
dest.write("{param_vector_t} C{idx}i = {__riscv_}vfmul( ACC{idx}i, alphar, gvl );")
|
|
else:
|
|
for i in range(a_regs):
|
|
idx = j*a_regs+i
|
|
increment = 'ci += ldc-gvl*{};'.format(a_regs-1) if (i == a_regs-1) else ' ci += gvl;'
|
|
if idx == N*a_regs-1:
|
|
increment = ''
|
|
with dest.map(idx=j*a_regs+i, increment=increment):
|
|
dest.write("{param_vector_t} C{idx}r = {VLSEV}( &C[ci*2+0], sizeof(FLOAT)*2, gvl );")
|
|
dest.write("{param_vector_t} C{idx}i = {VLSEV}( &C[ci*2+1], sizeof(FLOAT)*2, gvl );")
|
|
dest.write("{increment}")
|
|
|
|
if not TRMM:
|
|
for j in range(N):
|
|
for i in range(a_regs):
|
|
with dest.map(idx=j*a_regs+i):
|
|
dest.write("C{idx}r = {__riscv_}vfmacc( C{idx}r, alphar, ACC{idx}r, gvl );")
|
|
dest.write("C{idx}i = {__riscv_}vfmacc( C{idx}i, alphar, ACC{idx}i, gvl );")
|
|
|
|
for j in range(N):
|
|
for i in range(a_regs):
|
|
with dest.map(idx=j*a_regs+i):
|
|
dest.write("C{idx}r = {__riscv_}vfnmsac( C{idx}r, alphai, ACC{idx}i, gvl );")
|
|
dest.write("C{idx}i = {__riscv_}vfmacc ( C{idx}i, alphai, ACC{idx}r, gvl );")
|
|
|
|
if not TRMM:
|
|
dest.write()
|
|
dest.write("ci=n_top*ldc+m_top;")
|
|
dest.write()
|
|
|
|
for j in range(N):
|
|
for i in range(a_regs):
|
|
idx = j*a_regs+i
|
|
increment = 'ci += ldc-gvl*{};'.format(a_regs-1) if (i == a_regs-1) else ' ci += gvl;'
|
|
if idx == N*a_regs-1:
|
|
increment = ''
|
|
with dest.map(idx=j*a_regs+i, increment=increment):
|
|
dest.write("{VSSEV}( &C[ci*2+0], sizeof(FLOAT)*2, C{idx}r, gvl);")
|
|
dest.write("{VSSEV}( &C[ci*2+1], sizeof(FLOAT)*2, C{idx}i, gvl);")
|
|
dest.write("{increment}")
|
|
|
|
#-----------------------------------------------------------------------
|
|
def generate_gemm_kernel( settings, OUTPUT ):
|
|
if settings['conjugate'].value:
|
|
ERROR('conjugate gemm not yet supported')
|
|
|
|
is_complex = settings['complex'].value
|
|
generate_gemm_kernel_inner = generate_gemm_kernel_inner_complex if is_complex else generate_gemm_kernel_inner_real
|
|
dest = Target(OUTPUT, { k:str(settings[k].value) for k in settings })
|
|
|
|
M = settings['M'].value
|
|
N = settings['N'].value
|
|
vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value /
|
|
settings['ELEN_PARAM'].value)
|
|
a_regs = max(int(M/vlenmax), 1)
|
|
|
|
# for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
|
|
accumulation_regs = a_regs * N
|
|
required_regs = accumulation_regs + a_regs
|
|
if is_complex:
|
|
required_regs = required_regs * 2 + 2
|
|
dest.write('''
|
|
#if defined(NN) || defined(NT) || defined(TN) || defined(TT)
|
|
#define S0 1
|
|
#define S1 -1
|
|
#define S2 1
|
|
#define S3 1
|
|
#define VFMACC_RR __riscv_vfmsac{tail_policy}
|
|
#define VFMACC_RI __riscv_vfmacc{tail_policy}
|
|
#endif
|
|
#if defined(NR) || defined(NC) || defined(TR) || defined(TC)
|
|
#define S0 1
|
|
#define S1 1
|
|
#define S2 1
|
|
#define S3 -1
|
|
#define VFMACC_RR __riscv_vfmacc{tail_policy}
|
|
#define VFMACC_RI __riscv_vfmsac{tail_policy}
|
|
#endif
|
|
#if defined(RN) || defined(RT) || defined(CN) || defined(CT)
|
|
#define S0 1
|
|
#define S1 1
|
|
#define S2 -1
|
|
#define S3 1
|
|
#define VFMACC_RR __riscv_vfmacc{tail_policy}
|
|
#define VFMACC_RI __riscv_vfnmsac{tail_policy}
|
|
#endif
|
|
#if defined(RR) || defined(RC) || defined(CR) || defined(CC)
|
|
#define S0 1
|
|
#define S1 -1
|
|
#define S2 -1
|
|
#define S3 -1
|
|
#define VFMACC_RR __riscv_vfmsac{tail_policy}
|
|
#define VFMACC_RI __riscv_vfnmacc{tail_policy}
|
|
#endif
|
|
'''.format(tail_policy=settings['tail_policy'].value))
|
|
|
|
|
|
if required_regs > (32 // settings['LMUL_ACC'].value):
|
|
raise Exception("{} vector registers needed during accumulation for unrolling {} x {}{} but only {} are available".format(
|
|
required_regs, N, M, (" with wide accumulator" if settings['LMUL_ACC'].value > 1 else ''), 32 // settings['LMUL_ACC'].value
|
|
))
|
|
|
|
TRMM = (settings['op'].value == 'trmm')
|
|
if TRMM:
|
|
with dest.block("#if defined(LEFT) != defined(TRANSA)", "#endif"):
|
|
dest.write("#define BACKWARDS")
|
|
|
|
dest.write("int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, {alpha}, FLOAT* A, FLOAT* B, FLOAT* C, BLASLONG ldc{trmm})",
|
|
alpha = ('FLOAT alphar, FLOAT alphai' if is_complex else 'FLOAT alpha'),
|
|
trmm = (', BLASLONG offset' if TRMM else '')
|
|
)
|
|
|
|
with dest.block("{{", "}}", elt_size='*2' if is_complex else ''):
|
|
if settings['trace'].value:
|
|
dest.write("printf(\"\\n\\nENTRY: %s(%d) M %d N %d K %d ldc %d\\n\", __FILE__, __LINE__, M, N, K, ldc);")
|
|
dest.write("{index_type} gvl = 0;")
|
|
dest.write("{index_type} m_top = 0;")
|
|
dest.write("{index_type} n_top = 0;")
|
|
|
|
dest.write()
|
|
dest.write()
|
|
dest.write("// -- MAIN PASS")
|
|
|
|
with dest.block("for ({index_type} j=0; j<N/{N}; j+=1) {{", "}}"):
|
|
dest.write("m_top = 0;")
|
|
dest.write("{index_type} gvl = {VSETVL}({vlenmax});", vlenmax=min(vlenmax,max(int(M/a_regs),1)))
|
|
dest.write()
|
|
with dest.block("for ({index_type} i=0; i<M/{M}; i+=1) {{", "}}"):
|
|
generate_gemm_kernel_inner( settings, dest, M, N, vlenmax, a_regs )
|
|
dest.write( "m_top += {M};" )
|
|
|
|
dest.write()
|
|
dest.write()
|
|
dest.write("// -- tails for main pass")
|
|
generate_M_tails( dest, settings, M, N )
|
|
|
|
dest.write( "n_top += {N};" )
|
|
|
|
|
|
N_tail = int(N/2)
|
|
while( N_tail > 0 ):
|
|
with dest.map(N=N_tail):
|
|
dest.write()
|
|
dest.write()
|
|
dest.write("// -- tails for N={N}")
|
|
with dest.block("if( N & {N} ) {{", "}}" ):
|
|
if settings['trace'].value:
|
|
dest.write("printf(\"N tail entry: %s(%d) M %d N %d K %d m_top %d n_top %d\\n\", __FILE__, __LINE__, M, N, K, m_top, n_top);")
|
|
dest.write("gvl = {VSETVL}({vlenmax});", vlenmax=min(vlenmax,max(int(M/a_regs),1)))
|
|
dest.write("m_top = 0;")
|
|
with dest.block("for ({index_type} i=0; i<M/{M}; i+=1) {{", "}}"):
|
|
generate_gemm_kernel_inner( settings, dest, M, N_tail, vlenmax, a_regs )
|
|
dest.write("m_top += {M};")
|
|
|
|
generate_M_tails( dest, settings, M, N_tail )
|
|
dest.write("n_top += {N};")
|
|
N_tail = int(N_tail/2)
|
|
|
|
dest.write("return 0;");
|
|
|
|
|
|
#-----------------------------------------------------------------------
|
|
def generate_M_tails( dest, settings, M, N ):
|
|
M_tail = int(M/2)
|
|
M_tail_min = settings['M_tail_scalar_from'].value
|
|
vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value
|
|
/ settings['ELEN_PARAM'].value )
|
|
TRMM = (settings['op'].value == 'trmm')
|
|
is_complex = settings['complex'].value
|
|
generate_gemm_kernel_inner = generate_gemm_kernel_inner_complex if is_complex else generate_gemm_kernel_inner_real
|
|
|
|
while( M_tail > M_tail_min ):
|
|
with dest.block("if( M & {M_tail} ) {{", "}}", M_tail=M_tail ):
|
|
if settings['trace'].value:
|
|
dest.write("printf(\"tail: %s(%d) M %d N %d K %d m_top %d n_top %d\\n\", __FILE__, __LINE__, M, N, K, m_top, n_top);")
|
|
a_regs = max( 1, int(M_tail/vlenmax) )
|
|
vlen = int(M_tail/a_regs)
|
|
dest.write("gvl = {VSETVL}({vlen});\n", vlen=vlen)
|
|
|
|
generate_gemm_kernel_inner( settings, dest, M_tail, N, vlen, a_regs )
|
|
dest.write( "m_top += {M_tail};" )
|
|
|
|
M_tail = int( M_tail / 2 )
|
|
|
|
while( M_tail > 0 ):
|
|
with dest.block("if( M & {M_tail} ) {{", "}}",
|
|
M_tail=M_tail,
|
|
N=N,
|
|
result_t = ('double' if settings['force_acc_double'].value else settings['param_scalar_t'].value)
|
|
):
|
|
if settings['trace'].value:
|
|
dest.write("printf(\"tail: %s(%d) M %d N %d K %d m_top %d n_top %d\\n\", __FILE__, __LINE__, M, N, K, m_top, n_top);")
|
|
for r in range(M_tail * N * (2 if is_complex else 1)):
|
|
dest.write("{result_t} result{r} = 0;",
|
|
r=r
|
|
)
|
|
|
|
dest.write("{index_type} ai=m_top*K{elt_size};")
|
|
dest.write("{index_type} bi=n_top*K{elt_size};")
|
|
|
|
if TRMM:
|
|
with dest.map(M=M_tail, N=N):
|
|
generate_trmm_block( dest )
|
|
|
|
with dest.block("for({index_type} k=0; k<{Kend}; k++) {{", "}}", Kend = ('pass_K' if TRMM else 'K') ):
|
|
for ki in range( N ):
|
|
for kj in range( M_tail ):
|
|
if is_complex:
|
|
dest.write("result{dest}+=S0*A[ai+{kj}+0]*B[bi+{ki}+0] + S1*A[ai+{kj}+1]*B[bi+{ki}+1];".format(
|
|
dest=(ki*M_tail+kj)*2, kj=kj*2, ki=ki*2
|
|
))
|
|
dest.write("result{dest}+=S2*A[ai+{kj}+1]*B[bi+{ki}+0] + S3*A[ai+{kj}+0]*B[bi+{ki}+1];".format(
|
|
dest=(ki*M_tail+kj)*2+1, kj=kj*2, ki=ki*2
|
|
))
|
|
else:
|
|
dest.write("result{dest}+=A[ai+{kj}]*B[bi+{ki}];".format(
|
|
dest=ki*M_tail+kj, kj=kj, ki=ki
|
|
))
|
|
dest.write("ai+={M_tail}{elt_size};")
|
|
dest.write("bi+={N}{elt_size};")
|
|
|
|
dest.write("{index_type} ci=n_top*ldc+m_top;")
|
|
if is_complex:
|
|
dest.write("{result_t} Cr, Ci;")
|
|
for ki in range( N ):
|
|
for kj in range( M_tail ):
|
|
if is_complex:
|
|
if TRMM:
|
|
dest.write('Cr = result{dest}*alphar;', dest=(ki*M_tail+kj)*2+0)
|
|
dest.write('Ci = result{dest}*alphar;', dest=(ki*M_tail+kj)*2+1)
|
|
else:
|
|
dest.write('Cr = C[(ci+{ki}*ldc+{kj})*2+0];', ki=ki, kj=kj)
|
|
dest.write('Ci = C[(ci+{ki}*ldc+{kj})*2+1];', ki=ki, kj=kj)
|
|
dest.write('Cr += result{dest}*alphar;', dest=(ki*M_tail+kj)*2+0)
|
|
dest.write('Ci += result{dest}*alphar;', dest=(ki*M_tail+kj)*2+1)
|
|
dest.write('Cr -= result{dest}*alphai;', dest=(ki*M_tail+kj)*2+1)
|
|
dest.write('Ci += result{dest}*alphai;', dest=(ki*M_tail+kj)*2+0)
|
|
dest.write("C[(ci+{ki}*ldc+{kj})*2+0] = Cr;", ki=ki, kj=kj )
|
|
dest.write("C[(ci+{ki}*ldc+{kj})*2+1] = Ci;", ki=ki, kj=kj )
|
|
else:
|
|
op = '' if TRMM else '+'
|
|
dest.write("C[ci+{ki}*ldc+{kj}] {op}= alpha * result{dest};",
|
|
ki=ki, kj=kj, op=op, dest=ki*M_tail+kj
|
|
)
|
|
dest.write("m_top+={M_tail};")
|
|
|
|
M_tail = int(M_tail/2)
|
|
|
|
|
|
#-----------------------------------------------------------------------
|
|
class Setting(object):
|
|
def __init__( self, value, convert = None ):
|
|
self._value = value
|
|
self._convert = convert
|
|
|
|
@classmethod
|
|
def ENUM( cls, *values ):
|
|
def closure( values ):
|
|
return lambda value: values[value.lower()]
|
|
return closure( { v.lower():v for v in values } )
|
|
|
|
@classmethod
|
|
def BOOL( cls, value ):
|
|
return value.lower().startswith('t') or value == '1'
|
|
|
|
@property
|
|
def value( self ):
|
|
return self._value
|
|
|
|
@property
|
|
def configurable( self ):
|
|
return self._convert is not None
|
|
|
|
@value.setter
|
|
def value( self, value ):
|
|
self._value = self._convert( value )
|
|
|
|
def __str__( self ):
|
|
return str(self._value)
|
|
|
|
#-----------------------------------------------------------------------
|
|
def main():
|
|
settings = {
|
|
'op': Setting( 'gemm', Setting.ENUM( 'gemm', 'trmm' ) ),
|
|
'M': Setting( 16, int ),
|
|
'N': Setting( 4, int ),
|
|
'reg_width_bits': Setting( 256, int ),
|
|
'LMUL': Setting( 1, int ),
|
|
'M_tail_scalar_from':Setting( 2, int ),
|
|
'cpu': Setting( 'zvl256b', str ),
|
|
'param_precision': Setting( 'float', Setting.ENUM( 'float', 'double' ) ),
|
|
'force_acc_double': Setting( False, Setting.BOOL ),
|
|
'complex': Setting( False, Setting.BOOL ),
|
|
'conjugate': Setting( False, Setting.BOOL ),
|
|
'index_type': Setting( 'BLASLONG', str ),
|
|
'trace': Setting( False, Setting.BOOL ),
|
|
'output': Setting( None, str ),
|
|
'tail_policy': Setting( '', str ), # _ta, if toolchain supports it
|
|
'__riscv_': Setting( '__riscv_', str),
|
|
}
|
|
|
|
for item in sys.argv[1:]:
|
|
try:
|
|
name, value = tuple(item.split( '=', 1 ))
|
|
except:
|
|
ERROR("couldn't parse {}, expected arguments of the form name=value".format(item))
|
|
|
|
if name not in settings:
|
|
ERROR("couldn't parse {}, {} it is not a known option\n".format( item, name )
|
|
+"options (and current defaults) are\n{}".format(
|
|
" ".join([ '{}={}'.format(k, settings[k].value) for k in settings.keys()]))
|
|
)
|
|
|
|
try:
|
|
settings[name].value = value
|
|
except:
|
|
import traceback
|
|
traceback.print_exc()
|
|
ERROR("couldn't parse {}".format(item))
|
|
|
|
if settings['output'].value is None:
|
|
if settings['complex'].value:
|
|
prefix = 'z' if settings['param_precision'].value == 'double' else 'c'
|
|
else:
|
|
prefix = 'd' if settings['param_precision'].value == 'double' else 's'
|
|
settings['output'] = Setting('{}{}_kernel_{}x{}_{}.c'.format(
|
|
prefix,
|
|
settings['op'],
|
|
settings['M'],
|
|
settings['N'],
|
|
settings['cpu']
|
|
))
|
|
|
|
if settings['param_precision'].value == 'double':
|
|
settings['param_scalar_t'] = Setting( 'double' )
|
|
settings['ELEN_PARAM'] = Setting(64)
|
|
else:
|
|
settings['param_scalar_t'] = Setting( 'float' )
|
|
settings['ELEN_PARAM'] = Setting(32)
|
|
|
|
settings['VFMUL'] = Setting( '{}vfmul_vf_f{}m{}{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['LMUL'], settings['tail_policy']) )
|
|
settings['VFMACC'] = Setting( '{}vfmacc_vf_f{}m{}{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['LMUL'], settings['tail_policy']) )
|
|
|
|
settings['ELEN_ACC'] = settings['ELEN_PARAM']
|
|
settings['LMUL_ACC'] = Setting(settings['LMUL'].value)
|
|
widen = ''
|
|
|
|
if settings['force_acc_double'].value and (settings['param_precision'].value == 'float'):
|
|
settings['ELEN_ACC'] = Setting(64)
|
|
settings['LMUL_ACC'] = Setting(settings['LMUL'].value*2)
|
|
settings['VFNCVT'] = Setting('{}vfncvt_f_f_w_f{}m{}{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['LMUL'], settings['tail_policy']))
|
|
widen = 'w'
|
|
|
|
settings['VMUL_TO_ACC'] = Setting( '{}vf{}mul_vf_f{}m{}{}'.format(settings['__riscv_'], widen, settings['ELEN_ACC'], settings['LMUL_ACC'], settings['tail_policy']) )
|
|
settings['VMACC_TO_ACC'] = Setting( '{}vf{}macc_vf_f{}m{}{}'.format(settings['__riscv_'], widen, settings['ELEN_ACC'], settings['LMUL_ACC'], settings['tail_policy']) )
|
|
|
|
settings['param_vector_t']=Setting('vfloat{}m{}_t'.format(settings['ELEN_PARAM'], settings['LMUL']))
|
|
settings['acc_vector_t'] =Setting('vfloat{}m{}_t'.format(settings['ELEN_ACC'], settings['LMUL_ACC']))
|
|
settings['VLEV'] =Setting('{}vle{}_v_f{}m{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['ELEN_PARAM'], settings['LMUL']))
|
|
settings['VSEV'] =Setting('{}vse{}_v_f{}m{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['ELEN_PARAM'], settings['LMUL']))
|
|
settings['VLSEV'] =Setting('{}vlse{}_v_f{}m{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['ELEN_PARAM'], settings['LMUL']))
|
|
settings['VSSEV'] =Setting('{}vsse{}_v_f{}m{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['ELEN_PARAM'], settings['LMUL']))
|
|
settings['VSETVL'] =Setting('{}vsetvl_e{}m{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['LMUL']))
|
|
|
|
|
|
to_stdout = (settings['output'].value == '-')
|
|
if not to_stdout:
|
|
print("Writing {}".format(settings['output'].value), file=sys.stderr)
|
|
|
|
with open(sys.stdout.fileno() if to_stdout else settings['output'].value, 'w') as destination_file:
|
|
def OUTPUT(*args, **kwargs):
|
|
print(*args, file=destination_file, **kwargs)
|
|
|
|
OUTPUT("/*\n\nAUTOGENERATED KERNEL\nSettings:\n {}".format(" ".join([ "{}={}\n".format(k, repr(settings[k].value)) for k in sorted(settings.keys()) if settings[k].configurable])))
|
|
OUTPUT("Derived:\n {}\n*/\n".format(" ".join([ "{}={}\n".format(k, repr(settings[k].value)) for k in sorted(settings.keys()) if not settings[k].configurable])))
|
|
|
|
OUTPUT('#include "common.h"')
|
|
OUTPUT("\n")
|
|
|
|
if settings['op'].value in ('gemm', 'trmm'):
|
|
generate_gemm_kernel(settings, OUTPUT)
|
|
else:
|
|
ERROR("unsupported kernel type {}".format(settings['op']))
|
|
|
|
if __name__ == "__main__":
|
|
main()
|