From 878064f39463631e0daf78395248083f1c8b251f Mon Sep 17 00:00:00 2001 From: Bine Brank Date: Sun, 26 Dec 2021 08:44:05 +0100 Subject: [PATCH] sve zgemm kernel --- kernel/arm64/zgemm_kernel_sve_v1x4.S | 544 +++++++-------------------- 1 file changed, 132 insertions(+), 412 deletions(-) diff --git a/kernel/arm64/zgemm_kernel_sve_v1x4.S b/kernel/arm64/zgemm_kernel_sve_v1x4.S index 0fc966f8c..1201d6dac 100644 --- a/kernel/arm64/zgemm_kernel_sve_v1x4.S +++ b/kernel/arm64/zgemm_kernel_sve_v1x4.S @@ -48,6 +48,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define pCRow2 x14 #define pCRow3 x15 #define pA x16 +#define lanes x17 + #define alphaR x19 #define alphaI x20 @@ -168,7 +170,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. .macro KERNELv1x4_I ld2d {z0.d, z1.d}, p1/z, [pA] - ld2d {z2.d, z3.d}, p1/z, [pA, lanes, lsl #4] // next one + ld2d {z2.d, z3.d}, p1/z, [pA, #2, mul vl] // next one add pA, pA, lanes, lsl #5 // pA += lanes*2*2*8 ld1rd z8.d, p0/z, [pB] @@ -561,17 +563,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. prfm PLDL1KEEP, [origPA] fmov alphaR, d0 + dup alphaz_R, alphaR fmov alphaI, d1 + dup alphaz_I, alphaI lsl LDC, LDC, #4 // ldc = ldc * 2 * 8 + ptrue p0.d // create true predicate mov pB, origPB +// Loop over N mov counterJ, origN asr counterJ, counterJ, #2 // J = J / 4 cmp counterJ, #0 ble .Lzgemm_kernel_L2_BEGIN +/******************************************************************************/ .Lzgemm_kernel_L4_BEGIN: mov pCRow0, pC add pCRow1, pCRow0, LDC @@ -582,204 +589,112 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mov pA, origPA // pA = start of A array -.Lzgemm_kernel_L4_M4_BEGIN: +.Lzgemm_kernel_L4_Mv1_BEGIN: - mov counterI, origM - asr counterI, counterI, #2 // counterI = counterI / 4 - cmp counterI, #0 - ble .Lzgemm_kernel_L4_M2_BEGIN +/* Loop over M is done in an SVE fashion. This has the benefit of the last M%SVE_LEN iterations being done in a single sweep */ + mov counterI, #0 + whilelt p1.d, counterI, origM + cntp lanes, p0, p1.d // lanes contain number of active SVE lanes in M dimension .align 5 -.Lzgemm_kernel_L4_M4_20: +.Lzgemm_kernel_L4_Mv1_20: mov pB, origPB + INITv1x4 // fill with zeros + asr counterL , origK, #3 cmp counterL , #2 - blt .Lzgemm_kernel_L4_M4_32 + blt .Lzgemm_kernel_L4_Mv1_32 - KERNEL4x4_I - KERNEL4x4_M2 - KERNEL4x4_M1 - KERNEL4x4_M2 - KERNEL4x4_M1 - KERNEL4x4_M2 - KERNEL4x4_M1 - KERNEL4x4_M2 + KERNELv1x4_I + KERNELv1x4_M2 + KERNELv1x4_M1 + KERNELv1x4_M2 + KERNELv1x4_M1 + KERNELv1x4_M2 + KERNELv1x4_M1 + KERNELv1x4_M2 subs counterL, counterL, #2 // subtract 2 - ble .Lzgemm_kernel_L4_M4_22a + ble .Lzgemm_kernel_L4_Mv1_22a .align 5 -.Lzgemm_kernel_L4_M4_22: +.Lzgemm_kernel_L4_Mv1_22: - KERNEL4x4_M1 - KERNEL4x4_M2 - KERNEL4x4_M1 - KERNEL4x4_M2 - KERNEL4x4_M1 - KERNEL4x4_M2 - KERNEL4x4_M1 - KERNEL4x4_M2 + KERNELv1x4_M1 + KERNELv1x4_M2 + KERNELv1x4_M1 + KERNELv1x4_M2 + KERNELv1x4_M1 + KERNELv1x4_M2 + KERNELv1x4_M1 + KERNELv1x4_M2 subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L4_M4_22 + bgt .Lzgemm_kernel_L4_Mv1_22 .align 5 -.Lzgemm_kernel_L4_M4_22a: +.Lzgemm_kernel_L4_Mv1_22a: - KERNEL4x4_M1 - KERNEL4x4_M2 - KERNEL4x4_M1 - KERNEL4x4_M2 - KERNEL4x4_M1 - KERNEL4x4_M2 - KERNEL4x4_M1 - KERNEL4x4_E + KERNELv1x4_M1 + KERNELv1x4_M2 + KERNELv1x4_M1 + KERNELv1x4_M2 + KERNELv1x4_M1 + KERNELv1x4_M2 + KERNELv1x4_M1 + KERNELv1x4_E - b .Lzgemm_kernel_L4_M4_44 + b .Lzgemm_kernel_L4_Mv1_44 .align 5 -.Lzgemm_kernel_L4_M4_32: +.Lzgemm_kernel_L4_Mv1_32: tst counterL, #1 - ble .Lzgemm_kernel_L4_M4_40 + ble .Lzgemm_kernel_L4_Mv1_40 - KERNEL4x4_I - KERNEL4x4_M2 - KERNEL4x4_M1 - KERNEL4x4_M2 - KERNEL4x4_M1 - KERNEL4x4_M2 - KERNEL4x4_M1 - KERNEL4x4_E + KERNELv1x4_I + KERNELv1x4_M2 + KERNELv1x4_M1 + KERNELv1x4_M2 + KERNELv1x4_M1 + KERNELv1x4_M2 + KERNELv1x4_M1 + KERNELv1x4_E - b .Lzgemm_kernel_L4_M4_44 + b .Lzgemm_kernel_L4_Mv1_44 -.Lzgemm_kernel_L4_M4_40: +.Lzgemm_kernel_L4_Mv1_40: - INIT4x4 + INITv1x4 -.Lzgemm_kernel_L4_M4_44: +.Lzgemm_kernel_L4_Mv1_44: ands counterL , origK, #7 - ble .Lzgemm_kernel_L4_M4_100 + ble .Lzgemm_kernel_L4_Mv1_100 .align 5 -.Lzgemm_kernel_L4_M4_46: - KERNEL4x4_SUB +.Lzgemm_kernel_L4_Mv1_46: + KERNELv1x4_SUB subs counterL, counterL, #1 - bne .Lzgemm_kernel_L4_M4_46 + bne .Lzgemm_kernel_L4_Mv1_46 -.Lzgemm_kernel_L4_M4_100: +.Lzgemm_kernel_L4_Mv1_100: prfm PLDL1KEEP, [pA] prfm PLDL1KEEP, [pA, #64] prfm PLDL1KEEP, [origPB] - SAVE4x4 + SAVEv1x4 -.Lzgemm_kernel_L4_M4_END: - subs counterI, counterI, #1 - bne .Lzgemm_kernel_L4_M4_20 +.Lzgemm_kernel_L4_Mv1_END: -.Lzgemm_kernel_L4_M2_BEGIN: + incd counterI + whilelt p1.d, counterI, origM //SVE instruction + cntp lanes, p0, p1.d // lanes contain number of active SVE lanes in M dimension + b.any .Lzgemm_kernel_L4_Mv1_20 - mov counterI, origM - tst counterI , #3 - ble .Lzgemm_kernel_L4_END - - tst counterI, #2 // counterI = counterI / 2 - ble .Lzgemm_kernel_L4_M1_BEGIN - -.Lzgemm_kernel_L4_M2_20: - - INIT2x4 - - mov pB, origPB - asr counterL , origK, #3 // counterL = counterL / 8 - cmp counterL , #0 - ble .Lzgemm_kernel_L4_M2_40 - -.Lzgemm_kernel_L4_M2_22: - - KERNEL2x4_SUB - KERNEL2x4_SUB - KERNEL2x4_SUB - KERNEL2x4_SUB - - KERNEL2x4_SUB - KERNEL2x4_SUB - KERNEL2x4_SUB - KERNEL2x4_SUB - - subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L4_M2_22 - - -.Lzgemm_kernel_L4_M2_40: - - ands counterL , origK, #7 // counterL = counterL % 8 - ble .Lzgemm_kernel_L4_M2_100 - -.Lzgemm_kernel_L4_M2_42: - - KERNEL2x4_SUB - - subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L4_M2_42 - -.Lzgemm_kernel_L4_M2_100: - - SAVE2x4 - -.Lzgemm_kernel_L4_M2_END: - - -.Lzgemm_kernel_L4_M1_BEGIN: - - tst counterI, #1 // counterI = counterI % 2 - ble .Lzgemm_kernel_L4_END - -.Lzgemm_kernel_L4_M1_20: - - INIT1x4 - - mov pB, origPB - asr counterL , origK, #3 // counterL = counterL / 8 - cmp counterL , #0 - ble .Lzgemm_kernel_L4_M1_40 - -.Lzgemm_kernel_L4_M1_22: - KERNEL1x4_SUB - KERNEL1x4_SUB - KERNEL1x4_SUB - KERNEL1x4_SUB - - KERNEL1x4_SUB - KERNEL1x4_SUB - KERNEL1x4_SUB - KERNEL1x4_SUB - - subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L4_M1_22 - - -.Lzgemm_kernel_L4_M1_40: - - ands counterL , origK, #7 // counterL = counterL % 8 - ble .Lzgemm_kernel_L4_M1_100 - -.Lzgemm_kernel_L4_M1_42: - - KERNEL1x4_SUB - - subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L4_M1_42 - -.Lzgemm_kernel_L4_M1_100: - - SAVE1x4 .Lzgemm_kernel_L4_END: @@ -810,157 +725,61 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -.Lzgemm_kernel_L2_M4_BEGIN: +.Lzgemm_kernel_L2_Mv1_BEGIN: - mov counterI, origM - asr counterI, counterI, #2 // counterI = counterI / 4 - cmp counterI,#0 - ble .Lzgemm_kernel_L2_M2_BEGIN + mov counterI, #0 + whilelt p1.d, counterI, origM //SVE instruction + cntp lanes, p0, p1.d -.Lzgemm_kernel_L2_M4_20: - INIT4x2 +.Lzgemm_kernel_L2_Mv1_20: + + INITv1x2 mov pB, origPB asr counterL , origK, #3 // counterL = counterL / 8 cmp counterL,#0 - ble .Lzgemm_kernel_L2_M4_40 + ble .Lzgemm_kernel_L2_Mv1_40 .align 5 -.Lzgemm_kernel_L2_M4_22: - KERNEL4x2_SUB - KERNEL4x2_SUB - KERNEL4x2_SUB - KERNEL4x2_SUB +.Lzgemm_kernel_L2_Mv1_22: + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB - KERNEL4x2_SUB - KERNEL4x2_SUB - KERNEL4x2_SUB - KERNEL4x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L2_M4_22 + bgt .Lzgemm_kernel_L2_Mv1_22 -.Lzgemm_kernel_L2_M4_40: +.Lzgemm_kernel_L2_Mv1_40: ands counterL , origK, #7 // counterL = counterL % 8 - ble .Lzgemm_kernel_L2_M4_100 + ble .Lzgemm_kernel_L2_Mv1_100 -.Lzgemm_kernel_L2_M4_42: +.Lzgemm_kernel_L2_Mv1_42: - KERNEL4x2_SUB + KERNELv1x2_SUB subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L2_M4_42 + bgt .Lzgemm_kernel_L2_Mv1_42 -.Lzgemm_kernel_L2_M4_100: +.Lzgemm_kernel_L2_Mv1_100: - SAVE4x2 + SAVEv1x2 -.Lzgemm_kernel_L2_M4_END: - - subs counterI, counterI, #1 - bgt .Lzgemm_kernel_L2_M4_20 +.Lzgemm_kernel_L2_Mv1_END: -.Lzgemm_kernel_L2_M2_BEGIN: - - mov counterI, origM - tst counterI , #3 - ble .Lzgemm_kernel_L2_END - - tst counterI, #2 // counterI = counterI / 2 - ble .Lzgemm_kernel_L2_M1_BEGIN - -.Lzgemm_kernel_L2_M2_20: - - INIT2x2 - - mov pB, origPB - asr counterL , origK, #3 // counterL = counterL / 8 - cmp counterL,#0 - ble .Lzgemm_kernel_L2_M2_40 - -.Lzgemm_kernel_L2_M2_22: - - KERNEL2x2_SUB - KERNEL2x2_SUB - KERNEL2x2_SUB - KERNEL2x2_SUB - - KERNEL2x2_SUB - KERNEL2x2_SUB - KERNEL2x2_SUB - KERNEL2x2_SUB - - subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L2_M2_22 - - -.Lzgemm_kernel_L2_M2_40: - - ands counterL , origK, #7 // counterL = counterL % 8 - ble .Lzgemm_kernel_L2_M2_100 - -.Lzgemm_kernel_L2_M2_42: - - KERNEL2x2_SUB - - subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L2_M2_42 - -.Lzgemm_kernel_L2_M2_100: - - SAVE2x2 - -.Lzgemm_kernel_L2_M2_END: - - -.Lzgemm_kernel_L2_M1_BEGIN: - - tst counterI, #1 // counterI = counterI % 2 - ble .Lzgemm_kernel_L2_END - -.Lzgemm_kernel_L2_M1_20: - - INIT1x2 - - mov pB, origPB - asr counterL , origK, #3 // counterL = counterL / 8 - cmp counterL, #0 - ble .Lzgemm_kernel_L2_M1_40 - -.Lzgemm_kernel_L2_M1_22: - KERNEL1x2_SUB - KERNEL1x2_SUB - KERNEL1x2_SUB - KERNEL1x2_SUB - - KERNEL1x2_SUB - KERNEL1x2_SUB - KERNEL1x2_SUB - KERNEL1x2_SUB - - subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L2_M1_22 - - -.Lzgemm_kernel_L2_M1_40: - - ands counterL , origK, #7 // counterL = counterL % 8 - ble .Lzgemm_kernel_L2_M1_100 - -.Lzgemm_kernel_L2_M1_42: - - KERNEL1x2_SUB - - subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L2_M1_42 - -.Lzgemm_kernel_L2_M1_100: - - SAVE1x2 + incd counterI + whilelt p1.d, counterI, origM //SVE instruction + cntp lanes, p0, p1.d + b.any .Lzgemm_kernel_L2_Mv1_20 .Lzgemm_kernel_L2_END: @@ -981,163 +800,64 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mov pA, origPA // pA = A +.Lzgemm_kernel_L1_Mv1_BEGIN: + + mov counterI, #0 + whilelt p1.d, counterI, origM //SVE instruction + cntp lanes, p0, p1.d -.Lzgemm_kernel_L1_M4_BEGIN: +.Lzgemm_kernel_L1_Mv1_20: - mov counterI, origM - asr counterI, counterI, #2 // counterI = counterI / 4 - cmp counterI, #0 - ble .Lzgemm_kernel_L1_M2_BEGIN - -.Lzgemm_kernel_L1_M4_20: - - INIT4x1 + INITv1x1 mov pB, origPB asr counterL , origK, #3 // counterL = counterL / 8 cmp counterL , #0 - ble .Lzgemm_kernel_L1_M4_40 + ble .Lzgemm_kernel_L1_Mv1_40 .align 5 -.Lzgemm_kernel_L1_M4_22: - KERNEL4x1_SUB - KERNEL4x1_SUB - KERNEL4x1_SUB - KERNEL4x1_SUB +.Lzgemm_kernel_L1_Mv1_22: + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB - KERNEL4x1_SUB - KERNEL4x1_SUB - KERNEL4x1_SUB - KERNEL4x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L1_M4_22 + bgt .Lzgemm_kernel_L1_Mv1_22 -.Lzgemm_kernel_L1_M4_40: +.Lzgemm_kernel_L1_Mv1_40: ands counterL , origK, #7 // counterL = counterL % 8 - ble .Lzgemm_kernel_L1_M4_100 + ble .Lzgemm_kernel_L1_Mv1_100 -.Lzgemm_kernel_L1_M4_42: +.Lzgemm_kernel_L1_Mv1_42: - KERNEL4x1_SUB + KERNELv1x1_SUB subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L1_M4_42 + bgt .Lzgemm_kernel_L1_Mv1_42 -.Lzgemm_kernel_L1_M4_100: +.Lzgemm_kernel_L1_Mv1_100: - SAVE4x1 + SAVEv1x1 -.Lzgemm_kernel_L1_M4_END: - - subs counterI, counterI, #1 - bgt .Lzgemm_kernel_L1_M4_20 - - -.Lzgemm_kernel_L1_M2_BEGIN: - - mov counterI, origM - tst counterI , #3 - ble .Lzgemm_kernel_L1_END - - tst counterI, #2 // counterI = counterI / 2 - ble .Lzgemm_kernel_L1_M1_BEGIN - -.Lzgemm_kernel_L1_M2_20: - - INIT2x1 - - mov pB, origPB - asr counterL , origK, #3 // counterL = counterL / 8 - cmp counterL , #0 - ble .Lzgemm_kernel_L1_M2_40 - -.Lzgemm_kernel_L1_M2_22: - - KERNEL2x1_SUB - KERNEL2x1_SUB - KERNEL2x1_SUB - KERNEL2x1_SUB - - KERNEL2x1_SUB - KERNEL2x1_SUB - KERNEL2x1_SUB - KERNEL2x1_SUB - - subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L1_M2_22 - - -.Lzgemm_kernel_L1_M2_40: - - ands counterL , origK, #7 // counterL = counterL % 8 - ble .Lzgemm_kernel_L1_M2_100 - -.Lzgemm_kernel_L1_M2_42: - - KERNEL2x1_SUB - - subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L1_M2_42 - -.Lzgemm_kernel_L1_M2_100: - - SAVE2x1 - -.Lzgemm_kernel_L1_M2_END: - - -.Lzgemm_kernel_L1_M1_BEGIN: - - tst counterI, #1 // counterI = counterI % 2 - ble .Lzgemm_kernel_L1_END - -.Lzgemm_kernel_L1_M1_20: - - INIT1x1 - - mov pB, origPB - asr counterL , origK, #3 // counterL = counterL / 8 - cmp counterL , #0 - ble .Lzgemm_kernel_L1_M1_40 - -.Lzgemm_kernel_L1_M1_22: - KERNEL1x1_SUB - KERNEL1x1_SUB - KERNEL1x1_SUB - KERNEL1x1_SUB - - KERNEL1x1_SUB - KERNEL1x1_SUB - KERNEL1x1_SUB - KERNEL1x1_SUB - - subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L1_M1_22 - - -.Lzgemm_kernel_L1_M1_40: - - ands counterL , origK, #7 // counterL = counterL % 8 - ble .Lzgemm_kernel_L1_M1_100 - -.Lzgemm_kernel_L1_M1_42: - - KERNEL1x1_SUB - - subs counterL, counterL, #1 - bgt .Lzgemm_kernel_L1_M1_42 - -.Lzgemm_kernel_L1_M1_100: - - SAVE1x1 +.Lzgemm_kernel_L1_Mv1_END: + incd counterI + whilelt p1.d, counterI, origM //SVE instruction + cntp lanes, p0, p1.d + b.any .Lzgemm_kernel_L1_Mv1_20 .Lzgemm_kernel_L1_END: +/******************************************************************************/ .Lzgemm_kernel_L999: mov x0, #0 // set return value