Merge pull request #3416 from guowangy/spr-bf16
sbgemm: add AMX-BF16 based kernel for Sapphire Rapids
This commit is contained in:
commit
8cbf61792d
|
@ -1 +1,14 @@
|
|||
include $(KERNELDIR)/KERNEL.COOPERLAKE
|
||||
|
||||
SBGEMM_SMALL_M_PERMIT = sbgemm_small_kernel_permit_spr.c
|
||||
|
||||
SBGEMM_BETA = sgemm_beta_skylakex.c
|
||||
SBGEMMKERNEL = sbgemm_kernel_16x16_spr.c
|
||||
SBGEMMINCOPY = sbgemm_ncopy_16_cooperlake.c
|
||||
SBGEMMITCOPY = sbgemm_tcopy_16_cooperlake.c
|
||||
SBGEMMONCOPY = sbgemm_oncopy_16_spr.c
|
||||
SBGEMMOTCOPY = sbgemm_otcopy_16_spr.c
|
||||
SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX)
|
||||
SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX)
|
||||
SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX)
|
||||
SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX)
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/***************************************************************************
|
||||
* Copyright (c) 2021, The OpenBLAS Project
|
||||
* All rights reserved.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
* 1. Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* 2. Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in
|
||||
* the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
* 3. Neither the name of the OpenBLAS project nor the names of
|
||||
* its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
* USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* *****************************************************************************/
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#define ALPHA_ONE
|
||||
#include "sbgemm_kernel_16x16_spr_tmpl.c"
|
||||
#undef ALPHA_ONE
|
||||
#include "sbgemm_kernel_16x16_spr_tmpl.c"
|
||||
|
||||
|
||||
int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOAT * iB, FLOAT * C, BLASLONG ldc)
|
||||
{
|
||||
/* transport to Row Major matrix for AMX requirement */
|
||||
BLASLONG m, n;
|
||||
IFLOAT *A, *B;
|
||||
m = in;
|
||||
n = im;
|
||||
A = iB;
|
||||
B = iA;
|
||||
|
||||
if (alpha == 1.0f)
|
||||
return sbgemm_kernel_spr_alpha_one(m, n, k, alpha, A, B, C, ldc);
|
||||
else
|
||||
return sbgemm_kernel_spr_alpha(m, n, k, alpha, A, B, C, ldc);
|
||||
}
|
|
@ -0,0 +1,530 @@
|
|||
/***************************************************************************
|
||||
* Copyright (c) 2021, The OpenBLAS Project
|
||||
* All rights reserved.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
* 1. Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* 2. Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in
|
||||
* the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
* 3. Neither the name of the OpenBLAS project nor the names of
|
||||
* its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
* USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* *****************************************************************************/
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <string.h>
|
||||
#include "common.h"
|
||||
|
||||
#ifndef SBGEMM_KERNEL_SPR
|
||||
#define SBGEMM_KERNEL_SPR
|
||||
typedef struct {
|
||||
char palette_id;
|
||||
char start_row;
|
||||
char dummy0[14]; // bytes 2-15 reserved, must be zero
|
||||
short tile_colsb[8];
|
||||
char dummy1[16]; // bytes 32-47 reserved, must be zero
|
||||
char tile_rows[8];
|
||||
char dummy2[16]; // bytes 56-63 reserved, must be zero
|
||||
} tilecfg;
|
||||
|
||||
/* tile0/tile1 -- A (m x 2k)
|
||||
* tile2/tile3 -- B (2k x n)
|
||||
* tile4-7 -- C (m x n)
|
||||
*/
|
||||
#define TCONF(cfg, m, n, k2) \
|
||||
memset(&cfg, 0, sizeof(tilecfg)); \
|
||||
cfg.palette_id = 1; \
|
||||
cfg.tile_rows[0] = m; \
|
||||
cfg.tile_rows[1] = m; \
|
||||
cfg.tile_rows[2] = k2>>1; \
|
||||
cfg.tile_rows[3] = k2>>1; \
|
||||
cfg.tile_rows[4] = m; \
|
||||
cfg.tile_rows[5] = m; \
|
||||
cfg.tile_rows[6] = m; \
|
||||
cfg.tile_rows[7] = m; \
|
||||
cfg.tile_colsb[0] = k2<<1; \
|
||||
cfg.tile_colsb[1] = k2<<1; \
|
||||
cfg.tile_colsb[2] = n * 4; \
|
||||
cfg.tile_colsb[3] = n * 4; \
|
||||
cfg.tile_colsb[4] = n * 4; \
|
||||
cfg.tile_colsb[5] = n * 4; \
|
||||
cfg.tile_colsb[6] = n * 4; \
|
||||
cfg.tile_colsb[7] = n * 4; \
|
||||
_tile_loadconfig(&cfg);
|
||||
|
||||
/* CONFIG for handling k2 and odd tail at the same time
|
||||
* tile0 -- A (m x 2k)
|
||||
* tile1 -- A (m x 1)
|
||||
* tile2 -- B (2k x n)
|
||||
* tile3 -- B (1 x n)
|
||||
* tile4 -- C (m x n)
|
||||
*/
|
||||
#define TCONF_TAIL(cfg, m, n, k2) \
|
||||
memset(&cfg, 0, sizeof(tilecfg)); \
|
||||
cfg.palette_id = 1; \
|
||||
cfg.tile_rows[0] = m; \
|
||||
cfg.tile_rows[1] = m; \
|
||||
cfg.tile_rows[2] = k2>>1; \
|
||||
cfg.tile_rows[3] = 1; \
|
||||
cfg.tile_rows[4] = m; \
|
||||
cfg.tile_colsb[0] = k2<<1; \
|
||||
cfg.tile_colsb[1] = 4; \
|
||||
cfg.tile_colsb[2] = n * 4; \
|
||||
cfg.tile_colsb[3] = n * 4; \
|
||||
cfg.tile_colsb[4] = n * 4; \
|
||||
_tile_loadconfig(&cfg);
|
||||
|
||||
#define T_A0 0
|
||||
#define T_A1 1
|
||||
#define T_B0 2
|
||||
#define T_B1 3
|
||||
#define T_C00 4
|
||||
#define T_C01 5
|
||||
#define T_C10 6
|
||||
#define T_C11 7
|
||||
|
||||
// FIXME: gcc11 seem have problem in tile load/store address calc,
|
||||
// need to multiply with element size (2 or 4) here.
|
||||
#define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2)
|
||||
#define LOAD_A_TAIL(M, N) {\
|
||||
__m256i ymm = _mm256_loadu_epi16(ptr_a##M); \
|
||||
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
|
||||
_mm512_storeu_epi16(tail_a + 16 * M, zmm); \
|
||||
_tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
|
||||
}
|
||||
#define MASK_LOAD_A_TAIL(M, N) {\
|
||||
__m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \
|
||||
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
|
||||
_mm512_storeu_epi16(tail_a + 16 * M, zmm); \
|
||||
_tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
|
||||
}
|
||||
#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2)
|
||||
#define LOAD_B_TAIL(M, N) {\
|
||||
__m256i ymm = _mm256_loadu_epi16(ptr_b##N); \
|
||||
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
|
||||
_mm512_storeu_epi16(tail_b + 16 * N, zmm); \
|
||||
_tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \
|
||||
}
|
||||
#define MASK_LOAD_B_TAIL(M, N) {\
|
||||
__m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \
|
||||
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
|
||||
_mm512_storeu_epi16(tail_b + 16 * N, zmm); \
|
||||
_tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \
|
||||
}
|
||||
|
||||
#define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N)
|
||||
#define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N)
|
||||
#define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4)
|
||||
#define LOAD_C_F(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4)
|
||||
|
||||
#endif // end of SBGEMM_KERNEL_SPR
|
||||
|
||||
#ifdef ALPHA_ONE
|
||||
#undef LOAD_C
|
||||
#define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4)
|
||||
#else
|
||||
#undef LOAD_C
|
||||
#define LOAD_C(M, N) _tile_zero(T_C##M##N)
|
||||
#define ALPHA_STORE(N) \
|
||||
__m512 zmm_d##N = _mm512_loadu_ps(dst##N + noffset); \
|
||||
__m512 zmm_s##N = _mm512_loadu_ps(src##N + noffset); \
|
||||
zmm_d##N = _mm512_fmadd_ps(alpha_512, zmm_s##N, zmm_d##N); \
|
||||
_mm512_storeu_ps(dst##N + noffset, zmm_d##N);
|
||||
#define MASK_APLPHA_STORE(N) \
|
||||
__m512 zmm_d##N = _mm512_maskz_loadu_ps(mask, dst##N + noffset); \
|
||||
__m512 zmm_s##N = _mm512_maskz_loadu_ps(mask, src##N + noffset); \
|
||||
zmm_d##N = _mm512_fmadd_ps(alpha_512, zmm_s##N, zmm_d##N); \
|
||||
_mm512_mask_storeu_ps(dst##N + noffset, mask, zmm_d##N);
|
||||
#endif // end of ALPHA_ONE
|
||||
|
||||
|
||||
#ifdef ALPHA_ONE
|
||||
int sbgemm_kernel_spr_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc)
|
||||
#else
|
||||
int sbgemm_kernel_spr_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc)
|
||||
#endif
|
||||
{
|
||||
/* Row Major matrix for AMX requirement */
|
||||
IFLOAT *ptr_a = A, *ptr_b = B;
|
||||
IFLOAT *ptr_b0, *ptr_b1;
|
||||
IFLOAT *ptr_a0, *ptr_a1;
|
||||
FLOAT *ptr_c = C;
|
||||
FLOAT *ptr_c00, *ptr_c01, *ptr_c10, *ptr_c11;
|
||||
|
||||
BLASLONG lda, ldb;
|
||||
BLASLONG m_count = m;
|
||||
BLASLONG n_count, k_count;
|
||||
|
||||
#ifndef ALPHA_ONE
|
||||
// make sure each row is 64 bytes aligned
|
||||
BLASLONG cn = (n & 31) ? (n & ~31) + 32 : n;
|
||||
FLOAT *raw_tmp_c;
|
||||
if (k < 32) {
|
||||
// only need to zero buff in this situation
|
||||
raw_tmp_c = (FLOAT *)calloc(1, sizeof(FLOAT) * m * cn + 64);
|
||||
} else {
|
||||
raw_tmp_c = (FLOAT *)malloc(sizeof(FLOAT) * m * cn + 64);
|
||||
}
|
||||
// align buf to 64 byte boundary
|
||||
FLOAT *tmp_c = (FLOAT *)(((uintptr_t) raw_tmp_c + 63) & ~(uintptr_t)63);
|
||||
ptr_c = tmp_c;
|
||||
BLASLONG ldc_o = ldc;
|
||||
ldc = cn;
|
||||
#endif
|
||||
IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64)));
|
||||
IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64)));
|
||||
tilecfg cfg;
|
||||
|
||||
if (k > 31) {
|
||||
for (; m_count > 31; m_count -= 32) {
|
||||
ptr_b = B;
|
||||
|
||||
ptr_c00 = ptr_c;
|
||||
ptr_c01 = ptr_c00 + 16;
|
||||
ptr_c10 = ptr_c + 16 * ldc;
|
||||
ptr_c11 = ptr_c10 + 16;
|
||||
ptr_c += 32 * ldc;
|
||||
n_count = n;
|
||||
TCONF(cfg, 16, 16, 32);
|
||||
for (; n_count > 31; n_count -= 32) {
|
||||
ptr_a0 = ptr_a;
|
||||
ptr_a1 = ptr_a + 16 * k;
|
||||
|
||||
ptr_b0 = ptr_b;
|
||||
ptr_b1 = ptr_b + 16 * k;
|
||||
ptr_b += 32 * k;
|
||||
|
||||
lda = 32;
|
||||
ldb = 32;
|
||||
LOAD_C(0, 0); LOAD_C(0, 1);
|
||||
LOAD_C(1, 0); LOAD_C(1, 1);
|
||||
k_count = k;
|
||||
for (; k_count > 31; k_count -= 32) {
|
||||
LOAD_A(0, x); LOAD_A(1, x);
|
||||
ptr_a0 += 16 * 32;
|
||||
ptr_a1 += 16 * 32;
|
||||
LOAD_B(x, 0); LOAD_B(x, 1);
|
||||
ptr_b0 += 16 * 32;
|
||||
ptr_b1 += 16 * 32;
|
||||
|
||||
MATMUL(0, 0); MATMUL(0, 1);
|
||||
MATMUL(1, 0); MATMUL(1, 1);
|
||||
}
|
||||
STORE_C(0, 0); STORE_C(0, 1);
|
||||
STORE_C(1, 0); STORE_C(1, 1);
|
||||
ptr_c00 += 32;
|
||||
ptr_c01 += 32;
|
||||
ptr_c10 += 32;
|
||||
ptr_c11 += 32;
|
||||
}
|
||||
for (; n_count > 0; n_count -= 16) {
|
||||
int tail_n = (n_count > 16) ? 16: n_count;
|
||||
ptr_a0 = ptr_a;
|
||||
ptr_a1 = ptr_a + 16 * k;
|
||||
|
||||
ptr_b0 = ptr_b;
|
||||
ptr_b += tail_n * k;
|
||||
|
||||
lda = 32;
|
||||
ldb = 2 * tail_n;
|
||||
TCONF(cfg, 16, tail_n, 32);
|
||||
LOAD_C(0, 0);
|
||||
LOAD_C(1, 0);
|
||||
k_count = k;
|
||||
for (; k_count > 31; k_count -= 32) {
|
||||
LOAD_A(0, x); LOAD_A(1, x);
|
||||
ptr_a0 += 16 * 32;
|
||||
ptr_a1 += 16 * 32;
|
||||
LOAD_B(x, 0);
|
||||
ptr_b0 += tail_n * 32;
|
||||
|
||||
MATMUL(0, 0);
|
||||
MATMUL(1, 0);
|
||||
}
|
||||
STORE_C(0, 0);
|
||||
STORE_C(1, 0);
|
||||
ptr_c00 += tail_n;
|
||||
ptr_c10 += tail_n;
|
||||
}
|
||||
ptr_a += 32 * k;
|
||||
}
|
||||
for (; m_count > 0; m_count -= 16) {
|
||||
// process at most 16 m at a time
|
||||
int tail_m = (m_count > 16) ? 16: m_count;
|
||||
|
||||
ptr_b = B;
|
||||
|
||||
ptr_c00 = ptr_c;
|
||||
ptr_c01 = ptr_c00 + 16;
|
||||
ptr_c += tail_m * ldc;
|
||||
n_count = n;
|
||||
TCONF(cfg, tail_m, 16, 32);
|
||||
for (; n_count > 31; n_count -= 32) {
|
||||
ptr_a0 = ptr_a;
|
||||
|
||||
ptr_b0 = ptr_b;
|
||||
ptr_b1 = ptr_b + 16 * k;
|
||||
ptr_b += 32 * k;
|
||||
|
||||
lda = 32;
|
||||
ldb = 32;
|
||||
LOAD_C(0, 0); LOAD_C(0, 1);
|
||||
k_count = k;
|
||||
for (; k_count > 31; k_count -= 32) {
|
||||
LOAD_A(0, x);
|
||||
ptr_a0 += tail_m * 32;
|
||||
LOAD_B(x, 0); LOAD_B(x, 1);
|
||||
ptr_b0 += 16 * 32;
|
||||
ptr_b1 += 16 * 32;
|
||||
|
||||
MATMUL(0, 0); MATMUL(0, 1);
|
||||
}
|
||||
STORE_C(0, 0); STORE_C(0, 1);
|
||||
ptr_c00 += 32;
|
||||
ptr_c01 += 32;
|
||||
}
|
||||
for (; n_count > 0; n_count -= 16) {
|
||||
int tail_n = (n_count > 16) ? 16: n_count;
|
||||
ptr_a0 = ptr_a;
|
||||
|
||||
ptr_b0 = ptr_b;
|
||||
ptr_b += tail_n * k;
|
||||
|
||||
lda = 32;
|
||||
ldb = 2 * tail_n;
|
||||
TCONF(cfg, tail_m, tail_n, 32);
|
||||
LOAD_C(0, 0);
|
||||
k_count = k;
|
||||
for (; k_count > 31; k_count -= 32) {
|
||||
LOAD_A(0, x);
|
||||
ptr_a0 += tail_m * 32;
|
||||
LOAD_B(x, 0);
|
||||
ptr_b0 += tail_n * 32;
|
||||
|
||||
MATMUL(0, 0);
|
||||
}
|
||||
STORE_C(0, 0);
|
||||
ptr_c00 += tail_n;
|
||||
}
|
||||
ptr_a += tail_m * k;
|
||||
}
|
||||
}
|
||||
|
||||
// process for k < 32
|
||||
BLASLONG k32 = k & ~31;
|
||||
BLASLONG k2 = k & ~1;
|
||||
if (k32 != k) {
|
||||
int remain_k2 = k2 - k32;
|
||||
m_count = m;
|
||||
ptr_a = A;
|
||||
#ifndef ALPHA_ONE
|
||||
ptr_c = tmp_c;
|
||||
#else
|
||||
ptr_c = C;
|
||||
#endif
|
||||
if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0)
|
||||
for (; m_count > 0; m_count -= 16) {
|
||||
int tail_m = (m_count > 16) ? 16: m_count;
|
||||
__mmask16 amask = (1UL << tail_m) - 1;
|
||||
|
||||
ptr_a0 = ptr_a + tail_m * k32;
|
||||
ptr_a1 = ptr_a + tail_m * k2;
|
||||
ptr_a += tail_m * k;
|
||||
ptr_b = B;
|
||||
ptr_c00 = ptr_c;
|
||||
ptr_c += tail_m * ldc;
|
||||
n_count = n;
|
||||
lda = remain_k2;
|
||||
ldb = 32;
|
||||
if (n_count > 15) {
|
||||
TCONF_TAIL(cfg, tail_m, 16, remain_k2);
|
||||
LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x);
|
||||
for (; n_count > 15; n_count -= 16) {
|
||||
ptr_b0 = ptr_b + 16 * k32;
|
||||
ptr_b1 = ptr_b + 16 * k2;
|
||||
LOAD_C_F(0, 0);
|
||||
LOAD_B(x, 0); LOAD_B_TAIL(x, 1);
|
||||
MATMUL(0, 0); MATMUL_TAIL(1, 1);
|
||||
STORE_C(0, 0);
|
||||
ptr_b += 16 * k;
|
||||
ptr_c00 += 16;
|
||||
}
|
||||
}
|
||||
if (n_count > 0) {
|
||||
int tail_n = (n_count > 16) ? 16: n_count;
|
||||
__mmask16 bmask = (1UL << tail_n) - 1;
|
||||
ptr_b0 = ptr_b + tail_n * k32;
|
||||
ptr_b1 = ptr_b + tail_n * k2;
|
||||
ldb = 2 * tail_n;
|
||||
TCONF_TAIL(cfg, tail_m, tail_n, remain_k2);
|
||||
LOAD_C_F(0, 0);
|
||||
LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x);
|
||||
LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1);
|
||||
MATMUL(0, 0); MATMUL_TAIL(1, 1);
|
||||
STORE_C(0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
} else if (remain_k2 > 0) { // k%32 = 2x
|
||||
for (; m_count > 0; m_count -= 16) {
|
||||
int tail_m = (m_count > 16) ? 16: m_count;
|
||||
|
||||
ptr_a0 = ptr_a + tail_m * k32;
|
||||
ptr_a += tail_m * k;
|
||||
ptr_b = B;
|
||||
ptr_c00 = ptr_c;
|
||||
ptr_c += tail_m * ldc;
|
||||
n_count = n;
|
||||
lda = remain_k2;
|
||||
ldb = 32;
|
||||
if (n_count > 15) {
|
||||
TCONF(cfg, tail_m, 16, remain_k2);
|
||||
LOAD_A(0, x);
|
||||
for (; n_count > 15; n_count -= 16) {
|
||||
ptr_b0 = ptr_b + 16 * k32;
|
||||
LOAD_C_F(0, 0);
|
||||
LOAD_B(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
ptr_b += 16 * k;
|
||||
ptr_c00 += 16;
|
||||
}
|
||||
}
|
||||
if (n_count > 0) {
|
||||
int tail_n = (n_count > 16) ? 16: n_count;
|
||||
ptr_b0 = ptr_b + tail_n * k32;
|
||||
ldb = 2 * tail_n;
|
||||
TCONF(cfg, tail_m, tail_n, remain_k2);
|
||||
LOAD_C_F(0, 0);
|
||||
LOAD_A(0, x);
|
||||
LOAD_B(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
}
|
||||
}
|
||||
} else { // k%32 = 1
|
||||
for (; m_count > 0; m_count -= 16) {
|
||||
int tail_m = (m_count > 16) ? 16: m_count;
|
||||
__mmask16 amask = (1UL << tail_m) - 1;
|
||||
|
||||
ptr_a0 = ptr_a + tail_m * k2;
|
||||
ptr_a += tail_m * k;
|
||||
ptr_b = B;
|
||||
ptr_c00 = ptr_c;
|
||||
ptr_c += tail_m * ldc;
|
||||
n_count = n;
|
||||
if (n_count > 15) {
|
||||
TCONF(cfg, tail_m, 16, 2);
|
||||
MASK_LOAD_A_TAIL(0, x);
|
||||
for (; n_count > 15; n_count -= 16) {
|
||||
ptr_b0 = ptr_b + 16 * k2;
|
||||
LOAD_C_F(0, 0);
|
||||
LOAD_B_TAIL(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
ptr_b += 16 * k;
|
||||
ptr_c00 += 16;
|
||||
}
|
||||
}
|
||||
if (n_count > 0) {
|
||||
int tail_n = (n_count > 16) ? 16: n_count;
|
||||
__mmask16 bmask = (1UL << tail_n) - 1;
|
||||
ptr_b0 = ptr_b + tail_n * k2;
|
||||
TCONF(cfg, tail_m, tail_n, 2);
|
||||
LOAD_C_F(0, 0);
|
||||
MASK_LOAD_A_TAIL(0, x);
|
||||
MASK_LOAD_B_TAIL(x, 0);
|
||||
MATMUL(0, 0);
|
||||
STORE_C(0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
#ifndef ALPHA_ONE
|
||||
__m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha));
|
||||
BLASLONG n16 = n & ~15;
|
||||
BLASLONG noffset;
|
||||
FLOAT *src0, *src1, *src2, *src3;
|
||||
FLOAT *dst0, *dst1, *dst2, *dst3;
|
||||
FLOAT *src = tmp_c;
|
||||
FLOAT *dst = C;
|
||||
m_count = m;
|
||||
for (; m_count > 3; m_count -= 4) {
|
||||
src0 = src;
|
||||
src1 = src0 + ldc;
|
||||
src2 = src1 + ldc;
|
||||
src3 = src2 + ldc;
|
||||
src += 4 * ldc;
|
||||
|
||||
dst0 = dst;
|
||||
dst1 = dst0 + ldc_o;
|
||||
dst2 = dst1 + ldc_o;
|
||||
dst3 = dst2 + ldc_o;
|
||||
dst += 4 * ldc_o;
|
||||
|
||||
noffset = 0;
|
||||
for (; noffset < n16; noffset += 16) {
|
||||
ALPHA_STORE(0);
|
||||
ALPHA_STORE(1);
|
||||
ALPHA_STORE(2);
|
||||
ALPHA_STORE(3);
|
||||
}
|
||||
if (noffset < n) {
|
||||
__mmask16 mask = (1UL << (n - noffset)) - 1;
|
||||
MASK_APLPHA_STORE(0);
|
||||
MASK_APLPHA_STORE(1);
|
||||
MASK_APLPHA_STORE(2);
|
||||
MASK_APLPHA_STORE(3);
|
||||
}
|
||||
}
|
||||
for (; m_count > 1; m_count -= 2) {
|
||||
src0 = src;
|
||||
src1 = src0 + ldc;
|
||||
src += 2 * ldc;
|
||||
|
||||
dst0 = dst;
|
||||
dst1 = dst0 + ldc_o;
|
||||
dst += 2 * ldc_o;
|
||||
|
||||
noffset = 0;
|
||||
for (; noffset < n16; noffset += 16) {
|
||||
ALPHA_STORE(0);
|
||||
ALPHA_STORE(1);
|
||||
}
|
||||
if (noffset < n) {
|
||||
__mmask16 mask = (1UL << (n - noffset)) - 1;
|
||||
MASK_APLPHA_STORE(0);
|
||||
MASK_APLPHA_STORE(1);
|
||||
}
|
||||
}
|
||||
for (; m_count > 0; m_count -= 1) {
|
||||
src0 = src;
|
||||
dst0 = dst;
|
||||
noffset = 0;
|
||||
for (; noffset < n16; noffset += 16) {
|
||||
ALPHA_STORE(0);
|
||||
}
|
||||
if (noffset < n) {
|
||||
__mmask16 mask = (1UL << (n - noffset)) - 1;
|
||||
MASK_APLPHA_STORE(0);
|
||||
}
|
||||
}
|
||||
free(raw_tmp_c);
|
||||
#endif
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
/***************************************************************************
|
||||
* Copyright (c) 2021, The OpenBLAS Project
|
||||
* All rights reserved.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
* 1. Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* 2. Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in
|
||||
* the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
* 3. Neither the name of the OpenBLAS project nor the names of
|
||||
* its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
* USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* *****************************************************************************/
|
||||
|
||||
#include <immintrin.h>
|
||||
#include "common.h"
|
||||
|
||||
typedef struct {
|
||||
char palette_id;
|
||||
char start_row;
|
||||
char dummy0[14]; // bytes 2-15 reserved, must be zero
|
||||
short tile_colsb[8];
|
||||
char dummy1[16]; // bytes 32-47 reserved, must be zero
|
||||
char tile_rows[8];
|
||||
char dummy2[16]; // bytes 56-63 reserved, must be zero
|
||||
} tilecfg;
|
||||
|
||||
#define T_16x32 0
|
||||
#define T_16xm 1
|
||||
#define T_nx32 2
|
||||
#define T_nxm 3
|
||||
|
||||
#define TCONF(cfg, m, n) \
|
||||
memset(&cfg, 0, sizeof(tilecfg)); \
|
||||
cfg.palette_id = 1; \
|
||||
cfg.tile_rows[T_16x32] = 16; \
|
||||
cfg.tile_colsb[T_16x32] = 64; \
|
||||
if (m) { \
|
||||
cfg.tile_rows[T_16xm] = 16; \
|
||||
cfg.tile_colsb[T_16xm] = m * 2; \
|
||||
} \
|
||||
if (n) { \
|
||||
cfg.tile_rows[T_nx32] = n; \
|
||||
cfg.tile_colsb[T_nx32] = 64; \
|
||||
} \
|
||||
if (m && n) { \
|
||||
cfg.tile_rows[T_nxm] = n; \
|
||||
cfg.tile_colsb[T_nxm] = m * 2; \
|
||||
} \
|
||||
_tile_loadconfig(&cfg);
|
||||
|
||||
|
||||
int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) {
|
||||
BLASLONG i, j;
|
||||
IFLOAT *aoffset, *boffset;
|
||||
IFLOAT *aoffset0;
|
||||
|
||||
aoffset = a;
|
||||
boffset = b;
|
||||
|
||||
BLASLONG n16 = n & ~15;
|
||||
BLASLONG m32 = m & ~31;
|
||||
BLASLONG m2 = m & ~1;
|
||||
|
||||
BLASLONG tail_m = m2 - m32;
|
||||
BLASLONG tail_n = n - n16;
|
||||
tilecfg cfg;
|
||||
TCONF(cfg, tail_m, tail_n);
|
||||
|
||||
for (j = 0; j < n16; j += 16) {
|
||||
aoffset0 = aoffset;
|
||||
for (i = 0; i < m32; i += 32) {
|
||||
_tile_loadd(T_16x32, aoffset0, lda * 2);
|
||||
_tile_stored(T_16x32, boffset, 32 * 2);
|
||||
aoffset0 += 32;
|
||||
boffset += 32 * 16;
|
||||
}
|
||||
if (i < m2) {
|
||||
_tile_loadd(T_16xm, aoffset0, lda * 2);
|
||||
_tile_stored(T_16xm, boffset, tail_m * 2);
|
||||
aoffset0 += tail_m;
|
||||
boffset += tail_m * 16;
|
||||
i = m2;
|
||||
}
|
||||
if (i < m) {
|
||||
/* the tail odd k should put alone */
|
||||
for (int ii = 0; ii < 16; ii++) {
|
||||
*(boffset + ii) = *(aoffset0 + lda * ii);
|
||||
}
|
||||
boffset += 16;
|
||||
}
|
||||
aoffset += 16 * lda;
|
||||
}
|
||||
if (j < n) {
|
||||
aoffset0 = aoffset;
|
||||
for (i = 0; i < m32; i += 32) {
|
||||
_tile_loadd(T_nx32, aoffset0, lda * 2);
|
||||
_tile_stored(T_nx32, boffset, 32 * 2);
|
||||
aoffset0 += 32;
|
||||
boffset += 32 * tail_n;
|
||||
}
|
||||
if (i < m2) {
|
||||
_tile_loadd(T_nxm, aoffset0, lda * 2);
|
||||
_tile_stored(T_nxm, boffset, tail_m * 2);
|
||||
aoffset0 += tail_m;
|
||||
boffset += tail_m * tail_n;
|
||||
}
|
||||
if (i < m) {
|
||||
for (int ii = 0; ii < tail_n; ii++) {
|
||||
*(boffset + ii) = *(aoffset0 + lda * ii);
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,302 @@
|
|||
/***************************************************************************
|
||||
* Copyright (c) 2021, The OpenBLAS Project
|
||||
* All rights reserved.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
* 1. Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* 2. Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in
|
||||
* the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
* 3. Neither the name of the OpenBLAS project nor the names of
|
||||
* its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
* USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* *****************************************************************************/
|
||||
|
||||
#include <immintrin.h>
|
||||
#include "common.h"
|
||||
|
||||
#define LOAD_A_8VEC(aptr) \
|
||||
r0 = _mm256_loadu_si256((__m256i *)(aptr + lda*0)); \
|
||||
r1 = _mm256_loadu_si256((__m256i *)(aptr + lda*1)); \
|
||||
r2 = _mm256_loadu_si256((__m256i *)(aptr + lda*2)); \
|
||||
r3 = _mm256_loadu_si256((__m256i *)(aptr + lda*3)); \
|
||||
r4 = _mm256_loadu_si256((__m256i *)(aptr + lda*4)); \
|
||||
r5 = _mm256_loadu_si256((__m256i *)(aptr + lda*5)); \
|
||||
r6 = _mm256_loadu_si256((__m256i *)(aptr + lda*6)); \
|
||||
r7 = _mm256_loadu_si256((__m256i *)(aptr + lda*7));
|
||||
|
||||
#define MASK_LOAD_A_8VEC(aptr) \
|
||||
r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*0)); \
|
||||
r1 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*1)); \
|
||||
r2 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*2)); \
|
||||
r3 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*3)); \
|
||||
r4 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*4)); \
|
||||
r5 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*5)); \
|
||||
r6 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*6)); \
|
||||
r7 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*7));
|
||||
|
||||
#define SWITCH_LOAD_A_8VEC(aptr, cond) \
|
||||
switch((cond)) { \
|
||||
case 8: r7 = _mm256_loadu_si256((__m256i *)(aptr + lda*7)); \
|
||||
case 7: r6 = _mm256_loadu_si256((__m256i *)(aptr + lda*6)); \
|
||||
case 6: r5 = _mm256_loadu_si256((__m256i *)(aptr + lda*5)); \
|
||||
case 5: r4 = _mm256_loadu_si256((__m256i *)(aptr + lda*4)); \
|
||||
case 4: r3 = _mm256_loadu_si256((__m256i *)(aptr + lda*3)); \
|
||||
case 3: r2 = _mm256_loadu_si256((__m256i *)(aptr + lda*2)); \
|
||||
case 2: r1 = _mm256_loadu_si256((__m256i *)(aptr + lda*1)); \
|
||||
case 1: r0 = _mm256_loadu_si256((__m256i *)(aptr + lda*0)); \
|
||||
}
|
||||
|
||||
#define SWITCH_MASK_LOAD_A_8VEC(aptr, cond) \
|
||||
switch((cond)) { \
|
||||
case 8: r7 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*7)); \
|
||||
case 7: r6 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*6)); \
|
||||
case 6: r5 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*5)); \
|
||||
case 5: r4 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*4)); \
|
||||
case 4: r3 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*3)); \
|
||||
case 3: r2 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*2)); \
|
||||
case 2: r1 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*1)); \
|
||||
case 1: r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*0)); \
|
||||
}
|
||||
|
||||
#define REORDER_8x16(t0, t1, t2, t3, t4, t5, t6, t7) \
|
||||
t0 = _mm256_unpacklo_epi16(r0, r1); \
|
||||
t1 = _mm256_unpackhi_epi16(r0, r1); \
|
||||
t2 = _mm256_unpacklo_epi16(r2, r3); \
|
||||
t3 = _mm256_unpackhi_epi16(r2, r3); \
|
||||
t4 = _mm256_unpacklo_epi16(r4, r5); \
|
||||
t5 = _mm256_unpackhi_epi16(r4, r5); \
|
||||
t6 = _mm256_unpacklo_epi16(r6, r7); \
|
||||
t7 = _mm256_unpackhi_epi16(r6, r7); \
|
||||
r0 = _mm256_unpacklo_epi32(t0, t2); \
|
||||
r1 = _mm256_unpacklo_epi32(t1, t3); \
|
||||
r2 = _mm256_unpacklo_epi32(t4, t6); \
|
||||
r3 = _mm256_unpacklo_epi32(t5, t7); \
|
||||
r4 = _mm256_unpackhi_epi32(t0, t2); \
|
||||
r5 = _mm256_unpackhi_epi32(t1, t3); \
|
||||
r6 = _mm256_unpackhi_epi32(t4, t6); \
|
||||
r7 = _mm256_unpackhi_epi32(t5, t7); \
|
||||
t0 = _mm256_unpacklo_epi64(r0, r2); \
|
||||
t1 = _mm256_unpackhi_epi64(r0, r2); \
|
||||
t2 = _mm256_unpacklo_epi64(r4, r6); \
|
||||
t3 = _mm256_unpackhi_epi64(r4, r6); \
|
||||
t4 = _mm256_unpacklo_epi64(r1, r3); \
|
||||
t5 = _mm256_unpackhi_epi64(r1, r3); \
|
||||
t6 = _mm256_unpacklo_epi64(r5, r7); \
|
||||
t7 = _mm256_unpackhi_epi64(r5, r7);
|
||||
|
||||
#define STORE_256_LO(x) \
|
||||
v = _mm256_permute2x128_si256(t0##x, t1##x, 0x20); \
|
||||
_mm256_storeu_si256((__m256i *)(boffset + x*32), v);
|
||||
|
||||
#define STORE_256_HI(x) \
|
||||
v = _mm256_permute2x128_si256(t0##x, t1##x, 0x31); \
|
||||
_mm256_storeu_si256((__m256i *)(boffset + (x + 8)*32), v);
|
||||
|
||||
#define MASK_STORE_256_LO(x) \
|
||||
v = _mm256_permute2x128_si256(t0##x, t1##x, 0x20); \
|
||||
_mm256_mask_storeu_epi16(boffset + x*m_load, mmask, v);
|
||||
|
||||
#define MASK_STORE_256_HI(x) \
|
||||
v = _mm256_permute2x128_si256(t0##x, t1##x, 0x31); \
|
||||
_mm256_mask_storeu_epi16(boffset + (x + 8)*m_load, mmask, v);
|
||||
|
||||
#define STORE_256(x, y) {\
|
||||
__m256i v; \
|
||||
if (x == 0) { STORE_256_LO(y); } \
|
||||
else { STORE_256_HI(y); } \
|
||||
}
|
||||
|
||||
#define MASK_STORE_256(x, y) {\
|
||||
__m256i v; \
|
||||
if (x == 0) { MASK_STORE_256_LO(y); } \
|
||||
else { MASK_STORE_256_HI(y); } \
|
||||
}
|
||||
|
||||
#define SWITCH_STORE_16x(cond, func) \
|
||||
switch((cond)) {\
|
||||
case 15: func(1, 6); \
|
||||
case 14: func(1, 5); \
|
||||
case 13: func(1, 4); \
|
||||
case 12: func(1, 3); \
|
||||
case 11: func(1, 2); \
|
||||
case 10: func(1, 1); \
|
||||
case 9: func(1, 0); \
|
||||
case 8: func(0, 7); \
|
||||
case 7: func(0, 6); \
|
||||
case 6: func(0, 5); \
|
||||
case 5: func(0, 4); \
|
||||
case 4: func(0, 3); \
|
||||
case 3: func(0, 2); \
|
||||
case 2: func(0, 1); \
|
||||
case 1: func(0, 0); \
|
||||
}
|
||||
|
||||
|
||||
int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) {
|
||||
IFLOAT *aoffset, *boffset;
|
||||
IFLOAT *aoffset00, *aoffset01, *aoffset10, *aoffset11;
|
||||
IFLOAT *boffset0;
|
||||
|
||||
__m256i r0, r1, r2, r3, r4, r5, r6, r7;
|
||||
__m256i t00, t01, t02, t03, t04, t05, t06, t07;
|
||||
__m256i t10, t11, t12, t13, t14, t15, t16, t17;
|
||||
|
||||
aoffset = a;
|
||||
boffset = b;
|
||||
BLASLONG n_count = n;
|
||||
BLASLONG m_count = m;
|
||||
for (; n_count > 15; n_count -= 16) {
|
||||
aoffset00 = aoffset;
|
||||
aoffset01 = aoffset00 + 8 * lda;
|
||||
aoffset10 = aoffset01 + 8 * lda;
|
||||
aoffset11 = aoffset10 + 8 * lda;
|
||||
aoffset += 16;
|
||||
m_count = m;
|
||||
for (; m_count > 31; m_count -= 32) {
|
||||
// first 16 rows
|
||||
LOAD_A_8VEC(aoffset00);
|
||||
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
|
||||
LOAD_A_8VEC(aoffset01);
|
||||
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
|
||||
STORE_256(0, 0); STORE_256(0, 1); STORE_256(0, 2); STORE_256(0, 3);
|
||||
STORE_256(0, 4); STORE_256(0, 5); STORE_256(0, 6); STORE_256(0, 7);
|
||||
STORE_256(1, 0); STORE_256(1, 1); STORE_256(1, 2); STORE_256(1, 3);
|
||||
STORE_256(1, 4); STORE_256(1, 5); STORE_256(1, 6); STORE_256(1, 7);
|
||||
// last 16 rows
|
||||
boffset += 16;
|
||||
LOAD_A_8VEC(aoffset10);
|
||||
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
|
||||
LOAD_A_8VEC(aoffset11);
|
||||
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
|
||||
STORE_256(0, 0); STORE_256(0, 1); STORE_256(0, 2); STORE_256(0, 3);
|
||||
STORE_256(0, 4); STORE_256(0, 5); STORE_256(0, 6); STORE_256(0, 7);
|
||||
STORE_256(1, 0); STORE_256(1, 1); STORE_256(1, 2); STORE_256(1, 3);
|
||||
STORE_256(1, 4); STORE_256(1, 5); STORE_256(1, 6); STORE_256(1, 7);
|
||||
aoffset00 += 32 * lda;
|
||||
aoffset01 += 32 * lda;
|
||||
aoffset10 += 32 * lda;
|
||||
aoffset11 += 32 * lda;
|
||||
boffset += 31 * 16;
|
||||
}
|
||||
if (m_count > 1) {
|
||||
int m_load = m_count & ~1;
|
||||
m_count -= m_load;
|
||||
__mmask16 mmask;
|
||||
SWITCH_LOAD_A_8VEC(aoffset00, m_load > 8 ? 8: m_load);
|
||||
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
|
||||
if (m_load > 8) {
|
||||
SWITCH_LOAD_A_8VEC(aoffset01, m_load > 16 ? 8: m_load - 8);
|
||||
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
|
||||
}
|
||||
int this_load = m_load > 16 ? 16 : m_load;
|
||||
mmask = (1UL << this_load) - 1;
|
||||
MASK_STORE_256(0, 0); MASK_STORE_256(0, 1); MASK_STORE_256(0, 2); MASK_STORE_256(0, 3);
|
||||
MASK_STORE_256(0, 4); MASK_STORE_256(0, 5); MASK_STORE_256(0, 6); MASK_STORE_256(0, 7);
|
||||
MASK_STORE_256(1, 0); MASK_STORE_256(1, 1); MASK_STORE_256(1, 2); MASK_STORE_256(1, 3);
|
||||
MASK_STORE_256(1, 4); MASK_STORE_256(1, 5); MASK_STORE_256(1, 6); MASK_STORE_256(1, 7);
|
||||
boffset0 = boffset;
|
||||
if (m_load > 16) {
|
||||
boffset += this_load;
|
||||
SWITCH_LOAD_A_8VEC(aoffset10, m_load > 24 ? 8: m_load - 16);
|
||||
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
|
||||
if (m_load > 24) {
|
||||
SWITCH_LOAD_A_8VEC(aoffset11, m_load - 24);
|
||||
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
|
||||
}
|
||||
this_load = m_load - 16;
|
||||
mmask = (1UL << this_load) - 1;
|
||||
MASK_STORE_256(0, 0); MASK_STORE_256(0, 1); MASK_STORE_256(0, 2); MASK_STORE_256(0, 3);
|
||||
MASK_STORE_256(0, 4); MASK_STORE_256(0, 5); MASK_STORE_256(0, 6); MASK_STORE_256(0, 7);
|
||||
MASK_STORE_256(1, 0); MASK_STORE_256(1, 1); MASK_STORE_256(1, 2); MASK_STORE_256(1, 3);
|
||||
MASK_STORE_256(1, 4); MASK_STORE_256(1, 5); MASK_STORE_256(1, 6); MASK_STORE_256(1, 7);
|
||||
}
|
||||
boffset = boffset0 + 16 * m_load;
|
||||
aoffset00 += m_load * lda;
|
||||
}
|
||||
if (m_count > 0) {
|
||||
// just copy lask K to B directly
|
||||
r0 = _mm256_loadu_si256((__m256i *)(aoffset00));
|
||||
_mm256_storeu_si256((__m256i *)(boffset), r0);
|
||||
boffset += 16;
|
||||
}
|
||||
}
|
||||
if (n_count > 0) {
|
||||
__mmask16 nmask = (1UL << n_count) - 1;
|
||||
aoffset00 = aoffset;
|
||||
aoffset01 = aoffset00 + 8 * lda;
|
||||
aoffset10 = aoffset01 + 8 * lda;
|
||||
aoffset11 = aoffset10 + 8 * lda;
|
||||
m_count = m;
|
||||
for (; m_count > 31; m_count -= 32) {
|
||||
// first 16 rows
|
||||
MASK_LOAD_A_8VEC(aoffset00);
|
||||
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
|
||||
MASK_LOAD_A_8VEC(aoffset01);
|
||||
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
|
||||
SWITCH_STORE_16x(n_count, STORE_256);
|
||||
// last 16 rows
|
||||
boffset0 = boffset;
|
||||
boffset += 16;
|
||||
MASK_LOAD_A_8VEC(aoffset10);
|
||||
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
|
||||
MASK_LOAD_A_8VEC(aoffset11);
|
||||
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
|
||||
SWITCH_STORE_16x(n_count, STORE_256);
|
||||
aoffset00 += 32 * lda;
|
||||
aoffset01 += 32 * lda;
|
||||
aoffset10 += 32 * lda;
|
||||
aoffset11 += 32 * lda;
|
||||
boffset = 32 * n_count + boffset0;
|
||||
}
|
||||
if (m_count > 1) {
|
||||
int m_load = m_count & ~1;
|
||||
m_count -= m_load;
|
||||
__mmask16 mmask;
|
||||
SWITCH_MASK_LOAD_A_8VEC(aoffset00, m_load > 8 ? 8: m_load);
|
||||
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
|
||||
if (m_load > 8) {
|
||||
SWITCH_MASK_LOAD_A_8VEC(aoffset01, m_load > 16 ? 8: m_load - 8);
|
||||
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
|
||||
}
|
||||
int this_load = m_load > 16 ? 16 : m_load;
|
||||
mmask = (1UL << this_load) - 1;
|
||||
SWITCH_STORE_16x(n_count, MASK_STORE_256);
|
||||
boffset0 = boffset;
|
||||
if (m_load > 16) {
|
||||
boffset += this_load;
|
||||
SWITCH_MASK_LOAD_A_8VEC(aoffset10, m_load > 24 ? 8: m_load - 16);
|
||||
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
|
||||
if (m_load > 24) {
|
||||
SWITCH_MASK_LOAD_A_8VEC(aoffset11, m_load - 24);
|
||||
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
|
||||
}
|
||||
this_load = m_load - 16;
|
||||
mmask = (1UL << this_load) - 1;
|
||||
SWITCH_STORE_16x(n_count, MASK_STORE_256);
|
||||
}
|
||||
boffset = boffset0 + n_count * m_load;
|
||||
aoffset00 += m_load * lda;
|
||||
}
|
||||
if (m_count > 0) {
|
||||
// just copy lask K to B directly
|
||||
r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aoffset00));
|
||||
_mm256_mask_storeu_epi16((__m256i *)(boffset), nmask, r0);
|
||||
boffset += 16;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2021, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#include "sbgemm_block_microk_cooperlake.c"
|
||||
// Define micro kernels for ALPHA not ONE scenarios
|
||||
#undef ONE_ALPHA
|
||||
#include "sbgemm_microk_cooperlake_template.c"
|
||||
|
||||
// Define micro kernels for ALPHA as ONE scenarios
|
||||
#define ONE_ALPHA 1
|
||||
#include "sbgemm_microk_cooperlake_template.c"
|
||||
|
||||
int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta)
|
||||
{
|
||||
return 0;
|
||||
}
|
14
param.h
14
param.h
|
@ -1771,6 +1771,20 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
#endif
|
||||
#define USE_SGEMM_KERNEL_DIRECT 1
|
||||
|
||||
#undef SBGEMM_DEFAULT_UNROLL_N
|
||||
#undef SBGEMM_DEFAULT_UNROLL_M
|
||||
#undef SBGEMM_DEFAULT_P
|
||||
#undef SBGEMM_DEFAULT_R
|
||||
#undef SBGEMM_DEFAULT_Q
|
||||
// FIXME: actually UNROLL_M = UNROLL_N = 16
|
||||
// If M and N is equal, OpenBLAS will reuse OCOPY as ICOPY.
|
||||
// But for AMX, they are not the same, set UNROLL_M = 32 to workaround
|
||||
#define SBGEMM_DEFAULT_UNROLL_N 16
|
||||
#define SBGEMM_DEFAULT_UNROLL_M 32
|
||||
#define SBGEMM_DEFAULT_P 256
|
||||
#define SBGEMM_DEFAULT_Q 1024
|
||||
#define SBGEMM_DEFAULT_R sbgemm_r
|
||||
|
||||
#ifdef ARCH_X86
|
||||
|
||||
#define SGEMM_DEFAULT_UNROLL_M 4
|
||||
|
|
Loading…
Reference in New Issue