Merge pull request #4814 from Mousius/gemv-proxy

Forward GEMM to GEMV when one argument is actually a vector
This commit is contained in:
Martin Kroeker 2024-07-31 23:18:01 +02:00 committed by GitHub
commit 9afd0c8afd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 141 additions and 36 deletions

View File

@ -274,9 +274,18 @@ endif
ifeq ($(ARCH), loongarch64)
SMALL_MATRIX_OPT = 1
endif
ifeq ($(ARCH), arm64)
GEMM_GEMV_FORWARD = 1
endif
ifeq ($(SMALL_MATRIX_OPT), 1)
CCOMMON_OPT += -DSMALL_MATRIX_OPT
endif
ifeq ($(GEMM_GEMV_FORWARD), 1)
ifneq ($(ONLY_CBLAS), 1)
CCOMMON_OPT += -DGEMM_GEMV_FORWARD
endif
endif
# This operation is expensive, so execution should be once.
ifndef GOTOBLAS_MAKEFILE

View File

@ -391,6 +391,13 @@ endif ()
if (X86_64 OR ${CORE} STREQUAL POWER10)
set(SMALL_MATRIX_OPT TRUE)
endif ()
if (ARM64)
set(GEMM_GEMV_FORWARD TRUE)
endif ()
if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS)
set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD")
endif ()
if (SMALL_MATRIX_OPT)
set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT")
endif ()

View File

@ -1,4 +1,5 @@
/*********************************************************************/
/* Copyright 2024 The OpenBLAS Project */
/* Copyright 2009, 2010 The University of Texas at Austin. */
/* All rights reserved. */
/* */
@ -47,12 +48,16 @@
#define SMP_THRESHOLD_MIN 65536.0
#ifdef XDOUBLE
#define ERROR_NAME "QGEMM "
#define GEMV BLASFUNC(qgemv)
#elif defined(DOUBLE)
#define ERROR_NAME "DGEMM "
#define GEMV BLASFUNC(dgemv)
#elif defined(BFLOAT16)
#define ERROR_NAME "SBGEMM "
#define GEMV BLASFUNC(sbgemv)
#else
#define ERROR_NAME "SGEMM "
#define GEMV BLASFUNC(sgemv)
#endif
#else
#define SMP_THRESHOLD_MIN 8192.0
@ -493,6 +498,52 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
#endif
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX)
// Check if we can convert GEMM -> GEMV
if (args.k != 0) {
if (args.n == 1) {
blasint inc_x = 1;
blasint inc_y = 1;
// These were passed in as blasint, but the struct translates them to blaslong
blasint m = args.m;
blasint n = args.k;
blasint lda = args.lda;
// Create new transpose parameters
char NT = 'N';
if (transa & 1) {
NT = 'T';
m = args.k;
n = args.m;
}
if (transb & 1) {
inc_x = args.ldb;
}
GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
return;
}
if (args.m == 1) {
blasint inc_x = args.lda;
blasint inc_y = args.ldc;
// These were passed in as blasint, but the struct translates them to blaslong
blasint m = args.k;
blasint n = args.n;
blasint ldb = args.ldb;
// Create new transpose parameters
char NT = 'T';
if (transa & 1) {
inc_x = 1;
}
if (transb & 1) {
NT = 'N';
m = args.n;
n = args.k;
}
GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
return;
}
}
#endif
IDEBUG_START;
FUNCTION_PROFILE_START();

View File

@ -1 +1,4 @@
include $(KERNELDIR)/KERNEL.ARMV8SVE
SGEMVTKERNEL = gemv_t_sve.c
DGEMVTKERNEL = gemv_t_sve.c

View File

@ -1,5 +1,5 @@
/*******************************************************************************
Copyright (c) 2015, The OpenBLAS Project
Copyright (c) 2015, 2024 The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@ -170,39 +170,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
.macro KERNEL_F32_FINALIZE
#if !defined(DOUBLE)
fadd v1.4s, v1.4s, v2.4s
// F8 only has 2 accumulators
// so add into those pairs
fadd v1.4s, v1.4s, v3.4s
fadd v1.4s, v1.4s, v4.4s
#else
fadd v1.2d, v1.2d, v2.2d
fadd v1.2d, v1.2d, v3.2d
fadd v1.2d, v1.2d, v4.2d
fadd v2.4s, v2.4s, v4.4s
#endif
.endm
.macro KERNEL_F4
.macro KERNEL_F8
#if !defined(DOUBLE)
ld1 {v2.4s}, [A_PTR], #16
ld1 {v3.4s}, [X_PTR], #16
fmla v1.4s, v2.4s, v3.4s
#else
ld1 {v2.2d}, [A_PTR], #16
ld1 {v3.2d}, [X_PTR], #16
fmla v1.2d, v2.2d, v3.2d
ld1 {v4.2d}, [A_PTR], #16
ld1 {v5.2d}, [X_PTR], #16
fmla v1.2d, v4.2d, v5.2d
ld1 {v13.4s, v14.4s}, [A_PTR], #32
ld1 {v17.4s, v18.4s}, [X_PTR], #32
fmla v1.4s, v13.4s, v17.4s
fmla v2.4s, v14.4s, v18.4s
#else
ld1 {v13.2d, v14.2d, v15.2d, v16.2d}, [A_PTR], #64
ld1 {v17.2d, v18.2d, v19.2d, v20.2d}, [X_PTR], #64
fmla v1.2d, v13.2d, v17.2d
fmla v2.2d, v14.2d, v18.2d
fmla v3.2d, v15.2d, v19.2d
fmla v4.2d, v16.2d, v20.2d
#endif
.endm
.macro KERNEL_F4_FINALIZE
.macro KERNEL_F8_FINALIZE
#if !defined(DOUBLE)
ext v2.16b, v1.16b, v1.16b, #8
// Take the top two elements of v1 and
// put them into the first two lanes of v3
ext v3.16b, v1.16b, v1.16b, #8
fadd v1.2s, v1.2s, v3.2s
ext v4.16b, v2.16b, v2.16b, #8
fadd v2.2s, v2.2s, v4.2s
// Final pair
fadd v1.2s, v1.2s, v2.2s
faddp TEMP, v1.2s
#else
faddp TEMP, v1.2d
faddp TEMP1, v2.2d
faddp TEMP2, v3.2d
faddp TEMP3, v4.2d
fadd TEMP, TEMP, TEMP1
fadd TEMP2, TEMP2, TEMP3
fadd TEMP, TEMP, TEMP2
#endif
.endm
@ -258,7 +267,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
asr I, M, #5
cmp I, xzr
beq .Lgemv_t_kernel_F4
beq .Lgemv_t_kernel_F8
.Lgemv_t_kernel_F320:
@ -269,24 +278,24 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
KERNEL_F32_FINALIZE
.Lgemv_t_kernel_F4:
.Lgemv_t_kernel_F8:
ands I, M, #31
asr I, I, #2
asr I, I, #3
cmp I, xzr
beq .Lgemv_t_kernel_F1
.Lgemv_t_kernel_F40:
.Lgemv_t_kernel_F80:
KERNEL_F4
KERNEL_F8
subs I, I, #1
bne .Lgemv_t_kernel_F40
bne .Lgemv_t_kernel_F80
.Lgemv_t_kernel_F1:
KERNEL_F4_FINALIZE
KERNEL_F8_FINALIZE
ands I, M, #3
ands I, M, #7
ble .Lgemv_t_kernel_F_END
.Lgemv_t_kernel_F10:

View File

@ -59,20 +59,46 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
a_ptr = a;
if (inc_x == 1) {
svbool_t pg_true = SV_TRUE();
uint64_t sve_size = SV_COUNT();
uint64_t sve_size2 = sve_size * 2;
BLASLONG m1 = m & -sve_size;
BLASLONG m2 = m & -sve_size2;
for (j = 0; j < n; j++) {
BLASLONG i = 0;
SV_TYPE temp_vec_v2_0 = SV_DUP(0.0);
SV_TYPE temp_vec_v2_1 = SV_DUP(0.0);
for (; i < m2; i += sve_size2) {
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
SV_TYPE x_vec0 = svld1(pg_true, x + i);
SV_TYPE a_vec1 = svld1(pg_true, a_ptr + i + sve_size);
SV_TYPE x_vec1 = svld1(pg_true, x + i + sve_size);
temp_vec_v2_0 = svmla_m(pg_true, temp_vec_v2_0, a_vec0, x_vec0);
temp_vec_v2_1 = svmla_m(pg_true, temp_vec_v2_1, a_vec1, x_vec1);
}
SV_TYPE temp_vec_v1 = SV_DUP(0.0);
for (; i < m1; i += sve_size) {
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
SV_TYPE x_vec0 = svld1(pg_true, x + i);
temp_vec_v1 = svmla_m(pg_true, temp_vec_v1, a_vec0, x_vec0);
}
SV_TYPE temp_vec = SV_DUP(0.0);
i = 0;
svbool_t pg = SV_WHILE(i, m);
while (svptest_any(SV_TRUE(), pg)) {
for (; i < m; i += sve_size) {
svbool_t pg = SV_WHILE(i, m);
SV_TYPE a_vec = svld1(pg, a_ptr + i);
SV_TYPE x_vec = svld1(pg, x + i);
temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec);
i += sve_size;
pg = SV_WHILE(i, m);
}
temp = svaddv(SV_TRUE(), temp_vec);
y[iy] += alpha * temp;
y[iy] += alpha * (
(svaddv(SV_TRUE(), temp_vec_v2_0) + svaddv(SV_TRUE(), temp_vec)) +
(svaddv(SV_TRUE(), temp_vec_v2_1) + svaddv(SV_TRUE(), temp_vec_v1))
);
iy += inc_y;
a_ptr += lda;
}