Convert dscal_haswell to intrinsics and add AVX512 support

dscal is a relatively simple function... make it more readable and 50% faster
by using C intrinsics and AVX512 support
This commit is contained in:
Arjan van de Ven 2018-08-05 19:19:49 +00:00
parent 93aa18b1a8
commit b1cc69e7a8
1 changed files with 37 additions and 164 deletions

View File

@ -25,182 +25,55 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
#define HAVE_KERNEL_8 1
static void dscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x) __attribute__ ((noinline)); #ifndef __AVX512CD__
#pragma GCC target("avx2,fma")
#endif
#ifdef __AVX2__
#include <immintrin.h>
#define HAVE_KERNEL_8 1
static void dscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x) static void dscal_kernel_8( BLASLONG n, FLOAT *alpha, FLOAT *x)
{ {
int i = 0;
#ifdef __AVX512CD__
BLASLONG n1 = n >> 4 ; __m512d __alpha5 = _mm512_broadcastsd_pd(_mm_load_sd(alpha));
BLASLONG n2 = n & 8 ; for (; i < n; i += 8) {
_mm512_storeu_pd(&x[i + 0], __alpha5 * _mm512_loadu_pd(&x[i + 0]));
__asm__ __volatile__ }
( #else
"vmovddup (%2), %%xmm0 \n\t" // alpha __m256d __alpha = _mm256_broadcastsd_pd(_mm_load_sd(alpha));
for (; i < n; i += 8) {
"addq $128, %1 \n\t" _mm256_storeu_pd(&x[i + 0], __alpha * _mm256_loadu_pd(&x[i + 0]));
_mm256_storeu_pd(&x[i + 4], __alpha * _mm256_loadu_pd(&x[i + 4]));
"cmpq $0, %0 \n\t" }
"je 4f \n\t" #endif
"vmulpd -128(%1), %%xmm0, %%xmm4 \n\t"
"vmulpd -112(%1), %%xmm0, %%xmm5 \n\t"
"vmulpd -96(%1), %%xmm0, %%xmm6 \n\t"
"vmulpd -80(%1), %%xmm0, %%xmm7 \n\t"
"vmulpd -64(%1), %%xmm0, %%xmm8 \n\t"
"vmulpd -48(%1), %%xmm0, %%xmm9 \n\t"
"vmulpd -32(%1), %%xmm0, %%xmm10 \n\t"
"vmulpd -16(%1), %%xmm0, %%xmm11 \n\t"
"subq $1 , %0 \n\t"
"jz 2f \n\t"
".p2align 4 \n\t"
"1: \n\t"
// "prefetcht0 640(%1) \n\t"
"vmovups %%xmm4 ,-128(%1) \n\t"
"vmovups %%xmm5 ,-112(%1) \n\t"
"vmulpd 0(%1), %%xmm0, %%xmm4 \n\t"
"vmovups %%xmm6 , -96(%1) \n\t"
"vmulpd 16(%1), %%xmm0, %%xmm5 \n\t"
"vmovups %%xmm7 , -80(%1) \n\t"
"vmulpd 32(%1), %%xmm0, %%xmm6 \n\t"
// "prefetcht0 704(%1) \n\t"
"vmovups %%xmm8 , -64(%1) \n\t"
"vmulpd 48(%1), %%xmm0, %%xmm7 \n\t"
"vmovups %%xmm9 , -48(%1) \n\t"
"vmulpd 64(%1), %%xmm0, %%xmm8 \n\t"
"vmovups %%xmm10 , -32(%1) \n\t"
"vmulpd 80(%1), %%xmm0, %%xmm9 \n\t"
"vmovups %%xmm11 , -16(%1) \n\t"
"vmulpd 96(%1), %%xmm0, %%xmm10 \n\t"
"vmulpd 112(%1), %%xmm0, %%xmm11 \n\t"
"addq $128, %1 \n\t"
"subq $1 , %0 \n\t"
"jnz 1b \n\t"
"2: \n\t"
"vmovups %%xmm4 ,-128(%1) \n\t"
"vmovups %%xmm5 ,-112(%1) \n\t"
"vmovups %%xmm6 , -96(%1) \n\t"
"vmovups %%xmm7 , -80(%1) \n\t"
"vmovups %%xmm8 , -64(%1) \n\t"
"vmovups %%xmm9 , -48(%1) \n\t"
"vmovups %%xmm10 , -32(%1) \n\t"
"vmovups %%xmm11 , -16(%1) \n\t"
"addq $128, %1 \n\t"
"4: \n\t"
"cmpq $8 ,%3 \n\t"
"jne 5f \n\t"
"vmulpd -128(%1), %%xmm0, %%xmm4 \n\t"
"vmulpd -112(%1), %%xmm0, %%xmm5 \n\t"
"vmulpd -96(%1), %%xmm0, %%xmm6 \n\t"
"vmulpd -80(%1), %%xmm0, %%xmm7 \n\t"
"vmovups %%xmm4 ,-128(%1) \n\t"
"vmovups %%xmm5 ,-112(%1) \n\t"
"vmovups %%xmm6 , -96(%1) \n\t"
"vmovups %%xmm7 , -80(%1) \n\t"
"5: \n\t"
"vzeroupper \n\t"
:
:
"r" (n1), // 0
"r" (x), // 1
"r" (alpha), // 2
"r" (n2) // 3
: "cc",
"%xmm0", "%xmm1", "%xmm2", "%xmm3",
"%xmm4", "%xmm5", "%xmm6", "%xmm7",
"%xmm8", "%xmm9", "%xmm10", "%xmm11",
"%xmm12", "%xmm13", "%xmm14", "%xmm15",
"memory"
);
} }
static void dscal_kernel_8_zero( BLASLONG n, FLOAT *alpha, FLOAT *x) __attribute__ ((noinline));
static void dscal_kernel_8_zero( BLASLONG n, FLOAT *alpha, FLOAT *x) static void dscal_kernel_8_zero( BLASLONG n, FLOAT *alpha, FLOAT *x)
{ {
int i = 0;
/* question to self: Why is this not just memset() */
BLASLONG n1 = n >> 4 ; #ifdef __AVX512CD__
BLASLONG n2 = n & 8 ; __m512d zero = _mm512_setzero_pd();
for (; i < n; i += 8) {
__asm__ __volatile__ _mm512_storeu_pd(&x[i], zero);
( }
"vxorpd %%xmm0, %%xmm0 , %%xmm0 \n\t" #else
__m256d zero = _mm256_setzero_pd();
"addq $128, %1 \n\t" for (; i < n; i += 8) {
_mm256_storeu_pd(&x[i + 0], zero);
"cmpq $0, %0 \n\t" _mm256_storeu_pd(&x[i + 4], zero);
"je 2f \n\t" }
#endif
".p2align 4 \n\t"
"1: \n\t"
"vmovups %%xmm0 ,-128(%1) \n\t"
"vmovups %%xmm0 ,-112(%1) \n\t"
"vmovups %%xmm0 , -96(%1) \n\t"
"vmovups %%xmm0 , -80(%1) \n\t"
"vmovups %%xmm0 , -64(%1) \n\t"
"vmovups %%xmm0 , -48(%1) \n\t"
"vmovups %%xmm0 , -32(%1) \n\t"
"vmovups %%xmm0 , -16(%1) \n\t"
"addq $128, %1 \n\t"
"subq $1 , %0 \n\t"
"jnz 1b \n\t"
"2: \n\t"
"cmpq $8 ,%3 \n\t"
"jne 4f \n\t"
"vmovups %%xmm0 ,-128(%1) \n\t"
"vmovups %%xmm0 ,-112(%1) \n\t"
"vmovups %%xmm0 , -96(%1) \n\t"
"vmovups %%xmm0 , -80(%1) \n\t"
"4: \n\t"
"vzeroupper \n\t"
:
:
"r" (n1), // 0
"r" (x), // 1
"r" (alpha), // 2
"r" (n2) // 3
: "cc",
"%xmm0", "%xmm1", "%xmm2", "%xmm3",
"%xmm4", "%xmm5", "%xmm6", "%xmm7",
"%xmm8", "%xmm9", "%xmm10", "%xmm11",
"%xmm12", "%xmm13", "%xmm14", "%xmm15",
"memory"
);
} }
#endif