sve zgemm kernel

This commit is contained in:
Bine Brank 2021-12-26 08:44:05 +01:00
parent 683a7548bf
commit 878064f394
1 changed files with 132 additions and 412 deletions

View File

@ -48,6 +48,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define pCRow2 x14
#define pCRow3 x15
#define pA x16
#define lanes x17
#define alphaR x19
#define alphaI x20
@ -168,7 +170,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
.macro KERNELv1x4_I
ld2d {z0.d, z1.d}, p1/z, [pA]
ld2d {z2.d, z3.d}, p1/z, [pA, lanes, lsl #4] // next one
ld2d {z2.d, z3.d}, p1/z, [pA, #2, mul vl] // next one
add pA, pA, lanes, lsl #5 // pA += lanes*2*2*8
ld1rd z8.d, p0/z, [pB]
@ -561,17 +563,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
prfm PLDL1KEEP, [origPA]
fmov alphaR, d0
dup alphaz_R, alphaR
fmov alphaI, d1
dup alphaz_I, alphaI
lsl LDC, LDC, #4 // ldc = ldc * 2 * 8
ptrue p0.d // create true predicate
mov pB, origPB
// Loop over N
mov counterJ, origN
asr counterJ, counterJ, #2 // J = J / 4
cmp counterJ, #0
ble .Lzgemm_kernel_L2_BEGIN
/******************************************************************************/
.Lzgemm_kernel_L4_BEGIN:
mov pCRow0, pC
add pCRow1, pCRow0, LDC
@ -582,204 +589,112 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
mov pA, origPA // pA = start of A array
.Lzgemm_kernel_L4_M4_BEGIN:
.Lzgemm_kernel_L4_Mv1_BEGIN:
mov counterI, origM
asr counterI, counterI, #2 // counterI = counterI / 4
cmp counterI, #0
ble .Lzgemm_kernel_L4_M2_BEGIN
/* Loop over M is done in an SVE fashion. This has the benefit of the last M%SVE_LEN iterations being done in a single sweep */
mov counterI, #0
whilelt p1.d, counterI, origM
cntp lanes, p0, p1.d // lanes contain number of active SVE lanes in M dimension
.align 5
.Lzgemm_kernel_L4_M4_20:
.Lzgemm_kernel_L4_Mv1_20:
mov pB, origPB
INITv1x4 // fill with zeros
asr counterL , origK, #3
cmp counterL , #2
blt .Lzgemm_kernel_L4_M4_32
blt .Lzgemm_kernel_L4_Mv1_32
KERNEL4x4_I
KERNEL4x4_M2
KERNEL4x4_M1
KERNEL4x4_M2
KERNEL4x4_M1
KERNEL4x4_M2
KERNEL4x4_M1
KERNEL4x4_M2
KERNELv1x4_I
KERNELv1x4_M2
KERNELv1x4_M1
KERNELv1x4_M2
KERNELv1x4_M1
KERNELv1x4_M2
KERNELv1x4_M1
KERNELv1x4_M2
subs counterL, counterL, #2 // subtract 2
ble .Lzgemm_kernel_L4_M4_22a
ble .Lzgemm_kernel_L4_Mv1_22a
.align 5
.Lzgemm_kernel_L4_M4_22:
.Lzgemm_kernel_L4_Mv1_22:
KERNEL4x4_M1
KERNEL4x4_M2
KERNEL4x4_M1
KERNEL4x4_M2
KERNEL4x4_M1
KERNEL4x4_M2
KERNEL4x4_M1
KERNEL4x4_M2
KERNELv1x4_M1
KERNELv1x4_M2
KERNELv1x4_M1
KERNELv1x4_M2
KERNELv1x4_M1
KERNELv1x4_M2
KERNELv1x4_M1
KERNELv1x4_M2
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L4_M4_22
bgt .Lzgemm_kernel_L4_Mv1_22
.align 5
.Lzgemm_kernel_L4_M4_22a:
.Lzgemm_kernel_L4_Mv1_22a:
KERNEL4x4_M1
KERNEL4x4_M2
KERNEL4x4_M1
KERNEL4x4_M2
KERNEL4x4_M1
KERNEL4x4_M2
KERNEL4x4_M1
KERNEL4x4_E
KERNELv1x4_M1
KERNELv1x4_M2
KERNELv1x4_M1
KERNELv1x4_M2
KERNELv1x4_M1
KERNELv1x4_M2
KERNELv1x4_M1
KERNELv1x4_E
b .Lzgemm_kernel_L4_M4_44
b .Lzgemm_kernel_L4_Mv1_44
.align 5
.Lzgemm_kernel_L4_M4_32:
.Lzgemm_kernel_L4_Mv1_32:
tst counterL, #1
ble .Lzgemm_kernel_L4_M4_40
ble .Lzgemm_kernel_L4_Mv1_40
KERNEL4x4_I
KERNEL4x4_M2
KERNEL4x4_M1
KERNEL4x4_M2
KERNEL4x4_M1
KERNEL4x4_M2
KERNEL4x4_M1
KERNEL4x4_E
KERNELv1x4_I
KERNELv1x4_M2
KERNELv1x4_M1
KERNELv1x4_M2
KERNELv1x4_M1
KERNELv1x4_M2
KERNELv1x4_M1
KERNELv1x4_E
b .Lzgemm_kernel_L4_M4_44
b .Lzgemm_kernel_L4_Mv1_44
.Lzgemm_kernel_L4_M4_40:
.Lzgemm_kernel_L4_Mv1_40:
INIT4x4
INITv1x4
.Lzgemm_kernel_L4_M4_44:
.Lzgemm_kernel_L4_Mv1_44:
ands counterL , origK, #7
ble .Lzgemm_kernel_L4_M4_100
ble .Lzgemm_kernel_L4_Mv1_100
.align 5
.Lzgemm_kernel_L4_M4_46:
KERNEL4x4_SUB
.Lzgemm_kernel_L4_Mv1_46:
KERNELv1x4_SUB
subs counterL, counterL, #1
bne .Lzgemm_kernel_L4_M4_46
bne .Lzgemm_kernel_L4_Mv1_46
.Lzgemm_kernel_L4_M4_100:
.Lzgemm_kernel_L4_Mv1_100:
prfm PLDL1KEEP, [pA]
prfm PLDL1KEEP, [pA, #64]
prfm PLDL1KEEP, [origPB]
SAVE4x4
SAVEv1x4
.Lzgemm_kernel_L4_M4_END:
subs counterI, counterI, #1
bne .Lzgemm_kernel_L4_M4_20
.Lzgemm_kernel_L4_Mv1_END:
.Lzgemm_kernel_L4_M2_BEGIN:
incd counterI
whilelt p1.d, counterI, origM //SVE instruction
cntp lanes, p0, p1.d // lanes contain number of active SVE lanes in M dimension
b.any .Lzgemm_kernel_L4_Mv1_20
mov counterI, origM
tst counterI , #3
ble .Lzgemm_kernel_L4_END
tst counterI, #2 // counterI = counterI / 2
ble .Lzgemm_kernel_L4_M1_BEGIN
.Lzgemm_kernel_L4_M2_20:
INIT2x4
mov pB, origPB
asr counterL , origK, #3 // counterL = counterL / 8
cmp counterL , #0
ble .Lzgemm_kernel_L4_M2_40
.Lzgemm_kernel_L4_M2_22:
KERNEL2x4_SUB
KERNEL2x4_SUB
KERNEL2x4_SUB
KERNEL2x4_SUB
KERNEL2x4_SUB
KERNEL2x4_SUB
KERNEL2x4_SUB
KERNEL2x4_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L4_M2_22
.Lzgemm_kernel_L4_M2_40:
ands counterL , origK, #7 // counterL = counterL % 8
ble .Lzgemm_kernel_L4_M2_100
.Lzgemm_kernel_L4_M2_42:
KERNEL2x4_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L4_M2_42
.Lzgemm_kernel_L4_M2_100:
SAVE2x4
.Lzgemm_kernel_L4_M2_END:
.Lzgemm_kernel_L4_M1_BEGIN:
tst counterI, #1 // counterI = counterI % 2
ble .Lzgemm_kernel_L4_END
.Lzgemm_kernel_L4_M1_20:
INIT1x4
mov pB, origPB
asr counterL , origK, #3 // counterL = counterL / 8
cmp counterL , #0
ble .Lzgemm_kernel_L4_M1_40
.Lzgemm_kernel_L4_M1_22:
KERNEL1x4_SUB
KERNEL1x4_SUB
KERNEL1x4_SUB
KERNEL1x4_SUB
KERNEL1x4_SUB
KERNEL1x4_SUB
KERNEL1x4_SUB
KERNEL1x4_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L4_M1_22
.Lzgemm_kernel_L4_M1_40:
ands counterL , origK, #7 // counterL = counterL % 8
ble .Lzgemm_kernel_L4_M1_100
.Lzgemm_kernel_L4_M1_42:
KERNEL1x4_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L4_M1_42
.Lzgemm_kernel_L4_M1_100:
SAVE1x4
.Lzgemm_kernel_L4_END:
@ -810,157 +725,61 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
.Lzgemm_kernel_L2_M4_BEGIN:
.Lzgemm_kernel_L2_Mv1_BEGIN:
mov counterI, origM
asr counterI, counterI, #2 // counterI = counterI / 4
cmp counterI,#0
ble .Lzgemm_kernel_L2_M2_BEGIN
mov counterI, #0
whilelt p1.d, counterI, origM //SVE instruction
cntp lanes, p0, p1.d
.Lzgemm_kernel_L2_M4_20:
INIT4x2
.Lzgemm_kernel_L2_Mv1_20:
INITv1x2
mov pB, origPB
asr counterL , origK, #3 // counterL = counterL / 8
cmp counterL,#0
ble .Lzgemm_kernel_L2_M4_40
ble .Lzgemm_kernel_L2_Mv1_40
.align 5
.Lzgemm_kernel_L2_M4_22:
KERNEL4x2_SUB
KERNEL4x2_SUB
KERNEL4x2_SUB
KERNEL4x2_SUB
.Lzgemm_kernel_L2_Mv1_22:
KERNELv1x2_SUB
KERNELv1x2_SUB
KERNELv1x2_SUB
KERNELv1x2_SUB
KERNEL4x2_SUB
KERNEL4x2_SUB
KERNEL4x2_SUB
KERNEL4x2_SUB
KERNELv1x2_SUB
KERNELv1x2_SUB
KERNELv1x2_SUB
KERNELv1x2_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L2_M4_22
bgt .Lzgemm_kernel_L2_Mv1_22
.Lzgemm_kernel_L2_M4_40:
.Lzgemm_kernel_L2_Mv1_40:
ands counterL , origK, #7 // counterL = counterL % 8
ble .Lzgemm_kernel_L2_M4_100
ble .Lzgemm_kernel_L2_Mv1_100
.Lzgemm_kernel_L2_M4_42:
.Lzgemm_kernel_L2_Mv1_42:
KERNEL4x2_SUB
KERNELv1x2_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L2_M4_42
bgt .Lzgemm_kernel_L2_Mv1_42
.Lzgemm_kernel_L2_M4_100:
.Lzgemm_kernel_L2_Mv1_100:
SAVE4x2
SAVEv1x2
.Lzgemm_kernel_L2_M4_END:
subs counterI, counterI, #1
bgt .Lzgemm_kernel_L2_M4_20
.Lzgemm_kernel_L2_Mv1_END:
.Lzgemm_kernel_L2_M2_BEGIN:
mov counterI, origM
tst counterI , #3
ble .Lzgemm_kernel_L2_END
tst counterI, #2 // counterI = counterI / 2
ble .Lzgemm_kernel_L2_M1_BEGIN
.Lzgemm_kernel_L2_M2_20:
INIT2x2
mov pB, origPB
asr counterL , origK, #3 // counterL = counterL / 8
cmp counterL,#0
ble .Lzgemm_kernel_L2_M2_40
.Lzgemm_kernel_L2_M2_22:
KERNEL2x2_SUB
KERNEL2x2_SUB
KERNEL2x2_SUB
KERNEL2x2_SUB
KERNEL2x2_SUB
KERNEL2x2_SUB
KERNEL2x2_SUB
KERNEL2x2_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L2_M2_22
.Lzgemm_kernel_L2_M2_40:
ands counterL , origK, #7 // counterL = counterL % 8
ble .Lzgemm_kernel_L2_M2_100
.Lzgemm_kernel_L2_M2_42:
KERNEL2x2_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L2_M2_42
.Lzgemm_kernel_L2_M2_100:
SAVE2x2
.Lzgemm_kernel_L2_M2_END:
.Lzgemm_kernel_L2_M1_BEGIN:
tst counterI, #1 // counterI = counterI % 2
ble .Lzgemm_kernel_L2_END
.Lzgemm_kernel_L2_M1_20:
INIT1x2
mov pB, origPB
asr counterL , origK, #3 // counterL = counterL / 8
cmp counterL, #0
ble .Lzgemm_kernel_L2_M1_40
.Lzgemm_kernel_L2_M1_22:
KERNEL1x2_SUB
KERNEL1x2_SUB
KERNEL1x2_SUB
KERNEL1x2_SUB
KERNEL1x2_SUB
KERNEL1x2_SUB
KERNEL1x2_SUB
KERNEL1x2_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L2_M1_22
.Lzgemm_kernel_L2_M1_40:
ands counterL , origK, #7 // counterL = counterL % 8
ble .Lzgemm_kernel_L2_M1_100
.Lzgemm_kernel_L2_M1_42:
KERNEL1x2_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L2_M1_42
.Lzgemm_kernel_L2_M1_100:
SAVE1x2
incd counterI
whilelt p1.d, counterI, origM //SVE instruction
cntp lanes, p0, p1.d
b.any .Lzgemm_kernel_L2_Mv1_20
.Lzgemm_kernel_L2_END:
@ -981,163 +800,64 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
mov pA, origPA // pA = A
.Lzgemm_kernel_L1_Mv1_BEGIN:
mov counterI, #0
whilelt p1.d, counterI, origM //SVE instruction
cntp lanes, p0, p1.d
.Lzgemm_kernel_L1_M4_BEGIN:
.Lzgemm_kernel_L1_Mv1_20:
mov counterI, origM
asr counterI, counterI, #2 // counterI = counterI / 4
cmp counterI, #0
ble .Lzgemm_kernel_L1_M2_BEGIN
.Lzgemm_kernel_L1_M4_20:
INIT4x1
INITv1x1
mov pB, origPB
asr counterL , origK, #3 // counterL = counterL / 8
cmp counterL , #0
ble .Lzgemm_kernel_L1_M4_40
ble .Lzgemm_kernel_L1_Mv1_40
.align 5
.Lzgemm_kernel_L1_M4_22:
KERNEL4x1_SUB
KERNEL4x1_SUB
KERNEL4x1_SUB
KERNEL4x1_SUB
.Lzgemm_kernel_L1_Mv1_22:
KERNELv1x1_SUB
KERNELv1x1_SUB
KERNELv1x1_SUB
KERNELv1x1_SUB
KERNEL4x1_SUB
KERNEL4x1_SUB
KERNEL4x1_SUB
KERNEL4x1_SUB
KERNELv1x1_SUB
KERNELv1x1_SUB
KERNELv1x1_SUB
KERNELv1x1_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L1_M4_22
bgt .Lzgemm_kernel_L1_Mv1_22
.Lzgemm_kernel_L1_M4_40:
.Lzgemm_kernel_L1_Mv1_40:
ands counterL , origK, #7 // counterL = counterL % 8
ble .Lzgemm_kernel_L1_M4_100
ble .Lzgemm_kernel_L1_Mv1_100
.Lzgemm_kernel_L1_M4_42:
.Lzgemm_kernel_L1_Mv1_42:
KERNEL4x1_SUB
KERNELv1x1_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L1_M4_42
bgt .Lzgemm_kernel_L1_Mv1_42
.Lzgemm_kernel_L1_M4_100:
.Lzgemm_kernel_L1_Mv1_100:
SAVE4x1
SAVEv1x1
.Lzgemm_kernel_L1_M4_END:
subs counterI, counterI, #1
bgt .Lzgemm_kernel_L1_M4_20
.Lzgemm_kernel_L1_M2_BEGIN:
mov counterI, origM
tst counterI , #3
ble .Lzgemm_kernel_L1_END
tst counterI, #2 // counterI = counterI / 2
ble .Lzgemm_kernel_L1_M1_BEGIN
.Lzgemm_kernel_L1_M2_20:
INIT2x1
mov pB, origPB
asr counterL , origK, #3 // counterL = counterL / 8
cmp counterL , #0
ble .Lzgemm_kernel_L1_M2_40
.Lzgemm_kernel_L1_M2_22:
KERNEL2x1_SUB
KERNEL2x1_SUB
KERNEL2x1_SUB
KERNEL2x1_SUB
KERNEL2x1_SUB
KERNEL2x1_SUB
KERNEL2x1_SUB
KERNEL2x1_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L1_M2_22
.Lzgemm_kernel_L1_M2_40:
ands counterL , origK, #7 // counterL = counterL % 8
ble .Lzgemm_kernel_L1_M2_100
.Lzgemm_kernel_L1_M2_42:
KERNEL2x1_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L1_M2_42
.Lzgemm_kernel_L1_M2_100:
SAVE2x1
.Lzgemm_kernel_L1_M2_END:
.Lzgemm_kernel_L1_M1_BEGIN:
tst counterI, #1 // counterI = counterI % 2
ble .Lzgemm_kernel_L1_END
.Lzgemm_kernel_L1_M1_20:
INIT1x1
mov pB, origPB
asr counterL , origK, #3 // counterL = counterL / 8
cmp counterL , #0
ble .Lzgemm_kernel_L1_M1_40
.Lzgemm_kernel_L1_M1_22:
KERNEL1x1_SUB
KERNEL1x1_SUB
KERNEL1x1_SUB
KERNEL1x1_SUB
KERNEL1x1_SUB
KERNEL1x1_SUB
KERNEL1x1_SUB
KERNEL1x1_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L1_M1_22
.Lzgemm_kernel_L1_M1_40:
ands counterL , origK, #7 // counterL = counterL % 8
ble .Lzgemm_kernel_L1_M1_100
.Lzgemm_kernel_L1_M1_42:
KERNEL1x1_SUB
subs counterL, counterL, #1
bgt .Lzgemm_kernel_L1_M1_42
.Lzgemm_kernel_L1_M1_100:
SAVE1x1
.Lzgemm_kernel_L1_Mv1_END:
incd counterI
whilelt p1.d, counterI, origM //SVE instruction
cntp lanes, p0, p1.d
b.any .Lzgemm_kernel_L1_Mv1_20
.Lzgemm_kernel_L1_END:
/******************************************************************************/
.Lzgemm_kernel_L999:
mov x0, #0 // set return value