diff --git a/kernel/riscv64/KERNEL.x280 b/kernel/riscv64/KERNEL.x280 index 4d64354fb..217d8534e 100644 --- a/kernel/riscv64/KERNEL.x280 +++ b/kernel/riscv64/KERNEL.x280 @@ -118,8 +118,8 @@ DGEMVTKERNEL = gemv_t_rvv.c CGEMVTKERNEL = zgemv_t_rvv.c ZGEMVTKERNEL = zgemv_t_rvv.c -CTRMMKERNEL = ztrmmkernel_2x2_rvv.c -ZTRMMKERNEL = ztrmmkernel_2x2_rvv.c +CTRMMKERNEL = ztrmmkernel_rvv_v1x4.c +ZTRMMKERNEL = ztrmmkernel_rvv_v1x4.c # SGEMM_UNROLL_N set in params.h ifeq ($(SGEMM_UNROLL_N), 8) @@ -168,17 +168,28 @@ DSYMMUCOPY_M = symm_ucopy_rvv_v1.c DSYMMLCOPY_M = symm_lcopy_rvv_v1.c endif -CGEMMKERNEL = ../generic/zgemmkernel_2x2.c -CGEMMONCOPY = ../generic/zgemm_ncopy_2.c -CGEMMOTCOPY = ../generic/zgemm_tcopy_2.c -CGEMMONCOPYOBJ = cgemm_oncopy.o -CGEMMOTCOPYOBJ = cgemm_otcopy.o +CGEMMKERNEL = zgemmkernel_rvv_v1x4.c +CGEMMINCOPY = zgemm_ncopy_rvv_v1.c +CGEMMITCOPY = zgemm_tcopy_rvv_v1.c +CGEMMONCOPY = zgemm_ncopy_4_rvv.c +CGEMMOTCOPY = zgemm_tcopy_4_rvv.c -ZGEMMKERNEL = ../generic/zgemmkernel_2x2.c -ZGEMMONCOPY = ../generic/zgemm_ncopy_2.c -ZGEMMOTCOPY = ../generic/zgemm_tcopy_2.c -ZGEMMONCOPYOBJ = zgemm_oncopy.o -ZGEMMOTCOPYOBJ = zgemm_otcopy.o +CGEMMINCOPYOBJ = cgemm_incopy$(TSUFFIX).$(SUFFIX) +CGEMMITCOPYOBJ = cgemm_itcopy$(TSUFFIX).$(SUFFIX) +CGEMMONCOPYOBJ = cgemm_oncopy$(TSUFFIX).$(SUFFIX) +CGEMMOTCOPYOBJ = cgemm_otcopy$(TSUFFIX).$(SUFFIX) + +ZGEMMKERNEL = zgemmkernel_rvv_v1x4.c + +ZGEMMINCOPY = zgemm_ncopy_rvv_v1.c +ZGEMMITCOPY = zgemm_tcopy_rvv_v1.c +ZGEMMONCOPY = zgemm_ncopy_4_rvv.c +ZGEMMOTCOPY = zgemm_tcopy_4_rvv.c + +ZGEMMINCOPYOBJ = zgemm_incopy$(TSUFFIX).$(SUFFIX) +ZGEMMITCOPYOBJ = zgemm_itcopy$(TSUFFIX).$(SUFFIX) +ZGEMMONCOPYOBJ = zgemm_oncopy$(TSUFFIX).$(SUFFIX) +ZGEMMOTCOPYOBJ = zgemm_otcopy$(TSUFFIX).$(SUFFIX) STRSMKERNEL_LN = trsm_kernel_LN_rvv_v1.c STRSMKERNEL_LT = trsm_kernel_LT_rvv_v1.c @@ -190,20 +201,25 @@ DTRSMKERNEL_LT = trsm_kernel_LT_rvv_v1.c DTRSMKERNEL_RN = trsm_kernel_RN_rvv_v1.c DTRSMKERNEL_RT = trsm_kernel_RT_rvv_v1.c -CTRSMKERNEL_LN = ../generic/trsm_kernel_LN.c -CTRSMKERNEL_LT = ../generic/trsm_kernel_LT.c -CTRSMKERNEL_RN = ../generic/trsm_kernel_RN.c -CTRSMKERNEL_RT = ../generic/trsm_kernel_RT.c +CTRSMKERNEL_LN = trsm_kernel_LN_rvv_v1.c +CTRSMKERNEL_LT = trsm_kernel_LT_rvv_v1.c +CTRSMKERNEL_RN = trsm_kernel_RN_rvv_v1.c +CTRSMKERNEL_RT = trsm_kernel_RT_rvv_v1.c -ZTRSMKERNEL_LN = ../generic/trsm_kernel_LN.c -ZTRSMKERNEL_LT = ../generic/trsm_kernel_LT.c -ZTRSMKERNEL_RN = ../generic/trsm_kernel_RN.c -ZTRSMKERNEL_RT = ../generic/trsm_kernel_RT.c +ZTRSMKERNEL_LN = trsm_kernel_LN_rvv_v1.c +ZTRSMKERNEL_LT = trsm_kernel_LT_rvv_v1.c +ZTRSMKERNEL_RN = trsm_kernel_RN_rvv_v1.c +ZTRSMKERNEL_RT = trsm_kernel_RT_rvv_v1.c -TRSMCOPYLN_M = trsm_lncopy_rvv_v1.c -TRSMCOPYLT_M = trsm_ltcopy_rvv_v1.c -TRSMCOPYUN_M = trsm_uncopy_rvv_v1.c -TRSMCOPYUT_M = trsm_utcopy_rvv_v1.c +TRSMCOPYLN_M = trsm_lncopy_rvv_v1.c +TRSMCOPYLT_M = trsm_ltcopy_rvv_v1.c +TRSMCOPYUN_M = trsm_uncopy_rvv_v1.c +TRSMCOPYUT_M = trsm_utcopy_rvv_v1.c + +ZTRSMCOPYLN_M = ztrsm_lncopy_rvv_v1.c +ZTRSMCOPYLT_M = ztrsm_ltcopy_rvv_v1.c +ZTRSMCOPYUN_M = ztrsm_uncopy_rvv_v1.c +ZTRSMCOPYUT_M = ztrsm_utcopy_rvv_v1.c SSYMV_U_KERNEL = symv_U_rvv.c SSYMV_L_KERNEL = symv_L_rvv.c @@ -214,6 +230,27 @@ CSYMV_L_KERNEL = ../generic/zsymv_k.c ZSYMV_U_KERNEL = ../generic/zsymv_k.c ZSYMV_L_KERNEL = ../generic/zsymv_k.c +ZHEMMLTCOPY_M = zhemm_ltcopy_rvv_v1.c +ZHEMMUTCOPY_M = zhemm_utcopy_rvv_v1.c + +CHEMMLTCOPY_M = zhemm_ltcopy_rvv_v1.c +CHEMMUTCOPY_M = zhemm_utcopy_rvv_v1.c + +ZSYMMUCOPY_M = zsymm_ucopy_rvv_v1.c +ZSYMMLCOPY_M = zsymm_lcopy_rvv_v1.c + +CSYMMUCOPY_M = zsymm_ucopy_rvv_v1.c +CSYMMLCOPY_M = zsymm_lcopy_rvv_v1.c + +ZTRMMUNCOPY_M = ztrmm_uncopy_rvv_v1.c +ZTRMMLNCOPY_M = ztrmm_lncopy_rvv_v1.c +ZTRMMUTCOPY_M = ztrmm_utcopy_rvv_v1.c +ZTRMMLTCOPY_M = ztrmm_ltcopy_rvv_v1.c + +CTRMMUNCOPY_M = ztrmm_uncopy_rvv_v1.c +CTRMMLNCOPY_M = ztrmm_lncopy_rvv_v1.c +CTRMMUTCOPY_M = ztrmm_utcopy_rvv_v1.c +CTRMMLTCOPY_M = ztrmm_ltcopy_rvv_v1.c LSAME_KERNEL = ../generic/lsame.c diff --git a/kernel/riscv64/trmm_lncopy_rvv_v1.c b/kernel/riscv64/trmm_lncopy_rvv_v1.c index 73a8233f8..3457ca3e1 100644 --- a/kernel/riscv64/trmm_lncopy_rvv_v1.c +++ b/kernel/riscv64/trmm_lncopy_rvv_v1.c @@ -36,10 +36,10 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define VSEV_FLOAT vse32_v_f32m2 #define VLSEV_FLOAT vlse32_v_f32m2 #define VBOOL_T vbool16_t -#define UINT_V_T vint32m2_t -#define VID_V_UINT vid_v_i32m2 -#define VMSGTU_VX_UINT vmsgt_vx_i32m2_b16 -#define VMSEQ_VX_UINT vmseq_vx_i32m2_b16 +#define UINT_V_T vuint32m2_t +#define VID_V_UINT vid_v_u32m2 +#define VMSGTU_VX_UINT vmsgtu_vx_u32m2_b16 +#define VMSEQ_VX_UINT vmseq_vx_u32m2_b16 #define VFMERGE_VFM_FLOAT vfmerge_vfm_f32m2 #else #define VSETVL(n) vsetvl_e64m2(n) diff --git a/kernel/riscv64/trsm_kernel_LN_rvv_v1.c b/kernel/riscv64/trsm_kernel_LN_rvv_v1.c index 11a0398ca..2cba06b38 100644 --- a/kernel/riscv64/trsm_kernel_LN_rvv_v1.c +++ b/kernel/riscv64/trsm_kernel_LN_rvv_v1.c @@ -31,28 +31,31 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define VSETVL(n) vsetvl_e32m2(n) #define VSETVL_MAX vsetvlmax_e32m2() #define FLOAT_V_T vfloat32m2_t -#define VLEV_FLOAT vle32_v_f32m2 #define VLSEV_FLOAT vlse32_v_f32m2 -#define VLSEG2_FLOAT vlseg2e32_v_f32m2 -#define VSEV_FLOAT vse32_v_f32m2 #define VSSEV_FLOAT vsse32_v_f32m2 +#define VSEV_FLOAT vse32_v_f32m2 +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 #define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m2 +#define VSSSEG2_FLOAT vssseg2e32_v_f32m2 #define VFMACCVF_FLOAT vfmacc_vf_f32m2 -#define VFMULVF_FLOAT vfmul_vf_f32m2 #define VFNMSACVF_FLOAT vfnmsac_vf_f32m2 +#define VFMULVF_FLOAT vfmul_vf_f32m2 #else #define VSETVL(n) vsetvl_e64m2(n) #define VSETVL_MAX vsetvlmax_e64m2() #define FLOAT_V_T vfloat64m2_t -#define VLEV_FLOAT vle64_v_f64m2 #define VLSEV_FLOAT vlse64_v_f64m2 -#define VLSEG2_FLOAT vlseg2e64_v_f64m2 -#define VSEV_FLOAT vse64_v_f64m2 #define VSSEV_FLOAT vsse64_v_f64m2 +#define VSEV_FLOAT vse64_v_f64m2 +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 #define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m2 +#define VSSSEG2_FLOAT vssseg2e64_v_f64m2 +#define VFMVVF_FLOAT vfmv_v_f_f64m2 #define VFMACCVF_FLOAT vfmacc_vf_f64m2 -#define VFMULVF_FLOAT vfmul_vf_f64m2 #define VFNMSACVF_FLOAT vfnmsac_vf_f64m2 +#define VFMULVF_FLOAT vfmul_vf_f64m2 #endif @@ -88,606 +91,107 @@ static FLOAT dm1 = -1.; #ifndef COMPLEX -#if GEMM_DEFAULT_UNROLL_N == 1 - static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa, bb; - FLOAT *pa, *pc; + FLOAT aa; + FLOAT* pc; int i, j, k; - //fprintf(stderr, "%s , %s, m = %4ld n = %4ld offset = %4ld\n", __FILE__, __FUNCTION__, m, n, ldc); // Debug + + BLASLONG stride_ldc = sizeof(FLOAT) * ldc; + + FLOAT_V_T vb, vc; size_t vl; - FLOAT_V_T va, vc; a += (m - 1) * m; b += (m - 1) * n; - for (i = m - 1; i >= 0; i--) - { + for (i = m - 1; i >= 0; i--) { + aa = *(a + i); - for (j = 0; j < n; j ++) - { - bb = *(c + i + j * ldc); - bb *= aa; - *b = bb; - *(c + i + j * ldc) = bb; - b ++; + pc = c; + for (j = n; j > 0; j -= vl) { + vl = VSETVL(j); + vb = VLSEV_FLOAT(pc + i, stride_ldc, vl); + vb = VFMULVF_FLOAT(vb, aa, vl); + VSEV_FLOAT(b, vb, vl); + VSSEV_FLOAT(pc + i, stride_ldc, vb, vl); + b += vl; - pa = a; - pc = c + j * ldc; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc = VLEV_FLOAT(pc, vl); - va = VLEV_FLOAT(pa, vl); - vc = VFNMSACVF_FLOAT(vc, bb, va, vl); - VSEV_FLOAT(pc, vc, vl); - pa += vl; - pc += vl; + for (k = 0; k < i; k ++) { + vc = VLSEV_FLOAT(pc + k, stride_ldc, vl); + vc = VFNMSACVF_FLOAT(vc, *(a + k), vb, vl); + VSSEV_FLOAT(pc + k, stride_ldc, vc, vl); } + pc += vl * ldc; } a -= m; b -= 2 * n; } -} -#elif GEMM_DEFAULT_UNROLL_N == 2 - -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa, bb0, bb1; - FLOAT *pa, *pc, *pc0, *pc1; - FLOAT *pb0, *pb1; - - int i, j, k; - fprintf(stderr, "%s , %s, m = %4ld n = %4ld offset = %4ld\n", __FILE__, __FUNCTION__, m, n, ldc); // Debug - - size_t vl; - FLOAT_V_T va, vc0, vc1; - - a += (m - 1) * m; - b += (m - 1) * n; - - for (i = m - 1; i >= 0; i--) - { - aa = *(a + i); - pc = c + i; - for (j = 0; j < n/2; j ++) - { - //bb = *(c + i + j * ldc); - pb0 = pc + j * ldc * 2; - pb1 = pb0 + ldc; - //bb *= aa; - bb0 = (*pb0) * aa; - bb1 = (*pb1) * aa; - //*b = bb; - *b = bb0; - *(b+1) = bb1; - *pb0 = bb0; - *pb1 = bb1; - - //*(c + i + j * ldc) = bb; - //b ++; - - b += 2; - //pa = a + i + 1; - pc0 = c + j * ldc * 2; - pc1 = pc0 + ldc; - pa = a; - //pc = c + j * ldc; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - vc1 = VLEV_FLOAT(pc1, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - vc1 = VFNMSACVF_FLOAT(vc1, bb1, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - VSEV_FLOAT(pc1, vc1, vl); - - pa += vl; - pc0 += vl; - pc1 += vl; - } - } - pc += ldc * (n/2) * 2; - if (n & 1) - { - pb0 = pc; - bb0 = (*pb0) * aa; - *b = bb0; - *pb0 = bb0; - b += 1; - - pc0 = pc - i; - pa = a; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - - pa += vl; - pc0 += vl; - } - } - - a -= m; - b -= 2 * n; - } - -} - -#elif GEMM_DEFAULT_UNROLL_N == 4 - -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa, bb0, bb1, bb2, bb3; - FLOAT *pa, *pc, *pc0, *pc1, *pc2, *pc3; - FLOAT *pb0, *pb1, *pb2, *pb3; - - int i, j, k; - - size_t vl; - FLOAT_V_T va, vc0, vc1, vc2, vc3; - - a += (m - 1) * m; - b += (m - 1) * n; - - for (i = m - 1; i >= 0; i--) - { - aa = *(a + i); - pc = c + i; - for (j = 0; j < n/4; j ++) - { - pb0 = pc + j * ldc * 4; - pb1 = pb0 + ldc; - pb2 = pb1 + ldc; - pb3 = pb2 + ldc; - - bb0 = (*pb0) * aa; - bb1 = (*pb1) * aa; - bb2 = (*pb2) * aa; - bb3 = (*pb3) * aa; - - *b = bb0; - *(b+1) = bb1; - *(b+2) = bb2; - *(b+3) = bb3; - - *pb0 = bb0; - *pb1 = bb1; - *pb2 = bb2; - *pb3 = bb3; - - b += 4; - - pc0 = c + j * ldc * 4; - pc1 = pc0 + ldc; - pc2 = pc1 + ldc; - pc3 = pc2 + ldc; - - pa = a; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - vc1 = VLEV_FLOAT(pc1, vl); - vc2 = VLEV_FLOAT(pc2, vl); - vc3 = VLEV_FLOAT(pc3, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - vc1 = VFNMSACVF_FLOAT(vc1, bb1, va, vl); - vc2 = VFNMSACVF_FLOAT(vc2, bb2, va, vl); - vc3 = VFNMSACVF_FLOAT(vc3, bb3, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - VSEV_FLOAT(pc1, vc1, vl); - VSEV_FLOAT(pc2, vc2, vl); - VSEV_FLOAT(pc3, vc3, vl); - - pa += vl; - pc0 += vl; - pc1 += vl; - pc2 += vl; - pc3 += vl; - } - } - pc += ldc * (n/4) * 4; - - if (n & 2) - { - pb0 = pc + j * ldc * 2; - pb1 = pb0 + ldc; - - bb0 = (*pb0) * aa; - bb1 = (*pb1) * aa; - - *b = bb0; - *(b+1) = bb1; - - *pb0 = bb0; - *pb1 = bb1; - - b += 2; - - pc0 = c + j * ldc * 2; - pc1 = pc0 + ldc; - - pa = a; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - vc1 = VLEV_FLOAT(pc1, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - vc1 = VFNMSACVF_FLOAT(vc1, bb1, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - VSEV_FLOAT(pc1, vc1, vl); - - pa += vl; - pc0 += vl; - pc1 += vl; - } - pc += ldc * 2; - } - - if (n & 1) - { - pb0 = pc; - bb0 = (*pb0) * aa; - *b = bb0; - *pb0 = bb0; - b += 1; - - pc0 = pc - i; - pa = a; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - - pa += vl; - pc0 += vl; - } - } - - a -= m; - b -= 2 * n; - } - -} -#elif GEMM_DEFAULT_UNROLL_N == 8 - -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa, bb0, bb1, bb2, bb3, bb4, bb5, bb6, bb7; - FLOAT *pa, *pc, *pc0, *pc1, *pc2, *pc3, *pc4, *pc5, *pc6, *pc7; - FLOAT *pb0, *pb1, *pb2, *pb3, *pb4, *pb5, *pb6, *pb7; - - int i, j, k; - - size_t vl; - FLOAT_V_T va, vc0, vc1, vc2, vc3, vc4, vc5, vc6, vc7; - - a += (m - 1) * m; - b += (m - 1) * n; - - for (i = m - 1; i >= 0; i--) - { - aa = *(a + i); - pc = c + i; - for (j = 0; j < n/8; j ++) - { - pb0 = pc + j * ldc * 8; - pb1 = pb0 + ldc; - pb2 = pb1 + ldc; - pb3 = pb2 + ldc; - pb4 = pb3 + ldc; - pb5 = pb4 + ldc; - pb6 = pb5 + ldc; - pb7 = pb6 + ldc; - - bb0 = (*pb0) * aa; - bb1 = (*pb1) * aa; - bb2 = (*pb2) * aa; - bb3 = (*pb3) * aa; - bb4 = (*pb4) * aa; - bb5 = (*pb5) * aa; - bb6 = (*pb6) * aa; - bb7 = (*pb7) * aa; - - *b = bb0; - *(b+1) = bb1; - *(b+2) = bb2; - *(b+3) = bb3; - *(b+4) = bb4; - *(b+5) = bb5; - *(b+6) = bb6; - *(b+7) = bb7; - - *pb0 = bb0; - *pb1 = bb1; - *pb2 = bb2; - *pb3 = bb3; - *pb4 = bb4; - *pb5 = bb5; - *pb6 = bb6; - *pb7 = bb7; - - b += 8; - - pc0 = c + j * ldc * 8; - pc1 = pc0 + ldc; - pc2 = pc1 + ldc; - pc3 = pc2 + ldc; - pc4 = pc3 + ldc; - pc5 = pc4 + ldc; - pc6 = pc5 + ldc; - pc7 = pc6 + ldc; - - pa = a; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - vc1 = VLEV_FLOAT(pc1, vl); - vc2 = VLEV_FLOAT(pc2, vl); - vc3 = VLEV_FLOAT(pc3, vl); - vc4 = VLEV_FLOAT(pc4, vl); - vc5 = VLEV_FLOAT(pc5, vl); - vc6 = VLEV_FLOAT(pc6, vl); - vc7 = VLEV_FLOAT(pc7, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - vc1 = VFNMSACVF_FLOAT(vc1, bb1, va, vl); - vc2 = VFNMSACVF_FLOAT(vc2, bb2, va, vl); - vc3 = VFNMSACVF_FLOAT(vc3, bb3, va, vl); - vc4 = VFNMSACVF_FLOAT(vc4, bb4, va, vl); - vc5 = VFNMSACVF_FLOAT(vc5, bb5, va, vl); - vc6 = VFNMSACVF_FLOAT(vc6, bb6, va, vl); - vc7 = VFNMSACVF_FLOAT(vc7, bb7, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - VSEV_FLOAT(pc1, vc1, vl); - VSEV_FLOAT(pc2, vc2, vl); - VSEV_FLOAT(pc3, vc3, vl); - VSEV_FLOAT(pc4, vc4, vl); - VSEV_FLOAT(pc5, vc5, vl); - VSEV_FLOAT(pc6, vc6, vl); - VSEV_FLOAT(pc7, vc7, vl); - - pa += vl; - pc0 += vl; - pc1 += vl; - pc2 += vl; - pc3 += vl; - pc4 += vl; - pc5 += vl; - pc6 += vl; - pc7 += vl; - } - } - pc += ldc * (n/8) * 8; - - if (n & 4) - { - pb0 = pc + j * ldc * 4; - pb1 = pb0 + ldc; - pb2 = pb1 + ldc; - pb3 = pb2 + ldc; - - bb0 = (*pb0) * aa; - bb1 = (*pb1) * aa; - bb2 = (*pb2) * aa; - bb3 = (*pb3) * aa; - - *b = bb0; - *(b+1) = bb1; - *(b+2) = bb2; - *(b+3) = bb3; - - *pb0 = bb0; - *pb1 = bb1; - *pb2 = bb2; - *pb3 = bb3; - - b += 4; - - pc0 = c + j * ldc * 4; - pc1 = pc0 + ldc; - pc2 = pc1 + ldc; - pc3 = pc2 + ldc; - - pa = a; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - vc1 = VLEV_FLOAT(pc1, vl); - vc2 = VLEV_FLOAT(pc2, vl); - vc3 = VLEV_FLOAT(pc3, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - vc1 = VFNMSACVF_FLOAT(vc1, bb1, va, vl); - vc2 = VFNMSACVF_FLOAT(vc2, bb2, va, vl); - vc3 = VFNMSACVF_FLOAT(vc3, bb3, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - VSEV_FLOAT(pc1, vc1, vl); - VSEV_FLOAT(pc2, vc2, vl); - VSEV_FLOAT(pc3, vc3, vl); - - pa += vl; - pc0 += vl; - pc1 += vl; - pc2 += vl; - pc3 += vl; - } - pc += ldc * 4; - } - - if (n & 2) - { - pb0 = pc + j * ldc * 2; - pb1 = pb0 + ldc; - - bb0 = (*pb0) * aa; - bb1 = (*pb1) * aa; - - *b = bb0; - *(b+1) = bb1; - - *pb0 = bb0; - *pb1 = bb1; - - b += 2; - - pc0 = c + j * ldc * 2; - pc1 = pc0 + ldc; - - pa = a; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - vc1 = VLEV_FLOAT(pc1, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - vc1 = VFNMSACVF_FLOAT(vc1, bb1, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - VSEV_FLOAT(pc1, vc1, vl); - - pa += vl; - pc0 += vl; - pc1 += vl; - } - pc += ldc * 2; - } - - if (n & 1) - { - pb0 = pc; - bb0 = (*pb0) * aa; - *b = bb0; - *pb0 = bb0; - b += 1; - - pc0 = pc - i; - pa = a; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - - pa += vl; - pc0 += vl; - } - } - - a -= m; - b -= 2 * n; - } - } -#else -static inline void solve_generic(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa, bb; - - int i, j, k; - - a += (m - 1) * m; - b += (m - 1) * n; - - for (i = m - 1; i >= 0; i--) { - - aa = *(a + i); - - for (j = 0; j < n; j ++) { - bb = *(c + i + j * ldc); - bb *= aa; - *b = bb; - *(c + i + j * ldc) = bb; - b ++; - - for (k = 0; k < i; k ++){ - *(c + k + j * ldc) -= bb * *(a + k); - } - - } - a -= m; - b -= 2 * n; - } - -} - -#endif - #else static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - FLOAT aa1, aa2; - FLOAT bb1, bb2; - FLOAT cc1, cc2; + FLOAT aa1, aa2; + FLOAT *pc; + int i, j, k; - int i, j, k; + BLASLONG stride_ldc = sizeof(FLOAT) * ldc * 2; - ldc *= 2; - a += (m - 1) * m * 2; - b += (m - 1) * n * 2; + FLOAT_V_T vb1, vb2, vc1, vc2, vs1, vs2; + size_t vl; + a += (m - 1) * m * 2; + b += (m - 1) * n * 2; - for (i = m - 1; i >= 0; i--) { + for (i = m - 1; i >= 0; i--) { - aa1 = *(a + i * 2 + 0); - aa2 = *(a + i * 2 + 1); - - for (j = 0; j < n; j ++) { - bb1 = *(c + i * 2 + 0 + j * ldc); - bb2 = *(c + i * 2 + 1 + j * ldc); + aa1 = *(a + i * 2 + 0); + aa2 = *(a + i * 2 + 1); + pc = c; + for (j = n; j > 0; j -= vl) { + vl = VSETVL(j); + VLSSEG2_FLOAT(&vb1, &vb2, pc + i * 2, stride_ldc, vl); #ifndef CONJ - cc1 = aa1 * bb1 - aa2 * bb2; - cc2 = aa1 * bb2 + aa2 * bb1; + vs1 = VFMULVF_FLOAT(vb1, aa1, vl); + vs1 = VFNMSACVF_FLOAT(vs1, aa2, vb2, vl); + vs2 = VFMULVF_FLOAT(vb2, aa1, vl); + vs2 = VFMACCVF_FLOAT(vs2, aa2, vb1, vl); #else - cc1 = aa1 * bb1 + aa2 * bb2; - cc2 = aa1 * bb2 - aa2 * bb1; + vs1 = VFMULVF_FLOAT(vb1, aa1, vl); + vs1 = VFMACCVF_FLOAT(vs1, aa2, vb2, vl); + vs2 = VFMULVF_FLOAT(vb2, aa1, vl); + vs2 = VFNMSACVF_FLOAT(vs2, aa2, vb1, vl); #endif + VSSEG2_FLOAT(b, vs1, vs2, vl); + VSSSEG2_FLOAT(pc + i * 2, stride_ldc, vs1, vs2, vl); + b += vl * 2; - - *(b + 0) = cc1; - *(b + 1) = cc2; - *(c + i * 2 + 0 + j * ldc) = cc1; - *(c + i * 2 + 1 + j * ldc) = cc2; - b += 2; - - for (k = 0; k < i; k ++){ + for (k = 0; k < i; k ++) { + VLSSEG2_FLOAT(&vc1, &vc2, pc + k * 2, stride_ldc, vl); #ifndef CONJ - *(c + k * 2 + 0 + j * ldc) -= cc1 * *(a + k * 2 + 0) - cc2 * *(a + k * 2 + 1); - *(c + k * 2 + 1 + j * ldc) -= cc1 * *(a + k * 2 + 1) + cc2 * *(a + k * 2 + 0); -#else - *(c + k * 2 + 0 + j * ldc) -= cc1 * *(a + k * 2 + 0) + cc2 * *(a + k * 2 + 1); - *(c + k * 2 + 1 + j * ldc) -= - cc1 * *(a + k * 2 + 1) + cc2 * *(a + k * 2 + 0); + vc1 = VFMACCVF_FLOAT(vc1, *(a + k * 2 + 1), vs2, vl); + vc1 = VFNMSACVF_FLOAT(vc1, *(a + k * 2 + 0), vs1, vl); + vc2 = VFNMSACVF_FLOAT(vc2, *(a + k * 2 + 1), vs1, vl); + vc2 = VFNMSACVF_FLOAT(vc2, *(a + k * 2 + 0), vs2, vl); +#else + vc1 = VFNMSACVF_FLOAT(vc1, *(a + k * 2 + 1), vs2, vl); + vc1 = VFNMSACVF_FLOAT(vc1, *(a + k * 2 + 0), vs1, vl); + vc2 = VFMACCVF_FLOAT(vc2, *(a + k * 2 + 1), vs1, vl); + vc2 = VFNMSACVF_FLOAT(vc2, *(a + k * 2 + 0), vs2, vl); #endif - } - + VSSSEG2_FLOAT(pc + k * 2, stride_ldc, vc1, vc2, vl); + } + pc += vl * ldc * 2; + } + a -= m * 2; + b -= 4 * n; } - a -= m * 2; - b -= 4 * n; - } - } + #endif diff --git a/kernel/riscv64/trsm_kernel_LT_rvv_v1.c b/kernel/riscv64/trsm_kernel_LT_rvv_v1.c index 0380bd1bb..492a5631f 100644 --- a/kernel/riscv64/trsm_kernel_LT_rvv_v1.c +++ b/kernel/riscv64/trsm_kernel_LT_rvv_v1.c @@ -31,28 +31,31 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define VSETVL(n) vsetvl_e32m2(n) #define VSETVL_MAX vsetvlmax_e32m2() #define FLOAT_V_T vfloat32m2_t -#define VLEV_FLOAT vle32_v_f32m2 #define VLSEV_FLOAT vlse32_v_f32m2 -#define VLSEG2_FLOAT vlseg2e32_v_f32m2 -#define VSEV_FLOAT vse32_v_f32m2 #define VSSEV_FLOAT vsse32_v_f32m2 +#define VSEV_FLOAT vse32_v_f32m2 +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 #define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m2 +#define VSSSEG2_FLOAT vssseg2e32_v_f32m2 #define VFMACCVF_FLOAT vfmacc_vf_f32m2 -#define VFMULVF_FLOAT vfmul_vf_f32m2 #define VFNMSACVF_FLOAT vfnmsac_vf_f32m2 +#define VFMULVF_FLOAT vfmul_vf_f32m2 #else #define VSETVL(n) vsetvl_e64m2(n) #define VSETVL_MAX vsetvlmax_e64m2() #define FLOAT_V_T vfloat64m2_t -#define VLEV_FLOAT vle64_v_f64m2 #define VLSEV_FLOAT vlse64_v_f64m2 -#define VLSEG2_FLOAT vlseg2e64_v_f64m2 -#define VSEV_FLOAT vse64_v_f64m2 #define VSSEV_FLOAT vsse64_v_f64m2 +#define VSEV_FLOAT vse64_v_f64m2 +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 #define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m2 +#define VSSSEG2_FLOAT vssseg2e64_v_f64m2 +#define VFMVVF_FLOAT vfmv_v_f_f64m2 #define VFMACCVF_FLOAT vfmacc_vf_f64m2 -#define VFMULVF_FLOAT vfmul_vf_f64m2 #define VFNMSACVF_FLOAT vfnmsac_vf_f64m2 +#define VFMULVF_FLOAT vfmul_vf_f64m2 #endif @@ -87,616 +90,101 @@ static FLOAT dm1 = -1.; // Optimizes the implementation in ../arm64/trsm_kernel_LT_sve.c #ifndef COMPLEX -#if GEMM_DEFAULT_UNROLL_N == 1 -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) -{ - FLOAT aa, bb; - FLOAT *pa, *pc; +static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { + + FLOAT aa; + FLOAT* pc; int i, j, k; + + BLASLONG stride_ldc = sizeof(FLOAT) * ldc; + + FLOAT_V_T vb, vc; + size_t vl; - FLOAT_V_T va, vc; - for (i = 0; i < m; i++) - { + + for (i = 0; i < m; i++) { + aa = *(a + i); - for (j = 0; j < n; j ++) - { - bb = *(c + i + j * ldc); - bb *= aa; - *b = bb; - *(c + i + j * ldc) = bb; - b++; - pa = a + i + 1; - pc = c + j * ldc + i + 1; - for (k = (m - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc = VLEV_FLOAT(pc, vl); - va = VLEV_FLOAT(pa, vl); - vc = VFNMSACVF_FLOAT(vc, bb, va, vl); - VSEV_FLOAT(pc, vc, vl); - pa += vl; - pc += vl; + pc = c; + for (j = n; j > 0; j -= vl) { + vl = VSETVL(j); + vb = VLSEV_FLOAT(pc + i, stride_ldc, vl); + vb = VFMULVF_FLOAT(vb, aa, vl); + VSEV_FLOAT(b, vb, vl); + VSSEV_FLOAT(pc + i, stride_ldc, vb, vl); + b += vl; + + for (k = i + 1; k < m; k++) { + vc = VLSEV_FLOAT(pc + k, stride_ldc, vl); + vc = VFNMSACVF_FLOAT(vc, *(a + k), vb, vl); + VSSEV_FLOAT(pc + k, stride_ldc, vc, vl); } + pc += vl * ldc; } a += m; } } -#elif GEMM_DEFAULT_UNROLL_N == 2 - -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) -{ - - FLOAT aa, bb0, bb1; - FLOAT *pa, *pc, *pc0, *pc1; - FLOAT *pb0, *pb1; - - int i, j, k; - size_t vl; - FLOAT_V_T va, vc0, vc1; - for (i = 0; i < m; i++) - { - aa = *(a + i); - pc = c + i; - for (j = 0; j < n/2; j ++) - { - pb0 = pc + j * ldc * 2; - pb1 = pb0 + ldc; - bb0 = (*pb0) * aa; - bb1 = (*pb1) * aa; - *b = bb0; - *(b+1) = bb1; - *pb0 = bb0; - *pb1 = bb1; - b += 2; - pa = a + i + 1; - pc0 = pb0 + 1; - pc1 = pc0 + ldc; - for (k = (m - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - vc1 = VLEV_FLOAT(pc1, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - vc1 = VFNMSACVF_FLOAT(vc1, bb1, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - VSEV_FLOAT(pc1, vc1, vl); - pa += vl; - pc0 += vl; - pc1 += vl; - } - } - pc += ldc * (n/2) * 2; - if (n & 1) - { - pb0 = pc; - bb0 = *(pb0); - bb0 *= aa; - *b = bb0; - *(c + i) = bb0; - b++; - pa = a + i + 1; - pc0 = pb0 + 1; - for (k = (m - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - pa += vl; - pc0 += vl; - } - } - - a += m; - } -} -#elif GEMM_DEFAULT_UNROLL_N == 4 - -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) -{ - - FLOAT aa, bb0, bb1, bb2, bb3; - FLOAT *pa, *pc; - FLOAT *pc0, *pc1, *pc2, *pc3; - FLOAT *pb0, *pb1, *pb2, *pb3; - - int i, j, k; - size_t vl; - FLOAT_V_T va; - FLOAT_V_T vc0, vc1, vc2, vc3; - for (i = 0; i < m; i++) - { - aa = *(a + i); - pc = c + i; - for (j = 0; j < n/4; j ++) - { - pb0 = pc; - pb1 = pb0 + ldc; - pb2 = pb1 + ldc; - pb3 = pb2 + ldc; - - bb0 = (*pb0) * aa; - bb1 = (*pb1) * aa; - bb2 = (*pb2) * aa; - bb3 = (*pb3) * aa; - - *b = bb0; - *(b+1) = bb1; - *(b+2) = bb2; - *(b+3) = bb3; - - *pb0 = bb0; - *pb1 = bb1; - *pb2 = bb2; - *pb3 = bb3; - b += 4; - - pa = a + i + 1; - pc0 = pb0 + 1; - pc1 = pc0 + ldc; - pc2 = pc1 + ldc; - pc3 = pc2 + ldc; - - for (k = (m - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - vc1 = VLEV_FLOAT(pc1, vl); - vc2 = VLEV_FLOAT(pc2, vl); - vc3 = VLEV_FLOAT(pc3, vl); - - va = VLEV_FLOAT(pa, vl); - - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - vc1 = VFNMSACVF_FLOAT(vc1, bb1, va, vl); - vc2 = VFNMSACVF_FLOAT(vc2, bb2, va, vl); - vc3 = VFNMSACVF_FLOAT(vc3, bb3, va, vl); - - VSEV_FLOAT(pc0, vc0, vl); - VSEV_FLOAT(pc1, vc1, vl); - VSEV_FLOAT(pc2, vc2, vl); - VSEV_FLOAT(pc3, vc3, vl); - - pa += vl; - pc0 += vl; - pc1 += vl; - pc2 += vl; - pc3 += vl; - } - } - pc += ldc * (n/4) * 4; - - if (n & 2) - { - pb0 = pc; - pb1 = pb0 + ldc; - bb0 = (*pb0) * aa; - bb1 = (*pb1) * aa; - *b = bb0; - *(b+1) = bb1; - *pb0 = bb0; - *pb1 = bb1; - b += 2; - pa = a + i + 1; - pc0 = pb0 + 1; - pc1 = pc0 + ldc; - for (k = (m - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - vc1 = VLEV_FLOAT(pc1, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - vc1 = VFNMSACVF_FLOAT(vc1, bb1, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - VSEV_FLOAT(pc1, vc1, vl); - pa += vl; - pc0 += vl; - pc1 += vl; - } - pc += ldc * 2; - } - - if (n & 1) - { - pb0 = pc; - bb0 = *(pb0); - bb0 *= aa; - *b = bb0; - *(c + i) = bb0; - b++; - pa = a + i + 1; - pc0 = pb0 + 1; - for (k = (m - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - pa += vl; - pc0 += vl; - } - } - - a += m; - } -} -#elif GEMM_DEFAULT_UNROLL_N == 8 - -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) -{ - - FLOAT aa, bb0, bb1, bb2, bb3, bb4, bb5, bb6, bb7; - FLOAT *pa, *pc; - FLOAT *pc0, *pc1, *pc2, *pc3, *pc4, *pc5, *pc6, *pc7; - FLOAT *pb0, *pb1, *pb2, *pb3, *pb4, *pb5, *pb6, *pb7; - - int i, j, k; - size_t vl; - FLOAT_V_T va; - FLOAT_V_T vc0, vc1, vc2, vc3, vc4, vc5, vc6, vc7; - for (i = 0; i < m; i++) - { - aa = *(a + i); - pc = c + i; - for (j = 0; j < n/8; j ++) - { - pb0 = pc + j * ldc * 8; - pb1 = pb0 + ldc; - pb2 = pb1 + ldc; - pb3 = pb2 + ldc; - pb4 = pb3 + ldc; - pb5 = pb4 + ldc; - pb6 = pb5 + ldc; - pb7 = pb6 + ldc; - - bb0 = (*pb0) * aa; - bb1 = (*pb1) * aa; - bb2 = (*pb2) * aa; - bb3 = (*pb3) * aa; - bb4 = (*pb4) * aa; - bb5 = (*pb5) * aa; - bb6 = (*pb6) * aa; - bb7 = (*pb7) * aa; - - *b = bb0; - *(b+1) = bb1; - *(b+2) = bb2; - *(b+3) = bb3; - *(b+4) = bb4; - *(b+5) = bb5; - *(b+6) = bb6; - *(b+7) = bb7; - - *pb0 = bb0; - *pb1 = bb1; - *pb2 = bb2; - *pb3 = bb3; - *pb4 = bb4; - *pb5 = bb5; - *pb6 = bb6; - *pb7 = bb7; - b += 8; - - pa = a + i + 1; - pc0 = pb0 + 1; - pc1 = pc0 + ldc; - pc2 = pc1 + ldc; - pc3 = pc2 + ldc; - pc4 = pc3 + ldc; - pc5 = pc4 + ldc; - pc6 = pc5 + ldc; - pc7 = pc6 + ldc; - - for (k = (m - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - vc1 = VLEV_FLOAT(pc1, vl); - vc2 = VLEV_FLOAT(pc2, vl); - vc3 = VLEV_FLOAT(pc3, vl); - vc4 = VLEV_FLOAT(pc4, vl); - vc5 = VLEV_FLOAT(pc5, vl); - vc6 = VLEV_FLOAT(pc6, vl); - vc7 = VLEV_FLOAT(pc7, vl); - - va = VLEV_FLOAT(pa, vl); - - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - vc1 = VFNMSACVF_FLOAT(vc1, bb1, va, vl); - vc2 = VFNMSACVF_FLOAT(vc2, bb2, va, vl); - vc3 = VFNMSACVF_FLOAT(vc3, bb3, va, vl); - vc4 = VFNMSACVF_FLOAT(vc4, bb4, va, vl); - vc5 = VFNMSACVF_FLOAT(vc5, bb5, va, vl); - vc6 = VFNMSACVF_FLOAT(vc6, bb6, va, vl); - vc7 = VFNMSACVF_FLOAT(vc7, bb7, va, vl); - - VSEV_FLOAT(pc0, vc0, vl); - VSEV_FLOAT(pc1, vc1, vl); - VSEV_FLOAT(pc2, vc2, vl); - VSEV_FLOAT(pc3, vc3, vl); - VSEV_FLOAT(pc4, vc4, vl); - VSEV_FLOAT(pc5, vc5, vl); - VSEV_FLOAT(pc6, vc6, vl); - VSEV_FLOAT(pc7, vc7, vl); - - pa += vl; - pc0 += vl; - pc1 += vl; - pc2 += vl; - pc3 += vl; - pc4 += vl; - pc5 += vl; - pc6 += vl; - pc7 += vl; - } - } - pc += ldc * (n/8) * 8; - - if (n & 4) - { - pb0 = pc; - pb1 = pb0 + ldc; - pb2 = pb1 + ldc; - pb3 = pb2 + ldc; - - bb0 = (*pb0) * aa; - bb1 = (*pb1) * aa; - bb2 = (*pb2) * aa; - bb3 = (*pb3) * aa; - - *b = bb0; - *(b+1) = bb1; - *(b+2) = bb2; - *(b+3) = bb3; - - *pb0 = bb0; - *pb1 = bb1; - *pb2 = bb2; - *pb3 = bb3; - b += 4; - - pa = a + i + 1; - pc0 = pb0 + 1; - pc1 = pc0 + ldc; - pc2 = pc1 + ldc; - pc3 = pc2 + ldc; - - for (k = (m - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - vc1 = VLEV_FLOAT(pc1, vl); - vc2 = VLEV_FLOAT(pc2, vl); - vc3 = VLEV_FLOAT(pc3, vl); - - va = VLEV_FLOAT(pa, vl); - - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - vc1 = VFNMSACVF_FLOAT(vc1, bb1, va, vl); - vc2 = VFNMSACVF_FLOAT(vc2, bb2, va, vl); - vc3 = VFNMSACVF_FLOAT(vc3, bb3, va, vl); - - VSEV_FLOAT(pc0, vc0, vl); - VSEV_FLOAT(pc1, vc1, vl); - VSEV_FLOAT(pc2, vc2, vl); - VSEV_FLOAT(pc3, vc3, vl); - - pa += vl; - pc0 += vl; - pc1 += vl; - pc2 += vl; - pc3 += vl; - } - pc += ldc * 4; - } - - if (n & 2) - { - pb0 = pc; - pb1 = pb0 + ldc; - bb0 = (*pb0) * aa; - bb1 = (*pb1) * aa; - *b = bb0; - *(b+1) = bb1; - *pb0 = bb0; - *pb1 = bb1; - b += 2; - pa = a + i + 1; - pc0 = pb0 + 1; - pc1 = pc0 + ldc; - for (k = (m - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - vc1 = VLEV_FLOAT(pc1, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - vc1 = VFNMSACVF_FLOAT(vc1, bb1, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - VSEV_FLOAT(pc1, vc1, vl); - pa += vl; - pc0 += vl; - pc1 += vl; - } - pc += ldc * 2; - } - - if (n & 1) - { - pb0 = pc; - bb0 = *(pb0); - bb0 *= aa; - *b = bb0; - *(c + i) = bb0; - b++; - pa = a + i + 1; - pc0 = pb0 + 1; - for (k = (m - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLEV_FLOAT(pc0, vl); - va = VLEV_FLOAT(pa, vl); - vc0 = VFNMSACVF_FLOAT(vc0, bb0, va, vl); - VSEV_FLOAT(pc0, vc0, vl); - pa += vl; - pc0 += vl; - } - } - - a += m; - } -} #else static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - FLOAT aa, bb; + FLOAT aa1, aa2; + FLOAT *pc; + int i, j, k; - int i, j, k; + BLASLONG stride_ldc = sizeof(FLOAT) * ldc * 2; - for (i = 0; i < m; i++) { + FLOAT_V_T vb1, vb2, vc1, vc2, vs1, vs2; + size_t vl; - aa = *(a + i); + ldc *= 2; - for (j = 0; j < n; j ++) { - bb = *(c + i + j * ldc); - bb *= aa; - *b = bb; - *(c + i + j * ldc) = bb; - b ++; - - for (k = i + 1; k < m; k ++){ - *(c + k + j * ldc) -= bb * *(a + k); - } - - } - a += m; - } -} - -#endif - -#else - -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa1, aa2; - FLOAT bb1, bb2; - FLOAT cc1, cc2; - - int i, j, k; - - ldc *= 2; - - for (i = 0; i < m; i++) { - - aa1 = *(a + i * 2 + 0); - aa2 = *(a + i * 2 + 1); - - for (j = 0; j < n; j ++) { - bb1 = *(c + i * 2 + 0 + j * ldc); - bb2 = *(c + i * 2 + 1 + j * ldc); + for (i = 0; i < m; i++) { + aa1 = *(a + i * 2 + 0); + aa2 = *(a + i * 2 + 1); + pc = c; + for (j = n; j > 0; j -= vl) { + vl = VSETVL(j); + VLSSEG2_FLOAT(&vb1, &vb2, pc + i * 2, stride_ldc, vl); #ifndef CONJ - cc1 = aa1 * bb1 - aa2 * bb2; - cc2 = aa1 * bb2 + aa2 * bb1; + vs1 = VFMULVF_FLOAT(vb1, aa1, vl); + vs1 = VFNMSACVF_FLOAT(vs1, aa2, vb2, vl); + vs2 = VFMULVF_FLOAT(vb2, aa1, vl); + vs2 = VFMACCVF_FLOAT(vs2, aa2, vb1, vl); #else - cc1 = aa1 * bb1 + aa2 * bb2; - cc2 = aa1 * bb2 - aa2 * bb1; + vs1 = VFMULVF_FLOAT(vb1, aa1, vl); + vs1 = VFMACCVF_FLOAT(vs1, aa2, vb2, vl); + vs2 = VFMULVF_FLOAT(vb2, aa1, vl); + vs2 = VFNMSACVF_FLOAT(vs2, aa2, vb1, vl); #endif + VSSEG2_FLOAT(b, vs1, vs2, vl); + VSSSEG2_FLOAT(pc + i * 2, stride_ldc, vs1, vs2, vl); + b += vl * 2; - *(b + 0) = cc1; - *(b + 1) = cc2; - *(c + i * 2 + 0 + j * ldc) = cc1; - *(c + i * 2 + 1 + j * ldc) = cc2; - b += 2; - - for (k = i + 1; k < m; k ++){ + for (k = i + 1; k < m; k++) { + VLSSEG2_FLOAT(&vc1, &vc2, pc + k * 2, stride_ldc, vl); #ifndef CONJ - *(c + k * 2 + 0 + j * ldc) -= cc1 * *(a + k * 2 + 0) - cc2 * *(a + k * 2 + 1); - *(c + k * 2 + 1 + j * ldc) -= cc1 * *(a + k * 2 + 1) + cc2 * *(a + k * 2 + 0); -#else - *(c + k * 2 + 0 + j * ldc) -= cc1 * *(a + k * 2 + 0) + cc2 * *(a + k * 2 + 1); - *(c + k * 2 + 1 + j * ldc) -= -cc1 * *(a + k * 2 + 1) + cc2 * *(a + k * 2 + 0); + vc1 = VFMACCVF_FLOAT(vc1, *(a + k * 2 + 1), vs2, vl); + vc1 = VFNMSACVF_FLOAT(vc1, *(a + k * 2 + 0), vs1, vl); + vc2 = VFNMSACVF_FLOAT(vc2, *(a + k * 2 + 1), vs1, vl); + vc2 = VFNMSACVF_FLOAT(vc2, *(a + k * 2 + 0), vs2, vl); +#else + vc1 = VFNMSACVF_FLOAT(vc1, *(a + k * 2 + 1), vs2, vl); + vc1 = VFNMSACVF_FLOAT(vc1, *(a + k * 2 + 0), vs1, vl); + vc2 = VFMACCVF_FLOAT(vc2, *(a + k * 2 + 1), vs1, vl); + vc2 = VFNMSACVF_FLOAT(vc2, *(a + k * 2 + 0), vs2, vl); #endif - } - - } - a += m * 2; - } -} - - -static inline void solve_N1(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa1, aa2; - FLOAT bb1, bb2; - FLOAT cc1, cc2; - FLOAT *pa, *pc; - - int i, j, k; - - size_t vl; - FLOAT_V_T va0, va1, vc0, vc1; - - ldc *= 2; - - for (i = 0; i < m; i++) { - - aa1 = *(a + i * 2 + 0); - aa2 = *(a + i * 2 + 1); - - for (j = 0; j < n; j ++) { - bb1 = *(c + i * 2 + 0 + j * ldc); - bb2 = *(c + i * 2 + 1 + j * ldc); - -#ifndef CONJ - cc1 = aa1 * bb1 - aa2 * bb2; - cc2 = aa1 * bb2 + aa2 * bb1; -#else - cc1 = aa1 * bb1 + aa2 * bb2; - cc2 = aa1 * bb2 - aa2 * bb1; -#endif - - *(b + 0) = cc1; - *(b + 1) = cc2; - *(c + i * 2 + 0 + j * ldc) = cc1; - *(c + i * 2 + 1 + j * ldc) = cc2; - b += 2; - - pa = a + (i + 1) * 2; - pc = c + j * ldc + (i + 1) * 2; - for (k = (m - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - VLSEG2_FLOAT(&va0, &va1, pa, vl); - VLSEG2_FLOAT(&vc0, &vc1, pc, vl); -#ifndef CONJ - vc0 = VFNMSACVF_FLOAT(vc0, cc1, va0); - vc0 = VFMACCVF_FLOAT(vc0, cc2, va1); - vc1 = VFNMSACVF_FLOAT(vc1, cc1, va1); - vc1 = VFNMSACVF_FLOAT(vc1, cc2, va0); -#else - vc0 = VFNMSACVF_FLOAT(vc0, cc1, va0); - vc0 = VFNMSACVF_FLOAT(vc0, cc2, va1); - vc1 = VFMACCVF_FLOAT(vc1, cc1, va1); - vc1 = VFNMSACVF_FLOAT(vc1, cc2, va0); -#endif - VSSEG2_FLOAT(pc, vc0, vc1, vl); - pa += vl * 2; - pc += vl * 2; + VSSSEG2_FLOAT(pc + k * 2, stride_ldc, vc1, vc2, vl); + } + pc += vl * ldc * 2; } - } + + a += m * 2; } - a += m * 2; - } } #endif @@ -714,7 +202,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, size_t vl = VSETVL_MAX; - //fprintf(stderr, "%s , %s, m = %4ld n = %4ld k = %4ld offset = %4ld\n", __FILE__, __FUNCTION__, m, n, k, offset); // Debug + //fprintf(stderr, "%s , %s, m = %4ld n = %4ld k = %4ld offset = %4ld\n", __FILE__, __FUNCTION__, m, n, k, offset); // Debug j = (n >> GEMM_UNROLL_N_SHIFT); diff --git a/kernel/riscv64/trsm_kernel_RN_rvv_v1.c b/kernel/riscv64/trsm_kernel_RN_rvv_v1.c index 41368be60..4751ae012 100644 --- a/kernel/riscv64/trsm_kernel_RN_rvv_v1.c +++ b/kernel/riscv64/trsm_kernel_RN_rvv_v1.c @@ -32,28 +32,32 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define VSETVL_MAX vsetvlmax_e32m2() #define FLOAT_V_T vfloat32m2_t #define VLEV_FLOAT vle32_v_f32m2 -#define VLSEV_FLOAT vlse32_v_f32m2 -#define VLSEG2_FLOAT vlseg2e32_v_f32m2 -#define VSEV_FLOAT vse32_v_f32m2 #define VSSEV_FLOAT vsse32_v_f32m2 +#define VSEV_FLOAT vse32_v_f32m2 +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 #define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m2 +#define VSSSEG2_FLOAT vssseg2e32_v_f32m2 #define VFMACCVF_FLOAT vfmacc_vf_f32m2 #define VFNMSACVF_FLOAT vfnmsac_vf_f32m2 +#define VFMULVF_FLOAT vfmul_vf_f32m2 #else #define VSETVL(n) vsetvl_e64m2(n) #define VSETVL_MAX vsetvlmax_e64m2() #define FLOAT_V_T vfloat64m2_t #define VLEV_FLOAT vle64_v_f64m2 -#define VLSEV_FLOAT vlse64_v_f64m2 -#define VLSEG2_FLOAT vlseg2e64_v_f64m2 -#define VSEV_FLOAT vse64_v_f64m2 #define VSSEV_FLOAT vsse64_v_f64m2 +#define VSEV_FLOAT vse64_v_f64m2 +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 #define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m2 +#define VSSSEG2_FLOAT vssseg2e64_v_f64m2 +#define VFMVVF_FLOAT vfmv_v_f_f64m2 #define VFMACCVF_FLOAT vfmacc_vf_f64m2 #define VFNMSACVF_FLOAT vfnmsac_vf_f64m2 +#define VFMULVF_FLOAT vfmul_vf_f64m2 #endif - static FLOAT dm1 = -1.; #ifdef CONJ @@ -86,569 +90,99 @@ static FLOAT dm1 = -1.; #ifndef COMPLEX -#if GEMM_DEFAULT_UNROLL_N == 1 - -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa, bb; - FLOAT *pb, *pc; - BLASLONG stride_ldc = sizeof(FLOAT) * ldc; - int i, j, k; - size_t vl; - FLOAT_V_T vb, vc; - - for (i = 0; i < n; i++) - { - bb = *(b + i); - - for (j = 0; j < m; j ++) - { - aa = *(c + j + i * ldc); - aa *= bb; - *a = aa; - *(c + j + i * ldc) = aa; - a ++; - - pb = b + i + 1; - pc = c + j + (i + 1) *ldc; - for (k = (n - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc = VLSEV_FLOAT(pc, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - vc = VFNMSACVF_FLOAT(vc, aa, vb, vl); - VSSEV_FLOAT(pc, stride_ldc, vc, vl); - pb += vl; - pc ++; - } - } - b += n; - } -} - -#elif GEMM_DEFAULT_UNROLL_N == 2 - -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa0, aa1, bb; - FLOAT *pb, *pc; - FLOAT *pa0, *pa1, *pc0, *pc1; - BLASLONG stride_ldc = sizeof(FLOAT) * ldc; - int i, j, k; - size_t vl; - FLOAT_V_T vb, vc0, vc1; - - for (i = 0; i < n; i++) - { - bb = *(b + i); - pc = c + i * ldc; - for (j = 0; j < m/2; j ++) - { - pa0 = pc + j * 2; - pa1 = pc + j * 2 + 1; - aa0 = *pa0 * bb; - aa1 = *pa1 * bb; - - *pa0 = aa0; - *pa1 = aa1; - *a = aa0; - *(a + 1)= aa1; - a += 2; - - pb = b + i + 1; - pc0 = pa0 + ldc; - pc1 = pa1 + ldc; - for (k = (n - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vc1 = VLSEV_FLOAT(pc1, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - vc1 = VFNMSACVF_FLOAT(vc1, aa1, vb, vl); - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - VSSEV_FLOAT(pc1, stride_ldc, vc1, vl); - pb += vl; - pc0++; - pc1++; - } - } - pc += (m/2)*2; - if (m & 1) - { - pa0 = pc; - aa0 = *pa0 * bb; - - *pa0 = aa0; - *a = aa0; - a += 1; - - pb = b + i + 1; - pc0 = pa0 + ldc; - for (k = (n - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - pb += vl; - pc0++; - } - } - b += n; - } -} - -#elif GEMM_DEFAULT_UNROLL_N == 4 - static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { FLOAT bb; - FLOAT aa0, aa1, aa2, aa3; - FLOAT *pb, *pc; - FLOAT *pa0, *pa1, *pa2, *pa3; - FLOAT *pc0, *pc1, *pc2, *pc3; - BLASLONG stride_ldc = sizeof(FLOAT) * ldc; + FLOAT *pci, *pcj; + int i, j, k; + FLOAT_V_T va, vc; + size_t vl; - FLOAT_V_T vb, vc0, vc1, vc2, vc3; + for (i = 0; i < n; i++) { - for (i = 0; i < n; i++) - { bb = *(b + i); - pc = c + i * ldc; - for (j = 0; j < m/4; j ++) - { - pa0 = pc + j * 4; - pa1 = pa0 + 1; - pa2 = pa1 + 1; - pa3 = pa2 + 1; - - aa0 = *pa0 * bb; - aa1 = *pa1 * bb; - aa2 = *pa2 * bb; - aa3 = *pa3 * bb; - - *pa0 = aa0; - *pa1 = aa1; - *pa2 = aa2; - *pa3 = aa3; - - *a = aa0; - *(a + 1)= aa1; - *(a + 2)= aa2; - *(a + 3)= aa3; - - a += 4; - - pb = b + i + 1; - pc0 = pa0 + ldc; - pc1 = pa1 + ldc; - pc2 = pa2 + ldc; - pc3 = pa3 + ldc; - for (k = (n - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vc1 = VLSEV_FLOAT(pc1, stride_ldc, vl); - vc2 = VLSEV_FLOAT(pc2, stride_ldc, vl); - vc3 = VLSEV_FLOAT(pc3, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - vc1 = VFNMSACVF_FLOAT(vc1, aa1, vb, vl); - vc2 = VFNMSACVF_FLOAT(vc2, aa2, vb, vl); - vc3 = VFNMSACVF_FLOAT(vc3, aa3, vb, vl); - - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - VSSEV_FLOAT(pc1, stride_ldc, vc1, vl); - VSSEV_FLOAT(pc2, stride_ldc, vc2, vl); - VSSEV_FLOAT(pc3, stride_ldc, vc3, vl); - - pb += vl; - pc0++; - pc1++; - pc2++; - pc3++; - } - } - pc += (m/4)*4; - - if (m & 2) - { - pa0 = pc; - pa1 = pa0 + 1; - - aa0 = *pa0 * bb; - aa1 = *pa1 * bb; - - *pa0 = aa0; - *pa1 = aa1; - - *a = aa0; - *(a + 1)= aa1; - - a += 2; - - pb = b + i + 1; - pc0 = pa0 + ldc; - pc1 = pa1 + ldc; - for (k = (n - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vc1 = VLSEV_FLOAT(pc1, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - vc1 = VFNMSACVF_FLOAT(vc1, aa1, vb, vl); - - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - VSSEV_FLOAT(pc1, stride_ldc, vc1, vl); - - pb += vl; - pc0++; - pc1++; - } - pc += 2; - } - - if (m & 1) - { - pa0 = pc; - aa0 = *pa0 * bb; - - *pa0 = aa0; - *a = aa0; - a += 1; - - pb = b + i + 1; - pc0 = pa0 + ldc; - for (k = (n - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - pb += vl; - pc0++; + pci = c + i * ldc; + pcj = c; + for (j = m; j > 0; j -= vl) { + vl = VSETVL(j); + va = VLEV_FLOAT(pci, vl); + va = VFMULVF_FLOAT(va, bb, vl); + VSEV_FLOAT(a, va, vl); + VSEV_FLOAT(pci, va, vl); + a += vl; + pci += vl; + for (k = i + 1; k < n; k ++){ + vc = VLEV_FLOAT(pcj + k * ldc, vl); + vc = VFNMSACVF_FLOAT(vc, *(b + k), va, vl); + VSEV_FLOAT(pcj + k * ldc, vc, vl); } + pcj += vl; } b += n; } } -#elif GEMM_DEFAULT_UNROLL_N == 8 +#else static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - FLOAT bb; - FLOAT aa0, aa1, aa2, aa3, aa4, aa5, aa6, aa7; - FLOAT *pb, *pc; - FLOAT *pa0, *pa1, *pa2, *pa3, *pa4, *pa5, *pa6, *pa7; - FLOAT *pc0, *pc1, *pc2, *pc3, *pc4, *pc5, *pc6, *pc7; - BLASLONG stride_ldc = sizeof(FLOAT) * ldc; + FLOAT bb1, bb2; + + FLOAT *pci, *pcj; + int i, j, k; + + FLOAT_V_T va1, va2, vs1, vs2, vc1, vc2; + size_t vl; - FLOAT_V_T vb, vc0, vc1, vc2, vc3, vc4, vc5, vc6, vc7; - for (i = 0; i < n; i++) - { - bb = *(b + i); - pc = c + i * ldc; - for (j = 0; j < m/8; j ++) - { - pa0 = pc + j * 8; - pa1 = pa0 + 1; - pa2 = pa1 + 1; - pa3 = pa2 + 1; - pa4 = pa3 + 1; - pa5 = pa4 + 1; - pa6 = pa5 + 1; - pa7 = pa6 + 1; - - aa0 = *pa0 * bb; - aa1 = *pa1 * bb; - aa2 = *pa2 * bb; - aa3 = *pa3 * bb; - aa4 = *pa4 * bb; - aa5 = *pa5 * bb; - aa6 = *pa6 * bb; - aa7 = *pa7 * bb; + for (i = 0; i < n; i++) { - *pa0 = aa0; - *pa1 = aa1; - *pa2 = aa2; - *pa3 = aa3; - *pa4 = aa4; - *pa5 = aa5; - *pa6 = aa6; - *pa7 = aa7; + bb1 = *(b + i * 2 + 0); + bb2 = *(b + i * 2 + 1); - *a = aa0; - *(a + 1)= aa1; - *(a + 2)= aa2; - *(a + 3)= aa3; - *(a + 4)= aa4; - *(a + 5)= aa5; - *(a + 6)= aa6; - *(a + 7)= aa7; - - a += 8; - - pb = b + i + 1; - pc0 = pa0 + ldc; - pc1 = pa1 + ldc; - pc2 = pa2 + ldc; - pc3 = pa3 + ldc; - pc4 = pa4 + ldc; - pc5 = pa5 + ldc; - pc6 = pa6 + ldc; - pc7 = pa7 + ldc; - for (k = (n - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vc1 = VLSEV_FLOAT(pc1, stride_ldc, vl); - vc2 = VLSEV_FLOAT(pc2, stride_ldc, vl); - vc3 = VLSEV_FLOAT(pc3, stride_ldc, vl); - vc4 = VLSEV_FLOAT(pc4, stride_ldc, vl); - vc5 = VLSEV_FLOAT(pc5, stride_ldc, vl); - vc6 = VLSEV_FLOAT(pc6, stride_ldc, vl); - vc7 = VLSEV_FLOAT(pc7, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - vc1 = VFNMSACVF_FLOAT(vc1, aa1, vb, vl); - vc2 = VFNMSACVF_FLOAT(vc2, aa2, vb, vl); - vc3 = VFNMSACVF_FLOAT(vc3, aa3, vb, vl); - vc4 = VFNMSACVF_FLOAT(vc4, aa4, vb, vl); - vc5 = VFNMSACVF_FLOAT(vc5, aa5, vb, vl); - vc6 = VFNMSACVF_FLOAT(vc6, aa6, vb, vl); - vc7 = VFNMSACVF_FLOAT(vc7, aa7, vb, vl); - - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - VSSEV_FLOAT(pc1, stride_ldc, vc1, vl); - VSSEV_FLOAT(pc2, stride_ldc, vc2, vl); - VSSEV_FLOAT(pc3, stride_ldc, vc3, vl); - VSSEV_FLOAT(pc4, stride_ldc, vc4, vl); - VSSEV_FLOAT(pc5, stride_ldc, vc5, vl); - VSSEV_FLOAT(pc6, stride_ldc, vc6, vl); - VSSEV_FLOAT(pc7, stride_ldc, vc7, vl); - - pb += vl; - pc0++; - pc1++; - pc2++; - pc3++; - pc4++; - pc5++; - pc6++; - pc7++; - } - } - pc += (m/8)*8; - - if (m & 4) - { - pa0 = pc; - pa1 = pa0 + 1; - pa2 = pa1 + 1; - pa3 = pa2 + 1; - - aa0 = *pa0 * bb; - aa1 = *pa1 * bb; - aa2 = *pa2 * bb; - aa3 = *pa3 * bb; - - *pa0 = aa0; - *pa1 = aa1; - *pa2 = aa2; - *pa3 = aa3; - - *a = aa0; - *(a + 1)= aa1; - *(a + 2)= aa2; - *(a + 3)= aa3; - - a += 4; - - pb = b + i + 1; - pc0 = pa0 + ldc; - pc1 = pa1 + ldc; - pc2 = pa2 + ldc; - pc3 = pa3 + ldc; - for (k = (n - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vc1 = VLSEV_FLOAT(pc1, stride_ldc, vl); - vc2 = VLSEV_FLOAT(pc2, stride_ldc, vl); - vc3 = VLSEV_FLOAT(pc3, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - vc1 = VFNMSACVF_FLOAT(vc1, aa1, vb, vl); - vc2 = VFNMSACVF_FLOAT(vc2, aa2, vb, vl); - vc3 = VFNMSACVF_FLOAT(vc3, aa3, vb, vl); - - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - VSSEV_FLOAT(pc1, stride_ldc, vc1, vl); - VSSEV_FLOAT(pc2, stride_ldc, vc2, vl); - VSSEV_FLOAT(pc3, stride_ldc, vc3, vl); - - pb += vl; - pc0++; - pc1++; - pc2++; - pc3++; - } - pc += 4; - } - - if (m & 2) - { - pa0 = pc; - pa1 = pa0 + 1; - - aa0 = *pa0 * bb; - aa1 = *pa1 * bb; - - *pa0 = aa0; - *pa1 = aa1; - - *a = aa0; - *(a + 1)= aa1; - - a += 2; - - pb = b + i + 1; - pc0 = pa0 + ldc; - pc1 = pa1 + ldc; - for (k = (n - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vc1 = VLSEV_FLOAT(pc1, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - vc1 = VFNMSACVF_FLOAT(vc1, aa1, vb, vl); - - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - VSSEV_FLOAT(pc1, stride_ldc, vc1, vl); - - pb += vl; - pc0++; - pc1++; - } - pc += 2; - } - - if (m & 1) - { - pa0 = pc; - aa0 = *pa0 * bb; - - *pa0 = aa0; - *a = aa0; - a += 1; - - pb = b + i + 1; - pc0 = pa0 + ldc; - for (k = (n - i - 1); k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - pb += vl; - pc0++; - } - } - b += n; - } -} -#else -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa, bb; - - int i, j, k; - - for (i = 0; i < n; i++) { - - bb = *(b + i); - - for (j = 0; j < m; j ++) { - aa = *(c + j + i * ldc); - aa *= bb; - *a = aa; - *(c + j + i * ldc) = aa; - a ++; - - for (k = i + 1; k < n; k ++){ - *(c + j + k * ldc) -= aa * *(b + k); - } - - } - b += n; - } -} - -#endif - -#else - -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa1, aa2; - FLOAT bb1, bb2; - FLOAT cc1, cc2; - - int i, j, k; - - ldc *= 2; - - for (i = 0; i < n; i++) { - - bb1 = *(b + i * 2 + 0); - bb2 = *(b + i * 2 + 1); - - for (j = 0; j < m; j ++) { - aa1 = *(c + j * 2 + 0 + i * ldc); - aa2 = *(c + j * 2 + 1 + i * ldc); + pci = c + i * ldc * 2; + pcj = c; + for (j = m; j > 0; j -= vl) { + vl = VSETVL(j); + VLSEG2_FLOAT(&va1, &va2, pci, vl); #ifndef CONJ - cc1 = aa1 * bb1 - aa2 * bb2; - cc2 = aa1 * bb2 + aa2 * bb1; + vs1 = VFMULVF_FLOAT(va1, bb1, vl); + vs1 = VFNMSACVF_FLOAT(vs1, bb2, va2, vl); + vs2 = VFMULVF_FLOAT(va1, bb2, vl); + vs2 = VFMACCVF_FLOAT(vs2, bb1, va2, vl); #else - cc1 = aa1 * bb1 + aa2 * bb2; - cc2 = -aa1 * bb2 + aa2 * bb1; + vs1 = VFMULVF_FLOAT(va1, bb1, vl); + vs1 = VFMACCVF_FLOAT(vs1, bb2, va2, vl); + vs2 = VFMULVF_FLOAT(va2, bb1, vl); + vs2 = VFNMSACVF_FLOAT(vs2, bb2, va1, vl); #endif + VSSEG2_FLOAT(a, vs1, vs2, vl); + VSSEG2_FLOAT(pci, vs1, vs2, vl); + a += vl * 2; + pci += vl * 2; - *(a + 0) = cc1; - *(a + 1) = cc2; - *(c + j * 2 + 0 + i * ldc) = cc1; - *(c + j * 2 + 1 + i * ldc) = cc2; - a += 2; - - for (k = i + 1; k < n; k ++){ + for (k = i + 1; k < n; k ++){ + VLSEG2_FLOAT(&vc1, &vc2, pcj + k * ldc * 2, vl); #ifndef CONJ - *(c + j * 2 + 0 + k * ldc) -= cc1 * *(b + k * 2 + 0) - cc2 * *(b + k * 2 + 1); - *(c + j * 2 + 1 + k * ldc) -= cc1 * *(b + k * 2 + 1) + cc2 * *(b + k * 2 + 0); + vc1 = VFMACCVF_FLOAT(vc1, *(b + k * 2 + 1), vs2, vl); + vc1 = VFNMSACVF_FLOAT(vc1, *(b + k * 2 + 0), vs1, vl); + vc2 = VFNMSACVF_FLOAT(vc2, *(b + k * 2 + 1), vs1, vl); + vc2 = VFNMSACVF_FLOAT(vc2, *(b + k * 2 + 0), vs2, vl); #else - *(c + j * 2 + 0 + k * ldc) -= cc1 * *(b + k * 2 + 0) + cc2 * *(b + k * 2 + 1); - *(c + j * 2 + 1 + k * ldc) -= - cc1 * *(b + k * 2 + 1) + cc2 * *(b + k * 2 + 0); + vc1 = VFNMSACVF_FLOAT(vc1, *(b + k * 2 + 0), vs1, vl); + vc1 = VFNMSACVF_FLOAT(vc1, *(b + k * 2 + 1), vs2, vl); + vc2 = VFMACCVF_FLOAT(vc2, *(b + k * 2 + 1), vs1, vl); + vc2 = VFNMSACVF_FLOAT(vc2, *(b + k * 2 + 0), vs2, vl); #endif - } - + VSSEG2_FLOAT(pcj + k * ldc * 2, vc1, vc2, vl); + } + pcj += vl * 2; + } + b += n * 2; } - b += n * 2; - } } #endif @@ -666,7 +200,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, size_t vl = VSETVL_MAX; - //fprintf(stderr, "%s , %s, m = %4ld n = %4ld k = %4ld offset = %4ld\n", __FILE__, __FUNCTION__, m, n, k, offset); // Debug + //fprintf(stderr, "%s , %s, m = %4ld n = %4ld k = %4ld offset = %4ld\n", __FILE__, __FUNCTION__, m, n, k, offset); // Debug j = (n >> GEMM_UNROLL_N_SHIFT); diff --git a/kernel/riscv64/trsm_kernel_RT_rvv_v1.c b/kernel/riscv64/trsm_kernel_RT_rvv_v1.c index 459c1663a..93a9e6916 100644 --- a/kernel/riscv64/trsm_kernel_RT_rvv_v1.c +++ b/kernel/riscv64/trsm_kernel_RT_rvv_v1.c @@ -32,25 +32,24 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define VSETVL_MAX vsetvlmax_e32m2() #define FLOAT_V_T vfloat32m2_t #define VLEV_FLOAT vle32_v_f32m2 -#define VLSEV_FLOAT vlse32_v_f32m2 -#define VLSEG2_FLOAT vlseg2e32_v_f32m2 #define VSEV_FLOAT vse32_v_f32m2 -#define VSSEV_FLOAT vsse32_v_f32m2 +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 #define VSSEG2_FLOAT vsseg2e32_v_f32m2 #define VFMACCVF_FLOAT vfmacc_vf_f32m2 #define VFNMSACVF_FLOAT vfnmsac_vf_f32m2 +#define VFMULVF_FLOAT vfmul_vf_f32m2 #else #define VSETVL(n) vsetvl_e64m2(n) #define VSETVL_MAX vsetvlmax_e64m2() #define FLOAT_V_T vfloat64m2_t #define VLEV_FLOAT vle64_v_f64m2 -#define VLSEV_FLOAT vlse64_v_f64m2 -#define VLSEG2_FLOAT vlseg2e64_v_f64m2 #define VSEV_FLOAT vse64_v_f64m2 -#define VSSEV_FLOAT vsse64_v_f64m2 +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 #define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define VFMVVF_FLOAT vfmv_v_f_f64m2 #define VFMACCVF_FLOAT vfmacc_vf_f64m2 #define VFNMSACVF_FLOAT vfnmsac_vf_f64m2 +#define VFMULVF_FLOAT vfmul_vf_f64m2 #endif @@ -86,497 +85,38 @@ static FLOAT dm1 = -1.; #ifndef COMPLEX -#if GEMM_DEFAULT_UNROLL_N == 1 static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - FLOAT aa, bb; - FLOAT *pb, *pc; - BLASLONG stride_ldc = sizeof(FLOAT) * ldc; - - int i, j, k; - size_t vl; - FLOAT_V_T vb, vc; - - a += (n - 1) * m; - b += (n - 1) * n; - - for (i = n - 1; i >= 0; i--) { - - bb = *(b + i); - - for (j = 0; j < m; j ++) { - aa = *(c + j + i * ldc); - aa *= bb; - *a = aa; - *(c + j + i * ldc) = aa; - a ++; - - pb = b; - pc = c + j; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc = VLSEV_FLOAT(pc, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - vc = VFNMSACVF_FLOAT(vc, aa, vb, vl); - VSSEV_FLOAT(pc, stride_ldc, vc, vl); - pb += vl; - pc++; - } - } - b -= n; - a -= 2 * m; - } - -} -#elif GEMM_DEFAULT_UNROLL_N == 2 - -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa0, aa1, bb; - FLOAT *pb, *pc; - FLOAT *pa0, *pa1, *pc0, *pc1; - BLASLONG stride_ldc = sizeof(FLOAT) * ldc; - int i, j, k; - size_t vl; - FLOAT_V_T vb, vc0, vc1; - - a += (n - 1) * m; - b += (n - 1) * n; - - for (i = n - 1; i >= 0; i--) - { - bb = *(b + i); - pc = c + i * ldc; - for (j = 0; j < m/2; j ++) - { - pa0 = pc + j * 2; - pa1 = pc + j * 2 + 1; - aa0 = *pa0 * bb; - aa1 = *pa1 * bb; - - *pa0 = aa0; - *pa1 = aa1; - *a = aa0; - *(a + 1)= aa1; - a += 2; - - pb = b; - pc0 = c + j * 2; - pc1 = pc0 + 1; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vc1 = VLSEV_FLOAT(pc1, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - vc1 = VFNMSACVF_FLOAT(vc1, aa1, vb, vl); - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - VSSEV_FLOAT(pc1, stride_ldc, vc1, vl); - pb += vl; - pc0++; - pc1++; - } - } - pc += (m/2)*2; - - if (m & 1) - { - pa0 = pc; - aa0 = *pa0 * bb; - - *pa0 = aa0; - *a = aa0; - a += 1; - - pb = b; - pc0 = pc - i * ldc; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - pb += vl; - pc0++; - } - } - b -= n; - a -= 2 * m; - } -} - -#elif GEMM_DEFAULT_UNROLL_N == 4 - -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa0, aa1, aa2, aa3; FLOAT bb; - FLOAT *pb, *pc; - FLOAT *pa0, *pa1, *pa2, *pa3; - FLOAT *pc0, *pc1, *pc2, *pc3; - BLASLONG stride_ldc = sizeof(FLOAT) * ldc; + FLOAT *pci, *pcj; + int i, j, k; + FLOAT_V_T va, vc; + size_t vl; - FLOAT_V_T vb, vc0, vc1, vc2, vc3; a += (n - 1) * m; b += (n - 1) * n; - for (i = n - 1; i >= 0; i--) - { + for (i = n - 1; i >= 0; i--) { + bb = *(b + i); - pc = c + i * ldc; - for (j = 0; j < m/4; j ++) - { - pa0 = pc + j * 4; - pa1 = pa0 + 1; - pa2 = pa1 + 1; - pa3 = pa2 + 1; - - aa0 = *pa0 * bb; - aa1 = *pa1 * bb; - aa2 = *pa2 * bb; - aa3 = *pa3 * bb; - - *pa0 = aa0; - *pa1 = aa1; - *pa2 = aa2; - *pa3 = aa3; - - *a = aa0; - *(a + 1)= aa1; - *(a + 2)= aa2; - *(a + 3)= aa3; - a += 4; - - pb = b; - pc0 = c + j * 4; - pc1 = pc0 + 1; - pc2 = pc1 + 1; - pc3 = pc2 + 1; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vc1 = VLSEV_FLOAT(pc1, stride_ldc, vl); - vc2 = VLSEV_FLOAT(pc2, stride_ldc, vl); - vc3 = VLSEV_FLOAT(pc3, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - vc1 = VFNMSACVF_FLOAT(vc1, aa1, vb, vl); - vc2 = VFNMSACVF_FLOAT(vc2, aa2, vb, vl); - vc3 = VFNMSACVF_FLOAT(vc3, aa3, vb, vl); - - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - VSSEV_FLOAT(pc1, stride_ldc, vc1, vl); - VSSEV_FLOAT(pc2, stride_ldc, vc2, vl); - VSSEV_FLOAT(pc3, stride_ldc, vc3, vl); - - pb += vl; - pc0++; - pc1++; - pc2++; - pc3++; - } - } - pc += (m/4)*4; - - if (m & 2) - { - pa0 = pc + j * 2; - pa1 = pa0 + 1; - - aa0 = *pa0 * bb; - aa1 = *pa1 * bb; - - *pa0 = aa0; - *pa1 = aa1; - - *a = aa0; - *(a + 1)= aa1; - a += 2; - - pb = b; - pc0 = c + j * 4; - pc1 = pc0 + 1; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vc1 = VLSEV_FLOAT(pc1, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - vc1 = VFNMSACVF_FLOAT(vc1, aa1, vb, vl); - - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - VSSEV_FLOAT(pc1, stride_ldc, vc1, vl); - - pb += vl; - pc0++; - pc1++; - } - pc += 2; - } - - if (m & 1) - { - pa0 = pc; - aa0 = *pa0 * bb; - - *pa0 = aa0; - *a = aa0; - a += 1; - - pb = b; - pc0 = pc - i * ldc; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - pb += vl; - pc0++; - } - } - b -= n; - a -= 2 * m; - } -} -#elif GEMM_DEFAULT_UNROLL_N == 8 - -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa0, aa1, aa2, aa3, aa4, aa5, aa6, aa7; - FLOAT bb; - FLOAT *pb, *pc; - FLOAT *pa0, *pa1, *pa2, *pa3, *pa4, *pa5, *pa6, *pa7; - FLOAT *pc0, *pc1, *pc2, *pc3, *pc4, *pc5, *pc6, *pc7; - BLASLONG stride_ldc = sizeof(FLOAT) * ldc; - int i, j, k; - size_t vl; - FLOAT_V_T vb, vc0, vc1, vc2, vc3, vc4, vc5, vc6, vc7; - - a += (n - 1) * m; - b += (n - 1) * n; - - for (i = n - 1; i >= 0; i--) - { - bb = *(b + i); - pc = c + i * ldc; - for (j = 0; j < m/8; j ++) - { - pa0 = pc + j * 8; - pa1 = pa0 + 1; - pa2 = pa1 + 1; - pa3 = pa2 + 1; - pa4 = pa3 + 1; - pa5 = pa4 + 1; - pa6 = pa5 + 1; - pa7 = pa6 + 1; - - aa0 = *pa0 * bb; - aa1 = *pa1 * bb; - aa2 = *pa2 * bb; - aa3 = *pa3 * bb; - aa4 = *pa4 * bb; - aa5 = *pa5 * bb; - aa6 = *pa6 * bb; - aa7 = *pa7 * bb; - - *pa0 = aa0; - *pa1 = aa1; - *pa2 = aa2; - *pa3 = aa3; - *pa4 = aa4; - *pa5 = aa5; - *pa6 = aa6; - *pa7 = aa7; - - *a = aa0; - *(a + 1)= aa1; - *(a + 2)= aa2; - *(a + 3)= aa3; - *(a + 4)= aa4; - *(a + 5)= aa5; - *(a + 6)= aa6; - *(a + 7)= aa7; - a += 8; - - pb = b; - pc0 = c + j * 8; - pc1 = pc0 + 1; - pc2 = pc1 + 1; - pc3 = pc2 + 1; - pc4 = pc3 + 1; - pc5 = pc4 + 1; - pc6 = pc5 + 1; - pc7 = pc6 + 1; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vc1 = VLSEV_FLOAT(pc1, stride_ldc, vl); - vc2 = VLSEV_FLOAT(pc2, stride_ldc, vl); - vc3 = VLSEV_FLOAT(pc3, stride_ldc, vl); - vc4 = VLSEV_FLOAT(pc4, stride_ldc, vl); - vc5 = VLSEV_FLOAT(pc5, stride_ldc, vl); - vc6 = VLSEV_FLOAT(pc6, stride_ldc, vl); - vc7 = VLSEV_FLOAT(pc7, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - vc1 = VFNMSACVF_FLOAT(vc1, aa1, vb, vl); - vc2 = VFNMSACVF_FLOAT(vc2, aa2, vb, vl); - vc3 = VFNMSACVF_FLOAT(vc3, aa3, vb, vl); - vc4 = VFNMSACVF_FLOAT(vc4, aa4, vb, vl); - vc5 = VFNMSACVF_FLOAT(vc5, aa5, vb, vl); - vc6 = VFNMSACVF_FLOAT(vc6, aa6, vb, vl); - vc7 = VFNMSACVF_FLOAT(vc7, aa7, vb, vl); - - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - VSSEV_FLOAT(pc1, stride_ldc, vc1, vl); - VSSEV_FLOAT(pc2, stride_ldc, vc2, vl); - VSSEV_FLOAT(pc3, stride_ldc, vc3, vl); - VSSEV_FLOAT(pc4, stride_ldc, vc4, vl); - VSSEV_FLOAT(pc5, stride_ldc, vc5, vl); - VSSEV_FLOAT(pc6, stride_ldc, vc6, vl); - VSSEV_FLOAT(pc7, stride_ldc, vc7, vl); - - pb += vl; - pc0++; - pc1++; - pc2++; - pc3++; - pc4++; - pc5++; - pc6++; - pc7++; - } - } - pc += (m/8)*8; - - if (m & 4) - { - pa0 = pc; - pa1 = pa0 + 1; - pa2 = pa1 + 1; - pa3 = pa2 + 1; - - aa0 = *pa0 * bb; - aa1 = *pa1 * bb; - aa2 = *pa2 * bb; - aa3 = *pa3 * bb; - - *pa0 = aa0; - *pa1 = aa1; - *pa2 = aa2; - *pa3 = aa3; - - *a = aa0; - *(a + 1)= aa1; - *(a + 2)= aa2; - *(a + 3)= aa3; - a += 4; - - pb = b; - pc0 = pc - i * ldc; - pc1 = pc0 + 1; - pc2 = pc1 + 1; - pc3 = pc2 + 1; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vc1 = VLSEV_FLOAT(pc1, stride_ldc, vl); - vc2 = VLSEV_FLOAT(pc2, stride_ldc, vl); - vc3 = VLSEV_FLOAT(pc3, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - vc1 = VFNMSACVF_FLOAT(vc1, aa1, vb, vl); - vc2 = VFNMSACVF_FLOAT(vc2, aa2, vb, vl); - vc3 = VFNMSACVF_FLOAT(vc3, aa3, vb, vl); - - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - VSSEV_FLOAT(pc1, stride_ldc, vc1, vl); - VSSEV_FLOAT(pc2, stride_ldc, vc2, vl); - VSSEV_FLOAT(pc3, stride_ldc, vc3, vl); - - pb += vl; - pc0++; - pc1++; - pc2++; - pc3++; - } - pc += 4; - } - - if (m & 2) - { - pa0 = pc; - pa1 = pa0 + 1; - - aa0 = *pa0 * bb; - aa1 = *pa1 * bb; - - *pa0 = aa0; - *pa1 = aa1; - - *a = aa0; - *(a + 1)= aa1; - a += 2; - - pb = b; - pc0 = pc - i * ldc; - pc1 = pc0 + 1; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vc1 = VLSEV_FLOAT(pc1, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - vc1 = VFNMSACVF_FLOAT(vc1, aa1, vb, vl); - - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - VSSEV_FLOAT(pc1, stride_ldc, vc1, vl); - - pb += vl; - pc0++; - pc1++; - } - pc += 2; - } - - if (m & 1) - { - pa0 = pc; - aa0 = *pa0 * bb; - - *pa0 = aa0; - *a = aa0; - a += 1; - - pb = b; - pc0 = pc - i * ldc; - for (k = i; k > 0; k -= vl) - { - vl = VSETVL(k); - vc0 = VLSEV_FLOAT(pc0, stride_ldc, vl); - vb = VLEV_FLOAT(pb, vl); - vc0 = VFNMSACVF_FLOAT(vc0, aa0, vb, vl); - VSSEV_FLOAT(pc0, stride_ldc, vc0, vl); - pb += vl; - pc0++; + pci = c + i * ldc; + pcj = c; + for (j = m; j > 0; j -= vl) { + vl = VSETVL(j); + va = VLEV_FLOAT(pci, vl); + va = VFMULVF_FLOAT(va, bb, vl); + VSEV_FLOAT(a, va, vl); + VSEV_FLOAT(pci, va, vl); + a += vl; + pci += vl; + for (k = 0; k < i; k ++){ + vc = VLEV_FLOAT(pcj + k * ldc, vl); + vc = VFNMSACVF_FLOAT(vc, *(b + k), va, vl); + VSEV_FLOAT(pcj + k * ldc, vc, vl); } + pcj += vl; } b -= n; a -= 2 * m; @@ -587,92 +127,65 @@ static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, B static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - FLOAT aa, bb; + FLOAT bb1, bb2; - int i, j, k; + FLOAT *pci, *pcj; - a += (n - 1) * m; - b += (n - 1) * n; + int i, j, k; - for (i = n - 1; i >= 0; i--) { + FLOAT_V_T va1, va2, vs1, vs2, vc1, vc2; - bb = *(b + i); + size_t vl; - for (j = 0; j < m; j ++) { - aa = *(c + j + i * ldc); - aa *= bb; - *a = aa; - *(c + j + i * ldc) = aa; - a ++; + a += (n - 1) * m * 2; + b += (n - 1) * n * 2; - for (k = 0; k < i; k ++){ - *(c + j + k * ldc) -= aa * *(b + k); - } + for (i = n - 1; i >= 0; i--) { - } - b -= n; - a -= 2 * m; - } - -} - -#endif - -#else - -static inline void solve(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) { - - FLOAT aa1, aa2; - FLOAT bb1, bb2; - FLOAT cc1, cc2; - - int i, j, k; - - ldc *= 2; - - a += (n - 1) * m * 2; - b += (n - 1) * n * 2; - - for (i = n - 1; i >= 0; i--) { - - bb1 = *(b + i * 2 + 0); - bb2 = *(b + i * 2 + 1); - - for (j = 0; j < m; j ++) { - - aa1 = *(c + j * 2 + 0 + i * ldc); - aa2 = *(c + j * 2 + 1 + i * ldc); + bb1 = *(b + i * 2 + 0); + bb2 = *(b + i * 2 + 1); + pci = c + i * ldc * 2; + pcj = c; + for (j = m; j > 0; j -= vl) { + vl = VSETVL(j); + VLSEG2_FLOAT(&va1, &va2, pci, vl); #ifndef CONJ - cc1 = aa1 * bb1 - aa2 * bb2; - cc2 = aa1 * bb2 + aa2 * bb1; + vs1 = VFMULVF_FLOAT(va1, bb1, vl); + vs1 = VFNMSACVF_FLOAT(vs1, bb2, va2, vl); + vs2 = VFMULVF_FLOAT(va1, bb2, vl); + vs2 = VFMACCVF_FLOAT(vs2, bb1, va2, vl); #else - cc1 = aa1 * bb1 + aa2 * bb2; - cc2 = - aa1 * bb2 + aa2 * bb1; + vs1 = VFMULVF_FLOAT(va1, bb1, vl); + vs1 = VFMACCVF_FLOAT(vs1, bb2, va2, vl); + vs2 = VFMULVF_FLOAT(va2, bb1, vl); + vs2 = VFNMSACVF_FLOAT(vs2, bb2, va1, vl); #endif + VSSEG2_FLOAT(a, vs1, vs2, vl); + VSSEG2_FLOAT(pci, vs1, vs2, vl); + a += vl * 2; + pci += vl * 2; - *(a + 0) = cc1; - *(a + 1) = cc2; - - *(c + j * 2 + 0 + i * ldc) = cc1; - *(c + j * 2 + 1 + i * ldc) = cc2; - a += 2; - - for (k = 0; k < i; k ++){ + for (k = 0; k < i; k ++){ + VLSEG2_FLOAT(&vc1, &vc2, pcj + k * ldc * 2, vl); #ifndef CONJ - *(c + j * 2 + 0 + k * ldc) -= cc1 * *(b + k * 2 + 0) - cc2 * *(b + k * 2 + 1); - *(c + j * 2 + 1 + k * ldc) -= cc1 * *(b + k * 2 + 1) + cc2 * *(b + k * 2 + 0); + vc1 = VFMACCVF_FLOAT(vc1, *(b + k * 2 + 1), vs2, vl); + vc1 = VFNMSACVF_FLOAT(vc1, *(b + k * 2 + 0), vs1, vl); + vc2 = VFNMSACVF_FLOAT(vc2, *(b + k * 2 + 1), vs1, vl); + vc2 = VFNMSACVF_FLOAT(vc2, *(b + k * 2 + 0), vs2, vl); #else - *(c + j * 2 + 0 + k * ldc) -= cc1 * *(b + k * 2 + 0) + cc2 * *(b + k * 2 + 1); - *(c + j * 2 + 1 + k * ldc) -= -cc1 * *(b + k * 2 + 1) + cc2 * *(b + k * 2 + 0); + vc1 = VFNMSACVF_FLOAT(vc1, *(b + k * 2 + 0), vs1, vl); + vc1 = VFNMSACVF_FLOAT(vc1, *(b + k * 2 + 1), vs2, vl); + vc2 = VFMACCVF_FLOAT(vc2, *(b + k * 2 + 1), vs1, vl); + vc2 = VFNMSACVF_FLOAT(vc2, *(b + k * 2 + 0), vs2, vl); #endif - } - + VSSEG2_FLOAT(pcj + k * ldc * 2, vc1, vc2, vl); + } + pcj += vl * 2; + } + b -= n * 2; + a -= 4 * m; } - b -= n * 2; - a -= 4 * m; - } - } #endif @@ -689,7 +202,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, size_t vl = VSETVL_MAX; - //fprintf(stderr, "%s , %s, m = %4ld n = %4ld k = %4ld offset = %4ld\n", __FILE__, __FUNCTION__, m, n, k, offset); // Debug + //fprintf(stderr, "%s , %s, m = %4ld n = %4ld k = %4ld offset = %4ld\n", __FILE__, __FUNCTION__, m, n, k, offset); // Debug kk = n - offset; c += n * ldc * COMPSIZE; diff --git a/kernel/riscv64/zgemm_ncopy_4_rvv.c b/kernel/riscv64/zgemm_ncopy_4_rvv.c new file mode 100644 index 000000000..389ee5d57 --- /dev/null +++ b/kernel/riscv64/zgemm_ncopy_4_rvv.c @@ -0,0 +1,121 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m1(n) +#define FLOAT_V_T vfloat32m1_t +#define VLSEG2_FLOAT vlseg2e32_v_f32m1 +#define VSSEG2_FLOAT vsseg2e32_v_f32m1 +#define VSSEG4_FLOAT vsseg4e32_v_f32m1 +#define VSSEG8_FLOAT vsseg8e32_v_f32m1 +#else +#define VSETVL(n) vsetvl_e64m1(n) +#define FLOAT_V_T vfloat64m1_t +#define VLSEG2_FLOAT vlseg2e64_v_f64m1 +#define VSSEG2_FLOAT vsseg2e64_v_f64m1 +#define VSSEG4_FLOAT vsseg4e64_v_f64m1 +#define VSSEG8_FLOAT vsseg8e64_v_f64m1 +#endif + +// Optimizes the implementation in ../generic/zgemm_ncopy_4.c + +int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){ + BLASLONG i, j; + + FLOAT *aoffset; + FLOAT *aoffset1, *aoffset2, *aoffset3, *aoffset4; + + FLOAT *boffset; + + FLOAT_V_T v11, v12, v21, v22, v31, v32, v41, v42; + size_t vl; + + aoffset = a; + boffset = b; + lda *= 2; + + for (j = (n >> 2); j > 0; j--) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset += 4 * lda; + + for (i = m; i > 0; i -= vl) { + vl = VSETVL(i); + VLSEG2_FLOAT(&v11, &v12, aoffset1, vl); + VLSEG2_FLOAT(&v21, &v22, aoffset2, vl); + VLSEG2_FLOAT(&v31, &v32, aoffset3, vl); + VLSEG2_FLOAT(&v41, &v42, aoffset4, vl); + + VSSEG8_FLOAT(boffset, v11, v12, v21, v22, v31, v32, v41, v42, vl); + + aoffset1 += vl * 2; + aoffset2 += vl * 2; + aoffset3 += vl * 2; + aoffset4 += vl * 2; + boffset += vl * 8; + } + } + + if (n & 2) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset += 2 * lda; + + for (i = m; i > 0; i -= vl) { + vl = VSETVL(i); + VLSEG2_FLOAT(&v11, &v12, aoffset1, vl); + VLSEG2_FLOAT(&v21, &v22, aoffset2, vl); + + VSSEG4_FLOAT(boffset, v11, v12, v21, v22, vl); + + aoffset1 += vl * 2; + aoffset2 += vl * 2; + boffset += vl * 4; + } + } + + if (n & 1) { + aoffset1 = aoffset; + aoffset += lda; + + for (i = m; i > 0; i -= vl) { + vl = VSETVL(i); + VLSEG2_FLOAT(&v11, &v12, aoffset1, vl); + + VSSEG2_FLOAT(boffset, v11, v12, vl); + + aoffset1 += vl * 2; + boffset += vl * 2; + } + } + + return 0; +} diff --git a/kernel/riscv64/zgemm_ncopy_rvv_v1.c b/kernel/riscv64/zgemm_ncopy_rvv_v1.c new file mode 100644 index 000000000..df039bab6 --- /dev/null +++ b/kernel/riscv64/zgemm_ncopy_rvv_v1.c @@ -0,0 +1,74 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define FLOAT_V_T vfloat32m2_t +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define FLOAT_V_T vfloat64m2_t +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#endif + +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){ + + BLASLONG i, j; + + FLOAT *a_offset; + FLOAT *a_offset1; + FLOAT *b_offset; + + FLOAT_V_T v0, v1; + size_t vl; + + //fprintf(stderr, "%s, m=%ld n=%ld lda=%ld\n", __FUNCTION__, m, n, lda); + a_offset = a; + b_offset = b; + + for(j = n; j > 0; j -= vl) { + vl = VSETVL(j); + + a_offset1 = a_offset; + a_offset += vl * lda * 2; + + for(i = m; i > 0; i--) { + VLSSEG2_FLOAT(&v0, &v1, a_offset1, lda * sizeof(FLOAT) * 2, vl); + VSSEG2_FLOAT(b_offset, v0, v1, vl); + + a_offset1 += 2; + b_offset += vl * 2; + } + } + return 0; +} + diff --git a/kernel/riscv64/zgemm_tcopy_4_rvv.c b/kernel/riscv64/zgemm_tcopy_4_rvv.c new file mode 100644 index 000000000..1b34039c8 --- /dev/null +++ b/kernel/riscv64/zgemm_tcopy_4_rvv.c @@ -0,0 +1,181 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m1(n) +#define FLOAT_V_T vfloat32m1_t +#define VLEV_FLOAT vle32_v_f32m1 +#define VSEV_FLOAT vse32_v_f32m1 +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m1 +#define VLSSEG4_FLOAT vlsseg4e32_v_f32m1 +#define VLSSEG8_FLOAT vlsseg8e32_v_f32m1 +#define VSSEG2_FLOAT vsseg2e32_v_f32m1 +#define VSSEG4_FLOAT vsseg4e32_v_f32m1 +#define VSSEG8_FLOAT vsseg8e32_v_f32m1 +#else +#define VSETVL(n) vsetvl_e64m1(n) +#define FLOAT_V_T vfloat64m1_t +#define VLEV_FLOAT vle64_v_f64m1 +#define VSEV_FLOAT vse64_v_f64m1 +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m1 +#define VLSSEG4_FLOAT vlsseg4e64_v_f64m1 +#define VLSSEG8_FLOAT vlsseg8e64_v_f64m1 +#define VSSEG2_FLOAT vsseg2e64_v_f64m1 +#define VSSEG4_FLOAT vsseg4e64_v_f64m1 +#define VSSEG8_FLOAT vsseg8e64_v_f64m1 +#endif + +int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){ + + BLASLONG i, j; + + IFLOAT *aoffset; + IFLOAT *aoffset1; + + IFLOAT *boffset, *boffset1, *boffset2, *boffset3; + + FLOAT_V_T v0, v1, v2, v3, v4, v5, v6, v7; + size_t vl; + + //fprintf(stderr, "%s m=%ld n=%ld lda=%ld\n", __FUNCTION__, m, n, lda); + + aoffset = a; + boffset = b; + boffset2 = b + 2 * m * (n & ~3); + boffset3 = b + 2 * m * (n & ~1); + + for(j = (m >> 2); j > 0; j--) { + + aoffset1 = aoffset; + aoffset += 8 * lda; + + boffset1 = boffset; + boffset += 32; + + for(i = (n >> 2); i > 0; i--) { + vl = 4; + + VLSSEG8_FLOAT(&v0, &v1, &v2, &v3, &v4, &v5, &v6, &v7, aoffset1, lda * sizeof(FLOAT) * 2, vl); + VSSEG8_FLOAT(boffset1, v0, v1, v2, v3, v4, v5, v6, v7, vl); + + aoffset1 += 8; + boffset1 += m * 8; + } + + if (n & 2) { + vl = 4; + + VLSSEG4_FLOAT(&v0, &v1, &v2, &v3, aoffset1, lda * sizeof(FLOAT) * 2, vl); + VSSEG4_FLOAT(boffset2, v0, v1, v2, v3, vl); + + aoffset1 += 4; + boffset2 += 16; + } + + if (n & 1) { + vl = 4; + + VLSSEG2_FLOAT(&v0, &v1, aoffset1, lda * sizeof(FLOAT) * 2, vl); + VSSEG2_FLOAT(boffset3, v0, v1, vl); + + aoffset1 += 2; + boffset3 += 8; + } + } + + if (m & 2) { + aoffset1 = aoffset; + aoffset += 4 * lda; + + boffset1 = boffset; + boffset += 16; + + for(i = (n >> 2); i > 0; i--) { + vl = 2; + + VLSSEG8_FLOAT(&v0, &v1, &v2, &v3, &v4, &v5, &v6, &v7, aoffset1, lda * sizeof(FLOAT) * 2, vl); + VSSEG8_FLOAT(boffset1, v0, v1, v2, v3, v4, v5, v6, v7, vl); + + aoffset1 += 8; + boffset1 += m * 8; + } + + if (n & 2) { + vl = 2; + + VLSSEG4_FLOAT(&v0, &v1, &v2, &v3, aoffset1, lda * sizeof(FLOAT) * 2, vl); + VSSEG4_FLOAT(boffset2, v0, v1, v2, v3, vl); + + aoffset1 += 4; + boffset2 += 8; + } + + if (n & 1) { + vl = 2; + + VLSSEG2_FLOAT(&v0, &v1, aoffset1, lda * sizeof(FLOAT) * 2, vl); + VSSEG2_FLOAT(boffset3, v0, v1, vl); + + //aoffset1 += 2; + boffset3 += 4; + } + } + + if (m & 1) { + aoffset1 = aoffset; + boffset1 = boffset; + + for(i = (n >> 2); i > 0; i--) { + vl = 8; + + v0 = VLEV_FLOAT(aoffset1, vl); + VSEV_FLOAT(boffset1, v0, vl); + + aoffset1 += 8; + boffset1 += 8 * m; + } + + if (n & 2) { + vl = 4; + + v0 = VLEV_FLOAT(aoffset1, vl); + VSEV_FLOAT(boffset2, v0, vl); + + aoffset1 += 4; + //boffset2 += 4; + } + + if (n & 1) { + *(boffset3) = *(aoffset1); + *(boffset3 + 1) = *(aoffset1 + 1); + } + } + + return 0; +} diff --git a/kernel/riscv64/zgemm_tcopy_rvv_v1.c b/kernel/riscv64/zgemm_tcopy_rvv_v1.c new file mode 100644 index 000000000..7622fb810 --- /dev/null +++ b/kernel/riscv64/zgemm_tcopy_rvv_v1.c @@ -0,0 +1,74 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define FLOAT_V_T vfloat32m2_t +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define FLOAT_V_T vfloat64m2_t +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#endif + +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) +{ + BLASLONG i, j; + + IFLOAT *aoffset; + IFLOAT *aoffset1; + IFLOAT *boffset; + + FLOAT_V_T v0, v1; + size_t vl; + + //fprintf(stderr, "%s, m=%ld n=%ld lda=%ld\n", __FUNCTION__, m, n, lda); + + aoffset = a; + boffset = b; + + for(j = n; j > 0; j -= vl) { + vl = VSETVL(j); + + aoffset1 = aoffset; + aoffset += vl * 2; + + for(i = m; i > 0; i--) { + VLSEG2_FLOAT(&v0, &v1, aoffset1, vl); + VSSEG2_FLOAT(boffset, v0, v1, vl); + + aoffset1 += lda * 2; + boffset += vl * 2; + } + } + + return 0; +} diff --git a/kernel/riscv64/zgemmkernel_rvv_v1x4.c b/kernel/riscv64/zgemmkernel_rvv_v1x4.c new file mode 100644 index 000000000..50e29222f --- /dev/null +++ b/kernel/riscv64/zgemmkernel_rvv_v1x4.c @@ -0,0 +1,475 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define FLOAT_V_T vfloat32m2_t +#define VLEV_FLOAT vle32_v_f32m2 +#define VSEV_FLOAT vse32_v_f32m2 +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define VFMVVF_FLOAT vfmv_v_f_f32m2 +#define VFMACCVF_FLOAT vfmacc_vf_f32m2 +#define VFNMSACVF_FLOAT vfnmsac_vf_f32m2 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define FLOAT_V_T vfloat64m2_t +#define VLEV_FLOAT vle64_v_f64m2 +#define VSEV_FLOAT vse64_v_f64m2 +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define VFMVVF_FLOAT vfmv_v_f_f64m2 +#define VFMACCVF_FLOAT vfmacc_vf_f64m2 +#define VFNMSACVF_FLOAT vfnmsac_vf_f64m2 +#endif + +#if defined(NN) || defined(NT) || defined(TN) || defined(TT) +#define OP_rr VFMACCVF_FLOAT +#define OP_ir VFMACCVF_FLOAT +#define OP_ii VFNMSACVF_FLOAT +#define OP_ri VFMACCVF_FLOAT +#elif defined(NR) || defined(NC) || defined(TR) || defined(TC) +#define OP_rr VFMACCVF_FLOAT +#define OP_ir VFMACCVF_FLOAT +#define OP_ii VFMACCVF_FLOAT +#define OP_ri VFNMSACVF_FLOAT +#elif defined(RN) || defined(RT) || defined(CN) || defined(CT) +#define OP_rr VFMACCVF_FLOAT +#define OP_ir VFNMSACVF_FLOAT +#define OP_ii VFMACCVF_FLOAT +#define OP_ri VFMACCVF_FLOAT +#elif defined(RR) || defined(RC) || defined(CR) || defined(CC) +#define OP_rr VFMACCVF_FLOAT +#define OP_ir VFNMSACVF_FLOAT +#define OP_ii VFNMSACVF_FLOAT +#define OP_ri VFNMSACVF_FLOAT +#endif + +int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alphar,FLOAT alphai,FLOAT* ba,FLOAT* bb,FLOAT* C,BLASLONG ldc +#ifdef TRMMKERNEL + , BLASLONG offset +#endif + ) +{ + BLASLONG i,j,k; + FLOAT *C0, *C1, *C2, *C3, *ptrba,*ptrbb; + + FLOAT_V_T va0, va1, va2, va3, va4, va5, va6, va7; + FLOAT_V_T vres0, vres1, vres2, vres3, vres4, vres5, vres6, vres7; + + //fprintf(stderr, "%s, bn=%ld bm=%ld bk=%ld alphar=%f alphai=%f ldc=%ld\n", __FUNCTION__, bn, bm, bk, alphar, alphai, ldc); // Debug + + size_t vl; + for (j = bn/4; j > 0; j--) + { + C0 = C; + C1 = C0 + 2 * ldc; + C2 = C1 + 2 * ldc; + C3 = C2 + 2 * ldc; + ptrba = ba; + for (i = bm; i > 0; i -= vl) + { + vl = VSETVL(i); + ptrbb = bb; + + vres0 = VFMVVF_FLOAT(0.0, vl); + vres1 = VFMVVF_FLOAT(0.0, vl); + vres2 = VFMVVF_FLOAT(0.0, vl); + vres3 = VFMVVF_FLOAT(0.0, vl); + vres4 = VFMVVF_FLOAT(0.0, vl); + vres5 = VFMVVF_FLOAT(0.0, vl); + vres6 = VFMVVF_FLOAT(0.0, vl); + vres7 = VFMVVF_FLOAT(0.0, vl); + + for (k = bk/4; k > 0; k--) + { + VLSEG2_FLOAT(&va0, &va1, ptrba, vl); + ptrba += vl*2; + + VLSEG2_FLOAT(&va2, &va3, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va0, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va1, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va1, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va0, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va0, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va1, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va1, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va0, vl); + + vres4 = OP_rr(vres4, *(ptrbb + 4), va0, vl); + vres5 = OP_ir(vres5, *(ptrbb + 4), va1, vl); + vres4 = OP_ii(vres4, *(ptrbb + 5), va1, vl); + vres5 = OP_ri(vres5, *(ptrbb + 5), va0, vl); + + vres6 = OP_rr(vres6, *(ptrbb + 6), va0, vl); + vres7 = OP_ir(vres7, *(ptrbb + 6), va1, vl); + vres6 = OP_ii(vres6, *(ptrbb + 7), va1, vl); + vres7 = OP_ri(vres7, *(ptrbb + 7), va0, vl); + + ptrbb += 8; + + VLSEG2_FLOAT(&va4, &va5, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va2, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va3, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va3, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va2, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va2, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va3, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va3, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va2, vl); + + vres4 = OP_rr(vres4, *(ptrbb + 4), va2, vl); + vres5 = OP_ir(vres5, *(ptrbb + 4), va3, vl); + vres4 = OP_ii(vres4, *(ptrbb + 5), va3, vl); + vres5 = OP_ri(vres5, *(ptrbb + 5), va2, vl); + + vres6 = OP_rr(vres6, *(ptrbb + 6), va2, vl); + vres7 = OP_ir(vres7, *(ptrbb + 6), va3, vl); + vres6 = OP_ii(vres6, *(ptrbb + 7), va3, vl); + vres7 = OP_ri(vres7, *(ptrbb + 7), va2, vl); + + ptrbb += 8; + + VLSEG2_FLOAT(&va6, &va7, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va4, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va5, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va5, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va4, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va4, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va5, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va5, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va4, vl); + + vres4 = OP_rr(vres4, *(ptrbb + 4), va4, vl); + vres5 = OP_ir(vres5, *(ptrbb + 4), va5, vl); + vres4 = OP_ii(vres4, *(ptrbb + 5), va5, vl); + vres5 = OP_ri(vres5, *(ptrbb + 5), va4, vl); + + vres6 = OP_rr(vres6, *(ptrbb + 6), va4, vl); + vres7 = OP_ir(vres7, *(ptrbb + 6), va5, vl); + vres6 = OP_ii(vres6, *(ptrbb + 7), va5, vl); + vres7 = OP_ri(vres7, *(ptrbb + 7), va4, vl); + ptrbb += 8; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va6, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va7, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va7, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va6, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va6, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va7, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va7, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va6, vl); + + vres4 = OP_rr(vres4, *(ptrbb + 4), va6, vl); + vres5 = OP_ir(vres5, *(ptrbb + 4), va7, vl); + vres4 = OP_ii(vres4, *(ptrbb + 5), va7, vl); + vres5 = OP_ri(vres5, *(ptrbb + 5), va6, vl); + + vres6 = OP_rr(vres6, *(ptrbb + 6), va6, vl); + vres7 = OP_ir(vres7, *(ptrbb + 6), va7, vl); + vres6 = OP_ii(vres6, *(ptrbb + 7), va7, vl); + vres7 = OP_ri(vres7, *(ptrbb + 7), va6, vl); + + ptrbb += 8; + } + + for (k = (bk & 3); k > 0; k--) + { + VLSEG2_FLOAT(&va0, &va1, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va0, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va1, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va1, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va0, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va0, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va1, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va1, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va0, vl); + + vres4 = OP_rr(vres4, *(ptrbb + 4), va0, vl); + vres5 = OP_ir(vres5, *(ptrbb + 4), va1, vl); + vres4 = OP_ii(vres4, *(ptrbb + 5), va1, vl); + vres5 = OP_ri(vres5, *(ptrbb + 5), va0, vl); + + vres6 = OP_rr(vres6, *(ptrbb + 6), va0, vl); + vres7 = OP_ir(vres7, *(ptrbb + 6), va1, vl); + vres6 = OP_ii(vres6, *(ptrbb + 7), va1, vl); + vres7 = OP_ri(vres7, *(ptrbb + 7), va0, vl); + + ptrbb += 8; + } + + VLSEG2_FLOAT(&va0, &va1, C0, vl); + VLSEG2_FLOAT(&va2, &va3, C1, vl); + + va0 = VFMACCVF_FLOAT(va0, alphar, vres0, vl); + va1 = VFMACCVF_FLOAT(va1, alphar, vres1, vl); + va0 = VFNMSACVF_FLOAT(va0, alphai, vres1, vl); + va1 = VFMACCVF_FLOAT(va1, alphai, vres0, vl); + VSSEG2_FLOAT(C0, va0, va1, vl); + + va2 = VFMACCVF_FLOAT(va2, alphar, vres2, vl); + va3 = VFMACCVF_FLOAT(va3, alphar, vres3, vl); + va2 = VFNMSACVF_FLOAT(va2, alphai, vres3, vl); + va3 = VFMACCVF_FLOAT(va3, alphai, vres2, vl); + VSSEG2_FLOAT(C1, va2, va3, vl); + + VLSEG2_FLOAT(&va0, &va1, C2, vl); + VLSEG2_FLOAT(&va2, &va3, C3, vl); + + va0 = VFMACCVF_FLOAT(va0, alphar, vres4, vl); + va1 = VFMACCVF_FLOAT(va1, alphar, vres5, vl); + va0 = VFNMSACVF_FLOAT(va0, alphai, vres5, vl); + va1 = VFMACCVF_FLOAT(va1, alphai, vres4, vl); + VSSEG2_FLOAT(C2, va0, va1, vl); + + va2 = VFMACCVF_FLOAT(va2, alphar, vres6, vl); + va3 = VFMACCVF_FLOAT(va3, alphar, vres7, vl); + va2 = VFNMSACVF_FLOAT(va2, alphai, vres7, vl); + va3 = VFMACCVF_FLOAT(va3, alphai, vres6, vl); + VSSEG2_FLOAT(C3, va2, va3, vl); + + C0 += vl * 2; + C1 += vl * 2; + C2 += vl * 2; + C3 += vl * 2; + } + + bb += (bk << 3); + C += (ldc << 3); + } + + if (bn & 2) + { + C0 = C; + C1 = C0 + 2 * ldc; + ptrba = ba; + for (i = bm; i > 0; i -= vl) + { + vl = VSETVL(i); + ptrbb = bb; + + vres0 = VFMVVF_FLOAT(0.0, vl); + vres1 = VFMVVF_FLOAT(0.0, vl); + vres2 = VFMVVF_FLOAT(0.0, vl); + vres3 = VFMVVF_FLOAT(0.0, vl); + + for (k = bk/4; k > 0; k--) + { + VLSEG2_FLOAT(&va0, &va1, ptrba, vl); + ptrba += vl*2; + VLSEG2_FLOAT(&va2, &va3, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va0, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va1, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va1, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va0, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va0, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va1, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va1, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va0, vl); + + ptrbb += 4; + + VLSEG2_FLOAT(&va4, &va5, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va2, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va3, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va3, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va2, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va2, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va3, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va3, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va2, vl); + + ptrbb += 4; + + VLSEG2_FLOAT(&va6, &va7, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va4, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va5, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va5, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va4, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va4, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va5, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va5, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va4, vl); + + ptrbb += 4; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va6, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va7, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va7, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va6, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va6, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va7, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va7, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va6, vl); + + ptrbb += 4; + } + + for (k = (bk & 3); k > 0; k--) + { + VLSEG2_FLOAT(&va0, &va1, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va0, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va1, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va1, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va0, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va0, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va1, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va1, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va0, vl); + + ptrbb += 4; + } + + VLSEG2_FLOAT(&va0, &va1, C0, vl); + VLSEG2_FLOAT(&va2, &va3, C1, vl); + + va0 = VFMACCVF_FLOAT(va0, alphar, vres0, vl); + va1 = VFMACCVF_FLOAT(va1, alphar, vres1, vl); + va0 = VFNMSACVF_FLOAT(va0, alphai, vres1, vl); + va1 = VFMACCVF_FLOAT(va1, alphai, vres0, vl); + VSSEG2_FLOAT(C0, va0, va1, vl); + + va2 = VFMACCVF_FLOAT(va2, alphar, vres2, vl); + va3 = VFMACCVF_FLOAT(va3, alphar, vres3, vl); + va2 = VFNMSACVF_FLOAT(va2, alphai, vres3, vl); + va3 = VFMACCVF_FLOAT(va3, alphai, vres2, vl); + VSSEG2_FLOAT(C1, va2, va3, vl); + + C0 += vl * 2; + C1 += vl * 2; + } + + bb += (bk << 2); + C += (ldc << 2); + } + + if (bn & 1) + { + C0 = C; + ptrba = ba; + for (i = bm; i > 0; i -= vl) + { + vl = VSETVL(i); + ptrbb = bb; + + vres0 = VFMVVF_FLOAT(0.0, vl); + vres1 = VFMVVF_FLOAT(0.0, vl); + + for (k = bk/4; k > 0; k--) + { + VLSEG2_FLOAT(&va0, &va1, ptrba, vl); + ptrba += vl*2; + VLSEG2_FLOAT(&va2, &va3, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va0, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va1, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va1, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va0, vl); + ptrbb += 2; + + VLSEG2_FLOAT(&va4, &va5, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va2, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va3, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va3, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va2, vl); + + ptrbb += 2; + + VLSEG2_FLOAT(&va6, &va7, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va4, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va5, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va5, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va4, vl); + ptrbb += 2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va6, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va7, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va7, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va6, vl); + ptrbb += 2; + } + + for (k = (bk & 3); k > 0; k--) + { + VLSEG2_FLOAT(&va0, &va1, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va0, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va1, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va1, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va0, vl); + ptrbb += 2; + } + + VLSEG2_FLOAT(&va0, &va1, C0, vl); + va0 = VFMACCVF_FLOAT(va0, alphar, vres0, vl); + va1 = VFMACCVF_FLOAT(va1, alphar, vres1, vl); + va0 = VFNMSACVF_FLOAT(va0, alphai, vres1, vl); + va1 = VFMACCVF_FLOAT(va1, alphai, vres0, vl); + VSSEG2_FLOAT(C0, va0, va1, vl); + C0 += vl * 2; + } + + bb += bk << 1; + C += ldc << 1; + } + return 0; +} + diff --git a/kernel/riscv64/zhemm_ltcopy_rvv_v1.c b/kernel/riscv64/zhemm_ltcopy_rvv_v1.c new file mode 100644 index 000000000..cf466d3fa --- /dev/null +++ b/kernel/riscv64/zhemm_ltcopy_rvv_v1.c @@ -0,0 +1,124 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define VSETVL_MAX vsetvlmax_e32m2() +#define FLOAT_V_T vfloat32m2_t +#define VLEV_FLOAT vle32_v_f32m2 +#define VSEV_FLOAT vse32_v_f32m2 +#define VLSEV_FLOAT vlse32_v_f32m2 +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define INT_V_T vint32m2_t +#define VID_V_INT vid_v_i32m2 +#define VADD_VX_INT vadd_vx_i32m2 +#define VFRSUB_VF_FLOAT vfrsub_vf_f32m2 +#define VMSGT_VX_INT vmsgt_vx_i32m2_b16 +#define VMSLT_VX_INT vmslt_vx_i32m2_b16 +#define VMSEQ_VX_INT vmseq_vx_i32m2_b16 +#define VBOOL_T vbool16_t +#define VMERGE_VVM_FLOAT vmerge_vvm_f32m2 +#define VFMVVF_FLOAT vfmv_v_f_f32m2 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define VSETVL_MAX vsetvlmax_e64m2() +#define FLOAT_V_T vfloat64m2_t +#define VLEV_FLOAT vle64_v_f64m2 +#define VSEV_FLOAT vse64_v_f64m2 +#define VLSEV_FLOAT vlse64_v_f64m2 +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define INT_V_T vint64m2_t +#define VID_V_INT vid_v_i64m2 +#define VADD_VX_INT vadd_vx_i64m2 +#define VFRSUB_VF_FLOAT vfrsub_vf_f64m2 +#define VMSGT_VX_INT vmsgt_vx_i64m2_b32 +#define VMSLT_VX_INT vmslt_vx_i64m2_b32 +#define VMSEQ_VX_INT vmseq_vx_i64m2_b32 +#define VBOOL_T vbool32_t +#define VMERGE_VVM_FLOAT vmerge_vvm_f64m2 +#define VFMVVF_FLOAT vfmv_v_f_f64m2 +#endif + + +int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLONG posY, FLOAT *b) +{ + //fprintf(stderr, "%s, %s, m=%ld n=%ld lda=%ld posX=%ld posY=%ld\n", __FUNCTION__, __FILE__, m, n, lda, posX, posY); + + BLASLONG i, js, offset; + + FLOAT *ao1, *ao2; + + BLASLONG stride_lda = sizeof(FLOAT) * lda * 2; + + FLOAT_V_T vb0, vb1, vb2, va10, va11, va20, va21, vzero; + VBOOL_T vbool_gt0, vbool_lt0, vbool_eq0; + INT_V_T vindex_max, vindex; + + size_t vl = VSETVL_MAX; + vindex_max = VID_V_INT(vl); + vzero = VFMVVF_FLOAT(ZERO, vl); + + for (js = n; js > 0; js -= vl, posX += vl) { + vl = VSETVL(js); + offset = posX - posY; + + ao1 = a + posX * 2 + posY * lda * 2; + ao2 = a + posY * 2 + posX * lda * 2; + + for (i = m; i > 0; i--, offset--) { + VLSSEG2_FLOAT(&va20, &va21, ao2, stride_lda, vl); + VLSEG2_FLOAT(&va10, &va11, ao1, vl); + + vindex = VADD_VX_INT(vindex_max, offset, vl); + vbool_gt0 = VMSGT_VX_INT(vindex, 0, vl); + vbool_lt0 = VMSLT_VX_INT(vindex, 0, vl); + vbool_eq0 = VMSEQ_VX_INT(vindex, 0, vl); + + vb0 = VMERGE_VVM_FLOAT(vbool_gt0, va20, va10, vl); + vb1 = VMERGE_VVM_FLOAT(vbool_gt0, va21, va11, vl); + + vb2 = VFRSUB_VF_FLOAT(vb1, ZERO, vl); + + vb1 = VMERGE_VVM_FLOAT(vbool_lt0, vb1, vb2, vl); + vb1 = VMERGE_VVM_FLOAT(vbool_eq0, vb1, vzero, vl); + VSSEG2_FLOAT(b, vb0, vb1, vl); + + b += vl * 2; + ao1 += lda * 2; + ao2 += 2; + } + } + + return 0; +} + diff --git a/kernel/riscv64/zhemm_utcopy_rvv_v1.c b/kernel/riscv64/zhemm_utcopy_rvv_v1.c new file mode 100644 index 000000000..6209f5417 --- /dev/null +++ b/kernel/riscv64/zhemm_utcopy_rvv_v1.c @@ -0,0 +1,120 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define VSETVL_MAX vsetvlmax_e32m2() +#define FLOAT_V_T vfloat32m2_t +#define VLEV_FLOAT vle32_v_f32m2 +#define VSEV_FLOAT vse32_v_f32m2 +#define VLSEV_FLOAT vlse32_v_f32m2 +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define INT_V_T vint32m2_t +#define VID_V_INT vid_v_i32m2 +#define VADD_VX_INT vadd_vx_i32m2 +#define VFRSUB_VF_FLOAT vfrsub_vf_f32m2 +#define VMSGT_VX_INT vmsgt_vx_i32m2_b16 +#define VMSLT_VX_INT vmslt_vx_i32m2_b16 +#define VMSEQ_VX_INT vmseq_vx_i32m2_b16 +#define VBOOL_T vbool16_t +#define VMERGE_VVM_FLOAT vmerge_vvm_f32m2 +#define VFMVVF_FLOAT vfmv_v_f_f32m2 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define VSETVL_MAX vsetvlmax_e64m2() +#define FLOAT_V_T vfloat64m2_t +#define VLEV_FLOAT vle64_v_f64m2 +#define VSEV_FLOAT vse64_v_f64m2 +#define VLSEV_FLOAT vlse64_v_f64m2 +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define INT_V_T vint64m2_t +#define VID_V_INT vid_v_i64m2 +#define VADD_VX_INT vadd_vx_i64m2 +#define VFRSUB_VF_FLOAT vfrsub_vf_f64m2 +#define VMSGT_VX_INT vmsgt_vx_i64m2_b32 +#define VMSLT_VX_INT vmslt_vx_i64m2_b32 +#define VMSEQ_VX_INT vmseq_vx_i64m2_b32 +#define VBOOL_T vbool32_t +#define VMERGE_VVM_FLOAT vmerge_vvm_f64m2 +#define VFMVVF_FLOAT vfmv_v_f_f64m2 +#endif + + +int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLONG posY, FLOAT *b) +{ + BLASLONG i, js, offset; + + FLOAT *ao1, *ao2; + //fprintf(stderr, "%s, %s, m=%ld n=%ld lda=%ld posX=%ld posY=%ld\n", __FUNCTION__, __FILE__, m, n, lda, posX, posY); + BLASLONG stride_lda = sizeof(FLOAT) * lda * 2; + + FLOAT_V_T vb0, vb1, vb2, va10, va11, va20, va21, vzero; + VBOOL_T vbool_gt0, vbool_eq0; + INT_V_T vindex_max, vindex; + + size_t vl = VSETVL_MAX; + vindex_max = VID_V_INT(vl); + vzero = VFMVVF_FLOAT(ZERO, vl); + + for (js = n; js > 0; js -= vl, posX += vl) { + vl = VSETVL(js); + offset = posX - posY; + + ao1 = a + posY * 2 + posX * lda * 2; + ao2 = a + posX * 2 + posY * lda * 2; + + for (i = m; i > 0; i--, offset--) { + VLSSEG2_FLOAT(&va10, &va11, ao1, stride_lda, vl); + VLSEG2_FLOAT(&va20, &va21, ao2, vl); + + vindex = VADD_VX_INT(vindex_max, offset, vl); + vbool_gt0 = VMSGT_VX_INT(vindex, 0, vl); + vbool_eq0 = VMSEQ_VX_INT(vindex, 0, vl); + + vb0 = VMERGE_VVM_FLOAT(vbool_gt0, va20, va10, vl); + vb1 = VMERGE_VVM_FLOAT(vbool_gt0, va21, va11, vl); + + vb2 = VFRSUB_VF_FLOAT(vb1, ZERO, vl); + + vb1 = VMERGE_VVM_FLOAT(vbool_gt0, vb1, vb2, vl); + vb1 = VMERGE_VVM_FLOAT(vbool_eq0, vb1, vzero, vl); + VSSEG2_FLOAT(b, vb0, vb1, vl); + + b += vl * 2; + ao1 += 2; + ao2 += lda * 2; + } + } + + return 0; +} diff --git a/kernel/riscv64/zsymm_lcopy_rvv_v1.c b/kernel/riscv64/zsymm_lcopy_rvv_v1.c new file mode 100644 index 000000000..df5c916a5 --- /dev/null +++ b/kernel/riscv64/zsymm_lcopy_rvv_v1.c @@ -0,0 +1,106 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define VSETVL_MAX vsetvlmax_e32m2() +#define FLOAT_V_T vfloat32m2_t +#define VLEV_FLOAT vle32_v_f32m2 +#define VSEV_FLOAT vse32_v_f32m2 +#define VLSEV_FLOAT vlse32_v_f32m2 +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define INT_V_T vint32m2_t +#define VID_V_INT vid_v_i32m2 +#define VADD_VX_INT vadd_vx_i32m2 +#define VMSGT_VX_INT vmsgt_vx_i32m2_b16 +#define VBOOL_T vbool16_t +#define VMERGE_VVM_FLOAT vmerge_vvm_f32m2 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define VSETVL_MAX vsetvlmax_e64m2() +#define FLOAT_V_T vfloat64m2_t +#define VLEV_FLOAT vle64_v_f64m2 +#define VSEV_FLOAT vse64_v_f64m2 +#define VLSEV_FLOAT vlse64_v_f64m2 +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define INT_V_T vint64m2_t +#define VID_V_INT vid_v_i64m2 +#define VADD_VX_INT vadd_vx_i64m2 +#define VMSGT_VX_INT vmsgt_vx_i64m2_b32 +#define VBOOL_T vbool32_t +#define VMERGE_VVM_FLOAT vmerge_vvm_f64m2 +#endif + +int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLONG posY, FLOAT *b) +{ + BLASLONG i, js, offset; + + FLOAT *ao1, *ao2; + + BLASLONG stride_lda = sizeof(FLOAT)*lda*2; + + FLOAT_V_T vb0, vb1, va10, va11, va20, va21; + VBOOL_T vbool; + INT_V_T vindex_max, vindex; + + size_t vl = VSETVL_MAX; + vindex_max = VID_V_INT(vl); + + for (js = n; js > 0; js -= vl, posX += vl) { + vl = VSETVL(js); + offset = posX - posY; + + ao1 = a + posX * 2 + posY * lda * 2; + ao2 = a + posY * 2 + (posX) * lda * 2; + + for (i = m; i > 0; i--, offset--) { + + VLSSEG2_FLOAT(&va20, &va21, ao2, stride_lda, vl); + VLSEG2_FLOAT(&va10, &va11, ao1, vl); + + vindex = VADD_VX_INT(vindex_max, offset, vl); + vbool = VMSGT_VX_INT(vindex, 0, vl); + + vb0 = VMERGE_VVM_FLOAT(vbool, va20, va10, vl); + vb1 = VMERGE_VVM_FLOAT(vbool, va21, va11, vl); + VSSEG2_FLOAT(b, vb0, vb1, vl); + + b += vl * 2; + ao1 += lda * 2; + ao2 += 2; + } + } + + return 0; +} + diff --git a/kernel/riscv64/zsymm_ucopy_rvv_v1.c b/kernel/riscv64/zsymm_ucopy_rvv_v1.c new file mode 100644 index 000000000..dcf2b081a --- /dev/null +++ b/kernel/riscv64/zsymm_ucopy_rvv_v1.c @@ -0,0 +1,106 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define VSETVL_MAX vsetvlmax_e32m2() +#define FLOAT_V_T vfloat32m2_t +#define VLEV_FLOAT vle32_v_f32m2 +#define VSEV_FLOAT vse32_v_f32m2 +#define VLSEV_FLOAT vlse32_v_f32m2 +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define INT_V_T vint32m2_t +#define VID_V_INT vid_v_i32m2 +#define VADD_VX_INT vadd_vx_i32m2 +#define VMSGT_VX_INT vmsgt_vx_i32m2_b16 +#define VBOOL_T vbool16_t +#define VMERGE_VVM_FLOAT vmerge_vvm_f32m2 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define VSETVL_MAX vsetvlmax_e64m2() +#define FLOAT_V_T vfloat64m2_t +#define VLEV_FLOAT vle64_v_f64m2 +#define VSEV_FLOAT vse64_v_f64m2 +#define VLSEV_FLOAT vlse64_v_f64m2 +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define INT_V_T vint64m2_t +#define VID_V_INT vid_v_i64m2 +#define VADD_VX_INT vadd_vx_i64m2 +#define VMSGT_VX_INT vmsgt_vx_i64m2_b32 +#define VBOOL_T vbool32_t +#define VMERGE_VVM_FLOAT vmerge_vvm_f64m2 +#endif + + +int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLONG posY, FLOAT *b) +{ + BLASLONG i, js, offset; + + FLOAT *ao1, *ao2; + + BLASLONG stride_lda = sizeof(FLOAT)*lda * 2; + + FLOAT_V_T vb0, vb1, va10, va11, va20, va21; + VBOOL_T vbool; + INT_V_T vindex_max, vindex; + + + size_t vl = VSETVL_MAX; + vindex_max = VID_V_INT(vl); + + for (js = n; js > 0; js -= vl, posX += vl) { + vl = VSETVL(js); + offset = posX - posY; + + ao1 = a + posY * 2 + (posX + 0) * lda * 2; + ao2 = a + posX * 2 + 0 + posY * lda * 2; + + for (i = m; i > 0; i--, offset--) { + VLSSEG2_FLOAT(&va10, &va11, ao1, stride_lda, vl); + VLSEG2_FLOAT(&va20, &va21, ao2, vl); + + vindex = VADD_VX_INT(vindex_max, offset, vl); + vbool = VMSGT_VX_INT(vindex, 0, vl); + + vb0 = VMERGE_VVM_FLOAT(vbool, va20, va10, vl); + vb1 = VMERGE_VVM_FLOAT(vbool, va21, va11, vl); + VSSEG2_FLOAT(b, vb0, vb1, vl); + + b += vl * 2; + ao1 += 2; + ao2 += lda * 2; + } + } + + return 0; +} diff --git a/kernel/riscv64/ztrmm_lncopy_rvv_v1.c b/kernel/riscv64/ztrmm_lncopy_rvv_v1.c new file mode 100644 index 000000000..afd694408 --- /dev/null +++ b/kernel/riscv64/ztrmm_lncopy_rvv_v1.c @@ -0,0 +1,145 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + + +#include +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define FLOAT_V_T vfloat32m2_t +#define VLEV_FLOAT vle32_v_f32m2 +#define VSEV_FLOAT vse32_v_f32m2 +#define VLSEV_FLOAT vlse32_v_f32m2 +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define VBOOL_T vbool16_t +#define UINT_V_T vint32m2_t +#define VID_V_UINT vid_v_i32m2 +#define VMSGTU_VX_UINT vmsgt_vx_i32m2_b16 +#define VMSEQ_VX_UINT vmseq_vx_i32m2_b16 +#define VFMERGE_VFM_FLOAT vfmerge_vfm_f32m2 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define FLOAT_V_T vfloat64m2_t +#define VLEV_FLOAT vle64_v_f64m2 +#define VSEV_FLOAT vse64_v_f64m2 +#define VLSEV_FLOAT vlse64_v_f64m2 +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define VBOOL_T vbool32_t +#define UINT_V_T vuint64m2_t +#define VID_V_UINT vid_v_u64m2 +#define VMSGTU_VX_UINT vmsgtu_vx_u64m2_b32 +#define VMSEQ_VX_UINT vmseq_vx_u64m2_b32 +#define VFMERGE_VFM_FLOAT vfmerge_vfm_f64m2 +#endif + +int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLONG posY, FLOAT *b){ + + BLASLONG i, js, X; + + FLOAT *ao; + + BLASLONG stride_lda = sizeof(FLOAT)*lda*2; + + FLOAT_V_T va0, va1; + + size_t vl; +#ifdef UNIT + VBOOL_T vbool_eq; +#endif + + VBOOL_T vbool_cmp; + UINT_V_T vindex; + + for (js = n; js > 0; js -= vl) + { + vl = VSETVL(js); + X = posX; + + if (posX <= posY) + { + ao = a + posY * 2 + posX * lda * 2; + } + else + { + ao = a + posX * 2 + posY * lda * 2; + } + + i = 0; + do + { + if (X > posY) + { + VLSSEG2_FLOAT(&va0, &va1, ao, stride_lda, vl); + VSSEG2_FLOAT(b, va0, va1, vl); + + ao += 2; + b += vl * 2; + + X ++; + i ++; + } + else if (X < posY) + { + ao += lda * 2; + b += vl * 2; + X ++; + i ++; + } + else + { + vindex = VID_V_UINT(vl); + for (unsigned int j = 0; j < vl; j++) + { + VLSSEG2_FLOAT(&va0, &va1, ao, stride_lda, vl); + vbool_cmp = VMSGTU_VX_UINT(vindex, j, vl); + va0 = VFMERGE_VFM_FLOAT(vbool_cmp, va0, ZERO, vl); + va1 = VFMERGE_VFM_FLOAT(vbool_cmp, va1, ZERO, vl); +#ifdef UNIT + vbool_eq = VMSEQ_VX_UINT(vindex, j, vl); + va0 = VFMERGE_VFM_FLOAT(vbool_eq, va0, ONE, vl); + va1 = VFMERGE_VFM_FLOAT(vbool_eq, va1, ZERO, vl); +#endif + VSSEG2_FLOAT(b, va0, va1, vl); + ao += 2; + b += vl * 2; + } + + X += vl; + i += vl; + } + } while (i < m); + + posY += vl; + } + + return 0; +} diff --git a/kernel/riscv64/ztrmm_ltcopy_rvv_v1.c b/kernel/riscv64/ztrmm_ltcopy_rvv_v1.c new file mode 100644 index 000000000..c7d593949 --- /dev/null +++ b/kernel/riscv64/ztrmm_ltcopy_rvv_v1.c @@ -0,0 +1,143 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + + +#include +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define FLOAT_V_T vfloat32m2_t +#define VLEV_FLOAT vle32_v_f32m2 +#define VSEV_FLOAT vse32_v_f32m2 +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define VBOOL_T vbool16_t +#define UINT_V_T vuint32m2_t +#define VID_V_UINT vid_v_u32m2 +#define VMSLTU_VX_UINT vmsltu_vx_u32m2_b16 +#define VMSEQ_VX_UINT vmseq_vx_u32m2_b16 +#define VFMERGE_VFM_FLOAT vfmerge_vfm_f32m2 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define FLOAT_V_T vfloat64m2_t +#define VLEV_FLOAT vle64_v_f64m2 +#define VSEV_FLOAT vse64_v_f64m2 +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define VBOOL_T vbool32_t +#define UINT_V_T vuint64m2_t +#define VID_V_UINT vid_v_u64m2 +#define VMSLTU_VX_UINT vmsltu_vx_u64m2_b32 +#define VMSEQ_VX_UINT vmseq_vx_u64m2_b32 +#define VFMERGE_VFM_FLOAT vfmerge_vfm_f64m2 +#endif + +int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLONG posY, FLOAT *b){ + + BLASLONG i, js, X; + + FLOAT *ao; + + FLOAT_V_T va0, va1; + size_t vl; +#ifdef UNIT + VBOOL_T vbool_eq; +#endif + + VBOOL_T vbool_cmp; + UINT_V_T vindex; + + for (js = n; js > 0; js -= vl) + { + vl = VSETVL(js); + X = posX; + + if (posX <= posY) + { + ao = a + posY * 2 + posX * lda * 2; + } + else + { + ao = a + posX * 2 + posY * lda * 2; + } + + i = 0; + do + { + if (X > posY) + { + ao += 2; + b += vl * 2; + X++; + i++; + } + else if (X < posY) + { + //va1 = VLEV_FLOAT(ao, vl); + VLSEG2_FLOAT(&va0, &va1, ao, vl); + VSSEG2_FLOAT(b, va0, va1, vl); + + ao += lda * 2; + b += vl * 2; + X ++; + i ++; + } + else + { + vindex = VID_V_UINT(vl); + for (unsigned int j = 0; j < vl; j++) + { + //va1 = VLEV_FLOAT(ao, vl); + VLSEG2_FLOAT(&va0, &va1, ao, vl); + vbool_cmp = VMSLTU_VX_UINT(vindex, j, vl); + va0 = VFMERGE_VFM_FLOAT(vbool_cmp, va0, ZERO, vl); + va1 = VFMERGE_VFM_FLOAT(vbool_cmp, va1, ZERO, vl); +#ifdef UNIT + vbool_eq = VMSEQ_VX_UINT(vindex, j, vl); + va0 = VFMERGE_VFM_FLOAT(vbool_eq, va0, ONE, vl); + va1 = VFMERGE_VFM_FLOAT(vbool_eq, va1, ZERO, vl); +#endif + //VSEV_FLOAT(b, vb, vl); + VSSEG2_FLOAT(b, va0, va1, vl); + ao += lda * 2; + b += vl * 2; + } + X += vl; + i += vl; + + } + } while (i < m); + + posY += vl; + } + + return 0; +} + diff --git a/kernel/riscv64/ztrmm_uncopy_rvv_v1.c b/kernel/riscv64/ztrmm_uncopy_rvv_v1.c new file mode 100644 index 000000000..3c70b6385 --- /dev/null +++ b/kernel/riscv64/ztrmm_uncopy_rvv_v1.c @@ -0,0 +1,144 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + + +#include +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define FLOAT_V_T vfloat32m2_t +#define VLEV_FLOAT vle32_v_f32m2 +#define VLSEV_FLOAT vlse32_v_f32m2 +#define VSEV_FLOAT vse32_v_f32m2 +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define VBOOL_T vbool16_t +#define UINT_V_T vuint32m2_t +#define VID_V_UINT vid_v_u32m2 +#define VMSLTU_VX_UINT vmsltu_vx_u32m2_b16 +#define VMSEQ_VX_UINT vmseq_vx_u32m2_b16 +#define VFMERGE_VFM_FLOAT vfmerge_vfm_f32m2 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define FLOAT_V_T vfloat64m2_t +#define VLEV_FLOAT vle64_v_f64m2 +#define VLSEV_FLOAT vlse64_v_f64m2 +#define VSEV_FLOAT vse64_v_f64m2 +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define VBOOL_T vbool32_t +#define UINT_V_T vuint64m2_t +#define VID_V_UINT vid_v_u64m2 +#define VMSLTU_VX_UINT vmsltu_vx_u64m2_b32 +#define VMSEQ_VX_UINT vmseq_vx_u64m2_b32 +#define VFMERGE_VFM_FLOAT vfmerge_vfm_f64m2 +#endif + +int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLONG posY, FLOAT *b){ + + BLASLONG i, js, X; + BLASLONG stride_lda = sizeof(FLOAT) * lda * 2; + FLOAT *ao; + + FLOAT_V_T va0, va1; + size_t vl; + +#ifdef UNIT + VBOOL_T vbool_eq; +#endif + + VBOOL_T vbool_cmp; + UINT_V_T vindex; + + for (js = n; js > 0; js -= vl) + { + vl = VSETVL(js); + X = posX; + + if (posX <= posY) + { + ao = a + posX * 2 + posY * lda * 2; + } + else + { + ao = a + posY * 2 + posX * lda * 2; + } + + i = 0; + do + { + if (X < posY) + { + VLSSEG2_FLOAT(&va0, &va1, ao, stride_lda, vl); + VSSEG2_FLOAT(b, va0, va1, vl); + + ao += 2; + b += vl * 2; + + X++; + i++; + } + else if (X > posY) + { + ao += lda * 2; + b += vl * 2; + + X++; + i++; + } + else + { + vindex = VID_V_UINT(vl); + for (unsigned int j = 0; j < vl; j++) + { + VLSSEG2_FLOAT(&va0, &va1, ao, stride_lda, vl); + vbool_cmp = VMSLTU_VX_UINT(vindex, j, vl); + va0 = VFMERGE_VFM_FLOAT(vbool_cmp, va0, ZERO, vl); + va1 = VFMERGE_VFM_FLOAT(vbool_cmp, va1, ZERO, vl); +#ifdef UNIT + vbool_eq = VMSEQ_VX_UINT(vindex, j, vl); + va0 = VFMERGE_VFM_FLOAT(vbool_eq, va0, ONE, vl); + va1 = VFMERGE_VFM_FLOAT(vbool_eq, va1, ZERO, vl); +#endif + VSSEG2_FLOAT(b, va0, va1, vl); + ao += 2; + b += vl * 2; + } + + X += vl; + i += vl; + } + }while (i < m); + + posY += vl; + } + + return 0; +} diff --git a/kernel/riscv64/ztrmm_utcopy_rvv_v1.c b/kernel/riscv64/ztrmm_utcopy_rvv_v1.c new file mode 100644 index 000000000..706782cf0 --- /dev/null +++ b/kernel/riscv64/ztrmm_utcopy_rvv_v1.c @@ -0,0 +1,140 @@ +/*********************************************************************/ +/* Copyright 2009, 2010 The University of Texas at Austin. */ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + + +#include +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define FLOAT_V_T vfloat32m2_t +#define VLEV_FLOAT vle32_v_f32m2 +#define VSEV_FLOAT vse32_v_f32m2 +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define VBOOL_T vbool16_t +#define UINT_V_T vuint32m2_t +#define VID_V_UINT vid_v_u32m2 +#define VMSGTU_VX_UINT vmsgtu_vx_u32m2_b16 +#define VMSEQ_VX_UINT vmseq_vx_u32m2_b16 +#define VFMERGE_VFM_FLOAT vfmerge_vfm_f32m2 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define FLOAT_V_T vfloat64m2_t +#define VLEV_FLOAT vle64_v_f64m2 +#define VSEV_FLOAT vse64_v_f64m2 +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define VBOOL_T vbool32_t +#define UINT_V_T vuint64m2_t +#define VID_V_UINT vid_v_u64m2 +#define VMSGTU_VX_UINT vmsgtu_vx_u64m2_b32 +#define VMSEQ_VX_UINT vmseq_vx_u64m2_b32 +#define VFMERGE_VFM_FLOAT vfmerge_vfm_f64m2 +#endif + +int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG posX, BLASLONG posY, FLOAT *b){ + + BLASLONG i, j, js, X; + + FLOAT *ao; + FLOAT_V_T va0, va1; +#ifdef UNIT + VBOOL_T vbool_eq; +#endif + + VBOOL_T vbool_cmp; + UINT_V_T vindex; + + size_t vl; + + for (js = n; js > 0; js -= vl) + { + vl = VSETVL(js); + + X = posX; + + if (posX <= posY) + { + ao = a + posX * 2 + posY * lda * 2; + } + else + { + ao = a + posY * 2 + posX * lda * 2; + } + + i = 0; + do + { + if (X < posY) + { + ao += 2; + b += vl * 2; + X++; + i++; + } + else if (X > posY) + { + VLSEG2_FLOAT(&va0, &va1, ao, vl); + VSSEG2_FLOAT(b, va0, va1, vl); + ao += lda * 2; + b += vl * 2; + X++; + i++; + } + else + { + vindex = VID_V_UINT(vl); + for (j = 0; j < vl; j++) + { + VLSEG2_FLOAT(&va0, &va1, ao, vl); + vbool_cmp = VMSGTU_VX_UINT(vindex, j, vl); + va0 = VFMERGE_VFM_FLOAT(vbool_cmp, va0, ZERO, vl); + va1 = VFMERGE_VFM_FLOAT(vbool_cmp, va1, ZERO, vl); +#ifdef UNIT + vbool_eq = VMSEQ_VX_UINT(vindex, j, vl); + va0 = VFMERGE_VFM_FLOAT(vbool_eq, va0, ONE, vl); + va1 = VFMERGE_VFM_FLOAT(vbool_eq, va1, ZERO, vl); +#endif + VSSEG2_FLOAT(b, va0, va1, vl); + ao += lda * 2; + b += vl * 2; + } + X += vl; + i += vl; + } + }while (i < m); + posY += vl; + } + + return 0; +} + diff --git a/kernel/riscv64/ztrmmkernel_rvv_v1x4.c b/kernel/riscv64/ztrmmkernel_rvv_v1x4.c new file mode 100644 index 000000000..27409ec25 --- /dev/null +++ b/kernel/riscv64/ztrmmkernel_rvv_v1x4.c @@ -0,0 +1,574 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define FLOAT_V_T vfloat32m2_t +#define VLEV_FLOAT vle32_v_f32m2 +#define VSEV_FLOAT vse32_v_f32m2 +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define VFMVVF_FLOAT vfmv_v_f_f32m2 +#define VFMACCVF_FLOAT vfmacc_vf_f32m2 +#define VFNMSACVF_FLOAT vfnmsac_vf_f32m2 +#define VFMULVF_FLOAT vfmul_vf_f32m2 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define FLOAT_V_T vfloat64m2_t +#define VLEV_FLOAT vle64_v_f64m2 +#define VSEV_FLOAT vse64_v_f64m2 +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define VFMVVF_FLOAT vfmv_v_f_f64m2 +#define VFMACCVF_FLOAT vfmacc_vf_f64m2 +#define VFNMSACVF_FLOAT vfnmsac_vf_f64m2 +#define VFMULVF_FLOAT vfmul_vf_f64m2 +#endif + +#if defined(NN) || defined(NT) || defined(TN) || defined(TT) +#define OP_rr VFMACCVF_FLOAT +#define OP_ir VFMACCVF_FLOAT +#define OP_ii VFNMSACVF_FLOAT +#define OP_ri VFMACCVF_FLOAT +#elif defined(NR) || defined(NC) || defined(TR) || defined(TC) +#define OP_rr VFMACCVF_FLOAT +#define OP_ir VFMACCVF_FLOAT +#define OP_ii VFMACCVF_FLOAT +#define OP_ri VFNMSACVF_FLOAT +#elif defined(RN) || defined(RT) || defined(CN) || defined(CT) +#define OP_rr VFMACCVF_FLOAT +#define OP_ir VFNMSACVF_FLOAT +#define OP_ii VFMACCVF_FLOAT +#define OP_ri VFMACCVF_FLOAT +#elif defined(RR) || defined(RC) || defined(CR) || defined(CC) +#define OP_rr VFMACCVF_FLOAT +#define OP_ir VFNMSACVF_FLOAT +#define OP_ii VFNMSACVF_FLOAT +#define OP_ri VFNMSACVF_FLOAT +#endif + +int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alphar,FLOAT alphai,FLOAT* ba,FLOAT* bb,FLOAT* C, BLASLONG ldc, BLASLONG offset) +{ + BLASLONG i,j,k; + FLOAT *C0, *C1, *C2, *C3, *ptrba,*ptrbb; + BLASLONG off, temp; + +#if defined(TRMMKERNEL) && !defined(LEFT) + off = -offset; +#else + off = 0; +#endif + + FLOAT_V_T va0, va1, va2, va3, va4, va5, va6, va7; + FLOAT_V_T vres0, vres1, vres2, vres3, vres4, vres5, vres6, vres7; + + //fprintf(stderr, "%s, bn=%ld bm=%ld bk=%ld alphar=%f alphai=%f ldc=%ld, offset=%ld\n", __FUNCTION__, bn, bm, bk, alphar, alphai, ldc, offset); // Debug + + size_t vl; + for (j = bn/4; j > 0; j--) + { + C0 = C; + C1 = C0 + 2 * ldc; + C2 = C1 + 2 * ldc; + C3 = C2 + 2 * ldc; +#if defined(TRMMKERNEL) && defined(LEFT) + off = offset; +#endif + ptrba = ba; + for (i = bm; i > 0; i -= vl) + { + vl = VSETVL(i); +#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) + ptrbb = bb; +#else + ptrba += off*vl*2; + ptrbb = bb + off*4*2; +#endif + + vres0 = VFMVVF_FLOAT(0.0, vl); + vres1 = VFMVVF_FLOAT(0.0, vl); + vres2 = VFMVVF_FLOAT(0.0, vl); + vres3 = VFMVVF_FLOAT(0.0, vl); + vres4 = VFMVVF_FLOAT(0.0, vl); + vres5 = VFMVVF_FLOAT(0.0, vl); + vres6 = VFMVVF_FLOAT(0.0, vl); + vres7 = VFMVVF_FLOAT(0.0, vl); + +#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) + temp = bk-off; +#elif defined(LEFT) + temp = off+vl; // number of values in A +#else + temp = off+4; // number of values in B +#endif + + for (k = temp/4; k > 0; k--) + { + VLSEG2_FLOAT(&va0, &va1, ptrba, vl); + ptrba += vl*2; + + VLSEG2_FLOAT(&va2, &va3, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va0, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va1, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va1, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va0, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va0, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va1, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va1, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va0, vl); + + vres4 = OP_rr(vres4, *(ptrbb + 4), va0, vl); + vres5 = OP_ir(vres5, *(ptrbb + 4), va1, vl); + vres4 = OP_ii(vres4, *(ptrbb + 5), va1, vl); + vres5 = OP_ri(vres5, *(ptrbb + 5), va0, vl); + + vres6 = OP_rr(vres6, *(ptrbb + 6), va0, vl); + vres7 = OP_ir(vres7, *(ptrbb + 6), va1, vl); + vres6 = OP_ii(vres6, *(ptrbb + 7), va1, vl); + vres7 = OP_ri(vres7, *(ptrbb + 7), va0, vl); + + ptrbb += 8; + + VLSEG2_FLOAT(&va4, &va5, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va2, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va3, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va3, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va2, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va2, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va3, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va3, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va2, vl); + + vres4 = OP_rr(vres4, *(ptrbb + 4), va2, vl); + vres5 = OP_ir(vres5, *(ptrbb + 4), va3, vl); + vres4 = OP_ii(vres4, *(ptrbb + 5), va3, vl); + vres5 = OP_ri(vres5, *(ptrbb + 5), va2, vl); + + vres6 = OP_rr(vres6, *(ptrbb + 6), va2, vl); + vres7 = OP_ir(vres7, *(ptrbb + 6), va3, vl); + vres6 = OP_ii(vres6, *(ptrbb + 7), va3, vl); + vres7 = OP_ri(vres7, *(ptrbb + 7), va2, vl); + + ptrbb += 8; + + VLSEG2_FLOAT(&va6, &va7, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va4, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va5, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va5, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va4, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va4, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va5, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va5, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va4, vl); + + vres4 = OP_rr(vres4, *(ptrbb + 4), va4, vl); + vres5 = OP_ir(vres5, *(ptrbb + 4), va5, vl); + vres4 = OP_ii(vres4, *(ptrbb + 5), va5, vl); + vres5 = OP_ri(vres5, *(ptrbb + 5), va4, vl); + + vres6 = OP_rr(vres6, *(ptrbb + 6), va4, vl); + vres7 = OP_ir(vres7, *(ptrbb + 6), va5, vl); + vres6 = OP_ii(vres6, *(ptrbb + 7), va5, vl); + vres7 = OP_ri(vres7, *(ptrbb + 7), va4, vl); + + ptrbb += 8; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va6, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va7, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va7, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va6, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va6, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va7, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va7, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va6, vl); + + vres4 = OP_rr(vres4, *(ptrbb + 4), va6, vl); + vres5 = OP_ir(vres5, *(ptrbb + 4), va7, vl); + vres4 = OP_ii(vres4, *(ptrbb + 5), va7, vl); + vres5 = OP_ri(vres5, *(ptrbb + 5), va6, vl); + + vres6 = OP_rr(vres6, *(ptrbb + 6), va6, vl); + vres7 = OP_ir(vres7, *(ptrbb + 6), va7, vl); + vres6 = OP_ii(vres6, *(ptrbb + 7), va7, vl); + vres7 = OP_ri(vres7, *(ptrbb + 7), va6, vl); + + ptrbb += 8; + } + + for (k = temp & 3; k > 0; k--) + { + VLSEG2_FLOAT(&va0, &va1, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va0, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va1, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va1, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va0, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va0, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va1, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va1, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va0, vl); + + vres4 = OP_rr(vres4, *(ptrbb + 4), va0, vl); + vres5 = OP_ir(vres5, *(ptrbb + 4), va1, vl); + vres4 = OP_ii(vres4, *(ptrbb + 5), va1, vl); + vres5 = OP_ri(vres5, *(ptrbb + 5), va0, vl); + + vres6 = OP_rr(vres6, *(ptrbb + 6), va0, vl); + vres7 = OP_ir(vres7, *(ptrbb + 6), va1, vl); + vres6 = OP_ii(vres6, *(ptrbb + 7), va1, vl); + vres7 = OP_ri(vres7, *(ptrbb + 7), va0, vl); + + ptrbb += 8; + } + va0 = VFMULVF_FLOAT(vres0, alphar, vl); + va1 = VFMULVF_FLOAT(vres1, alphar, vl); + va0 = VFNMSACVF_FLOAT(va0, alphai, vres1, vl); + va1 = VFMACCVF_FLOAT(va1, alphai, vres0, vl); + VSSEG2_FLOAT(C0, va0, va1, vl); + + va2 = VFMULVF_FLOAT(vres2, alphar, vl); + va3 = VFMULVF_FLOAT(vres3, alphar, vl); + va2 = VFNMSACVF_FLOAT(va2, alphai, vres3, vl); + va3 = VFMACCVF_FLOAT(va3, alphai, vres2, vl); + VSSEG2_FLOAT(C1, va2, va3, vl); + + va0 = VFMULVF_FLOAT(vres4, alphar, vl); + va1 = VFMULVF_FLOAT(vres5, alphar, vl); + va0 = VFNMSACVF_FLOAT(va0, alphai, vres5, vl); + va1 = VFMACCVF_FLOAT(va1, alphai, vres4, vl); + VSSEG2_FLOAT(C2, va0, va1, vl); + + va2 = VFMULVF_FLOAT(vres6, alphar, vl); + va3 = VFMULVF_FLOAT(vres7, alphar, vl); + va2 = VFNMSACVF_FLOAT(va2, alphai, vres7, vl); + va3 = VFMACCVF_FLOAT(va3, alphai, vres6, vl); + VSSEG2_FLOAT(C3, va2, va3, vl); + +#if ( defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) + temp = bk - off; +#ifdef LEFT + temp -= vl; // number of values in A +#else + temp -= 4; // number of values in B +#endif + ptrba += temp*vl*2; + ptrbb += temp*4*2; +#endif + +#ifdef LEFT + off += vl; // number of values in A +#endif + + C0 += vl * 2; + C1 += vl * 2; + C2 += vl * 2; + C3 += vl * 2; + } +#if defined(TRMMKERNEL) && !defined(LEFT) + off += 4; +#endif + + bb += (bk << 3); + C += (ldc << 3); + } + + if (bn & 2) + { + C0 = C; + C1 = C0 + 2 * ldc; +#if defined(TRMMKERNEL) && defined(LEFT) + off = offset; +#endif + ptrba = ba; + for (i = bm; i > 0; i -= vl) + { + vl = VSETVL(i); +#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) + ptrbb = bb; +#else + ptrba += off*vl*2; + ptrbb = bb + off*2*2; +#endif + + vres0 = VFMVVF_FLOAT(0.0, vl); + vres1 = VFMVVF_FLOAT(0.0, vl); + vres2 = VFMVVF_FLOAT(0.0, vl); + vres3 = VFMVVF_FLOAT(0.0, vl); + +#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) + temp = bk-off; +#elif defined(LEFT) + temp = off+vl; // number of values in A +#else + temp = off+2; // number of values in B +#endif + for (k = temp/4; k > 0; k--) + { + VLSEG2_FLOAT(&va0, &va1, ptrba, vl); + ptrba += vl*2; + + VLSEG2_FLOAT(&va2, &va3, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va0, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va1, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va1, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va0, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va0, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va1, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va1, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va0, vl); + + ptrbb += 4; + + VLSEG2_FLOAT(&va4, &va5, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va2, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va3, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va3, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va2, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va2, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va3, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va3, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va2, vl); + + ptrbb += 4; + + VLSEG2_FLOAT(&va6, &va7, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va4, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va5, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va5, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va4, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va4, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va5, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va5, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va4, vl); + + ptrbb += 4; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va6, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va7, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va7, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va6, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va6, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va7, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va7, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va6, vl); + + ptrbb += 4; + } + + for (k = temp & 3; k > 0; k--) + { + VLSEG2_FLOAT(&va0, &va1, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va0, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va1, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va1, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va0, vl); + + vres2 = OP_rr(vres2, *(ptrbb + 2), va0, vl); + vres3 = OP_ir(vres3, *(ptrbb + 2), va1, vl); + vres2 = OP_ii(vres2, *(ptrbb + 3), va1, vl); + vres3 = OP_ri(vres3, *(ptrbb + 3), va0, vl); + + ptrbb += 4; + } + + va0 = VFMULVF_FLOAT(vres0, alphar, vl); + va1 = VFMULVF_FLOAT(vres1, alphar, vl); + va0 = VFNMSACVF_FLOAT(va0, alphai, vres1, vl); + va1 = VFMACCVF_FLOAT(va1, alphai, vres0, vl); + VSSEG2_FLOAT(C0, va0, va1, vl); + + va2 = VFMULVF_FLOAT(vres2, alphar, vl); + va3 = VFMULVF_FLOAT(vres3, alphar, vl); + va2 = VFNMSACVF_FLOAT(va2, alphai, vres3, vl); + va3 = VFMACCVF_FLOAT(va3, alphai, vres2, vl); + VSSEG2_FLOAT(C1, va2, va3, vl); + +#if ( defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) + temp = bk - off; +#ifdef LEFT + temp -= vl; // number of values in A +#else + temp -= 2; // number of values in B +#endif + ptrba += temp*vl*2; + ptrbb += temp*2*2; +#endif + +#ifdef LEFT + off += vl; // number of values in A +#endif + C0 += vl * 2; + C1 += vl * 2; + } + +#if defined(TRMMKERNEL) && !defined(LEFT) + off += 2; +#endif + bb += (bk << 2); + C += (ldc << 2); + } + + if (bn & 1) + { + C0 = C; +#if defined(TRMMKERNEL) && defined(LEFT) + off = offset; +#endif + ptrba = ba; + for (i = bm; i > 0; i -= vl) + { + vl = VSETVL(i); +#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) + ptrbb = bb; +#else + ptrba += off*vl*2; + ptrbb = bb + off*2; +#endif + + vres0 = VFMVVF_FLOAT(0.0, vl); + vres1 = VFMVVF_FLOAT(0.0, vl); + +#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) + temp = bk-off; +#elif defined(LEFT) + temp = off+vl; // number of values in A +#else + temp = off+1; // number of values in B +#endif + for (k = temp/4; k > 0; k--) + { + VLSEG2_FLOAT(&va0, &va1, ptrba, vl); + ptrba += vl*2; + + VLSEG2_FLOAT(&va2, &va3, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va0, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va1, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va1, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va0, vl); + + ptrbb += 2; + + VLSEG2_FLOAT(&va4, &va5, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va2, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va3, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va3, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va2, vl); + + ptrbb += 2; + + VLSEG2_FLOAT(&va6, &va7, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va4, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va5, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va5, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va4, vl); + + ptrbb += 2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va6, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va7, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va7, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va6, vl); + + ptrbb += 2; + } + + for (k = temp & 3; k > 0; k--) + { + VLSEG2_FLOAT(&va0, &va1, ptrba, vl); + ptrba += vl*2; + + vres0 = OP_rr(vres0, *(ptrbb + 0), va0, vl); + vres1 = OP_ir(vres1, *(ptrbb + 0), va1, vl); + vres0 = OP_ii(vres0, *(ptrbb + 1), va1, vl); + vres1 = OP_ri(vres1, *(ptrbb + 1), va0, vl); + + ptrbb += 2; + } + + va0 = VFMULVF_FLOAT(vres0, alphar, vl); + va1 = VFMULVF_FLOAT(vres1, alphar, vl); + va0 = VFNMSACVF_FLOAT(va0, alphai, vres1, vl); + va1 = VFMACCVF_FLOAT(va1, alphai, vres0, vl); + VSSEG2_FLOAT(C0, va0, va1, vl); + +#if ( defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) + temp = bk - off; +#ifdef LEFT + temp -= vl; // number of values in A +#else + temp -= 1; // number of values in B +#endif + ptrba += temp*vl*2; + ptrbb += temp*2; +#endif + +#ifdef LEFT + off += vl; // number of values in A +#endif + C0 += vl * 2; + } + +#if defined(TRMMKERNEL) && !defined(LEFT) + off += 1; +#endif + bb += bk << 1; + C += ldc << 1; + } + return 0; +} diff --git a/kernel/riscv64/ztrsm_lncopy_rvv_v1.c b/kernel/riscv64/ztrsm_lncopy_rvv_v1.c new file mode 100644 index 000000000..b7ccb1eb3 --- /dev/null +++ b/kernel/riscv64/ztrsm_lncopy_rvv_v1.c @@ -0,0 +1,115 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define FLOAT_V_T vfloat32m2_t +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define VSSEG2_FLOAT_M vsseg2e32_v_f32m2_m +#define VBOOL_T vbool16_t +#define UINT_V_T vuint32m2_t +#define VID_V_UINT vid_v_u32m2 +#define VMSLTU_VX_UINT vmsltu_vx_u32m2_b16 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define FLOAT_V_T vfloat64m2_t +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define VSSEG2_FLOAT_M vsseg2e64_v_f64m2_m +#define VBOOL_T vbool32_t +#define UINT_V_T vuint64m2_t +#define VID_V_UINT vid_v_u64m2 +#define VMSLTU_VX_UINT vmsltu_vx_u64m2_b32 + +#endif + + +int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ + + //fprintf(stderr, "%s , %s, m = %4ld n = %4ld lda = %4ld offset = %4ld\n", __FILE__, __FUNCTION__, m, n, lda, offset); // Debug + + BLASLONG i, ii, jj, js; + + FLOAT *ao; + + jj = offset; + + BLASLONG stride_lda = sizeof(FLOAT)*lda*2; + + FLOAT_V_T va0, va1; + VBOOL_T vbool_cmp; + UINT_V_T vindex; + size_t vl; + + for (js = n; js > 0; js -= vl) + { + vl = VSETVL(js); + ao = a; + + ii = 0; + for (i = 0; i < m;) + { + if (ii == jj) + { + vindex = VID_V_UINT(vl); + for (unsigned int j = 0; j < vl; j++) + { + VLSSEG2_FLOAT(&va0, &va1, ao, stride_lda, vl); + vbool_cmp = VMSLTU_VX_UINT(vindex, j, vl); + VSSEG2_FLOAT_M(vbool_cmp, b, va0, va1, vl); + + compinv((b + j * 2), *(ao + j * lda * 2), *(ao + j * lda * 2 + 1)); + ao += 2; + b += vl * 2; + } + i += vl; + ii += vl; + } + else + { + if (ii > jj) + { + VLSSEG2_FLOAT(&va0, &va1, ao, stride_lda, vl); + VSSEG2_FLOAT(b, va0, va1, vl); + } + ao += 2; + b += vl * 2; + i++; + ii++; + } + } + + a += vl * lda * 2; + jj += vl; + } + + return 0; +} diff --git a/kernel/riscv64/ztrsm_ltcopy_rvv_v1.c b/kernel/riscv64/ztrsm_ltcopy_rvv_v1.c new file mode 100644 index 000000000..911b81de5 --- /dev/null +++ b/kernel/riscv64/ztrsm_ltcopy_rvv_v1.c @@ -0,0 +1,114 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define FLOAT_V_T vfloat32m2_t +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define VSSEG2_FLOAT_M vsseg2e32_v_f32m2_m +#define VBOOL_T vbool16_t +#define UINT_V_T vuint32m2_t +#define VID_V_UINT vid_v_u32m2 +#define VMSGTU_VX_UINT vmsgtu_vx_u32m2_b16 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define FLOAT_V_T vfloat64m2_t +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define VSSEG2_FLOAT_M vsseg2e64_v_f64m2_m +#define VBOOL_T vbool32_t +#define UINT_V_T vuint64m2_t +#define VID_V_UINT vid_v_u64m2 +#define VMSGTU_VX_UINT vmsgtu_vx_u64m2_b32 +#endif + +int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ + + //fprintf(stderr, "%s , %s, m = %4ld n = %4ld lda = %4ld offset = %4ld\n", __FILE__, __FUNCTION__, m, n, lda, offset); // Debug + + BLASLONG i, ii, jj, js; + + FLOAT *ao; + + jj = offset; + + FLOAT_V_T va0, va1; + VBOOL_T vbool_cmp; + UINT_V_T vindex; + + size_t vl; + + for (js = n; js > 0; js -= vl) + { + vl = VSETVL(js); + ao = a; + + ii = 0; + for (i = 0; i < m;) + { + + if (ii == jj) + { + vindex = VID_V_UINT(vl); + for (unsigned int j = 0; j < vl; j++) + { + compinv((b + j * 2), *(ao + j * 2), *(ao + j * 2 + 1)); + + VLSEG2_FLOAT(&va0, &va1, ao, vl); + vbool_cmp = VMSGTU_VX_UINT(vindex, j, vl); + VSSEG2_FLOAT_M(vbool_cmp, b, va0, va1, vl); + + b += vl * 2; + ao += lda * 2; + } + i += vl; + ii += vl; + } + else + { + if (ii < jj) + { + VLSEG2_FLOAT(&va0, &va1, ao, vl); + VSSEG2_FLOAT(b, va0, va1, vl); + } + ao += lda * 2; + b += vl * 2; + i ++; + ii ++; + } + } + + a += vl * 2; + jj += vl; + } + return 0; +} + diff --git a/kernel/riscv64/ztrsm_uncopy_rvv_v1.c b/kernel/riscv64/ztrsm_uncopy_rvv_v1.c new file mode 100644 index 000000000..db075c29b --- /dev/null +++ b/kernel/riscv64/ztrsm_uncopy_rvv_v1.c @@ -0,0 +1,113 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + + +#include +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define FLOAT_V_T vfloat32m2_t +#define VLSSEG2_FLOAT vlsseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define VSSEG2_FLOAT_M vsseg2e32_v_f32m2_m +#define VBOOL_T vbool16_t +#define UINT_V_T vuint32m2_t +#define VID_V_UINT vid_v_u32m2 +#define VMSGTU_VX_UINT vmsgtu_vx_u32m2_b16 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define FLOAT_V_T vfloat64m2_t +#define VLSSEG2_FLOAT vlsseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define VSSEG2_FLOAT_M vsseg2e64_v_f64m2_m +#define VBOOL_T vbool32_t +#define UINT_V_T vuint64m2_t +#define VID_V_UINT vid_v_u64m2 +#define VMSGTU_VX_UINT vmsgtu_vx_u64m2_b32 +#endif + + +int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ + + //fprintf(stderr, "%s , %s, m = %4ld n = %4ld lda = %4ld offset = %4ld\n", __FILE__, __FUNCTION__, m, n, lda, offset); // Debug + + BLASLONG i, ii, jj, js; + BLASLONG stride_lda = sizeof(FLOAT)*lda*2; + + FLOAT *ao; + jj = offset; + + FLOAT_V_T va0, va1; + VBOOL_T vbool_cmp; + UINT_V_T vindex; + + size_t vl; + + for (js = n; js > 0; js -= vl) + { + vl = VSETVL(js); + ao = a; + + i = 0; + ii = 0; + for (i = 0; i < m;) + { + if (ii == jj) + { + vindex = VID_V_UINT(vl); + for (unsigned int j = 0; j < vl; j++) + { + compinv((b + j * 2), *(ao + j * lda * 2), *(ao + j * lda * 2 + 1)); + VLSSEG2_FLOAT(&va0, &va1, ao, stride_lda, vl); + vbool_cmp = VMSGTU_VX_UINT(vindex, j, vl); + VSSEG2_FLOAT_M(vbool_cmp, b, va0, va1, vl); + ao += 2; + b += vl * 2; + } + i += vl; + ii += vl; + } + else + { + if (ii < jj) + { + VLSSEG2_FLOAT(&va0, &va1, ao, stride_lda, vl); + VSSEG2_FLOAT(b, va0, va1, vl); + } + ao += 2; + b += vl * 2; + i++; + ii++; + } + } + + a += vl * lda * 2; + jj += vl; + } + return 0; +} diff --git a/kernel/riscv64/ztrsm_utcopy_rvv_v1.c b/kernel/riscv64/ztrsm_utcopy_rvv_v1.c new file mode 100644 index 000000000..e121c6273 --- /dev/null +++ b/kernel/riscv64/ztrsm_utcopy_rvv_v1.c @@ -0,0 +1,115 @@ +/*************************************************************************** +Copyright (c) 2022, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" + +#if !defined(DOUBLE) +#define VSETVL(n) vsetvl_e32m2(n) +#define FLOAT_V_T vfloat32m2_t +#define VLSEG2_FLOAT vlseg2e32_v_f32m2 +#define VSSEG2_FLOAT vsseg2e32_v_f32m2 +#define VSSEG2_FLOAT_M vsseg2e32_v_f32m2_m +#define VBOOL_T vbool16_t +#define UINT_V_T vuint32m2_t +#define VID_V_UINT vid_v_u32m2 +#define VMSLTU_VX_UINT vmsltu_vx_u32m2_b16 +#else +#define VSETVL(n) vsetvl_e64m2(n) +#define FLOAT_V_T vfloat64m2_t +#define VLSEG2_FLOAT vlseg2e64_v_f64m2 +#define VSSEG2_FLOAT vsseg2e64_v_f64m2 +#define VSSEG2_FLOAT_M vsseg2e64_v_f64m2_m +#define VBOOL_T vbool32_t +#define UINT_V_T vuint64m2_t +#define VID_V_UINT vid_v_u64m2 +#define VMSLTU_VX_UINT vmsltu_vx_u64m2_b32 +#endif + + +int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ + + //fprintf(stderr, "%s , %s, m = %4ld n = %4ld lda = %4ld offset = %4ld\n", __FILE__, __FUNCTION__, m, n, lda, offset); // Debug + + BLASLONG i, ii, jj, js; + + FLOAT *ao; + + jj = offset; + FLOAT_V_T va0, va1; + + VBOOL_T vbool_cmp; + UINT_V_T vindex; + + size_t vl; + + for (js = n; js > 0; js -= vl) + { + vl = VSETVL(js); + ao = a; + + ii = 0; + for (i = 0; i < m;) + { + + if (ii == jj) + { + vindex = VID_V_UINT(vl); + for (unsigned int j = 0; j < vl; j++) + { + VLSEG2_FLOAT(&va0, &va1, ao, vl); + vbool_cmp = VMSLTU_VX_UINT(vindex, j, vl); + VSSEG2_FLOAT_M(vbool_cmp, b, va0, va1, vl); + + compinv((b + j * 2), *(ao + j * 2), *(ao + j * 2 + 1)); + + ao += lda * 2; + b += vl * 2; + } + i += vl; + ii += vl; + } + else + { + if (ii > jj) + { + VLSEG2_FLOAT(&va0, &va1, ao, vl); + VSSEG2_FLOAT(b, va0, va1, vl); + } + ao += lda * 2; + b += vl * 2; + i ++; + ii ++; + } + } + + a += vl * 2; + jj += vl; + } + + return 0; +} diff --git a/param.h b/param.h index 4ece0b5a6..c5c70b78e 100644 --- a/param.h +++ b/param.h @@ -3055,11 +3055,13 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define DGEMM_DEFAULT_UNROLL_N 8 //2 // 4 #define DGEMM_DEFAULT_UNROLL_MN 32 -#define CGEMM_DEFAULT_UNROLL_M 2 -#define CGEMM_DEFAULT_UNROLL_N 2 +#define CGEMM_DEFAULT_UNROLL_M 8 +#define CGEMM_DEFAULT_UNROLL_N 4 +#define CGEMM_DEFAULT_UNROLL_MN 16 -#define ZGEMM_DEFAULT_UNROLL_M 2 -#define ZGEMM_DEFAULT_UNROLL_N 2 +#define ZGEMM_DEFAULT_UNROLL_M 8 +#define ZGEMM_DEFAULT_UNROLL_N 4 +#define ZGEMM_DEFAULT_UNROLL_MN 16 #define SGEMM_DEFAULT_P 160 #define DGEMM_DEFAULT_P 160