optimize gemv forwarding on ARM64 systems
This commit is contained in:
parent
72461f1c8c
commit
cb48505251
|
@ -226,3 +226,6 @@ In chronological order:
|
||||||
|
|
||||||
* Dirreke <https://github.com/mseminatore>
|
* Dirreke <https://github.com/mseminatore>
|
||||||
* [2024-01-16] Add basic support for the CSKY architecture
|
* [2024-01-16] Add basic support for the CSKY architecture
|
||||||
|
|
||||||
|
* Christopher Daley <https://github.com/cdaley>
|
||||||
|
* [2024-01-24] Optimize GEMV forwarding on ARM64 systems
|
||||||
|
|
|
@ -39,6 +39,7 @@
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#ifdef FUNCTION_PROFILE
|
#ifdef FUNCTION_PROFILE
|
||||||
#include "functable.h"
|
#include "functable.h"
|
||||||
|
@ -499,6 +500,15 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16))
|
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16))
|
||||||
|
#if defined(ARCH_ARM64)
|
||||||
|
// The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c}
|
||||||
|
// perform poorly in certain circumstances. We use the following boolean
|
||||||
|
// variable along with the gemv argument values to avoid these inefficient
|
||||||
|
// gemv cases, see github issue#4951.
|
||||||
|
bool have_tuned_gemv = false;
|
||||||
|
#else
|
||||||
|
bool have_tuned_gemv = true;
|
||||||
|
#endif
|
||||||
// Check if we can convert GEMM -> GEMV
|
// Check if we can convert GEMM -> GEMV
|
||||||
if (args.k != 0) {
|
if (args.k != 0) {
|
||||||
if (args.n == 1) {
|
if (args.n == 1) {
|
||||||
|
@ -518,8 +528,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
|
||||||
if (transb & 1) {
|
if (transb & 1) {
|
||||||
inc_x = args.ldb;
|
inc_x = args.ldb;
|
||||||
}
|
}
|
||||||
GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
|
bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N') || (NT == 'T' && inc_x == 1));
|
||||||
return;
|
if (is_efficient_gemv) {
|
||||||
|
GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (args.m == 1) {
|
if (args.m == 1) {
|
||||||
blasint inc_x = args.lda;
|
blasint inc_x = args.lda;
|
||||||
|
@ -538,8 +551,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
|
||||||
m = args.n;
|
m = args.n;
|
||||||
n = args.k;
|
n = args.k;
|
||||||
}
|
}
|
||||||
GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
|
bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N' && inc_y == 1) || (NT == 'T' && inc_x == 1));
|
||||||
return;
|
if (is_efficient_gemv) {
|
||||||
|
GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
Loading…
Reference in New Issue