[WIP] Refactor the driver code for direct SGEMM (#2782)
Move "direct SGEMM" functionality out of the SkylakeX SGEMM kernel and make it available (on x86_64 targets only for now) in DYNAMIC_ARCH builds * Add sgemm_direct targets in the kernel Makefile.L3 and CMakeLists.txt * Add direct_sgemm functions to the gotoblas struct in common_param.h * Move sgemm_direct_performant helper to separate file * Update gemm.c to macros for sgemm_direct to support dynamic_arch naming via common_s,h * (Conditionally) add sgemm_direct functions in setparam-ref.c
This commit is contained in:
parent
2c72972570
commit
75eeb265d7
|
@ -47,12 +47,12 @@ __global__ void cuda_dgemm_kernel(int, int, int, double *, double *, double *);
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
extern void sgemm_kernel_direct(BLASLONG M, BLASLONG N, BLASLONG K,
|
||||
void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K,
|
||||
float * A, BLASLONG strideA,
|
||||
float * B, BLASLONG strideB,
|
||||
float * R, BLASLONG strideR);
|
||||
|
||||
extern int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
|
||||
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
|
||||
|
||||
|
||||
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
|
||||
|
|
|
@ -175,6 +175,11 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
|
|||
int (*ssymv_L) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
|
||||
int (*ssymv_U) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
|
||||
|
||||
#ifdef ARCH_X86_64
|
||||
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG);
|
||||
int (*sgemm_direct_performant) (BLASLONG M, BLASLONG N, BLASLONG K);
|
||||
#endif
|
||||
|
||||
int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
|
||||
int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
|
||||
|
||||
|
|
12
common_s.h
12
common_s.h
|
@ -45,6 +45,10 @@
|
|||
#define SSYMV_THREAD_U ssymv_thread_U
|
||||
#define SSYMV_THREAD_L ssymv_thread_L
|
||||
|
||||
|
||||
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
|
||||
#define SGEMM_DIRECT sgemm_direct
|
||||
|
||||
#define SGEMM_ONCOPY sgemm_oncopy
|
||||
#define SGEMM_OTCOPY sgemm_otcopy
|
||||
|
||||
|
@ -204,6 +208,14 @@
|
|||
#define SSYMV_THREAD_U ssymv_thread_U
|
||||
#define SSYMV_THREAD_L ssymv_thread_L
|
||||
|
||||
#ifdef ARCH_X86_64
|
||||
#define SGEMM_DIRECT_PERFORMANT gotoblas -> sgemm_direct_performant
|
||||
#define SGEMM_DIRECT gotoblas -> sgemm_direct
|
||||
#else
|
||||
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
|
||||
#define SGEMM_DIRECT sgemm_direct
|
||||
#endif
|
||||
|
||||
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy
|
||||
#define SGEMM_OTCOPY gotoblas -> sgemm_otcopy
|
||||
#define SGEMM_INCOPY gotoblas -> sgemm_incopy
|
||||
|
|
|
@ -275,8 +275,8 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
|
|||
#ifdef DYNAMIC_ARCH
|
||||
if (support_avx512() )
|
||||
#endif
|
||||
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && sgemm_kernel_direct_performant(m,n,k)) {
|
||||
sgemm_kernel_direct(m, n, k, a, lda, b, ldb, c, ldc);
|
||||
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) {
|
||||
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -134,6 +134,20 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
|
|||
set(USE_TRMM true)
|
||||
endif ()
|
||||
|
||||
set(USE_DIRECT_SGEMM false)
|
||||
if (X86_64)
|
||||
set(USE_DIRECT_SGEMM true)
|
||||
endif()
|
||||
|
||||
if (USE_DIRECT_SGEMM)
|
||||
# if (NOT DEFINED SGEMMDIRECTKERNEL)
|
||||
set (SGEMMDIRECTKERNEL sgemm_direct_skylakex.c)
|
||||
set (SGEMMDIRECTPERFORMANT sgemm_direct_performant.c)
|
||||
# endif()
|
||||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE)
|
||||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false SINGLE)
|
||||
endif()
|
||||
|
||||
foreach (float_type SINGLE DOUBLE HALF)
|
||||
string(SUBSTRING ${float_type} 0 1 float_char)
|
||||
if (${float_type} STREQUAL "HALF")
|
||||
|
|
|
@ -9,6 +9,10 @@ ifeq ($(ARCH), x86_64)
|
|||
USE_GEMM3M = 1
|
||||
endif
|
||||
|
||||
ifeq ($(ARCH), x86_64)
|
||||
USE_DIRECT_SGEMM = 1
|
||||
endif
|
||||
|
||||
ifeq ($(ARCH), ia64)
|
||||
USE_GEMM3M = 1
|
||||
endif
|
||||
|
@ -65,6 +69,13 @@ ifeq ($(CORE), Z14)
|
|||
USE_TRMM = 1
|
||||
endif
|
||||
|
||||
ifdef USE_DIRECT_SGEMM
|
||||
ifndef SGEMMDIRECTKERNEL
|
||||
SGEMMDIRECTKERNEL = sgemm_direct_skylakex.c
|
||||
SGEMMDIRECTPERFORMANT = sgemm_direct_performant.c
|
||||
endif
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_HALF), 1)
|
||||
ifndef SHGEMMKERNEL
|
||||
SHGEMM_BETA = ../generic/gemm_beta.c
|
||||
|
@ -90,6 +101,12 @@ SKERNELOBJS += \
|
|||
$(SGEMMINCOPYOBJ) $(SGEMMITCOPYOBJ) \
|
||||
$(SGEMMONCOPYOBJ) $(SGEMMOTCOPYOBJ)
|
||||
|
||||
ifdef USE_DIRECT_SGEMM
|
||||
SKERNELOBJS += \
|
||||
sgemm_direct$(TSUFFIX).$(SUFFIX) \
|
||||
sgemm_direct_performant$(TSUFFIX).$(SUFFIX)
|
||||
endif
|
||||
|
||||
DKERNELOBJS += \
|
||||
dgemm_kernel$(TSUFFIX).$(SUFFIX) \
|
||||
$(DGEMMINCOPYOBJ) $(DGEMMITCOPYOBJ) \
|
||||
|
@ -668,6 +685,13 @@ else
|
|||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
|
||||
endif
|
||||
|
||||
ifdef USE_DIRECT_SGEMM
|
||||
$(KDIR)sgemm_direct_performant$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTPERFORMANT)
|
||||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
|
||||
$(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL)
|
||||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_HALF), 1)
|
||||
|
||||
$(KDIR)shgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)
|
||||
|
|
|
@ -135,6 +135,11 @@ gotoblas_t TABLE_NAME = {
|
|||
sgemv_nTS, sgemv_tTS, sger_kTS,
|
||||
ssymv_LTS, ssymv_UTS,
|
||||
|
||||
#ifdef ARCH_X86_64
|
||||
sgemm_directTS,
|
||||
sgemm_direct_performantTS,
|
||||
#endif
|
||||
|
||||
sgemm_kernelTS, sgemm_betaTS,
|
||||
#if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N
|
||||
sgemm_incopyTS, sgemm_itcopyTS,
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
#include "common.h"
|
||||
/* helper for the direct sgemm code written by Arjan van der Ven */
|
||||
|
||||
|
||||
|
||||
|
||||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K)
|
||||
{
|
||||
unsigned long long mnk = M * N * K;
|
||||
/* large matrixes -> not performant */
|
||||
if (mnk >= 28 * 512 * 512)
|
||||
return 0;
|
||||
|
||||
/*
|
||||
* if the B matrix is not a nice multiple if 4 we get many unaligned accesses,
|
||||
* and the regular sgemm copy/realignment of data pays off much quicker
|
||||
*/
|
||||
if ((N & 3) != 0 && (mnk >= 8 * 512 * 512))
|
||||
return 0;
|
||||
|
||||
#ifdef SMP
|
||||
/* if we can run multithreaded, the threading changes the based threshold */
|
||||
if (mnk > 2 * 350 * 512 && num_cpu_avail(3)> 1)
|
||||
return 0;
|
||||
#endif
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
#if defined(SKYLAKEX) || defined (COOPERLAKE)
|
||||
/* the direct sgemm code written by Arjan van der Ven */
|
||||
//#include <immintrin.h>
|
||||
|
||||
#include <immintrin.h>
|
||||
#include "common.h"
|
||||
/*
|
||||
* "Direct sgemm" code. This code operates directly on the inputs and outputs
|
||||
* of the sgemm call, avoiding the copies, memory realignments and threading,
|
||||
|
@ -38,6 +38,7 @@
|
|||
#define MATMUL_SCALAR(N,M) result##N##M += Aval##M * Bval##N;
|
||||
#define STORE_SCALAR(N,M) R[(i+M) * strideR + j + N] = result##N##M;
|
||||
|
||||
#if 0
|
||||
int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K)
|
||||
{
|
||||
unsigned long long mnk = M * N * K;
|
||||
|
@ -61,9 +62,10 @@ int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K)
|
|||
return 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR)
|
||||
//void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR)
|
||||
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR)
|
||||
{
|
||||
int i, j, k;
|
||||
|
||||
|
@ -465,3 +467,8 @@ void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict
|
|||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
#include "common.h"
|
||||
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR)
|
||||
{}
|
||||
#endif
|
||||
|
|
|
@ -512,4 +512,4 @@ CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float * __restrict__ A, f
|
|||
return 0;
|
||||
}
|
||||
#include <immintrin.h>
|
||||
#include "sgemm_direct_skylakex.c"
|
||||
//#include "sgemm_direct_skylakex.c"
|
||||
|
|
Loading…
Reference in New Issue