848 lines
46 KiB
C
848 lines
46 KiB
C
/***************************************************************************
|
|
Copyright (c) 2014, The OpenBLAS Project
|
|
All rights reserved.
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions are
|
|
met:
|
|
1. Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
2. Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in
|
|
the documentation and/or other materials provided with the
|
|
distribution.
|
|
3. Neither the name of the OpenBLAS project nor the names of
|
|
its contributors may be used to endorse or promote products
|
|
derived from this software without specific prior written permission.
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
|
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
|
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*****************************************************************************/
|
|
#ifndef __BF16_COMMON_MACROS
|
|
#define __BF16_COMMON_MACROS
|
|
|
|
#include <immintrin.h>
|
|
|
|
#define _MM512_BROADCASTD_EPI32(addr, zmm) \
|
|
__asm__ ("vpbroadcastd (%1), %0;" \
|
|
: "=v" (zmm) \
|
|
: "r" (addr) )
|
|
|
|
#define PREFETCH_T0(addr) \
|
|
__asm__ ("prefetcht0 (%0);" \
|
|
: \
|
|
: "r" (addr) )
|
|
|
|
#define EXTRACT_LOW_256_FROM_512_2X(reg256, reg512) \
|
|
reg256##_0 = _mm512_castps512_ps256(reg512##_0); \
|
|
reg256##_1 = _mm512_castps512_ps256(reg512##_1);
|
|
|
|
|
|
#define BF16_MATRIX_LOAD_8x32(regArray, a, lda, idx_m, idx_n) \
|
|
regArray##_0 = _mm512_loadu_si512(&a[(idx_m+0)*lda + idx_n]); \
|
|
regArray##_1 = _mm512_loadu_si512(&a[(idx_m+1)*lda + idx_n]); \
|
|
regArray##_2 = _mm512_loadu_si512(&a[(idx_m+2)*lda + idx_n]); \
|
|
regArray##_3 = _mm512_loadu_si512(&a[(idx_m+3)*lda + idx_n]); \
|
|
regArray##_4 = _mm512_loadu_si512(&a[(idx_m+4)*lda + idx_n]); \
|
|
regArray##_5 = _mm512_loadu_si512(&a[(idx_m+5)*lda + idx_n]); \
|
|
regArray##_6 = _mm512_loadu_si512(&a[(idx_m+6)*lda + idx_n]); \
|
|
regArray##_7 = _mm512_loadu_si512(&a[(idx_m+7)*lda + idx_n]);
|
|
|
|
|
|
#define BF16_MATRIX_LOAD_8x16(regArray, a, lda, idx_m, idx_n) \
|
|
regArray##_0 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+0)*lda + idx_n])); \
|
|
regArray##_1 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+1)*lda + idx_n])); \
|
|
regArray##_2 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+2)*lda + idx_n])); \
|
|
regArray##_3 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+3)*lda + idx_n])); \
|
|
regArray##_4 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+4)*lda + idx_n])); \
|
|
regArray##_5 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+5)*lda + idx_n])); \
|
|
regArray##_6 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+6)*lda + idx_n])); \
|
|
regArray##_7 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+7)*lda + idx_n]));
|
|
|
|
|
|
#define BF16_MATRIX_LOAD_8x8(regArray, a, lda, idx_m, idx_n) \
|
|
regArray##_0 = _mm_loadu_si128((__m128i *)(&a[(idx_m+0)*lda + idx_n])); \
|
|
regArray##_1 = _mm_loadu_si128((__m128i *)(&a[(idx_m+1)*lda + idx_n])); \
|
|
regArray##_2 = _mm_loadu_si128((__m128i *)(&a[(idx_m+2)*lda + idx_n])); \
|
|
regArray##_3 = _mm_loadu_si128((__m128i *)(&a[(idx_m+3)*lda + idx_n])); \
|
|
regArray##_4 = _mm_loadu_si128((__m128i *)(&a[(idx_m+4)*lda + idx_n])); \
|
|
regArray##_5 = _mm_loadu_si128((__m128i *)(&a[(idx_m+5)*lda + idx_n])); \
|
|
regArray##_6 = _mm_loadu_si128((__m128i *)(&a[(idx_m+6)*lda + idx_n])); \
|
|
regArray##_7 = _mm_loadu_si128((__m128i *)(&a[(idx_m+7)*lda + idx_n]));
|
|
|
|
|
|
#define BF16_MATRIX_LOAD_1x32(regArray, a, lda, idx_m, idx_n) \
|
|
regArray = _mm512_loadu_si512(&a[idx_m*lda + idx_n]);
|
|
|
|
|
|
#define BF16_MATRIX_MASKZ_LOAD_8x32(regArray, a, lda, idx_m, idx_n, mask) \
|
|
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
|
|
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
|
|
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
|
|
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
|
|
regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
|
|
regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
|
|
regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
|
|
regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
|
|
|
|
|
|
#define BF16_MATRIX_MASKZ_LOAD_8x16(regArray, a, lda, idx_m, idx_n, mask) \
|
|
regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
|
|
regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
|
|
regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
|
|
regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
|
|
regArray##_4 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
|
|
regArray##_5 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
|
|
regArray##_6 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
|
|
regArray##_7 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
|
|
|
|
|
|
#define BF16_MATRIX_MASKZ_LOAD_8x8(regArray, a, lda, idx_m, idx_n, mask) \
|
|
regArray##_0 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
|
|
regArray##_1 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
|
|
regArray##_2 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
|
|
regArray##_3 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
|
|
regArray##_4 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
|
|
regArray##_5 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
|
|
regArray##_6 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
|
|
regArray##_7 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
|
|
|
|
|
|
#define BF16_MATRIX_MASKZ_LOAD_4x32(regArray, a, lda, idx_m, idx_n, mask) \
|
|
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
|
|
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
|
|
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
|
|
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]);
|
|
|
|
|
|
#define BF16_MATRIX_MASKZ_LOAD_4x16(regArray, a, lda, idx_m, idx_n, mask) \
|
|
regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
|
|
regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
|
|
regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
|
|
regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]);
|
|
|
|
|
|
#define BF16_MATRIX_MASKZ_LOAD_8x32_2(regArray, a, lda, idx_m, idx_n, mask) \
|
|
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
|
|
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
|
|
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
|
|
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
|
|
regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+8)*lda + idx_n]); \
|
|
regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+10)*lda + idx_n]); \
|
|
regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+12)*lda + idx_n]); \
|
|
regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+14)*lda + idx_n]);
|
|
|
|
|
|
#define BF16_MATRIX_MASKZ_LOAD_4x32_2(regArray, a, lda, idx_m, idx_n, mask) \
|
|
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
|
|
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
|
|
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
|
|
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]);
|
|
|
|
#define BF16_MATRIX_MASKZ_LOAD_1x32(regArray, a, lda, idx_m, idx_n, mask) \
|
|
regArray = _mm512_maskz_loadu_epi16(mask, &a[idx_m*lda + idx_n]);
|
|
|
|
#define BF16_VECTOR_LOAD_1x32(reg, x, idx_n) \
|
|
reg = _mm512_loadu_si512(x + idx_n);
|
|
|
|
|
|
#define BF16_VECTOR_LOAD_1x16(reg, x, idx_n) \
|
|
reg = _mm256_loadu_si256((__m256i *)(x + idx_n));
|
|
|
|
|
|
#define BF16_VECTOR_LOAD_1x8(reg, x, idx_n) \
|
|
reg = _mm_loadu_si128((__m128i *)(x + idx_n));
|
|
|
|
|
|
#define BF16_VECTOR_MASKZ_LOAD_1x32(reg, x, idx_n, mask) \
|
|
reg = _mm512_maskz_loadu_epi16(mask, x + idx_n);
|
|
|
|
|
|
#define BF16_VECTOR_MASKZ_LOAD_1x16(reg, x, idx_n, mask) \
|
|
reg = _mm256_maskz_loadu_epi16(mask, x + idx_n);
|
|
|
|
|
|
#define BF16_VECTOR_MASKZ_LOAD_1x8(reg, x, idx_n, mask) \
|
|
reg = _mm_maskz_loadu_epi16(mask, x + idx_n);
|
|
|
|
|
|
/* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row
|
|
Input - register array of 8 rows of raw-major matrix
|
|
Output - the output of Step 2
|
|
|
|
Step 1: 2-element interleave for matrix
|
|
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27
|
|
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27
|
|
|e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11|e16|e17|f16|f17|e18|e19|f18|f19|e24|e25|f24|f25|e26|e27|f26|f27
|
|
|g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11|g16|g17|h16|h17|g18|g19|h18|h19|g24|g25|h24|h25|g26|g27|h26|h27
|
|
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31
|
|
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31
|
|
|e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15|e20|e21|f20|f21|e22|e23|f22|f23|e28|e29|f28|f29|e30|e31|f30|f31
|
|
|g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15|g20|g21|h20|h21|g22|g23|h22|h23|g28|g29|h28|h29|g30|g31|h30|h31
|
|
|
|
Step 2: 4-element interleave for matrix
|
|
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25
|
|
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27
|
|
|e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9 |e16|e17|f16|f17|g16|g17|h16|h17|e24|e25|f24|f25|g24|g25|h24|h25
|
|
|e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11|e18|e19|f18|f19|g18|g19|h18|h19|e26|e27|f26|f27|g26|g27|h26|h27
|
|
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29
|
|
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31
|
|
|e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13|e20|e21|f20|f21|g20|g21|h20|h21|e28|e29|f28|f29|g28|g29|h28|h29
|
|
|e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15|e22|e23|f22|f23|g22|g23|h22|h23|e30|e31|f30|f31|g30|g31|h30|h31
|
|
*/
|
|
#define BF16_INTERLEAVE_8x32(regArray) \
|
|
regArray##_8 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \
|
|
regArray##_9 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \
|
|
regArray##_10 = _mm512_unpacklo_epi32(regArray##_4, regArray##_5); \
|
|
regArray##_11 = _mm512_unpacklo_epi32(regArray##_6, regArray##_7); \
|
|
regArray##_12 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \
|
|
regArray##_13 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \
|
|
regArray##_14 = _mm512_unpackhi_epi32(regArray##_4, regArray##_5); \
|
|
regArray##_15 = _mm512_unpackhi_epi32(regArray##_6, regArray##_7); \
|
|
\
|
|
regArray##_0 = _mm512_unpacklo_epi64(regArray##_8, regArray##_9); \
|
|
regArray##_1 = _mm512_unpackhi_epi64(regArray##_8, regArray##_9); \
|
|
regArray##_2 = _mm512_unpacklo_epi64(regArray##_10, regArray##_11); \
|
|
regArray##_3 = _mm512_unpackhi_epi64(regArray##_10, regArray##_11); \
|
|
regArray##_4 = _mm512_unpacklo_epi64(regArray##_12, regArray##_13); \
|
|
regArray##_5 = _mm512_unpackhi_epi64(regArray##_12, regArray##_13); \
|
|
regArray##_6 = _mm512_unpacklo_epi64(regArray##_14, regArray##_15); \
|
|
regArray##_7 = _mm512_unpackhi_epi64(regArray##_14, regArray##_15);
|
|
|
|
|
|
/* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row
|
|
Input - register array of 8 rows of raw-major matrix
|
|
Output - the output of Step 2
|
|
|
|
Step 1: 2-element interleave for matrix
|
|
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11
|
|
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11
|
|
|e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11
|
|
|g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11
|
|
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15
|
|
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15
|
|
|e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15
|
|
|g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15
|
|
|
|
Step 2: 4-element interleave for matrix
|
|
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9
|
|
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11
|
|
|e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9
|
|
|e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11
|
|
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13
|
|
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15
|
|
|e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13
|
|
|e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15
|
|
*/
|
|
#define BF16_INTERLEAVE_8x16(regArray) \
|
|
regArray##_8 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \
|
|
regArray##_9 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \
|
|
regArray##_10 = _mm256_unpacklo_epi32(regArray##_4, regArray##_5); \
|
|
regArray##_11 = _mm256_unpacklo_epi32(regArray##_6, regArray##_7); \
|
|
regArray##_12 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \
|
|
regArray##_13 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \
|
|
regArray##_14 = _mm256_unpackhi_epi32(regArray##_4, regArray##_5); \
|
|
regArray##_15 = _mm256_unpackhi_epi32(regArray##_6, regArray##_7); \
|
|
\
|
|
regArray##_0 = _mm256_unpacklo_epi64(regArray##_8, regArray##_9); \
|
|
regArray##_1 = _mm256_unpackhi_epi64(regArray##_8, regArray##_9); \
|
|
regArray##_2 = _mm256_unpacklo_epi64(regArray##_10, regArray##_11); \
|
|
regArray##_3 = _mm256_unpackhi_epi64(regArray##_10, regArray##_11); \
|
|
regArray##_4 = _mm256_unpacklo_epi64(regArray##_12, regArray##_13); \
|
|
regArray##_5 = _mm256_unpackhi_epi64(regArray##_12, regArray##_13); \
|
|
regArray##_6 = _mm256_unpacklo_epi64(regArray##_14, regArray##_15); \
|
|
regArray##_7 = _mm256_unpackhi_epi64(regArray##_14, regArray##_15);
|
|
|
|
/* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row
|
|
Input - register array of 8 rows of raw-major matrix
|
|
Output - the output of Step 2
|
|
|
|
Step 1: 2-element interleave for matrix
|
|
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27
|
|
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27
|
|
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31
|
|
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31
|
|
|
|
Step 2: 4-element interleave for matrix
|
|
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25
|
|
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27
|
|
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29
|
|
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31
|
|
*/
|
|
#define BF16_INTERLEAVE_4x32(regArray) \
|
|
regArray##_4 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \
|
|
regArray##_5 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \
|
|
regArray##_6 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \
|
|
regArray##_7 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \
|
|
\
|
|
regArray##_0 = _mm512_unpacklo_epi64(regArray##_4, regArray##_5); \
|
|
regArray##_1 = _mm512_unpackhi_epi64(regArray##_4, regArray##_5); \
|
|
regArray##_2 = _mm512_unpacklo_epi64(regArray##_6, regArray##_7); \
|
|
regArray##_3 = _mm512_unpackhi_epi64(regArray##_6, regArray##_7);
|
|
|
|
|
|
/* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row
|
|
Input - register array of 8 rows of raw-major matrix
|
|
Output - the output of Step 2
|
|
|
|
Step 1: 2-element interleave for matrix
|
|
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11
|
|
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11
|
|
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15
|
|
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15
|
|
|
|
Step 2: 4-element interleave for matrix
|
|
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9
|
|
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11
|
|
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13
|
|
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15
|
|
*/
|
|
#define BF16_INTERLEAVE_4x16(regArray) \
|
|
regArray##_4 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \
|
|
regArray##_5 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \
|
|
regArray##_6 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \
|
|
regArray##_7 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \
|
|
\
|
|
regArray##_0 = _mm256_unpacklo_epi64(regArray##_4, regArray##_5); \
|
|
regArray##_1 = _mm256_unpackhi_epi64(regArray##_4, regArray##_5); \
|
|
regArray##_2 = _mm256_unpacklo_epi64(regArray##_6, regArray##_7); \
|
|
regArray##_3 = _mm256_unpackhi_epi64(regArray##_6, regArray##_7);
|
|
|
|
|
|
/* 2-step interleave for x with 32 BF16 elements
|
|
Input - original vector
|
|
Output - the output of Step 2
|
|
|
|
Step 1: 2-element interleave for x:
|
|
|x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11|x16|x17|x16|x17|x18|x19|x18|x19|x24|x25|x24|x25|x26|x27|x26|x27
|
|
|x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15|x20|x21|x20|x21|x22|x23|x22|x23|x28|x29|x28|x29|x30|x31|x30|x31
|
|
|
|
Step 2: 4-element interleave for x:
|
|
|x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9 |x16|x17|x16|x17|x16|x17|x16|x17|x24|x25|x24|x25|x24|x25|x24|x25
|
|
|x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11|x18|x19|x18|x19|x18|x19|x18|x19|x26|x27|x26|x27|x26|x27|x26|x27
|
|
|x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13|x20|x21|x20|x21|x20|x21|x20|x21|x28|x29|x28|x29|x28|x29|x28|x29
|
|
|x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15|x22|x23|x22|x23|x22|x23|x22|x23|x30|x31|x30|x31|x30|x31|x30|x31
|
|
*/
|
|
#define BF16_INTERLEAVE_1x32(regArray) \
|
|
regArray##_1 = _mm512_unpacklo_epi32(regArray##_0, regArray##_0); \
|
|
regArray##_3 = _mm512_unpackhi_epi32(regArray##_0, regArray##_0); \
|
|
\
|
|
regArray##_0 = _mm512_unpacklo_epi64(regArray##_1, regArray##_1); \
|
|
regArray##_1 = _mm512_unpackhi_epi64(regArray##_1, regArray##_1); \
|
|
regArray##_2 = _mm512_unpacklo_epi64(regArray##_3, regArray##_3); \
|
|
regArray##_3 = _mm512_unpackhi_epi64(regArray##_3, regArray##_3);
|
|
|
|
|
|
/* 2-step interleave for x with 16 BF16 elements
|
|
Input - original vector
|
|
Output - the output of Step 2
|
|
|
|
Step 1: 2-element interleave for x:
|
|
|x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11
|
|
|x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15
|
|
|
|
Step 2: 4-element interleave for x:
|
|
|x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9
|
|
|x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11
|
|
|x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13
|
|
|x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15
|
|
*/
|
|
#define BF16_INTERLEAVE_1x16(regArray) \
|
|
regArray##_1 = _mm256_unpacklo_epi32(regArray##_0, regArray##_0); \
|
|
regArray##_3 = _mm256_unpackhi_epi32(regArray##_0, regArray##_0); \
|
|
\
|
|
regArray##_0 = _mm256_unpacklo_epi64(regArray##_1, regArray##_1); \
|
|
regArray##_1 = _mm256_unpackhi_epi64(regArray##_1, regArray##_1); \
|
|
regArray##_2 = _mm256_unpacklo_epi64(regArray##_3, regArray##_3); \
|
|
regArray##_3 = _mm256_unpackhi_epi64(regArray##_3, regArray##_3);
|
|
|
|
/* 1-step interleave to exchange the high-256s bit and low-256 bits of 4 pair of registers
|
|
|a0|a1|...|a14|a15|i0|i1|...|i14|i15|
|
|
|b0|b1|...|b14|b15|j0|j1|...|j14|j15|
|
|
|c0|c1|...|c14|c15|k0|k1|...|k14|k15|
|
|
|d0|d1|...|d14|d15|l0|l1|...|l14|l15|
|
|
|e0|e1|...|e14|e15|m0|m1|...|m14|m15|
|
|
|f0|f1|...|f14|f15|n0|n1|...|n14|n15|
|
|
|g0|g1|...|g14|g15|o0|o1|...|o14|o15|
|
|
|h0|h1|...|h14|h15|p0|p1|...|p14|p15|
|
|
*/
|
|
#define BF16_INTERLEAVE256_8x32(regArray) \
|
|
regArray##_0 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0x44); \
|
|
regArray##_1 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0xee); \
|
|
regArray##_2 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0x44); \
|
|
regArray##_3 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0xee); \
|
|
regArray##_4 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0x44); \
|
|
regArray##_5 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0xee); \
|
|
regArray##_6 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0x44); \
|
|
regArray##_7 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0xee);
|
|
|
|
|
|
/* 1-step interleave to exchange the high-256s bit and low-256 bits of 2 pair of registers
|
|
|a0|a1|...|a14|a15|e0|e1|...|e14|e15|
|
|
|b0|b1|...|b14|b15|f0|f1|...|f14|f15|
|
|
|c0|c1|...|c14|c15|g0|g1|...|g14|g15|
|
|
|d0|d1|...|d14|d15|h0|h1|...|h14|h15|
|
|
*/
|
|
#define BF16_INTERLEAVE256_4x32(regArray) \
|
|
regArray##_0 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0x44); \
|
|
regArray##_1 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0xee); \
|
|
regArray##_2 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0x44); \
|
|
regArray##_3 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0xee);
|
|
|
|
|
|
#define BF16_PERMUTE_8x32(idx, regArray) \
|
|
regArray##_8 = _mm512_permutexvar_epi16(idx, regArray##_0); \
|
|
regArray##_9 = _mm512_permutexvar_epi16(idx, regArray##_1); \
|
|
regArray##_10 = _mm512_permutexvar_epi16(idx, regArray##_2); \
|
|
regArray##_11 = _mm512_permutexvar_epi16(idx, regArray##_3); \
|
|
regArray##_12 = _mm512_permutexvar_epi16(idx, regArray##_4); \
|
|
regArray##_13 = _mm512_permutexvar_epi16(idx, regArray##_5); \
|
|
regArray##_14 = _mm512_permutexvar_epi16(idx, regArray##_6); \
|
|
regArray##_15 = _mm512_permutexvar_epi16(idx, regArray##_7);
|
|
|
|
|
|
#define BF16_PERMUTE_8x32_2(idx, regArray) \
|
|
regArray##_8 = _mm512_permutexvar_epi32(idx, regArray##_0); \
|
|
regArray##_9 = _mm512_permutexvar_epi32(idx, regArray##_1); \
|
|
regArray##_10 = _mm512_permutexvar_epi32(idx, regArray##_2); \
|
|
regArray##_11 = _mm512_permutexvar_epi32(idx, regArray##_3); \
|
|
regArray##_12 = _mm512_permutexvar_epi32(idx, regArray##_4); \
|
|
regArray##_13 = _mm512_permutexvar_epi32(idx, regArray##_5); \
|
|
regArray##_14 = _mm512_permutexvar_epi32(idx, regArray##_6); \
|
|
regArray##_15 = _mm512_permutexvar_epi32(idx, regArray##_7);
|
|
|
|
|
|
#define BF16_PERMUTE_4x32(idx, regArray) \
|
|
regArray##_4 = _mm512_permutexvar_epi16(idx, regArray##_0); \
|
|
regArray##_5 = _mm512_permutexvar_epi16(idx, regArray##_1); \
|
|
regArray##_6 = _mm512_permutexvar_epi16(idx, regArray##_2); \
|
|
regArray##_7 = _mm512_permutexvar_epi16(idx, regArray##_3);
|
|
|
|
|
|
#define BF16_PERMUTE_4x32_2(idx, regArray) \
|
|
regArray##_4 = _mm512_permutexvar_epi32(idx, regArray##_0); \
|
|
regArray##_5 = _mm512_permutexvar_epi32(idx, regArray##_1); \
|
|
regArray##_6 = _mm512_permutexvar_epi32(idx, regArray##_2); \
|
|
regArray##_7 = _mm512_permutexvar_epi32(idx, regArray##_3);
|
|
|
|
|
|
/* Calculate the dot result for 2-step interleaved matrix and vector
|
|
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
|
|
*/
|
|
#define BF16_2STEP_INTERLEAVED_DOT_8x32(accumArray, matArray, xArray) \
|
|
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \
|
|
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_2, (__m512bh) xArray##_0); \
|
|
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \
|
|
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_1); \
|
|
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_4, (__m512bh) xArray##_2); \
|
|
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_6, (__m512bh) xArray##_2); \
|
|
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_5, (__m512bh) xArray##_3); \
|
|
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_7, (__m512bh) xArray##_3);
|
|
|
|
|
|
/* Calculate the dot result for 2-step interleaved matrix and vector
|
|
(Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
|
|
*/
|
|
#define BF16_2STEP_INTERLEAVED_DOT_8x16(accumArray, matArray, xArray) \
|
|
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \
|
|
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_2, (__m256bh) xArray##_0); \
|
|
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \
|
|
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_1); \
|
|
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_4, (__m256bh) xArray##_2); \
|
|
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_6, (__m256bh) xArray##_2); \
|
|
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_5, (__m256bh) xArray##_3); \
|
|
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_7, (__m256bh) xArray##_3);
|
|
|
|
/* Calculate the dot result for 2-step interleaved matrix and vector
|
|
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
|
|
*/
|
|
#define BF16_2STEP_INTERLEAVED_DOT_4x32(accumArray, matArray, xArray) \
|
|
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \
|
|
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \
|
|
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_2, (__m512bh) xArray##_2); \
|
|
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_3);
|
|
|
|
|
|
/* Calculate the dot result for 2-step interleaved matrix and vector
|
|
(Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
|
|
*/
|
|
#define BF16_2STEP_INTERLEAVED_DOT_4x16(accumArray, matArray, xArray) \
|
|
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \
|
|
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \
|
|
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_2, (__m256bh) xArray##_2); \
|
|
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_3);
|
|
|
|
|
|
/* Calculate the dot result for matrix and vector at 32 elements per row
|
|
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
|
|
*/
|
|
#define BF16_DOT_8x32(accumArray, matArray, xArray) \
|
|
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray); \
|
|
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray); \
|
|
accumArray##_2 = _mm512_dpbf16_ps(accumArray##_2, (__m512bh) matArray##_2, (__m512bh) xArray); \
|
|
accumArray##_3 = _mm512_dpbf16_ps(accumArray##_3, (__m512bh) matArray##_3, (__m512bh) xArray); \
|
|
accumArray##_4 = _mm512_dpbf16_ps(accumArray##_4, (__m512bh) matArray##_4, (__m512bh) xArray); \
|
|
accumArray##_5 = _mm512_dpbf16_ps(accumArray##_5, (__m512bh) matArray##_5, (__m512bh) xArray); \
|
|
accumArray##_6 = _mm512_dpbf16_ps(accumArray##_6, (__m512bh) matArray##_6, (__m512bh) xArray); \
|
|
accumArray##_7 = _mm512_dpbf16_ps(accumArray##_7, (__m512bh) matArray##_7, (__m512bh) xArray);
|
|
|
|
/* Calculate the dot result for matrix and vector at 32 elements per row
|
|
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
|
|
*/
|
|
#define BF16_DOT_1x32(accumArray, matArray, xArray) \
|
|
accumArray = _mm512_dpbf16_ps(accumArray, (__m512bh) matArray, (__m512bh) xArray);
|
|
|
|
/* Calculate the dot result for matrix and vector at 16 elements per row
|
|
(Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
|
|
*/
|
|
#define BF16_DOT_8x16(accumArray, matArray, xArray) \
|
|
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray); \
|
|
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray); \
|
|
accumArray##_2 = _mm256_dpbf16_ps(accumArray##_2, (__m256bh) matArray##_2, (__m256bh) xArray); \
|
|
accumArray##_3 = _mm256_dpbf16_ps(accumArray##_3, (__m256bh) matArray##_3, (__m256bh) xArray); \
|
|
accumArray##_4 = _mm256_dpbf16_ps(accumArray##_4, (__m256bh) matArray##_4, (__m256bh) xArray); \
|
|
accumArray##_5 = _mm256_dpbf16_ps(accumArray##_5, (__m256bh) matArray##_5, (__m256bh) xArray); \
|
|
accumArray##_6 = _mm256_dpbf16_ps(accumArray##_6, (__m256bh) matArray##_6, (__m256bh) xArray); \
|
|
accumArray##_7 = _mm256_dpbf16_ps(accumArray##_7, (__m256bh) matArray##_7, (__m256bh) xArray);
|
|
|
|
|
|
/* 2-step interleave for matrix against 8 rows with 16 fp32 elements per row
|
|
Input - register array of 8 rows of raw-major matrix
|
|
Output - the output of Step 2
|
|
|
|
Step 1: 2-element interleave for matrix
|
|
|a0|b0|a1|b1|a4|b4|a5|b5|a8 |b8 |a9 |b9 |a12|b12|a13|b13|
|
|
|c0|d0|c1|d1|c4|d4|c5|d5|c8 |d8 |c9 |d9 |c12|d12|c13|d13|
|
|
|e0|f0|e1|f1|e4|f4|e5|f5|e8 |f8 |e9 |f9 |e12|f12|e13|f13|
|
|
|g0|h0|g1|h1|g4|h4|g5|h5|g8 |h8 |g9 |h9 |g12|h12|g13|h13|
|
|
|a2|b2|a3|b3|a6|b6|a7|b7|a10|b10|a11|b11|a14|b14|a15|b15|
|
|
|c2|d2|c3|d3|c6|d6|c7|d7|c10|d10|c11|d11|c14|d14|c15|d15|
|
|
|e2|f2|e3|f3|e6|f6|e7|f7|e10|f10|e11|f11|e14|f14|e15|f15|
|
|
|g2|h2|g3|h3|g6|h6|g7|h7|g10|h10|g11|h11|g14|h14|g15|h15|
|
|
|
|
Step 2: 4-element interleave for matrix
|
|
|a0|b0|c0|d0|a4|b4|c4|d4|a8 |b8 |c8 |d8 |a12|b12|c12|d12|
|
|
|a1|b1|c1|d1|a5|b5|c5|d5|a9 |b9 |c9 |d9 |a13|b13|c13|d13|
|
|
|e0|f0|g0|h0|e4|f4|g4|h4|e8 |f8 |g8 |h8 |e12|f12|g12|h12|
|
|
|e1|f1|g1|h1|e5|f5|g5|h5|e9 |f9 |g9 |h9 |e13|f13|g13|h13|
|
|
|a2|b2|c2|d2|a6|b6|c6|d6|a10|b10|c10|d10|a14|b14|c14|d14|
|
|
|a3|b3|c3|d3|a7|b7|c7|d7|a11|b11|c11|d11|a15|b15|c15|d15|
|
|
|e2|f2|g2|h2|e6|f6|g6|h6|e10|f10|g10|h10|e14|f14|g14|h14|
|
|
|e3|f3|g3|h3|e7|f7|g7|h7|e11|f11|g11|h11|e15|f15|g15|h15|
|
|
*/
|
|
#define FP32_INTERLEAVE_8x16(regArray) \
|
|
regArray##_8 = _mm512_unpacklo_ps(regArray##_0, regArray##_1); \
|
|
regArray##_9 = _mm512_unpacklo_ps(regArray##_2, regArray##_3); \
|
|
regArray##_10 = _mm512_unpacklo_ps(regArray##_4, regArray##_5); \
|
|
regArray##_11 = _mm512_unpacklo_ps(regArray##_6, regArray##_7); \
|
|
regArray##_12 = _mm512_unpackhi_ps(regArray##_0, regArray##_1); \
|
|
regArray##_13 = _mm512_unpackhi_ps(regArray##_2, regArray##_3); \
|
|
regArray##_14 = _mm512_unpackhi_ps(regArray##_4, regArray##_5); \
|
|
regArray##_15 = _mm512_unpackhi_ps(regArray##_6, regArray##_7); \
|
|
\
|
|
regArray##_0 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \
|
|
regArray##_1 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \
|
|
regArray##_4 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \
|
|
regArray##_5 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \
|
|
regArray##_2 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \
|
|
regArray##_3 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \
|
|
regArray##_6 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_14, (__m512d) regArray##_15); \
|
|
regArray##_7 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_14, (__m512d) regArray##_15);
|
|
|
|
#define FP32_INTERLEAVE_8x16_ARRAY(regArray) \
|
|
regArray[8] = _mm512_unpacklo_ps(regArray[0], regArray[1]); \
|
|
regArray[9] = _mm512_unpacklo_ps(regArray[2], regArray[3]); \
|
|
regArray[10] = _mm512_unpacklo_ps(regArray[4], regArray[5]); \
|
|
regArray[11] = _mm512_unpacklo_ps(regArray[6], regArray[7]); \
|
|
regArray[12] = _mm512_unpackhi_ps(regArray[0], regArray[1]); \
|
|
regArray[13] = _mm512_unpackhi_ps(regArray[2], regArray[3]); \
|
|
regArray[14] = _mm512_unpackhi_ps(regArray[4], regArray[5]); \
|
|
regArray[15] = _mm512_unpackhi_ps(regArray[6], regArray[7]); \
|
|
\
|
|
regArray[0] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[8], (__m512d) regArray[9]); \
|
|
regArray[1] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[8], (__m512d) regArray[9]); \
|
|
regArray[4] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[10], (__m512d) regArray[11]); \
|
|
regArray[5] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[10], (__m512d) regArray[11]); \
|
|
regArray[2] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[12], (__m512d) regArray[13]); \
|
|
regArray[3] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[12], (__m512d) regArray[13]); \
|
|
regArray[6] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[14], (__m512d) regArray[15]); \
|
|
regArray[7] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[14], (__m512d) regArray[15]);
|
|
|
|
/* 2-step interleave for matrix against 8 rows with 8 fp32 elements per row
|
|
Input - register array of 8 rows of raw-major matrix
|
|
Output - the output of Step 2
|
|
|
|
Step 1: 2-element interleave for matrix
|
|
|a0|b0|a1|b1|a4|b4|a5|b5|
|
|
|c0|d0|c1|d1|c4|d4|c5|d5|
|
|
|e0|f0|e1|f1|e4|f4|e5|f5|
|
|
|g0|h0|g1|h1|g4|h4|g5|h5|
|
|
|a2|b2|a3|b3|a6|b6|a7|b7|
|
|
|c2|d2|c3|d3|c6|d6|c7|d7|
|
|
|e2|f2|e3|f3|e6|f6|e7|f7|
|
|
|g2|h2|g3|h3|g6|h6|g7|h7|
|
|
|
|
Step 2: 4-element interleave for matrix
|
|
|a0|b0|c0|d0|a4|b4|c4|d4|
|
|
|a1|b1|c1|d1|a5|b5|c5|d5|
|
|
|e0|f0|g0|h0|e4|f4|g4|h4|
|
|
|e1|f1|g1|h1|e5|f5|g5|h5|
|
|
|a2|b2|c2|d2|a6|b6|c6|d6|
|
|
|a3|b3|c3|d3|a7|b7|c7|d7|
|
|
|e2|f2|g2|h2|e6|f6|g6|h6|
|
|
|e3|f3|g3|h3|e7|f7|g7|h7|
|
|
*/
|
|
#define FP32_INTERLEAVE_8x8(regArray) \
|
|
regArray##_8 = _mm256_unpacklo_ps(regArray##_0, regArray##_1); \
|
|
regArray##_9 = _mm256_unpacklo_ps(regArray##_2, regArray##_3); \
|
|
regArray##_10 = _mm256_unpacklo_ps(regArray##_4, regArray##_5); \
|
|
regArray##_11 = _mm256_unpacklo_ps(regArray##_6, regArray##_7); \
|
|
regArray##_12 = _mm256_unpackhi_ps(regArray##_0, regArray##_1); \
|
|
regArray##_13 = _mm256_unpackhi_ps(regArray##_2, regArray##_3); \
|
|
regArray##_14 = _mm256_unpackhi_ps(regArray##_4, regArray##_5); \
|
|
regArray##_15 = _mm256_unpackhi_ps(regArray##_6, regArray##_7); \
|
|
\
|
|
regArray##_0 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \
|
|
regArray##_1 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \
|
|
regArray##_4 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \
|
|
regArray##_5 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \
|
|
regArray##_2 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \
|
|
regArray##_3 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \
|
|
regArray##_6 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_14, (__m256d) regArray##_15); \
|
|
regArray##_7 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_14, (__m256d) regArray##_15);
|
|
|
|
|
|
/* Accumulate the result for 2 batch of 4-registers
|
|
*/
|
|
#define FP32_ACCUM2_8x16(regArray) \
|
|
regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_1); \
|
|
regArray##_2 = _mm512_add_ps(regArray##_2, regArray##_3); \
|
|
regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_5); \
|
|
regArray##_6 = _mm512_add_ps(regArray##_6, regArray##_7); \
|
|
regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_2); \
|
|
regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_6);
|
|
|
|
#define FP32_ACCUM2_8x16_ARRAY(regArray) \
|
|
regArray[0] = _mm512_add_ps(regArray[0], regArray[1]); \
|
|
regArray[2] = _mm512_add_ps(regArray[2], regArray[3]); \
|
|
regArray[4] = _mm512_add_ps(regArray[4], regArray[5]); \
|
|
regArray[6] = _mm512_add_ps(regArray[6], regArray[7]); \
|
|
regArray[0] = _mm512_add_ps(regArray[0], regArray[2]); \
|
|
regArray[4] = _mm512_add_ps(regArray[4], regArray[6]);
|
|
|
|
/* Accumulate the result for 2 batch of 4-registers
|
|
*/
|
|
#define FP32_ACCUM2_8x8(regArray) \
|
|
regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_1); \
|
|
regArray##_2 = _mm256_add_ps(regArray##_2, regArray##_3); \
|
|
regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_5); \
|
|
regArray##_6 = _mm256_add_ps(regArray##_6, regArray##_7); \
|
|
regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_2); \
|
|
regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_6);
|
|
|
|
|
|
/* Store 16 (alpha * result + beta * y) to y
|
|
*/
|
|
#define STORE16_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
|
|
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_loadu_ps(targetAddr))); \
|
|
_mm512_storeu_ps(targetAddr, regResult);
|
|
|
|
|
|
/* Masked store 16 (alpha * result + beta * y) to y
|
|
*/
|
|
#define STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
|
|
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_maskz_loadu_ps(mask, targetAddr))); \
|
|
_mm512_mask_storeu_ps(targetAddr, mask, regResult);
|
|
|
|
|
|
/* Store 8 (alpha * result + beta * y) to y
|
|
*/
|
|
#define STORE8_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
|
|
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_loadu_ps(targetAddr))); \
|
|
_mm256_storeu_ps(targetAddr, regResult);
|
|
|
|
|
|
/* Masked store 8 (alpha * result + beta * y) to y
|
|
*/
|
|
#define STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
|
|
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_maskz_loadu_ps(mask, targetAddr))); \
|
|
_mm256_mask_storeu_ps(targetAddr, mask, regResult);
|
|
|
|
|
|
/* Store 4 (alpha * result + beta * y) to y
|
|
*/
|
|
#define STORE4_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
|
|
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_loadu_ps(targetAddr))); \
|
|
_mm_storeu_ps(targetAddr, regResult);
|
|
|
|
|
|
/* Masked store 4 (alpha * result + beta * y) to y
|
|
*/
|
|
#define STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
|
|
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_maskz_loadu_ps(mask, targetAddr))); \
|
|
_mm_mask_storeu_ps(targetAddr, mask, regResult);
|
|
|
|
|
|
/* Store 16 (alpha * result + y) to y
|
|
*/
|
|
#define STORE16_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
|
|
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_loadu_ps(targetAddr)); \
|
|
_mm512_storeu_ps(targetAddr, regResult);
|
|
|
|
|
|
/* Masked store 16 (alpha * result + y) to y
|
|
*/
|
|
#define STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
|
|
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \
|
|
_mm512_mask_storeu_ps(targetAddr, mask, regResult);
|
|
|
|
|
|
/* Store 8 (alpha * result + y) to y
|
|
*/
|
|
#define STORE8_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
|
|
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_loadu_ps(targetAddr)); \
|
|
_mm256_storeu_ps(targetAddr, regResult);
|
|
|
|
|
|
/* Masked store 8 (alpha * result + y) to y
|
|
*/
|
|
#define STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
|
|
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \
|
|
_mm256_mask_storeu_ps(targetAddr, mask, regResult);
|
|
|
|
|
|
/* Store 4 (alpha * result + y) to y
|
|
*/
|
|
#define STORE4_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
|
|
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_loadu_ps(targetAddr)); \
|
|
_mm_storeu_ps(targetAddr, regResult);
|
|
|
|
|
|
/* Masked store 4 (alpha * result + y) to y
|
|
*/
|
|
#define STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
|
|
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \
|
|
_mm_mask_storeu_ps(targetAddr, mask, regResult);
|
|
|
|
|
|
/* Store 16 (result + y) to y
|
|
*/
|
|
#define STORE16_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
|
|
regResult = _mm512_add_ps(regResult, _mm512_loadu_ps(targetAddr)); \
|
|
_mm512_storeu_ps(targetAddr, regResult);
|
|
|
|
|
|
/* Masked store 16 (result + y) to y
|
|
*/
|
|
#define STORE16_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
|
|
regResult = _mm512_add_ps(regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \
|
|
_mm512_mask_storeu_ps(targetAddr, mask, regResult);
|
|
|
|
|
|
/* Store 8 (result + y) to y
|
|
*/
|
|
#define STORE8_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
|
|
regResult = _mm256_add_ps(regResult, _mm256_loadu_ps(targetAddr)); \
|
|
_mm256_storeu_ps(targetAddr, regResult);
|
|
|
|
|
|
/* Masked store 8 (result + y) to y
|
|
*/
|
|
#define STORE8_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
|
|
regResult = _mm256_add_ps(regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \
|
|
_mm256_mask_storeu_ps(targetAddr, mask, regResult);
|
|
|
|
|
|
/* Store 4 (result + y) to y
|
|
*/
|
|
#define STORE4_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
|
|
regResult = _mm_add_ps(regResult, _mm_loadu_ps(targetAddr)); \
|
|
_mm_storeu_ps(targetAddr, regResult);
|
|
|
|
|
|
/* Masked store 4 (result + y) to y
|
|
*/
|
|
#define STORE4_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
|
|
regResult = _mm_add_ps(regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \
|
|
_mm_mask_storeu_ps(targetAddr, mask, regResult);
|
|
|
|
|
|
/* Store 16 (alpha * result) to y
|
|
*/
|
|
#define STORE16_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
|
|
_mm512_storeu_ps(targetAddr, _mm512_mul_ps(ALPHAVECTOR, regResult));
|
|
|
|
|
|
/* Masked store 16 (alpha * result) to y
|
|
*/
|
|
#define STORE16_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
|
|
_mm512_mask_storeu_ps(targetAddr, mask, _mm512_mul_ps(ALPHAVECTOR, regResult));
|
|
|
|
|
|
/* Store 8 (alpha * result) to y
|
|
*/
|
|
#define STORE8_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
|
|
_mm256_storeu_ps(targetAddr, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult));
|
|
|
|
|
|
/* Masked store 8 (alpha * result) to y
|
|
*/
|
|
#define STORE8_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
|
|
_mm256_mask_storeu_ps(targetAddr, mask, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult));
|
|
|
|
|
|
/* Store 4 (alpha * result) to y
|
|
*/
|
|
#define STORE4_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
|
|
_mm_storeu_ps(targetAddr, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult));
|
|
|
|
|
|
/* Masked store 4 (alpha * result) to y
|
|
*/
|
|
#define STORE4_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
|
|
_mm_mask_storeu_ps(targetAddr, mask, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult));
|
|
|
|
|
|
/* Store 16 result to y
|
|
*/
|
|
#define STORE16_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
|
|
_mm512_storeu_ps(targetAddr, regResult);
|
|
|
|
|
|
/* Masked store 16 result to y
|
|
*/
|
|
#define STORE16_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
|
|
_mm512_mask_storeu_ps(targetAddr, mask, regResult);
|
|
|
|
|
|
/* Store 8 result to y
|
|
*/
|
|
#define STORE8_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
|
|
_mm256_storeu_ps(targetAddr, regResult);
|
|
|
|
|
|
/* Masked store 8 result to y
|
|
*/
|
|
#define STORE8_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
|
|
_mm256_mask_storeu_ps(targetAddr, mask, regResult);
|
|
|
|
|
|
/* Store 4 result to y
|
|
*/
|
|
#define STORE4_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
|
|
_mm_storeu_ps(targetAddr, regResult);
|
|
|
|
|
|
/* Masked store 4 result to y
|
|
*/
|
|
#define STORE4_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
|
|
_mm_mask_storeu_ps(targetAddr, mask, regResult);
|
|
|
|
#endif
|