Merge pull request #2897 from Qiyu8/usimd-double
Add double precision universal intrinsics for X86/ARM
This commit is contained in:
commit
9b9ee92d5f
|
@ -43,6 +43,26 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
|
||||||
if (inc_x == 1)
|
if (inc_x == 1)
|
||||||
{
|
{
|
||||||
#if V_SIMD
|
#if V_SIMD
|
||||||
|
#ifdef DOUBLE
|
||||||
|
const int vstep = v_nlanes_f64;
|
||||||
|
const int unrollx2 = n & (-vstep * 2);
|
||||||
|
const int unrollx = n & -vstep;
|
||||||
|
v_f64 vsum0 = v_zero_f64();
|
||||||
|
v_f64 vsum1 = v_zero_f64();
|
||||||
|
while (i < unrollx2)
|
||||||
|
{
|
||||||
|
vsum0 = v_add_f64(vsum0, v_loadu_f64(x));
|
||||||
|
vsum1 = v_add_f64(vsum1, v_loadu_f64(x + vstep));
|
||||||
|
i += vstep * 2;
|
||||||
|
}
|
||||||
|
vsum0 = v_add_f64(vsum0, vsum1);
|
||||||
|
while (i < unrollx)
|
||||||
|
{
|
||||||
|
vsum0 = v_add_f64(vsum0, v_loadu_f64(x + i));
|
||||||
|
i += vstep;
|
||||||
|
}
|
||||||
|
sumf = v_sum_f64(vsum0);
|
||||||
|
#else
|
||||||
const int vstep = v_nlanes_f32;
|
const int vstep = v_nlanes_f32;
|
||||||
const int unrollx4 = n & (-vstep * 4);
|
const int unrollx4 = n & (-vstep * 4);
|
||||||
const int unrollx = n & -vstep;
|
const int unrollx = n & -vstep;
|
||||||
|
@ -66,6 +86,7 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
|
||||||
i += vstep;
|
i += vstep;
|
||||||
}
|
}
|
||||||
sumf = v_sum_f32(vsum0);
|
sumf = v_sum_f32(vsum0);
|
||||||
|
#endif
|
||||||
#else
|
#else
|
||||||
int n1 = n & -4;
|
int n1 = n & -4;
|
||||||
for (; i < n1; i += 4)
|
for (; i < n1; i += 4)
|
||||||
|
|
|
@ -4,20 +4,27 @@
|
||||||
* Data Type
|
* Data Type
|
||||||
***************************/
|
***************************/
|
||||||
typedef __m256 v_f32;
|
typedef __m256 v_f32;
|
||||||
|
typedef __m256d v_f64;
|
||||||
#define v_nlanes_f32 8
|
#define v_nlanes_f32 8
|
||||||
|
#define v_nlanes_f64 4
|
||||||
/***************************
|
/***************************
|
||||||
* Arithmetic
|
* Arithmetic
|
||||||
***************************/
|
***************************/
|
||||||
#define v_add_f32 _mm256_add_ps
|
#define v_add_f32 _mm256_add_ps
|
||||||
|
#define v_add_f64 _mm256_add_pd
|
||||||
#define v_mul_f32 _mm256_mul_ps
|
#define v_mul_f32 _mm256_mul_ps
|
||||||
|
#define v_mul_f64 _mm256_mul_pd
|
||||||
|
|
||||||
#ifdef HAVE_FMA3
|
#ifdef HAVE_FMA3
|
||||||
// multiply and add, a*b + c
|
// multiply and add, a*b + c
|
||||||
#define v_muladd_f32 _mm256_fmadd_ps
|
#define v_muladd_f32 _mm256_fmadd_ps
|
||||||
|
#define v_muladd_f64 _mm256_fmadd_pd
|
||||||
#else
|
#else
|
||||||
// multiply and add, a*b + c
|
// multiply and add, a*b + c
|
||||||
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
|
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
|
||||||
{ return v_add_f32(v_mul_f32(a, b), c); }
|
{ return v_add_f32(v_mul_f32(a, b), c); }
|
||||||
|
BLAS_FINLINE v_f64 v_muladd_f64(v_f64 a, v_f64 b, v_f64 c)
|
||||||
|
{ return v_add_f64(v_mul_f64(a, b), c); }
|
||||||
#endif // !HAVE_FMA3
|
#endif // !HAVE_FMA3
|
||||||
|
|
||||||
// Horizontal add: Calculates the sum of all vector elements.
|
// Horizontal add: Calculates the sum of all vector elements.
|
||||||
|
@ -31,11 +38,23 @@ BLAS_FINLINE float v_sum_f32(__m256 a)
|
||||||
return _mm_cvtss_f32(sum);
|
return _mm_cvtss_f32(sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BLAS_FINLINE double v_sum_f64(__m256d a)
|
||||||
|
{
|
||||||
|
__m256d sum_halves = _mm256_hadd_pd(a, a);
|
||||||
|
__m128d lo = _mm256_castpd256_pd128(sum_halves);
|
||||||
|
__m128d hi = _mm256_extractf128_pd(sum_halves, 1);
|
||||||
|
__m128d sum = _mm_add_pd(lo, hi);
|
||||||
|
return _mm_cvtsd_f64(sum);
|
||||||
|
}
|
||||||
/***************************
|
/***************************
|
||||||
* memory
|
* memory
|
||||||
***************************/
|
***************************/
|
||||||
// unaligned load
|
// unaligned load
|
||||||
#define v_loadu_f32 _mm256_loadu_ps
|
#define v_loadu_f32 _mm256_loadu_ps
|
||||||
|
#define v_loadu_f64 _mm256_loadu_pd
|
||||||
#define v_storeu_f32 _mm256_storeu_ps
|
#define v_storeu_f32 _mm256_storeu_ps
|
||||||
|
#define v_storeu_f64 _mm256_storeu_pd
|
||||||
#define v_setall_f32(VAL) _mm256_set1_ps(VAL)
|
#define v_setall_f32(VAL) _mm256_set1_ps(VAL)
|
||||||
#define v_zero_f32 _mm256_setzero_ps
|
#define v_setall_f64(VAL) _mm256_set1_pd(VAL)
|
||||||
|
#define v_zero_f32 _mm256_setzero_ps
|
||||||
|
#define v_zero_f64 _mm256_setzero_pd
|
|
@ -4,15 +4,19 @@
|
||||||
* Data Type
|
* Data Type
|
||||||
***************************/
|
***************************/
|
||||||
typedef __m512 v_f32;
|
typedef __m512 v_f32;
|
||||||
|
typedef __m512d v_f64;
|
||||||
#define v_nlanes_f32 16
|
#define v_nlanes_f32 16
|
||||||
|
#define v_nlanes_f64 8
|
||||||
/***************************
|
/***************************
|
||||||
* Arithmetic
|
* Arithmetic
|
||||||
***************************/
|
***************************/
|
||||||
#define v_add_f32 _mm512_add_ps
|
#define v_add_f32 _mm512_add_ps
|
||||||
|
#define v_add_f64 _mm512_add_pd
|
||||||
#define v_mul_f32 _mm512_mul_ps
|
#define v_mul_f32 _mm512_mul_ps
|
||||||
|
#define v_mul_f64 _mm512_mul_pd
|
||||||
// multiply and add, a*b + c
|
// multiply and add, a*b + c
|
||||||
#define v_muladd_f32 _mm512_fmadd_ps
|
#define v_muladd_f32 _mm512_fmadd_ps
|
||||||
|
#define v_muladd_f64 _mm512_fmadd_pd
|
||||||
BLAS_FINLINE float v_sum_f32(v_f32 a)
|
BLAS_FINLINE float v_sum_f32(v_f32 a)
|
||||||
{
|
{
|
||||||
__m512 h64 = _mm512_shuffle_f32x4(a, a, _MM_SHUFFLE(3, 2, 3, 2));
|
__m512 h64 = _mm512_shuffle_f32x4(a, a, _MM_SHUFFLE(3, 2, 3, 2));
|
||||||
|
@ -25,11 +29,26 @@ BLAS_FINLINE float v_sum_f32(v_f32 a)
|
||||||
__m512 sum4 = _mm512_add_ps(sum8, h4);
|
__m512 sum4 = _mm512_add_ps(sum8, h4);
|
||||||
return _mm_cvtss_f32(_mm512_castps512_ps128(sum4));
|
return _mm_cvtss_f32(_mm512_castps512_ps128(sum4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BLAS_FINLINE double v_sum_f64(v_f64 a)
|
||||||
|
{
|
||||||
|
__m512d h64 = _mm512_shuffle_f64x2(a, a, _MM_SHUFFLE(3, 2, 3, 2));
|
||||||
|
__m512d sum32 = _mm512_add_pd(a, h64);
|
||||||
|
__m512d h32 = _mm512_permutex_pd(sum32, _MM_SHUFFLE(1, 0, 3, 2));
|
||||||
|
__m512d sum16 = _mm512_add_pd(sum32, h32);
|
||||||
|
__m512d h16 = _mm512_permute_pd(sum16, _MM_SHUFFLE(2, 3, 0, 1));
|
||||||
|
__m512d sum8 = _mm512_add_pd(sum16, h16);
|
||||||
|
return _mm_cvtsd_f64(_mm512_castpd512_pd128(sum8));
|
||||||
|
}
|
||||||
/***************************
|
/***************************
|
||||||
* memory
|
* memory
|
||||||
***************************/
|
***************************/
|
||||||
// unaligned load
|
// unaligned load
|
||||||
#define v_loadu_f32(PTR) _mm512_loadu_ps((const __m512*)(PTR))
|
#define v_loadu_f32(PTR) _mm512_loadu_ps((const __m512*)(PTR))
|
||||||
|
#define v_loadu_f64(PTR) _mm512_loadu_pd((const __m512*)(PTR))
|
||||||
#define v_storeu_f32 _mm512_storeu_ps
|
#define v_storeu_f32 _mm512_storeu_ps
|
||||||
|
#define v_storeu_f64 _mm512_storeu_pd
|
||||||
#define v_setall_f32(VAL) _mm512_set1_ps(VAL)
|
#define v_setall_f32(VAL) _mm512_set1_ps(VAL)
|
||||||
|
#define v_setall_f64(VAL) _mm512_set1_pd(VAL)
|
||||||
#define v_zero_f32 _mm512_setzero_ps
|
#define v_zero_f32 _mm512_setzero_ps
|
||||||
|
#define v_zero_f64 _mm512_setzero_pd
|
||||||
|
|
|
@ -8,12 +8,18 @@
|
||||||
* Data Type
|
* Data Type
|
||||||
***************************/
|
***************************/
|
||||||
typedef float32x4_t v_f32;
|
typedef float32x4_t v_f32;
|
||||||
|
#if V_SIMD_F64
|
||||||
|
typedef float64x2_t v_f64;
|
||||||
|
#endif
|
||||||
#define v_nlanes_f32 4
|
#define v_nlanes_f32 4
|
||||||
|
#define v_nlanes_f64 2
|
||||||
/***************************
|
/***************************
|
||||||
* Arithmetic
|
* Arithmetic
|
||||||
***************************/
|
***************************/
|
||||||
#define v_add_f32 vaddq_f32
|
#define v_add_f32 vaddq_f32
|
||||||
|
#define v_add_f64 vaddq_f64
|
||||||
#define v_mul_f32 vmulq_f32
|
#define v_mul_f32 vmulq_f32
|
||||||
|
#define v_mul_f64 vmulq_f64
|
||||||
|
|
||||||
// FUSED F32
|
// FUSED F32
|
||||||
#ifdef HAVE_VFPV4 // FMA
|
#ifdef HAVE_VFPV4 // FMA
|
||||||
|
@ -26,12 +32,26 @@ typedef float32x4_t v_f32;
|
||||||
{ return vmlaq_f32(c, a, b); }
|
{ return vmlaq_f32(c, a, b); }
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// FUSED F64
|
||||||
|
#if V_SIMD_F64
|
||||||
|
BLAS_FINLINE v_f64 v_muladd_f64(v_f64 a, v_f64 b, v_f64 c)
|
||||||
|
{ return vfmaq_f64(c, a, b); }
|
||||||
|
#endif
|
||||||
|
|
||||||
// Horizontal add: Calculates the sum of all vector elements.
|
// Horizontal add: Calculates the sum of all vector elements.
|
||||||
BLAS_FINLINE float v_sum_f32(float32x4_t a)
|
BLAS_FINLINE float v_sum_f32(float32x4_t a)
|
||||||
{
|
{
|
||||||
float32x2_t r = vadd_f32(vget_high_f32(a), vget_low_f32(a));
|
float32x2_t r = vadd_f32(vget_high_f32(a), vget_low_f32(a));
|
||||||
return vget_lane_f32(vpadd_f32(r, r), 0);
|
return vget_lane_f32(vpadd_f32(r, r), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if V_SIMD_F64
|
||||||
|
BLAS_FINLINE double v_sum_f64(float64x2_t a)
|
||||||
|
{
|
||||||
|
return vget_lane_f64(vget_low_f64(a) + vget_high_f64(a), 0);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
/***************************
|
/***************************
|
||||||
* memory
|
* memory
|
||||||
***************************/
|
***************************/
|
||||||
|
@ -39,4 +59,10 @@ BLAS_FINLINE float v_sum_f32(float32x4_t a)
|
||||||
#define v_loadu_f32(a) vld1q_f32((const float*)a)
|
#define v_loadu_f32(a) vld1q_f32((const float*)a)
|
||||||
#define v_storeu_f32 vst1q_f32
|
#define v_storeu_f32 vst1q_f32
|
||||||
#define v_setall_f32(VAL) vdupq_n_f32(VAL)
|
#define v_setall_f32(VAL) vdupq_n_f32(VAL)
|
||||||
#define v_zero_f32() vdupq_n_f32(0.0f)
|
#define v_zero_f32() vdupq_n_f32(0.0f)
|
||||||
|
#if V_SIMD_F64
|
||||||
|
#define v_loadu_f64(a) vld1q_f64((const double*)a)
|
||||||
|
#define v_storeu_f64 vst1q_f64
|
||||||
|
#define v_setall_f64 vdupq_n_f64
|
||||||
|
#define v_zero_f64() vdupq_n_f64(0.0)
|
||||||
|
#endif
|
|
@ -4,22 +4,30 @@
|
||||||
* Data Type
|
* Data Type
|
||||||
***************************/
|
***************************/
|
||||||
typedef __m128 v_f32;
|
typedef __m128 v_f32;
|
||||||
|
typedef __m128d v_f64;
|
||||||
#define v_nlanes_f32 4
|
#define v_nlanes_f32 4
|
||||||
|
#define v_nlanes_f64 2
|
||||||
/***************************
|
/***************************
|
||||||
* Arithmetic
|
* Arithmetic
|
||||||
***************************/
|
***************************/
|
||||||
#define v_add_f32 _mm_add_ps
|
#define v_add_f32 _mm_add_ps
|
||||||
|
#define v_add_f64 _mm_add_pd
|
||||||
#define v_mul_f32 _mm_mul_ps
|
#define v_mul_f32 _mm_mul_ps
|
||||||
|
#define v_mul_f64 _mm_mul_pd
|
||||||
#ifdef HAVE_FMA3
|
#ifdef HAVE_FMA3
|
||||||
// multiply and add, a*b + c
|
// multiply and add, a*b + c
|
||||||
#define v_muladd_f32 _mm_fmadd_ps
|
#define v_muladd_f32 _mm_fmadd_ps
|
||||||
|
#define v_muladd_f64 _mm_fmadd_pd
|
||||||
#elif defined(HAVE_FMA4)
|
#elif defined(HAVE_FMA4)
|
||||||
// multiply and add, a*b + c
|
// multiply and add, a*b + c
|
||||||
#define v_muladd_f32 _mm_macc_ps
|
#define v_muladd_f32 _mm_macc_ps
|
||||||
|
#define v_muladd_f64 _mm_macc_pd
|
||||||
#else
|
#else
|
||||||
// multiply and add, a*b + c
|
// multiply and add, a*b + c
|
||||||
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
|
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
|
||||||
{ return v_add_f32(v_mul_f32(a, b), c); }
|
{ return v_add_f32(v_mul_f32(a, b), c); }
|
||||||
|
BLAS_FINLINE v_f64 v_muladd_f64(v_f64 a, v_f64 b, v_f64 c)
|
||||||
|
{ return v_add_f64(v_mul_f64(a, b), c); }
|
||||||
#endif // HAVE_FMA3
|
#endif // HAVE_FMA3
|
||||||
|
|
||||||
// Horizontal add: Calculates the sum of all vector elements.
|
// Horizontal add: Calculates the sum of all vector elements.
|
||||||
|
@ -36,11 +44,24 @@ BLAS_FINLINE float v_sum_f32(__m128 a)
|
||||||
return _mm_cvtss_f32(t4);
|
return _mm_cvtss_f32(t4);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BLAS_FINLINE double v_sum_f64(__m128d a)
|
||||||
|
{
|
||||||
|
#ifdef HAVE_SSE3
|
||||||
|
return _mm_cvtsd_f64(_mm_hadd_pd(a, a));
|
||||||
|
#else
|
||||||
|
return _mm_cvtsd_f64(_mm_add_pd(a, _mm_unpackhi_pd(a, a)));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
/***************************
|
/***************************
|
||||||
* memory
|
* memory
|
||||||
***************************/
|
***************************/
|
||||||
// unaligned load
|
// unaligned load
|
||||||
#define v_loadu_f32 _mm_loadu_ps
|
#define v_loadu_f32 _mm_loadu_ps
|
||||||
|
#define v_loadu_f64 _mm_loadu_pd
|
||||||
#define v_storeu_f32 _mm_storeu_ps
|
#define v_storeu_f32 _mm_storeu_ps
|
||||||
|
#define v_storeu_f64 _mm_storeu_pd
|
||||||
#define v_setall_f32(VAL) _mm_set1_ps(VAL)
|
#define v_setall_f32(VAL) _mm_set1_ps(VAL)
|
||||||
#define v_zero_f32 _mm_setzero_ps
|
#define v_setall_f64(VAL) _mm_set1_pd(VAL)
|
||||||
|
#define v_zero_f32 _mm_setzero_ps
|
||||||
|
#define v_zero_f64 _mm_setzero_pd
|
|
@ -53,6 +53,15 @@ static void daxpy_kernel_8(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *alpha)
|
||||||
BLASLONG register i = 0;
|
BLASLONG register i = 0;
|
||||||
FLOAT a = *alpha;
|
FLOAT a = *alpha;
|
||||||
#if V_SIMD
|
#if V_SIMD
|
||||||
|
#ifdef DOUBLE
|
||||||
|
v_f64 __alpha, tmp;
|
||||||
|
__alpha = v_setall_f64(*alpha);
|
||||||
|
const int vstep = v_nlanes_f64;
|
||||||
|
for (; i < n; i += vstep) {
|
||||||
|
tmp = v_muladd_f64(__alpha, v_loadu_f64( x + i ), v_loadu_f64(y + i));
|
||||||
|
v_storeu_f64(y + i, tmp);
|
||||||
|
}
|
||||||
|
#else
|
||||||
v_f32 __alpha, tmp;
|
v_f32 __alpha, tmp;
|
||||||
__alpha = v_setall_f32(*alpha);
|
__alpha = v_setall_f32(*alpha);
|
||||||
const int vstep = v_nlanes_f32;
|
const int vstep = v_nlanes_f32;
|
||||||
|
@ -60,6 +69,7 @@ static void daxpy_kernel_8(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *alpha)
|
||||||
tmp = v_muladd_f32(__alpha, v_loadu_f32( x + i ), v_loadu_f32(y + i));
|
tmp = v_muladd_f32(__alpha, v_loadu_f32( x + i ), v_loadu_f32(y + i));
|
||||||
v_storeu_f32(y + i, tmp);
|
v_storeu_f32(y + i, tmp);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
#else
|
#else
|
||||||
while(i < n)
|
while(i < n)
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue