diff --git a/Makefile.x86_64 b/Makefile.x86_64 index 00975b25a..65b67bba1 100644 --- a/Makefile.x86_64 +++ b/Makefile.x86_64 @@ -8,6 +8,11 @@ endif endif endif +ifdef HAVE_SSE3 +CCOMMON_OPT += -msse3 +FCOMMON_OPT += -msse3 +endif + ifeq ($(CORE), SKYLAKEX) ifndef DYNAMIC_ARCH ifndef NO_AVX512 diff --git a/cmake/system.cmake b/cmake/system.cmake index 0734065df..3729f6c62 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -70,6 +70,9 @@ if (DEFINED TARGET) set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -mavx2") endif() endif() + if (DEFINED HAVE_SSE3) + set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -msse3") + endif() endif() if (DEFINED TARGET) diff --git a/kernel/Makefile b/kernel/Makefile index 16211218f..0f0fa5a5e 100644 --- a/kernel/Makefile +++ b/kernel/Makefile @@ -5,6 +5,9 @@ endif TOPDIR = .. include $(TOPDIR)/Makefile.system +ifdef HAVE_SSE3 +CFLAGS += -msse3 +endif ifeq ($(C_COMPILER), GCC) GCCVERSIONGTEQ9 := $(shell expr `$(CC) -dumpversion | cut -f1 -d.` \>= 9) diff --git a/kernel/arm64/KERNEL.ARMV8 b/kernel/arm64/KERNEL.ARMV8 index fe32d3137..603e47d87 100644 --- a/kernel/arm64/KERNEL.ARMV8 +++ b/kernel/arm64/KERNEL.ARMV8 @@ -97,7 +97,7 @@ CNRM2KERNEL = znrm2.S ZNRM2KERNEL = znrm2.S DDOTKERNEL = dot.S -SDOTKERNEL = dot.S +SDOTKERNEL = ../generic/dot.c CDOTKERNEL = zdot.S ZDOTKERNEL = zdot.S DSDOTKERNEL = dot.S diff --git a/kernel/arm64/KERNEL.CORTEXA53 b/kernel/arm64/KERNEL.CORTEXA53 index eba38a92e..e23133e52 100644 --- a/kernel/arm64/KERNEL.CORTEXA53 +++ b/kernel/arm64/KERNEL.CORTEXA53 @@ -97,7 +97,7 @@ CNRM2KERNEL = znrm2.S ZNRM2KERNEL = znrm2.S DDOTKERNEL = dot.S -SDOTKERNEL = dot.S +SDOTKERNEL = ../generic/dot.c CDOTKERNEL = zdot.S ZDOTKERNEL = zdot.S DSDOTKERNEL = dot.S diff --git a/kernel/arm64/KERNEL.CORTEXA57 b/kernel/arm64/KERNEL.CORTEXA57 index 04d6940d7..dcf2383a9 100644 --- a/kernel/arm64/KERNEL.CORTEXA57 +++ b/kernel/arm64/KERNEL.CORTEXA57 @@ -70,7 +70,7 @@ DCOPYKERNEL = copy.S CCOPYKERNEL = copy.S ZCOPYKERNEL = copy.S -SDOTKERNEL = dot.S +SDOTKERNEL = ../generic/dot.c DDOTKERNEL = dot.S CDOTKERNEL = zdot.S ZDOTKERNEL = zdot.S diff --git a/kernel/generic/dot.c b/kernel/generic/dot.c index bc07bc78f..5abbb735c 100644 --- a/kernel/generic/dot.c +++ b/kernel/generic/dot.c @@ -27,7 +27,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" - +#include "../simd/intrin.h" #if defined(DSDOT) double CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y) #else @@ -47,27 +47,59 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y) if ( (inc_x == 1) && (inc_y == 1) ) { - - int n1 = n & -4; - - while(i < n1) + int n1 = n & -4; +#if V_SIMD && !defined(DSDOT) + const int vstep = v_nlanes_f32; + const int unrollx4 = n & (-vstep * 4); + const int unrollx = n & -vstep; + v_f32 vsum0 = v_zero_f32(); + v_f32 vsum1 = v_zero_f32(); + v_f32 vsum2 = v_zero_f32(); + v_f32 vsum3 = v_zero_f32(); + while(i < unrollx4) + { + vsum0 = v_muladd_f32( + v_loadu_f32(x + i), v_loadu_f32(y + i), vsum0 + ); + vsum1 = v_muladd_f32( + v_loadu_f32(x + i + vstep), v_loadu_f32(y + i + vstep), vsum1 + ); + vsum2 = v_muladd_f32( + v_loadu_f32(x + i + vstep*2), v_loadu_f32(y + i + vstep*2), vsum2 + ); + vsum3 = v_muladd_f32( + v_loadu_f32(x + i + vstep*3), v_loadu_f32(y + i + vstep*3), vsum3 + ); + i += vstep*4; + } + vsum0 = v_add_f32( + v_add_f32(vsum0, vsum1), v_add_f32(vsum2 , vsum3) + ); + while(i < unrollx) + { + vsum0 = v_muladd_f32( + v_loadu_f32(x + i), v_loadu_f32(y + i), vsum0 + ); + i += vstep; + } + dot = v_sum_f32(vsum0); +#elif defined(DSDOT) + for (; i < n1; i += 4) { - -#if defined(DSDOT) dot += (double) y[i] * (double) x[i] + (double) y[i+1] * (double) x[i+1] + (double) y[i+2] * (double) x[i+2] + (double) y[i+3] * (double) x[i+3] ; + } #else + for (; i < n1; i += 4) + { dot += y[i] * x[i] + y[i+1] * x[i+1] + y[i+2] * x[i+2] + y[i+3] * x[i+3] ; -#endif - i+=4 ; - } - +#endif while(i < n) { diff --git a/kernel/simd/intrin.h b/kernel/simd/intrin.h index 5997bb6ac..ef8fcb865 100644 --- a/kernel/simd/intrin.h +++ b/kernel/simd/intrin.h @@ -51,6 +51,11 @@ extern "C" { #include #endif +/** NEON **/ +#ifdef HAVE_NEON +#include +#endif + // distribute #if defined(HAVE_AVX512VL) || defined(HAVE_AVX512BF16) #include "intrin_avx512.h" @@ -60,6 +65,10 @@ extern "C" { #include "intrin_sse.h" #endif +#ifdef HAVE_NEON +#include "intrin_neon.h" +#endif + #ifndef V_SIMD #define V_SIMD 0 #define V_SIMD_F64 0 diff --git a/kernel/simd/intrin_avx.h b/kernel/simd/intrin_avx.h index f6257ae98..f36a3dbf0 100644 --- a/kernel/simd/intrin_avx.h +++ b/kernel/simd/intrin_avx.h @@ -1,13 +1,13 @@ #define V_SIMD 256 #define V_SIMD_F64 1 -/* -Data Type -*/ +/*************************** + * Data Type + ***************************/ typedef __m256 v_f32; #define v_nlanes_f32 8 -/* -arithmetic -*/ +/*************************** + * Arithmetic + ***************************/ #define v_add_f32 _mm256_add_ps #define v_mul_f32 _mm256_mul_ps @@ -20,10 +20,22 @@ arithmetic { return v_add_f32(v_mul_f32(a, b), c); } #endif // !HAVE_FMA3 -/* -memory -*/ +// Horizontal add: Calculates the sum of all vector elements. +BLAS_FINLINE float v_sum_f32(__m256 a) +{ + __m256 sum_halves = _mm256_hadd_ps(a, a); + sum_halves = _mm256_hadd_ps(sum_halves, sum_halves); + __m128 lo = _mm256_castps256_ps128(sum_halves); + __m128 hi = _mm256_extractf128_ps(sum_halves, 1); + __m128 sum = _mm_add_ps(lo, hi); + return _mm_cvtss_f32(sum); +} + +/*************************** + * memory + ***************************/ // unaligned load #define v_loadu_f32 _mm256_loadu_ps #define v_storeu_f32 _mm256_storeu_ps -#define v_setall_f32(VAL) _mm256_set1_ps(VAL) \ No newline at end of file +#define v_setall_f32(VAL) _mm256_set1_ps(VAL) +#define v_zero_f32 _mm256_setzero_ps \ No newline at end of file diff --git a/kernel/simd/intrin_avx512.h b/kernel/simd/intrin_avx512.h index cb116a9a3..70e5f72e3 100644 --- a/kernel/simd/intrin_avx512.h +++ b/kernel/simd/intrin_avx512.h @@ -1,21 +1,35 @@ #define V_SIMD 512 #define V_SIMD_F64 1 -/* -Data Type -*/ +/*************************** + * Data Type + ***************************/ typedef __m512 v_f32; #define v_nlanes_f32 16 -/* -arithmetic -*/ +/*************************** + * Arithmetic + ***************************/ #define v_add_f32 _mm512_add_ps #define v_mul_f32 _mm512_mul_ps // multiply and add, a*b + c #define v_muladd_f32 _mm512_fmadd_ps -/* -memory -*/ + +BLAS_FINLINE float v_sum_f32(v_f32 a) +{ + __m512 h64 = _mm512_shuffle_f32x4(a, a, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 sum32 = _mm512_add_ps(a, h64); + __m512 h32 = _mm512_shuffle_f32x4(sum32, sum32, _MM_SHUFFLE(1, 0, 3, 2)); + __m512 sum16 = _mm512_add_ps(sum32, h32); + __m512 h16 = _mm512_permute_ps(sum16, _MM_SHUFFLE(1, 0, 3, 2)); + __m512 sum8 = _mm512_add_ps(sum16, h16); + __m512 h4 = _mm512_permute_ps(sum8, _MM_SHUFFLE(2, 3, 0, 1)); + __m512 sum4 = _mm512_add_ps(sum8, h4); + return _mm_cvtss_f32(_mm512_castps512_ps128(sum4)); +} +/*************************** + * memory + ***************************/ // unaligned load #define v_loadu_f32(PTR) _mm512_loadu_ps((const __m512*)(PTR)) #define v_storeu_f32 _mm512_storeu_ps #define v_setall_f32(VAL) _mm512_set1_ps(VAL) +#define v_zero_f32 _mm512_setzero_ps diff --git a/kernel/simd/intrin_neon.h b/kernel/simd/intrin_neon.h new file mode 100644 index 000000000..5875c0e4e --- /dev/null +++ b/kernel/simd/intrin_neon.h @@ -0,0 +1,42 @@ +#define V_SIMD 128 +#ifdef __aarch64__ + #define V_SIMD_F64 1 +#else + #define V_SIMD_F64 0 +#endif +/*************************** + * Data Type + ***************************/ +typedef float32x4_t v_f32; +#define v_nlanes_f32 4 +/*************************** + * Arithmetic + ***************************/ +#define v_add_f32 vaddq_f32 +#define v_mul_f32 vmulq_f32 + +// FUSED F32 +#ifdef HAVE_VFPV4 // FMA + // multiply and add, a*b + c + BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c) + { return vfmaq_f32(c, a, b); } +#else + // multiply and add, a*b + c + BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c) + { return vmlaq_f32(c, a, b); } +#endif + +// Horizontal add: Calculates the sum of all vector elements. +BLAS_FINLINE float v_sum_f32(float32x4_t a) +{ + float32x2_t r = vadd_f32(vget_high_f32(a), vget_low_f32(a)); + return vget_lane_f32(vpadd_f32(r, r), 0); +} +/*************************** + * memory + ***************************/ +// unaligned load +#define v_loadu_f32(a) vld1q_f32((const float*)a) +#define v_storeu_f32 vst1q_f32 +#define v_setall_f32(VAL) vdupq_n_f32(VAL) +#define v_zero_f32() vdupq_n_f32(0.0f) \ No newline at end of file diff --git a/kernel/simd/intrin_sse.h b/kernel/simd/intrin_sse.h index 260112028..9de7e1b27 100644 --- a/kernel/simd/intrin_sse.h +++ b/kernel/simd/intrin_sse.h @@ -1,13 +1,13 @@ #define V_SIMD 128 #define V_SIMD_F64 1 -/* -Data Type -*/ +/*************************** + * Data Type + ***************************/ typedef __m128 v_f32; #define v_nlanes_f32 4 -/* -arithmetic -*/ +/*************************** + * Arithmetic + ***************************/ #define v_add_f32 _mm_add_ps #define v_mul_f32 _mm_mul_ps #ifdef HAVE_FMA3 @@ -21,10 +21,26 @@ arithmetic BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c) { return v_add_f32(v_mul_f32(a, b), c); } #endif // HAVE_FMA3 -/* -memory -*/ + +// Horizontal add: Calculates the sum of all vector elements. +BLAS_FINLINE float v_sum_f32(__m128 a) +{ +#ifdef HAVE_SSE3 + __m128 sum_halves = _mm_hadd_ps(a, a); + return _mm_cvtss_f32(_mm_hadd_ps(sum_halves, sum_halves)); +#else + __m128 t1 = _mm_movehl_ps(a, a); + __m128 t2 = _mm_add_ps(a, t1); + __m128 t3 = _mm_shuffle_ps(t2, t2, 1); + __m128 t4 = _mm_add_ss(t2, t3); + return _mm_cvtss_f32(t4); +#endif +} +/*************************** + * memory + ***************************/ // unaligned load #define v_loadu_f32 _mm_loadu_ps #define v_storeu_f32 _mm_storeu_ps -#define v_setall_f32(VAL) _mm_set1_ps(VAL) \ No newline at end of file +#define v_setall_f32(VAL) _mm_set1_ps(VAL) +#define v_zero_f32 _mm_setzero_ps \ No newline at end of file diff --git a/utest/test_dsdot.c b/utest/test_dsdot.c index d58b398a8..57da7101e 100644 --- a/utest/test_dsdot.c +++ b/utest/test_dsdot.c @@ -47,3 +47,17 @@ CTEST(dsdot,dsdot_n_1) ASSERT_DBL_NEAR_TOL(res2, res1, DOUBLE_EPS); } + +CTEST(dsdot,dsdot_n_2) +{ + float x[] = {0.1F, 0.2F, 0.3F, 0.4F, 0.5F, 0.6F, 0.7F, 0.8F}; + float y[] = {0.1F, 0.2F, 0.3F, 0.4F, 0.5F, 0.6F, 0.7F, 0.8F}; + blasint incx=1; + blasint incy=1; + blasint n=8; + + double res1=0.0f, res2= 2.0400000444054616; + + res1=BLASFUNC(dsdot)(&n, &x, &incx, &y, &incy); + ASSERT_DBL_NEAR_TOL(res2, res1, DOUBLE_EPS); +} \ No newline at end of file