Merge pull request #4814 from Mousius/gemv-proxy
Forward GEMM to GEMV when one argument is actually a vector
This commit is contained in:
commit
9afd0c8afd
|
@ -274,9 +274,18 @@ endif
|
|||
ifeq ($(ARCH), loongarch64)
|
||||
SMALL_MATRIX_OPT = 1
|
||||
endif
|
||||
ifeq ($(ARCH), arm64)
|
||||
GEMM_GEMV_FORWARD = 1
|
||||
endif
|
||||
|
||||
ifeq ($(SMALL_MATRIX_OPT), 1)
|
||||
CCOMMON_OPT += -DSMALL_MATRIX_OPT
|
||||
endif
|
||||
ifeq ($(GEMM_GEMV_FORWARD), 1)
|
||||
ifneq ($(ONLY_CBLAS), 1)
|
||||
CCOMMON_OPT += -DGEMM_GEMV_FORWARD
|
||||
endif
|
||||
endif
|
||||
|
||||
# This operation is expensive, so execution should be once.
|
||||
ifndef GOTOBLAS_MAKEFILE
|
||||
|
|
|
@ -391,6 +391,13 @@ endif ()
|
|||
if (X86_64 OR ${CORE} STREQUAL POWER10)
|
||||
set(SMALL_MATRIX_OPT TRUE)
|
||||
endif ()
|
||||
if (ARM64)
|
||||
set(GEMM_GEMV_FORWARD TRUE)
|
||||
endif ()
|
||||
|
||||
if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS)
|
||||
set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD")
|
||||
endif ()
|
||||
if (SMALL_MATRIX_OPT)
|
||||
set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT")
|
||||
endif ()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
/*********************************************************************/
|
||||
/* Copyright 2024 The OpenBLAS Project */
|
||||
/* Copyright 2009, 2010 The University of Texas at Austin. */
|
||||
/* All rights reserved. */
|
||||
/* */
|
||||
|
@ -47,12 +48,16 @@
|
|||
#define SMP_THRESHOLD_MIN 65536.0
|
||||
#ifdef XDOUBLE
|
||||
#define ERROR_NAME "QGEMM "
|
||||
#define GEMV BLASFUNC(qgemv)
|
||||
#elif defined(DOUBLE)
|
||||
#define ERROR_NAME "DGEMM "
|
||||
#define GEMV BLASFUNC(dgemv)
|
||||
#elif defined(BFLOAT16)
|
||||
#define ERROR_NAME "SBGEMM "
|
||||
#define GEMV BLASFUNC(sbgemv)
|
||||
#else
|
||||
#define ERROR_NAME "SGEMM "
|
||||
#define GEMV BLASFUNC(sgemv)
|
||||
#endif
|
||||
#else
|
||||
#define SMP_THRESHOLD_MIN 8192.0
|
||||
|
@ -493,6 +498,52 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
|
|||
args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
|
||||
#endif
|
||||
|
||||
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX)
|
||||
// Check if we can convert GEMM -> GEMV
|
||||
if (args.k != 0) {
|
||||
if (args.n == 1) {
|
||||
blasint inc_x = 1;
|
||||
blasint inc_y = 1;
|
||||
// These were passed in as blasint, but the struct translates them to blaslong
|
||||
blasint m = args.m;
|
||||
blasint n = args.k;
|
||||
blasint lda = args.lda;
|
||||
// Create new transpose parameters
|
||||
char NT = 'N';
|
||||
if (transa & 1) {
|
||||
NT = 'T';
|
||||
m = args.k;
|
||||
n = args.m;
|
||||
}
|
||||
if (transb & 1) {
|
||||
inc_x = args.ldb;
|
||||
}
|
||||
GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
|
||||
return;
|
||||
}
|
||||
if (args.m == 1) {
|
||||
blasint inc_x = args.lda;
|
||||
blasint inc_y = args.ldc;
|
||||
// These were passed in as blasint, but the struct translates them to blaslong
|
||||
blasint m = args.k;
|
||||
blasint n = args.n;
|
||||
blasint ldb = args.ldb;
|
||||
// Create new transpose parameters
|
||||
char NT = 'T';
|
||||
if (transa & 1) {
|
||||
inc_x = 1;
|
||||
}
|
||||
if (transb & 1) {
|
||||
NT = 'N';
|
||||
m = args.n;
|
||||
n = args.k;
|
||||
}
|
||||
GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
IDEBUG_START;
|
||||
|
||||
FUNCTION_PROFILE_START();
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
include $(KERNELDIR)/KERNEL.ARMV8SVE
|
||||
|
||||
SGEMVTKERNEL = gemv_t_sve.c
|
||||
DGEMVTKERNEL = gemv_t_sve.c
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
Copyright (c) 2015, The OpenBLAS Project
|
||||
Copyright (c) 2015, 2024 The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
|
@ -170,39 +170,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
|
||||
.macro KERNEL_F32_FINALIZE
|
||||
#if !defined(DOUBLE)
|
||||
fadd v1.4s, v1.4s, v2.4s
|
||||
// F8 only has 2 accumulators
|
||||
// so add into those pairs
|
||||
fadd v1.4s, v1.4s, v3.4s
|
||||
fadd v1.4s, v1.4s, v4.4s
|
||||
#else
|
||||
fadd v1.2d, v1.2d, v2.2d
|
||||
fadd v1.2d, v1.2d, v3.2d
|
||||
fadd v1.2d, v1.2d, v4.2d
|
||||
fadd v2.4s, v2.4s, v4.4s
|
||||
#endif
|
||||
.endm
|
||||
|
||||
.macro KERNEL_F4
|
||||
.macro KERNEL_F8
|
||||
#if !defined(DOUBLE)
|
||||
ld1 {v2.4s}, [A_PTR], #16
|
||||
ld1 {v3.4s}, [X_PTR], #16
|
||||
fmla v1.4s, v2.4s, v3.4s
|
||||
#else
|
||||
ld1 {v2.2d}, [A_PTR], #16
|
||||
ld1 {v3.2d}, [X_PTR], #16
|
||||
fmla v1.2d, v2.2d, v3.2d
|
||||
|
||||
ld1 {v4.2d}, [A_PTR], #16
|
||||
ld1 {v5.2d}, [X_PTR], #16
|
||||
fmla v1.2d, v4.2d, v5.2d
|
||||
ld1 {v13.4s, v14.4s}, [A_PTR], #32
|
||||
ld1 {v17.4s, v18.4s}, [X_PTR], #32
|
||||
fmla v1.4s, v13.4s, v17.4s
|
||||
fmla v2.4s, v14.4s, v18.4s
|
||||
#else
|
||||
ld1 {v13.2d, v14.2d, v15.2d, v16.2d}, [A_PTR], #64
|
||||
ld1 {v17.2d, v18.2d, v19.2d, v20.2d}, [X_PTR], #64
|
||||
fmla v1.2d, v13.2d, v17.2d
|
||||
fmla v2.2d, v14.2d, v18.2d
|
||||
fmla v3.2d, v15.2d, v19.2d
|
||||
fmla v4.2d, v16.2d, v20.2d
|
||||
#endif
|
||||
.endm
|
||||
|
||||
.macro KERNEL_F4_FINALIZE
|
||||
.macro KERNEL_F8_FINALIZE
|
||||
#if !defined(DOUBLE)
|
||||
ext v2.16b, v1.16b, v1.16b, #8
|
||||
// Take the top two elements of v1 and
|
||||
// put them into the first two lanes of v3
|
||||
ext v3.16b, v1.16b, v1.16b, #8
|
||||
fadd v1.2s, v1.2s, v3.2s
|
||||
ext v4.16b, v2.16b, v2.16b, #8
|
||||
fadd v2.2s, v2.2s, v4.2s
|
||||
// Final pair
|
||||
fadd v1.2s, v1.2s, v2.2s
|
||||
faddp TEMP, v1.2s
|
||||
#else
|
||||
faddp TEMP, v1.2d
|
||||
faddp TEMP1, v2.2d
|
||||
faddp TEMP2, v3.2d
|
||||
faddp TEMP3, v4.2d
|
||||
fadd TEMP, TEMP, TEMP1
|
||||
fadd TEMP2, TEMP2, TEMP3
|
||||
fadd TEMP, TEMP, TEMP2
|
||||
#endif
|
||||
.endm
|
||||
|
||||
|
@ -258,7 +267,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
|
||||
asr I, M, #5
|
||||
cmp I, xzr
|
||||
beq .Lgemv_t_kernel_F4
|
||||
beq .Lgemv_t_kernel_F8
|
||||
|
||||
.Lgemv_t_kernel_F320:
|
||||
|
||||
|
@ -269,24 +278,24 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
|
||||
KERNEL_F32_FINALIZE
|
||||
|
||||
.Lgemv_t_kernel_F4:
|
||||
.Lgemv_t_kernel_F8:
|
||||
ands I, M, #31
|
||||
asr I, I, #2
|
||||
asr I, I, #3
|
||||
cmp I, xzr
|
||||
beq .Lgemv_t_kernel_F1
|
||||
|
||||
.Lgemv_t_kernel_F40:
|
||||
.Lgemv_t_kernel_F80:
|
||||
|
||||
KERNEL_F4
|
||||
KERNEL_F8
|
||||
|
||||
subs I, I, #1
|
||||
bne .Lgemv_t_kernel_F40
|
||||
bne .Lgemv_t_kernel_F80
|
||||
|
||||
.Lgemv_t_kernel_F1:
|
||||
|
||||
KERNEL_F4_FINALIZE
|
||||
KERNEL_F8_FINALIZE
|
||||
|
||||
ands I, M, #3
|
||||
ands I, M, #7
|
||||
ble .Lgemv_t_kernel_F_END
|
||||
|
||||
.Lgemv_t_kernel_F10:
|
||||
|
|
|
@ -59,20 +59,46 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
|
|||
a_ptr = a;
|
||||
|
||||
if (inc_x == 1) {
|
||||
svbool_t pg_true = SV_TRUE();
|
||||
uint64_t sve_size = SV_COUNT();
|
||||
uint64_t sve_size2 = sve_size * 2;
|
||||
BLASLONG m1 = m & -sve_size;
|
||||
BLASLONG m2 = m & -sve_size2;
|
||||
|
||||
for (j = 0; j < n; j++) {
|
||||
BLASLONG i = 0;
|
||||
|
||||
SV_TYPE temp_vec_v2_0 = SV_DUP(0.0);
|
||||
SV_TYPE temp_vec_v2_1 = SV_DUP(0.0);
|
||||
for (; i < m2; i += sve_size2) {
|
||||
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
|
||||
SV_TYPE x_vec0 = svld1(pg_true, x + i);
|
||||
SV_TYPE a_vec1 = svld1(pg_true, a_ptr + i + sve_size);
|
||||
SV_TYPE x_vec1 = svld1(pg_true, x + i + sve_size);
|
||||
temp_vec_v2_0 = svmla_m(pg_true, temp_vec_v2_0, a_vec0, x_vec0);
|
||||
temp_vec_v2_1 = svmla_m(pg_true, temp_vec_v2_1, a_vec1, x_vec1);
|
||||
}
|
||||
|
||||
SV_TYPE temp_vec_v1 = SV_DUP(0.0);
|
||||
for (; i < m1; i += sve_size) {
|
||||
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
|
||||
SV_TYPE x_vec0 = svld1(pg_true, x + i);
|
||||
temp_vec_v1 = svmla_m(pg_true, temp_vec_v1, a_vec0, x_vec0);
|
||||
}
|
||||
|
||||
SV_TYPE temp_vec = SV_DUP(0.0);
|
||||
i = 0;
|
||||
svbool_t pg = SV_WHILE(i, m);
|
||||
while (svptest_any(SV_TRUE(), pg)) {
|
||||
for (; i < m; i += sve_size) {
|
||||
svbool_t pg = SV_WHILE(i, m);
|
||||
SV_TYPE a_vec = svld1(pg, a_ptr + i);
|
||||
SV_TYPE x_vec = svld1(pg, x + i);
|
||||
temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec);
|
||||
i += sve_size;
|
||||
pg = SV_WHILE(i, m);
|
||||
}
|
||||
temp = svaddv(SV_TRUE(), temp_vec);
|
||||
y[iy] += alpha * temp;
|
||||
|
||||
y[iy] += alpha * (
|
||||
(svaddv(SV_TRUE(), temp_vec_v2_0) + svaddv(SV_TRUE(), temp_vec)) +
|
||||
(svaddv(SV_TRUE(), temp_vec_v2_1) + svaddv(SV_TRUE(), temp_vec_v1))
|
||||
);
|
||||
|
||||
iy += inc_y;
|
||||
a_ptr += lda;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue