[RISC-V] Improve RVV kernel generator LMUL usage
The RVV kernel generation script uses the provided LMUL to increase the number of accumulator registers. Since the effect of the LMUL is to group together the vector registers into larger ones, it actually should be used as a multiplier in the calculation of vlenmax. At the moment, no matter what LMUL is provided, the generated kernels would only set the maximum number of vector elements equal to VLEN/SEW. Commit changes the use of LMUL to properly adjust vlenmax. Note that an increase in LMUL results in a decrease in the number of effective vector registers.
This commit is contained in:
parent
62f0f506ec
commit
4a12cf53ec
|
@ -197,13 +197,13 @@ def generate_gemm_kernel_inner_complex( settings, dest, M, N, vlen, a_regs ):
|
|||
dest.write("ai += {M}*2;")
|
||||
dest.write()
|
||||
|
||||
|
||||
accumulation_regs = a_regs * N * settings['LMUL_ACC'].value
|
||||
# for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
|
||||
accumulation_regs = a_regs * N
|
||||
dest.write("// {a_regs} vector regs to hold A array contents, {accumulation_regs} regs to hold values accumulated over k",
|
||||
a_regs=a_regs*2, accumulation_regs=accumulation_regs*2
|
||||
)
|
||||
pass_regs = (accumulation_regs + a_regs)*2
|
||||
tmp_regs = 32-pass_regs
|
||||
tmp_regs = (32 // settings['LMUL_ACC'].value) - pass_regs
|
||||
if tmp_regs < 2:
|
||||
raise RuntimeError("Complex kernel would use too many registers!")
|
||||
|
||||
|
@ -337,10 +337,12 @@ def generate_gemm_kernel( settings, OUTPUT ):
|
|||
|
||||
M = settings['M'].value
|
||||
N = settings['N'].value
|
||||
vlenmax = int( settings['reg_width_bits'].value / settings['ELEN_PARAM'].value )
|
||||
vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value /
|
||||
settings['ELEN_PARAM'].value)
|
||||
a_regs = max(int(M/vlenmax), 1)
|
||||
|
||||
accumulation_regs = a_regs * N * settings['LMUL_ACC'].value
|
||||
# for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
|
||||
accumulation_regs = a_regs * N
|
||||
required_regs = accumulation_regs + a_regs
|
||||
if is_complex:
|
||||
required_regs = required_regs * 2 + 2
|
||||
|
@ -380,9 +382,9 @@ def generate_gemm_kernel( settings, OUTPUT ):
|
|||
'''.format(tail_policy=settings['tail_policy'].value))
|
||||
|
||||
|
||||
if required_regs > 32:
|
||||
raise Exception("{} vector registers needed during accumulation for unrolling {} x {}{} but only 32 are available".format(
|
||||
required_regs, N, M, (" with wide accumulator" if settings['LMUL_ACC'].value > 1 else '')
|
||||
if required_regs > (32 // settings['LMUL_ACC'].value):
|
||||
raise Exception("{} vector registers needed during accumulation for unrolling {} x {}{} but only {} are available".format(
|
||||
required_regs, N, M, (" with wide accumulator" if settings['LMUL_ACC'].value > 1 else ''), 32 // settings['LMUL_ACC'].value
|
||||
))
|
||||
|
||||
TRMM = (settings['op'].value == 'trmm')
|
||||
|
@ -448,7 +450,8 @@ def generate_gemm_kernel( settings, OUTPUT ):
|
|||
def generate_M_tails( dest, settings, M, N ):
|
||||
M_tail = int(M/2)
|
||||
M_tail_min = settings['M_tail_scalar_from'].value
|
||||
vlenmax = int( settings['reg_width_bits'].value / settings['ELEN_PARAM'].value )
|
||||
vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value
|
||||
/ settings['ELEN_PARAM'].value )
|
||||
TRMM = (settings['op'].value == 'trmm')
|
||||
is_complex = settings['complex'].value
|
||||
generate_gemm_kernel_inner = generate_gemm_kernel_inner_complex if is_complex else generate_gemm_kernel_inner_real
|
||||
|
@ -667,4 +670,4 @@ def main():
|
|||
ERROR("unsupported kernel type {}".format(settings['op']))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue