From 86ae89bf33f28780ccaa1044376c94401545b806 Mon Sep 17 00:00:00 2001 From: Bine Brank Date: Sun, 28 Nov 2021 18:12:47 +0100 Subject: [PATCH 1/6] add sgemm kernel and copy functions for sgemm and ssymm --- kernel/Makefile.L3 | 10 + kernel/arm64/KERNEL.A64FX | 34 +- kernel/arm64/sgemm_kernel_sve_v1x8.S | 874 +++++++++++++++++++++++++++ kernel/arm64/sgemm_ncopy_sve_v1.c | 78 +++ kernel/arm64/sgemm_tcopy_sve_v1.c | 77 +++ kernel/arm64/symm_lcopy_sve.c | 50 ++ kernel/arm64/symm_ucopy_sve.c | 50 ++ param.h | 4 +- 8 files changed, 1151 insertions(+), 26 deletions(-) create mode 100644 kernel/arm64/sgemm_kernel_sve_v1x8.S create mode 100644 kernel/arm64/sgemm_ncopy_sve_v1.c create mode 100644 kernel/arm64/sgemm_tcopy_sve_v1.c diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 695f8ae70..593e33dde 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -1809,11 +1809,21 @@ $(KDIR)ssymm_outcopy$(TSUFFIX).$(SUFFIX) : generic/symm_ucopy_$(SGEMM_UNROLL_N). $(KDIR)ssymm_oltcopy$(TSUFFIX).$(SUFFIX) : generic/symm_lcopy_$(SGEMM_UNROLL_N).c $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -DOUTER -DLOWER $< -o $@ +ifdef SSYMMUCOPY_M +$(KDIR)ssymm_iutcopy$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYMMUCOPY_M) + $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -ULOWER $< -o $@ +else $(KDIR)ssymm_iutcopy$(TSUFFIX).$(SUFFIX) : generic/symm_ucopy_$(SGEMM_UNROLL_M).c $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -ULOWER $< -o $@ +endif +ifdef SSYMMLCOPY_M +$(KDIR)ssymm_iltcopy$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYMMLCOPY_M) + $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -DLOWER $< -o $@ +else $(KDIR)ssymm_iltcopy$(TSUFFIX).$(SUFFIX) : generic/symm_lcopy_$(SGEMM_UNROLL_M).c $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -DLOWER $< -o $@ +endif $(KDIR)dsymm_outcopy$(TSUFFIX).$(SUFFIX) : generic/symm_ucopy_$(DGEMM_UNROLL_N).c $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -DDOUBLE -UCOMPLEX -DOUTER -ULOWER $< -o $@ diff --git a/kernel/arm64/KERNEL.A64FX b/kernel/arm64/KERNEL.A64FX index 83536f12d..ee66fea8e 100644 --- a/kernel/arm64/KERNEL.A64FX +++ b/kernel/arm64/KERNEL.A64FX @@ -114,35 +114,21 @@ DSDOTKERNEL = dot.S DGEMM_BETA = dgemm_beta.S SGEMM_BETA = sgemm_beta.S -SGEMMKERNEL = sgemm_kernel_$(SGEMM_UNROLL_M)x$(SGEMM_UNROLL_N).S -STRMMKERNEL = strmm_kernel_$(SGEMM_UNROLL_M)x$(SGEMM_UNROLL_N).S -ifneq ($(SGEMM_UNROLL_M), $(SGEMM_UNROLL_N)) -ifeq ($(SGEMM_UNROLL_M), 16) -SGEMMITCOPY = sgemm_tcopy_$(SGEMM_UNROLL_M).S -else -SGEMMITCOPY = ../generic/gemm_tcopy_$(SGEMM_UNROLL_M).c -endif -ifeq ($(SGEMM_UNROLL_M), 4) -SGEMMINCOPY = sgemm_ncopy_$(SGEMM_UNROLL_M).S -else -SGEMMINCOPY = ../generic/gemm_ncopy_$(SGEMM_UNROLL_M).c -endif +SGEMMKERNEL = sgemm_kernel_sve_v1x$(SGEMM_UNROLL_N).S +STRMMKERNEL = strmm_kernel_8x$(SGEMM_UNROLL_N).S + +SGEMMINCOPY = sgemm_ncopy_sve_v1.c +SGEMMITCOPY = sgemm_tcopy_sve_v1.c +SGEMMONCOPY = sgemm_ncopy_$(DGEMM_UNROLL_N).S +SGEMMOTCOPY = sgemm_tcopy_$(DGEMM_UNROLL_N).S + SGEMMINCOPYOBJ = sgemm_incopy$(TSUFFIX).$(SUFFIX) SGEMMITCOPYOBJ = sgemm_itcopy$(TSUFFIX).$(SUFFIX) -endif -ifeq ($(SGEMM_UNROLL_N), 16) -SGEMMOTCOPY = sgemm_tcopy_$(SGEMM_UNROLL_N).S -else -SGEMMOTCOPY = ../generic/gemm_tcopy_$(SGEMM_UNROLL_N).c -endif -ifeq ($(SGEMM_UNROLL_N), 4) -SGEMMONCOPY = sgemm_ncopy_$(SGEMM_UNROLL_N).S -else -SGEMMONCOPY = ../generic/gemm_ncopy_$(SGEMM_UNROLL_N).c -endif SGEMMONCOPYOBJ = sgemm_oncopy$(TSUFFIX).$(SUFFIX) SGEMMOTCOPYOBJ = sgemm_otcopy$(TSUFFIX).$(SUFFIX) +SSYMMUCOPY_M = symm_ucopy_sve.c +SSYMMLCOPY_M = symm_lcopy_sve.c DGEMMKERNEL = dgemm_kernel_sve_v2x$(DGEMM_UNROLL_N).S DTRMMKERNEL = dtrmm_kernel_sve_v1x$(DGEMM_UNROLL_N).S diff --git a/kernel/arm64/sgemm_kernel_sve_v1x8.S b/kernel/arm64/sgemm_kernel_sve_v1x8.S new file mode 100644 index 000000000..88c74bc0f --- /dev/null +++ b/kernel/arm64/sgemm_kernel_sve_v1x8.S @@ -0,0 +1,874 @@ +/******************************************************************************* +Copyright (c) 2015, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +#define ASSEMBLER +#include "common.h" + +/* X0 X1 X2 s0 X3 x4 x5 x6 */ +/*int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha0,FLOAT* ba,FLOAT* bb,FLOAT* C,BLASLONG ldc )*/ + +#define origM x0 +#define origN x1 +#define origK x2 +#define origPA x3 +#define origPB x4 +#define pC x5 +#define LDC x6 +#define temp x7 +#define counterL x8 +#define counterI x9 +#define counterJ x10 +#define pB x11 +#define pCRow0 x12 +#define pCRow1 x13 +#define pCRow2 x14 + +#define lanes x15 +#define pA x16 +#define alpha w17 + +#define alpha0 s10 +#define alphaZ z2.s + +#define A_PRE_SIZE 1536 +#define B_PRE_SIZE 512 +#define C_PRE_SIZE 128 + +// 00 origM +// 01 origN +// 02 origK +// 03 origPA +// 04 origPB +// 05 pC +// 06 origLDC -> LDC +// 07 temp +// 08 counterL +// 09 counterI +// 10 counterJ +// 11 pB +// 12 pCRow0 +// 13 pCRow1 +// 14 pCRow2 +// 15 lanes +// 16 pA +// 17 +// 18 must save +// 19 must save +// 20 must save +// 21 must save +// 22 must save +// 23 must save +// 24 must save +// 25 must save +// 26 must save +// 27 must save +// 28 must save +// 29 frame +// 30 link +// 31 sp + +//v00 ALPHA -> pA0_0 +//v01 pA0_1 +//v02 ALPHA0 +//v03 +//v04 +//v05 +//v06 +//v07 +//v08 must save pB0_0 +//v09 must save pB0_1 +//v10 must save pB0_2 +//v11 must save pB0_3 +//v12 must save pB0_4 +//v13 must save pB0_5 +//v14 must save pB0_6 +//v15 must save pB0_7 +//v16 must save C0 +//v17 must save C1 +//v18 must save C2 +//v19 must save C3 +//v20 must save C4 +//v21 must save C5 +//v22 must save C6 +//v23 must save C7 + +/******************************************************************************* +* Macro definitions +*******************************************************************************/ + +.macro INITv1x8 + dup z16.s, #0 + dup z17.s, #0 + dup z18.s, #0 + dup z19.s, #0 + dup z20.s, #0 + dup z21.s, #0 + dup z22.s, #0 + dup z23.s, #0 +.endm + +.macro KERNELv1x8_I + ld1w z0.s, p1/z, [pA] + ld1w z1.s, p1/z, [pA, lanes, lsl #2] // next one + add pA, pA, lanes, lsl #3 // pA = pA + lanes * 2 * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + ld1rw z10.s, p0/z, [pB, 8] + ld1rw z11.s, p0/z, [pB, 12] + ld1rw z12.s, p0/z, [pB, 16] + ld1rw z13.s, p0/z, [pB, 20] + ld1rw z14.s, p0/z, [pB, 24] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 + + fmla z16.s, p1/m, z0.s, z8.s + ld1rw z8.s, p0/z, [pB] + fmla z17.s, p1/m, z0.s, z9.s + ld1rw z9.s, p0/z, [pB, 4] + fmla z18.s, p1/m, z0.s, z10.s + ld1rw z10.s, p0/z, [pB, 8] + fmla z19.s, p1/m, z0.s, z11.s + ld1rw z11.s, p0/z, [pB, 12] + fmla z20.s, p1/m, z0.s, z12.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE] + ld1rw z12.s, p0/z, [pB, 16] + fmla z21.s, p1/m, z0.s, z13.s + ld1rw z13.s, p0/z, [pB, 20] + fmla z22.s, p1/m, z0.s, z14.s + ld1rw z14.s, p0/z, [pB, 24] + fmla z23.s, p1/m, z0.s, z15.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE+64] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 +.endm + +.macro KERNELv1x8_M1 + ld1w z1.s, p1/z, [pA] + add pA, pA, lanes, lsl #2 // pA = pA + lanes * 4 + + fmla z16.s, p1/m, z0.s, z8.s + ld1rw z8.s, p0/z, [pB] + fmla z17.s, p1/m, z0.s, z9.s + ld1rw z9.s, p0/z, [pB, 4] + fmla z18.s, p1/m, z0.s, z10.s + ld1rw z10.s, p0/z, [pB, 8] + fmla z19.s, p1/m, z0.s, z11.s + ld1rw z11.s, p0/z, [pB, 12] + fmla z20.s, p1/m, z0.s, z12.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE] + ld1rw z12.s, p0/z, [pB, 16] + fmla z21.s, p1/m, z0.s, z13.s + ld1rw z13.s, p0/z, [pB, 20] + fmla z22.s, p1/m, z0.s, z14.s + ld1rw z14.s, p0/z, [pB, 24] + fmla z23.s, p1/m, z0.s, z15.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE+64] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 +.endm + +.macro KERNELv1x8_M2 + ld1w z0.s, p1/z, [pA] + add pA, pA, lanes, lsl #2 // pA = pA + lanes * 4 + + fmla z16.s, p1/m, z1.s, z8.s + ld1rw z8.s, p0/z, [pB] + fmla z17.s, p1/m, z1.s, z9.s + ld1rw z9.s, p0/z, [pB, 4] + fmla z18.s, p1/m, z1.s, z10.s + ld1rw z10.s, p0/z, [pB, 8] + fmla z19.s, p1/m, z1.s, z11.s + ld1rw z11.s, p0/z, [pB, 12] + fmla z20.s, p1/m, z1.s, z12.s + ld1rw z12.s, p0/z, [pB, 16] + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + fmla z21.s, p1/m, z1.s, z13.s + ld1rw z13.s, p0/z, [pB, 20] + fmla z22.s, p1/m, z1.s, z14.s + ld1rw z14.s, p0/z, [pB, 24] + fmla z23.s, p1/m, z1.s, z15.s + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 +.endm + +.macro KERNELv1x8_E + fmla z16.s, p1/m, z1.s, z8.s + fmla z17.s, p1/m, z1.s, z9.s + fmla z18.s, p1/m, z1.s, z10.s + fmla z19.s, p1/m, z1.s, z11.s + fmla z20.s, p1/m, z1.s, z12.s + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + fmla z21.s, p1/m, z1.s, z13.s + fmla z22.s, p1/m, z1.s, z14.s + fmla z23.s, p1/m, z1.s, z15.s +.endm + +.macro KERNELv1x8_SUB + ld1w z0.s, p1/z, [pA] + add pA, pA, lanes, lsl #2 // pA = pA + lanes * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + ld1rw z10.s, p0/z, [pB, 8] + ld1rw z11.s, p0/z, [pB, 12] + ld1rw z12.s, p0/z, [pB, 16] + ld1rw z13.s, p0/z, [pB, 20] + ld1rw z14.s, p0/z, [pB, 24] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 + + fmla z16.s, p1/m, z0.s, z8.s + fmla z17.s, p1/m, z0.s, z9.s + fmla z18.s, p1/m, z0.s, z10.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE] + fmla z19.s, p1/m, z0.s, z11.s + fmla z20.s, p1/m, z0.s, z12.s + fmla z21.s, p1/m, z0.s, z13.s + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + fmla z22.s, p1/m, z0.s, z14.s + fmla z23.s, p1/m, z0.s, z15.s + +.endm + +.macro SAVEv1x8 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + add pCRow1, pCRow0, LDC + ld1w z24.s, p1/z, [pCRow0] + fmla z24.s, p1/m, z16.s, alphaZ + st1w z24.s, p1, [pCRow0] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + ld1w z25.s, p1/z, [pCRow1] + fmla z25.s, p1/m, z17.s, alphaZ + st1w z25.s, p1, [pCRow1] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + ld1w z26.s, p1/z, [pCRow2] + fmla z26.s, p1/m, z18.s, alphaZ + st1w z26.s, p1, [pCRow2] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + ld1w z27.s, p1/z, [pCRow1] + fmla z27.s, p1/m, z19.s, alphaZ + st1w z27.s, p1, [pCRow1] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + ld1w z28.s, p1/z, [pCRow2] + fmla z28.s, p1/m, z20.s, alphaZ + st1w z28.s, p1, [pCRow2] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + ld1w z29.s, p1/z, [pCRow1] + fmla z29.s, p1/m, z21.s, alphaZ + st1w z29.s, p1, [pCRow1] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + ld1w z30.s, p1/z, [pCRow2] + fmla z30.s, p1/m, z22.s, alphaZ + st1w z30.s, p1, [pCRow2] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + ld1w z31.s, p1/z, [pCRow1] + fmla z31.s, p1/m, z23.s, alphaZ + st1w z31.s, p1, [pCRow1] + + add pCRow0, pCRow0, lanes, lsl #2 // pC = pC + lanes * 4 + +.endm + +/******************************************************************************/ + +.macro INITv1x4 + dup z16.s, #0 + dup z17.s, #0 + dup z18.s, #0 + dup z19.s, #0 +.endm + +.macro KERNELv1x4_SUB + ld1w z0.s, p1/z, [pA] + add pA, pA, lanes, lsl #2 // pA = pA + lanes * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + ld1rw z10.s, p0/z, [pB, 8] + ld1rw z11.s, p0/z, [pB, 12] + + add pB, pB, 16 + + fmla z16.s, p1/m, z0.s, z8.s + fmla z17.s, p1/m, z0.s, z9.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE] + fmla z18.s, p1/m, z0.s, z10.s + fmla z19.s, p1/m, z0.s, z11.s + +.endm + +.macro SAVEv1x4 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + add pCRow1, pCRow0, LDC + ld1w z24.s, p1/z, [pCRow0] + fmla z24.s, p1/m, z16.s, alphaZ + st1w z24.s, p1, [pCRow0] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + ld1w z25.s, p1/z, [pCRow1] + fmla z25.s, p1/m, z17.s, alphaZ + st1w z25.s, p1, [pCRow1] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + ld1w z26.s, p1/z, [pCRow2] + fmla z26.s, p1/m, z18.s, alphaZ + st1w z26.s, p1, [pCRow2] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + ld1w z27.s, p1/z, [pCRow1] + fmla z27.s, p1/m, z19.s, alphaZ + st1w z27.s, p1, [pCRow1] + + add pCRow0, pCRow0, lanes, lsl #2 // pC = pC + lanes * 4 + +.endm + +/******************************************************************************/ + +.macro INITv1x2 + dup z16.s, #0 + dup z17.s, #0 +.endm + +.macro KERNELv1x2_SUB + ld1w z0.s, p1/z, [pA] + add pA, pA, lanes, lsl #2 // pA = pA + lanes * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + + add pB, pB, 8 + + fmla z16.s, p1/m, z0.s, z8.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE] + fmla z17.s, p1/m, z0.s, z9.s + +.endm + +.macro SAVEv1x2 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + add pCRow1, pCRow0, LDC + ld1w z24.s, p1/z, [pCRow0] + fmla z24.s, p1/m, z16.s, alphaZ + st1w z24.s, p1, [pCRow0] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + ld1w z25.s, p1/z, [pCRow1] + fmla z25.s, p1/m, z17.s, alphaZ + st1w z25.s, p1, [pCRow1] + + add pCRow0, pCRow0, lanes, lsl #2 // pC = pC + lanes * 4 + +.endm + +/******************************************************************************/ + +.macro INITv1x1 + dup z16.s, #0 +.endm + +.macro KERNELv1x1_SUB + ld1w z0.s, p1/z, [pA] + add pA, pA, lanes, lsl #2 // pA = pA + lanes * 8 + + ld1rw z8.s, p0/z, [pB] + + add pB, pB, 4 + + fmla z16.s, p1/m, z0.s, z8.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE] + +.endm + +.macro SAVEv1x1 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + ld1w z24.s, p1/z, [pCRow0] + fmla z24.s, p1/m, z16.s, alphaZ + st1w z24.s, p1, [pCRow0] + + + add pCRow0, pCRow0, lanes, lsl #2 // pC = pC + lanes * 4 + +.endm + + +/******************************************************************************* +* End of macro definitions +*******************************************************************************/ + + PROLOGUE + + .align 5 + add sp, sp, #-(11 * 16) + stp d8, d9, [sp, #(0 * 16)] + stp d10, d11, [sp, #(1 * 16)] + stp d12, d13, [sp, #(2 * 16)] + stp d14, d15, [sp, #(3 * 16)] + stp d16, d17, [sp, #(4 * 16)] + stp x18, x19, [sp, #(5 * 16)] + stp x20, x21, [sp, #(6 * 16)] + stp x22, x23, [sp, #(7 * 16)] + stp x24, x25, [sp, #(8 * 16)] + stp x26, x27, [sp, #(9 * 16)] + str x28, [sp, #(10 * 16)] + + prfm PLDL1KEEP, [origPB] + prfm PLDL1KEEP, [origPA] + + fmov alpha, s0 + dup alphaZ, alpha + + lsl LDC, LDC, #2 // ldc = ldc * 4 + ptrue p0.s // create true predicate + + mov pB, origPB +// Loop over N + mov counterJ, origN + asr counterJ, counterJ, #3 // J = J / 8 + cmp counterJ, #0 + ble .Ldgemm_kernel_L4_BEGIN + +/******************************************************************************/ +/* Repeat this as long as there are 8 left in N */ + + .align 5 +.Ldgemm_kernel_L8_BEGIN: + mov pCRow0, pC + + add pC, pC, LDC, lsl #3 // add 8 x LDC + + mov pA, origPA // pA = start of A array + +.Ldgemm_kernel_L8_Mv1_BEGIN: + +/* Loop over M is done in an SVE fashion. This has the benefit of the last M%SVE_LEN iterations being done in a single sweep */ + mov counterI, #0 + whilelt p1.s, counterI, origM + cntp lanes, p0, p1.s // lanes contain number of active SVE lanes in M dimension + + .align 5 +.Ldgemm_kernel_L8_Mv1_20: + + mov pB, origPB + INITv1x8 // fill with zeros + + asr counterL , origK, #3 // L = K / 8 + cmp counterL , #2 // is there at least 4 to do? + blt .Ldgemm_kernel_L8_Mv1_32 + + KERNELv1x8_I + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + + subs counterL, counterL, #2 // subtract 2 + ble .Ldgemm_kernel_L8_Mv1_22a + + .align 5 +.Ldgemm_kernel_L8_Mv1_22: + + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + + subs counterL, counterL, #1 + bgt .Ldgemm_kernel_L8_Mv1_22 + + .align 5 +.Ldgemm_kernel_L8_Mv1_22a: + + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_E + + b .Ldgemm_kernel_L8_Mv1_44 + + .align 5 +.Ldgemm_kernel_L8_Mv1_32: + + tst counterL, #1 + ble .Ldgemm_kernel_L8_Mv1_40 + + KERNELv1x8_I + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_E + + + b .Ldgemm_kernel_L8_Mv1_44 + +.Ldgemm_kernel_L8_Mv1_40: + + INITv1x8 + +.Ldgemm_kernel_L8_Mv1_44: + + ands counterL , origK, #7 + ble .Ldgemm_kernel_L8_Mv1_100 + + .align 5 +.Ldgemm_kernel_L8_Mv1_46: + + KERNELv1x8_SUB + + subs counterL, counterL, #1 + bne .Ldgemm_kernel_L8_Mv1_46 + +.Ldgemm_kernel_L8_Mv1_100: + prfm PLDL1KEEP, [pA] + prfm PLDL1KEEP, [pA, #64] + prfm PLDL1KEEP, [origPB] + + SAVEv1x8 + +.Ldgemm_kernel_L8_Mv1_END: + + incw counterI + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s // lanes contain number of active SVE lanes in M dimension + b.any .Ldgemm_kernel_L8_Mv1_20 + +.Ldgemm_kernel_L8_END: + + lsl temp, origK, #5 + add origPB, origPB, temp // B = B + K * 8 * 4 + + subs counterJ, counterJ , #1 // j-- + bgt .Ldgemm_kernel_L8_BEGIN + +/******************************************************************************/ +/* Repeat the same thing if 4 left in N */ + + .align 5 +.Ldgemm_kernel_L4_BEGIN: + + mov counterJ , origN + tst counterJ , #4 + ble .Ldgemm_kernel_L2_BEGIN + + + mov pCRow0, pC + + add pC, pC, LDC, lsl #2 // add 4 x LDC + + mov pA, origPA // pA = start of A array + +.Ldgemm_kernel_L4_Mv1_BEGIN: + + mov counterI, #0 + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + + .align 5 +.Ldgemm_kernel_L4_Mv1_20: + + mov pB, origPB + INITv1x4 // fill with zeros + + asr counterL , origK, #3 // L = K / 8 + cmp counterL , #0 // is there at least 4 to do? + ble .Ldgemm_kernel_L4_Mv1_44 + + .align 5 +.Ldgemm_kernel_L4_Mv1_22: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x4_SUB + KERNELv1x4_SUB + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x4_SUB + KERNELv1x4_SUB + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x4_SUB + KERNELv1x4_SUB + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x4_SUB + KERNELv1x4_SUB + + subs counterL, counterL, #1 + bgt .Ldgemm_kernel_L4_Mv1_22 + +.Ldgemm_kernel_L4_Mv1_44: + + ands counterL , origK, #7 + ble .Ldgemm_kernel_L4_Mv1_100 + + .align 5 +.Ldgemm_kernel_L4_Mv1_46: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x4_SUB + + subs counterL, counterL, #1 + bne .Ldgemm_kernel_L4_Mv1_46 + +.Ldgemm_kernel_L4_Mv1_100: + prfm PLDL1KEEP, [pA] + prfm PLDL1KEEP, [pA, #64] + prfm PLDL1KEEP, [origPB] + + SAVEv1x4 + +.Ldgemm_kernel_L4_Mv1_END: + + incw counterI + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + b.any .Ldgemm_kernel_L4_Mv1_20 + + +.Ldgemm_kernel_L4_END: + lsl temp, origK, #4 + add origPB, origPB, temp // B = B + K * 4 * 4 + +/******************************************************************************/ +/* Repeat the same thing if 2 left in N */ + + .align 5 +.Ldgemm_kernel_L2_BEGIN: + + mov counterJ , origN + tst counterJ , #2 + ble .Ldgemm_kernel_L1_BEGIN + + mov pCRow0, pC + + add pC, pC, LDC, lsl #1 // add 2 x LDC + + mov pA, origPA // pA = start of A array + +.Ldgemm_kernel_L2_Mv1_BEGIN: + + mov counterI, #0 + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + + .align 5 +.Ldgemm_kernel_L2_Mv1_20: + + mov pB, origPB + INITv1x2 // fill with zeros + + asr counterL , origK, #3 // L = K / 8 + cmp counterL , #0 // is there at least 4 to do? + ble .Ldgemm_kernel_L2_Mv1_44 + + .align 5 +.Ldgemm_kernel_L2_Mv1_22: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + + subs counterL, counterL, #1 + bgt .Ldgemm_kernel_L2_Mv1_22 + +.Ldgemm_kernel_L2_Mv1_44: + + ands counterL , origK, #7 + ble .Ldgemm_kernel_L2_Mv1_100 + + .align 5 +.Ldgemm_kernel_L2_Mv1_46: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x2_SUB + + subs counterL, counterL, #1 + bne .Ldgemm_kernel_L2_Mv1_46 + +.Ldgemm_kernel_L2_Mv1_100: + prfm PLDL1KEEP, [pA] + prfm PLDL1KEEP, [pA, #64] + prfm PLDL1KEEP, [origPB] + + SAVEv1x2 + +.Ldgemm_kernel_L2_Mv1_END: + + incw counterI + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + b.any .Ldgemm_kernel_L2_Mv1_20 + + +.Ldgemm_kernel_L2_END: + add origPB, origPB, origK, lsl #3 // B = B + K * 2 * 4 + +/******************************************************************************/ +/* Repeat the same thing if 1 left in N */ + + .align 5 +.Ldgemm_kernel_L1_BEGIN: + + mov counterJ , origN + tst counterJ , #1 + ble .Ldgemm_kernel_L999 // done + + mov pCRow0, pC + + add pC, pC, LDC // add 1 x LDC + + mov pA, origPA // pA = start of A array + +.Ldgemm_kernel_L1_Mv1_BEGIN: + + mov counterI, #0 + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + + .align 5 +.Ldgemm_kernel_L1_Mv1_20: + + mov pB, origPB + INITv1x1 // fill with zeros + + asr counterL , origK, #3 // L = K / 8 + cmp counterL , #0 // is there at least 8 to do? + ble .Ldgemm_kernel_L1_Mv1_44 + + .align 5 +.Ldgemm_kernel_L1_Mv1_22: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + + subs counterL, counterL, #1 + bgt .Ldgemm_kernel_L1_Mv1_22 + +.Ldgemm_kernel_L1_Mv1_44: + + ands counterL , origK, #7 + ble .Ldgemm_kernel_L1_Mv1_100 + + .align 5 +.Ldgemm_kernel_L1_Mv1_46: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x1_SUB + + subs counterL, counterL, #1 + bgt .Ldgemm_kernel_L1_Mv1_46 + +.Ldgemm_kernel_L1_Mv1_100: + prfm PLDL1KEEP, [pA] + prfm PLDL1KEEP, [pA, #64] + prfm PLDL1KEEP, [origPB] + + SAVEv1x1 + +.Ldgemm_kernel_L1_Mv1_END: + + incw counterI + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + b.any .Ldgemm_kernel_L1_Mv1_20 + + +.Ldgemm_kernel_L1_END: + +/******************************************************************************/ + +.Ldgemm_kernel_L999: + mov x0, #0 // set return value + ldp d8, d9, [sp, #(0 * 16)] + ldp d10, d11, [sp, #(1 * 16)] + ldp d12, d13, [sp, #(2 * 16)] + ldp d14, d15, [sp, #(3 * 16)] + ldp d16, d17, [sp, #(4 * 16)] + ldp x18, x19, [sp, #(5 * 16)] + ldp x20, x21, [sp, #(6 * 16)] + ldp x22, x23, [sp, #(7 * 16)] + ldp x24, x25, [sp, #(8 * 16)] + ldp x26, x27, [sp, #(9 * 16)] + ldr x28, [sp, #(10 * 16)] + add sp, sp, #(11*16) + ret + + EPILOGUE + diff --git a/kernel/arm64/sgemm_ncopy_sve_v1.c b/kernel/arm64/sgemm_ncopy_sve_v1.c new file mode 100644 index 000000000..1bc186335 --- /dev/null +++ b/kernel/arm64/sgemm_ncopy_sve_v1.c @@ -0,0 +1,78 @@ +/*********************************************************************/ +/* Copyright 2009, 2010 The University of Texas at Austin. */ +/* All rights reserved. */ +/* */ +/* Redistribution and use in source and binary forms, with or */ +/* without modification, are permitted provided that the following */ +/* conditions are met: */ +/* */ +/* 1. Redistributions of source code must retain the above */ +/* copyright notice, this list of conditions and the following */ +/* disclaimer. */ +/* */ +/* 2. Redistributions in binary form must reproduce the above */ +/* copyright notice, this list of conditions and the following */ +/* disclaimer in the documentation and/or other materials */ +/* provided with the distribution. */ +/* */ +/* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */ +/* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */ +/* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */ +/* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */ +/* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */ +/* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */ +/* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */ +/* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */ +/* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */ +/* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */ +/* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */ +/* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */ +/* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */ +/* POSSIBILITY OF SUCH DAMAGE. */ +/* */ +/* The views and conclusions contained in the software and */ +/* documentation are those of the authors and should not be */ +/* interpreted as representing official policies, either expressed */ +/* or implied, of The University of Texas at Austin. */ +/*********************************************************************/ + +#include +#include "common.h" +#include + +// TODO: write in assembly with proper unrolling of inner loop +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){ + + BLASLONG j; + IFLOAT *aoffset, *aoffset1, *boffset; + + svint32_t lda_vec = svindex_s32(0LL, lda); + uint32_t sve_size = svcntw(); + + aoffset = a; + boffset = b; + + j = 0; + svbool_t pg = svwhilelt_b32(j, n); + uint32_t active = svcntp_b32(svptrue_b32(), pg); + do { + + aoffset1 = aoffset; + + uint32_t i_cnt = m; + while (i_cnt--) { + svfloat32_t a_vec = svld1_gather_index(pg, (float *) aoffset1, lda_vec); + svst1_f32(pg, (float *) boffset, a_vec); + aoffset1++; + boffset += active; + } + aoffset += sve_size * lda; + + j += svcntw(); + pg = svwhilelt_b32(j, n); + active = svcntp_b32(svptrue_b32(), pg); + + } while (svptest_any(svptrue_b32(), pg)); + + return 0; +} diff --git a/kernel/arm64/sgemm_tcopy_sve_v1.c b/kernel/arm64/sgemm_tcopy_sve_v1.c new file mode 100644 index 000000000..9f8cf502a --- /dev/null +++ b/kernel/arm64/sgemm_tcopy_sve_v1.c @@ -0,0 +1,77 @@ +/*********************************************************************/ +/* Copyright 2009, 2010 The University of Texas at Austin. */ +/* All rights reserved. */ +/* */ +/* Redistribution and use in source and binary forms, with or */ +/* without modification, are permitted provided that the following */ +/* conditions are met: */ +/* */ +/* 1. Redistributions of source code must retain the above */ +/* copyright notice, this list of conditions and the following */ +/* disclaimer. */ +/* */ +/* 2. Redistributions in binary form must reproduce the above */ +/* copyright notice, this list of conditions and the following */ +/* disclaimer in the documentation and/or other materials */ +/* provided with the distribution. */ +/* */ +/* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */ +/* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */ +/* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */ +/* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */ +/* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */ +/* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */ +/* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */ +/* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */ +/* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */ +/* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */ +/* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */ +/* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */ +/* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */ +/* POSSIBILITY OF SUCH DAMAGE. */ +/* */ +/* The views and conclusions contained in the software and */ +/* documentation are those of the authors and should not be */ +/* interpreted as representing official policies, either expressed */ +/* or implied, of The University of Texas at Austin. */ +/*********************************************************************/ + +#include +#include "common.h" +#include + +// TODO: write in assembly with proper unrolling of inner loop +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){ + + BLASLONG j; + IFLOAT *aoffset, *aoffset1, *boffset; + + uint32_t sve_size = svcntw(); + + aoffset = a; + boffset = b; + + j = 0; + svbool_t pg = svwhilelt_b32(j, n); + uint32_t active = svcntp_b32(svptrue_b32(), pg); + do { + + aoffset1 = aoffset; + + uint32_t i_cnt = m; + while (i_cnt--) { + svfloat32_t a_vec = svld1(pg, (float *) aoffset1); + svst1_f32(pg, (float *) boffset, a_vec); + aoffset1 += lda; + boffset += active; + } + aoffset += sve_size; + + j += svcntw(); + pg = svwhilelt_b32(j, n); + active = svcntp_b32(svptrue_b32(), pg); + + } while (svptest_any(svptrue_b32(), pg)); + + return 0; +} diff --git a/kernel/arm64/symm_lcopy_sve.c b/kernel/arm64/symm_lcopy_sve.c index 94a68ad7c..6ba4afc8b 100644 --- a/kernel/arm64/symm_lcopy_sve.c +++ b/kernel/arm64/symm_lcopy_sve.c @@ -44,6 +44,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON BLASLONG i, offset; +#if defined(DOUBLE) uint64_t sve_size = svcntd(); svint64_t posY_vec = svdup_s64(posY); svint64_t posX_vec = svdup_s64(posX); @@ -89,5 +90,54 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON active = svcntp_b64(svptrue_b64(), pg); } while (svptest_any(svptrue_b64(), pg)); +#else + uint32_t sve_size = svcntw(); + svint32_t posY_vec = svdup_s32(posY); + svint32_t posX_vec = svdup_s32(posX); + svint32_t lda_vec = svdup_s32(lda); + svint32_t one_vec = svdup_s32(1); + + int32_t N = n; + int32_t j = 0; + svbool_t pg = svwhilelt_b32(j, N); + int32_t active = svcntp_b32(svptrue_b32(), pg); + svint32_t index_neg = svindex_s32(0, -1); + svint32_t index = svindex_s32(0, 1); + do { + offset = posX - posY; + svint32_t vec_off = svdup_s32(offset); + svbool_t cmp = svcmpgt(pg, vec_off, index_neg); + + svint32_t temp = svadd_z(pg, posX_vec, index); + svint32_t temp1 = svmla_z(pg, temp, posY_vec, lda_vec); + svint32_t temp2 = svmla_z(pg, posY_vec, temp, lda); + svint32_t gat_ind = svsel(cmp, temp1, temp2); + + i = m; + while (i>0) { + svfloat32_t data_vec = svld1_gather_index(pg, a, gat_ind); + + gat_ind = svadd_m(cmp, gat_ind, lda_vec); + gat_ind = svadd_m(svnot_z(pg, cmp) , gat_ind, one_vec); + + svst1(pg, b, data_vec); + + b += active; + offset --; + vec_off = svsub_z(pg, vec_off, one_vec); + cmp = svcmpgt(pg, vec_off, index_neg); + + i--; + } + + posX += sve_size; + posX_vec = svdup_s32(posX); + j += sve_size; + pg = svwhilelt_b32(j, N); + active = svcntp_b32(svptrue_b32(), pg); + } while (svptest_any(svptrue_b32(), pg)); + +#endif + return 0; } diff --git a/kernel/arm64/symm_ucopy_sve.c b/kernel/arm64/symm_ucopy_sve.c index 3cf18e0fd..32da5bd16 100644 --- a/kernel/arm64/symm_ucopy_sve.c +++ b/kernel/arm64/symm_ucopy_sve.c @@ -44,6 +44,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON BLASLONG i, offset; +#if defined(DOUBLE) uint64_t sve_size = svcntd(); svint64_t posY_vec = svdup_s64(posY); svint64_t posX_vec = svdup_s64(posX); @@ -89,5 +90,54 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON active = svcntp_b64(svptrue_b64(), pg); } while (svptest_any(svptrue_b64(), pg)); +#else + uint32_t sve_size = svcntw(); + svint32_t posY_vec = svdup_s32(posY); + svint32_t posX_vec = svdup_s32(posX); + svint32_t lda_vec = svdup_s32(lda); + svint32_t one_vec = svdup_s32(1); + + int32_t N = n; + int32_t j = 0; + svbool_t pg = svwhilelt_b32(j, N); + int32_t active = svcntp_b32(svptrue_b32(), pg); + svint32_t index_neg = svindex_s32(0, -1); + svint32_t index = svindex_s32(0, 1); + do { + offset = posX - posY; + svint32_t vec_off = svdup_s32(offset); + svbool_t cmp = svcmpgt(pg, vec_off, index_neg); + + svint32_t temp = svadd_z(pg, posX_vec, index); + svint32_t temp1 = svmla_z(pg, temp, posY_vec, lda_vec); + svint32_t temp2 = svmla_z(pg, posY_vec, temp, lda); + svint32_t gat_ind = svsel(cmp, temp2, temp1); + + i = m; + while (i>0) { + svfloat32_t data_vec = svld1_gather_index(pg, a, gat_ind); + + gat_ind = svadd_m(cmp, gat_ind, one_vec); + gat_ind = svadd_m(svnot_z(pg, cmp) , gat_ind, lda_vec); + + svst1(pg, b, data_vec); + + b += active; + offset --; + vec_off = svsub_z(pg, vec_off, one_vec); + cmp = svcmpgt(pg, vec_off, index_neg); + + i--; + } + + posX += sve_size; + posX_vec = svdup_s32(posX); + j += sve_size; + pg = svwhilelt_b32(j, N); + active = svcntp_b32(svptrue_b32(), pg); + } while (svptest_any(svptrue_b32(), pg)); + +#endif + return 0; } diff --git a/param.h b/param.h index c1dff1367..e9419bd9d 100644 --- a/param.h +++ b/param.h @@ -3296,8 +3296,8 @@ is a big desktop or server with abundant cache rather than a phone or embedded d #elif defined(ARMV8SVE) || defined(A64FX) -#define SGEMM_DEFAULT_UNROLL_M 16 -#define SGEMM_DEFAULT_UNROLL_N 4 +#define SGEMM_DEFAULT_UNROLL_M 4 +#define SGEMM_DEFAULT_UNROLL_N 8 /* When all BLAS3 routines are implemeted with SVE, DGEMM_DEFAULT_UNROLL_M should be "sve_vl". Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy routines in both directions seperated. */ From 0de36f7b5ceea1c410ed98e62fd4748e9cc9324d Mon Sep 17 00:00:00 2001 From: Bine Brank Date: Mon, 29 Nov 2021 21:25:05 +0100 Subject: [PATCH 2/6] trmm sve copy fucntions for single precision --- kernel/arm64/trmm_lncopy_sve_v1.c | 21 ++++++++++++++++++--- kernel/arm64/trmm_ltcopy_sve_v1.c | 15 +++++++++++++++ kernel/arm64/trmm_uncopy_sve_v1.c | 21 ++++++++++++++++++--- kernel/arm64/trmm_utcopy_sve_v1.c | 15 +++++++++++++++ 4 files changed, 66 insertions(+), 6 deletions(-) diff --git a/kernel/arm64/trmm_lncopy_sve_v1.c b/kernel/arm64/trmm_lncopy_sve_v1.c index fc1b61325..918e945ac 100644 --- a/kernel/arm64/trmm_lncopy_sve_v1.c +++ b/kernel/arm64/trmm_lncopy_sve_v1.c @@ -48,12 +48,17 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON BLASLONG i, js; BLASLONG X; - svint64_t index = svindex_s64(0LL, lda); - - FLOAT *ao; js = 0; + FLOAT *ao; +#ifdef DOUBLE + svint64_t index = svindex_s64(0LL, lda); svbool_t pn = svwhilelt_b64(js, n); int n_active = svcntp_b64(svptrue_b64(), pn); +#else + svint32_t index = svindex_s32(0, lda); + svbool_t pn = svwhilelt_b32(js, n); + int n_active = svcntp_b32(svptrue_b32(), pn); +#endif do { X = posX; @@ -68,7 +73,11 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON do { if (X > posY) { +#ifdef DOUBLE svfloat64_t aj_vec = svld1_gather_index(pn, ao, index); +#else + svfloat32_t aj_vec = svld1_gather_index(pn, ao, index); +#endif svst1(pn, b, aj_vec); ao ++; b += n_active; @@ -113,9 +122,15 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON posY += n_active; js += n_active; +#ifdef DOUBLE pn = svwhilelt_b64(js, n); n_active = svcntp_b64(svptrue_b64(), pn); } while (svptest_any(svptrue_b64(), pn)); +#else + pn = svwhilelt_b32(js, n); + n_active = svcntp_b32(svptrue_b32(), pn); + } while (svptest_any(svptrue_b32(), pn)); +#endif return 0; } diff --git a/kernel/arm64/trmm_ltcopy_sve_v1.c b/kernel/arm64/trmm_ltcopy_sve_v1.c index 14c6762d2..b76cc56de 100644 --- a/kernel/arm64/trmm_ltcopy_sve_v1.c +++ b/kernel/arm64/trmm_ltcopy_sve_v1.c @@ -50,8 +50,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON FLOAT *ao; js = 0; +#ifdef DOUBLE svbool_t pn = svwhilelt_b64(js, n); int n_active = svcntp_b64(svptrue_b64(), pn); +#else + svbool_t pn = svwhilelt_b32(js, n); + int n_active = svcntp_b32(svptrue_b32(), pn); +#endif do { X = posX; @@ -72,7 +77,11 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON i ++; } else if (X < posY) { +#ifdef DOUBLE svfloat64_t aj_vec = svld1(pn, ao); +#else + svfloat32_t aj_vec = svld1(pn, ao); +#endif svst1(pn, b, aj_vec); ao += lda; b += n_active; @@ -112,9 +121,15 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON posY += n_active; js += n_active; +#ifdef DOUBLE pn = svwhilelt_b64(js, n); n_active = svcntp_b64(svptrue_b64(), pn); } while (svptest_any(svptrue_b64(), pn)); +#else + pn = svwhilelt_b32(js, n); + n_active = svcntp_b32(svptrue_b32(), pn); + } while (svptest_any(svptrue_b32(), pn)); +#endif return 0; diff --git a/kernel/arm64/trmm_uncopy_sve_v1.c b/kernel/arm64/trmm_uncopy_sve_v1.c index b8344d474..75fa163ae 100644 --- a/kernel/arm64/trmm_uncopy_sve_v1.c +++ b/kernel/arm64/trmm_uncopy_sve_v1.c @@ -48,12 +48,17 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON BLASLONG i, js; BLASLONG X; - svint64_t index = svindex_s64(0LL, lda); - - FLOAT *ao; js = 0; + FLOAT *ao; +#ifdef DOUBLE + svint64_t index = svindex_s64(0LL, lda); svbool_t pn = svwhilelt_b64(js, n); int n_active = svcntp_b64(svptrue_b64(), pn); +#else + svint32_t index = svindex_s32(0, lda); + svbool_t pn = svwhilelt_b32(js, n); + int n_active = svcntp_b32(svptrue_b32(), pn); +#endif do { X = posX; @@ -68,7 +73,11 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON do { if (X < posY) { +#ifdef DOUBLE svfloat64_t aj_vec = svld1_gather_index(pn, ao, index); +#else + svfloat32_t aj_vec = svld1_gather_index(pn, ao, index); +#endif svst1(pn, b, aj_vec); ao ++; b += n_active; @@ -113,9 +122,15 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON posY += n_active; js += n_active; +#ifdef DOUBLE pn = svwhilelt_b64(js, n); n_active = svcntp_b64(svptrue_b64(), pn); } while (svptest_any(svptrue_b64(), pn)); +#else + pn = svwhilelt_b32(js, n); + n_active = svcntp_b32(svptrue_b32(), pn); + } while (svptest_any(svptrue_b32(), pn)); +#endif return 0; } diff --git a/kernel/arm64/trmm_utcopy_sve_v1.c b/kernel/arm64/trmm_utcopy_sve_v1.c index 9be1c0abb..36a03242a 100644 --- a/kernel/arm64/trmm_utcopy_sve_v1.c +++ b/kernel/arm64/trmm_utcopy_sve_v1.c @@ -50,8 +50,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON FLOAT *ao; js = 0; +#ifdef DOUBLE svbool_t pn = svwhilelt_b64(js, n); int n_active = svcntp_b64(svptrue_b64(), pn); +#else + svbool_t pn = svwhilelt_b32(js, n); + int n_active = svcntp_b32(svptrue_b32(), pn); +#endif do { X = posX; @@ -72,7 +77,11 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON i ++; } else if (X > posY) { +#ifdef DOUBLE svfloat64_t aj_vec = svld1(pn, ao); +#else + svfloat32_t aj_vec = svld1(pn, ao); +#endif svst1(pn, b, aj_vec); ao += lda; b += n_active; @@ -111,9 +120,15 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLON posY += n_active; js += n_active; +#ifdef DOUBLE pn = svwhilelt_b64(js, n); n_active = svcntp_b64(svptrue_b64(), pn); } while (svptest_any(svptrue_b64(), pn)); +#else + pn = svwhilelt_b32(js, n); + n_active = svcntp_b32(svptrue_b32(), pn); + } while (svptest_any(svptrue_b32(), pn)); +#endif return 0; } From abe1ce3434c3cd6df0c8d00650d1722a31bff784 Mon Sep 17 00:00:00 2001 From: Bine Brank Date: Sun, 5 Dec 2021 14:03:08 +0100 Subject: [PATCH 3/6] strmm sve v1x8 kernel --- kernel/arm64/strmm_kernel_sve_v1x8.S | 1008 ++++++++++++++++++++++++++ 1 file changed, 1008 insertions(+) create mode 100644 kernel/arm64/strmm_kernel_sve_v1x8.S diff --git a/kernel/arm64/strmm_kernel_sve_v1x8.S b/kernel/arm64/strmm_kernel_sve_v1x8.S new file mode 100644 index 000000000..3c45e3e29 --- /dev/null +++ b/kernel/arm64/strmm_kernel_sve_v1x8.S @@ -0,0 +1,1008 @@ +/******************************************************************************* +Copyright (c) 2015, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +#define ASSEMBLER +#include "common.h" + +/* X0 X1 X2 s0 X3 x4 x5 x6 */ +/*int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha0,FLOAT* ba,FLOAT* bb,FLOAT* C,BLASLONG ldc )*/ + +#define origM x0 +#define origN x1 +#define origK x2 +#define origPA x3 +#define origPB x4 +#define pC x5 +#define LDC x6 +#define offset x7 +#define counterL x8 +#define counterI x9 +#define counterJ x10 +#define pB x11 +#define pCRow0 x12 +#define pCRow1 x13 +#define pCRow2 x14 + +#define lanes x15 +#define pA x16 +#define alpha w17 +//#define temp x18 +#define tempOffset x19 +#define tempK x20 +#define temp x21 + +#define alpha0 s10 +#define alphaZ z2.s + +#define A_PRE_SIZE 1536 +#define B_PRE_SIZE 512 +#define C_PRE_SIZE 128 + +// 00 origM +// 01 origN +// 02 origK +// 03 origPA +// 04 origPB +// 05 pC +// 06 origLDC -> LDC +// 07 temp +// 08 counterL +// 09 counterI +// 10 counterJ +// 11 pB +// 12 pCRow0 +// 13 pCRow1 +// 14 pCRow2 +// 15 lanes +// 16 pA +// 17 +// 18 must save +// 19 must save +// 20 must save +// 21 must save +// 22 must save +// 23 must save +// 24 must save +// 25 must save +// 26 must save +// 27 must save +// 28 must save +// 29 frame +// 30 link +// 31 sp + +//v00 ALPHA -> pA0_0 +//v01 pA0_1 +//v02 ALPHA0 +//v03 +//v04 +//v05 +//v06 +//v07 +//v08 must save pB0_0 +//v09 must save pB0_1 +//v10 must save pB0_2 +//v11 must save pB0_3 +//v12 must save pB0_4 +//v13 must save pB0_5 +//v14 must save pB0_6 +//v15 must save pB0_7 +//v16 must save C0 +//v17 must save C1 +//v18 must save C2 +//v19 must save C3 +//v20 must save C4 +//v21 must save C5 +//v22 must save C6 +//v23 must save C7 + +/******************************************************************************* +* Macro definitions +*******************************************************************************/ + +.macro INITv1x8 + dup z16.s, #0 + dup z17.s, #0 + dup z18.s, #0 + dup z19.s, #0 + dup z20.s, #0 + dup z21.s, #0 + dup z22.s, #0 + dup z23.s, #0 +.endm + +.macro KERNELv1x8_I + ld1w z0.s, p1/z, [pA] + ld1w z1.s, p1/z, [pA, lanes, lsl #2] // next one + add pA, pA, lanes, lsl #3 // pA = pA + lanes * 2 * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + ld1rw z10.s, p0/z, [pB, 8] + ld1rw z11.s, p0/z, [pB, 12] + ld1rw z12.s, p0/z, [pB, 16] + ld1rw z13.s, p0/z, [pB, 20] + ld1rw z14.s, p0/z, [pB, 24] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 + + fmla z16.s, p1/m, z0.s, z8.s + ld1rw z8.s, p0/z, [pB] + fmla z17.s, p1/m, z0.s, z9.s + ld1rw z9.s, p0/z, [pB, 4] + fmla z18.s, p1/m, z0.s, z10.s + ld1rw z10.s, p0/z, [pB, 8] + fmla z19.s, p1/m, z0.s, z11.s + ld1rw z11.s, p0/z, [pB, 12] + fmla z20.s, p1/m, z0.s, z12.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE] + ld1rw z12.s, p0/z, [pB, 16] + fmla z21.s, p1/m, z0.s, z13.s + ld1rw z13.s, p0/z, [pB, 20] + fmla z22.s, p1/m, z0.s, z14.s + ld1rw z14.s, p0/z, [pB, 24] + fmla z23.s, p1/m, z0.s, z15.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE+64] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 +.endm + +.macro KERNELv1x8_M1 + ld1w z1.s, p1/z, [pA] + add pA, pA, lanes, lsl #2 // pA = pA + lanes * 4 + + fmla z16.s, p1/m, z0.s, z8.s + ld1rw z8.s, p0/z, [pB] + fmla z17.s, p1/m, z0.s, z9.s + ld1rw z9.s, p0/z, [pB, 4] + fmla z18.s, p1/m, z0.s, z10.s + ld1rw z10.s, p0/z, [pB, 8] + fmla z19.s, p1/m, z0.s, z11.s + ld1rw z11.s, p0/z, [pB, 12] + fmla z20.s, p1/m, z0.s, z12.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE] + ld1rw z12.s, p0/z, [pB, 16] + fmla z21.s, p1/m, z0.s, z13.s + ld1rw z13.s, p0/z, [pB, 20] + fmla z22.s, p1/m, z0.s, z14.s + ld1rw z14.s, p0/z, [pB, 24] + fmla z23.s, p1/m, z0.s, z15.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE+64] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 +.endm + +.macro KERNELv1x8_M2 + ld1w z0.s, p1/z, [pA] + add pA, pA, lanes, lsl #2 // pA = pA + lanes * 4 + + fmla z16.s, p1/m, z1.s, z8.s + ld1rw z8.s, p0/z, [pB] + fmla z17.s, p1/m, z1.s, z9.s + ld1rw z9.s, p0/z, [pB, 4] + fmla z18.s, p1/m, z1.s, z10.s + ld1rw z10.s, p0/z, [pB, 8] + fmla z19.s, p1/m, z1.s, z11.s + ld1rw z11.s, p0/z, [pB, 12] + fmla z20.s, p1/m, z1.s, z12.s + ld1rw z12.s, p0/z, [pB, 16] + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + fmla z21.s, p1/m, z1.s, z13.s + ld1rw z13.s, p0/z, [pB, 20] + fmla z22.s, p1/m, z1.s, z14.s + ld1rw z14.s, p0/z, [pB, 24] + fmla z23.s, p1/m, z1.s, z15.s + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 +.endm + +.macro KERNELv1x8_E + fmla z16.s, p1/m, z1.s, z8.s + fmla z17.s, p1/m, z1.s, z9.s + fmla z18.s, p1/m, z1.s, z10.s + fmla z19.s, p1/m, z1.s, z11.s + fmla z20.s, p1/m, z1.s, z12.s + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + fmla z21.s, p1/m, z1.s, z13.s + fmla z22.s, p1/m, z1.s, z14.s + fmla z23.s, p1/m, z1.s, z15.s +.endm + +.macro KERNELv1x8_SUB + ld1w z0.s, p1/z, [pA] + add pA, pA, lanes, lsl #2 // pA = pA + lanes * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + ld1rw z10.s, p0/z, [pB, 8] + ld1rw z11.s, p0/z, [pB, 12] + ld1rw z12.s, p0/z, [pB, 16] + ld1rw z13.s, p0/z, [pB, 20] + ld1rw z14.s, p0/z, [pB, 24] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 + + fmla z16.s, p1/m, z0.s, z8.s + fmla z17.s, p1/m, z0.s, z9.s + fmla z18.s, p1/m, z0.s, z10.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE] + fmla z19.s, p1/m, z0.s, z11.s + fmla z20.s, p1/m, z0.s, z12.s + fmla z21.s, p1/m, z0.s, z13.s + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + fmla z22.s, p1/m, z0.s, z14.s + fmla z23.s, p1/m, z0.s, z15.s + +.endm + +.macro SAVEv1x8 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + add pCRow1, pCRow0, LDC + fmul z16.s, p1/m, z16.s, alphaZ + st1w z16.s, p1, [pCRow0] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + fmul z17.s, p1/m, z17.s, alphaZ + st1w z17.s, p1, [pCRow1] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + fmul z18.s, p1/m, z18.s, alphaZ + st1w z18.s, p1, [pCRow2] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + fmul z19.s, p1/m, z19.s, alphaZ + st1w z19.s, p1, [pCRow1] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + fmul z20.s, p1/m, z20.s, alphaZ + st1w z20.s, p1, [pCRow2] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + fmul z21.s, p1/m, z21.s, alphaZ + st1w z21.s, p1, [pCRow1] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + fmul z22.s, p1/m, z22.s, alphaZ + st1w z22.s, p1, [pCRow2] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + fmul z23.s, p1/m, z23.s, alphaZ + st1w z23.s, p1, [pCRow1] + + add pCRow0, pCRow0, lanes, lsl #2 // pC = pC + lanes * 4 + +.endm + +/******************************************************************************/ + +.macro INITv1x4 + dup z16.s, #0 + dup z17.s, #0 + dup z18.s, #0 + dup z19.s, #0 +.endm + +.macro KERNELv1x4_SUB + ld1w z0.s, p1/z, [pA] + add pA, pA, lanes, lsl #2 // pA = pA + lanes * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + ld1rw z10.s, p0/z, [pB, 8] + ld1rw z11.s, p0/z, [pB, 12] + + add pB, pB, 16 + + fmla z16.s, p1/m, z0.s, z8.s + fmla z17.s, p1/m, z0.s, z9.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE] + fmla z18.s, p1/m, z0.s, z10.s + fmla z19.s, p1/m, z0.s, z11.s + +.endm + +.macro SAVEv1x4 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + add pCRow1, pCRow0, LDC + fmul z16.s, p1/m, z16.s, alphaZ + st1w z16.s, p1, [pCRow0] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + fmul z17.s, p1/m, z17.s, alphaZ + st1w z17.s, p1, [pCRow1] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + fmul z18.s, p1/m, z18.s, alphaZ + st1w z18.s, p1, [pCRow2] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + fmul z19.s, p1/m, z19.s, alphaZ + st1w z19.s, p1, [pCRow1] + + add pCRow0, pCRow0, lanes, lsl #2 // pC = pC + lanes * 4 + +.endm + +/******************************************************************************/ + +.macro INITv1x2 + dup z16.s, #0 + dup z17.s, #0 +.endm + +.macro KERNELv1x2_SUB + ld1w z0.s, p1/z, [pA] + add pA, pA, lanes, lsl #2 // pA = pA + lanes * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + + add pB, pB, 8 + + fmla z16.s, p1/m, z0.s, z8.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE] + fmla z17.s, p1/m, z0.s, z9.s + +.endm + +.macro SAVEv1x2 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + add pCRow1, pCRow0, LDC + fmul z16.s, p1/m, z16.s, alphaZ + st1w z16.s, p1, [pCRow0] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + fmul z17.s, p1/m, z17.s, alphaZ + st1w z17.s, p1, [pCRow1] + + add pCRow0, pCRow0, lanes, lsl #2 // pC = pC + lanes * 4 + +.endm + +/******************************************************************************/ + +.macro INITv1x1 + dup z16.s, #0 +.endm + +.macro KERNELv1x1_SUB + ld1w z0.s, p1/z, [pA] + add pA, pA, lanes, lsl #2 // pA = pA + lanes * 4 + + ld1rw z8.s, p0/z, [pB] + + add pB, pB, 4 + + fmla z16.s, p1/m, z0.s, z8.s + prfm PLDL1KEEP, [pA, #A_PRE_SIZE] + +.endm + +.macro SAVEv1x1 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + fmul z16.s, p1/m, z16.s, alphaZ + st1w z16.s, p1, [pCRow0] + + + add pCRow0, pCRow0, lanes, lsl #2 // pC = pC + lanes * 4 + +.endm + + +/******************************************************************************* +* End of macro definitions +*******************************************************************************/ + + PROLOGUE + + .align 5 + add sp, sp, #-(11 * 16) + stp d8, d9, [sp, #(0 * 16)] + stp d10, d11, [sp, #(1 * 16)] + stp d12, d13, [sp, #(2 * 16)] + stp d14, d15, [sp, #(3 * 16)] + stp d16, d17, [sp, #(4 * 16)] + stp x18, x19, [sp, #(5 * 16)] + stp x20, x21, [sp, #(6 * 16)] + stp x22, x23, [sp, #(7 * 16)] + stp x24, x25, [sp, #(8 * 16)] + stp x26, x27, [sp, #(9 * 16)] + str x28, [sp, #(10 * 16)] + + prfm PLDL1KEEP, [origPB] + prfm PLDL1KEEP, [origPA] + + fmov alpha, s0 + dup alphaZ, alpha + + lsl LDC, LDC, #2 // ldc = ldc * 8 + ptrue p0.s // create true predicate + +#if !defined(LEFT) + neg tempOffset, offset +#endif + + mov pB, origPB +// Loop over N + mov counterJ, origN + asr counterJ, counterJ, #3 // J = J / 8 + cmp counterJ, #0 + ble .Lstrmm_kernel_L4_BEGIN + +/******************************************************************************/ +/* Repeat this as long as there are 8 left in N */ + + .align 5 +.Lstrmm_kernel_L8_BEGIN: + mov pCRow0, pC + + add pC, pC, LDC, lsl #3 // add 8 x LDC + +#if defined(LEFT) + mov tempOffset, offset +#endif + + mov pA, origPA // pA = start of A array + +.Lstrmm_kernel_L8_Mv1_BEGIN: + +/* Loop over M is done in an SVE fashion. This has the benefit of the last M%SVE_LEN iterations being done in a single sweep */ + mov counterI, #0 + whilelt p1.s, counterI, origM + cntp lanes, p0, p1.s // lanes contain number of active SVE lanes in M dimension + + .align 5 +.Lstrmm_kernel_L8_Mv1_20: + +#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) + mov pB, origPB +#else + mov pB, origPB + mul temp, tempOffset, lanes + add pA, pA, temp, lsl #2 // add tempOffset*lanes*4 + lsl temp, tempOffset, #5 + add pB, pB, temp +#endif + +#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) + sub tempK, origK, tempOffset +#elif defined(LEFT) + add tempK, tempOffset, lanes +#else + add tempK, tempOffset, #8 +#endif + + INITv1x8 // fill with zeros + + asr counterL , tempK, #3 // L = K / 8 + cmp counterL , #2 // is there at least 4 to do? + blt .Lstrmm_kernel_L8_Mv1_32 + + KERNELv1x8_I + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + + subs counterL, counterL, #2 // subtract 2 + ble .Lstrmm_kernel_L8_Mv1_22a + + .align 5 +.Lstrmm_kernel_L8_Mv1_22: + + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + + subs counterL, counterL, #1 + bgt .Lstrmm_kernel_L8_Mv1_22 + + .align 5 +.Lstrmm_kernel_L8_Mv1_22a: + + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_E + + b .Lstrmm_kernel_L8_Mv1_44 + + .align 5 +.Lstrmm_kernel_L8_Mv1_32: + + tst counterL, #1 + ble .Lstrmm_kernel_L8_Mv1_40 + + KERNELv1x8_I + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_E + + + b .Lstrmm_kernel_L8_Mv1_44 + +.Lstrmm_kernel_L8_Mv1_40: + + INITv1x8 + +.Lstrmm_kernel_L8_Mv1_44: + + ands counterL , tempK, #7 + ble .Lstrmm_kernel_L8_Mv1_100 + + .align 5 +.Lstrmm_kernel_L8_Mv1_46: + + KERNELv1x8_SUB + + subs counterL, counterL, #1 + bne .Lstrmm_kernel_L8_Mv1_46 + +.Lstrmm_kernel_L8_Mv1_100: + prfm PLDL1KEEP, [pA] + prfm PLDL1KEEP, [pA, #64] + prfm PLDL1KEEP, [origPB] + + SAVEv1x8 + +#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) + sub tempK, origK, tempOffset +#if defined(LEFT) + sub tempK, tempK, lanes +#else + sub tempK, tempK, #8 +#endif + mul temp, tempK, lanes + add pA, pA, temp, lsl #2 // add tempOffset*lanes*4 + lsl temp, tempK, #5 + add pB, pB, temp +#endif +#if defined(LEFT) + add tempOffset, tempOffset, lanes +#endif + +.Lstrmm_kernel_L8_Mv1_END: + + incw counterI + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + b.any .Lstrmm_kernel_L8_Mv1_20 + +.Lstrmm_kernel_L8_END: + + lsl temp, origK, #5 + add origPB, origPB, temp // B = B + K * 8 * 4 + +#if !defined(LEFT) + add tempOffset, tempOffset, #8 +#endif + + subs counterJ, counterJ , #1 // j-- + bgt .Lstrmm_kernel_L8_BEGIN + +/******************************************************************************/ +/* Repeat the same thing if 4 left in N */ + + .align 5 +.Lstrmm_kernel_L4_BEGIN: + + mov counterJ , origN + tst counterJ , #4 + ble .Lstrmm_kernel_L2_BEGIN + +#if defined(LEFT) + mov tempOffset, offset +#endif + + mov pCRow0, pC + + add pC, pC, LDC, lsl #2 // add 4 x LDC + + mov pA, origPA // pA = start of A array + +.Lstrmm_kernel_L4_Mv1_BEGIN: + + mov counterI, #0 + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + + .align 5 +.Lstrmm_kernel_L4_Mv1_20: + +#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) + mov pB, origPB +#else + mov pB, origPB + mul temp, tempOffset, lanes + add pA, pA, temp, lsl #2 // add tempOffset*lanes*4 + lsl temp, tempOffset, #4 + add pB, pB, temp +#endif + +#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) + sub tempK, origK, tempOffset +#elif defined(LEFT) + add tempK, tempOffset, lanes +#else + add tempK, tempOffset, #4 +#endif + + INITv1x4 // fill with zeros + + asr counterL , tempK, #3 // L = K / 8 + cmp counterL , #0 // is there at least 4 to do? + ble .Lstrmm_kernel_L4_Mv1_44 + + .align 5 +.Lstrmm_kernel_L4_Mv1_22: + + KERNELv1x4_SUB + KERNELv1x4_SUB + KERNELv1x4_SUB + KERNELv1x4_SUB + KERNELv1x4_SUB + KERNELv1x4_SUB + KERNELv1x4_SUB + KERNELv1x4_SUB + + subs counterL, counterL, #1 + bgt .Lstrmm_kernel_L4_Mv1_22 + +.Lstrmm_kernel_L4_Mv1_44: + + ands counterL , tempK, #7 + ble .Lstrmm_kernel_L4_Mv1_100 + + .align 5 +.Lstrmm_kernel_L4_Mv1_46: + + KERNELv1x4_SUB + + subs counterL, counterL, #1 + bne .Lstrmm_kernel_L4_Mv1_46 + +.Lstrmm_kernel_L4_Mv1_100: + + SAVEv1x4 + +#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) + sub tempK, origK, tempOffset +#if defined(LEFT) + sub tempK, tempK, lanes +#else + sub tempK, tempK, #4 +#endif + mul temp, tempK, lanes + add pA, pA, temp, lsl #2 // add tempOffset*lanes*4 + lsl temp, tempK, #4 + add pB, pB, temp +#endif +#if defined(LEFT) + add tempOffset, tempOffset, lanes +#endif + +.Lstrmm_kernel_L4_Mv1_END: + + incw counterI + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + b.any .Lstrmm_kernel_L4_Mv1_20 + + +.Lstrmm_kernel_L4_END: + lsl temp, origK, #4 + add origPB, origPB, temp // B = B + K * 4 * 4 +#if !defined(LEFT) + add tempOffset, tempOffset, #4 +#endif + +/******************************************************************************/ +/* Repeat the same thing if 2 left in N */ + + .align 5 +.Lstrmm_kernel_L2_BEGIN: + + mov counterJ , origN + tst counterJ , #2 + ble .Lstrmm_kernel_L1_BEGIN + + mov pCRow0, pC + + add pC, pC, LDC, lsl #1 // add 2 x LDC + +#if defined(LEFT) + mov tempOffset, offset +#endif + + mov pA, origPA // pA = start of A array + +.Lstrmm_kernel_L2_Mv1_BEGIN: + + mov counterI, #0 + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + + .align 5 +.Lstrmm_kernel_L2_Mv1_20: + +#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) + mov pB, origPB +#else + mov pB, origPB + mul temp, tempOffset, lanes + add pA, pA, temp, lsl #2 // add tempOffset*lanes*4 + lsl temp, tempOffset, #3 + add pB, pB, temp +#endif + +#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) + sub tempK, origK, tempOffset +#elif defined(LEFT) + add tempK, tempOffset, lanes +#else + add tempK, tempOffset, #2 +#endif + + INITv1x2 // fill with zeros + + asr counterL , tempK, #3 // L = K / 8 + cmp counterL , #0 // is there at least 4 to do? + ble .Lstrmm_kernel_L2_Mv1_44 + + .align 5 +.Lstrmm_kernel_L2_Mv1_22: + + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + + subs counterL, counterL, #1 + bgt .Lstrmm_kernel_L2_Mv1_22 + +.Lstrmm_kernel_L2_Mv1_44: + + ands counterL , tempK, #7 + ble .Lstrmm_kernel_L2_Mv1_100 + + .align 5 +.Lstrmm_kernel_L2_Mv1_46: + + KERNELv1x2_SUB + + subs counterL, counterL, #1 + bne .Lstrmm_kernel_L2_Mv1_46 + +.Lstrmm_kernel_L2_Mv1_100: + + SAVEv1x2 + +#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) + sub tempK, origK, tempOffset +#if defined(LEFT) + sub tempK, tempK, lanes +#else + sub tempK, tempK, #2 +#endif + mul temp, tempK, lanes + add pA, pA, temp, lsl #2 // add tempOffset*lanes*4 + lsl temp, tempK, #3 + add pB, pB, temp +#endif +#if defined(LEFT) + add tempOffset, tempOffset, lanes +#endif + + +.Lstrmm_kernel_L2_Mv1_END: + + incw counterI + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + b.any .Lstrmm_kernel_L2_Mv1_20 + + +.Lstrmm_kernel_L2_END: + add origPB, origPB, origK, lsl #3 // B = B + K * 2 * 4 +#if !defined(LEFT) + add tempOffset, tempOffset, #2 +#endif + +/******************************************************************************/ +/* Repeat the same thing if 1 left in N */ + + .align 5 +.Lstrmm_kernel_L1_BEGIN: + + mov counterJ , origN + tst counterJ , #1 + ble .Lstrmm_kernel_L999 // done + + mov pCRow0, pC + + add pC, pC, LDC // add 1 x LDC + +#if defined(LEFT) + mov tempOffset, offset +#endif + + mov pA, origPA // pA = start of A array + +.Lstrmm_kernel_L1_Mv1_BEGIN: + + mov counterI, #0 + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + + .align 5 +.Lstrmm_kernel_L1_Mv1_20: + +#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) + mov pB, origPB +#else + mov pB, origPB + mul temp, tempOffset, lanes + add pA, pA, temp, lsl #2 // add tempOffset*lanes*4 + lsl temp, tempOffset, #2 + add pB, pB, temp +#endif + +#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) + sub tempK, origK, tempOffset +#elif defined(LEFT) + add tempK, tempOffset, lanes +#else + add tempK, tempOffset, #1 +#endif + + INITv1x1 // fill with zeros + + asr counterL , tempK, #3 // L = K / 8 + cmp counterL , #0 // is there at least 8 to do? + ble .Lstrmm_kernel_L1_Mv1_44 + + .align 5 +.Lstrmm_kernel_L1_Mv1_22: + + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + + subs counterL, counterL, #1 + bgt .Lstrmm_kernel_L1_Mv1_22 + +.Lstrmm_kernel_L1_Mv1_44: + + ands counterL , tempK, #7 + ble .Lstrmm_kernel_L1_Mv1_100 + + .align 5 +.Lstrmm_kernel_L1_Mv1_46: + + KERNELv1x1_SUB + + subs counterL, counterL, #1 + bgt .Lstrmm_kernel_L1_Mv1_46 + +.Lstrmm_kernel_L1_Mv1_100: + + SAVEv1x1 + +#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) + sub tempK, origK, tempOffset +#if defined(LEFT) + sub tempK, tempK, lanes +#else + sub tempK, tempK, #1 +#endif + mul temp, tempK, lanes + add pA, pA, temp, lsl #2 // add tempOffset*lanes*4 + lsl temp, tempK, #2 + add pB, pB, temp +#endif +#if defined(LEFT) + add tempOffset, tempOffset, lanes +#endif + + + +.Lstrmm_kernel_L1_Mv1_END: + + incw counterI + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + b.any .Lstrmm_kernel_L1_Mv1_20 + + +.Lstrmm_kernel_L1_END: + +/******************************************************************************/ + +.Lstrmm_kernel_L999: + mov x0, #0 // set return value + ldp d8, d9, [sp, #(0 * 16)] + ldp d10, d11, [sp, #(1 * 16)] + ldp d12, d13, [sp, #(2 * 16)] + ldp d14, d15, [sp, #(3 * 16)] + ldp d16, d17, [sp, #(4 * 16)] + ldp x18, x19, [sp, #(5 * 16)] + ldp x20, x21, [sp, #(6 * 16)] + ldp x22, x23, [sp, #(7 * 16)] + ldp x24, x25, [sp, #(8 * 16)] + ldp x26, x27, [sp, #(9 * 16)] + ldr x28, [sp, #(10 * 16)] + add sp, sp, #(11*16) + ret + + EPILOGUE + From a1fea1fe2aed7d5169d85b132195c3a80116599f Mon Sep 17 00:00:00 2001 From: Bine Brank Date: Sun, 5 Dec 2021 18:47:29 +0100 Subject: [PATCH 4/6] sgemm v2x8 SVE kernel --- kernel/arm64/sgemm_kernel_sve_v2x8.S | 1683 ++++++++++++++++++++++++++ 1 file changed, 1683 insertions(+) create mode 100644 kernel/arm64/sgemm_kernel_sve_v2x8.S diff --git a/kernel/arm64/sgemm_kernel_sve_v2x8.S b/kernel/arm64/sgemm_kernel_sve_v2x8.S new file mode 100644 index 000000000..1cdd8253e --- /dev/null +++ b/kernel/arm64/sgemm_kernel_sve_v2x8.S @@ -0,0 +1,1683 @@ +/******************************************************************************* +Copyright (c) 2015, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +/* This is an SVE sgemm kernel with size 2*SVE_LEN x 8. +However, the data layout is the same as for the kernel 1*SVE_LEN x 8. +This means that we sweep two panels of packed A when iterating in a loop over K. +With this approach, we can reuse sgemm_n|tcopy_sve_v1.c packing functions. */ + +#define ASSEMBLER +#include "common.h" + +/* X0 X1 X2 s0 X3 x4 x5 x6 */ +/*int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha0,FLOAT* ba,FLOAT* bb,FLOAT* C,BLASLONG ldc )*/ + +#define origM x0 +#define origN x1 +#define origK x2 +#define origPA x3 +#define origPB x4 +#define pC x5 +#define LDC x6 +#define temp x7 +#define counterL x8 +#define counterI x9 +#define counterJ x10 +#define pB x11 +#define pCRow0 x12 +#define pCRow1 x13 +#define pCRow2 x14 + +#define lanes x15 +#define pA1 x16 +#define pA2 x17 +#define alpha w18 +#define vec_len x19 +#define vec_lenx2 x20 + +#define alpha0 s10 +#define alphaZ z7.s + +#define A_PRE_SIZE 1536 +#define B_PRE_SIZE 512 +#define C_PRE_SIZE 128 + +// 00 origM +// 01 origN +// 02 origK +// 03 origPA +// 04 origPB +// 05 pC +// 06 origLDC -> LDC +// 07 temp +// 08 counterL +// 09 counterI +// 10 counterJ +// 11 pB +// 12 pCRow0 +// 13 pCRow1 +// 14 pCRow2 +// 15 lanes +// 16 pA1 +// 17 pA1 +// 18 must save alpha +// 19 must save vec_len +// 20 must save +// 21 must save +// 22 must save +// 23 must save +// 24 must save +// 25 must save +// 26 must save +// 27 must save +// 28 must save +// 29 frame +// 30 link +// 31 sp + +//v00 ALPHA -> pA10_0 +//v01 pA10_1 +//v02 pA20_0 +//v03 pA20_1 +//v04 +//v05 +//v06 +//v07 ALPHA0 +//v08 must save pB0_0 +//v09 must save pB0_1 +//v10 must save pB0_2 +//v11 must save pB0_3 +//v12 must save pB0_4 +//v13 must save pB0_5 +//v14 must save pB0_6 +//v15 must save pB0_7 +//v16 must save C0 +//v17 must save C1 +//v18 must save C2 +//v19 must save C3 +//v20 must save C4 +//v21 must save C5 +//v22 must save C6 +//v23 must save C7 +//v24 must save C8 +//v25 must save C9 +//v26 must save C10 +//v27 must save C11 +//v28 must save C12 +//v29 must save C13 +//v30 must save C14 +//v31 must save C15 + +/******************************************************************************* +* Macro definitions +*******************************************************************************/ + +.macro INITv2x8 + dup z16.s, #0 + dup z17.s, #0 + dup z18.s, #0 + dup z19.s, #0 + dup z20.s, #0 + dup z21.s, #0 + dup z22.s, #0 + dup z23.s, #0 + dup z24.s, #0 + dup z25.s, #0 + dup z26.s, #0 + dup z27.s, #0 + dup z28.s, #0 + dup z29.s, #0 + dup z30.s, #0 + dup z31.s, #0 +.endm + +.macro KERNELv2x8_I + ld1w z0.s, p0/z, [pA1] + ld1w z1.s, p0/z, [pA2] + ld1w z2.s, p0/z, [pA1, vec_len, lsl #2] + ld1w z3.s, p0/z, [pA2, vec_len, lsl #2] + add pA1, pA1, vec_len, lsl #3 // pA1 = pA1 + vec_len * 4 *2 + add pA2, pA2, vec_len, lsl #3 // pA1 = pA1 + vec_len * 4 *2 + + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + ld1rw z10.s, p0/z, [pB, 8] + ld1rw z11.s, p0/z, [pB, 12] + ld1rw z12.s, p0/z, [pB, 16] + ld1rw z13.s, p0/z, [pB, 20] + ld1rw z14.s, p0/z, [pB, 24] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 + + fmla z16.s, p0/m, z0.s, z8.s + fmla z17.s, p0/m, z1.s, z8.s + ld1rw z8.s, p0/z, [pB] + fmla z18.s, p0/m, z0.s, z9.s + fmla z19.s, p0/m, z1.s, z9.s + ld1rw z9.s, p0/z, [pB, 4] + fmla z20.s, p0/m, z0.s, z10.s + fmla z21.s, p0/m, z1.s, z10.s + ld1rw z10.s, p0/z, [pB, 8] + fmla z22.s, p0/m, z0.s, z11.s + fmla z23.s, p0/m, z1.s, z11.s + ld1rw z11.s, p0/z, [pB, 12] + fmla z24.s, p0/m, z0.s, z12.s + fmla z25.s, p0/m, z1.s, z12.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE] + ld1rw z12.s, p0/z, [pB, 16] + fmla z26.s, p0/m, z0.s, z13.s + fmla z27.s, p0/m, z1.s, z13.s + prfm PLDL1KEEP, [pA2, #A_PRE_SIZE] + ld1rw z13.s, p0/z, [pB, 20] + fmla z28.s, p0/m, z0.s, z14.s + fmla z29.s, p0/m, z1.s, z14.s + ld1rw z14.s, p0/z, [pB, 24] + fmla z30.s, p0/m, z0.s, z15.s + fmla z31.s, p0/m, z1.s, z15.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE+64] + ld1rw z15.s, p0/z, [pB, 28] + prfm PLDL1KEEP, [pA2, #A_PRE_SIZE+64] + + add pB, pB, 32 +.endm + +.macro KERNELv2x8_M1 + ld1w z2.s, p0/z, [pA1] + ld1w z3.s, p0/z, [pA2] + add pA1, pA1, vec_len, lsl #2 // pA1 = pA1 + vec_len * 4 + add pA2, pA2, vec_len, lsl #2 // pA1 = pA1 + vec_len * 4 + + fmla z16.s, p0/m, z0.s, z8.s + fmla z17.s, p0/m, z1.s, z8.s + ld1rw z8.s, p0/z, [pB] + fmla z18.s, p0/m, z0.s, z9.s + fmla z19.s, p0/m, z1.s, z9.s + ld1rw z9.s, p0/z, [pB, 4] + fmla z20.s, p0/m, z0.s, z10.s + fmla z21.s, p0/m, z1.s, z10.s + ld1rw z10.s, p0/z, [pB, 8] + fmla z22.s, p0/m, z0.s, z11.s + fmla z23.s, p0/m, z1.s, z11.s + ld1rw z11.s, p0/z, [pB, 12] + fmla z24.s, p0/m, z0.s, z12.s + fmla z25.s, p0/m, z1.s, z12.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE] + ld1rw z12.s, p0/z, [pB, 16] + fmla z26.s, p0/m, z0.s, z13.s + fmla z27.s, p0/m, z1.s, z13.s + prfm PLDL1KEEP, [pA2, #A_PRE_SIZE] + ld1rw z13.s, p0/z, [pB, 20] + fmla z28.s, p0/m, z0.s, z14.s + fmla z29.s, p0/m, z1.s, z14.s + ld1rw z14.s, p0/z, [pB, 24] + fmla z30.s, p0/m, z0.s, z15.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE+64] + fmla z31.s, p0/m, z1.s, z15.s + prfm PLDL1KEEP, [pA2, #A_PRE_SIZE+64] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 +.endm + +.macro KERNELv2x8_M2 + ld1w z0.s, p0/z, [pA1] + ld1w z1.s, p0/z, [pA2] + add pA1, pA1, vec_len, lsl #2 // pA1 = pA1 + vec_len * 2 * 4 + add pA2, pA2, vec_len, lsl #2 // pA1 = pA1 + vec_len * 2 * 4 + + fmla z16.s, p0/m, z2.s, z8.s + fmla z17.s, p0/m, z3.s, z8.s + ld1rw z8.s, p0/z, [pB] + fmla z18.s, p0/m, z2.s, z9.s + fmla z19.s, p0/m, z3.s, z9.s + ld1rw z9.s, p0/z, [pB, 4] + fmla z20.s, p0/m, z2.s, z10.s + fmla z21.s, p0/m, z3.s, z10.s + ld1rw z10.s, p0/z, [pB, 8] + fmla z22.s, p0/m, z2.s, z11.s + fmla z23.s, p0/m, z3.s, z11.s + ld1rw z11.s, p0/z, [pB, 12] + fmla z24.s, p0/m, z2.s, z12.s + fmla z25.s, p0/m, z3.s, z12.s + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + ld1rw z12.s, p0/z, [pB, 16] + fmla z26.s, p0/m, z2.s, z13.s + fmla z27.s, p0/m, z3.s, z13.s + ld1rw z13.s, p0/z, [pB, 20] + fmla z28.s, p0/m, z2.s, z14.s + fmla z29.s, p0/m, z3.s, z14.s + ld1rw z14.s, p0/z, [pB, 24] + fmla z30.s, p0/m, z2.s, z15.s + fmla z31.s, p0/m, z3.s, z15.s + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 +.endm + +.macro KERNELv2x8_E + fmla z16.s, p0/m, z2.s, z8.s + fmla z17.s, p0/m, z3.s, z8.s + fmla z18.s, p0/m, z2.s, z9.s + fmla z19.s, p0/m, z3.s, z9.s + fmla z20.s, p0/m, z2.s, z10.s + fmla z21.s, p0/m, z3.s, z10.s + fmla z22.s, p0/m, z2.s, z11.s + fmla z23.s, p0/m, z3.s, z11.s + fmla z24.s, p0/m, z2.s, z12.s + fmla z25.s, p0/m, z3.s, z12.s + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + fmla z26.s, p0/m, z2.s, z13.s + fmla z27.s, p0/m, z3.s, z13.s + fmla z28.s, p0/m, z2.s, z14.s + fmla z29.s, p0/m, z3.s, z14.s + fmla z30.s, p0/m, z2.s, z15.s + fmla z31.s, p0/m, z3.s, z15.s +.endm + +.macro KERNELv2x8_SUB + ld1w z0.s, p0/z, [pA1] + ld1w z1.s, p0/z, [pA2] + add pA1, pA1, vec_len, lsl #2 // pA1 = pA1 + vec_len * 4 + add pA2, pA2, vec_len, lsl #2 // pA1 = pA1 + vec_len * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + ld1rw z10.s, p0/z, [pB, 8] + ld1rw z11.s, p0/z, [pB, 12] + ld1rw z12.s, p0/z, [pB, 16] + ld1rw z13.s, p0/z, [pB, 20] + ld1rw z14.s, p0/z, [pB, 24] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 + + fmla z16.s, p0/m, z0.s, z8.s + fmla z17.s, p0/m, z1.s, z8.s + fmla z18.s, p0/m, z0.s, z9.s + fmla z19.s, p0/m, z1.s, z9.s + fmla z20.s, p0/m, z0.s, z10.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE] + fmla z21.s, p0/m, z1.s, z10.s + fmla z22.s, p0/m, z0.s, z11.s + fmla z23.s, p0/m, z1.s, z11.s + fmla z24.s, p0/m, z0.s, z12.s + prfm PLDL1KEEP, [pA2, #A_PRE_SIZE] + fmla z25.s, p0/m, z1.s, z12.s + fmla z26.s, p0/m, z0.s, z13.s + fmla z27.s, p0/m, z1.s, z13.s + fmla z28.s, p0/m, z0.s, z14.s + fmla z29.s, p0/m, z1.s, z14.s + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + fmla z30.s, p0/m, z0.s, z15.s + fmla z31.s, p0/m, z1.s, z15.s +.endm + +.macro SAVEv2x8 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + add pCRow1, pCRow0, LDC + ld1w z8.s, p0/z, [pCRow0] + ld1w z9.s, p0/z, [pCRow0, #1, mul vl] + fmla z8.s, p0/m, z16.s, alphaZ + fmla z9.s, p0/m, z17.s, alphaZ + st1w z8.s, p0, [pCRow0] + st1w z9.s, p0, [pCRow0, #1, mul vl] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + ld1w z10.s, p0/z, [pCRow1] + ld1w z11.s, p0/z, [pCRow1, #1, mul vl] + fmla z10.s, p0/m, z18.s, alphaZ + fmla z11.s, p0/m, z19.s, alphaZ + st1w z10.s, p0, [pCRow1] + st1w z11.s, p0, [pCRow1, #1, mul vl] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + ld1w z12.s, p0/z, [pCRow2] + ld1w z13.s, p0/z, [pCRow2, #1, mul vl] + fmla z12.s, p0/m, z20.s, alphaZ + fmla z13.s, p0/m, z21.s, alphaZ + st1w z12.s, p0, [pCRow2] + st1w z13.s, p0, [pCRow2, #1, mul vl] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + ld1w z14.s, p0/z, [pCRow1] + ld1w z15.s, p0/z, [pCRow1, #1, mul vl] + fmla z14.s, p0/m, z22.s, alphaZ + fmla z15.s, p0/m, z23.s, alphaZ + st1w z14.s, p0, [pCRow1] + st1w z15.s, p0, [pCRow1, #1, mul vl] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + ld1w z8.s, p0/z, [pCRow2] + ld1w z9.s, p0/z, [pCRow2, #1, mul vl] + fmla z8.s, p0/m, z24.s, alphaZ + fmla z9.s, p0/m, z25.s, alphaZ + st1w z8.s, p0, [pCRow2] + st1w z9.s, p0, [pCRow2, #1, mul vl] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + ld1w z10.s, p0/z, [pCRow1] + ld1w z11.s, p0/z, [pCRow1, #1, mul vl] + fmla z10.s, p0/m, z26.s, alphaZ + fmla z11.s, p0/m, z27.s, alphaZ + st1w z10.s, p0, [pCRow1] + st1w z11.s, p0, [pCRow1, #1, mul vl] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + ld1w z12.s, p0/z, [pCRow2] + ld1w z13.s, p0/z, [pCRow2, #1, mul vl] + fmla z12.s, p0/m, z28.s, alphaZ + fmla z13.s, p0/m, z29.s, alphaZ + st1w z12.s, p0, [pCRow2] + st1w z13.s, p0, [pCRow2, #1, mul vl] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + ld1w z14.s, p0/z, [pCRow1] + ld1w z15.s, p0/z, [pCRow1, #1, mul vl] + fmla z14.s, p0/m, z30.s, alphaZ + fmla z15.s, p0/m, z31.s, alphaZ + st1w z14.s, p0, [pCRow1] + st1w z15.s, p0, [pCRow1, #1, mul vl] + + add pCRow0, pCRow0, vec_len, lsl #3 // pC = pC + vec_len * 4 * 2 + +.endm + +.macro INITv2x4 + dup z16.s, #0 + dup z17.s, #0 + dup z18.s, #0 + dup z19.s, #0 + dup z20.s, #0 + dup z21.s, #0 + dup z22.s, #0 + dup z23.s, #0 +.endm + +.macro KERNELv2x4_SUB + ld1w z0.s, p0/z, [pA1] + ld1w z1.s, p0/z, [pA2] + add pA1, pA1, vec_len, lsl #2 // pA1 = pA1 + vec_len * 4 + add pA2, pA2, vec_len, lsl #2 // pA1 = pA1 + vec_len * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + ld1rw z10.s, p0/z, [pB, 8] + ld1rw z11.s, p0/z, [pB, 12] + + add pB, pB, 16 + + fmla z16.s, p0/m, z0.s, z8.s + fmla z17.s, p0/m, z1.s, z8.s + fmla z18.s, p0/m, z0.s, z9.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE] + fmla z19.s, p0/m, z1.s, z9.s + fmla z20.s, p0/m, z0.s, z10.s + prfm PLDL1KEEP, [pA2, #A_PRE_SIZE] + fmla z21.s, p0/m, z1.s, z10.s + fmla z22.s, p0/m, z0.s, z11.s + fmla z23.s, p0/m, z1.s, z11.s +.endm + +.macro SAVEv2x4 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + add pCRow1, pCRow0, LDC + ld1w z8.s, p0/z, [pCRow0] + ld1w z9.s, p0/z, [pCRow0, #1, mul vl] + fmla z8.s, p0/m, z16.s, alphaZ + fmla z9.s, p0/m, z17.s, alphaZ + st1w z8.s, p0, [pCRow0] + st1w z9.s, p0, [pCRow0, #1, mul vl] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + ld1w z10.s, p0/z, [pCRow1] + ld1w z11.s, p0/z, [pCRow1, #1, mul vl] + fmla z10.s, p0/m, z18.s, alphaZ + fmla z11.s, p0/m, z19.s, alphaZ + st1w z10.s, p0, [pCRow1] + st1w z11.s, p0, [pCRow1, #1, mul vl] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + ld1w z12.s, p0/z, [pCRow2] + ld1w z13.s, p0/z, [pCRow2, #1, mul vl] + fmla z12.s, p0/m, z20.s, alphaZ + fmla z13.s, p0/m, z21.s, alphaZ + st1w z12.s, p0, [pCRow2] + st1w z13.s, p0, [pCRow2, #1, mul vl] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + ld1w z14.s, p0/z, [pCRow1] + ld1w z15.s, p0/z, [pCRow1, #1, mul vl] + fmla z14.s, p0/m, z22.s, alphaZ + fmla z15.s, p0/m, z23.s, alphaZ + st1w z14.s, p0, [pCRow1] + st1w z15.s, p0, [pCRow1, #1, mul vl] + + add pCRow0, pCRow0, vec_len, lsl #3 // pC = pC + vec_len * 4 * 2 + +.endm + +.macro INITv2x2 + dup z16.s, #0 + dup z17.s, #0 + dup z18.s, #0 + dup z19.s, #0 +.endm + +.macro KERNELv2x2_SUB + ld1w z0.s, p0/z, [pA1] + ld1w z1.s, p0/z, [pA2] + add pA1, pA1, vec_len, lsl #2 // pA1 = pA1 + vec_len * 4 + add pA2, pA2, vec_len, lsl #2 // pA1 = pA1 + vec_len * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + + add pB, pB, 8 + + fmla z16.s, p0/m, z0.s, z8.s + fmla z17.s, p0/m, z1.s, z8.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE] + fmla z18.s, p0/m, z0.s, z9.s + fmla z19.s, p0/m, z1.s, z9.s + prfm PLDL1KEEP, [pA2, #A_PRE_SIZE] +.endm + +.macro SAVEv2x2 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + add pCRow1, pCRow0, LDC + ld1w z8.s, p0/z, [pCRow0] + ld1w z9.s, p0/z, [pCRow0, #1, mul vl] + fmla z8.s, p0/m, z16.s, alphaZ + fmla z9.s, p0/m, z17.s, alphaZ + st1w z8.s, p0, [pCRow0] + st1w z9.s, p0, [pCRow0, #1, mul vl] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + ld1w z10.s, p0/z, [pCRow1] + ld1w z11.s, p0/z, [pCRow1, #1, mul vl] + fmla z10.s, p0/m, z18.s, alphaZ + fmla z11.s, p0/m, z19.s, alphaZ + st1w z10.s, p0, [pCRow1] + st1w z11.s, p0, [pCRow1, #1, mul vl] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + + + add pCRow0, pCRow0, vec_len, lsl #3 // pC = pC + vec_len * 4 * 2 +.endm + +.macro INITv2x1 + dup z16.s, #0 + dup z17.s, #0 +.endm + +.macro KERNELv2x1_SUB + ld1w z0.s, p0/z, [pA1] + ld1w z1.s, p0/z, [pA2] + add pA1, pA1, vec_len, lsl #2 // pA1 = pA1 + vec_len * 4 + add pA2, pA2, vec_len, lsl #2 // pA1 = pA1 + vec_len * 4 + + ld1rw z8.s, p0/z, [pB] + + add pB, pB, 4 + + fmla z16.s, p0/m, z0.s, z8.s + fmla z17.s, p0/m, z1.s, z8.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE] +.endm + +.macro SAVEv2x1 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + add pCRow1, pCRow0, LDC + ld1w z8.s, p0/z, [pCRow0] + ld1w z9.s, p0/z, [pCRow0, #1, mul vl] + fmla z8.s, p0/m, z16.s, alphaZ + fmla z9.s, p0/m, z17.s, alphaZ + st1w z8.s, p0, [pCRow0] + st1w z9.s, p0, [pCRow0, #1, mul vl] + + add pCRow0, pCRow0, vec_len, lsl #3 // pC = pC + vec_len * 4 * 2 + +.endm + +.macro INITv1x8 + dup z16.s, #0 + dup z17.s, #0 + dup z18.s, #0 + dup z19.s, #0 + dup z20.s, #0 + dup z21.s, #0 + dup z22.s, #0 + dup z23.s, #0 +.endm + +.macro KERNELv1x8_I + ld1w z0.s, p1/z, [pA1] + ld1w z1.s, p1/z, [pA1, lanes, lsl #2] // next one + add pA1, pA1, lanes, lsl #3 // pA1 = pA1 + lanes * 2 * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + ld1rw z10.s, p0/z, [pB, 8] + ld1rw z11.s, p0/z, [pB, 12] + ld1rw z12.s, p0/z, [pB, 16] + ld1rw z13.s, p0/z, [pB, 20] + ld1rw z14.s, p0/z, [pB, 24] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 + + fmla z16.s, p1/m, z0.s, z8.s + ld1rw z8.s, p0/z, [pB] + fmla z17.s, p1/m, z0.s, z9.s + ld1rw z9.s, p0/z, [pB, 4] + fmla z18.s, p1/m, z0.s, z10.s + ld1rw z10.s, p0/z, [pB, 8] + fmla z19.s, p1/m, z0.s, z11.s + ld1rw z11.s, p0/z, [pB, 12] + fmla z20.s, p1/m, z0.s, z12.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE] + ld1rw z12.s, p0/z, [pB, 16] + fmla z21.s, p1/m, z0.s, z13.s + ld1rw z13.s, p0/z, [pB, 20] + fmla z22.s, p1/m, z0.s, z14.s + ld1rw z14.s, p0/z, [pB, 24] + fmla z23.s, p1/m, z0.s, z15.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE+64] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 +.endm + +.macro KERNELv1x8_M1 + ld1w z1.s, p1/z, [pA1] + add pA1, pA1, lanes, lsl #2 // pA1 = pA1 + lanes * 4 + + fmla z16.s, p1/m, z0.s, z8.s + ld1rw z8.s, p0/z, [pB] + fmla z17.s, p1/m, z0.s, z9.s + ld1rw z9.s, p0/z, [pB, 4] + fmla z18.s, p1/m, z0.s, z10.s + ld1rw z10.s, p0/z, [pB, 8] + fmla z19.s, p1/m, z0.s, z11.s + ld1rw z11.s, p0/z, [pB, 12] + fmla z20.s, p1/m, z0.s, z12.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE] + ld1rw z12.s, p0/z, [pB, 16] + fmla z21.s, p1/m, z0.s, z13.s + ld1rw z13.s, p0/z, [pB, 20] + fmla z22.s, p1/m, z0.s, z14.s + ld1rw z14.s, p0/z, [pB, 24] + fmla z23.s, p1/m, z0.s, z15.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE+64] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 +.endm + +.macro KERNELv1x8_M2 + ld1w z0.s, p1/z, [pA1] + add pA1, pA1, lanes, lsl #2 // pA1 = pA1 + lanes * 4 + + fmla z16.s, p1/m, z1.s, z8.s + ld1rw z8.s, p0/z, [pB] + fmla z17.s, p1/m, z1.s, z9.s + ld1rw z9.s, p0/z, [pB, 4] + fmla z18.s, p1/m, z1.s, z10.s + ld1rw z10.s, p0/z, [pB, 8] + fmla z19.s, p1/m, z1.s, z11.s + ld1rw z11.s, p0/z, [pB, 12] + fmla z20.s, p1/m, z1.s, z12.s + ld1rw z12.s, p0/z, [pB, 16] + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + fmla z21.s, p1/m, z1.s, z13.s + ld1rw z13.s, p0/z, [pB, 20] + fmla z22.s, p1/m, z1.s, z14.s + ld1rw z14.s, p0/z, [pB, 24] + fmla z23.s, p1/m, z1.s, z15.s + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 +.endm + +.macro KERNELv1x8_E + fmla z16.s, p1/m, z1.s, z8.s + fmla z17.s, p1/m, z1.s, z9.s + fmla z18.s, p1/m, z1.s, z10.s + fmla z19.s, p1/m, z1.s, z11.s + fmla z20.s, p1/m, z1.s, z12.s + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + fmla z21.s, p1/m, z1.s, z13.s + fmla z22.s, p1/m, z1.s, z14.s + fmla z23.s, p1/m, z1.s, z15.s +.endm + +.macro KERNELv1x8_SUB + ld1w z0.s, p1/z, [pA1] + add pA1, pA1, lanes, lsl #2 // pA1 = pA1 + lanes * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + ld1rw z10.s, p0/z, [pB, 8] + ld1rw z11.s, p0/z, [pB, 12] + ld1rw z12.s, p0/z, [pB, 16] + ld1rw z13.s, p0/z, [pB, 20] + ld1rw z14.s, p0/z, [pB, 24] + ld1rw z15.s, p0/z, [pB, 28] + + add pB, pB, 32 + + fmla z16.s, p1/m, z0.s, z8.s + fmla z17.s, p1/m, z0.s, z9.s + fmla z18.s, p1/m, z0.s, z10.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE] + fmla z19.s, p1/m, z0.s, z11.s + fmla z20.s, p1/m, z0.s, z12.s + fmla z21.s, p1/m, z0.s, z13.s + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + fmla z22.s, p1/m, z0.s, z14.s + fmla z23.s, p1/m, z0.s, z15.s + + +.endm + +.macro SAVEv1x8 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + add pCRow1, pCRow0, LDC + ld1w z24.s, p1/z, [pCRow0] + fmla z24.s, p1/m, z16.s, alphaZ + st1w z24.s, p1, [pCRow0] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + ld1w z25.s, p1/z, [pCRow1] + fmla z25.s, p1/m, z17.s, alphaZ + st1w z25.s, p1, [pCRow1] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + ld1w z26.s, p1/z, [pCRow2] + fmla z26.s, p1/m, z18.s, alphaZ + st1w z26.s, p1, [pCRow2] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + ld1w z27.s, p1/z, [pCRow1] + fmla z27.s, p1/m, z19.s, alphaZ + st1w z27.s, p1, [pCRow1] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + ld1w z28.s, p1/z, [pCRow2] + fmla z28.s, p1/m, z20.s, alphaZ + st1w z28.s, p1, [pCRow2] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + ld1w z29.s, p1/z, [pCRow1] + fmla z29.s, p1/m, z21.s, alphaZ + st1w z29.s, p1, [pCRow1] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + ld1w z30.s, p1/z, [pCRow2] + fmla z30.s, p1/m, z22.s, alphaZ + st1w z30.s, p1, [pCRow2] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + ld1w z31.s, p1/z, [pCRow1] + fmla z31.s, p1/m, z23.s, alphaZ + st1w z31.s, p1, [pCRow1] + + add pCRow0, pCRow0, lanes, lsl #2 // pC = pC + lanes * 4 + +.endm + +/******************************************************************************/ + +.macro INITv1x4 + dup z16.s, #0 + dup z17.s, #0 + dup z18.s, #0 + dup z19.s, #0 +.endm + +.macro KERNELv1x4_SUB + ld1w z0.s, p1/z, [pA1] + add pA1, pA1, lanes, lsl #2 // pA1 = pA1 + lanes * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + ld1rw z10.s, p0/z, [pB, 8] + ld1rw z11.s, p0/z, [pB, 12] + + add pB, pB, 16 + + fmla z16.s, p1/m, z0.s, z8.s + fmla z17.s, p1/m, z0.s, z9.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE] + fmla z18.s, p1/m, z0.s, z10.s + fmla z19.s, p1/m, z0.s, z11.s + +.endm + +.macro SAVEv1x4 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + add pCRow1, pCRow0, LDC + ld1w z24.s, p1/z, [pCRow0] + fmla z24.s, p1/m, z16.s, alphaZ + st1w z24.s, p1, [pCRow0] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + add pCRow2, pCRow1, LDC + ld1w z25.s, p1/z, [pCRow1] + fmla z25.s, p1/m, z17.s, alphaZ + st1w z25.s, p1, [pCRow1] + prfm PLDL2KEEP, [pCRow2, #C_PRE_SIZE] + + add pCRow1, pCRow2, LDC + ld1w z26.s, p1/z, [pCRow2] + fmla z26.s, p1/m, z18.s, alphaZ + st1w z26.s, p1, [pCRow2] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + ld1w z27.s, p1/z, [pCRow1] + fmla z27.s, p1/m, z19.s, alphaZ + st1w z27.s, p1, [pCRow1] + + add pCRow0, pCRow0, lanes, lsl #2 // pC = pC + lanes * 4 + +.endm + +/******************************************************************************/ + +.macro INITv1x2 + dup z16.s, #0 + dup z17.s, #0 +.endm + +.macro KERNELv1x2_SUB + ld1w z0.s, p1/z, [pA1] + add pA1, pA1, lanes, lsl #2 // pA1 = pA1 + lanes * 4 + + ld1rw z8.s, p0/z, [pB] + ld1rw z9.s, p0/z, [pB, 4] + + add pB, pB, 8 + + fmla z16.s, p1/m, z0.s, z8.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE] + fmla z17.s, p1/m, z0.s, z9.s + +.endm + +.macro SAVEv1x2 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + add pCRow1, pCRow0, LDC + ld1w z24.s, p1/z, [pCRow0] + fmla z24.s, p1/m, z16.s, alphaZ + st1w z24.s, p1, [pCRow0] + prfm PLDL2KEEP, [pCRow1, #C_PRE_SIZE] + + ld1w z25.s, p1/z, [pCRow1] + fmla z25.s, p1/m, z17.s, alphaZ + st1w z25.s, p1, [pCRow1] + + add pCRow0, pCRow0, lanes, lsl #2 // pC = pC + lanes * 4 + +.endm + +/******************************************************************************/ + +.macro INITv1x1 + dup z16.s, #0 +.endm + +.macro KERNELv1x1_SUB + ld1w z0.s, p1/z, [pA1] + add pA1, pA1, lanes, lsl #2 // pA1 = pA1 + lanes * 4 + + ld1rw z8.s, p0/z, [pB] + + add pB, pB, 4 + + fmla z16.s, p1/m, z0.s, z8.s + prfm PLDL1KEEP, [pA1, #A_PRE_SIZE] + +.endm + +.macro SAVEv1x1 + + prfm PLDL2KEEP, [pCRow0, #C_PRE_SIZE] + + ld1w z24.s, p1/z, [pCRow0] + fmla z24.s, p1/m, z16.s, alphaZ + st1w z24.s, p1, [pCRow0] + + + add pCRow0, pCRow0, lanes, lsl #2 // pC = pC + lanes * 4 + +.endm + + +/******************************************************************************* +* End of macro definitions +*******************************************************************************/ + + PROLOGUE + + .align 5 + add sp, sp, #-(11 * 16) + stp d8, d9, [sp, #(0 * 16)] + stp d10, d11, [sp, #(1 * 16)] + stp d12, d13, [sp, #(2 * 16)] + stp d14, d15, [sp, #(3 * 16)] + stp d16, d17, [sp, #(4 * 16)] + stp x18, x19, [sp, #(5 * 16)] + stp x20, x21, [sp, #(6 * 16)] + stp x22, x23, [sp, #(7 * 16)] + stp x24, x25, [sp, #(8 * 16)] + stp x26, x27, [sp, #(9 * 16)] + str x28, [sp, #(10 * 16)] + + prfm PLDL1KEEP, [origPB] + prfm PLDL1KEEP, [origPA] + + fmov alpha, s0 + dup alphaZ, alpha + cntw vec_len + lsl vec_lenx2, vec_len, #1 + + lsl LDC, LDC, #2 // ldc = ldc * 8 + ptrue p0.s // create true predicate + + mov pB, origPB +// Loop over N + mov counterJ, origN + asr counterJ, counterJ, #3 // J = J / 8 + cmp counterJ, #0 + ble .Lsgemm_kernel_L4_BEGIN + +/******************************************************************************/ +/* Repeat this as long as there are 8 left in N */ + + .align 5 +.Lsgemm_kernel_L8_BEGIN: + mov pCRow0, pC + + add pC, pC, LDC, lsl #3 // add 8 x LDC + + mov pA1, origPA // pA1 = start of A array + +.Lsgemm_kernel_L8_Mv2_BEGIN: + + mov counterI, #0 + cmp origM, vec_lenx2 // Check if M < 2*SVE_LEN + blt .Lsgemm_kernel_L8_Mv1_BEGIN + + mov counterI, origM + +/* Until we have at least 2*SVE_LEN iters left in M, we do them with V2*8 kernel */ + mul temp, vec_len, origK // generate address of pA2 + add pA2, pA1, temp, lsl #2 // pA1 = start of A array + prfm PLDL1KEEP, [pA2] + + .align 5 +.Lsgemm_kernel_L8_Mv2_20: + + mov pB, origPB + INITv2x8 // fill with zeros + + asr counterL , origK, #3 // L = K / 8 + cmp counterL , #2 // is there at least 4 to do? + blt .Lsgemm_kernel_L8_Mv2_32 + + KERNELv2x8_I + KERNELv2x8_M2 + KERNELv2x8_M1 + KERNELv2x8_M2 + KERNELv2x8_M1 + KERNELv2x8_M2 + KERNELv2x8_M1 + KERNELv2x8_M2 + + subs counterL, counterL, #2 // subtract 2 + ble .Lsgemm_kernel_L8_Mv2_22a + + .align 5 +.Lsgemm_kernel_L8_Mv2_22: + + KERNELv2x8_M1 + KERNELv2x8_M2 + KERNELv2x8_M1 + KERNELv2x8_M2 + KERNELv2x8_M1 + KERNELv2x8_M2 + KERNELv2x8_M1 + KERNELv2x8_M2 + + subs counterL, counterL, #1 + bgt .Lsgemm_kernel_L8_Mv2_22 + + .align 5 +.Lsgemm_kernel_L8_Mv2_22a: + + KERNELv2x8_M1 + KERNELv2x8_M2 + KERNELv2x8_M1 + KERNELv2x8_M2 + KERNELv2x8_M1 + KERNELv2x8_M2 + KERNELv2x8_M1 + KERNELv2x8_E + + b .Lsgemm_kernel_L8_Mv2_44 + + .align 5 +.Lsgemm_kernel_L8_Mv2_32: + + tst counterL, #1 + ble .Lsgemm_kernel_L8_Mv2_40 + + KERNELv2x8_I + KERNELv2x8_M2 + KERNELv2x8_M1 + KERNELv2x8_M2 + KERNELv2x8_M1 + KERNELv2x8_M2 + KERNELv2x8_M1 + KERNELv2x8_E + + + b .Lsgemm_kernel_L8_Mv2_44 + +.Lsgemm_kernel_L8_Mv2_40: + + INITv2x8 + +.Lsgemm_kernel_L8_Mv2_44: + + ands counterL , origK, #7 + ble .Lsgemm_kernel_L8_Mv2_100 + + .align 5 +.Lsgemm_kernel_L8_Mv2_46: + + KERNELv2x8_SUB + + subs counterL, counterL, #1 + bne .Lsgemm_kernel_L8_Mv2_46 + +.Lsgemm_kernel_L8_Mv2_100: + prfm PLDL1KEEP, [pA1] + prfm PLDL1KEEP, [pA1, #64] + prfm PLDL1KEEP, [pA2] + prfm PLDL1KEEP, [pA2, #64] + prfm PLDL1KEEP, [origPB] + + SAVEv2x8 + mov pA1, pA2 // pA1 = pA2 + mul temp, vec_len, origK // generate address of pA2 + add pA2, pA1, temp, lsl #2 // + +.Lsgemm_kernel_L8_Mv2_END: + sub counterI, counterI, vec_lenx2 + cmp counterI, vec_lenx2 + bge .Lsgemm_kernel_L8_Mv2_20 + sub counterI, origM, counterI + + cmp counterI, origM + beq .Lsgemm_kernel_L8_END + +////////////////////////////////////////// +// We have less than 2*SVE_LEN left. We do this with V1x8 kernel. +.Lsgemm_kernel_L8_Mv1_BEGIN: + + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s // lanes contain number of active SVE lanes in M dimension + + .align 5 +.Lsgemm_kernel_L8_Mv1_20: + + mov pB, origPB + INITv1x8 // fill with zeros + + asr counterL , origK, #3 // L = K / 8 + cmp counterL , #2 // is there at least 4 to do? + blt .Lsgemm_kernel_L8_Mv1_32 + + KERNELv1x8_I + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + + subs counterL, counterL, #2 // subtract 2 + ble .Lsgemm_kernel_L8_Mv1_22a + + .align 5 +.Lsgemm_kernel_L8_Mv1_22: + + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + + subs counterL, counterL, #1 + bgt .Lsgemm_kernel_L8_Mv1_22 + + .align 5 +.Lsgemm_kernel_L8_Mv1_22a: + + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_E + + b .Lsgemm_kernel_L8_Mv1_44 + + .align 5 +.Lsgemm_kernel_L8_Mv1_32: + + tst counterL, #1 + ble .Lsgemm_kernel_L8_Mv1_40 + + KERNELv1x8_I + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_M2 + KERNELv1x8_M1 + KERNELv1x8_E + + + b .Lsgemm_kernel_L8_Mv1_44 + +.Lsgemm_kernel_L8_Mv1_40: + + INITv1x8 + +.Lsgemm_kernel_L8_Mv1_44: + + ands counterL , origK, #7 + ble .Lsgemm_kernel_L8_Mv1_100 + + .align 5 +.Lsgemm_kernel_L8_Mv1_46: + + KERNELv1x8_SUB + + subs counterL, counterL, #1 + bne .Lsgemm_kernel_L8_Mv1_46 + +.Lsgemm_kernel_L8_Mv1_100: + prfm PLDL1KEEP, [pA1] + prfm PLDL1KEEP, [pA1, #64] + prfm PLDL1KEEP, [origPB] + + SAVEv1x8 + +.Lsgemm_kernel_L8_Mv1_END: + + incw counterI + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s // lanes contain number of active SVE lanes in M dimension + b.any .Lsgemm_kernel_L8_Mv1_20 + +.Lsgemm_kernel_L8_END: + + lsl temp, origK, #5 + add origPB, origPB, temp // B = B + K * 8 * 4 + + subs counterJ, counterJ , #1 // j-- + bgt .Lsgemm_kernel_L8_BEGIN + +/******************************************************************************/ +/* Repeat the same thing if 4 left in N */ + + .align 5 +.Lsgemm_kernel_L4_BEGIN: + + mov counterJ , origN + tst counterJ , #4 + ble .Lsgemm_kernel_L2_BEGIN + + + mov pCRow0, pC + + add pC, pC, LDC, lsl #2 // add 4 x LDC + + mov pA1, origPA // pA1 = start of A array + +.Lsgemm_kernel_L4_Mv2_BEGIN: + + mov counterI, #0 + cmp origM, vec_lenx2 + blt .Lsgemm_kernel_L4_Mv1_BEGIN + + mov counterI, origM + + mul temp, vec_len, origK // generate address of pA2 + add pA2, pA1, temp, lsl #2 // pA1 = start of A array + + .align 5 +.Lsgemm_kernel_L4_Mv2_20: + + mov pB, origPB + INITv2x4 // fill with zeros + + asr counterL , origK, #3 // L = K / 8 + cmp counterL , #0 // is there at least 4 to do? + ble .Lsgemm_kernel_L4_Mv2_44 + + .align 5 +.Lsgemm_kernel_L4_Mv2_22: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv2x4_SUB + KERNELv2x4_SUB + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv2x4_SUB + KERNELv2x4_SUB + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv2x4_SUB + KERNELv2x4_SUB + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv2x4_SUB + KERNELv2x4_SUB + + subs counterL, counterL, #1 + bgt .Lsgemm_kernel_L4_Mv2_22 + +.Lsgemm_kernel_L4_Mv2_44: + + ands counterL , origK, #7 + ble .Lsgemm_kernel_L4_Mv2_100 + + .align 5 +.Lsgemm_kernel_L4_Mv2_46: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv2x4_SUB + + subs counterL, counterL, #1 + bne .Lsgemm_kernel_L4_Mv2_46 + +.Lsgemm_kernel_L4_Mv2_100: + prfm PLDL1KEEP, [pA1] + prfm PLDL1KEEP, [pA1, #64] + prfm PLDL1KEEP, [pA2] + prfm PLDL1KEEP, [pA2, #64] + prfm PLDL1KEEP, [origPB] + + SAVEv2x4 + mov pA1, pA2 // pA1 = pA2 + mul temp, vec_len, origK // generate address of pA2 + add pA2, pA1, temp, lsl #2 // + +.Lsgemm_kernel_L4_Mv2_END: + sub counterI, counterI, vec_lenx2 + cmp counterI, vec_lenx2 + bge .Lsgemm_kernel_L4_Mv2_20 + sub counterI, origM, counterI + + cmp counterI, origM + beq .Lsgemm_kernel_L4_END + +////////////////////////////////// +// We have less than 2*SVE_LEN left. We do this with V1x4 kernel. +.Lsgemm_kernel_L4_Mv1_BEGIN: + + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s // lanes contain number of active SVE lanes in M dimension + + .align 5 +.Lsgemm_kernel_L4_Mv1_20: + + mov pB, origPB + INITv1x4 // fill with zeros + + asr counterL , origK, #3 // L = K / 8 + cmp counterL , #0 // is there at least 4 to do? + ble .Lsgemm_kernel_L4_Mv1_44 + + .align 5 +.Lsgemm_kernel_L4_Mv1_22: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x4_SUB + KERNELv1x4_SUB + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x4_SUB + KERNELv1x4_SUB + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x4_SUB + KERNELv1x4_SUB + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x4_SUB + KERNELv1x4_SUB + + subs counterL, counterL, #1 + bgt .Lsgemm_kernel_L4_Mv1_22 + +.Lsgemm_kernel_L4_Mv1_44: + + ands counterL , origK, #7 + ble .Lsgemm_kernel_L4_Mv1_100 + + .align 5 +.Lsgemm_kernel_L4_Mv1_46: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x4_SUB + + subs counterL, counterL, #1 + bne .Lsgemm_kernel_L4_Mv1_46 + +.Lsgemm_kernel_L4_Mv1_100: + prfm PLDL1KEEP, [pA1] + prfm PLDL1KEEP, [pA1, #64] + prfm PLDL1KEEP, [origPB] + + SAVEv1x4 + +.Lsgemm_kernel_L4_Mv1_END: + + incw counterI + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + b.any .Lsgemm_kernel_L4_Mv1_20 + + +.Lsgemm_kernel_L4_END: + lsl temp, origK, #4 + add origPB, origPB, temp // B = B + K * 4 * 4 + +/******************************************************************************/ +/* Repeat the same thing if 2 left in N */ + + .align 5 +.Lsgemm_kernel_L2_BEGIN: + + mov counterJ , origN + tst counterJ , #2 + ble .Lsgemm_kernel_L1_BEGIN + + mov pCRow0, pC + + add pC, pC, LDC, lsl #1 // add 2 x LDC + + mov pA1, origPA // pA1 = start of A array + +.Lsgemm_kernel_L2_Mv2_BEGIN: + + mov counterI, #0 + cmp origM, vec_lenx2 + blt .Lsgemm_kernel_L2_Mv1_BEGIN + + mov counterI, origM + + mul temp, vec_len, origK // generate address of pA2 + add pA2, pA1, temp, lsl #2 // pA1 = start of A array + + .align 5 +.Lsgemm_kernel_L2_Mv2_20: + + mov pB, origPB + INITv2x2 // fill with zeros + + asr counterL , origK, #3 // L = K / 8 + cmp counterL , #0 // is there at least 4 to do? + ble .Lsgemm_kernel_L2_Mv2_44 + + .align 5 +.Lsgemm_kernel_L2_Mv2_22: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv2x2_SUB + KERNELv2x2_SUB + KERNELv2x2_SUB + KERNELv2x2_SUB + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv2x2_SUB + KERNELv2x2_SUB + KERNELv2x2_SUB + KERNELv2x2_SUB + + subs counterL, counterL, #1 + bgt .Lsgemm_kernel_L2_Mv2_22 + +.Lsgemm_kernel_L2_Mv2_44: + + ands counterL , origK, #7 + ble .Lsgemm_kernel_L2_Mv2_100 + + .align 5 +.Lsgemm_kernel_L2_Mv2_46: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv2x2_SUB + + subs counterL, counterL, #1 + bne .Lsgemm_kernel_L2_Mv2_46 + +.Lsgemm_kernel_L2_Mv2_100: + prfm PLDL1KEEP, [pA1] + prfm PLDL1KEEP, [pA1, #64] + prfm PLDL1KEEP, [pA2] + prfm PLDL1KEEP, [pA2, #64] + prfm PLDL1KEEP, [origPB] + + SAVEv2x2 + mov pA1, pA2 // pA1 = pA2 + mul temp, vec_len, origK // generate address of pA2 + add pA2, pA1, temp, lsl #2 // + +.Lsgemm_kernel_L2_Mv2_END: + sub counterI, counterI, vec_lenx2 + cmp counterI, vec_lenx2 + bge .Lsgemm_kernel_L2_Mv2_20 + sub counterI, origM, counterI + + cmp counterI, origM + beq .Lsgemm_kernel_L2_END + + +////////////////////////////////// +// We have less than 2*SVE_LEN left. We do this with V1x2 kernel. +.Lsgemm_kernel_L2_Mv1_BEGIN: + + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + + .align 5 +.Lsgemm_kernel_L2_Mv1_20: + + mov pB, origPB + INITv1x2 // fill with zeros + + asr counterL , origK, #3 // L = K / 8 + cmp counterL , #0 // is there at least 4 to do? + ble .Lsgemm_kernel_L2_Mv1_44 + + .align 5 +.Lsgemm_kernel_L2_Mv1_22: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + KERNELv1x2_SUB + + subs counterL, counterL, #1 + bgt .Lsgemm_kernel_L2_Mv1_22 + +.Lsgemm_kernel_L2_Mv1_44: + + ands counterL , origK, #7 + ble .Lsgemm_kernel_L2_Mv1_100 + + .align 5 +.Lsgemm_kernel_L2_Mv1_46: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x2_SUB + + subs counterL, counterL, #1 + bne .Lsgemm_kernel_L2_Mv1_46 + +.Lsgemm_kernel_L2_Mv1_100: + prfm PLDL1KEEP, [pA1] + prfm PLDL1KEEP, [pA1, #64] + prfm PLDL1KEEP, [origPB] + + SAVEv1x2 + +.Lsgemm_kernel_L2_Mv1_END: + + incw counterI + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + b.any .Lsgemm_kernel_L2_Mv1_20 + + +.Lsgemm_kernel_L2_END: + add origPB, origPB, origK, lsl #3 // B = B + K * 2 * 4 + +/******************************************************************************/ +/* Repeat the same thing if 1 left in N */ + + .align 5 +.Lsgemm_kernel_L1_BEGIN: + + mov counterJ , origN + tst counterJ , #1 + ble .Lsgemm_kernel_L999 // done + + mov pCRow0, pC + + add pC, pC, LDC // add 1 x LDC + + mov pA1, origPA // pA1 = start of A array + +.Lsgemm_kernel_L1_Mv2_BEGIN: + + mov counterI, #0 + cmp origM, vec_lenx2 + blt .Lsgemm_kernel_L1_Mv1_BEGIN + + mov counterI, origM + + mul temp, vec_len, origK // generate address of pA2 + add pA2, pA1, temp, lsl #2 // pA1 = start of A array + + + .align 5 +.Lsgemm_kernel_L1_Mv2_20: + + mov pB, origPB + INITv2x1 // fill with zeros + + asr counterL , origK, #3 // L = K / 8 + cmp counterL , #0 // is there at least 8 to do? + ble .Lsgemm_kernel_L1_Mv2_44 + + .align 5 +.Lsgemm_kernel_L1_Mv2_22: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv2x1_SUB + KERNELv2x1_SUB + KERNELv2x1_SUB + KERNELv2x1_SUB + KERNELv2x1_SUB + KERNELv2x1_SUB + KERNELv2x1_SUB + KERNELv2x1_SUB + + subs counterL, counterL, #1 + bgt .Lsgemm_kernel_L1_Mv2_22 + +.Lsgemm_kernel_L1_Mv2_44: + + ands counterL , origK, #7 + ble .Lsgemm_kernel_L1_Mv2_100 + + .align 5 +.Lsgemm_kernel_L1_Mv2_46: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv2x1_SUB + + subs counterL, counterL, #1 + bgt .Lsgemm_kernel_L1_Mv2_46 + +.Lsgemm_kernel_L1_Mv2_100: + prfm PLDL1KEEP, [pA1] + prfm PLDL1KEEP, [pA1, #64] + prfm PLDL1KEEP, [origPB] + + SAVEv2x1 + mov pA1, pA2 // pA1 = pA2 + mul temp, vec_len, origK // generate address of pA2 + add pA2, pA1, temp, lsl #2 // + +.Lsgemm_kernel_L1_Mv2_END: + sub counterI, counterI, vec_lenx2 + cmp counterI, vec_lenx2 + bge .Lsgemm_kernel_L1_Mv2_20 + sub counterI, origM, counterI + + cmp counterI, origM + beq .Lsgemm_kernel_L1_END + + +////////////////////////////////// +// We have less than 2*SVE_LEN left. We do this with V1x1 kernel. +.Lsgemm_kernel_L1_Mv1_BEGIN: + + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + + .align 5 +.Lsgemm_kernel_L1_Mv1_20: + + mov pB, origPB + INITv1x1 // fill with zeros + + asr counterL , origK, #3 // L = K / 8 + cmp counterL , #0 // is there at least 8 to do? + ble .Lsgemm_kernel_L1_Mv1_44 + + .align 5 +.Lsgemm_kernel_L1_Mv1_22: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + KERNELv1x1_SUB + + subs counterL, counterL, #1 + bgt .Lsgemm_kernel_L1_Mv1_22 + +.Lsgemm_kernel_L1_Mv1_44: + + ands counterL , origK, #7 + ble .Lsgemm_kernel_L1_Mv1_100 + + .align 5 +.Lsgemm_kernel_L1_Mv1_46: + + prfm PLDL1KEEP, [pB, #B_PRE_SIZE] + KERNELv1x1_SUB + + subs counterL, counterL, #1 + bgt .Lsgemm_kernel_L1_Mv1_46 + +.Lsgemm_kernel_L1_Mv1_100: + prfm PLDL1KEEP, [pA1] + prfm PLDL1KEEP, [pA1, #64] + prfm PLDL1KEEP, [origPB] + + SAVEv1x1 + +.Lsgemm_kernel_L1_Mv1_END: + + incw counterI + whilelt p1.s, counterI, origM //SVE instruction + cntp lanes, p0, p1.s + b.any .Lsgemm_kernel_L1_Mv1_20 + + +.Lsgemm_kernel_L1_END: + +/******************************************************************************/ + +.Lsgemm_kernel_L999: + mov x0, #0 // set return value + ldp d8, d9, [sp, #(0 * 16)] + ldp d10, d11, [sp, #(1 * 16)] + ldp d12, d13, [sp, #(2 * 16)] + ldp d14, d15, [sp, #(3 * 16)] + ldp d16, d17, [sp, #(4 * 16)] + ldp x18, x19, [sp, #(5 * 16)] + ldp x20, x21, [sp, #(6 * 16)] + ldp x22, x23, [sp, #(7 * 16)] + ldp x24, x25, [sp, #(8 * 16)] + ldp x26, x27, [sp, #(9 * 16)] + ldr x28, [sp, #(10 * 16)] + add sp, sp, #(11*16) + ret + + EPILOGUE + From 774267fdac4594f027916979b064e0151c1a2b9e Mon Sep 17 00:00:00 2001 From: Bine Brank Date: Sat, 11 Dec 2021 16:35:08 +0100 Subject: [PATCH 5/6] adjust Makefile.L3 for SVE --- kernel/Makefile.L3 | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 593e33dde..d22bd46a5 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -1483,29 +1483,61 @@ $(KDIR)xtrsm_kernel_RC$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XTRSMKERNEL_RT) $(XT $(CC) -c $(CFLAGS) -DTRSMKERNEL -DCOMPLEX -DXDOUBLE -UUPPER -DRT -DCONJ $< -o $@ +ifdef STRMMUNCOPY_M +$(KDIR)strmm_iunucopy$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMUNCOPY_M) + $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -ULOWER -DUNIT $< -o $@ + +$(KDIR)strmm_iunncopy$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMUNCOPY_M) + $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -ULOWER -UUNIT $< -o $@ +else $(KDIR)strmm_iunucopy$(TSUFFIX).$(SUFFIX) : generic/trmm_uncopy_$(SGEMM_UNROLL_M).c $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -ULOWER -DUNIT $< -o $@ $(KDIR)strmm_iunncopy$(TSUFFIX).$(SUFFIX) : generic/trmm_uncopy_$(SGEMM_UNROLL_M).c $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -ULOWER -UUNIT $< -o $@ +endif +ifdef STRMMLNCOPY_M +$(KDIR)strmm_ilnucopy$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMLNCOPY_M) + $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -DLOWER -DUNIT $< -o $@ + +$(KDIR)strmm_ilnncopy$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMLNCOPY_M) + $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -DLOWER -UUNIT $< -o $@ +else $(KDIR)strmm_ilnucopy$(TSUFFIX).$(SUFFIX) : generic/trmm_lncopy_$(SGEMM_UNROLL_M).c $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -DLOWER -DUNIT $< -o $@ $(KDIR)strmm_ilnncopy$(TSUFFIX).$(SUFFIX) : generic/trmm_lncopy_$(SGEMM_UNROLL_M).c $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -DLOWER -UUNIT $< -o $@ +endif +ifdef STRMMUTCOPY_M +$(KDIR)strmm_iutucopy$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMUTCOPY_M) + $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -ULOWER -DUNIT $< -o $@ + +$(KDIR)strmm_iutncopy$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMUTCOPY_M) + $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -ULOWER -UUNIT $< -o $@ +else $(KDIR)strmm_iutucopy$(TSUFFIX).$(SUFFIX) : generic/trmm_utcopy_$(SGEMM_UNROLL_M).c $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -ULOWER -DUNIT $< -o $@ $(KDIR)strmm_iutncopy$(TSUFFIX).$(SUFFIX) : generic/trmm_utcopy_$(SGEMM_UNROLL_M).c $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -ULOWER -UUNIT $< -o $@ +endif +ifdef STRMMLTCOPY_M +$(KDIR)strmm_iltucopy$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMLTCOPY_M) + $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -DLOWER -DUNIT $< -o $@ + +$(KDIR)strmm_iltncopy$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMLTCOPY_M) + $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -DLOWER -UUNIT $< -o $@ +else $(KDIR)strmm_iltucopy$(TSUFFIX).$(SUFFIX) : generic/trmm_ltcopy_$(SGEMM_UNROLL_M).c $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -DLOWER -DUNIT $< -o $@ $(KDIR)strmm_iltncopy$(TSUFFIX).$(SUFFIX) : generic/trmm_ltcopy_$(SGEMM_UNROLL_M).c $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -UOUTER -DLOWER -UUNIT $< -o $@ +endif $(KDIR)strmm_ounucopy$(TSUFFIX).$(SUFFIX) : generic/trmm_uncopy_$(SGEMM_UNROLL_N).c $(CC) -c $(CFLAGS) $(NO_UNINITIALIZED_WARN) -UDOUBLE -UCOMPLEX -DOUTER -ULOWER -DUNIT $< -o $@ From a8f62a347bb6a9d653f0b57bf5a05b5e3cd097a8 Mon Sep 17 00:00:00 2001 From: Bine Brank Date: Sat, 11 Dec 2021 16:37:23 +0100 Subject: [PATCH 6/6] fix UNROLL_MN and add to targets for SVE --- kernel/arm64/KERNEL.A64FX | 9 ++++++-- kernel/arm64/KERNEL.ARMV8SVE | 40 +++++++++++++++--------------------- param.h | 8 ++++++++ 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/kernel/arm64/KERNEL.A64FX b/kernel/arm64/KERNEL.A64FX index ee66fea8e..80be4ddd0 100644 --- a/kernel/arm64/KERNEL.A64FX +++ b/kernel/arm64/KERNEL.A64FX @@ -114,8 +114,8 @@ DSDOTKERNEL = dot.S DGEMM_BETA = dgemm_beta.S SGEMM_BETA = sgemm_beta.S -SGEMMKERNEL = sgemm_kernel_sve_v1x$(SGEMM_UNROLL_N).S -STRMMKERNEL = strmm_kernel_8x$(SGEMM_UNROLL_N).S +SGEMMKERNEL = sgemm_kernel_sve_v2x$(SGEMM_UNROLL_N).S +STRMMKERNEL = strmm_kernel_sve_v1x$(SGEMM_UNROLL_N).S SGEMMINCOPY = sgemm_ncopy_sve_v1.c SGEMMITCOPY = sgemm_tcopy_sve_v1.c @@ -127,6 +127,11 @@ SGEMMITCOPYOBJ = sgemm_itcopy$(TSUFFIX).$(SUFFIX) SGEMMONCOPYOBJ = sgemm_oncopy$(TSUFFIX).$(SUFFIX) SGEMMOTCOPYOBJ = sgemm_otcopy$(TSUFFIX).$(SUFFIX) +STRMMUNCOPY_M = trmm_uncopy_sve_v1.c +STRMMLNCOPY_M = trmm_lncopy_sve_v1.c +STRMMUTCOPY_M = trmm_utcopy_sve_v1.c +STRMMLTCOPY_M = trmm_ltcopy_sve_v1.c + SSYMMUCOPY_M = symm_ucopy_sve.c SSYMMLCOPY_M = symm_lcopy_sve.c diff --git a/kernel/arm64/KERNEL.ARMV8SVE b/kernel/arm64/KERNEL.ARMV8SVE index 1f605d10b..0364a929c 100644 --- a/kernel/arm64/KERNEL.ARMV8SVE +++ b/kernel/arm64/KERNEL.ARMV8SVE @@ -114,35 +114,27 @@ DSDOTKERNEL = dot.S DGEMM_BETA = dgemm_beta.S SGEMM_BETA = sgemm_beta.S -SGEMMKERNEL = sgemm_kernel_$(SGEMM_UNROLL_M)x$(SGEMM_UNROLL_N).S -STRMMKERNEL = strmm_kernel_$(SGEMM_UNROLL_M)x$(SGEMM_UNROLL_N).S -ifneq ($(SGEMM_UNROLL_M), $(SGEMM_UNROLL_N)) -ifeq ($(SGEMM_UNROLL_M), 16) -SGEMMITCOPY = sgemm_tcopy_$(SGEMM_UNROLL_M).S -else -SGEMMITCOPY = ../generic/gemm_tcopy_$(SGEMM_UNROLL_M).c -endif -ifeq ($(SGEMM_UNROLL_M), 4) -SGEMMINCOPY = sgemm_ncopy_$(SGEMM_UNROLL_M).S -else -SGEMMINCOPY = ../generic/gemm_ncopy_$(SGEMM_UNROLL_M).c -endif +SGEMMKERNEL = sgemm_kernel_sve_v2x$(SGEMM_UNROLL_N).S +STRMMKERNEL = strmm_kernel_sve_v1x$(SGEMM_UNROLL_N).S + +SGEMMINCOPY = sgemm_ncopy_sve_v1.c +SGEMMITCOPY = sgemm_tcopy_sve_v1.c +SGEMMONCOPY = sgemm_ncopy_$(DGEMM_UNROLL_N).S +SGEMMOTCOPY = sgemm_tcopy_$(DGEMM_UNROLL_N).S + SGEMMINCOPYOBJ = sgemm_incopy$(TSUFFIX).$(SUFFIX) SGEMMITCOPYOBJ = sgemm_itcopy$(TSUFFIX).$(SUFFIX) -endif -ifeq ($(SGEMM_UNROLL_N), 16) -SGEMMOTCOPY = sgemm_tcopy_$(SGEMM_UNROLL_N).S -else -SGEMMOTCOPY = ../generic/gemm_tcopy_$(SGEMM_UNROLL_N).c -endif -ifeq ($(SGEMM_UNROLL_N), 4) -SGEMMONCOPY = sgemm_ncopy_$(SGEMM_UNROLL_N).S -else -SGEMMONCOPY = ../generic/gemm_ncopy_$(SGEMM_UNROLL_N).c -endif SGEMMONCOPYOBJ = sgemm_oncopy$(TSUFFIX).$(SUFFIX) SGEMMOTCOPYOBJ = sgemm_otcopy$(TSUFFIX).$(SUFFIX) +STRMMUNCOPY_M = trmm_uncopy_sve_v1.c +STRMMLNCOPY_M = trmm_lncopy_sve_v1.c +STRMMUTCOPY_M = trmm_utcopy_sve_v1.c +STRMMLTCOPY_M = trmm_ltcopy_sve_v1.c + +SSYMMUCOPY_M = symm_ucopy_sve.c +SSYMMLCOPY_M = symm_lcopy_sve.c + DGEMMKERNEL = dgemm_kernel_sve_v2x$(DGEMM_UNROLL_N).S DTRMMKERNEL = dtrmm_kernel_sve_v1x$(DGEMM_UNROLL_N).S diff --git a/param.h b/param.h index e9419bd9d..f7b8eb07b 100644 --- a/param.h +++ b/param.h @@ -3296,14 +3296,22 @@ is a big desktop or server with abundant cache rather than a phone or embedded d #elif defined(ARMV8SVE) || defined(A64FX) +/* When all BLAS3 routines are implemeted with SVE, SGEMM_DEFAULT_UNROLL_M should be "sve_vl". +Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy routines in both directions seperated. */ #define SGEMM_DEFAULT_UNROLL_M 4 #define SGEMM_DEFAULT_UNROLL_N 8 +/* SGEMM_UNROLL_MN is calculated as max(SGEMM_UNROLL_M, SGEMM_UNROLL_N) + * Since we don't define SGEMM_UNROLL_M correctly we have to manually set this macro. + * If SVE size is ever more than 1024, this should be increased also. */ +#define SGEMM_DEFAULT_UNROLL_MN 32 /* When all BLAS3 routines are implemeted with SVE, DGEMM_DEFAULT_UNROLL_M should be "sve_vl". Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy routines in both directions seperated. */ #define DGEMM_DEFAULT_UNROLL_M 2 #define DGEMM_DEFAULT_UNROLL_N 8 +#define DGEMM_DEFAULT_UNROLL_MN 32 + #define CGEMM_DEFAULT_UNROLL_M 8 #define CGEMM_DEFAULT_UNROLL_N 4