OpenBLAS/kernel/x86_64/sbgemv_t_microk_cooperlake_...

3120 lines
157 KiB
C

/***************************************************************************
Copyright (c) 2014, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#include <immintrin.h>
#include "common.h"
// Include common macros for BF16 based operations with IA intrinsics
#include "bf16_common_macros.h"
#undef STORE16_COMPLETE_RESULT
#undef STORE16_MASK_COMPLETE_RESULT
#undef STORE8_COMPLETE_RESULT
#undef STORE8_MASK_COMPLETE_RESULT
#undef STORE4_COMPLETE_RESULT
#undef STORE4_MASK_COMPLETE_RESULT
#ifndef ZERO_BETA // Beta is non-zero
#ifndef ONE_BETA // BETA is not ONE
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_BETA
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA_BETA
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA_BETA
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA
#else // BETA is ONE
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA_ONE
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA_ONE
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE
#endif
#else // BETA is zero
#ifndef ONE_ALPHA // ALPHA is not ONE
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA
#else // ALPHA is ONE
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_DIRECT
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_DIRECT
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_DIRECT
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_DIRECT
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_DIRECT
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_DIRECT
#endif
#endif
// 32 rows parallel processing BF16 GEMV kernel for n=1 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_32x1_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_32x1_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_32x1_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_32x1(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_32x = m & (~31);
__m512i matrixArray_0, matrixArray_1, matrixArray_2;
__m512i xArray;
__m512 result_0, result_1;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512i load_idx_lo = _mm512_set_epi16(0, 15, 0, 14, 0, 13, 0, 12, 0, 11, 0, 10, 0, 9, 0, 8,\
0, 7, 0, 6, 0, 5, 0, 4, 0, 3, 0, 2, 0, 1, 0, 0);
__m512i M512_EPI16_16 = _mm512_set1_epi16(16);
__m512i load_idx_hi = _mm512_add_epi16(load_idx_lo, M512_EPI16_16);
unsigned int interleve_mask_value = ((unsigned int) 0x55555555);
__mmask32 interleave_mask = *((__mmask32*) &interleve_mask_value);
xArray = _mm512_set1_epi16((short) x[0]);
xArray = _mm512_mask_blend_epi16(interleave_mask, _mm512_setzero_si512(), xArray);
if (tag_m_32x > 0) {
for (BLASLONG idx_m = 0; idx_m < tag_m_32x; idx_m+=32) {
result_0 = _mm512_setzero_ps();
result_1 = _mm512_setzero_ps();
matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)]); // Load 32 rows with n=1
matrixArray_1 = _mm512_permutexvar_epi16(load_idx_lo, matrixArray_0); // Expand the low 16 elements
matrixArray_2 = _mm512_permutexvar_epi16(load_idx_hi, matrixArray_0); // Expand the high 16 elements
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_1, (__m512bh) xArray);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_2, (__m512bh) xArray);
STORE16_COMPLETE_RESULT(result_0, y+idx_m)
STORE16_COMPLETE_RESULT(result_1, y+idx_m+16)
}
}
BLASLONG tail_num = m - tag_m_32x;
if (tail_num > 16) {
result_0 = _mm512_setzero_ps();
result_1 = _mm512_setzero_ps();
unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-tail_num));
__mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
matrixArray_0 = _mm512_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)]); // Load 32 rows with n=1
matrixArray_1 = _mm512_permutexvar_epi16(load_idx_lo, matrixArray_0); // Expand the low 16 elements
matrixArray_2 = _mm512_permutexvar_epi16(load_idx_hi, matrixArray_0); // Expand the high 16 elements
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_1, (__m512bh) xArray);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_2, (__m512bh) xArray);
unsigned short store_mask_value = (((unsigned short)0xffff) >> (32-tail_num));
__mmask16 store_mask = *((__mmask16*) &store_mask_value);
STORE16_COMPLETE_RESULT(result_0, y+tag_m_32x)
STORE16_MASK_COMPLETE_RESULT(result_1, y+tag_m_32x+16, store_mask)
} else if (tail_num > 8) {
__m256 result256_0 = _mm256_setzero_ps();
__m256 result256_1 = _mm256_setzero_ps();
__m256i load_idx_lo256 = _mm512_castsi512_si256(load_idx_lo);
__m256i load_idx_hi256 = _mm512_extracti32x8_epi32(load_idx_lo, 0x1);
__m256i xArray256 = _mm512_castsi512_si256(xArray);
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-tail_num));
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
__m256i matrixArray256_0 = _mm256_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)]); // Load 16 rows with n=1
__m256i matrixArray256_1 = _mm256_permutexvar_epi16(load_idx_lo256, matrixArray256_0); // Expand the low 8 elements
__m256i matrixArray256_2 = _mm256_permutexvar_epi16(load_idx_hi256, matrixArray256_0); // Expand the high 8 elements
result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_1, (__m256bh) xArray256);
result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_2, (__m256bh) xArray256);
unsigned char store_mask_value = (((unsigned char)0xff) >> (16-tail_num));
__mmask8 store_mask = *((__mmask8*) &store_mask_value);
STORE8_COMPLETE_RESULT(result256_0, y+tag_m_32x)
STORE8_MASK_COMPLETE_RESULT(result256_1, y+tag_m_32x+8, store_mask)
} else {
__m128 result128_0 = _mm_setzero_ps();
__m128 result128_1 = _mm_setzero_ps();
__m128i load_idx_lo128 = _mm_set_epi16(0, 3, 0, 2, 0, 1, 0, 0);
__m128i M128_EPI16_4 = _mm_set1_epi16(4);
__m128i load_idx_hi128 = _mm_add_epi16(load_idx_lo128, M128_EPI16_4);
__m128i xArray128 = _mm512_castsi512_si128(xArray);
unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
__mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
__m128i matrixArray128_0 = _mm_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)]); // Load 8 rows with n=1
__m128i matrixArray128_1 = _mm_permutexvar_epi16(load_idx_lo128, matrixArray128_0); // Expand the low 4 elements
__m128i matrixArray128_2 = _mm_permutexvar_epi16(load_idx_hi128, matrixArray128_0); // Expand the high 4 elements
result128_0 = _mm_dpbf16_ps(result128_0, (__m128bh) matrixArray128_1, (__m128bh) xArray128);
result128_1 = _mm_dpbf16_ps(result128_1, (__m128bh) matrixArray128_2, (__m128bh) xArray128);
if (tail_num > 4) {
unsigned char store_mask_value = (((unsigned char)0xf) >> (8-tail_num));
__mmask8 store_mask = *((__mmask8*) &store_mask_value);
STORE4_COMPLETE_RESULT(result128_0, y+tag_m_32x)
STORE4_MASK_COMPLETE_RESULT(result128_1, y+tag_m_32x+4, store_mask)
} else {
unsigned char store_mask_value = (((unsigned char)0xf) >> (4-tail_num));
__mmask8 store_mask = *((__mmask8*) &store_mask_value);
STORE4_MASK_COMPLETE_RESULT(result128_0, y+tag_m_32x, store_mask)
}
}
return 0;
}
// 32 rows parallel processing BF16 GEMV kernel for n=2 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_32x2_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_32x2_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_32x2_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_32x2(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_32x = m & (~31);
__m512i matrixArray_0, matrixArray_1;
__m512i xArray;
__m512 result_0, result_1;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
unsigned char load_mask_value = (((unsigned char)0xff) >> 6);
__mmask8 load_mask = *((__mmask8*) &load_mask_value);
xArray = _mm512_broadcastd_epi32(_mm_maskz_loadu_epi16(load_mask, x));
if (tag_m_32x > 0) {
for (BLASLONG idx_m = 0; idx_m < tag_m_32x; idx_m+=32) {
result_0 = _mm512_setzero_ps();
result_1 = _mm512_setzero_ps();
matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*2]); // Load 16 rows as n=2
matrixArray_1 = _mm512_loadu_si512(&a[(idx_m+16)*2]); // Load 16 rows as n=2
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray);
STORE16_COMPLETE_RESULT(result_0, y+idx_m)
STORE16_COMPLETE_RESULT(result_1, y+idx_m+16)
}
}
if (m - tag_m_32x >= 16) {
result_0 = _mm512_setzero_ps();
matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_32x)*2]); // Load 16 rows with n=2
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray);
STORE16_COMPLETE_RESULT(result_0, y+tag_m_32x)
tag_m_32x += 16;
}
BLASLONG tail_num = m - tag_m_32x;
if (tail_num > 8) {
result_0 = _mm512_setzero_ps();
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-(m&15)));
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
matrixArray_0 = _mm512_maskz_loadu_epi32(tail_mask, &a[(tag_m_32x)*2]); // Load 16 rows with n=2
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray);
STORE16_MASK_COMPLETE_RESULT(result_0, y+tag_m_32x, tail_mask)
} else if (tail_num == 8) {
__m256 result256 = _mm256_setzero_ps();
__m256i matrixArray256 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*2]); // Load 8 rows with n=2
__m256i xArray256 = _mm512_castsi512_si256(xArray);
result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) xArray256);
STORE8_COMPLETE_RESULT(result256, y+tag_m_32x)
} else {
__m256 result256 = _mm256_setzero_ps();
unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-(m&7)));
__mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
__m256i matrixArray256 = _mm256_maskz_loadu_epi32(tail_mask, &a[(tag_m_32x)*2]); // Load 8 rows with n=2
__m256i xArray256 = _mm512_castsi512_si256(xArray);
result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) xArray256);
STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_32x, tail_mask)
}
return 0;
}
// 32 rows parallel processing BF16 GEMV kernel for n=3 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_32x3_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_32x3_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_32x3_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_32x = m & (~31);
__m512 result_0, result_1;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5);
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
__m128i xTmp = _mm_maskz_loadu_epi16(x_load_mask, x); // x0|x1|x2|0|0|0|0|0|
__m512i xArray_0 = _mm512_broadcastd_epi32(xTmp); // x0|x1|x0|x1|...|x0|x1|
__m512i xArray_1 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(xTmp, 0x1)); // x2| 0|x2| 0|...|x2| 0|
__m512i load_idx_base;
__m512i M512_EPI16_2, M512_EPI16_8, M512_EPI16_16;
M512_EPI16_2 = _mm512_set1_epi16(2);
M512_EPI16_8 = _mm512_add_epi16(M512_EPI16_2, M512_EPI16_2);
M512_EPI16_8 = _mm512_add_epi16(M512_EPI16_8, M512_EPI16_8);
M512_EPI16_16 = _mm512_add_epi16(M512_EPI16_8, M512_EPI16_8);
load_idx_base = _mm512_set_epi16(46, 45, 43, 42, 40, 39, 37, 36, 34, 33, 31, 30, 28, 27, 25, 24,
22, 21, 19, 18, 16, 15, 13, 12, 10, 9, 7, 6, 4, 3, 1, 0);
if (tag_m_32x > 0) {
__m512i load_idx01_1st, load_idx01_2nd, load_idx2_1st, load_idx2_2nd;
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6;
unsigned int idx_blend_mask_value = ((unsigned int)0x80000000);
__mmask32 idx_blend_mask = *((__mmask32*) &idx_blend_mask_value);
load_idx01_1st = load_idx_base;
load_idx01_2nd = _mm512_add_epi16(load_idx01_1st, M512_EPI16_16);
load_idx2_1st = _mm512_add_epi16(load_idx01_1st, M512_EPI16_2);
load_idx2_2nd = _mm512_add_epi16(load_idx01_2nd, M512_EPI16_2);
load_idx2_2nd = _mm512_mask_blend_epi16(idx_blend_mask, load_idx2_2nd, _mm512_setzero_si512());
for (BLASLONG idx_m = 0; idx_m < tag_m_32x; idx_m+=32) {
result_0 = _mm512_setzero_ps();
result_1 = _mm512_setzero_ps();
matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*3]); // Load 10 rows with n=3 plus 2 element
matrixArray_1 = _mm512_loadu_si512(&a[((idx_m+10)*3 + 2)]); // Load 10 rows with n=3 plus 2 element
matrixArray_2 = _mm512_loadu_si512(&a[((idx_m+21)*3 + 1)]); // Load 10 rows with n=3 plus 2 element
matrixArray_3 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_1st, matrixArray_1); // Select the first 2 elements for each row
matrixArray_4 = _mm512_permutex2var_epi16(matrixArray_1, load_idx01_2nd, matrixArray_2); // Select the first 2 elements for each row
matrixArray_5 = _mm512_permutex2var_epi16(matrixArray_0, load_idx2_1st, matrixArray_1); // Select the third element for each row
matrixArray_6 = _mm512_permutex2var_epi16(matrixArray_1, load_idx2_2nd, matrixArray_2); // Select the third element for each row
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_3, (__m512bh) xArray_0);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_5, (__m512bh) xArray_1);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_4, (__m512bh) xArray_0);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_6, (__m512bh) xArray_1);
STORE16_COMPLETE_RESULT(result_0, y+idx_m)
STORE16_COMPLETE_RESULT(result_1, y+idx_m+16)
}
}
if (tag_m_32x != m) {
__m256i load256_idx01_1st, load256_idx01_2nd, load256_idx2_1st, load256_idx2_2nd;
__m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6;
__m256 result256_0, result256_1;
unsigned short idx256_blend_mask_value = ((unsigned short)0x8000);
__mmask16 idx256_blend_mask = *((__mmask16*) &idx256_blend_mask_value);
load256_idx01_1st = _mm512_castsi512_si256(load_idx_base);
load256_idx01_2nd = _mm256_add_epi16(load256_idx01_1st, _mm512_castsi512_si256(M512_EPI16_8));
load256_idx2_1st = _mm256_add_epi16(load256_idx01_1st, _mm512_castsi512_si256(M512_EPI16_2));
load256_idx2_2nd = _mm256_add_epi16(load256_idx01_2nd, _mm512_castsi512_si256(M512_EPI16_2));
load256_idx2_2nd = _mm256_mask_blend_epi16(idx256_blend_mask, load256_idx2_2nd, _mm256_setzero_si256());
if (m - tag_m_32x > 15) {
result256_0 = _mm256_setzero_ps();
result256_1 = _mm256_setzero_ps();
matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
matrixArray256_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element
matrixArray256_2 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+10)*3 + 2)]); // Load 5 rows with n=3 plus 1 element
matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row
matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
matrixArray256_6 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx2_2nd, matrixArray256_2); // Select the third element for each row
result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_4, (__m256bh) _mm512_castsi512_si256(xArray_0));
result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_6, (__m256bh) _mm512_castsi512_si256(xArray_1));
STORE8_COMPLETE_RESULT(result256_0, y+tag_m_32x)
STORE8_COMPLETE_RESULT(result256_1, y+tag_m_32x+8)
tag_m_32x += 16;
}
if (tag_m_32x != m) {
result256_0 = _mm256_setzero_ps();
result256_1 = _mm256_setzero_ps();
BLASLONG tail_num = m-tag_m_32x;
if (tail_num > 10) {
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-10-1)*3+1)));
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
matrixArray256_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element
matrixArray256_2 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+10)*3 + 2)]); // Load m-tag_m_32x-10 rows
matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row
matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
matrixArray256_6 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx2_2nd, matrixArray256_2); // Select the third element for each row
result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_4, (__m256bh) _mm512_castsi512_si256(xArray_0));
result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_6, (__m256bh) _mm512_castsi512_si256(xArray_1));
} else if (tail_num > 5) {
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-5-1)*3+2)));
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
matrixArray256_1 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+5)*3+1)]); // Load m-tag_m_32x-5 rows
matrixArray256_2 = _mm256_setzero_si256();
matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row
matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
matrixArray256_6 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx2_2nd, matrixArray256_2); // Select the third element for each row
result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_4, (__m256bh) _mm512_castsi512_si256(xArray_0));
result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_6, (__m256bh) _mm512_castsi512_si256(xArray_1));
} else {
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-(tail_num*3)));
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
matrixArray256_0 = _mm256_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)*3]); // Load m-tag_m_32x rows
matrixArray256_1 = _mm256_setzero_si256();
matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
}
unsigned short store_tail_mask_value = (((unsigned short)0xffff) >> (16-(tail_num)));
__mmask16 store_tail_mask = *((__mmask16*) &store_tail_mask_value);
__m512 result512 = _mm512_insertf32x8(_mm512_castps256_ps512(result256_0), result256_1, 0x1);
STORE16_MASK_COMPLETE_RESULT(result512, y+tag_m_32x, store_tail_mask)
}
}
return 0;
}
// 16 rows parallel processing BF16 GEMV kernel for n=4 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_16x4_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_16x4_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_16x4_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_16x4(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_16x = m & (~15);
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
__m512i xArray_01, xArray_23, xArray_remix;
__m512 result;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512i M512_EPI32_1 = _mm512_set1_epi32(1);
__m512i idx_base_0 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_1);
__m512i idx_base_remix = _mm512_inserti32x8(idx_base_0, _mm512_castsi512_si256(idx_base_1), 0x1);
unsigned char x_load_mask_value = (((unsigned char)0xf) >> 2);
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
__m128i xTmp = _mm_maskz_loadu_epi32(x_load_mask, x); // |x0|x1|x2|x3|0|0|0|0|
xArray_01 = _mm512_broadcastd_epi32(xTmp); // |x0|x1|x0|x1|...|x0|x1|
xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(xTmp, 0x1)); // |x2|x3|x2|x3|...|x2|x3|
unsigned short blend_mask_value = ((unsigned short)0xff00);
__mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
xArray_remix = _mm512_mask_blend_epi32(blend_mask, xArray_01, xArray_23); // |x0|x1|x0|x1|x0|x1|x0|x1|...|x2|x3|x2|x3|x2|x3|x2|x3|
if (tag_m_16x > 0) {
for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
result = _mm512_setzero_ps();
matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*4]); // Load 8 rows with n=4
matrixArray_1 = _mm512_loadu_si512(&a[(idx_m+8)*4]); // Load 8 rows with n=4
matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_0, idx_base_0, matrixArray_1); // |a0|a1|...|h0|h1|i0|i1|...|p0|p1|
matrixArray_3 = _mm512_permutex2var_epi32(matrixArray_0, idx_base_1, matrixArray_1); // |a2|a3|...|h2|h3|i2|i3|...|p2|p3|
result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_2, (__m512bh) xArray_01);
result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_3, (__m512bh) xArray_23);
STORE16_COMPLETE_RESULT(result, y+idx_m)
}
}
if (m - tag_m_16x > 7) {
result = _mm512_setzero_ps();
matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x)*4]); // Load 8 rows with n=4
matrixArray_2 = _mm512_permutexvar_epi32(idx_base_remix, matrixArray_0); // a0|a1|...|h0|h1|a2|a3|...|h2|h3|
result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_2, (__m512bh) xArray_remix);
__m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result), _mm512_extractf32x8_ps(result, 1));
STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
tag_m_16x += 8;
}
BLASLONG tail_num = m-tag_m_16x;
if (tail_num != 0) {
result = _mm512_setzero_ps();
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-tail_num*2));
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
matrixArray_0 = _mm512_maskz_loadu_epi32(tail_mask, &a[(tag_m_16x)*4]); // Load 8 rows with n=4
matrixArray_2 = _mm512_permutexvar_epi32(idx_base_remix, matrixArray_0); // a0|a1|...|h0|h1|a2|a3|...|h2|h3|
result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_2, (__m512bh) xArray_remix);
__m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result), _mm512_extractf32x8_ps(result, 1));
unsigned char store_tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
__mmask8 store_tail_mask = *((__mmask8*) &store_tail_mask_value);
STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_16x, store_tail_mask)
}
return 0;
}
// 30 rows parallel processing BF16 GEMV kernel for n=5 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_30x5_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_30x5_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_30x5_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_30x5(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_30x = m - (m%30);
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 3);
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
__m128i x128 = _mm_maskz_loadu_epi16(x_load_mask, x); // x0|x1|x2|x3|x4|0|0|0|
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512 result_0, result_1;
__m512i xArray_01 = _mm512_broadcastd_epi32(x128); // x0|x1|x0|x1|...|x0|x1|
__m512i xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)); // x2|x3|x2|x3|...|x2|x3|
__m512i xArray_4 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)); // x4| 0|x4| 0|...|x4| 0|
__m512i M512_EPI16_2 = _mm512_set1_epi16(2);
__m512i load_idx01_stage1_1st = _mm512_set_epi16( 0, 0, 0, 0, 0, 0, 0, 0, 58, 57, 53, 52, 48, 47, 43, 42,
38, 37, 33, 32, 26, 25, 21, 20, 16, 15, 11, 10, 6, 5, 1, 0);
__m512i load_idx01_stage1_2nd = _mm512_shuffle_i32x4(load_idx01_stage1_1st, load_idx01_stage1_1st, 0x39);
__m512i load_idx01_stage1_3rd = _mm512_shuffle_i32x4(load_idx01_stage1_1st, load_idx01_stage1_1st, 0x4f);
__m512i load_idx23_stage1_1st = _mm512_add_epi16(load_idx01_stage1_1st, M512_EPI16_2);
__m512i load_idx23_stage1_2nd = _mm512_add_epi16(load_idx01_stage1_2nd, M512_EPI16_2);
__m512i load_idx23_stage1_3rd = _mm512_add_epi16(load_idx01_stage1_3rd, M512_EPI16_2);
__m512i load_idx4_stage1_1st = _mm512_add_epi16(load_idx23_stage1_1st, M512_EPI16_2);
__m512i load_idx4_stage1_2nd = _mm512_add_epi16(load_idx23_stage1_2nd, M512_EPI16_2);
__m512i load_idx4_stage1_3rd = _mm512_add_epi16(load_idx23_stage1_3rd, M512_EPI16_2);
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4;
__m512i matrixArray_stage1_0, matrixArray_stage1_1, matrixArray_stage1_2;
__m512i matrixArray_stage2_0, matrixArray_stage2_1;
unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 2);
__mmask32 load_mask = *((__mmask32*) &load_mask_value);
unsigned short store_mask_value = (((unsigned short)0xffff) >> 2);
__mmask16 store_mask = *((__mmask16*) &store_mask_value);
if (tag_m_30x > 0) {
unsigned short blend_mask_value_0 = ((unsigned short)0xf000);
__mmask16 blend_mask_0 = *((__mmask16*) &blend_mask_value_0);
unsigned short blend_mask_value_1 = ((unsigned short)0x3f00);
__mmask16 blend_mask_1 = *((__mmask16*) &blend_mask_value_1);
for (BLASLONG idx_m = 0; idx_m < tag_m_30x; idx_m+=30) {
result_0 = _mm512_setzero_ps();
result_1 = _mm512_setzero_ps();
matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m)*5]); // Load 6 rows with n=5
matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+6)*5)]); // Load 6 rows with n=5
matrixArray_2 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+12)*5)]); // Load 6 rows with n=5
matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+18)*5)]); // Load 6 rows with n=5
matrixArray_4 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+24)*5)]); // Load 6 rows with n=5
// Process the 0|1 elements
// Stage 1: Select the 0|1 elements for each row
matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_stage1_1st, matrixArray_1);
matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_2, load_idx01_stage1_2nd, matrixArray_3);
matrixArray_stage1_2 = _mm512_permutexvar_epi16(load_idx01_stage1_3rd, matrixArray_4);
// Stage 2: Reorder and compress all the 0|1 elements
matrixArray_stage2_0 = _mm512_mask_blend_epi32(blend_mask_0, matrixArray_stage1_0, matrixArray_stage1_1);
matrixArray_stage2_1 = _mm512_mask_blend_epi32(blend_mask_1, matrixArray_stage1_1, matrixArray_stage1_2);
// Calculate the result of the 0|1 elements
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage2_0, (__m512bh) xArray_01);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage2_1, (__m512bh) xArray_01);
// Process the 2|3 elements
// Stage 1: Select the 2|3 elements for each row
matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx23_stage1_1st, matrixArray_1);
matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_2, load_idx23_stage1_2nd, matrixArray_3);
matrixArray_stage1_2 = _mm512_permutexvar_epi16(load_idx23_stage1_3rd, matrixArray_4);
// Stage 2: Reorder and compress all the 2|3 elements
matrixArray_stage2_0 = _mm512_mask_blend_epi32(blend_mask_0, matrixArray_stage1_0, matrixArray_stage1_1);
matrixArray_stage2_1 = _mm512_mask_blend_epi32(blend_mask_1, matrixArray_stage1_1, matrixArray_stage1_2);
// Calculate the result of the 2|3 elements and accumulate the result of 0|1 elements
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage2_0, (__m512bh) xArray_23);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage2_1, (__m512bh) xArray_23);
// Process the for 4 elements
// Stage 1: Select the 4 elements for each row
matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx4_stage1_1st, matrixArray_1);
matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_2, load_idx4_stage1_2nd, matrixArray_3);
matrixArray_stage1_2 = _mm512_permutexvar_epi16(load_idx4_stage1_3rd, matrixArray_4);
// Stage 2: Reorder and compress all the 4 elements
matrixArray_stage2_0 = _mm512_mask_blend_epi32(blend_mask_0, matrixArray_stage1_0, matrixArray_stage1_1);
matrixArray_stage2_1 = _mm512_mask_blend_epi32(blend_mask_1, matrixArray_stage1_1, matrixArray_stage1_2);
// Calculate the result of the 4 element and accumulate the result of 0|1 and 2|3 elements
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage2_0, (__m512bh) xArray_4);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage2_1, (__m512bh) xArray_4);
STORE16_COMPLETE_RESULT(result_0, y+idx_m)
STORE16_MASK_COMPLETE_RESULT(result_1, y+idx_m+16, store_mask)
}
}
if (m - tag_m_30x > 11) {
BLASLONG tag_m_12x = m - ((m-tag_m_30x)%12);
for (BLASLONG idx_m = tag_m_30x; idx_m < tag_m_12x; idx_m+=12) {
unsigned short store_less_mask_value = (((unsigned short)0xffff) >> 4);
__mmask16 store_less_mask = *((__mmask16*) &store_less_mask_value);
result_0 = _mm512_setzero_ps();
matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m)*5]); // Load 6 rows with n=5
matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+6)*5)]); // Load 6 rows with n=5
// Interleave the elements
matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_stage1_1st, matrixArray_1);
matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_0, load_idx23_stage1_1st, matrixArray_1);
matrixArray_stage1_2 = _mm512_permutex2var_epi16(matrixArray_0, load_idx4_stage1_1st, matrixArray_1);
// Calculate and accumulate the result
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_0, (__m512bh) xArray_01);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_1, (__m512bh) xArray_23);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_2, (__m512bh) xArray_4);
STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_less_mask)
tag_m_30x += 12;
}
}
BLASLONG tail_num = m - tag_m_30x;
if (tail_num > 6) {
unsigned short store_less_mask_value = (((unsigned short)0xffff) >> (4+(12-tail_num)));
__mmask16 store_less_mask = *((__mmask16*) &store_less_mask_value);
unsigned int load_less_mask_value = (((unsigned int)0xffffffff) >> (2+(12-tail_num)*5));
__mmask32 load_less_mask = *((__mmask32*) &load_less_mask_value);
result_0 = _mm512_setzero_ps();
matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_30x)*5]); // Load 6 rows with n=5
matrixArray_1 = _mm512_maskz_loadu_epi16(load_less_mask, &a[((tag_m_30x+6)*5)]); // Load x rows with n=5
// Interleave the elements
matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_stage1_1st, matrixArray_1);
matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_0, load_idx23_stage1_1st, matrixArray_1);
matrixArray_stage1_2 = _mm512_permutex2var_epi16(matrixArray_0, load_idx4_stage1_1st, matrixArray_1);
// Calculate and accumulate the result
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_0, (__m512bh) xArray_01);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_1, (__m512bh) xArray_23);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_2, (__m512bh) xArray_4);
STORE16_MASK_COMPLETE_RESULT(result_0, y+tag_m_30x, store_less_mask)
} else {
__m128i matrixArray128;
__m128 result128, tmp128;
for (BLASLONG i = tag_m_30x; i < m; i++) {
result128 = _mm_setzero_ps();
matrixArray128 = _mm_maskz_loadu_epi16(x_load_mask, &a[(i)*5]); // Load 1 rows with n=5
result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
tmp128 = _mm_shuffle_ps(result128, result128, 14);
result128 = _mm_add_ps(result128, tmp128);
tmp128 = _mm_shuffle_ps(result128, result128, 1);
result128 = _mm_add_ps(result128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * result128[0] + beta * y[i];
#else
y[i] = alpha * result128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = result128[0] * alpha;
#else
y[i] = result128[0];
#endif
#endif
}
}
return 0;
}
// 16 rows parallel processing BF16 GEMV kernel for n=6 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_16x6_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_16x6_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_16x6_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_16x6(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_16x = m & (~15);
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 2);
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
__m128i x128 = _mm_maskz_loadu_epi16(x_load_mask, x); // x0|x1|x2|x3|x4|x5|0|0|
if (tag_m_16x > 0) {
__m512 result_0;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512i M512_EPI32_1 = _mm512_set1_epi32(1);
__m512i load_idx01_1st = _mm512_set_epi32( 0, 0, 0, 0, 0, 30, 27, 24, 21, 18, 15, 12, 9, 6, 3, 0);
__m512i load_idx01_2nd = _mm512_set_epi32(13, 10, 7, 4, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
__m512i load_idx23_1st = _mm512_add_epi32(load_idx01_1st, M512_EPI32_1);
__m512i load_idx23_2nd = _mm512_add_epi32(load_idx01_2nd, M512_EPI32_1);
__m512i load_idx45_1st = _mm512_add_epi32(load_idx23_1st, M512_EPI32_1);
__m512i load_idx45_2nd = _mm512_add_epi32(load_idx23_2nd, M512_EPI32_1);
unsigned short blend_mask_value = ((unsigned short)0x0400);
__mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
// Set the 11th element to be 0 as invalid index for a 512 bit epi32 register
load_idx45_1st = _mm512_mask_blend_epi32(blend_mask, load_idx45_1st, load_idx01_2nd);
// Set the 11th element to be 0 as 0 is the correct index
load_idx45_2nd = _mm512_mask_blend_epi32(blend_mask, load_idx45_2nd, load_idx01_2nd);
__m512i xArray_01 = _mm512_broadcastd_epi32(x128); // x0|x1|x0|x1|...|x0|x1|
__m512i xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)); // x2|x3|x2|x3|...|x2|x3|
__m512i xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)); // x4|x5|x4|x5|...|x4|x5|
unsigned short permute_mask01_uint = (((unsigned short)0xf800));
__mmask16 permute_mask01 = *((__mmask16*) &permute_mask01_uint);
unsigned short permute_mask45_uint = (((unsigned short)0xfc00));
__mmask16 permute_mask45 = *((__mmask16*) &permute_mask45_uint);
__m512i matrixArray_0, matrixArray_1, matrixArray_2;
__m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2;
for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
result_0 = _mm512_setzero_ps();
matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*6]); // Load 5 rows with n=6 plus 2 element
matrixArray_1 = _mm512_loadu_si512(&a[((idx_m+5)*6 + 2)]); // Load 5 rows with n=6 plus 2 element
matrixArray_2 = _mm512_loadu_si512(&a[((idx_m+10)*6 + 4)]); // Load 5 rows with n=6 plus 2 element
// Stage 1: interleave for the a..k elements
matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx01_1st, matrixArray_1);
matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx23_1st, matrixArray_1);
matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_0, load_idx45_1st, matrixArray_1);
// Stage 2: interleave for the l..p elements and remix together
matrixArray_stage_0 = _mm512_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask01, load_idx01_2nd, matrixArray_2);
matrixArray_stage_1 = _mm512_mask_permutexvar_epi32(matrixArray_stage_1, permute_mask01, load_idx23_2nd, matrixArray_2);
matrixArray_stage_2 = _mm512_mask_permutexvar_epi32(matrixArray_stage_2, permute_mask45, load_idx45_2nd, matrixArray_2);
// Calculate the result of the 0|1 elements
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_01);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_23);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_2, (__m512bh) xArray_45);
STORE16_COMPLETE_RESULT(result_0, y+idx_m)
}
if (m - tag_m_16x > 7) {
__m256i M256_EPI32_1 = _mm512_castsi512_si256(M512_EPI32_1);
__m256i load_idx01_1st = _mm256_set_epi32( 0, 0, 15, 12, 9, 6, 3, 0);
__m256i load_idx01_2nd = _mm256_set_epi32( 5, 2, 0, 0, 0, 0, 0, 0);
__m256i load_idx23_1st = _mm256_add_epi32(load_idx01_1st, M256_EPI32_1);
__m256i load_idx23_2nd = _mm256_add_epi32(load_idx01_2nd, M256_EPI32_1);
unsigned char blend_mask_value = ((unsigned char)0x20);
__mmask8 blend_mask = *((__mmask8*) &blend_mask_value);
// Set the 6th element to be 0 as invalid index for a 512 bit epi32 register
load_idx23_1st = _mm256_mask_blend_epi32(blend_mask, load_idx23_1st, load_idx01_2nd);
// Set the 6th element to be 0 as 0 is the correct index
load_idx23_2nd = _mm256_mask_blend_epi32(blend_mask, load_idx23_2nd, load_idx01_2nd);
__m256i load_idx45_1st = _mm256_add_epi32(load_idx23_1st, M256_EPI32_1);
__m256i load_idx45_2nd = _mm256_add_epi32(load_idx23_2nd, M256_EPI32_1);
unsigned char permute_mask01_uint = (((unsigned char)0xc0));
__mmask8 permute_mask01 = *((__mmask8*) &permute_mask01_uint);
unsigned char permute_mask45_uint = (((unsigned char)0xe0));
__mmask8 permute_mask45 = *((__mmask8*) &permute_mask45_uint);
__m256i matrixArray_0, matrixArray_1, matrixArray_2;
__m256i matrixArray_stage_0;
__m256 result256_0;
result256_0 = _mm256_setzero_ps();
matrixArray_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_16x)*6]); // Load 2 rows with n=6 plus 4 element
matrixArray_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_16x+2)*6 + 4)]); // Load 2 rows with n=6 plus 4 element
matrixArray_2 = _mm256_loadu_si256((__m256i *)&a[((tag_m_16x+5)*6 + 2)]); // Load 2 rows with n=6 plus 4 element
// Process the 0|1 elements
// Select the 0|1 elements for each row
matrixArray_stage_0 = _mm256_permutex2var_epi32(matrixArray_0, load_idx01_1st, matrixArray_1);
matrixArray_stage_0 = _mm256_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask01, load_idx01_2nd, matrixArray_2);
// Calculate the result of the 0|1 elements
result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray_stage_0, (__m256bh) _mm512_castsi512_si256(xArray_01));
// Process the 2|3 elements
// Select the 2|3 elements for each row
matrixArray_stage_0 = _mm256_permutex2var_epi32(matrixArray_0, load_idx23_1st, matrixArray_1);
matrixArray_stage_0 = _mm256_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask45, load_idx23_2nd, matrixArray_2);
// Calculate the result of the 0|1 elements
result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray_stage_0, (__m256bh) _mm512_castsi512_si256(xArray_23));
// Process the for 4 elements
// Select the 4|5 elements for each row
matrixArray_stage_0 = _mm256_permutex2var_epi32(matrixArray_0, load_idx45_1st, matrixArray_1);
matrixArray_stage_0 = _mm256_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask45, load_idx45_2nd, matrixArray_2);
// Calculate the result of the 0|1 elements
result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray_stage_0, (__m256bh) _mm512_castsi512_si256(xArray_45));
STORE8_COMPLETE_RESULT(result256_0, y+tag_m_16x)
tag_m_16x += 8;
}
}
if (tag_m_16x != m) {
__m128i matrixArray128;
__m128 result128, tmp128;
for (BLASLONG i = tag_m_16x; i < m; i++) {
result128 = _mm_setzero_ps();
matrixArray128 = _mm_maskz_loadu_epi16(x_load_mask, &a[(i)*6]); // Load 1 rows with n=6
result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
tmp128 = _mm_shuffle_ps(result128, result128, 14);
result128 = _mm_add_ps(result128, tmp128);
tmp128 = _mm_shuffle_ps(result128, result128, 1);
result128 = _mm_add_ps(result128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * result128[0] + beta * y[i];
#else
y[i] = alpha * result128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = result128[0] * alpha;
#else
y[i] = result128[0];
#endif
#endif
}
}
return 0;
}
// 16 rows parallel processing BF16 GEMV kernel for n=7 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_16x7_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_16x7_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_16x7_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_16x7(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_16x = m & (~15);
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 1);
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
__m128i x128 = _mm_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|0|
if (tag_m_16x > 0) {
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
__m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3;
__m512i xArray_0123, xArray_4567;
__m512 result_0, result_1, result_2, result_3;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512i M512_EPI32_2 = _mm512_set1_epi32(2);
__m512i load_idx_stage1_0 = _mm512_set_epi16(31, 27, 26, 25, 24, 23, 22, 21, 31, 20, 19, 18, 17, 16, 15, 14,
31, 13, 12, 11, 10, 9, 8, 7, 31, 6, 5, 4, 3, 2, 1, 0);
__m512i load_idx_stage2_0 = _mm512_set_epi32(29, 25, 21, 17, 13, 9, 5, 1, 28, 24, 20, 16, 12, 8, 4, 0);
__m512i load_idx_stage2_1 = _mm512_add_epi32(load_idx_stage2_0, M512_EPI32_2);
unsigned short x_blend_mask_value = ((unsigned short)0xff00);
__mmask16 x_blend_mask = *((__mmask16*) &x_blend_mask_value);
xArray_0123 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(x128), \
_mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)));
xArray_4567 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)), \
_mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x3)));
unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 4);
__mmask32 load_mask = *((__mmask32*) &load_mask_value);
for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
result_0 = _mm512_setzero_ps();
result_1 = _mm512_setzero_ps();
matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m)*7]); // Load 4 rows with n=7
matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+4)*7]); // Load 4 rows with n=7
matrixArray_2 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+8)*7]); // Load 4 rows with n=7
matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+12)*7]); // Load 4 rows with n=7
// Stage 1: padding
matrixArray_0 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_0); // |a0|a1|a2|a3|...|b6|b7|c0|c1|c2|c3|...|d6|d7|
matrixArray_1 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_1); // |e0|e1|e2|e3|...|f6|f7|g0|g1|g2|g3|...|h6|h7|
matrixArray_2 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_2); // |i0|i1|i2|i3|...|j6|j7|k0|k1|k2|k3|...|l6|l7|
matrixArray_3 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_3); // |m0|m1|m2|m3|...|n6|n7|o0|o1|o2|o3|...|p6|p7|
// Stage 2: interleave per 32 bits
matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|...|h0|h1|a2|a3|...|h2|h3|
matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|...|h4|h5|a6|a7|...|h6|h7|
matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_0, matrixArray_3); // |i0|i1|...|p0|p1|i2|i3|...|p2|p3|
matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_1, matrixArray_3); // |i4|i5|...|p4|p5|i6|i7|...|p6|p7|
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_2, (__m512bh) xArray_0123);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_3, (__m512bh) xArray_4567);
// Stage 3: interleave per 256 bits
result_2 = _mm512_shuffle_f32x4(result_0, result_1, 0x44);
result_3 = _mm512_shuffle_f32x4(result_0, result_1, 0xee);
result_2 = _mm512_add_ps(result_2, result_3);
STORE16_COMPLETE_RESULT(result_2, y+idx_m)
}
if (m - tag_m_16x > 7) {
result_0 = _mm512_setzero_ps();
matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_16x)*7]); // Load 4 rows with n=7
matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_16x+4)*7]); // Load 4 rows with n=7
// Stage 1: padding
matrixArray_0 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_0); // |a0|a1|a2|a3|...|b6|b7|c0|c1|c2|c3|...|d6|d7|
matrixArray_1 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_1); // |e0|e1|e2|e3|...|f6|f7|g0|g1|g2|g3|...|h6|h7|
// Stage 2: interleave per 32 bits
matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
__m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
tag_m_16x += 8;
}
BLASLONG tail_num = m - tag_m_16x;
if (tail_num > 3) {
result_0 = _mm512_setzero_ps();
matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_16x)*7]); // Load 4 rows with n=7
unsigned int tail_load_mask_value = (((unsigned int)0xffffffff) >> (4+(8-tail_num)*7));
__mmask32 tail_load_mask = *((__mmask32*) &tail_load_mask_value);
matrixArray_1 = _mm512_maskz_loadu_epi16(tail_load_mask, &a[(tag_m_16x+4)*7]); // Load 4 rows with n=7
// Stage 1: padding
matrixArray_0 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_0); // |a0|a1|a2|a3|...|b6|b7|c0|c1|c2|c3|...|d6|d7|
matrixArray_1 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_1); // |e0|e1|e2|e3|...|f6|f7|g0|g1|g2|g3|...|h6|h7|
// Stage 2: interleave per 32 bits
matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
__m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
__mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_16x, tail_mask)
tag_m_16x = m;
}
}
if (tag_m_16x != m) {
__m128i matrixArray128;
__m128 result128, tmp128;
for (BLASLONG i = tag_m_16x; i < m; i++) {
result128 = _mm_setzero_ps();
matrixArray128 = _mm_maskz_loadu_epi16(x_load_mask, &a[(i)*7]); // Load 1 rows with n=7
result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
tmp128 = _mm_shuffle_ps(result128, result128, 14);
result128 = _mm_add_ps(result128, tmp128);
tmp128 = _mm_shuffle_ps(result128, result128, 1);
result128 = _mm_add_ps(result128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * result128[0] + beta * y[i];
#else
y[i] = alpha * result128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = result128[0] * alpha;
#else
y[i] = result128[0];
#endif
#endif
}
}
return 0;
}
// 16 rows parallel processing BF16 GEMV kernel for n=8 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_16x8_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_16x8_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_16x8_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_16x = m & (~15);
__m128i x128 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7|
if (tag_m_16x > 0) {
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
__m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3;
__m512i xArray_0123, xArray_4567;
__m512 result_0, result_1, result_2, result_3;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512i M512_EPI32_2 = _mm512_set1_epi32(2);
__m512i load_idx_stage2_0 = _mm512_set_epi32(29, 25, 21, 17, 13, 9, 5, 1, 28, 24, 20, 16, 12, 8, 4, 0);
__m512i load_idx_stage2_1 = _mm512_add_epi32(load_idx_stage2_0, M512_EPI32_2);
unsigned short x_blend_mask_value = ((unsigned short)0xff00);
__mmask16 x_blend_mask = *((__mmask16*) &x_blend_mask_value);
xArray_0123 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(x128), \
_mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)));
xArray_4567 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)), \
_mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x3)));
for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
result_0 = _mm512_setzero_ps();
result_1 = _mm512_setzero_ps();
matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*8]); // Load 4 rows with n=8
matrixArray_1 = _mm512_loadu_si512(&a[(idx_m+4)*8]); // Load 4 rows with n=8
matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+8)*8]); // Load 4 rows with n=8
matrixArray_3 = _mm512_loadu_si512(&a[(idx_m+12)*8]); // Load 4 rows with n=8
// Stage 1: interleave per 32 bits
matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|...|h0|h1|a2|a3|...|h2|h3|
matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|...|h4|h5|a6|a7|...|h6|h7|
matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_0, matrixArray_3); // |i0|i1|...|p0|p1|i2|i3|...|p2|p3|
matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_1, matrixArray_3); // |i4|i5|...|p4|p5|i6|i7|...|p6|p7|
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_2, (__m512bh) xArray_0123);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_3, (__m512bh) xArray_4567);
// Stage 2: interleave per 256 bits
result_2 = _mm512_shuffle_f32x4(result_0, result_1, 0x44);
result_3 = _mm512_shuffle_f32x4(result_0, result_1, 0xee);
result_2 = _mm512_add_ps(result_2, result_3);
STORE16_COMPLETE_RESULT(result_2, y+idx_m)
}
if (m - tag_m_16x > 7) {
result_0 = _mm512_setzero_ps();
matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x)*8]); // Load 4 rows with n=8
matrixArray_1 = _mm512_loadu_si512(&a[(tag_m_16x+4)*8]); // Load 4 rows with n=8
// Stage 1: interleave per 32 bits
matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
__m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
tag_m_16x += 8;
}
BLASLONG tail_num = m - tag_m_16x;
if (tail_num > 3) {
result_0 = _mm512_setzero_ps();
matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x)*8]); // Load 4 rows with n=8
unsigned short tail_load_mask_value = (((unsigned int)0xffff) >> ((8-tail_num)*4));
__mmask16 tail_load_mask = *((__mmask16*) &tail_load_mask_value);
matrixArray_1 = _mm512_maskz_loadu_epi32(tail_load_mask, &a[(tag_m_16x+4)*8]); // Load 4 rows with n=8
// Stage 1: interleave per 32 bits
matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
__m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
__mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_16x, tail_mask)
tag_m_16x = m;
}
}
if (tag_m_16x != m) {
__m128i matrixArray128;
__m128 result128, tmp128;
for (BLASLONG i = tag_m_16x; i < m; i++) {
result128 = _mm_setzero_ps();
matrixArray128 = _mm_loadu_si128((__m128i *)&a[(i)*8]); // Load 1 rows with n=8
result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
tmp128 = _mm_shuffle_ps(result128, result128, 14);
result128 = _mm_add_ps(result128, tmp128);
tmp128 = _mm_shuffle_ps(result128, result128, 1);
result128 = _mm_add_ps(result128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * result128[0] + beta * y[i];
#else
y[i] = alpha * result128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = result128[0] * alpha;
#else
y[i] = result128[0];
#endif
#endif
}
}
return 0;
}
// 14 rows parallel processing BF16 GEMV kernel for n=9 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_14x9_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_14x9_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_14x9_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_14x9(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_14x = m - (m%14);
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 7);
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
__m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7|
__m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|0 |0 | 0| 0| 0| 0| 0|
if (tag_m_14x > 0) {
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5;
__m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3;
__m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89;
__m512 result_0, result_1;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m256i M256_EPI16_2 = _mm256_set1_epi16(2);
__m256i idx_base_0 = _mm256_set_epi16( 0, 0, 55, 54, 46, 45, 37, 36, 28, 27, 19, 18, 10, 9, 1, 0);
__m256i idx_base_1 = _mm256_add_epi16(idx_base_0, M256_EPI16_2);
__m256i idx_base_2 = _mm256_add_epi16(idx_base_1, M256_EPI16_2);
__m256i idx_base_3 = _mm256_add_epi16(idx_base_2, M256_EPI16_2);
__m256i idx_base_4 = _mm256_add_epi16(idx_base_3, M256_EPI16_2);
__m512i idx_idx = _mm512_set_epi32( 0, 0, 22, 21, 20, 19, 18, 17, 16, 6, 5, 4, 3, 2, 1, 0);
__m512i load_idx_stage1_0 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_0), idx_idx, _mm512_castsi256_si512(idx_base_1));
__m512i load_idx_stage1_1 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_2), idx_idx, _mm512_castsi256_si512(idx_base_3));
__m512i load_idx_stage1_2 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_1), idx_idx, _mm512_castsi256_si512(idx_base_0));
__m512i load_idx_stage1_3 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_3), idx_idx, _mm512_castsi256_si512(idx_base_2));
__m512i load_idx_stage1_4 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_4), idx_idx, _mm512_castsi256_si512(idx_base_4));
__m512i load_idx_stage2_0 = _mm512_set_epi32( 0, 0, 22, 21, 20, 19, 18, 17, 16, 13, 12, 11, 10, 9, 8, 7);
xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0|x1|x0|x1| ... |x0|x1|
xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2|x3|x2|x3| ... |x2|x3|
xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4|x5|x4|x5| ... |x4|x5|
xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6|x7|x6|x7| ... |x6|x7|
xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8|0 |x8| 0| ... |x8| 0|
unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 1);
__mmask32 load_mask = *((__mmask32*) &load_mask_value);
unsigned short blend_mask_value = ((unsigned short)0x3f80);
__mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
unsigned short store_mask_value = (((unsigned short)0xffff) >> 2);
__mmask16 store_mask = *((__mmask16*) &store_mask_value);
for (BLASLONG idx_m = 0; idx_m < tag_m_14x; idx_m+=14) {
result_0 = _mm512_setzero_ps();
result_1 = _mm512_setzero_ps();
matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*9]); // Load 3 rows with n=9 plus 5 elements
matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+3)*9 + 5]); // Load 3 rows with n=9 plus 4 elements
matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+7)*9]); // Load 3 rows with n=9 plus 5 elements
matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+10)*9 + 5]); // Load 3 rows with n=9 plus 4 elements
// Stage 1: interleave per 16 bits
matrixArray_stage_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx_stage1_0, matrixArray_1); // |a0|a1|...|g0|g1|a2|a3|...|g2|g3|x|x|x|x|
matrixArray_stage_1 = _mm512_permutex2var_epi16(matrixArray_0, load_idx_stage1_1, matrixArray_1); // |a4|a5|...|g4|g5|a6|a7|...|g6|g7|x|x|x|x|
matrixArray_stage_2 = _mm512_permutex2var_epi16(matrixArray_2, load_idx_stage1_2, matrixArray_3); // |h2|h3|...|n2|n3|h0|h1|...|n0|n1|x|x|x|x|
matrixArray_stage_3 = _mm512_permutex2var_epi16(matrixArray_2, load_idx_stage1_3, matrixArray_3); // |h6|h7|...|n6|n7|h4|h5|...|n4|n5|x|x|x|x|
matrixArray_4 = _mm512_permutex2var_epi16(matrixArray_0, load_idx_stage1_4, matrixArray_1); // |a8| x|...|g8| x| x| x|...| x| x|x|x|x|x|
matrixArray_5 = _mm512_permutex2var_epi16(matrixArray_2, load_idx_stage1_4, matrixArray_3); // | x| x|...| x| x|h8| x|...|n8| x|x|x|x|x|
// Stage 2: interleave per 32 bits
matrixArray_0 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_0, matrixArray_stage_2); // |a0|a1|b0|b1|...|h0|h1|i0|i1|j0|j1|...|n0|n1|x|x|x|x|
matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, load_idx_stage2_0, matrixArray_stage_2); // |a2|a3|b2|b3|...|h2|h3|i2|i3|j2|j3|...|n2|n3|x|x|x|x|
matrixArray_2 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_1, matrixArray_stage_3); // |a4|a5|b4|b5|...|h4|h5|i4|i5|j4|j5|...|n4|n5|x|x|x|x|
matrixArray_3 = _mm512_permutex2var_epi32(matrixArray_stage_1, load_idx_stage2_0, matrixArray_stage_3); // |a6|a7|b6|b7|...|h6|h7|i6|i7|j6|j7|...|n6|n7|x|x|x|x|
matrixArray_4 = _mm512_mask_blend_epi32(blend_mask, matrixArray_4, matrixArray_5); // |a8| x|b8| x|...|h8| x|i8| x|j8| x|...|n8| x|x|x|x|x|
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
result_0 = _mm512_add_ps(result_0, result_1);
STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
}
}
if (tag_m_14x != m) {
__m256i matrixArray256;
__m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
__m256 result256;
__m128 result128, tmp128;
unsigned short load256_mask_value = (((unsigned short)0xffff) >> 7);
__mmask16 load256_mask = *((__mmask16*) &load256_mask_value);
for (BLASLONG i = tag_m_14x; i < m; i++) {
result256 = _mm256_setzero_ps();
matrixArray256 = _mm256_maskz_loadu_epi16(load256_mask, &a[(i)*9]);
result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
tmp128 = _mm_shuffle_ps(result128, result128, 14);
result128 = _mm_add_ps(result128, tmp128);
tmp128 = _mm_shuffle_ps(result128, result128, 1);
result128 = _mm_add_ps(result128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * result128[0] + beta * y[i];
#else
y[i] = alpha * result128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = result128[0] * alpha;
#else
y[i] = result128[0];
#endif
#endif
}
}
return 0;
}
// 12 rows parallel processing BF16 GEMV kernel for n=10 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_12x10_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_12x10_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_12x10_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_12x10(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_12x = m - (m%12);
unsigned char x_load_mask_value = (((unsigned char)0xf) >> 3);
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
__m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7|
__m128i x128_1 = _mm_maskz_loadu_epi32(x_load_mask, (x+8)); // |x8|x9|0 | 0| 0| 0| 0| 0|
if (tag_m_12x > 0) {
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4;
__m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3, matrixArray_stage_4, matrixArray_stage_5;
__m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89;
__m512 result_0, result_1;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m256i M256_EPI32_1 = _mm256_set1_epi32(1);
__m256i idx_base_0 = _mm256_set_epi32( 0, 0, 26, 21, 16, 10, 5, 0);
__m256i idx_base_1 = _mm256_add_epi32(idx_base_0, M256_EPI32_1);
__m256i idx_base_2 = _mm256_add_epi32(idx_base_1, M256_EPI32_1);
__m256i idx_base_3 = _mm256_add_epi32(idx_base_2, M256_EPI32_1);
__m256i idx_base_4 = _mm256_add_epi32(idx_base_3, M256_EPI32_1);
__m512i idx_idx = _mm512_set_epi32( 0, 0, 0, 0, 21, 20, 19, 18, 17, 16, 5, 4, 3, 2, 1, 0);
__m512i load_idx_stage1_0 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_0), idx_idx, _mm512_castsi256_si512(idx_base_1));
__m512i load_idx_stage1_1 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_2), idx_idx, _mm512_castsi256_si512(idx_base_3));
__m512i load_idx_stage1_2 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_1), idx_idx, _mm512_castsi256_si512(idx_base_0));
__m512i load_idx_stage1_3 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_3), idx_idx, _mm512_castsi256_si512(idx_base_2));
__m512i load_idx_stage1_4 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_4), idx_idx, _mm512_castsi256_si512(idx_base_4));
__m512i load_idx_stage2_0 = _mm512_set_epi32( 0, 0, 0, 0, 21, 20, 19, 18, 17, 16, 11, 10, 9, 8, 7, 6);
xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0|x1|x0|x1| ... |x0|x1|
xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2|x3|x2|x3| ... |x2|x3|
xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4|x5|x4|x5| ... |x4|x5|
xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6|x7|x6|x7| ... |x6|x7|
xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8|x9|x8|x9| ... |x8|x9|
unsigned short blend_mask_value = ((unsigned short)0x0fc0);
__mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
unsigned short load_mask_value = (((unsigned short)0xffff) >> 1);
__mmask16 load_mask = *((__mmask16*) &load_mask_value);
unsigned short store_mask_value = (((unsigned short)0xffff) >> 4);
__mmask16 store_mask = *((__mmask16*) &store_mask_value);
for (BLASLONG idx_m = 0; idx_m < tag_m_12x; idx_m+=12) {
result_0 = _mm512_setzero_ps();
result_1 = _mm512_setzero_ps();
matrixArray_0 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m)*10]); // Load 3 rows with n=10
matrixArray_1 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m+3)*10]); // Load 3 rows with n=10
matrixArray_2 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m+6)*10]); // Load 3 rows with n=10
matrixArray_3 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m+9)*10]); // Load 3 rows with n=10
// Stage 1: interleave per 32 bits
matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage1_0, matrixArray_1); // |a0|a1|...|f0|f1|a2|a3|...|f2|f3|x|x|x|x|x|x|x|x|
matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage1_1, matrixArray_1); // |a4|a5|...|f4|f5|a6|a7|...|f6|f7|x|x|x|x|x|x|x|x|
matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage1_2, matrixArray_3); // |g2|g3|...|l2|l3|g0|g1|...|l0|l1|x|x|x|x|x|x|x|x|
matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage1_3, matrixArray_3); // |g6|g7|...|l6|l7|g4|g5|...|l4|l5|x|x|x|x|x|x|x|x|
matrixArray_stage_4 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage1_4, matrixArray_1); // |a8|a9|...|f8|f9| x| x|...| x| x|x|x|x|x|x|x|x|x|
matrixArray_stage_5 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage1_4, matrixArray_3); // | x| x|...| x| x|g8|g9|...|l8|l9|x|x|x|x|x|x|x|x|
// Stage 3: interleave per 256 bits
matrixArray_0 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_0, matrixArray_stage_2); // |a0|a1|...|l0|l1|x|x|x|x|x|x|x|x|
matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, load_idx_stage2_0, matrixArray_stage_2); // |a2|a3|...|l2|l3|x|x|x|x|x|x|x|x|
matrixArray_2 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_1, matrixArray_stage_3); // |a4|a5|...|l4|l5|x|x|x|x|x|x|x|x|
matrixArray_3 = _mm512_permutex2var_epi32(matrixArray_stage_1, load_idx_stage2_0, matrixArray_stage_3); // |a6|a7|...|l6|l7|x|x|x|x|x|x|x|x|
matrixArray_4 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_4, matrixArray_stage_5); // |a8|a9|...|l8|l9|x|x|x|x|x|x|x|x|
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
result_0 = _mm512_add_ps(result_0, result_1);
STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
}
}
if (tag_m_12x != m) {
__m256i matrixArray256;
__m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
__m256 result256;
__m128 result128, tmp128;
unsigned char load256_mask_value = (((unsigned char)0xff) >> 3);
__mmask8 load256_mask = *((__mmask8*) &load256_mask_value);
for (BLASLONG i = tag_m_12x; i < m; i++) {
result256 = _mm256_setzero_ps();
matrixArray256 = _mm256_maskz_loadu_epi32(load256_mask, &a[(i)*10]);
result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
tmp128 = _mm_shuffle_ps(result128, result128, 14);
result128 = _mm_add_ps(result128, tmp128);
tmp128 = _mm_shuffle_ps(result128, result128, 1);
result128 = _mm_add_ps(result128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * result128[0] + beta * y[i];
#else
y[i] = alpha * result128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = result128[0] * alpha;
#else
y[i] = result128[0];
#endif
#endif
}
}
return 0;
}
// 15 rows parallel processing BF16 GEMV kernel for n=11 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_15x11_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_15x11_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_15x11_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_15x11(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_15x = m - (m%15);
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5);
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
__m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1| x2|x3|x4|x5|x6|x7|
__m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10| 0| 0| 0| 0| 0|
if (tag_m_15x > 0) {
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5;
__m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3, matrixArray_stage_4, matrixArray_stage_5;
__m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89, xArray_10;
__m512 result_0, result_1;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5;
__m512i idx_stage2_base_0, idx_stage2_base_1, idx_stage2_base_2, idx_stage2_base_3;
__m512i M512_EPI16_2, M512_EPI16_4, M512_EPI16_6, M512_EPI32_5;
M512_EPI16_2 = _mm512_set1_epi16(2);
M512_EPI16_4 = _mm512_add_epi16(M512_EPI16_2, M512_EPI16_2);
M512_EPI16_6 = _mm512_add_epi16(M512_EPI16_4, M512_EPI16_2);
M512_EPI32_5 = _mm512_set1_epi32(5);
unsigned int BASE_MASK_10_value = ((unsigned int)0x000003ff);
__mmask32 BASE_MASK_10 = *((__mmask32*) &BASE_MASK_10_value);
unsigned int BASE_MASK_20_value = ((unsigned int)0x000ffc00);
__mmask32 BASE_MASK_20 = *((__mmask32*) &BASE_MASK_20_value);
unsigned int BASE_MASK_30_value = ((unsigned int)0x3ff00000);
__mmask32 BASE_MASK_30 = *((__mmask32*) &BASE_MASK_30_value);
idx_stage1_base_0 = _mm512_set_epi16( 0, 0, 49, 48, 38, 37, 27, 26, 16, 15, 5, 4, 47, 46, 36, 35,
25, 24, 14, 13, 3, 2, 45, 44, 34, 33, 23, 22, 12, 11, 1, 0);
idx_stage1_base_1 = _mm512_add_epi16(idx_stage1_base_0, M512_EPI16_6);
idx_stage1_base_2 = _mm512_mask_add_epi16(idx_stage1_base_0, BASE_MASK_10, idx_stage1_base_0, M512_EPI16_2);
idx_stage1_base_2 = _mm512_mask_sub_epi16(idx_stage1_base_2, BASE_MASK_20, idx_stage1_base_0, M512_EPI16_2);
idx_stage1_base_3 = _mm512_add_epi16(idx_stage1_base_2, M512_EPI16_6);
idx_stage1_base_4 = _mm512_mask_add_epi16(idx_stage1_base_2, BASE_MASK_10, idx_stage1_base_2, M512_EPI16_2);
idx_stage1_base_4 = _mm512_mask_add_epi16(idx_stage1_base_4, BASE_MASK_20, idx_stage1_base_2, M512_EPI16_2);
idx_stage1_base_4 = _mm512_mask_sub_epi16(idx_stage1_base_4, BASE_MASK_30, idx_stage1_base_2, M512_EPI16_4);
idx_stage1_base_5 = _mm512_add_epi16(idx_stage1_base_4, M512_EPI16_6);
unsigned short idx_stage2_mask_1_value = ((unsigned short)0x03e0);
__mmask16 idx_stage2_mask_1 = *((__mmask16*) &idx_stage2_mask_1_value);
unsigned short idx_stage2_mask_2_value = ((unsigned short)0x7c00);
__mmask16 idx_stage2_mask_2 = *((__mmask16*) &idx_stage2_mask_2_value);
idx_stage2_base_0 = _mm512_set_epi32( 0, 0, 0, 0, 0, 0, 20, 19, 18, 17, 16, 9, 8, 7, 6, 5);
idx_stage2_base_1 = _mm512_set_epi32( 0, 25, 24, 23, 22, 21, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
idx_stage2_base_2 = _mm512_add_epi32(idx_stage2_base_0, M512_EPI32_5);
idx_stage2_base_2 = _mm512_mask_add_epi32(idx_stage2_base_2, idx_stage2_mask_1, idx_stage2_base_2, M512_EPI32_5);
idx_stage2_base_3 = _mm512_mask_sub_epi32(idx_stage2_base_1, idx_stage2_mask_2, idx_stage2_base_1, M512_EPI32_5);
xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0 |x1 |x0 |x1 | ... |x0 |x1 |
xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2 |x3 |x2 |x3 | ... |x2 |x3 |
xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4 |x5 |x4 |x5 | ... |x4 |x5 |
xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6 |x7 |x6 |x7 | ... |x6 |x7 |
xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8 |x9 |x8 |x9 | ... |x8 |x9 |
xArray_10 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_1, 0x1)); // |x10|0 |x10|0 | ... |x10|0 |
unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 9);
__mmask32 load_mask = *((__mmask32*) &load_mask_value);
unsigned short store_mask_value = (((unsigned short)0xffff) >> 1);
__mmask16 store_mask = *((__mmask16*) &store_mask_value);
for (BLASLONG idx_m = 0; idx_m < tag_m_15x; idx_m+=15) {
result_0 = _mm512_setzero_ps();
result_1 = _mm512_setzero_ps();
matrixArray_0 = _mm512_loadu_si512(&a[idx_m*11]); // Load 2 rows with n=11 plus 10 elements
matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[idx_m*11 + 32]); // Load 2 rows with n=11 plus 1 element
matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+5)*11]); // Load 2 rows with n=11 plus 10 elements
matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+5)*11 + 32]); // Load 2 rows with n=11 plus 1 element
matrixArray_4 = _mm512_loadu_si512(&a[(idx_m+10)*11]); // Load 2 rows with n=11 plus 10 elements
matrixArray_5 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+10)*11 + 32]); // Load 2 rows with n=11 plus 1 element
// Stage 1: interleave per 16 bits
matrixArray_stage_0 = _mm512_permutex2var_epi16(matrixArray_0, idx_stage1_base_0, matrixArray_1); // |a0|a1|...|e0|e1|a2|a3|...|e2|e3|a4 |a5|...|e4 |e5|
matrixArray_stage_1 = _mm512_permutex2var_epi16(matrixArray_0, idx_stage1_base_1, matrixArray_1); // |a6|a7|...|e6|e7|a8|a9|...|e8|e9|a10|x |...|e10|x |
matrixArray_stage_2 = _mm512_permutex2var_epi16(matrixArray_2, idx_stage1_base_2, matrixArray_3); // |f2|f3|...|j2|j3|f0|f1|...|j0|j1|f4 |f5|...|j4 |j5|
matrixArray_stage_3 = _mm512_permutex2var_epi16(matrixArray_2, idx_stage1_base_3, matrixArray_3); // |f8|f9|...|j8|j9|f6|f7|...|j6|j7|f10|x |...|j10|x |
matrixArray_stage_4 = _mm512_permutex2var_epi16(matrixArray_4, idx_stage1_base_4, matrixArray_5); // |k4|k5|...|o4|o5|k2|k3|...|o2|o3|k0 |k1|...|o0 |o1|
matrixArray_stage_5 = _mm512_permutex2var_epi16(matrixArray_4, idx_stage1_base_5, matrixArray_5); // |k10|x|...|o10|x|k8|k9|...|o8|o9|k6 |k7|...|o6 |o7|
// Stage 2: interleave per 32 bits
matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_0, matrixArray_stage_2); // |a0|a1|...|j0|j1|x|x|x|x|x|x|x|x|x|x|x|x|
matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_1, matrixArray_stage_3); // |a6|a7|...|j6|j7|x|x|x|x|x|x|x|x|x|x|x|x|
matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_0, matrixArray_stage_2); // |a2|a3|...|j2|j3|x|x|x|x|x|x|x|x|x|x|x|x|
matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_2, matrixArray_stage_2); // |a4|a5|...|j4|j5|x|x|x|x|x|x|x|x|x|x|x|x|
matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_0, matrixArray_stage_3); // |a8|a9|...|j8|j9|x|x|x|x|x|x|x|x|x|x|x|x|
matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_2, matrixArray_stage_3); // |a10|x|...|j10|x|x|x|x|x|x|x|x|x|x|x|x|x|
matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_0, matrixArray_stage_4); // |a0|a1|.......................|o0|o1|x|x|
matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_3, matrixArray_stage_5); // |a6|a7|.......................|o6|o7|x|x|
matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_1 , idx_stage2_base_1, matrixArray_stage_4); // |a2|a3|.......................|o2|o3|x|x|
matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_2 , idx_stage2_base_3, matrixArray_stage_4); // |a4|a5|.......................|o4|o5|x|x|
matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_4 , idx_stage2_base_1, matrixArray_stage_5); // |a8|a9|.......................|o8|o9|x|x|
matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_5 , idx_stage2_base_3, matrixArray_stage_5); // |a10|x|.......................|o10|x|x|x|
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_5, (__m512bh) xArray_10);
result_0 = _mm512_add_ps(result_0, result_1);
STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
}
}
if (tag_m_15x != m) {
__m256i matrixArray256;
__m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
__m256 result256;
__m128 result128, tmp128;
unsigned short load256_mask_value = (((unsigned short)0xffff) >> 5);
__mmask16 load256_mask = *((__mmask16*) &load256_mask_value);
for (BLASLONG i = tag_m_15x; i < m; i++) {
result256 = _mm256_setzero_ps();
matrixArray256 = _mm256_maskz_loadu_epi16(load256_mask, &a[(i)*11]);
result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
tmp128 = _mm_shuffle_ps(result128, result128, 14);
result128 = _mm_add_ps(result128, tmp128);
tmp128 = _mm_shuffle_ps(result128, result128, 1);
result128 = _mm_add_ps(result128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * result128[0] + beta * y[i];
#else
y[i] = alpha * result128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = result128[0] * alpha;
#else
y[i] = result128[0];
#endif
#endif
}
}
return 0;
}
// 15 rows parallel processing BF16 GEMV kernel for n=12 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_15x12_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_15x12_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_15x12_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_15x12(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_15x = m - (m%15);
unsigned char x_load_mask_value = (((unsigned char)0xff) >> 4);
__mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
__m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1| x2| x3|x4|x5|x6|x7|
__m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10|x11| 0| 0| 0| 0|
if (tag_m_15x > 0) {
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5;
__m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3, matrixArray_stage_4, matrixArray_stage_5;
__m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89, xArray_10;
__m512 result_0, result_1;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5;
__m512i idx_stage2_base_0, idx_stage2_base_1, idx_stage2_base_2, idx_stage2_base_3;
__m512i M512_EPI32_1, M512_EPI32_2, M512_EPI32_3, M512_EPI32_5;
M512_EPI32_1 = _mm512_set1_epi32(1);
M512_EPI32_2 = _mm512_add_epi32(M512_EPI32_1, M512_EPI32_1);
M512_EPI32_3 = _mm512_add_epi32(M512_EPI32_2, M512_EPI32_1);
M512_EPI32_5 = _mm512_add_epi32(M512_EPI32_3, M512_EPI32_2);
unsigned short BASE_MASK_10_value = ((unsigned short)0x001f);
__mmask16 BASE_MASK_10 = *((__mmask16*) &BASE_MASK_10_value);
unsigned short BASE_MASK_20_value = ((unsigned short)0x03e0);
__mmask16 BASE_MASK_20 = *((__mmask16*) &BASE_MASK_20_value);
unsigned short BASE_MASK_30_value = ((unsigned short)0xfc00);
__mmask16 BASE_MASK_30 = *((__mmask16*) &BASE_MASK_30_value);
idx_stage1_base_0 = _mm512_set_epi32( 0, 26, 20, 14, 8, 2, 25, 19, 13, 7, 1, 24, 18, 12, 6, 0);
idx_stage1_base_1 = _mm512_add_epi32(idx_stage1_base_0, M512_EPI32_3);
idx_stage1_base_2 = _mm512_mask_add_epi32(idx_stage1_base_0, BASE_MASK_10, idx_stage1_base_0, M512_EPI32_1);
idx_stage1_base_2 = _mm512_mask_sub_epi32(idx_stage1_base_2, BASE_MASK_20, idx_stage1_base_0, M512_EPI32_1);
idx_stage1_base_3 = _mm512_add_epi32(idx_stage1_base_2, M512_EPI32_3);
idx_stage1_base_4 = _mm512_mask_add_epi32(idx_stage1_base_2, BASE_MASK_10, idx_stage1_base_2, M512_EPI32_1);
idx_stage1_base_4 = _mm512_mask_add_epi32(idx_stage1_base_4, BASE_MASK_20, idx_stage1_base_2, M512_EPI32_1);
idx_stage1_base_4 = _mm512_mask_sub_epi32(idx_stage1_base_4, BASE_MASK_30, idx_stage1_base_2, M512_EPI32_2);
idx_stage1_base_5 = _mm512_add_epi32(idx_stage1_base_4, M512_EPI32_3);
unsigned short idx_stage2_mask_1_value = ((unsigned short)0x03e0);
__mmask16 idx_stage2_mask_1 = *((__mmask16*) &idx_stage2_mask_1_value);
unsigned short idx_stage2_mask_2_value = ((unsigned short)0x7c00);
__mmask16 idx_stage2_mask_2 = *((__mmask16*) &idx_stage2_mask_2_value);
idx_stage2_base_0 = _mm512_set_epi32( 0, 0, 0, 0, 0, 0, 20, 19, 18, 17, 16, 9, 8, 7, 6, 5);
idx_stage2_base_1 = _mm512_set_epi32( 0, 25, 24, 23, 22, 21, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
idx_stage2_base_2 = _mm512_add_epi32(idx_stage2_base_0, M512_EPI32_5);
idx_stage2_base_2 = _mm512_mask_add_epi32(idx_stage2_base_2, idx_stage2_mask_1, idx_stage2_base_2, M512_EPI32_5);
idx_stage2_base_3 = _mm512_mask_sub_epi32(idx_stage2_base_1, idx_stage2_mask_2, idx_stage2_base_1, M512_EPI32_5);
xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0 |x1 |x0 |x1 | ... |x0 |x1 |
xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2 |x3 |x2 |x3 | ... |x2 |x3 |
xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4 |x5 |x4 |x5 | ... |x4 |x5 |
xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6 |x7 |x6 |x7 | ... |x6 |x7 |
xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8 |x9 |x8 |x9 | ... |x8 |x9 |
xArray_10 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_1, 0x1)); // |x10|x11|x10|x11| ... |x10|x11|
unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 4);
__mmask32 load_mask = *((__mmask32*) &load_mask_value);
unsigned short store_mask_value = (((unsigned short)0xffff) >> 1);
__mmask16 store_mask = *((__mmask16*) &store_mask_value);
for (BLASLONG idx_m = 0; idx_m < tag_m_15x; idx_m+=15) {
result_0 = _mm512_setzero_ps();
result_1 = _mm512_setzero_ps();
matrixArray_0 = _mm512_loadu_si512(&a[idx_m*12]); // Load 2 rows with n=12 plus 8 elements
matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[idx_m*12 + 32]); // Load 2 rows with n=12 plus 4 element
matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+5)*12]); // Load 2 rows with n=12 plus 8 elements
matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+5)*12 + 32]); // Load 2 rows with n=12 plus 4 element
matrixArray_4 = _mm512_loadu_si512(&a[(idx_m+10)*12]); // Load 2 rows with n=12 plus 8 elements
matrixArray_5 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+10)*12 + 32]); // Load 2 rows with n=12 plus 4 element
// Stage 1: interleave per 16 bits
matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, idx_stage1_base_0, matrixArray_1); // |a0 |a1 |...|e0 |e1 |a2|a3|...|e2|e3|a4 |a5 |...|e4 |e5 |
matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, idx_stage1_base_1, matrixArray_1); // |a6 |a7 |...|e6 |e7 |a8|a9|...|e8|e9|a10|a11|...|e10|e11|
matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, idx_stage1_base_2, matrixArray_3); // |f2 |f3 |...|j2 |j3 |f0|f1|...|j0|j1|f4 |f5 |...|j4 |j5 |
matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, idx_stage1_base_3, matrixArray_3); // |f8 |f9 |...|j8 |j9 |f6|f7|...|j6|j7|f10|f11|...|j10|j11|
matrixArray_stage_4 = _mm512_permutex2var_epi32(matrixArray_4, idx_stage1_base_4, matrixArray_5); // |k4 |k5 |...|o4 |o5 |k2|k3|...|o2|o3|k0 |k1 |...|o0 |o1 |
matrixArray_stage_5 = _mm512_permutex2var_epi32(matrixArray_4, idx_stage1_base_5, matrixArray_5); // |k10|k11|...|o10|o11|k8|k9|...|o8|o9|k6 |k7 |...|o6 |o7 |
// Stage 2: interleave per 32 bits
matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_0, matrixArray_stage_2); // |a0 |a1 |...|j0 |j1 |x|x|x|x|x|x|x|x|x|x|x|x|
matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_1, matrixArray_stage_3); // |a6 |a7 |...|j6 |j7 |x|x|x|x|x|x|x|x|x|x|x|x|
matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_0, matrixArray_stage_2); // |a2 |a3 |...|j2 |j3 |x|x|x|x|x|x|x|x|x|x|x|x|
matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_2, matrixArray_stage_2); // |a4 |a5 |...|j4 |j5 |x|x|x|x|x|x|x|x|x|x|x|x|
matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_0, matrixArray_stage_3); // |a8 |a9 |...|j8 |j9 |x|x|x|x|x|x|x|x|x|x|x|x|
matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_2, matrixArray_stage_3); // |a10|a11|...|j10|j11|x|x|x|x|x|x|x|x|x|x|x|x|
matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_0, matrixArray_stage_4); // |a0|a1|.......................|o0|o1|x|x|
matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_3, matrixArray_stage_5); // |a6|a7|.......................|o6|o7|x|x|
matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_1 , idx_stage2_base_1, matrixArray_stage_4); // |a2|a3|.......................|o2|o3|x|x|
matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_2 , idx_stage2_base_3, matrixArray_stage_4); // |a4|a5|.......................|o4|o5|x|x|
matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_4 , idx_stage2_base_1, matrixArray_stage_5); // |a8|a9|.......................|o8|o9|x|x|
matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_5 , idx_stage2_base_3, matrixArray_stage_5); // |a10|x|.......................|o10|x|x|x|
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_5, (__m512bh) xArray_10);
result_0 = _mm512_add_ps(result_0, result_1);
STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
}
}
if (tag_m_15x != m) {
__m256i matrixArray256;
__m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
__m256 result256;
__m128 result128, tmp128;
unsigned short load256_mask_value = (((unsigned short)0xffff) >> 4);
__mmask16 load256_mask = *((__mmask16*) &load256_mask_value);
for (BLASLONG i = tag_m_15x; i < m; i++) {
result256 = _mm256_setzero_ps();
matrixArray256 = _mm256_maskz_loadu_epi16(load256_mask, &a[(i)*12]);
result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
tmp128 = _mm_shuffle_ps(result128, result128, 14);
result128 = _mm_add_ps(result128, tmp128);
tmp128 = _mm_shuffle_ps(result128, result128, 1);
result128 = _mm_add_ps(result128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * result128[0] + beta * y[i];
#else
y[i] = alpha * result128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = result128[0] * alpha;
#else
y[i] = result128[0];
#endif
#endif
}
}
return 0;
}
// 16 rows parallel processing BF16 GEMV kernel for n=13 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_16x13_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_16x13_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_16x13_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_16x13(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_16x = m & (~15);
unsigned short x_load_mask_value = (((unsigned short)0xffff) >> 3);
__mmask16 x_load_mask = *((__mmask16*) &x_load_mask_value);
__m256i x256 = _mm256_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|0|0|0|
if (tag_m_16x > 0) {
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
__m512i xArray_0, xArray_1, xArray_2, xArray_3;
__m512 accum512_0, accum512_1;
__m512 result_0, result_1;
__m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512i M512_EPI32_4 = _mm512_set1_epi32(4);
__m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
// Prepare X with 2-step interleave way
xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
BF16_INTERLEAVE_1x32(xArray)
for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
// Load matrix
BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 13, idx_m, 0, x_load_mask)
matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 13, idx_m+8, 0, x_load_mask)
matrixArray_12 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
matrixArray_13 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
matrixArray_14 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
matrixArray_15 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
// interleave per 256 bits
BF16_INTERLEAVE256_8x32(matrixArray)
// 2-step interleave for matrix
BF16_INTERLEAVE_8x32(matrixArray)
// Calculate the temp result for a..p[0:15]
BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
// Reorder and add up the final result
result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
result_0 = _mm512_add_ps(result_0, result_1);
STORE16_COMPLETE_RESULT(result_0, y+idx_m)
}
if (m - tag_m_16x > 7) {
__m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
// Load matrix
BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 13, tag_m_16x, 0, x_load_mask)
matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
// interleave per 256 bits
matrixArray_0 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0x44);
matrixArray_1 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0xee);
matrixArray_2 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0x44);
matrixArray_3 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0xee);
// 2-step interleave for matrix
BF16_INTERLEAVE_4x32(matrixArray)
// Calculate the temp result for a..h[0:15]
BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
__m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
tag_m_16x += 8;
}
if (m - tag_m_16x > 3) {
__m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
__m256 accum256_0, accum256_1;
xArray256_0 = _mm512_castsi512_si256(xArray_0);
xArray256_1 = _mm512_castsi512_si256(xArray_1);
xArray256_2 = _mm512_castsi512_si256(xArray_2);
xArray256_3 = _mm512_castsi512_si256(xArray_3);
accum256_0 = _mm256_setzero_ps();
accum256_1 = _mm256_setzero_ps();
BF16_MATRIX_MASKZ_LOAD_4x16(matrixArray256, a, 13, tag_m_16x, 0, x_load_mask)
// 2-step interleave for matrix
BF16_INTERLEAVE_4x16(matrixArray256)
// Calculate the temp result for a..d[0:15]
BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
__m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
tag_m_16x += 4;
}
}
if (tag_m_16x != m) {
__m256i matrixArray256;
__m256 accum256;
__m128 accum128, tmp128;
for (BLASLONG i = tag_m_16x; i < m; i++) {
accum256 = _mm256_setzero_ps();
matrixArray256 = _mm256_maskz_loadu_epi16(x_load_mask, &a[(i)*13]); // Load 1 rows with n=13
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
accum128 = _mm_add_ps(accum128, tmp128);
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
accum128 = _mm_add_ps(accum128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * accum128[0] + beta * y[i];
#else
y[i] = alpha * accum128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = accum128[0] * alpha;
#else
y[i] = accum128[0];
#endif
#endif
}
}
return 0;
}
// 16 rows parallel processing BF16 GEMV kernel for n=14 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_16x14_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_16x14_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_16x14_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_16x14(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_16x = m & (~15);
unsigned short x_load_mask_value = (((unsigned short)0xffff) >> 2);
__mmask16 x_load_mask = *((__mmask16*) &x_load_mask_value);
__m256i x256 = _mm256_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|0|0|
if (tag_m_16x > 0) {
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
__m512i xArray_0, xArray_1, xArray_2, xArray_3;
__m512 accum512_0, accum512_1;
__m512 result_0, result_1;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512i M512_EPI32_4 = _mm512_set1_epi32(4);
__m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
__m512i shift_idx = _mm512_set_epi32(0, 13, 12, 11, 10, 9, 8, 7, 0, 6, 5, 4, 3, 2, 1, 0);
unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 4);
__mmask32 load_mask = *((__mmask32*) &load_mask_value);
// Prepare X with 2-step interleave way
xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
BF16_INTERLEAVE_1x32(xArray)
for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
// Load matrix
BF16_MATRIX_MASKZ_LOAD_8x32_2(matrixArray, a, 14, idx_m, 0, load_mask)
// Pre-stage: shift the 2nd vector 1 position right for each register
BF16_PERMUTE_8x32_2(shift_idx, matrixArray)
// interleave per 256 bits
BF16_INTERLEAVE256_8x32(matrixArray)
// 2-step interleave for matrix
BF16_INTERLEAVE_8x32(matrixArray)
// Calculate the temp result for a..p[0:15]
BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
// Reorder and add up the final result
result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
result_0 = _mm512_add_ps(result_0, result_1);
STORE16_COMPLETE_RESULT(result_0, y+idx_m)
}
if (m - tag_m_16x > 7) {
__m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
// Load matrix
BF16_MATRIX_MASKZ_LOAD_4x32_2(matrixArray, a, 14, tag_m_16x, 0, load_mask)
// Pre-stage: shift the 2nd vector 1 position right for each register
BF16_PERMUTE_4x32_2(shift_idx, matrixArray)
// interleave per 256 bits
BF16_INTERLEAVE256_4x32(matrixArray)
// 2-step interleave for matrix
BF16_INTERLEAVE_4x32(matrixArray)
// Calculate the temp result for a..h[0:15]
BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
__m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
tag_m_16x += 8;
}
if (m - tag_m_16x > 3) {
__m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
__m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
__m256 accum256_0, accum256_1;
xArray256_0 = _mm512_castsi512_si256(xArray_0);
xArray256_1 = _mm512_castsi512_si256(xArray_1);
xArray256_2 = _mm512_castsi512_si256(xArray_2);
xArray256_3 = _mm512_castsi512_si256(xArray_3);
accum256_0 = _mm256_setzero_ps();
accum256_1 = _mm256_setzero_ps();
BF16_MATRIX_MASKZ_LOAD_4x16(matrixArray256, a, 14, tag_m_16x, 0, x_load_mask)
// 2-step interleave for matrix
BF16_INTERLEAVE_4x16(matrixArray256)
// Calculate the temp result for a..d[0:15]
BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
__m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
tag_m_16x += 4;
}
}
if (tag_m_16x != m) {
__m256i matrixArray256;
__m256 accum256;
__m128 accum128, tmp128;
for (BLASLONG i = tag_m_16x; i < m; i++) {
accum256 = _mm256_setzero_ps();
matrixArray256 = _mm256_maskz_loadu_epi16(x_load_mask, &a[(i)*14]); // Load 1 rows with n=14
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
accum128 = _mm_add_ps(accum128, tmp128);
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
accum128 = _mm_add_ps(accum128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * accum128[0] + beta * y[i];
#else
y[i] = alpha * accum128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = accum128[0] * alpha;
#else
y[i] = accum128[0];
#endif
#endif
}
}
return 0;
}
// 16 rows parallel processing BF16 GEMV kernel for n=15 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_16x15_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_16x15_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_16x15_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_16x15(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_16x = m & (~15);
unsigned short x_load_mask_value = (((unsigned short)0xffff) >> 1);
__mmask16 x_load_mask = *((__mmask16*) &x_load_mask_value);
__m256i x256 = _mm256_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|0|
if (tag_m_16x > 0) {
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
__m512i xArray_0, xArray_1, xArray_2, xArray_3;
__m512 accum512_0, accum512_1;
__m512 result_0, result_1;
__m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512i M512_EPI32_4 = _mm512_set1_epi32(4);
__m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
// Prepare X with 2-step interleave way
xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
BF16_INTERLEAVE_1x32(xArray)
for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
// Load matrix
BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 15, idx_m, 0, x_load_mask)
matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 15, idx_m+8, 0, x_load_mask)
matrixArray_12 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
matrixArray_13 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
matrixArray_14 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
matrixArray_15 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
// interleave per 256 bits
BF16_INTERLEAVE256_8x32(matrixArray)
// 2-step interleave for matrix
BF16_INTERLEAVE_8x32(matrixArray)
// Calculate the temp result for a..p[0:15]
BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
// Reorder and add up the final result
result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
result_0 = _mm512_add_ps(result_0, result_1);
STORE16_COMPLETE_RESULT(result_0, y+idx_m)
}
if (m - tag_m_16x > 7) {
__m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
// Load matrix
BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 15, tag_m_16x, 0, x_load_mask)
matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
// interleave per 256 bits
matrixArray_0 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0x44);
matrixArray_1 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0xee);
matrixArray_2 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0x44);
matrixArray_3 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0xee);
// 2-step interleave for matrix
BF16_INTERLEAVE_4x32(matrixArray)
// Calculate the temp result for a..h[0:15]
BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
__m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
tag_m_16x += 8;
}
if (m - tag_m_16x > 3) {
__m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
__m256 accum256_0, accum256_1;
xArray256_0 = _mm512_castsi512_si256(xArray_0);
xArray256_1 = _mm512_castsi512_si256(xArray_1);
xArray256_2 = _mm512_castsi512_si256(xArray_2);
xArray256_3 = _mm512_castsi512_si256(xArray_3);
accum256_0 = _mm256_setzero_ps();
accum256_1 = _mm256_setzero_ps();
BF16_MATRIX_MASKZ_LOAD_4x16(matrixArray256, a, 15, tag_m_16x, 0, x_load_mask)
// 2-step interleave for matrix
BF16_INTERLEAVE_4x16(matrixArray256)
// Calculate the temp result for a..d[0:15]
BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
__m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
tag_m_16x += 4;
}
}
if (tag_m_16x != m) {
__m256i matrixArray256;
__m256 accum256;
__m128 accum128, tmp128;
for (BLASLONG i = tag_m_16x; i < m; i++) {
accum256 = _mm256_setzero_ps();
matrixArray256 = _mm256_maskz_loadu_epi16(x_load_mask, &a[(i)*15]); // Load 1 rows with n=15
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
accum128 = _mm_add_ps(accum128, tmp128);
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
accum128 = _mm_add_ps(accum128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * accum128[0] + beta * y[i];
#else
y[i] = alpha * accum128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = accum128[0] * alpha;
#else
y[i] = accum128[0];
#endif
#endif
}
}
return 0;
}
// 16 rows parallel processing BF16 GEMV kernel for n=16 && lda ineffective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_16x16_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_16x16_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_16x16_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_16x = m & (~15);
__m256i x256 = _mm256_loadu_si256((__m256i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15|
if (tag_m_16x > 0) {
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
__m512i xArray_0, xArray_1, xArray_2, xArray_3;
__m512 accum512_0, accum512_1;
__m512 result_0, result_1;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512i M512_EPI32_4 = _mm512_set1_epi32(4);
__m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
// Prepare X with 2-step interleave way
xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
BF16_INTERLEAVE_1x32(xArray)
for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
matrixArray_8 = _mm512_loadu_si512(&a[(idx_m )*16]); // Load 2 rows with n=16
matrixArray_9 = _mm512_loadu_si512(&a[(idx_m+2 )*16]); // Load 2 rows with n=16
matrixArray_10 = _mm512_loadu_si512(&a[(idx_m+4 )*16]); // Load 2 rows with n=16
matrixArray_11 = _mm512_loadu_si512(&a[(idx_m+6 )*16]); // Load 2 rows with n=16
matrixArray_12 = _mm512_loadu_si512(&a[(idx_m+8 )*16]); // Load 2 rows with n=16
matrixArray_13 = _mm512_loadu_si512(&a[(idx_m+10)*16]); // Load 2 rows with n=16
matrixArray_14 = _mm512_loadu_si512(&a[(idx_m+12)*16]); // Load 2 rows with n=16
matrixArray_15 = _mm512_loadu_si512(&a[(idx_m+14)*16]); // Load 2 rows with n=16
// interleave per 256 bits
BF16_INTERLEAVE256_8x32(matrixArray)
// 2-step interleave for matrix
BF16_INTERLEAVE_8x32(matrixArray)
// Calculate the temp result for a..p[0:15]
BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
// Reorder and add up the final result
result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
result_0 = _mm512_add_ps(result_0, result_1);
STORE16_COMPLETE_RESULT(result_0, y+idx_m)
}
if (m - tag_m_16x > 7) {
__m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
matrixArray_4 = _mm512_loadu_si512(&a[(tag_m_16x )*16]); // Load 2 rows with n=16
matrixArray_5 = _mm512_loadu_si512(&a[(tag_m_16x+2 )*16]); // Load 2 rows with n=16
matrixArray_6 = _mm512_loadu_si512(&a[(tag_m_16x+4 )*16]); // Load 2 rows with n=16
matrixArray_7 = _mm512_loadu_si512(&a[(tag_m_16x+6 )*16]); // Load 2 rows with n=16
// interleave per 256 bits
BF16_INTERLEAVE256_4x32(matrixArray)
// 2-step interleave for matrix
BF16_INTERLEAVE_4x32(matrixArray)
// Calculate the temp result for a..h[0:15]
BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
__m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
tag_m_16x += 8;
}
if (m - tag_m_16x > 3) {
__m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, \
matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
__m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
__m256 accum256_0, accum256_1;
xArray256_0 = _mm512_castsi512_si256(xArray_0);
xArray256_1 = _mm512_castsi512_si256(xArray_1);
xArray256_2 = _mm512_castsi512_si256(xArray_2);
xArray256_3 = _mm512_castsi512_si256(xArray_3);
accum256_0 = _mm256_setzero_ps();
accum256_1 = _mm256_setzero_ps();
matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x )*16]); // Load 2 rows with n=16
matrixArray_1 = _mm512_loadu_si512(&a[(tag_m_16x+2 )*16]); // Load 2 rows with n=16
matrixArray256_0 = _mm512_castsi512_si256(matrixArray_0);
matrixArray256_1 = _mm512_extracti32x8_epi32(matrixArray_0, 0x1);
matrixArray256_2 = _mm512_castsi512_si256(matrixArray_1);
matrixArray256_3 = _mm512_extracti32x8_epi32(matrixArray_1, 0x1);
// 2-step interleave for matrix
BF16_INTERLEAVE_4x16(matrixArray256)
// Calculate the temp result for a..d[0:15]
BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
__m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
tag_m_16x += 4;
}
}
if (tag_m_16x != m) {
__m256i matrixArray256;
__m256 accum256;
__m128 accum128, tmp128;
for (BLASLONG i = tag_m_16x; i < m; i++) {
accum256 = _mm256_setzero_ps();
matrixArray256 = _mm256_loadu_si256((__m256i *)&a[(i)*16]); // Load 1 rows with n=16
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
accum128 = _mm_add_ps(accum128, tmp128);
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
accum128 = _mm_add_ps(accum128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * accum128[0] + beta * y[i];
#else
y[i] = alpha * accum128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = accum128[0] * alpha;
#else
y[i] = accum128[0];
#endif
#endif
}
}
return 0;
}
// 8 rows parallel processing BF16 GEMV kernel for n>16 && lda effective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_8x16p_lda_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_8x16p_lda_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_8x16p_lda_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_8x16p_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_8x = m & (~7);
unsigned int load_mask_value = (((unsigned int)0xffffffff) >> (32-n));
__mmask32 load_mask = *((__mmask32*) &load_mask_value);
__m512i x512 = _mm512_maskz_loadu_epi16(load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15|...
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
__m512 accum512_0, accum512_1, accum512_2, accum512_3;
__m256 accum256;
__m128 accum128;
if (tag_m_8x > 0) {
__m512i xArray_0, xArray_1, xArray_2, xArray_3;
__m512i M512_EPI32_4 = _mm512_set1_epi32(4);
__m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
// Prepare X with 2-step interleave way
xArray_0 = x512;
BF16_INTERLEAVE_1x32(xArray)
for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
// Load 8 rows from matrix
BF16_MATRIX_MASKZ_LOAD_8x32(matrixArray, a, lda, idx_m, 0, load_mask)
// 2-step interleave for matrix
BF16_INTERLEAVE_8x32(matrixArray)
// Calculate the temp result for a..h[0:31]
BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
// Reorder and add up the final result
accum512_2 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
accum512_3 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
accum512_2 = _mm512_add_ps(accum512_2, accum512_3);
accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_2), _mm512_extractf32x8_ps(accum512_2, 1));
STORE8_COMPLETE_RESULT(accum256, y+idx_m)
}
if (m - tag_m_8x > 3) {
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
// Load 4 rows from matrix
BF16_MATRIX_MASKZ_LOAD_4x32(matrixArray, a, lda, tag_m_8x, 0, load_mask)
// 2-step interleave for matrix
BF16_INTERLEAVE_4x32(matrixArray)
// Calculate the temp result for a..d[0:31]
BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
STORE4_COMPLETE_RESULT(accum128, y+tag_m_8x)
tag_m_8x += 4;
}
}
if (tag_m_8x != m) {
__m128 tmp128;
for (BLASLONG i = tag_m_8x; i < m; i++) {
accum512_0 = _mm512_setzero_ps();
matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(i)*lda]); // Load 1 rows with n=16
accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) matrixArray_0, (__m512bh) x512);
accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
accum128 = _mm_add_ps(accum128, tmp128);
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
accum128 = _mm_add_ps(accum128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * accum128[0] + beta * y[i];
#else
y[i] = alpha * accum128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = accum128[0] * alpha;
#else
y[i] = accum128[0];
#endif
#endif
}
}
return 0;
}
// 8 rows parallel processing BF16 GEMV kernel for big N && lda effective scenario (process before interleave)
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_1x128_lda_direct_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_1x128_lda_direct_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_1x128_lda_direct_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_1x128_lda_direct(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_8x = m & (~7);
BLASLONG tag_n_32x = n & (~31);
BLASLONG tag_n_128x = n & (~127);
__m512 accum512_bridge[8];
__m512 accum512_t_0, accum512_t_1, accum512_t_2, accum512_t_3;
__m256 accum256_0;
__m128 accum128;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
__m512i xArray_0, xArray_1, xArray_2, xArray_3;
unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(n&31)));
__mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
__m512i M512_EPI32_4 = _mm512_set1_epi32(4);
__m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
if (tag_m_8x > 0) {
for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
for (int j = idx_m; j < idx_m + 8; j++) {
accum512_t_0 = _mm512_setzero_ps();
accum512_t_1 = _mm512_setzero_ps();
accum512_t_2 = _mm512_setzero_ps();
accum512_t_3 = _mm512_setzero_ps();
/* Processing the main chunk with 128-elements per round */
for (long idx_n = 0; idx_n < tag_n_128x; idx_n += 128) {
BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n + 0)
BF16_MATRIX_LOAD_1x32(matrixArray_1, a, lda, j, idx_n + 32)
BF16_MATRIX_LOAD_1x32(matrixArray_2, a, lda, j, idx_n + 64)
BF16_MATRIX_LOAD_1x32(matrixArray_3, a, lda, j, idx_n + 96)
BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n + 0)
BF16_VECTOR_LOAD_1x32(xArray_1, x, idx_n + 32)
BF16_VECTOR_LOAD_1x32(xArray_2, x, idx_n + 64)
BF16_VECTOR_LOAD_1x32(xArray_3, x, idx_n + 96)
BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
BF16_DOT_1x32(accum512_t_1, matrixArray_1, xArray_1)
BF16_DOT_1x32(accum512_t_2, matrixArray_2, xArray_2)
BF16_DOT_1x32(accum512_t_3, matrixArray_3, xArray_3)
}
/* Processing the remaining <128 chunk with 32-elements per round */
for (long idx_n = tag_n_128x; idx_n < tag_n_32x; idx_n += 32) {
BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n)
BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
}
/* Processing the remaining <32 chunk with masked 32-elements processing */
if ((n&31) != 0) {
BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_0, a, lda, j, tag_n_32x, tail_mask)
BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
BF16_DOT_1x32(accum512_t_2, matrixArray_0, xArray_0)
}
/* Accumulate the 4 registers into 1 register */
accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_1);
accum512_t_2 = _mm512_add_ps(accum512_t_2, accum512_t_3);
accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_2);
// Temply save the result into a ZMM
accum512_bridge[j-idx_m] = accum512_t_0;
}
FP32_INTERLEAVE_8x16_ARRAY(accum512_bridge)
FP32_ACCUM2_8x16_ARRAY(accum512_bridge)
accum512_bridge[1] = _mm512_permutex2var_ps(accum512_bridge[0], idx_base_0, accum512_bridge[4]);
accum512_bridge[2] = _mm512_permutex2var_ps(accum512_bridge[0], idx_base_1, accum512_bridge[4]);
accum512_bridge[1] = _mm512_add_ps(accum512_bridge[1], accum512_bridge[2]);
accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_bridge[1]), _mm512_extractf32x8_ps(accum512_bridge[1], 1));
STORE8_COMPLETE_RESULT(accum256_0, y+idx_m)
}
}
if (tag_m_8x != m) {
__m128 tmp128;
for (BLASLONG j = tag_m_8x; j < m; j++) {
accum512_t_0 = _mm512_setzero_ps();
accum512_t_1 = _mm512_setzero_ps();
accum512_t_2 = _mm512_setzero_ps();
accum512_t_3 = _mm512_setzero_ps();
/* Processing the main chunk with 128-elements per round */
for (long idx_n = 0; idx_n < tag_n_128x; idx_n += 128) {
BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n + 0)
BF16_MATRIX_LOAD_1x32(matrixArray_1, a, lda, j, idx_n + 32)
BF16_MATRIX_LOAD_1x32(matrixArray_2, a, lda, j, idx_n + 64)
BF16_MATRIX_LOAD_1x32(matrixArray_3, a, lda, j, idx_n + 96)
BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n + 0)
BF16_VECTOR_LOAD_1x32(xArray_1, x, idx_n + 32)
BF16_VECTOR_LOAD_1x32(xArray_2, x, idx_n + 64)
BF16_VECTOR_LOAD_1x32(xArray_3, x, idx_n + 96)
BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
BF16_DOT_1x32(accum512_t_1, matrixArray_1, xArray_1)
BF16_DOT_1x32(accum512_t_2, matrixArray_2, xArray_2)
BF16_DOT_1x32(accum512_t_3, matrixArray_3, xArray_3)
}
/* Processing the remaining <128 chunk with 32-elements per round */
for (long idx_n = tag_n_128x; idx_n < tag_n_32x; idx_n += 32) {
BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n)
BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
}
/* Processing the remaining <32 chunk with masked 32-elements processing */
if ((n&31) != 0) {
BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_0, a, lda, j, tag_n_32x, tail_mask)
BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
BF16_DOT_1x32(accum512_t_2, matrixArray_0, xArray_0)
}
/* Accumulate the 4 registers into 1 register */
accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_1);
accum512_t_2 = _mm512_add_ps(accum512_t_2, accum512_t_3);
accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_2);
accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_t_0), _mm512_extractf32x8_ps(accum512_t_0, 1));
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
accum128 = _mm_add_ps(accum128, tmp128);
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
accum128 = _mm_add_ps(accum128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[j] = alpha * accum128[0] + beta * y[j];
#else
y[j] = alpha * accum128[0] + y[j];
#endif
#else
#ifndef ONE_ALPHA
y[j] = accum128[0] * alpha;
#else
y[j] = accum128[0];
#endif
#endif
}
}
return 0;
}
// 8 rows parallel processing BF16 GEMV kernel for n=32 && lda effective scenario (process before interleave)
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_8x32_lda_direct_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_8x32_lda_direct_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_8x32_lda_direct_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_8x32_lda_direct(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_8x = m & (~7);
BLASLONG tag_n_32x = n & (~31);
__m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7, \
accum512_8, accum512_9, accum512_10, accum512_11, accum512_12, accum512_13, accum512_14, accum512_15;
__m256 accum256_0;
__m128 accum128;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
#endif
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7;
__m512i xArray_0;
unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(n&31)));
__mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
if (tag_m_8x > 0) {
__m512i M512_EPI32_4 = _mm512_set1_epi32(4);
__m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
accum512_2 = _mm512_setzero_ps();
accum512_3 = _mm512_setzero_ps();
accum512_4 = _mm512_setzero_ps();
accum512_5 = _mm512_setzero_ps();
accum512_6 = _mm512_setzero_ps();
accum512_7 = _mm512_setzero_ps();
for (BLASLONG idx_n = 0; idx_n < tag_n_32x; idx_n+=32) {
// Load 8 rows from matrix
BF16_MATRIX_LOAD_8x32(matrixArray, a, lda, idx_m, idx_n)
// Load x
BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
// Calculate the temp result for a..h[0:31]
BF16_DOT_8x32(accum512, matrixArray, xArray_0)
}
if (tag_n_32x != n) { // Go with masked 512
// Load 8 rows from matrix
BF16_MATRIX_MASKZ_LOAD_8x32(matrixArray, a, lda, idx_m, tag_n_32x, tail_mask)
// Load x
BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
// Calculate the temp result for a..h[0:31]
BF16_DOT_8x32(accum512, matrixArray, xArray_0)
}
// 2-step interleave for FP32 regsiter array
FP32_INTERLEAVE_8x16(accum512)
// Accumulate the 2 batch of registers into 2 register (0 and 4)
FP32_ACCUM2_8x16(accum512)
accum512_1 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_4);
accum512_2 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_4);
accum512_1 = _mm512_add_ps(accum512_1, accum512_2);
accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_1), _mm512_extractf32x8_ps(accum512_1, 1));
STORE8_COMPLETE_RESULT(accum256_0, y+idx_m)
}
}
if (tag_m_8x != m) {
__m128 tmp128;
for (BLASLONG i = tag_m_8x; i < m; i++) {
accum512_0 = _mm512_setzero_ps();
for (BLASLONG idx_n = 0; idx_n < tag_n_32x; idx_n+=32) {
// Load 32 elements from matrix
BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, i, idx_n)
// Load 32 elements from x
BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
// Calculate and accumulate the temp result
BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0)
}
if (tag_n_32x != n) {
// Load tail elements from matrix
BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_0, a, lda, i, tag_n_32x, tail_mask)
// Load 32 elements from x
BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
// Calculate and accumulate the temp result
BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0)
}
accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
accum128 = _mm_add_ps(accum128, tmp128);
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
accum128 = _mm_add_ps(accum128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * accum128[0] + beta * y[i];
#else
y[i] = alpha * accum128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = accum128[0] * alpha;
#else
y[i] = accum128[0];
#endif
#endif
}
}
return 0;
}
// 8 rows parallel processing BF16 GEMV kernel for n<16 && lda effective scenario
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_8x16m_lda_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_8x16m_lda_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_8x16m_lda_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_8x16m_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_8x = m & (~7);
__m256i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7;
__m256i xArray256;
// Keep align with other kernels and macro definition, the high 256bit is never used
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(alpha));
#endif
#ifndef ZERO_BETA
#ifndef ONE_BETA
__m512 BETAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(beta));
#endif
#endif
__m256 accum256_0, accum256_1, accum256_2, accum256_3, accum256_4, accum256_5, accum256_6, accum256_7, \
accum256_8, accum256_9, accum256_10, accum256_11, accum256_12, accum256_13, accum256_14, accum256_15;
__m256i M256_EPI32_4 = _mm256_set1_epi32(4);
__m256i idx_base_0 = _mm256_set_epi32(11, 10, 9, 8, 3, 2, 1, 0);
__m256i idx_base_1 = _mm256_add_epi32(idx_base_0, M256_EPI32_4);
unsigned short load_mask_value = (((unsigned short)0xffff) >> (16-n));
__mmask16 load_mask = *((__mmask16*) &load_mask_value);
if (n == 16) {
BF16_VECTOR_LOAD_1x16(xArray256, x, 0)
} else {
BF16_VECTOR_MASKZ_LOAD_1x16(xArray256, x, 0, load_mask)
}
if (n == 16) {
for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
accum256_0 = _mm256_setzero_ps();
accum256_1 = _mm256_setzero_ps();
accum256_2 = _mm256_setzero_ps();
accum256_3 = _mm256_setzero_ps();
accum256_4 = _mm256_setzero_ps();
accum256_5 = _mm256_setzero_ps();
accum256_6 = _mm256_setzero_ps();
accum256_7 = _mm256_setzero_ps();
BF16_MATRIX_LOAD_8x16(matrixArray, a, lda, idx_m, 0)
BF16_DOT_8x16(accum256, matrixArray, xArray256)
// 2-step interleave for FP32 regsiter array
FP32_INTERLEAVE_8x8(accum256)
// Accumulate the 2 batch of registers into 2 register (0 and 4)
FP32_ACCUM2_8x8(accum256)
accum256_1 = _mm256_permutex2var_ps(accum256_0, idx_base_0, accum256_4);
accum256_2 = _mm256_permutex2var_ps(accum256_0, idx_base_1, accum256_4);
accum256_1 = _mm256_add_ps(accum256_1, accum256_2);
STORE8_COMPLETE_RESULT(accum256_1, y+idx_m)
}
if (tag_m_8x != m) {
__m128 accum128, tmp128;
for (BLASLONG i = tag_m_8x; i < m; i++) {
accum256_0 = _mm256_setzero_ps();
matrixArray_0 = _mm256_loadu_si256((__m256i *)&a[(i)*lda]); // Load 1 rows with n=16
accum256_0 = _mm256_dpbf16_ps(accum256_0, (__m256bh) matrixArray_0, (__m256bh) xArray256);
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
accum128 = _mm_add_ps(accum128, tmp128);
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
accum128 = _mm_add_ps(accum128, tmp128);
y[i] += accum128[0] * alpha;
}
}
} else {
for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
accum256_0 = _mm256_setzero_ps();
accum256_1 = _mm256_setzero_ps();
accum256_2 = _mm256_setzero_ps();
accum256_3 = _mm256_setzero_ps();
accum256_4 = _mm256_setzero_ps();
accum256_5 = _mm256_setzero_ps();
accum256_6 = _mm256_setzero_ps();
accum256_7 = _mm256_setzero_ps();
BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray, a, lda, idx_m, 0, load_mask)
BF16_DOT_8x16(accum256, matrixArray, xArray256)
// 2-step interleave for FP32 regsiter array
FP32_INTERLEAVE_8x8(accum256)
// Accumulate the 2 batch of registers into 2 register (0 and 4)
FP32_ACCUM2_8x8(accum256)
accum256_1 = _mm256_permutex2var_ps(accum256_0, idx_base_0, accum256_4);
accum256_2 = _mm256_permutex2var_ps(accum256_0, idx_base_1, accum256_4);
accum256_1 = _mm256_add_ps(accum256_1, accum256_2);
STORE8_COMPLETE_RESULT(accum256_1, y+idx_m)
}
if (tag_m_8x != m) {
__m128 accum128, tmp128;
for (BLASLONG i = tag_m_8x; i < m; i++) {
accum256_0 = _mm256_setzero_ps();
matrixArray_0 = _mm256_maskz_loadu_epi16(load_mask, &a[(i)*lda]); // Load 1 rows with n=16
accum256_0 = _mm256_dpbf16_ps(accum256_0, (__m256bh) matrixArray_0, (__m256bh) xArray256);
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
accum128 = _mm_add_ps(accum128, tmp128);
tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
accum128 = _mm_add_ps(accum128, tmp128);
#ifndef ZERO_BETA
#ifndef ONE_BETA
y[i] = alpha * accum128[0] + beta * y[i];
#else
y[i] = alpha * accum128[0] + y[i];
#endif
#else
#ifndef ONE_ALPHA
y[i] = accum128[0] * alpha;
#else
y[i] = accum128[0];
#endif
#endif
}
}
}
return 0;
}