x86_64: BFLOAT16: fix build warning
This commit is contained in:
parent
9f52abf12c
commit
ee5ca8a328
|
@ -56,25 +56,25 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
|
||||
|
||||
#define BF16_MATRIX_LOAD_8x16(regArray, a, lda, idx_m, idx_n) \
|
||||
regArray##_0 = _mm256_loadu_si256(&a[(idx_m+0)*lda + idx_n]); \
|
||||
regArray##_1 = _mm256_loadu_si256(&a[(idx_m+1)*lda + idx_n]); \
|
||||
regArray##_2 = _mm256_loadu_si256(&a[(idx_m+2)*lda + idx_n]); \
|
||||
regArray##_3 = _mm256_loadu_si256(&a[(idx_m+3)*lda + idx_n]); \
|
||||
regArray##_4 = _mm256_loadu_si256(&a[(idx_m+4)*lda + idx_n]); \
|
||||
regArray##_5 = _mm256_loadu_si256(&a[(idx_m+5)*lda + idx_n]); \
|
||||
regArray##_6 = _mm256_loadu_si256(&a[(idx_m+6)*lda + idx_n]); \
|
||||
regArray##_7 = _mm256_loadu_si256(&a[(idx_m+7)*lda + idx_n]);
|
||||
regArray##_0 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+0)*lda + idx_n])); \
|
||||
regArray##_1 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+1)*lda + idx_n])); \
|
||||
regArray##_2 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+2)*lda + idx_n])); \
|
||||
regArray##_3 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+3)*lda + idx_n])); \
|
||||
regArray##_4 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+4)*lda + idx_n])); \
|
||||
regArray##_5 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+5)*lda + idx_n])); \
|
||||
regArray##_6 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+6)*lda + idx_n])); \
|
||||
regArray##_7 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+7)*lda + idx_n]));
|
||||
|
||||
|
||||
#define BF16_MATRIX_LOAD_8x8(regArray, a, lda, idx_m, idx_n) \
|
||||
regArray##_0 = _mm_loadu_si128(&a[(idx_m+0)*lda + idx_n]); \
|
||||
regArray##_1 = _mm_loadu_si128(&a[(idx_m+1)*lda + idx_n]); \
|
||||
regArray##_2 = _mm_loadu_si128(&a[(idx_m+2)*lda + idx_n]); \
|
||||
regArray##_3 = _mm_loadu_si128(&a[(idx_m+3)*lda + idx_n]); \
|
||||
regArray##_4 = _mm_loadu_si128(&a[(idx_m+4)*lda + idx_n]); \
|
||||
regArray##_5 = _mm_loadu_si128(&a[(idx_m+5)*lda + idx_n]); \
|
||||
regArray##_6 = _mm_loadu_si128(&a[(idx_m+6)*lda + idx_n]); \
|
||||
regArray##_7 = _mm_loadu_si128(&a[(idx_m+7)*lda + idx_n]);
|
||||
regArray##_0 = _mm_loadu_si128((__m128i *)(&a[(idx_m+0)*lda + idx_n])); \
|
||||
regArray##_1 = _mm_loadu_si128((__m128i *)(&a[(idx_m+1)*lda + idx_n])); \
|
||||
regArray##_2 = _mm_loadu_si128((__m128i *)(&a[(idx_m+2)*lda + idx_n])); \
|
||||
regArray##_3 = _mm_loadu_si128((__m128i *)(&a[(idx_m+3)*lda + idx_n])); \
|
||||
regArray##_4 = _mm_loadu_si128((__m128i *)(&a[(idx_m+4)*lda + idx_n])); \
|
||||
regArray##_5 = _mm_loadu_si128((__m128i *)(&a[(idx_m+5)*lda + idx_n])); \
|
||||
regArray##_6 = _mm_loadu_si128((__m128i *)(&a[(idx_m+6)*lda + idx_n])); \
|
||||
regArray##_7 = _mm_loadu_si128((__m128i *)(&a[(idx_m+7)*lda + idx_n]));
|
||||
|
||||
|
||||
#define BF16_MATRIX_LOAD_1x32(regArray, a, lda, idx_m, idx_n) \
|
||||
|
@ -153,11 +153,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
|
||||
|
||||
#define BF16_VECTOR_LOAD_1x16(reg, x, idx_n) \
|
||||
reg = _mm256_loadu_si256(x + idx_n);
|
||||
reg = _mm256_loadu_si256((__m256i *)(x + idx_n));
|
||||
|
||||
|
||||
#define BF16_VECTOR_LOAD_1x8(reg, x, idx_n) \
|
||||
reg = _mm_loadu_si128(x + idx_n);
|
||||
reg = _mm_loadu_si128((__m128i *)(x + idx_n));
|
||||
|
||||
|
||||
#define BF16_VECTOR_MASKZ_LOAD_1x32(reg, x, idx_n, mask) \
|
||||
|
|
|
@ -79,21 +79,21 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
|
|||
__m256 accum256_1 = _mm256_setzero_ps();
|
||||
int tail_index_32 = n&(~31);
|
||||
for (int j = 0; j < tail_index_32; j += 32) {
|
||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[j+ 0]), (__m256bh) _mm256_loadu_si256(&y[j+ 0]));
|
||||
accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256(&x[j+16]), (__m256bh) _mm256_loadu_si256(&y[j+16]));
|
||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256((__m256i *)&x[j+ 0]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[j+ 0]));
|
||||
accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256((__m256i *)&x[j+16]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[j+16]));
|
||||
}
|
||||
accum256 = _mm256_add_ps(accum256, accum256_1);
|
||||
|
||||
/* Processing the remaining <32 chunk with 16-elements processing */
|
||||
if ((n&16) != 0) {
|
||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[tail_index_32]), (__m256bh) _mm256_loadu_si256(&y[tail_index_32]));
|
||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256((__m256i *)&x[tail_index_32]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[tail_index_32]));
|
||||
}
|
||||
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
|
||||
|
||||
/* Processing the remaining <16 chunk with 8-elements processing */
|
||||
if ((n&8) != 0) {
|
||||
int tail_index_16 = n&(~15);
|
||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16]));
|
||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128((__m128i *)&x[tail_index_16]), (__m128bh) _mm_loadu_si128((__m128i *)&y[tail_index_16]));
|
||||
}
|
||||
|
||||
/* Processing the remaining <8 chunk with masked 8-elements processing */
|
||||
|
@ -108,13 +108,13 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
|
|||
} else if (n > 15) { /* n range from 16 to 31 */
|
||||
/* Processing <32 chunk with 16-elements processing */
|
||||
__m256 accum256 = _mm256_setzero_ps();
|
||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[0]), (__m256bh) _mm256_loadu_si256(&y[0]));
|
||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256((__m256i *)&x[0]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[0]));
|
||||
accum128 += _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
|
||||
|
||||
/* Processing the remaining <16 chunk with 8-elements processing */
|
||||
if ((n&8) != 0) {
|
||||
int tail_index_16 = n&(~15);
|
||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16]));
|
||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128((__m128i *)&x[tail_index_16]), (__m128bh) _mm_loadu_si128((__m128i *)&y[tail_index_16]));
|
||||
}
|
||||
|
||||
/* Processing the remaining <8 chunk with masked 8-elements processing */
|
||||
|
@ -128,7 +128,7 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
|
|||
}
|
||||
} else if (n > 7) { /* n range from 8 to 15 */
|
||||
/* Processing <16 chunk with 8-elements processing */
|
||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[0]), (__m128bh) _mm_loadu_si128(&y[0]));
|
||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128((__m128i *)&x[0]), (__m128bh) _mm_loadu_si128((__m128i *)&y[0]));
|
||||
|
||||
/* Processing the remaining <8 chunk with masked 8-elements processing */
|
||||
if ((n&7) != 0) {
|
||||
|
|
|
@ -1246,7 +1246,7 @@ void COL_MAJOR_ITCOPY_KERNEL_Kx16(BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat
|
|||
// K=Any number but will be processed based on 32, M<=16
|
||||
void COL_MAJOR_ITCOPY_KERNEL_Kx16m(BLASLONG m, BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
|
||||
{
|
||||
bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3;
|
||||
bfloat16 * src_addr0;
|
||||
bfloat16 * dst_addr0, * dst_addr1;
|
||||
|
||||
BLASLONG tag_k_32x = k & (~31);
|
||||
|
|
|
@ -30,6 +30,13 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
// Include common macros for BF16 based operations with IA intrinsics
|
||||
#include "bf16_common_macros.h"
|
||||
|
||||
#undef STORE16_COMPLETE_RESULT
|
||||
#undef STORE16_MASK_COMPLETE_RESULT
|
||||
#undef STORE8_COMPLETE_RESULT
|
||||
#undef STORE8_MASK_COMPLETE_RESULT
|
||||
#undef STORE4_COMPLETE_RESULT
|
||||
#undef STORE4_MASK_COMPLETE_RESULT
|
||||
|
||||
#ifndef ZERO_BETA // Beta is non-zero
|
||||
|
||||
#ifndef ONE_BETA // BETA is not ONE
|
||||
|
@ -103,7 +110,9 @@ static int sbgemv_kernel_32xN_lda_direct(BLASLONG m, BLASLONG n, float alpha, bf
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512i matrixArray_seed_0, matrixArray_seed_1, matrixArray_seed_2, matrixArray_seed_3;
|
||||
|
@ -202,7 +211,7 @@ static int sbgemv_kernel_32xN_lda_direct(BLASLONG m, BLASLONG n, float alpha, bf
|
|||
unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(m&31)));
|
||||
__mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
|
||||
|
||||
unsigned short store_tail_mask_value = (((unsigned int)0xffff) >> (16-(m&15)));
|
||||
unsigned int store_tail_mask_value = (((unsigned int)0xffff) >> (16-(m&15)));
|
||||
__mmask32 store_tail_mask = *((__mmask32*) &store_tail_mask_value);
|
||||
|
||||
accum512_0 = _mm512_setzero_ps();
|
||||
|
|
|
@ -29,6 +29,13 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
// Include common macros for BF16 based operations with IA intrinsics
|
||||
#include "bf16_common_macros.h"
|
||||
|
||||
#undef STORE16_COMPLETE_RESULT
|
||||
#undef STORE16_MASK_COMPLETE_RESULT
|
||||
#undef STORE8_COMPLETE_RESULT
|
||||
#undef STORE8_MASK_COMPLETE_RESULT
|
||||
#undef STORE4_COMPLETE_RESULT
|
||||
#undef STORE4_MASK_COMPLETE_RESULT
|
||||
|
||||
#ifndef ZERO_BETA // Beta is non-zero
|
||||
|
||||
#ifndef ONE_BETA // BETA is not ONE
|
||||
|
@ -231,7 +238,9 @@ static int sbgemv_kernel_32x2(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
unsigned char load_mask_value = (((unsigned char)0xff) >> 6);
|
||||
|
@ -280,7 +289,7 @@ static int sbgemv_kernel_32x2(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
} else if (tail_num == 8) {
|
||||
__m256 result256 = _mm256_setzero_ps();
|
||||
|
||||
__m256i matrixArray256 = _mm256_loadu_si256(&a[(tag_m_32x)*2]); // Load 8 rows with n=2
|
||||
__m256i matrixArray256 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*2]); // Load 8 rows with n=2
|
||||
__m256i xArray256 = _mm512_castsi512_si256(xArray);
|
||||
result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) xArray256);
|
||||
|
||||
|
@ -323,7 +332,9 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5);
|
||||
|
@ -395,9 +406,9 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
result256_0 = _mm256_setzero_ps();
|
||||
result256_1 = _mm256_setzero_ps();
|
||||
|
||||
matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
|
||||
matrixArray256_1 = _mm256_loadu_si256(&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element
|
||||
matrixArray256_2 = _mm256_loadu_si256(&a[((tag_m_32x+10)*3 + 2)]); // Load 5 rows with n=3 plus 1 element
|
||||
matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
|
||||
matrixArray256_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element
|
||||
matrixArray256_2 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+10)*3 + 2)]); // Load 5 rows with n=3 plus 1 element
|
||||
|
||||
matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
|
||||
matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row
|
||||
|
@ -423,8 +434,8 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
if (tail_num > 10) {
|
||||
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-10-1)*3+1)));
|
||||
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
|
||||
matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
|
||||
matrixArray256_1 = _mm256_loadu_si256(&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element
|
||||
matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
|
||||
matrixArray256_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element
|
||||
matrixArray256_2 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+10)*3 + 2)]); // Load m-tag_m_32x-10 rows
|
||||
|
||||
matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
|
||||
|
@ -439,7 +450,7 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
} else if (tail_num > 5) {
|
||||
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-5-1)*3+2)));
|
||||
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
|
||||
matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
|
||||
matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
|
||||
matrixArray256_1 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+5)*3+1)]); // Load m-tag_m_32x-5 rows
|
||||
matrixArray256_2 = _mm256_setzero_si256();
|
||||
|
||||
|
@ -499,7 +510,9 @@ static int sbgemv_kernel_16x4(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512i M512_EPI32_1 = _mm512_set1_epi32(1);
|
||||
|
@ -591,7 +604,9 @@ static int sbgemv_kernel_30x5(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512 result_0, result_1;
|
||||
|
@ -782,7 +797,9 @@ static int sbgemv_kernel_16x6(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512i M512_EPI32_1 = _mm512_set1_epi32(1);
|
||||
|
@ -866,9 +883,9 @@ static int sbgemv_kernel_16x6(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
|
||||
result256_0 = _mm256_setzero_ps();
|
||||
|
||||
matrixArray_0 = _mm256_loadu_si256(&a[(tag_m_16x)*6]); // Load 2 rows with n=6 plus 4 element
|
||||
matrixArray_1 = _mm256_loadu_si256(&a[((tag_m_16x+2)*6 + 4)]); // Load 2 rows with n=6 plus 4 element
|
||||
matrixArray_2 = _mm256_loadu_si256(&a[((tag_m_16x+5)*6 + 2)]); // Load 2 rows with n=6 plus 4 element
|
||||
matrixArray_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_16x)*6]); // Load 2 rows with n=6 plus 4 element
|
||||
matrixArray_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_16x+2)*6 + 4)]); // Load 2 rows with n=6 plus 4 element
|
||||
matrixArray_2 = _mm256_loadu_si256((__m256i *)&a[((tag_m_16x+5)*6 + 2)]); // Load 2 rows with n=6 plus 4 element
|
||||
|
||||
// Process the 0|1 elements
|
||||
// Select the 0|1 elements for each row
|
||||
|
@ -957,7 +974,9 @@ static int sbgemv_kernel_16x7(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512i M512_EPI32_2 = _mm512_set1_epi32(2);
|
||||
|
@ -1110,7 +1129,7 @@ static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
{
|
||||
BLASLONG tag_m_16x = m & (~15);
|
||||
|
||||
__m128i x128 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7|
|
||||
__m128i x128 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7|
|
||||
|
||||
if (tag_m_16x > 0) {
|
||||
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
|
||||
|
@ -1122,7 +1141,9 @@ static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512i M512_EPI32_2 = _mm512_set1_epi32(2);
|
||||
|
@ -1214,7 +1235,7 @@ static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
__m128 result128, tmp128;
|
||||
for (BLASLONG i = tag_m_16x; i < m; i++) {
|
||||
result128 = _mm_setzero_ps();
|
||||
matrixArray128 = _mm_loadu_si128(&a[(i)*8]); // Load 1 rows with n=8
|
||||
matrixArray128 = _mm_loadu_si128((__m128i *)&a[(i)*8]); // Load 1 rows with n=8
|
||||
result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
|
||||
tmp128 = _mm_shuffle_ps(result128, result128, 14);
|
||||
result128 = _mm_add_ps(result128, tmp128);
|
||||
|
@ -1258,7 +1279,7 @@ static int sbgemv_kernel_14x9(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
|
||||
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 7);
|
||||
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
|
||||
__m128i x128_0 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7|
|
||||
__m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7|
|
||||
__m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|0 |0 | 0| 0| 0| 0| 0|
|
||||
|
||||
if (tag_m_14x > 0) {
|
||||
|
@ -1271,7 +1292,9 @@ static int sbgemv_kernel_14x9(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x,
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m256i M256_EPI16_2 = _mm256_set1_epi16(2);
|
||||
|
@ -1390,7 +1413,7 @@ static int sbgemv_kernel_12x10(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x
|
|||
|
||||
unsigned char x_load_mask_value = (((unsigned char)0xf) >> 3);
|
||||
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
|
||||
__m128i x128_0 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7|
|
||||
__m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7|
|
||||
__m128i x128_1 = _mm_maskz_loadu_epi32(x_load_mask, (x+8)); // |x8|x9|0 | 0| 0| 0| 0| 0|
|
||||
|
||||
if (tag_m_12x > 0) {
|
||||
|
@ -1403,7 +1426,9 @@ static int sbgemv_kernel_12x10(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m256i M256_EPI32_1 = _mm256_set1_epi32(1);
|
||||
|
@ -1522,7 +1547,7 @@ static int sbgemv_kernel_15x11(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x
|
|||
|
||||
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5);
|
||||
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
|
||||
__m128i x128_0 = _mm_loadu_si128(x); // |x0|x1| x2|x3|x4|x5|x6|x7|
|
||||
__m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1| x2|x3|x4|x5|x6|x7|
|
||||
__m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10| 0| 0| 0| 0| 0|
|
||||
|
||||
if (tag_m_15x > 0) {
|
||||
|
@ -1535,7 +1560,9 @@ static int sbgemv_kernel_15x11(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5;
|
||||
|
@ -1690,7 +1717,7 @@ static int sbgemv_kernel_15x12(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x
|
|||
|
||||
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 4);
|
||||
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
|
||||
__m128i x128_0 = _mm_loadu_si128(x); // |x0|x1| x2| x3|x4|x5|x6|x7|
|
||||
__m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1| x2| x3|x4|x5|x6|x7|
|
||||
__m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10|x11| 0| 0| 0| 0|
|
||||
|
||||
if (tag_m_15x > 0) {
|
||||
|
@ -1703,7 +1730,9 @@ static int sbgemv_kernel_15x12(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5;
|
||||
|
@ -1873,16 +1902,15 @@ static int sbgemv_kernel_16x13(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512i M512_EPI32_4 = _mm512_set1_epi32(4);
|
||||
__m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
|
||||
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
|
||||
|
||||
unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 6);
|
||||
__mmask32 load_mask = *((__mmask32*) &load_mask_value);
|
||||
|
||||
// Prepare X with 2-step interleave way
|
||||
xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
|
||||
BF16_INTERLEAVE_1x32(xArray)
|
||||
|
@ -2045,7 +2073,9 @@ static int sbgemv_kernel_16x14(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512i M512_EPI32_4 = _mm512_set1_epi32(4);
|
||||
|
@ -2207,16 +2237,15 @@ static int sbgemv_kernel_16x15(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512i M512_EPI32_4 = _mm512_set1_epi32(4);
|
||||
__m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
|
||||
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
|
||||
|
||||
unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 2);
|
||||
__mmask32 load_mask = *((__mmask32*) &load_mask_value);
|
||||
|
||||
// Prepare X with 2-step interleave way
|
||||
xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
|
||||
BF16_INTERLEAVE_1x32(xArray)
|
||||
|
@ -2364,7 +2393,7 @@ static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x
|
|||
{
|
||||
BLASLONG tag_m_16x = m & (~15);
|
||||
|
||||
__m256i x256 = _mm256_loadu_si256(x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15|
|
||||
__m256i x256 = _mm256_loadu_si256((__m256i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15|
|
||||
|
||||
if (tag_m_16x > 0) {
|
||||
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
|
||||
|
@ -2377,7 +2406,9 @@ static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512i M512_EPI32_4 = _mm512_set1_epi32(4);
|
||||
|
@ -2484,7 +2515,7 @@ static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x
|
|||
__m128 accum128, tmp128;
|
||||
for (BLASLONG i = tag_m_16x; i < m; i++) {
|
||||
accum256 = _mm256_setzero_ps();
|
||||
matrixArray256 = _mm256_loadu_si256(&a[(i)*16]); // Load 1 rows with n=16
|
||||
matrixArray256 = _mm256_loadu_si256((__m256i *)&a[(i)*16]); // Load 1 rows with n=16
|
||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
|
||||
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
|
||||
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
|
||||
|
@ -2535,7 +2566,9 @@ static int sbgemv_kernel_8x16p_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
|
||||
|
@ -2647,8 +2680,6 @@ static int sbgemv_kernel_1x128_lda_direct(BLASLONG m, BLASLONG n, float alpha, b
|
|||
BLASLONG tag_n_32x = n & (~31);
|
||||
BLASLONG tag_n_128x = n & (~127);
|
||||
|
||||
__m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7, \
|
||||
accum512_8, accum512_9, accum512_10, accum512_11, accum512_12, accum512_13, accum512_14, accum512_15;
|
||||
__m512 accum512_bridge[8];
|
||||
__m512 accum512_t_0, accum512_t_1, accum512_t_2, accum512_t_3;
|
||||
__m256 accum256_0;
|
||||
|
@ -2658,7 +2689,9 @@ static int sbgemv_kernel_1x128_lda_direct(BLASLONG m, BLASLONG n, float alpha, b
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
|
||||
|
@ -2825,7 +2858,9 @@ static int sbgemv_kernel_8x32_lda_direct(BLASLONG m, BLASLONG n, float alpha, bf
|
|||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_set1_ps(beta);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7;
|
||||
|
@ -2961,7 +2996,9 @@ static int sbgemv_kernel_8x16m_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16
|
|||
__m512 ALPHAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(alpha));
|
||||
#endif
|
||||
#ifndef ZERO_BETA
|
||||
#ifndef ONE_BETA
|
||||
__m512 BETAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(beta));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
__m256 accum256_0, accum256_1, accum256_2, accum256_3, accum256_4, accum256_5, accum256_6, accum256_7, \
|
||||
|
@ -3012,7 +3049,7 @@ static int sbgemv_kernel_8x16m_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16
|
|||
__m128 accum128, tmp128;
|
||||
for (BLASLONG i = tag_m_8x; i < m; i++) {
|
||||
accum256_0 = _mm256_setzero_ps();
|
||||
matrixArray_0 = _mm256_loadu_si256(&a[(i)*lda]); // Load 1 rows with n=16
|
||||
matrixArray_0 = _mm256_loadu_si256((__m256i *)&a[(i)*lda]); // Load 1 rows with n=16
|
||||
accum256_0 = _mm256_dpbf16_ps(accum256_0, (__m256bh) matrixArray_0, (__m256bh) xArray256);
|
||||
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
|
||||
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
|
||||
|
|
Loading…
Reference in New Issue