Implementation of BF16 based gemv

1. Add a new API -- sbgemv to support bfloat16 based gemv
2. Implement a generic kernel for sbgemv
3. Implement an avx512-bf16 based kernel for sbgemv

Signed-off-by: Chen, Guobing <guobing.chen@intel.com>
This commit is contained in:
Chen, Guobing
2020-10-28 08:49:12 +08:00
parent 67f39ad813
commit a7b1f9b1bb
24 changed files with 5111 additions and 16 deletions

View File

@@ -384,6 +384,14 @@ endif
GEMVDEP = ../l2param.h
ifndef SBGEMVNKERNEL
SBGEMVNKERNEL = sbgemv_n.c
endif
ifndef SBGEMVTKERNEL
SBGEMVTKERNEL = sbgemv_t.c
endif
ifndef SGEMVNKERNEL
SGEMVNKERNEL = sgemv_n.c
endif

View File

@@ -0,0 +1,795 @@
/***************************************************************************
Copyright (c) 2014, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#ifndef __BF16_COMMON_MACROS
#define __BF16_COMMON_MACROS
#include <immintrin.h>
#define EXTRACT_LOW_256_FROM_512_2X(reg256, reg512) \
reg256##_0 = _mm512_castps512_ps256(reg512##_0); \
reg256##_1 = _mm512_castps512_ps256(reg512##_1);
#define BF16_MATRIX_LOAD_8x32(regArray, a, lda, idx_m, idx_n) \
regArray##_0 = _mm512_loadu_si512(&a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm512_loadu_si512(&a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm512_loadu_si512(&a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm512_loadu_si512(&a[(idx_m+3)*lda + idx_n]); \
regArray##_4 = _mm512_loadu_si512(&a[(idx_m+4)*lda + idx_n]); \
regArray##_5 = _mm512_loadu_si512(&a[(idx_m+5)*lda + idx_n]); \
regArray##_6 = _mm512_loadu_si512(&a[(idx_m+6)*lda + idx_n]); \
regArray##_7 = _mm512_loadu_si512(&a[(idx_m+7)*lda + idx_n]);
#define BF16_MATRIX_LOAD_8x16(regArray, a, lda, idx_m, idx_n) \
regArray##_0 = _mm256_loadu_si256(&a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm256_loadu_si256(&a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm256_loadu_si256(&a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm256_loadu_si256(&a[(idx_m+3)*lda + idx_n]); \
regArray##_4 = _mm256_loadu_si256(&a[(idx_m+4)*lda + idx_n]); \
regArray##_5 = _mm256_loadu_si256(&a[(idx_m+5)*lda + idx_n]); \
regArray##_6 = _mm256_loadu_si256(&a[(idx_m+6)*lda + idx_n]); \
regArray##_7 = _mm256_loadu_si256(&a[(idx_m+7)*lda + idx_n]);
#define BF16_MATRIX_LOAD_8x8(regArray, a, lda, idx_m, idx_n) \
regArray##_0 = _mm_loadu_si128(&a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm_loadu_si128(&a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm_loadu_si128(&a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm_loadu_si128(&a[(idx_m+3)*lda + idx_n]); \
regArray##_4 = _mm_loadu_si128(&a[(idx_m+4)*lda + idx_n]); \
regArray##_5 = _mm_loadu_si128(&a[(idx_m+5)*lda + idx_n]); \
regArray##_6 = _mm_loadu_si128(&a[(idx_m+6)*lda + idx_n]); \
regArray##_7 = _mm_loadu_si128(&a[(idx_m+7)*lda + idx_n]);
#define BF16_MATRIX_LOAD_1x32(regArray, a, lda, idx_m, idx_n) \
regArray = _mm512_loadu_si512(&a[idx_m*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_8x32(regArray, a, lda, idx_m, idx_n, mask) \
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_8x16(regArray, a, lda, idx_m, idx_n, mask) \
regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
regArray##_4 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
regArray##_5 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
regArray##_6 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
regArray##_7 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_8x8(regArray, a, lda, idx_m, idx_n, mask) \
regArray##_0 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
regArray##_4 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
regArray##_5 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
regArray##_6 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
regArray##_7 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_4x32(regArray, a, lda, idx_m, idx_n, mask) \
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_4x16(regArray, a, lda, idx_m, idx_n, mask) \
regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_8x32_2(regArray, a, lda, idx_m, idx_n, mask) \
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+8)*lda + idx_n]); \
regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+10)*lda + idx_n]); \
regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+12)*lda + idx_n]); \
regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+14)*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_4x32_2(regArray, a, lda, idx_m, idx_n, mask) \
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_1x32(regArray, a, lda, idx_m, idx_n, mask) \
regArray = _mm512_maskz_loadu_epi16(mask, &a[idx_m*lda + idx_n]);
#define BF16_VECTOR_LOAD_1x32(reg, x, idx_n) \
reg = _mm512_loadu_si512(x + idx_n);
#define BF16_VECTOR_LOAD_1x16(reg, x, idx_n) \
reg = _mm256_loadu_si256(x + idx_n);
#define BF16_VECTOR_LOAD_1x8(reg, x, idx_n) \
reg = _mm_loadu_si128(x + idx_n);
#define BF16_VECTOR_MASKZ_LOAD_1x32(reg, x, idx_n, mask) \
reg = _mm512_maskz_loadu_epi16(mask, x + idx_n);
#define BF16_VECTOR_MASKZ_LOAD_1x16(reg, x, idx_n, mask) \
reg = _mm256_maskz_loadu_epi16(mask, x + idx_n);
#define BF16_VECTOR_MASKZ_LOAD_1x8(reg, x, idx_n, mask) \
reg = _mm_maskz_loadu_epi16(mask, x + idx_n);
/* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row
Input - register array of 8 rows of raw-major matrix
Output - the output of Step 2
Step 1: 2-element interleave for matrix
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27
|e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11|e16|e17|f16|f17|e18|e19|f18|f19|e24|e25|f24|f25|e26|e27|f26|f27
|g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11|g16|g17|h16|h17|g18|g19|h18|h19|g24|g25|h24|h25|g26|g27|h26|h27
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31
|e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15|e20|e21|f20|f21|e22|e23|f22|f23|e28|e29|f28|f29|e30|e31|f30|f31
|g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15|g20|g21|h20|h21|g22|g23|h22|h23|g28|g29|h28|h29|g30|g31|h30|h31
Step 2: 4-element interleave for matrix
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27
|e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9 |e16|e17|f16|f17|g16|g17|h16|h17|e24|e25|f24|f25|g24|g25|h24|h25
|e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11|e18|e19|f18|f19|g18|g19|h18|h19|e26|e27|f26|f27|g26|g27|h26|h27
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31
|e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13|e20|e21|f20|f21|g20|g21|h20|h21|e28|e29|f28|f29|g28|g29|h28|h29
|e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15|e22|e23|f22|f23|g22|g23|h22|h23|e30|e31|f30|f31|g30|g31|h30|h31
*/
#define BF16_INTERLEAVE_8x32(regArray) \
regArray##_8 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \
regArray##_9 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \
regArray##_10 = _mm512_unpacklo_epi32(regArray##_4, regArray##_5); \
regArray##_11 = _mm512_unpacklo_epi32(regArray##_6, regArray##_7); \
regArray##_12 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \
regArray##_13 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \
regArray##_14 = _mm512_unpackhi_epi32(regArray##_4, regArray##_5); \
regArray##_15 = _mm512_unpackhi_epi32(regArray##_6, regArray##_7); \
\
regArray##_0 = _mm512_unpacklo_epi64(regArray##_8, regArray##_9); \
regArray##_1 = _mm512_unpackhi_epi64(regArray##_8, regArray##_9); \
regArray##_2 = _mm512_unpacklo_epi64(regArray##_10, regArray##_11); \
regArray##_3 = _mm512_unpackhi_epi64(regArray##_10, regArray##_11); \
regArray##_4 = _mm512_unpacklo_epi64(regArray##_12, regArray##_13); \
regArray##_5 = _mm512_unpackhi_epi64(regArray##_12, regArray##_13); \
regArray##_6 = _mm512_unpacklo_epi64(regArray##_14, regArray##_15); \
regArray##_7 = _mm512_unpackhi_epi64(regArray##_14, regArray##_15);
/* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row
Input - register array of 8 rows of raw-major matrix
Output - the output of Step 2
Step 1: 2-element interleave for matrix
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11
|e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11
|g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15
|e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15
|g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15
Step 2: 4-element interleave for matrix
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11
|e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9
|e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15
|e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13
|e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15
*/
#define BF16_INTERLEAVE_8x16(regArray) \
regArray##_8 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \
regArray##_9 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \
regArray##_10 = _mm256_unpacklo_epi32(regArray##_4, regArray##_5); \
regArray##_11 = _mm256_unpacklo_epi32(regArray##_6, regArray##_7); \
regArray##_12 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \
regArray##_13 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \
regArray##_14 = _mm256_unpackhi_epi32(regArray##_4, regArray##_5); \
regArray##_15 = _mm256_unpackhi_epi32(regArray##_6, regArray##_7); \
\
regArray##_0 = _mm256_unpacklo_epi64(regArray##_8, regArray##_9); \
regArray##_1 = _mm256_unpackhi_epi64(regArray##_8, regArray##_9); \
regArray##_2 = _mm256_unpacklo_epi64(regArray##_10, regArray##_11); \
regArray##_3 = _mm256_unpackhi_epi64(regArray##_10, regArray##_11); \
regArray##_4 = _mm256_unpacklo_epi64(regArray##_12, regArray##_13); \
regArray##_5 = _mm256_unpackhi_epi64(regArray##_12, regArray##_13); \
regArray##_6 = _mm256_unpacklo_epi64(regArray##_14, regArray##_15); \
regArray##_7 = _mm256_unpackhi_epi64(regArray##_14, regArray##_15);
/* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row
Input - register array of 8 rows of raw-major matrix
Output - the output of Step 2
Step 1: 2-element interleave for matrix
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31
Step 2: 4-element interleave for matrix
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31
*/
#define BF16_INTERLEAVE_4x32(regArray) \
regArray##_4 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \
regArray##_5 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \
regArray##_6 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \
regArray##_7 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \
\
regArray##_0 = _mm512_unpacklo_epi64(regArray##_4, regArray##_5); \
regArray##_1 = _mm512_unpackhi_epi64(regArray##_4, regArray##_5); \
regArray##_2 = _mm512_unpacklo_epi64(regArray##_6, regArray##_7); \
regArray##_3 = _mm512_unpackhi_epi64(regArray##_6, regArray##_7);
/* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row
Input - register array of 8 rows of raw-major matrix
Output - the output of Step 2
Step 1: 2-element interleave for matrix
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15
Step 2: 4-element interleave for matrix
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15
*/
#define BF16_INTERLEAVE_4x16(regArray) \
regArray##_4 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \
regArray##_5 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \
regArray##_6 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \
regArray##_7 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \
\
regArray##_0 = _mm256_unpacklo_epi64(regArray##_4, regArray##_5); \
regArray##_1 = _mm256_unpackhi_epi64(regArray##_4, regArray##_5); \
regArray##_2 = _mm256_unpacklo_epi64(regArray##_6, regArray##_7); \
regArray##_3 = _mm256_unpackhi_epi64(regArray##_6, regArray##_7);
/* 2-step interleave for x with 32 BF16 elements
Input - original vector
Output - the output of Step 2
Step 1: 2-element interleave for x:
|x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11|x16|x17|x16|x17|x18|x19|x18|x19|x24|x25|x24|x25|x26|x27|x26|x27
|x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15|x20|x21|x20|x21|x22|x23|x22|x23|x28|x29|x28|x29|x30|x31|x30|x31
Step 2: 4-element interleave for x:
|x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9 |x16|x17|x16|x17|x16|x17|x16|x17|x24|x25|x24|x25|x24|x25|x24|x25
|x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11|x18|x19|x18|x19|x18|x19|x18|x19|x26|x27|x26|x27|x26|x27|x26|x27
|x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13|x20|x21|x20|x21|x20|x21|x20|x21|x28|x29|x28|x29|x28|x29|x28|x29
|x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15|x22|x23|x22|x23|x22|x23|x22|x23|x30|x31|x30|x31|x30|x31|x30|x31
*/
#define BF16_INTERLEAVE_1x32(regArray) \
regArray##_1 = _mm512_unpacklo_epi32(regArray##_0, regArray##_0); \
regArray##_3 = _mm512_unpackhi_epi32(regArray##_0, regArray##_0); \
\
regArray##_0 = _mm512_unpacklo_epi64(regArray##_1, regArray##_1); \
regArray##_1 = _mm512_unpackhi_epi64(regArray##_1, regArray##_1); \
regArray##_2 = _mm512_unpacklo_epi64(regArray##_3, regArray##_3); \
regArray##_3 = _mm512_unpackhi_epi64(regArray##_3, regArray##_3);
/* 2-step interleave for x with 16 BF16 elements
Input - original vector
Output - the output of Step 2
Step 1: 2-element interleave for x:
|x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11
|x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15
Step 2: 4-element interleave for x:
|x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9
|x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11
|x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13
|x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15
*/
#define BF16_INTERLEAVE_1x16(regArray) \
regArray##_1 = _mm256_unpacklo_epi32(regArray##_0, regArray##_0); \
regArray##_3 = _mm256_unpackhi_epi32(regArray##_0, regArray##_0); \
\
regArray##_0 = _mm256_unpacklo_epi64(regArray##_1, regArray##_1); \
regArray##_1 = _mm256_unpackhi_epi64(regArray##_1, regArray##_1); \
regArray##_2 = _mm256_unpacklo_epi64(regArray##_3, regArray##_3); \
regArray##_3 = _mm256_unpackhi_epi64(regArray##_3, regArray##_3);
/* 1-step interleave to exchange the high-256s bit and low-256 bits of 4 pair of registers
|a0|a1|...|a14|a15|i0|i1|...|i14|i15|
|b0|b1|...|b14|b15|j0|j1|...|j14|j15|
|c0|c1|...|c14|c15|k0|k1|...|k14|k15|
|d0|d1|...|d14|d15|l0|l1|...|l14|l15|
|e0|e1|...|e14|e15|m0|m1|...|m14|m15|
|f0|f1|...|f14|f15|n0|n1|...|n14|n15|
|g0|g1|...|g14|g15|o0|o1|...|o14|o15|
|h0|h1|...|h14|h15|p0|p1|...|p14|p15|
*/
#define BF16_INTERLEAVE256_8x32(regArray) \
regArray##_0 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0x44); \
regArray##_1 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0xee); \
regArray##_2 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0x44); \
regArray##_3 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0xee); \
regArray##_4 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0x44); \
regArray##_5 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0xee); \
regArray##_6 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0x44); \
regArray##_7 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0xee);
/* 1-step interleave to exchange the high-256s bit and low-256 bits of 2 pair of registers
|a0|a1|...|a14|a15|e0|e1|...|e14|e15|
|b0|b1|...|b14|b15|f0|f1|...|f14|f15|
|c0|c1|...|c14|c15|g0|g1|...|g14|g15|
|d0|d1|...|d14|d15|h0|h1|...|h14|h15|
*/
#define BF16_INTERLEAVE256_4x32(regArray) \
regArray##_0 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0x44); \
regArray##_1 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0xee); \
regArray##_2 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0x44); \
regArray##_3 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0xee);
#define BF16_PERMUTE_8x32(idx, regArray) \
regArray##_8 = _mm512_permutexvar_epi16(idx, regArray##_0); \
regArray##_9 = _mm512_permutexvar_epi16(idx, regArray##_1); \
regArray##_10 = _mm512_permutexvar_epi16(idx, regArray##_2); \
regArray##_11 = _mm512_permutexvar_epi16(idx, regArray##_3); \
regArray##_12 = _mm512_permutexvar_epi16(idx, regArray##_4); \
regArray##_13 = _mm512_permutexvar_epi16(idx, regArray##_5); \
regArray##_14 = _mm512_permutexvar_epi16(idx, regArray##_6); \
regArray##_15 = _mm512_permutexvar_epi16(idx, regArray##_7);
#define BF16_PERMUTE_8x32_2(idx, regArray) \
regArray##_8 = _mm512_permutexvar_epi32(idx, regArray##_0); \
regArray##_9 = _mm512_permutexvar_epi32(idx, regArray##_1); \
regArray##_10 = _mm512_permutexvar_epi32(idx, regArray##_2); \
regArray##_11 = _mm512_permutexvar_epi32(idx, regArray##_3); \
regArray##_12 = _mm512_permutexvar_epi32(idx, regArray##_4); \
regArray##_13 = _mm512_permutexvar_epi32(idx, regArray##_5); \
regArray##_14 = _mm512_permutexvar_epi32(idx, regArray##_6); \
regArray##_15 = _mm512_permutexvar_epi32(idx, regArray##_7);
#define BF16_PERMUTE_4x32(idx, regArray) \
regArray##_4 = _mm512_permutexvar_epi16(idx, regArray##_0); \
regArray##_5 = _mm512_permutexvar_epi16(idx, regArray##_1); \
regArray##_6 = _mm512_permutexvar_epi16(idx, regArray##_2); \
regArray##_7 = _mm512_permutexvar_epi16(idx, regArray##_3);
#define BF16_PERMUTE_4x32_2(idx, regArray) \
regArray##_4 = _mm512_permutexvar_epi32(idx, regArray##_0); \
regArray##_5 = _mm512_permutexvar_epi32(idx, regArray##_1); \
regArray##_6 = _mm512_permutexvar_epi32(idx, regArray##_2); \
regArray##_7 = _mm512_permutexvar_epi32(idx, regArray##_3);
/* Calculate the dot result for 2-step interleaved matrix and vector
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
*/
#define BF16_2STEP_INTERLEAVED_DOT_8x32(accumArray, matArray, xArray) \
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_2, (__m512bh) xArray##_0); \
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_1); \
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_4, (__m512bh) xArray##_2); \
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_6, (__m512bh) xArray##_2); \
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_5, (__m512bh) xArray##_3); \
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_7, (__m512bh) xArray##_3);
/* Calculate the dot result for 2-step interleaved matrix and vector
(Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
*/
#define BF16_2STEP_INTERLEAVED_DOT_8x16(accumArray, matArray, xArray) \
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_2, (__m256bh) xArray##_0); \
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_1); \
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_4, (__m256bh) xArray##_2); \
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_6, (__m256bh) xArray##_2); \
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_5, (__m256bh) xArray##_3); \
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_7, (__m256bh) xArray##_3);
/* Calculate the dot result for 2-step interleaved matrix and vector
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
*/
#define BF16_2STEP_INTERLEAVED_DOT_4x32(accumArray, matArray, xArray) \
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_2, (__m512bh) xArray##_2); \
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_3);
/* Calculate the dot result for 2-step interleaved matrix and vector
(Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
*/
#define BF16_2STEP_INTERLEAVED_DOT_4x16(accumArray, matArray, xArray) \
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_2, (__m256bh) xArray##_2); \
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_3);
/* Calculate the dot result for matrix and vector at 32 elements per row
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
*/
#define BF16_DOT_8x32(accumArray, matArray, xArray) \
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray); \
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray); \
accumArray##_2 = _mm512_dpbf16_ps(accumArray##_2, (__m512bh) matArray##_2, (__m512bh) xArray); \
accumArray##_3 = _mm512_dpbf16_ps(accumArray##_3, (__m512bh) matArray##_3, (__m512bh) xArray); \
accumArray##_4 = _mm512_dpbf16_ps(accumArray##_4, (__m512bh) matArray##_4, (__m512bh) xArray); \
accumArray##_5 = _mm512_dpbf16_ps(accumArray##_5, (__m512bh) matArray##_5, (__m512bh) xArray); \
accumArray##_6 = _mm512_dpbf16_ps(accumArray##_6, (__m512bh) matArray##_6, (__m512bh) xArray); \
accumArray##_7 = _mm512_dpbf16_ps(accumArray##_7, (__m512bh) matArray##_7, (__m512bh) xArray);
/* Calculate the dot result for matrix and vector at 32 elements per row
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
*/
#define BF16_DOT_1x32(accumArray, matArray, xArray) \
accumArray = _mm512_dpbf16_ps(accumArray, (__m512bh) matArray, (__m512bh) xArray);
/* Calculate the dot result for matrix and vector at 16 elements per row
(Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
*/
#define BF16_DOT_8x16(accumArray, matArray, xArray) \
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray); \
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray); \
accumArray##_2 = _mm256_dpbf16_ps(accumArray##_2, (__m256bh) matArray##_2, (__m256bh) xArray); \
accumArray##_3 = _mm256_dpbf16_ps(accumArray##_3, (__m256bh) matArray##_3, (__m256bh) xArray); \
accumArray##_4 = _mm256_dpbf16_ps(accumArray##_4, (__m256bh) matArray##_4, (__m256bh) xArray); \
accumArray##_5 = _mm256_dpbf16_ps(accumArray##_5, (__m256bh) matArray##_5, (__m256bh) xArray); \
accumArray##_6 = _mm256_dpbf16_ps(accumArray##_6, (__m256bh) matArray##_6, (__m256bh) xArray); \
accumArray##_7 = _mm256_dpbf16_ps(accumArray##_7, (__m256bh) matArray##_7, (__m256bh) xArray);
/* 2-step interleave for matrix against 8 rows with 16 fp32 elements per row
Input - register array of 8 rows of raw-major matrix
Output - the output of Step 2
Step 1: 2-element interleave for matrix
|a0|b0|a1|b1|a4|b4|a5|b5|a8 |b8 |a9 |b9 |a12|b12|a13|b13|
|c0|d0|c1|d1|c4|d4|c5|d5|c8 |d8 |c9 |d9 |c12|d12|c13|d13|
|e0|f0|e1|f1|e4|f4|e5|f5|e8 |f8 |e9 |f9 |e12|f12|e13|f13|
|g0|h0|g1|h1|g4|h4|g5|h5|g8 |h8 |g9 |h9 |g12|h12|g13|h13|
|a2|b2|a3|b3|a6|b6|a7|b7|a10|b10|a11|b11|a14|b14|a15|b15|
|c2|d2|c3|d3|c6|d6|c7|d7|c10|d10|c11|d11|c14|d14|c15|d15|
|e2|f2|e3|f3|e6|f6|e7|f7|e10|f10|e11|f11|e14|f14|e15|f15|
|g2|h2|g3|h3|g6|h6|g7|h7|g10|h10|g11|h11|g14|h14|g15|h15|
Step 2: 4-element interleave for matrix
|a0|b0|c0|d0|a4|b4|c4|d4|a8 |b8 |c8 |d8 |a12|b12|c12|d12|
|a1|b1|c1|d1|a5|b5|c5|d5|a9 |b9 |c9 |d9 |a13|b13|c13|d13|
|e0|f0|g0|h0|e4|f4|g4|h4|e8 |f8 |g8 |h8 |e12|f12|g12|h12|
|e1|f1|g1|h1|e5|f5|g5|h5|e9 |f9 |g9 |h9 |e13|f13|g13|h13|
|a2|b2|c2|d2|a6|b6|c6|d6|a10|b10|c10|d10|a14|b14|c14|d14|
|a3|b3|c3|d3|a7|b7|c7|d7|a11|b11|c11|d11|a15|b15|c15|d15|
|e2|f2|g2|h2|e6|f6|g6|h6|e10|f10|g10|h10|e14|f14|g14|h14|
|e3|f3|g3|h3|e7|f7|g7|h7|e11|f11|g11|h11|e15|f15|g15|h15|
*/
#define FP32_INTERLEAVE_8x16(regArray) \
regArray##_8 = _mm512_unpacklo_ps(regArray##_0, regArray##_1); \
regArray##_9 = _mm512_unpacklo_ps(regArray##_2, regArray##_3); \
regArray##_10 = _mm512_unpacklo_ps(regArray##_4, regArray##_5); \
regArray##_11 = _mm512_unpacklo_ps(regArray##_6, regArray##_7); \
regArray##_12 = _mm512_unpackhi_ps(regArray##_0, regArray##_1); \
regArray##_13 = _mm512_unpackhi_ps(regArray##_2, regArray##_3); \
regArray##_14 = _mm512_unpackhi_ps(regArray##_4, regArray##_5); \
regArray##_15 = _mm512_unpackhi_ps(regArray##_6, regArray##_7); \
\
regArray##_0 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \
regArray##_1 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \
regArray##_4 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \
regArray##_5 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \
regArray##_2 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \
regArray##_3 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \
regArray##_6 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_14, (__m512d) regArray##_15); \
regArray##_7 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_14, (__m512d) regArray##_15);
#define FP32_INTERLEAVE_8x16_ARRAY(regArray) \
regArray[8] = _mm512_unpacklo_ps(regArray[0], regArray[1]); \
regArray[9] = _mm512_unpacklo_ps(regArray[2], regArray[3]); \
regArray[10] = _mm512_unpacklo_ps(regArray[4], regArray[5]); \
regArray[11] = _mm512_unpacklo_ps(regArray[6], regArray[7]); \
regArray[12] = _mm512_unpackhi_ps(regArray[0], regArray[1]); \
regArray[13] = _mm512_unpackhi_ps(regArray[2], regArray[3]); \
regArray[14] = _mm512_unpackhi_ps(regArray[4], regArray[5]); \
regArray[15] = _mm512_unpackhi_ps(regArray[6], regArray[7]); \
\
regArray[0] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[8], (__m512d) regArray[9]); \
regArray[1] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[8], (__m512d) regArray[9]); \
regArray[4] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[10], (__m512d) regArray[11]); \
regArray[5] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[10], (__m512d) regArray[11]); \
regArray[2] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[12], (__m512d) regArray[13]); \
regArray[3] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[12], (__m512d) regArray[13]); \
regArray[6] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[14], (__m512d) regArray[15]); \
regArray[7] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[14], (__m512d) regArray[15]);
/* 2-step interleave for matrix against 8 rows with 8 fp32 elements per row
Input - register array of 8 rows of raw-major matrix
Output - the output of Step 2
Step 1: 2-element interleave for matrix
|a0|b0|a1|b1|a4|b4|a5|b5|
|c0|d0|c1|d1|c4|d4|c5|d5|
|e0|f0|e1|f1|e4|f4|e5|f5|
|g0|h0|g1|h1|g4|h4|g5|h5|
|a2|b2|a3|b3|a6|b6|a7|b7|
|c2|d2|c3|d3|c6|d6|c7|d7|
|e2|f2|e3|f3|e6|f6|e7|f7|
|g2|h2|g3|h3|g6|h6|g7|h7|
Step 2: 4-element interleave for matrix
|a0|b0|c0|d0|a4|b4|c4|d4|
|a1|b1|c1|d1|a5|b5|c5|d5|
|e0|f0|g0|h0|e4|f4|g4|h4|
|e1|f1|g1|h1|e5|f5|g5|h5|
|a2|b2|c2|d2|a6|b6|c6|d6|
|a3|b3|c3|d3|a7|b7|c7|d7|
|e2|f2|g2|h2|e6|f6|g6|h6|
|e3|f3|g3|h3|e7|f7|g7|h7|
*/
#define FP32_INTERLEAVE_8x8(regArray) \
regArray##_8 = _mm256_unpacklo_ps(regArray##_0, regArray##_1); \
regArray##_9 = _mm256_unpacklo_ps(regArray##_2, regArray##_3); \
regArray##_10 = _mm256_unpacklo_ps(regArray##_4, regArray##_5); \
regArray##_11 = _mm256_unpacklo_ps(regArray##_6, regArray##_7); \
regArray##_12 = _mm256_unpackhi_ps(regArray##_0, regArray##_1); \
regArray##_13 = _mm256_unpackhi_ps(regArray##_2, regArray##_3); \
regArray##_14 = _mm256_unpackhi_ps(regArray##_4, regArray##_5); \
regArray##_15 = _mm256_unpackhi_ps(regArray##_6, regArray##_7); \
\
regArray##_0 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \
regArray##_1 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \
regArray##_4 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \
regArray##_5 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \
regArray##_2 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \
regArray##_3 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \
regArray##_6 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_14, (__m256d) regArray##_15); \
regArray##_7 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_14, (__m256d) regArray##_15);
/* Accumulate the result for 2 batch of 4-registers
*/
#define FP32_ACCUM2_8x16(regArray) \
regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_1); \
regArray##_2 = _mm512_add_ps(regArray##_2, regArray##_3); \
regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_5); \
regArray##_6 = _mm512_add_ps(regArray##_6, regArray##_7); \
regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_2); \
regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_6);
#define FP32_ACCUM2_8x16_ARRAY(regArray) \
regArray[0] = _mm512_add_ps(regArray[0], regArray[1]); \
regArray[2] = _mm512_add_ps(regArray[2], regArray[3]); \
regArray[4] = _mm512_add_ps(regArray[4], regArray[5]); \
regArray[6] = _mm512_add_ps(regArray[6], regArray[7]); \
regArray[0] = _mm512_add_ps(regArray[0], regArray[2]); \
regArray[4] = _mm512_add_ps(regArray[4], regArray[6]);
/* Accumulate the result for 2 batch of 4-registers
*/
#define FP32_ACCUM2_8x8(regArray) \
regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_1); \
regArray##_2 = _mm256_add_ps(regArray##_2, regArray##_3); \
regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_5); \
regArray##_6 = _mm256_add_ps(regArray##_6, regArray##_7); \
regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_2); \
regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_6);
/* Store 16 (alpha * result + beta * y) to y
*/
#define STORE16_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_loadu_ps(targetAddr))); \
_mm512_storeu_ps(targetAddr, regResult);
/* Masked store 16 (alpha * result + beta * y) to y
*/
#define STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_maskz_loadu_ps(mask, targetAddr))); \
_mm512_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 8 (alpha * result + beta * y) to y
*/
#define STORE8_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_loadu_ps(targetAddr))); \
_mm256_storeu_ps(targetAddr, regResult);
/* Masked store 8 (alpha * result + beta * y) to y
*/
#define STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_maskz_loadu_ps(mask, targetAddr))); \
_mm256_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 4 (alpha * result + beta * y) to y
*/
#define STORE4_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_loadu_ps(targetAddr))); \
_mm_storeu_ps(targetAddr, regResult);
/* Masked store 4 (alpha * result + beta * y) to y
*/
#define STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_maskz_loadu_ps(mask, targetAddr))); \
_mm_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 16 (alpha * result + y) to y
*/
#define STORE16_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_loadu_ps(targetAddr)); \
_mm512_storeu_ps(targetAddr, regResult);
/* Masked store 16 (alpha * result + y) to y
*/
#define STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \
_mm512_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 8 (alpha * result + y) to y
*/
#define STORE8_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_loadu_ps(targetAddr)); \
_mm256_storeu_ps(targetAddr, regResult);
/* Masked store 8 (alpha * result + y) to y
*/
#define STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \
_mm256_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 4 (alpha * result + y) to y
*/
#define STORE4_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_loadu_ps(targetAddr)); \
_mm_storeu_ps(targetAddr, regResult);
/* Masked store 4 (alpha * result + y) to y
*/
#define STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \
_mm_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 16 (alpha * result) to y
*/
#define STORE16_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
_mm512_storeu_ps(targetAddr, _mm512_mul_ps(ALPHAVECTOR, regResult));
/* Masked store 16 (alpha * result) to y
*/
#define STORE16_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
_mm512_mask_storeu_ps(targetAddr, mask, _mm512_mul_ps(ALPHAVECTOR, regResult));
/* Store 8 (alpha * result) to y
*/
#define STORE8_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
_mm256_storeu_ps(targetAddr, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult));
/* Masked store 8 (alpha * result) to y
*/
#define STORE8_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
_mm256_mask_storeu_ps(targetAddr, mask, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult));
/* Store 4 (alpha * result) to y
*/
#define STORE4_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
_mm_storeu_ps(targetAddr, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult));
/* Masked store 4 (alpha * result) to y
*/
#define STORE4_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
_mm_mask_storeu_ps(targetAddr, mask, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult));
/* Store 16 result to y
*/
#define STORE16_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
_mm512_storeu_ps(targetAddr, regResult);
/* Masked store 16 result to y
*/
#define STORE16_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
_mm512_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 8 result to y
*/
#define STORE8_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
_mm256_storeu_ps(targetAddr, regResult);
/* Masked store 8 result to y
*/
#define STORE8_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
_mm256_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 4 result to y
*/
#define STORE4_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
_mm_storeu_ps(targetAddr, regResult);
/* Masked store 4 result to y
*/
#define STORE4_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
_mm_mask_storeu_ps(targetAddr, mask, regResult);
#endif

137
kernel/x86_64/sbgemv_n.c Normal file
View File

@@ -0,0 +1,137 @@
/***************************************************************************
Copyright (c) 2014, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#include "common.h"
#if defined (COOPERLAKE)
#include "sbgemv_n_microk_cooperlake.c"
#endif
#define ALIGN64_ALLOC(alloc_size, TYPE, ptr_align, ptr) \
ptr = (TYPE *) malloc(sizeof(TYPE)*alloc_size + 63); \
ptr_align = ((int)(((uintptr_t)ptr & (uintptr_t)0x3F))!=0) ? (TYPE *)((char *)ptr + (64 - (int)((uintptr_t)ptr & (uintptr_t)0x3F))) : ptr
#define ALIGN64_FREE(ptr) \
free(ptr)
#ifndef HAVE_SBGEMV_N_ACCL_KERNEL
static void sbgemv_kernel_n(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
{
BLASLONG offset_lda, offset_m;
float accum = 0.0;
float tmp_x = 0.0;
bfloat16 * a_bf16 = malloc(sizeof(bfloat16)*m*n);
float * a_fp32 = malloc(sizeof(float)*m*n);
float * x_fp32 = malloc(sizeof(float)*n);
for (BLASLONG j=0; j<n; j++) {
offset_lda = lda * j;
offset_m = m * j;
for (BLASLONG i=0; i<m; i++) {
a_bf16[offset_m + i] = a[offset_lda + i];
}
}
SBF16TOS_K(n, x, 1, x_fp32, 1);
SBF16TOS_K(m*n, a_bf16, 1, a_fp32, 1);
for (BLASLONG i=0; i<m; i++) {
accum = 0.0;
for (BLASLONG j=0; j<n; j++) {
accum += a_fp32[j*m + i] * x_fp32[j];
}
if (beta == ZERO) {
y[i] = alpha * accum;
} else {
y[i] = alpha * accum + beta * y[i];
}
}
free(a_bf16);
free(a_fp32);
free(x_fp32);
}
#endif
static void bf16_compress_vector(BLASLONG n, bfloat16 * src, bfloat16 * target, BLASLONG inc)
{
for(BLASLONG i=0; i<n; i++) {
target[i] = src[i*inc];
}
}
static void fp32_compress_vector(BLASLONG n, float * src, float * target, BLASLONG inc)
{
for(BLASLONG i=0; i<n; i++) {
target[i] = src[i*inc];
}
}
static void fp32_expand_vector(BLASLONG n, float * src, float * target, BLASLONG inc)
{
for(BLASLONG i=0; i<n; i++) {
target[i*inc] = src[i];
}
}
int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float * y, BLASLONG incy)
{
if ( m < 1 || n < 1) return(0);
bfloat16 * xbuffer_align = x;
float * ybuffer_align = y;
bfloat16 * xbuffer = NULL;
float * ybuffer = NULL;
if (incx != 1) {
ALIGN64_ALLOC(n, bfloat16, xbuffer_align, xbuffer);
bf16_compress_vector(n, x, xbuffer_align, incx);
}
if (incy != 1) {
ALIGN64_ALLOC(m, float, ybuffer_align, ybuffer);
if (beta != ZERO) {
fp32_compress_vector(m, y, ybuffer_align, incy);
}
}
sbgemv_kernel_n(m, n, alpha, a, lda, xbuffer_align, beta, ybuffer_align);
if (incy != 1) {
fp32_expand_vector(m, ybuffer_align, y, incy);
ALIGN64_FREE(ybuffer);
}
if (incx != 1) {
ALIGN64_FREE(xbuffer);
}
return(0);
}

View File

@@ -0,0 +1,76 @@
/***************************************************************************
Copyright (c) 2014, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
/* need a new enough GCC for avx512 support */
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
#define HAVE_SBGEMV_N_ACCL_KERNEL 1
#include "common.h"
#include <immintrin.h>
// Define micro kernels for ALPHA not ONE && BETA effective && BETA not ONE scenarios
#undef ZERO_BETA
#undef ONE_BETA
#undef ONE_ALPHA
#include "sbgemv_n_microk_cooperlake_template.c"
// Define micro kernels for ALPHA not ONE && BETA as ONE scenarios
#undef ZERO_BETA
#define ONE_BETA 1
#undef ONE_ALPHA
#include "sbgemv_n_microk_cooperlake_template.c"
// Define micro kernels for ALPHA not ONE && BETA in-effective (BETA == 0) scenarios
#define ZERO_BETA 1
#undef ONE_ALPHA
#include "sbgemv_n_microk_cooperlake_template.c"
// Define micro kernels for ALPHA as ONE && BETA in-effective (BETA == 0) scenarios
#define ZERO_BETA 1
#define ONE_ALPHA 1
#include "sbgemv_n_microk_cooperlake_template.c"
static int sbgemv_kernel_n(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
{
if (beta == ZERO) { // BETA == 0.0, no need to accumulate the original Y data
if (alpha == ONE) { // ALPHA == 1.0, no need to multipy ALPHA
sbgemv_kernel_32xN_lda_direct(m, n, alpha, a, lda, x, y);
} else { // ALPHA != 1.0, need to multipy ALPHA
sbgemv_kernel_32xN_lda_direct_alpha(m, n, alpha, a, lda, x, y);
}
} else { // BETA != 0.0, need to accumulate the original Y data no matter what ALPHA is
if (beta == ONE) {
sbgemv_kernel_32xN_lda_direct_alpha_one(m, n, alpha, a, lda, x, beta, y);
} else {
sbgemv_kernel_32xN_lda_direct_alpha_beta(m, n, alpha, a, lda, x, beta, y);
}
}
return 0;
}
#endif

View File

@@ -0,0 +1,234 @@
/***************************************************************************
Copyright (c) 2014, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#include <immintrin.h>
#include "common.h"
// Include common macros for BF16 based operations with IA intrinsics
#include "bf16_common_macros.h"
#ifndef ZERO_BETA // Beta is non-zero
#ifndef ONE_BETA // BETA is not ONE
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_BETA
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA_BETA
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA_BETA
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA
#else // BETA is ONE
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA_ONE
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA_ONE
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE
#endif
#else // BETA is zero
#ifndef ONE_ALPHA // ALPHA is not ONE
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA
#else // ALPHA is ONE
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_DIRECT
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_DIRECT
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_DIRECT
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_DIRECT
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_DIRECT
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_DIRECT
#endif
#endif
// 8 rows parallel processing BF16 GEMV kernel for big N && lda effective scenario (process before interleave)
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_32xN_lda_direct_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_32xN_lda_direct_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_32xN_lda_direct_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_32xN_lda_direct(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_32x = m & (~31);
BLASLONG tag_m_128x = m & (~127);
__m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7, \
accum512_8, accum512_9, accum512_10, accum512_11, accum512_12, accum512_13, accum512_14, accum512_15;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
__m512i matrixArray_seed_0, matrixArray_seed_1, matrixArray_seed_2, matrixArray_seed_3;
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7;
__m512i xArray_0;
__m512i ZERO512 = _mm512_setzero_si512();
unsigned int blend_hi_mask_value = ((unsigned int)0xaaaaaaaa);
__mmask32 blend_hi_mask = *((__mmask32*) &blend_hi_mask_value);
unsigned int blend_lo_mask_value = ((unsigned int)0x55555555);
__mmask32 blend_lo_mask = *((__mmask32*) &blend_lo_mask_value);
__m512i M512_EPI32_8 = _mm512_set1_epi32(8);
__m512i idx_base_0 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0);
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_8);
for (BLASLONG idx_m = 0; idx_m < tag_m_128x; idx_m+=128) {
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
accum512_2 = _mm512_setzero_ps();
accum512_3 = _mm512_setzero_ps();
accum512_4 = _mm512_setzero_ps();
accum512_5 = _mm512_setzero_ps();
accum512_6 = _mm512_setzero_ps();
accum512_7 = _mm512_setzero_ps();
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) {
xArray_0 = _mm512_set1_epi16(x[idx_n]);
BF16_MATRIX_LOAD_1x32(matrixArray_seed_0, a, lda, idx_n, idx_m + 0)
BF16_MATRIX_LOAD_1x32(matrixArray_seed_1, a, lda, idx_n, idx_m + 32)
BF16_MATRIX_LOAD_1x32(matrixArray_seed_2, a, lda, idx_n, idx_m + 64)
BF16_MATRIX_LOAD_1x32(matrixArray_seed_3, a, lda, idx_n, idx_m + 96)
matrixArray_0 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_0);
matrixArray_1 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_0);
matrixArray_2 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_1);
matrixArray_3 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_1);
matrixArray_4 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_2);
matrixArray_5 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_2);
matrixArray_6 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_3);
matrixArray_7 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_3);
BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0)
BF16_DOT_1x32(accum512_1, matrixArray_1, xArray_0)
BF16_DOT_1x32(accum512_2, matrixArray_2, xArray_0)
BF16_DOT_1x32(accum512_3, matrixArray_3, xArray_0)
BF16_DOT_1x32(accum512_4, matrixArray_4, xArray_0)
BF16_DOT_1x32(accum512_5, matrixArray_5, xArray_0)
BF16_DOT_1x32(accum512_6, matrixArray_6, xArray_0)
BF16_DOT_1x32(accum512_7, matrixArray_7, xArray_0)
}
accum512_8 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
accum512_9 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
accum512_10 = _mm512_permutex2var_ps(accum512_2, idx_base_0, accum512_3);
accum512_11 = _mm512_permutex2var_ps(accum512_2, idx_base_1, accum512_3);
accum512_12 = _mm512_permutex2var_ps(accum512_4, idx_base_0, accum512_5);
accum512_13 = _mm512_permutex2var_ps(accum512_4, idx_base_1, accum512_5);
accum512_14 = _mm512_permutex2var_ps(accum512_6, idx_base_0, accum512_7);
accum512_15 = _mm512_permutex2var_ps(accum512_6, idx_base_1, accum512_7);
STORE16_COMPLETE_RESULT(accum512_8, y+idx_m+0)
STORE16_COMPLETE_RESULT(accum512_9, y+idx_m+16)
STORE16_COMPLETE_RESULT(accum512_10, y+idx_m+32)
STORE16_COMPLETE_RESULT(accum512_11, y+idx_m+48)
STORE16_COMPLETE_RESULT(accum512_12, y+idx_m+64)
STORE16_COMPLETE_RESULT(accum512_13, y+idx_m+80)
STORE16_COMPLETE_RESULT(accum512_14, y+idx_m+96)
STORE16_COMPLETE_RESULT(accum512_15, y+idx_m+112)
}
for (BLASLONG idx_m = tag_m_128x; idx_m < tag_m_32x; idx_m+=32) {
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) {
xArray_0 = _mm512_set1_epi16(x[idx_n]);
BF16_MATRIX_LOAD_1x32(matrixArray_seed_0, a, lda, idx_n, idx_m)
matrixArray_0 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_0);
matrixArray_1 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_0);
BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0)
BF16_DOT_1x32(accum512_1, matrixArray_1, xArray_0)
}
accum512_8 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
accum512_9 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
STORE16_COMPLETE_RESULT(accum512_8, y+idx_m+0)
STORE16_COMPLETE_RESULT(accum512_9, y+idx_m+16)
}
if (tag_m_32x != m) {
unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(m&31)));
__mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
unsigned short store_tail_mask_value = (((unsigned int)0xffff) >> (16-(m&15)));
__mmask32 store_tail_mask = *((__mmask32*) &store_tail_mask_value);
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) {
xArray_0 = _mm512_set1_epi16(x[idx_n]);
BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_seed_0, a, lda, idx_n, tag_m_32x, tail_mask)
matrixArray_0 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_0);
matrixArray_1 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_0);
BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0)
BF16_DOT_1x32(accum512_1, matrixArray_1, xArray_0)
}
accum512_8 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
accum512_9 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
if ((m-tag_m_32x) > 16) {
STORE16_COMPLETE_RESULT(accum512_8, y+tag_m_32x+0)
STORE16_MASK_COMPLETE_RESULT(accum512_9, y+tag_m_32x+16, store_tail_mask)
} else {
STORE16_MASK_COMPLETE_RESULT(accum512_8, y+tag_m_32x+0, store_tail_mask)
}
}
return 0;
}

142
kernel/x86_64/sbgemv_t.c Normal file
View File

@@ -0,0 +1,142 @@
/***************************************************************************
Copyright (c) 2014, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#include "common.h"
#if defined (COOPERLAKE)
#include "sbgemv_t_microk_cooperlake.c"
#endif
#define ALIGN64_ALLOC(alloc_size, TYPE, ptr_align, ptr) \
ptr = (TYPE *) malloc(sizeof(TYPE)*alloc_size + 63); \
ptr_align = ((int)(((uintptr_t)ptr & (uintptr_t)0x3F))!=0) ? (TYPE *)((char *)ptr + (64 - (int)((uintptr_t)ptr & (uintptr_t)0x3F))) : ptr
#define ALIGN64_FREE(ptr) \
free(ptr)
#ifndef HAVE_SBGEMV_T_ACCL_KERNEL
static void sbgemv_kernel_t(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
{
BLASLONG offset_lda, offset_n;
float accum = 0.0;
bfloat16 * a_bf16 = malloc(sizeof(bfloat16)*m*n);
float * a_fp32 = malloc(sizeof(float)*m*n);
float * x_fp32 = malloc(sizeof(float)*n);
for (BLASLONG i=0; i<m; i++) {
offset_lda = lda * i;
offset_n = n * i;
for (BLASLONG j=0; j<n; j++) {
a_bf16[offset_n + j] = a[offset_lda + j];
}
}
SBF16TOS_K(n, x, 1, x_fp32, 1);
SBF16TOS_K(m*n, a_bf16, 1, a_fp32, 1);
for (BLASLONG i=0; i<m; i++) {
offset_n = n * i;
accum = 0.0;
for (BLASLONG j=0; j<n; j++) {
accum += a_fp32[offset_n + j] * x_fp32[j];
}
if (beta == ZERO) {
y[i] = alpha * accum;
} else {
y[i] = alpha * accum + beta * y[i];
}
}
free(a_bf16);
free(a_fp32);
free(x_fp32);
}
#endif
static void bf16_compress_vector(BLASLONG n, bfloat16 * src, bfloat16 * target, BLASLONG inc)
{
for(BLASLONG i=0; i<n; i++) {
target[i] = src[i*inc];
}
}
static void fp32_compress_vector(BLASLONG n, float * src, float * target, BLASLONG inc)
{
for(BLASLONG i=0; i<n; i++) {
target[i] = src[i*inc];
}
}
static void fp32_expand_vector(BLASLONG n, float * src, float * target, BLASLONG inc)
{
for(BLASLONG i=0; i<n; i++) {
target[i*inc] = src[i];
}
}
int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float * y, BLASLONG incy)
{
if ( m < 1 || n < 1) return(0);
bfloat16 * xbuffer_align = x;
float * ybuffer_align = y;
bfloat16 * xbuffer = NULL;
float * ybuffer = NULL;
// Switch m and n
BLASLONG t = m;
m = n;
n = t;
if (incx != 1) {
ALIGN64_ALLOC(n, bfloat16, xbuffer_align, xbuffer);
bf16_compress_vector(n, x, xbuffer_align, incx);
}
if (incy != 1) {
ALIGN64_ALLOC(m, float, ybuffer_align, ybuffer);
if (beta != ZERO) {
fp32_compress_vector(m, y, ybuffer_align, incy);
}
}
sbgemv_kernel_t(m, n, alpha, a, lda, xbuffer_align, beta, ybuffer_align);
if (incy != 1) {
fp32_expand_vector(m, ybuffer_align, y, incy);
ALIGN64_FREE(ybuffer);
}
if (incx != 1) {
ALIGN64_FREE(xbuffer);
}
return(0);
}

View File

@@ -0,0 +1,202 @@
/***************************************************************************
Copyright (c) 2014, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
/* need a new enough GCC for avx512 support */
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
#define HAVE_SBGEMV_T_ACCL_KERNEL 1
// Define micro kernels for ALPHA not ONE && BETA effective && BETA not ONE scenarios
#undef ZERO_BETA
#undef ONE_BETA
#undef ONE_ALPHA
#include "sbgemv_t_microk_cooperlake_template.c"
// Define micro kernels for ALPHA not ONE && BETA as ONE scenarios
#undef ZERO_BETA
#define ONE_BETA 1
#undef ONE_ALPHA
#include "sbgemv_t_microk_cooperlake_template.c"
// Define micro kernels for ALPHA not ONE && BETA in-effective (BETA == 0) scenarios
#define ZERO_BETA 1
#undef ONE_ALPHA
#include "sbgemv_t_microk_cooperlake_template.c"
// Define micro kernels for ALPHA as ONE && BETA in-effective (BETA == 0) scenarios
#define ZERO_BETA 1
#define ONE_ALPHA 1
#include "sbgemv_t_microk_cooperlake_template.c"
static int sbgemv_kernel_t(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
{
if (beta == ZERO) { // BETA == 0.0, no need to accumulate the original Y data
if (alpha == ONE) { // ALPHA == 1.0, no need to multipy ALPHA
if (n > 127) {
sbgemv_kernel_1x128_lda_direct(m, n, alpha, a, lda, x, y);
} else if (n > 32) {
sbgemv_kernel_8x32_lda_direct(m, n, alpha, a, lda, x, y);
} else {
if (n > 16) {
sbgemv_kernel_8x16p_lda(m, n, alpha, a, lda, x, y);
} else {
if (lda == n) {
switch(n) {
case 1: sbgemv_kernel_32x1 (m, alpha, a, x, y); break;
case 2: sbgemv_kernel_32x2 (m, alpha, a, x, y); break;
case 3: sbgemv_kernel_32x3 (m, alpha, a, x, y); break;
case 4: sbgemv_kernel_16x4 (m, alpha, a, x, y); break;
case 5: sbgemv_kernel_30x5 (m, alpha, a, x, y); break;
case 6: sbgemv_kernel_16x6 (m, alpha, a, x, y); break;
case 7: sbgemv_kernel_16x7 (m, alpha, a, x, y); break;
case 8: sbgemv_kernel_16x8 (m, alpha, a, x, y); break;
case 9: sbgemv_kernel_14x9 (m, alpha, a, x, y); break;
case 10: sbgemv_kernel_12x10(m, alpha, a, x, y); break;
case 11: sbgemv_kernel_15x11(m, alpha, a, x, y); break;
case 12: sbgemv_kernel_15x12(m, alpha, a, x, y); break;
case 13: sbgemv_kernel_16x13(m, alpha, a, x, y); break;
case 14: sbgemv_kernel_16x14(m, alpha, a, x, y); break;
case 15: sbgemv_kernel_16x15(m, alpha, a, x, y); break;
case 16: sbgemv_kernel_16x16(m, alpha, a, x, y); break;
default: break;
}
} else {
sbgemv_kernel_8x16m_lda(m, n, alpha, a, lda, x, y);
}
}
}
} else { // ALPHA != 1.0, need to multipy ALPHA
if (n > 127) {
sbgemv_kernel_1x128_lda_direct_alpha(m, n, alpha, a, lda, x, y);
} else if (n > 32) {
sbgemv_kernel_8x32_lda_direct_alpha(m, n, alpha, a, lda, x, y);
} else {
if (n > 16) {
sbgemv_kernel_8x16p_lda_alpha(m, n, alpha, a, lda, x, y);
} else {
if (lda == n) {
switch(n) {
case 1: sbgemv_kernel_32x1_alpha (m, alpha, a, x, y); break;
case 2: sbgemv_kernel_32x2_alpha (m, alpha, a, x, y); break;
case 3: sbgemv_kernel_32x3_alpha (m, alpha, a, x, y); break;
case 4: sbgemv_kernel_16x4_alpha (m, alpha, a, x, y); break;
case 5: sbgemv_kernel_30x5_alpha (m, alpha, a, x, y); break;
case 6: sbgemv_kernel_16x6_alpha (m, alpha, a, x, y); break;
case 7: sbgemv_kernel_16x7_alpha (m, alpha, a, x, y); break;
case 8: sbgemv_kernel_16x8_alpha (m, alpha, a, x, y); break;
case 9: sbgemv_kernel_14x9_alpha (m, alpha, a, x, y); break;
case 10: sbgemv_kernel_12x10_alpha(m, alpha, a, x, y); break;
case 11: sbgemv_kernel_15x11_alpha(m, alpha, a, x, y); break;
case 12: sbgemv_kernel_15x12_alpha(m, alpha, a, x, y); break;
case 13: sbgemv_kernel_16x13_alpha(m, alpha, a, x, y); break;
case 14: sbgemv_kernel_16x14_alpha(m, alpha, a, x, y); break;
case 15: sbgemv_kernel_16x15_alpha(m, alpha, a, x, y); break;
case 16: sbgemv_kernel_16x16_alpha(m, alpha, a, x, y); break;
default: break;
}
} else {
sbgemv_kernel_8x16m_lda_alpha(m, n, alpha, a, lda, x, y);
}
}
}
}
} else { // BETA != 0.0, need to accumulate the original Y data no matter what ALPHA is
if (beta == ONE) {
if (n > 127) {
sbgemv_kernel_1x128_lda_direct_alpha_one(m, n, alpha, a, lda, x, beta, y);
} else if (n > 32) {
sbgemv_kernel_8x32_lda_direct_alpha_one(m, n, alpha, a, lda, x, beta, y);
} else {
if (n > 16) {
sbgemv_kernel_8x16p_lda_alpha_one(m, n, alpha, a, lda, x, beta, y);
} else {
if (lda == n) {
switch(n) {
case 1: sbgemv_kernel_32x1_alpha_one (m, alpha, a, x, beta, y); break;
case 2: sbgemv_kernel_32x2_alpha_one (m, alpha, a, x, beta, y); break;
case 3: sbgemv_kernel_32x3_alpha_one (m, alpha, a, x, beta, y); break;
case 4: sbgemv_kernel_16x4_alpha_one (m, alpha, a, x, beta, y); break;
case 5: sbgemv_kernel_30x5_alpha_one (m, alpha, a, x, beta, y); break;
case 6: sbgemv_kernel_16x6_alpha_one (m, alpha, a, x, beta, y); break;
case 7: sbgemv_kernel_16x7_alpha_one (m, alpha, a, x, beta, y); break;
case 8: sbgemv_kernel_16x8_alpha_one (m, alpha, a, x, beta, y); break;
case 9: sbgemv_kernel_14x9_alpha_one (m, alpha, a, x, beta, y); break;
case 10: sbgemv_kernel_12x10_alpha_one(m, alpha, a, x, beta, y); break;
case 11: sbgemv_kernel_15x11_alpha_one(m, alpha, a, x, beta, y); break;
case 12: sbgemv_kernel_15x12_alpha_one(m, alpha, a, x, beta, y); break;
case 13: sbgemv_kernel_16x13_alpha_one(m, alpha, a, x, beta, y); break;
case 14: sbgemv_kernel_16x14_alpha_one(m, alpha, a, x, beta, y); break;
case 15: sbgemv_kernel_16x15_alpha_one(m, alpha, a, x, beta, y); break;
case 16: sbgemv_kernel_16x16_alpha_one(m, alpha, a, x, beta, y); break;
default: break;
}
} else {
sbgemv_kernel_8x16m_lda_alpha_one(m, n, alpha, a, lda, x, beta, y);
}
}
}
} else {
if (n > 127) {
sbgemv_kernel_1x128_lda_direct_alpha_beta(m, n, alpha, a, lda, x, beta, y);
} else if (n > 32) {
sbgemv_kernel_8x32_lda_direct_alpha_beta(m, n, alpha, a, lda, x, beta, y);
} else {
if (n > 16) {
sbgemv_kernel_8x16p_lda_alpha_beta(m, n, alpha, a, lda, x, beta, y);
} else {
if (lda == n) {
switch(n) {
case 1: sbgemv_kernel_32x1_alpha_beta (m, alpha, a, x, beta, y); break;
case 2: sbgemv_kernel_32x2_alpha_beta (m, alpha, a, x, beta, y); break;
case 3: sbgemv_kernel_32x3_alpha_beta (m, alpha, a, x, beta, y); break;
case 4: sbgemv_kernel_16x4_alpha_beta (m, alpha, a, x, beta, y); break;
case 5: sbgemv_kernel_30x5_alpha_beta (m, alpha, a, x, beta, y); break;
case 6: sbgemv_kernel_16x6_alpha_beta (m, alpha, a, x, beta, y); break;
case 7: sbgemv_kernel_16x7_alpha_beta (m, alpha, a, x, beta, y); break;
case 8: sbgemv_kernel_16x8_alpha_beta (m, alpha, a, x, beta, y); break;
case 9: sbgemv_kernel_14x9_alpha_beta (m, alpha, a, x, beta, y); break;
case 10: sbgemv_kernel_12x10_alpha_beta(m, alpha, a, x, beta, y); break;
case 11: sbgemv_kernel_15x11_alpha_beta(m, alpha, a, x, beta, y); break;
case 12: sbgemv_kernel_15x12_alpha_beta(m, alpha, a, x, beta, y); break;
case 13: sbgemv_kernel_16x13_alpha_beta(m, alpha, a, x, beta, y); break;
case 14: sbgemv_kernel_16x14_alpha_beta(m, alpha, a, x, beta, y); break;
case 15: sbgemv_kernel_16x15_alpha_beta(m, alpha, a, x, beta, y); break;
case 16: sbgemv_kernel_16x16_alpha_beta(m, alpha, a, x, beta, y); break;
default: break;
}
} else {
sbgemv_kernel_8x16m_lda_alpha_beta(m, n, alpha, a, lda, x, beta, y);
}
}
}
}
}
return 0;
}
#endif

File diff suppressed because it is too large Load Diff