Merge pull request #2867 from Qiyu8/usimd-floatdot
Optimize the performance of dot by using universal intrinsics in X86/ARM
This commit is contained in:
commit
e1b7123bbe
|
@ -8,6 +8,11 @@ endif
|
|||
endif
|
||||
endif
|
||||
|
||||
ifdef HAVE_SSE3
|
||||
CCOMMON_OPT += -msse3
|
||||
FCOMMON_OPT += -msse3
|
||||
endif
|
||||
|
||||
ifeq ($(CORE), SKYLAKEX)
|
||||
ifndef DYNAMIC_ARCH
|
||||
ifndef NO_AVX512
|
||||
|
|
|
@ -70,6 +70,9 @@ if (DEFINED TARGET)
|
|||
set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -mavx2")
|
||||
endif()
|
||||
endif()
|
||||
if (DEFINED HAVE_SSE3)
|
||||
set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -msse3")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (DEFINED TARGET)
|
||||
|
|
|
@ -5,6 +5,9 @@ endif
|
|||
TOPDIR = ..
|
||||
include $(TOPDIR)/Makefile.system
|
||||
|
||||
ifdef HAVE_SSE3
|
||||
CFLAGS += -msse3
|
||||
endif
|
||||
|
||||
ifeq ($(C_COMPILER), GCC)
|
||||
GCCVERSIONGTEQ9 := $(shell expr `$(CC) -dumpversion | cut -f1 -d.` \>= 9)
|
||||
|
|
|
@ -97,7 +97,7 @@ CNRM2KERNEL = znrm2.S
|
|||
ZNRM2KERNEL = znrm2.S
|
||||
|
||||
DDOTKERNEL = dot.S
|
||||
SDOTKERNEL = dot.S
|
||||
SDOTKERNEL = ../generic/dot.c
|
||||
CDOTKERNEL = zdot.S
|
||||
ZDOTKERNEL = zdot.S
|
||||
DSDOTKERNEL = dot.S
|
||||
|
|
|
@ -97,7 +97,7 @@ CNRM2KERNEL = znrm2.S
|
|||
ZNRM2KERNEL = znrm2.S
|
||||
|
||||
DDOTKERNEL = dot.S
|
||||
SDOTKERNEL = dot.S
|
||||
SDOTKERNEL = ../generic/dot.c
|
||||
CDOTKERNEL = zdot.S
|
||||
ZDOTKERNEL = zdot.S
|
||||
DSDOTKERNEL = dot.S
|
||||
|
|
|
@ -70,7 +70,7 @@ DCOPYKERNEL = copy.S
|
|||
CCOPYKERNEL = copy.S
|
||||
ZCOPYKERNEL = copy.S
|
||||
|
||||
SDOTKERNEL = dot.S
|
||||
SDOTKERNEL = ../generic/dot.c
|
||||
DDOTKERNEL = dot.S
|
||||
CDOTKERNEL = zdot.S
|
||||
ZDOTKERNEL = zdot.S
|
||||
|
|
|
@ -27,7 +27,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#include "../simd/intrin.h"
|
||||
#if defined(DSDOT)
|
||||
double CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y)
|
||||
#else
|
||||
|
@ -47,27 +47,59 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y)
|
|||
|
||||
if ( (inc_x == 1) && (inc_y == 1) )
|
||||
{
|
||||
|
||||
int n1 = n & -4;
|
||||
|
||||
while(i < n1)
|
||||
int n1 = n & -4;
|
||||
#if V_SIMD && !defined(DSDOT)
|
||||
const int vstep = v_nlanes_f32;
|
||||
const int unrollx4 = n & (-vstep * 4);
|
||||
const int unrollx = n & -vstep;
|
||||
v_f32 vsum0 = v_zero_f32();
|
||||
v_f32 vsum1 = v_zero_f32();
|
||||
v_f32 vsum2 = v_zero_f32();
|
||||
v_f32 vsum3 = v_zero_f32();
|
||||
while(i < unrollx4)
|
||||
{
|
||||
vsum0 = v_muladd_f32(
|
||||
v_loadu_f32(x + i), v_loadu_f32(y + i), vsum0
|
||||
);
|
||||
vsum1 = v_muladd_f32(
|
||||
v_loadu_f32(x + i + vstep), v_loadu_f32(y + i + vstep), vsum1
|
||||
);
|
||||
vsum2 = v_muladd_f32(
|
||||
v_loadu_f32(x + i + vstep*2), v_loadu_f32(y + i + vstep*2), vsum2
|
||||
);
|
||||
vsum3 = v_muladd_f32(
|
||||
v_loadu_f32(x + i + vstep*3), v_loadu_f32(y + i + vstep*3), vsum3
|
||||
);
|
||||
i += vstep*4;
|
||||
}
|
||||
vsum0 = v_add_f32(
|
||||
v_add_f32(vsum0, vsum1), v_add_f32(vsum2 , vsum3)
|
||||
);
|
||||
while(i < unrollx)
|
||||
{
|
||||
vsum0 = v_muladd_f32(
|
||||
v_loadu_f32(x + i), v_loadu_f32(y + i), vsum0
|
||||
);
|
||||
i += vstep;
|
||||
}
|
||||
dot = v_sum_f32(vsum0);
|
||||
#elif defined(DSDOT)
|
||||
for (; i < n1; i += 4)
|
||||
{
|
||||
|
||||
#if defined(DSDOT)
|
||||
dot += (double) y[i] * (double) x[i]
|
||||
+ (double) y[i+1] * (double) x[i+1]
|
||||
+ (double) y[i+2] * (double) x[i+2]
|
||||
+ (double) y[i+3] * (double) x[i+3] ;
|
||||
}
|
||||
#else
|
||||
for (; i < n1; i += 4)
|
||||
{
|
||||
dot += y[i] * x[i]
|
||||
+ y[i+1] * x[i+1]
|
||||
+ y[i+2] * x[i+2]
|
||||
+ y[i+3] * x[i+3] ;
|
||||
#endif
|
||||
i+=4 ;
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
while(i < n)
|
||||
{
|
||||
|
||||
|
|
|
@ -51,6 +51,11 @@ extern "C" {
|
|||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
/** NEON **/
|
||||
#ifdef HAVE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
// distribute
|
||||
#if defined(HAVE_AVX512VL) || defined(HAVE_AVX512BF16)
|
||||
#include "intrin_avx512.h"
|
||||
|
@ -60,6 +65,10 @@ extern "C" {
|
|||
#include "intrin_sse.h"
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_NEON
|
||||
#include "intrin_neon.h"
|
||||
#endif
|
||||
|
||||
#ifndef V_SIMD
|
||||
#define V_SIMD 0
|
||||
#define V_SIMD_F64 0
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
#define V_SIMD 256
|
||||
#define V_SIMD_F64 1
|
||||
/*
|
||||
Data Type
|
||||
*/
|
||||
/***************************
|
||||
* Data Type
|
||||
***************************/
|
||||
typedef __m256 v_f32;
|
||||
#define v_nlanes_f32 8
|
||||
/*
|
||||
arithmetic
|
||||
*/
|
||||
/***************************
|
||||
* Arithmetic
|
||||
***************************/
|
||||
#define v_add_f32 _mm256_add_ps
|
||||
#define v_mul_f32 _mm256_mul_ps
|
||||
|
||||
|
@ -20,10 +20,22 @@ arithmetic
|
|||
{ return v_add_f32(v_mul_f32(a, b), c); }
|
||||
#endif // !HAVE_FMA3
|
||||
|
||||
/*
|
||||
memory
|
||||
*/
|
||||
// Horizontal add: Calculates the sum of all vector elements.
|
||||
BLAS_FINLINE float v_sum_f32(__m256 a)
|
||||
{
|
||||
__m256 sum_halves = _mm256_hadd_ps(a, a);
|
||||
sum_halves = _mm256_hadd_ps(sum_halves, sum_halves);
|
||||
__m128 lo = _mm256_castps256_ps128(sum_halves);
|
||||
__m128 hi = _mm256_extractf128_ps(sum_halves, 1);
|
||||
__m128 sum = _mm_add_ps(lo, hi);
|
||||
return _mm_cvtss_f32(sum);
|
||||
}
|
||||
|
||||
/***************************
|
||||
* memory
|
||||
***************************/
|
||||
// unaligned load
|
||||
#define v_loadu_f32 _mm256_loadu_ps
|
||||
#define v_storeu_f32 _mm256_storeu_ps
|
||||
#define v_setall_f32(VAL) _mm256_set1_ps(VAL)
|
||||
#define v_setall_f32(VAL) _mm256_set1_ps(VAL)
|
||||
#define v_zero_f32 _mm256_setzero_ps
|
|
@ -1,21 +1,35 @@
|
|||
#define V_SIMD 512
|
||||
#define V_SIMD_F64 1
|
||||
/*
|
||||
Data Type
|
||||
*/
|
||||
/***************************
|
||||
* Data Type
|
||||
***************************/
|
||||
typedef __m512 v_f32;
|
||||
#define v_nlanes_f32 16
|
||||
/*
|
||||
arithmetic
|
||||
*/
|
||||
/***************************
|
||||
* Arithmetic
|
||||
***************************/
|
||||
#define v_add_f32 _mm512_add_ps
|
||||
#define v_mul_f32 _mm512_mul_ps
|
||||
// multiply and add, a*b + c
|
||||
#define v_muladd_f32 _mm512_fmadd_ps
|
||||
/*
|
||||
memory
|
||||
*/
|
||||
|
||||
BLAS_FINLINE float v_sum_f32(v_f32 a)
|
||||
{
|
||||
__m512 h64 = _mm512_shuffle_f32x4(a, a, _MM_SHUFFLE(3, 2, 3, 2));
|
||||
__m512 sum32 = _mm512_add_ps(a, h64);
|
||||
__m512 h32 = _mm512_shuffle_f32x4(sum32, sum32, _MM_SHUFFLE(1, 0, 3, 2));
|
||||
__m512 sum16 = _mm512_add_ps(sum32, h32);
|
||||
__m512 h16 = _mm512_permute_ps(sum16, _MM_SHUFFLE(1, 0, 3, 2));
|
||||
__m512 sum8 = _mm512_add_ps(sum16, h16);
|
||||
__m512 h4 = _mm512_permute_ps(sum8, _MM_SHUFFLE(2, 3, 0, 1));
|
||||
__m512 sum4 = _mm512_add_ps(sum8, h4);
|
||||
return _mm_cvtss_f32(_mm512_castps512_ps128(sum4));
|
||||
}
|
||||
/***************************
|
||||
* memory
|
||||
***************************/
|
||||
// unaligned load
|
||||
#define v_loadu_f32(PTR) _mm512_loadu_ps((const __m512*)(PTR))
|
||||
#define v_storeu_f32 _mm512_storeu_ps
|
||||
#define v_setall_f32(VAL) _mm512_set1_ps(VAL)
|
||||
#define v_zero_f32 _mm512_setzero_ps
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
#define V_SIMD 128
|
||||
#ifdef __aarch64__
|
||||
#define V_SIMD_F64 1
|
||||
#else
|
||||
#define V_SIMD_F64 0
|
||||
#endif
|
||||
/***************************
|
||||
* Data Type
|
||||
***************************/
|
||||
typedef float32x4_t v_f32;
|
||||
#define v_nlanes_f32 4
|
||||
/***************************
|
||||
* Arithmetic
|
||||
***************************/
|
||||
#define v_add_f32 vaddq_f32
|
||||
#define v_mul_f32 vmulq_f32
|
||||
|
||||
// FUSED F32
|
||||
#ifdef HAVE_VFPV4 // FMA
|
||||
// multiply and add, a*b + c
|
||||
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
|
||||
{ return vfmaq_f32(c, a, b); }
|
||||
#else
|
||||
// multiply and add, a*b + c
|
||||
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
|
||||
{ return vmlaq_f32(c, a, b); }
|
||||
#endif
|
||||
|
||||
// Horizontal add: Calculates the sum of all vector elements.
|
||||
BLAS_FINLINE float v_sum_f32(float32x4_t a)
|
||||
{
|
||||
float32x2_t r = vadd_f32(vget_high_f32(a), vget_low_f32(a));
|
||||
return vget_lane_f32(vpadd_f32(r, r), 0);
|
||||
}
|
||||
/***************************
|
||||
* memory
|
||||
***************************/
|
||||
// unaligned load
|
||||
#define v_loadu_f32(a) vld1q_f32((const float*)a)
|
||||
#define v_storeu_f32 vst1q_f32
|
||||
#define v_setall_f32(VAL) vdupq_n_f32(VAL)
|
||||
#define v_zero_f32() vdupq_n_f32(0.0f)
|
|
@ -1,13 +1,13 @@
|
|||
#define V_SIMD 128
|
||||
#define V_SIMD_F64 1
|
||||
/*
|
||||
Data Type
|
||||
*/
|
||||
/***************************
|
||||
* Data Type
|
||||
***************************/
|
||||
typedef __m128 v_f32;
|
||||
#define v_nlanes_f32 4
|
||||
/*
|
||||
arithmetic
|
||||
*/
|
||||
/***************************
|
||||
* Arithmetic
|
||||
***************************/
|
||||
#define v_add_f32 _mm_add_ps
|
||||
#define v_mul_f32 _mm_mul_ps
|
||||
#ifdef HAVE_FMA3
|
||||
|
@ -21,10 +21,26 @@ arithmetic
|
|||
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
|
||||
{ return v_add_f32(v_mul_f32(a, b), c); }
|
||||
#endif // HAVE_FMA3
|
||||
/*
|
||||
memory
|
||||
*/
|
||||
|
||||
// Horizontal add: Calculates the sum of all vector elements.
|
||||
BLAS_FINLINE float v_sum_f32(__m128 a)
|
||||
{
|
||||
#ifdef HAVE_SSE3
|
||||
__m128 sum_halves = _mm_hadd_ps(a, a);
|
||||
return _mm_cvtss_f32(_mm_hadd_ps(sum_halves, sum_halves));
|
||||
#else
|
||||
__m128 t1 = _mm_movehl_ps(a, a);
|
||||
__m128 t2 = _mm_add_ps(a, t1);
|
||||
__m128 t3 = _mm_shuffle_ps(t2, t2, 1);
|
||||
__m128 t4 = _mm_add_ss(t2, t3);
|
||||
return _mm_cvtss_f32(t4);
|
||||
#endif
|
||||
}
|
||||
/***************************
|
||||
* memory
|
||||
***************************/
|
||||
// unaligned load
|
||||
#define v_loadu_f32 _mm_loadu_ps
|
||||
#define v_storeu_f32 _mm_storeu_ps
|
||||
#define v_setall_f32(VAL) _mm_set1_ps(VAL)
|
||||
#define v_setall_f32(VAL) _mm_set1_ps(VAL)
|
||||
#define v_zero_f32 _mm_setzero_ps
|
|
@ -47,3 +47,17 @@ CTEST(dsdot,dsdot_n_1)
|
|||
ASSERT_DBL_NEAR_TOL(res2, res1, DOUBLE_EPS);
|
||||
|
||||
}
|
||||
|
||||
CTEST(dsdot,dsdot_n_2)
|
||||
{
|
||||
float x[] = {0.1F, 0.2F, 0.3F, 0.4F, 0.5F, 0.6F, 0.7F, 0.8F};
|
||||
float y[] = {0.1F, 0.2F, 0.3F, 0.4F, 0.5F, 0.6F, 0.7F, 0.8F};
|
||||
blasint incx=1;
|
||||
blasint incy=1;
|
||||
blasint n=8;
|
||||
|
||||
double res1=0.0f, res2= 2.0400000444054616;
|
||||
|
||||
res1=BLASFUNC(dsdot)(&n, &x, &incx, &y, &incy);
|
||||
ASSERT_DBL_NEAR_TOL(res2, res1, DOUBLE_EPS);
|
||||
}
|
Loading…
Reference in New Issue