[RISC-V] Improve RVV kernel generator LMUL usage

The RVV kernel generation script uses the provided LMUL to increase the number of accumulator registers.
Since the effect of the LMUL is to group together the vector registers into larger ones, it actually should be used as a multiplier in the calculation of vlenmax.
At the moment, no matter what LMUL is provided, the generated kernels would only set the maximum number of vector elements equal to VLEN/SEW.
Commit changes the use of LMUL to properly adjust vlenmax. Note that an increase in LMUL results in a decrease in the number of effective vector registers.
This commit is contained in:
Octavian Maghiar 2023-12-04 11:13:35 +00:00
parent 62f0f506ec
commit 4a12cf53ec
1 changed files with 13 additions and 10 deletions

View File

@ -197,13 +197,13 @@ def generate_gemm_kernel_inner_complex( settings, dest, M, N, vlen, a_regs ):
dest.write("ai += {M}*2;")
dest.write()
accumulation_regs = a_regs * N * settings['LMUL_ACC'].value
# for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
accumulation_regs = a_regs * N
dest.write("// {a_regs} vector regs to hold A array contents, {accumulation_regs} regs to hold values accumulated over k",
a_regs=a_regs*2, accumulation_regs=accumulation_regs*2
)
pass_regs = (accumulation_regs + a_regs)*2
tmp_regs = 32-pass_regs
tmp_regs = (32 // settings['LMUL_ACC'].value) - pass_regs
if tmp_regs < 2:
raise RuntimeError("Complex kernel would use too many registers!")
@ -337,10 +337,12 @@ def generate_gemm_kernel( settings, OUTPUT ):
M = settings['M'].value
N = settings['N'].value
vlenmax = int( settings['reg_width_bits'].value / settings['ELEN_PARAM'].value )
vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value /
settings['ELEN_PARAM'].value)
a_regs = max(int(M/vlenmax), 1)
accumulation_regs = a_regs * N * settings['LMUL_ACC'].value
# for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
accumulation_regs = a_regs * N
required_regs = accumulation_regs + a_regs
if is_complex:
required_regs = required_regs * 2 + 2
@ -380,9 +382,9 @@ def generate_gemm_kernel( settings, OUTPUT ):
'''.format(tail_policy=settings['tail_policy'].value))
if required_regs > 32:
raise Exception("{} vector registers needed during accumulation for unrolling {} x {}{} but only 32 are available".format(
required_regs, N, M, (" with wide accumulator" if settings['LMUL_ACC'].value > 1 else '')
if required_regs > (32 // settings['LMUL_ACC'].value):
raise Exception("{} vector registers needed during accumulation for unrolling {} x {}{} but only {} are available".format(
required_regs, N, M, (" with wide accumulator" if settings['LMUL_ACC'].value > 1 else ''), 32 // settings['LMUL_ACC'].value
))
TRMM = (settings['op'].value == 'trmm')
@ -448,7 +450,8 @@ def generate_gemm_kernel( settings, OUTPUT ):
def generate_M_tails( dest, settings, M, N ):
M_tail = int(M/2)
M_tail_min = settings['M_tail_scalar_from'].value
vlenmax = int( settings['reg_width_bits'].value / settings['ELEN_PARAM'].value )
vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value
/ settings['ELEN_PARAM'].value )
TRMM = (settings['op'].value == 'trmm')
is_complex = settings['complex'].value
generate_gemm_kernel_inner = generate_gemm_kernel_inner_complex if is_complex else generate_gemm_kernel_inner_real
@ -667,4 +670,4 @@ def main():
ERROR("unsupported kernel type {}".format(settings['op']))
if __name__ == "__main__":
main()
main()