Merge pull request #3416 from guowangy/spr-bf16

sbgemm: add AMX-BF16 based kernel for Sapphire Rapids
This commit is contained in:
Martin Kroeker 2021-10-18 14:59:21 +02:00 committed by GitHub
commit 8cbf61792d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1079 additions and 0 deletions

View File

@ -1 +1,14 @@
include $(KERNELDIR)/KERNEL.COOPERLAKE
SBGEMM_SMALL_M_PERMIT = sbgemm_small_kernel_permit_spr.c
SBGEMM_BETA = sgemm_beta_skylakex.c
SBGEMMKERNEL = sbgemm_kernel_16x16_spr.c
SBGEMMINCOPY = sbgemm_ncopy_16_cooperlake.c
SBGEMMITCOPY = sbgemm_tcopy_16_cooperlake.c
SBGEMMONCOPY = sbgemm_oncopy_16_spr.c
SBGEMMOTCOPY = sbgemm_otcopy_16_spr.c
SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX)
SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX)
SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX)
SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX)

View File

@ -0,0 +1,50 @@
/***************************************************************************
* Copyright (c) 2021, The OpenBLAS Project
* All rights reserved.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in
* the documentation and/or other materials provided with the
* distribution.
* 3. Neither the name of the OpenBLAS project nor the names of
* its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
* USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* *****************************************************************************/
#include "common.h"
#define ALPHA_ONE
#include "sbgemm_kernel_16x16_spr_tmpl.c"
#undef ALPHA_ONE
#include "sbgemm_kernel_16x16_spr_tmpl.c"
int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOAT * iB, FLOAT * C, BLASLONG ldc)
{
/* transport to Row Major matrix for AMX requirement */
BLASLONG m, n;
IFLOAT *A, *B;
m = in;
n = im;
A = iB;
B = iA;
if (alpha == 1.0f)
return sbgemm_kernel_spr_alpha_one(m, n, k, alpha, A, B, C, ldc);
else
return sbgemm_kernel_spr_alpha(m, n, k, alpha, A, B, C, ldc);
}

View File

@ -0,0 +1,530 @@
/***************************************************************************
* Copyright (c) 2021, The OpenBLAS Project
* All rights reserved.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in
* the documentation and/or other materials provided with the
* distribution.
* 3. Neither the name of the OpenBLAS project nor the names of
* its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
* USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* *****************************************************************************/
#include <immintrin.h>
#include <string.h>
#include "common.h"
#ifndef SBGEMM_KERNEL_SPR
#define SBGEMM_KERNEL_SPR
typedef struct {
char palette_id;
char start_row;
char dummy0[14]; // bytes 2-15 reserved, must be zero
short tile_colsb[8];
char dummy1[16]; // bytes 32-47 reserved, must be zero
char tile_rows[8];
char dummy2[16]; // bytes 56-63 reserved, must be zero
} tilecfg;
/* tile0/tile1 -- A (m x 2k)
* tile2/tile3 -- B (2k x n)
* tile4-7 -- C (m x n)
*/
#define TCONF(cfg, m, n, k2) \
memset(&cfg, 0, sizeof(tilecfg)); \
cfg.palette_id = 1; \
cfg.tile_rows[0] = m; \
cfg.tile_rows[1] = m; \
cfg.tile_rows[2] = k2>>1; \
cfg.tile_rows[3] = k2>>1; \
cfg.tile_rows[4] = m; \
cfg.tile_rows[5] = m; \
cfg.tile_rows[6] = m; \
cfg.tile_rows[7] = m; \
cfg.tile_colsb[0] = k2<<1; \
cfg.tile_colsb[1] = k2<<1; \
cfg.tile_colsb[2] = n * 4; \
cfg.tile_colsb[3] = n * 4; \
cfg.tile_colsb[4] = n * 4; \
cfg.tile_colsb[5] = n * 4; \
cfg.tile_colsb[6] = n * 4; \
cfg.tile_colsb[7] = n * 4; \
_tile_loadconfig(&cfg);
/* CONFIG for handling k2 and odd tail at the same time
* tile0 -- A (m x 2k)
* tile1 -- A (m x 1)
* tile2 -- B (2k x n)
* tile3 -- B (1 x n)
* tile4 -- C (m x n)
*/
#define TCONF_TAIL(cfg, m, n, k2) \
memset(&cfg, 0, sizeof(tilecfg)); \
cfg.palette_id = 1; \
cfg.tile_rows[0] = m; \
cfg.tile_rows[1] = m; \
cfg.tile_rows[2] = k2>>1; \
cfg.tile_rows[3] = 1; \
cfg.tile_rows[4] = m; \
cfg.tile_colsb[0] = k2<<1; \
cfg.tile_colsb[1] = 4; \
cfg.tile_colsb[2] = n * 4; \
cfg.tile_colsb[3] = n * 4; \
cfg.tile_colsb[4] = n * 4; \
_tile_loadconfig(&cfg);
#define T_A0 0
#define T_A1 1
#define T_B0 2
#define T_B1 3
#define T_C00 4
#define T_C01 5
#define T_C10 6
#define T_C11 7
// FIXME: gcc11 seem have problem in tile load/store address calc,
// need to multiply with element size (2 or 4) here.
#define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2)
#define LOAD_A_TAIL(M, N) {\
__m256i ymm = _mm256_loadu_epi16(ptr_a##M); \
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
_mm512_storeu_epi16(tail_a + 16 * M, zmm); \
_tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
}
#define MASK_LOAD_A_TAIL(M, N) {\
__m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
_mm512_storeu_epi16(tail_a + 16 * M, zmm); \
_tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
}
#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2)
#define LOAD_B_TAIL(M, N) {\
__m256i ymm = _mm256_loadu_epi16(ptr_b##N); \
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
_mm512_storeu_epi16(tail_b + 16 * N, zmm); \
_tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \
}
#define MASK_LOAD_B_TAIL(M, N) {\
__m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \
_mm512_storeu_epi16(tail_b + 16 * N, zmm); \
_tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \
}
#define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N)
#define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N)
#define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4)
#define LOAD_C_F(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4)
#endif // end of SBGEMM_KERNEL_SPR
#ifdef ALPHA_ONE
#undef LOAD_C
#define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4)
#else
#undef LOAD_C
#define LOAD_C(M, N) _tile_zero(T_C##M##N)
#define ALPHA_STORE(N) \
__m512 zmm_d##N = _mm512_loadu_ps(dst##N + noffset); \
__m512 zmm_s##N = _mm512_loadu_ps(src##N + noffset); \
zmm_d##N = _mm512_fmadd_ps(alpha_512, zmm_s##N, zmm_d##N); \
_mm512_storeu_ps(dst##N + noffset, zmm_d##N);
#define MASK_APLPHA_STORE(N) \
__m512 zmm_d##N = _mm512_maskz_loadu_ps(mask, dst##N + noffset); \
__m512 zmm_s##N = _mm512_maskz_loadu_ps(mask, src##N + noffset); \
zmm_d##N = _mm512_fmadd_ps(alpha_512, zmm_s##N, zmm_d##N); \
_mm512_mask_storeu_ps(dst##N + noffset, mask, zmm_d##N);
#endif // end of ALPHA_ONE
#ifdef ALPHA_ONE
int sbgemm_kernel_spr_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc)
#else
int sbgemm_kernel_spr_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc)
#endif
{
/* Row Major matrix for AMX requirement */
IFLOAT *ptr_a = A, *ptr_b = B;
IFLOAT *ptr_b0, *ptr_b1;
IFLOAT *ptr_a0, *ptr_a1;
FLOAT *ptr_c = C;
FLOAT *ptr_c00, *ptr_c01, *ptr_c10, *ptr_c11;
BLASLONG lda, ldb;
BLASLONG m_count = m;
BLASLONG n_count, k_count;
#ifndef ALPHA_ONE
// make sure each row is 64 bytes aligned
BLASLONG cn = (n & 31) ? (n & ~31) + 32 : n;
FLOAT *raw_tmp_c;
if (k < 32) {
// only need to zero buff in this situation
raw_tmp_c = (FLOAT *)calloc(1, sizeof(FLOAT) * m * cn + 64);
} else {
raw_tmp_c = (FLOAT *)malloc(sizeof(FLOAT) * m * cn + 64);
}
// align buf to 64 byte boundary
FLOAT *tmp_c = (FLOAT *)(((uintptr_t) raw_tmp_c + 63) & ~(uintptr_t)63);
ptr_c = tmp_c;
BLASLONG ldc_o = ldc;
ldc = cn;
#endif
IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64)));
IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64)));
tilecfg cfg;
if (k > 31) {
for (; m_count > 31; m_count -= 32) {
ptr_b = B;
ptr_c00 = ptr_c;
ptr_c01 = ptr_c00 + 16;
ptr_c10 = ptr_c + 16 * ldc;
ptr_c11 = ptr_c10 + 16;
ptr_c += 32 * ldc;
n_count = n;
TCONF(cfg, 16, 16, 32);
for (; n_count > 31; n_count -= 32) {
ptr_a0 = ptr_a;
ptr_a1 = ptr_a + 16 * k;
ptr_b0 = ptr_b;
ptr_b1 = ptr_b + 16 * k;
ptr_b += 32 * k;
lda = 32;
ldb = 32;
LOAD_C(0, 0); LOAD_C(0, 1);
LOAD_C(1, 0); LOAD_C(1, 1);
k_count = k;
for (; k_count > 31; k_count -= 32) {
LOAD_A(0, x); LOAD_A(1, x);
ptr_a0 += 16 * 32;
ptr_a1 += 16 * 32;
LOAD_B(x, 0); LOAD_B(x, 1);
ptr_b0 += 16 * 32;
ptr_b1 += 16 * 32;
MATMUL(0, 0); MATMUL(0, 1);
MATMUL(1, 0); MATMUL(1, 1);
}
STORE_C(0, 0); STORE_C(0, 1);
STORE_C(1, 0); STORE_C(1, 1);
ptr_c00 += 32;
ptr_c01 += 32;
ptr_c10 += 32;
ptr_c11 += 32;
}
for (; n_count > 0; n_count -= 16) {
int tail_n = (n_count > 16) ? 16: n_count;
ptr_a0 = ptr_a;
ptr_a1 = ptr_a + 16 * k;
ptr_b0 = ptr_b;
ptr_b += tail_n * k;
lda = 32;
ldb = 2 * tail_n;
TCONF(cfg, 16, tail_n, 32);
LOAD_C(0, 0);
LOAD_C(1, 0);
k_count = k;
for (; k_count > 31; k_count -= 32) {
LOAD_A(0, x); LOAD_A(1, x);
ptr_a0 += 16 * 32;
ptr_a1 += 16 * 32;
LOAD_B(x, 0);
ptr_b0 += tail_n * 32;
MATMUL(0, 0);
MATMUL(1, 0);
}
STORE_C(0, 0);
STORE_C(1, 0);
ptr_c00 += tail_n;
ptr_c10 += tail_n;
}
ptr_a += 32 * k;
}
for (; m_count > 0; m_count -= 16) {
// process at most 16 m at a time
int tail_m = (m_count > 16) ? 16: m_count;
ptr_b = B;
ptr_c00 = ptr_c;
ptr_c01 = ptr_c00 + 16;
ptr_c += tail_m * ldc;
n_count = n;
TCONF(cfg, tail_m, 16, 32);
for (; n_count > 31; n_count -= 32) {
ptr_a0 = ptr_a;
ptr_b0 = ptr_b;
ptr_b1 = ptr_b + 16 * k;
ptr_b += 32 * k;
lda = 32;
ldb = 32;
LOAD_C(0, 0); LOAD_C(0, 1);
k_count = k;
for (; k_count > 31; k_count -= 32) {
LOAD_A(0, x);
ptr_a0 += tail_m * 32;
LOAD_B(x, 0); LOAD_B(x, 1);
ptr_b0 += 16 * 32;
ptr_b1 += 16 * 32;
MATMUL(0, 0); MATMUL(0, 1);
}
STORE_C(0, 0); STORE_C(0, 1);
ptr_c00 += 32;
ptr_c01 += 32;
}
for (; n_count > 0; n_count -= 16) {
int tail_n = (n_count > 16) ? 16: n_count;
ptr_a0 = ptr_a;
ptr_b0 = ptr_b;
ptr_b += tail_n * k;
lda = 32;
ldb = 2 * tail_n;
TCONF(cfg, tail_m, tail_n, 32);
LOAD_C(0, 0);
k_count = k;
for (; k_count > 31; k_count -= 32) {
LOAD_A(0, x);
ptr_a0 += tail_m * 32;
LOAD_B(x, 0);
ptr_b0 += tail_n * 32;
MATMUL(0, 0);
}
STORE_C(0, 0);
ptr_c00 += tail_n;
}
ptr_a += tail_m * k;
}
}
// process for k < 32
BLASLONG k32 = k & ~31;
BLASLONG k2 = k & ~1;
if (k32 != k) {
int remain_k2 = k2 - k32;
m_count = m;
ptr_a = A;
#ifndef ALPHA_ONE
ptr_c = tmp_c;
#else
ptr_c = C;
#endif
if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0)
for (; m_count > 0; m_count -= 16) {
int tail_m = (m_count > 16) ? 16: m_count;
__mmask16 amask = (1UL << tail_m) - 1;
ptr_a0 = ptr_a + tail_m * k32;
ptr_a1 = ptr_a + tail_m * k2;
ptr_a += tail_m * k;
ptr_b = B;
ptr_c00 = ptr_c;
ptr_c += tail_m * ldc;
n_count = n;
lda = remain_k2;
ldb = 32;
if (n_count > 15) {
TCONF_TAIL(cfg, tail_m, 16, remain_k2);
LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x);
for (; n_count > 15; n_count -= 16) {
ptr_b0 = ptr_b + 16 * k32;
ptr_b1 = ptr_b + 16 * k2;
LOAD_C_F(0, 0);
LOAD_B(x, 0); LOAD_B_TAIL(x, 1);
MATMUL(0, 0); MATMUL_TAIL(1, 1);
STORE_C(0, 0);
ptr_b += 16 * k;
ptr_c00 += 16;
}
}
if (n_count > 0) {
int tail_n = (n_count > 16) ? 16: n_count;
__mmask16 bmask = (1UL << tail_n) - 1;
ptr_b0 = ptr_b + tail_n * k32;
ptr_b1 = ptr_b + tail_n * k2;
ldb = 2 * tail_n;
TCONF_TAIL(cfg, tail_m, tail_n, remain_k2);
LOAD_C_F(0, 0);
LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x);
LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1);
MATMUL(0, 0); MATMUL_TAIL(1, 1);
STORE_C(0, 0);
}
}
} else if (remain_k2 > 0) { // k%32 = 2x
for (; m_count > 0; m_count -= 16) {
int tail_m = (m_count > 16) ? 16: m_count;
ptr_a0 = ptr_a + tail_m * k32;
ptr_a += tail_m * k;
ptr_b = B;
ptr_c00 = ptr_c;
ptr_c += tail_m * ldc;
n_count = n;
lda = remain_k2;
ldb = 32;
if (n_count > 15) {
TCONF(cfg, tail_m, 16, remain_k2);
LOAD_A(0, x);
for (; n_count > 15; n_count -= 16) {
ptr_b0 = ptr_b + 16 * k32;
LOAD_C_F(0, 0);
LOAD_B(x, 0);
MATMUL(0, 0);
STORE_C(0, 0);
ptr_b += 16 * k;
ptr_c00 += 16;
}
}
if (n_count > 0) {
int tail_n = (n_count > 16) ? 16: n_count;
ptr_b0 = ptr_b + tail_n * k32;
ldb = 2 * tail_n;
TCONF(cfg, tail_m, tail_n, remain_k2);
LOAD_C_F(0, 0);
LOAD_A(0, x);
LOAD_B(x, 0);
MATMUL(0, 0);
STORE_C(0, 0);
}
}
} else { // k%32 = 1
for (; m_count > 0; m_count -= 16) {
int tail_m = (m_count > 16) ? 16: m_count;
__mmask16 amask = (1UL << tail_m) - 1;
ptr_a0 = ptr_a + tail_m * k2;
ptr_a += tail_m * k;
ptr_b = B;
ptr_c00 = ptr_c;
ptr_c += tail_m * ldc;
n_count = n;
if (n_count > 15) {
TCONF(cfg, tail_m, 16, 2);
MASK_LOAD_A_TAIL(0, x);
for (; n_count > 15; n_count -= 16) {
ptr_b0 = ptr_b + 16 * k2;
LOAD_C_F(0, 0);
LOAD_B_TAIL(x, 0);
MATMUL(0, 0);
STORE_C(0, 0);
ptr_b += 16 * k;
ptr_c00 += 16;
}
}
if (n_count > 0) {
int tail_n = (n_count > 16) ? 16: n_count;
__mmask16 bmask = (1UL << tail_n) - 1;
ptr_b0 = ptr_b + tail_n * k2;
TCONF(cfg, tail_m, tail_n, 2);
LOAD_C_F(0, 0);
MASK_LOAD_A_TAIL(0, x);
MASK_LOAD_B_TAIL(x, 0);
MATMUL(0, 0);
STORE_C(0, 0);
}
}
}
}
#ifndef ALPHA_ONE
__m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha));
BLASLONG n16 = n & ~15;
BLASLONG noffset;
FLOAT *src0, *src1, *src2, *src3;
FLOAT *dst0, *dst1, *dst2, *dst3;
FLOAT *src = tmp_c;
FLOAT *dst = C;
m_count = m;
for (; m_count > 3; m_count -= 4) {
src0 = src;
src1 = src0 + ldc;
src2 = src1 + ldc;
src3 = src2 + ldc;
src += 4 * ldc;
dst0 = dst;
dst1 = dst0 + ldc_o;
dst2 = dst1 + ldc_o;
dst3 = dst2 + ldc_o;
dst += 4 * ldc_o;
noffset = 0;
for (; noffset < n16; noffset += 16) {
ALPHA_STORE(0);
ALPHA_STORE(1);
ALPHA_STORE(2);
ALPHA_STORE(3);
}
if (noffset < n) {
__mmask16 mask = (1UL << (n - noffset)) - 1;
MASK_APLPHA_STORE(0);
MASK_APLPHA_STORE(1);
MASK_APLPHA_STORE(2);
MASK_APLPHA_STORE(3);
}
}
for (; m_count > 1; m_count -= 2) {
src0 = src;
src1 = src0 + ldc;
src += 2 * ldc;
dst0 = dst;
dst1 = dst0 + ldc_o;
dst += 2 * ldc_o;
noffset = 0;
for (; noffset < n16; noffset += 16) {
ALPHA_STORE(0);
ALPHA_STORE(1);
}
if (noffset < n) {
__mmask16 mask = (1UL << (n - noffset)) - 1;
MASK_APLPHA_STORE(0);
MASK_APLPHA_STORE(1);
}
}
for (; m_count > 0; m_count -= 1) {
src0 = src;
dst0 = dst;
noffset = 0;
for (; noffset < n16; noffset += 16) {
ALPHA_STORE(0);
}
if (noffset < n) {
__mmask16 mask = (1UL << (n - noffset)) - 1;
MASK_APLPHA_STORE(0);
}
}
free(raw_tmp_c);
#endif
return 0;
}

View File

@ -0,0 +1,128 @@
/***************************************************************************
* Copyright (c) 2021, The OpenBLAS Project
* All rights reserved.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in
* the documentation and/or other materials provided with the
* distribution.
* 3. Neither the name of the OpenBLAS project nor the names of
* its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
* USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* *****************************************************************************/
#include <immintrin.h>
#include "common.h"
typedef struct {
char palette_id;
char start_row;
char dummy0[14]; // bytes 2-15 reserved, must be zero
short tile_colsb[8];
char dummy1[16]; // bytes 32-47 reserved, must be zero
char tile_rows[8];
char dummy2[16]; // bytes 56-63 reserved, must be zero
} tilecfg;
#define T_16x32 0
#define T_16xm 1
#define T_nx32 2
#define T_nxm 3
#define TCONF(cfg, m, n) \
memset(&cfg, 0, sizeof(tilecfg)); \
cfg.palette_id = 1; \
cfg.tile_rows[T_16x32] = 16; \
cfg.tile_colsb[T_16x32] = 64; \
if (m) { \
cfg.tile_rows[T_16xm] = 16; \
cfg.tile_colsb[T_16xm] = m * 2; \
} \
if (n) { \
cfg.tile_rows[T_nx32] = n; \
cfg.tile_colsb[T_nx32] = 64; \
} \
if (m && n) { \
cfg.tile_rows[T_nxm] = n; \
cfg.tile_colsb[T_nxm] = m * 2; \
} \
_tile_loadconfig(&cfg);
int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) {
BLASLONG i, j;
IFLOAT *aoffset, *boffset;
IFLOAT *aoffset0;
aoffset = a;
boffset = b;
BLASLONG n16 = n & ~15;
BLASLONG m32 = m & ~31;
BLASLONG m2 = m & ~1;
BLASLONG tail_m = m2 - m32;
BLASLONG tail_n = n - n16;
tilecfg cfg;
TCONF(cfg, tail_m, tail_n);
for (j = 0; j < n16; j += 16) {
aoffset0 = aoffset;
for (i = 0; i < m32; i += 32) {
_tile_loadd(T_16x32, aoffset0, lda * 2);
_tile_stored(T_16x32, boffset, 32 * 2);
aoffset0 += 32;
boffset += 32 * 16;
}
if (i < m2) {
_tile_loadd(T_16xm, aoffset0, lda * 2);
_tile_stored(T_16xm, boffset, tail_m * 2);
aoffset0 += tail_m;
boffset += tail_m * 16;
i = m2;
}
if (i < m) {
/* the tail odd k should put alone */
for (int ii = 0; ii < 16; ii++) {
*(boffset + ii) = *(aoffset0 + lda * ii);
}
boffset += 16;
}
aoffset += 16 * lda;
}
if (j < n) {
aoffset0 = aoffset;
for (i = 0; i < m32; i += 32) {
_tile_loadd(T_nx32, aoffset0, lda * 2);
_tile_stored(T_nx32, boffset, 32 * 2);
aoffset0 += 32;
boffset += 32 * tail_n;
}
if (i < m2) {
_tile_loadd(T_nxm, aoffset0, lda * 2);
_tile_stored(T_nxm, boffset, tail_m * 2);
aoffset0 += tail_m;
boffset += tail_m * tail_n;
}
if (i < m) {
for (int ii = 0; ii < tail_n; ii++) {
*(boffset + ii) = *(aoffset0 + lda * ii);
}
}
}
return 0;
}

View File

@ -0,0 +1,302 @@
/***************************************************************************
* Copyright (c) 2021, The OpenBLAS Project
* All rights reserved.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in
* the documentation and/or other materials provided with the
* distribution.
* 3. Neither the name of the OpenBLAS project nor the names of
* its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
* USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* *****************************************************************************/
#include <immintrin.h>
#include "common.h"
#define LOAD_A_8VEC(aptr) \
r0 = _mm256_loadu_si256((__m256i *)(aptr + lda*0)); \
r1 = _mm256_loadu_si256((__m256i *)(aptr + lda*1)); \
r2 = _mm256_loadu_si256((__m256i *)(aptr + lda*2)); \
r3 = _mm256_loadu_si256((__m256i *)(aptr + lda*3)); \
r4 = _mm256_loadu_si256((__m256i *)(aptr + lda*4)); \
r5 = _mm256_loadu_si256((__m256i *)(aptr + lda*5)); \
r6 = _mm256_loadu_si256((__m256i *)(aptr + lda*6)); \
r7 = _mm256_loadu_si256((__m256i *)(aptr + lda*7));
#define MASK_LOAD_A_8VEC(aptr) \
r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*0)); \
r1 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*1)); \
r2 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*2)); \
r3 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*3)); \
r4 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*4)); \
r5 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*5)); \
r6 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*6)); \
r7 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*7));
#define SWITCH_LOAD_A_8VEC(aptr, cond) \
switch((cond)) { \
case 8: r7 = _mm256_loadu_si256((__m256i *)(aptr + lda*7)); \
case 7: r6 = _mm256_loadu_si256((__m256i *)(aptr + lda*6)); \
case 6: r5 = _mm256_loadu_si256((__m256i *)(aptr + lda*5)); \
case 5: r4 = _mm256_loadu_si256((__m256i *)(aptr + lda*4)); \
case 4: r3 = _mm256_loadu_si256((__m256i *)(aptr + lda*3)); \
case 3: r2 = _mm256_loadu_si256((__m256i *)(aptr + lda*2)); \
case 2: r1 = _mm256_loadu_si256((__m256i *)(aptr + lda*1)); \
case 1: r0 = _mm256_loadu_si256((__m256i *)(aptr + lda*0)); \
}
#define SWITCH_MASK_LOAD_A_8VEC(aptr, cond) \
switch((cond)) { \
case 8: r7 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*7)); \
case 7: r6 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*6)); \
case 6: r5 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*5)); \
case 5: r4 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*4)); \
case 4: r3 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*3)); \
case 3: r2 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*2)); \
case 2: r1 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*1)); \
case 1: r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*0)); \
}
#define REORDER_8x16(t0, t1, t2, t3, t4, t5, t6, t7) \
t0 = _mm256_unpacklo_epi16(r0, r1); \
t1 = _mm256_unpackhi_epi16(r0, r1); \
t2 = _mm256_unpacklo_epi16(r2, r3); \
t3 = _mm256_unpackhi_epi16(r2, r3); \
t4 = _mm256_unpacklo_epi16(r4, r5); \
t5 = _mm256_unpackhi_epi16(r4, r5); \
t6 = _mm256_unpacklo_epi16(r6, r7); \
t7 = _mm256_unpackhi_epi16(r6, r7); \
r0 = _mm256_unpacklo_epi32(t0, t2); \
r1 = _mm256_unpacklo_epi32(t1, t3); \
r2 = _mm256_unpacklo_epi32(t4, t6); \
r3 = _mm256_unpacklo_epi32(t5, t7); \
r4 = _mm256_unpackhi_epi32(t0, t2); \
r5 = _mm256_unpackhi_epi32(t1, t3); \
r6 = _mm256_unpackhi_epi32(t4, t6); \
r7 = _mm256_unpackhi_epi32(t5, t7); \
t0 = _mm256_unpacklo_epi64(r0, r2); \
t1 = _mm256_unpackhi_epi64(r0, r2); \
t2 = _mm256_unpacklo_epi64(r4, r6); \
t3 = _mm256_unpackhi_epi64(r4, r6); \
t4 = _mm256_unpacklo_epi64(r1, r3); \
t5 = _mm256_unpackhi_epi64(r1, r3); \
t6 = _mm256_unpacklo_epi64(r5, r7); \
t7 = _mm256_unpackhi_epi64(r5, r7);
#define STORE_256_LO(x) \
v = _mm256_permute2x128_si256(t0##x, t1##x, 0x20); \
_mm256_storeu_si256((__m256i *)(boffset + x*32), v);
#define STORE_256_HI(x) \
v = _mm256_permute2x128_si256(t0##x, t1##x, 0x31); \
_mm256_storeu_si256((__m256i *)(boffset + (x + 8)*32), v);
#define MASK_STORE_256_LO(x) \
v = _mm256_permute2x128_si256(t0##x, t1##x, 0x20); \
_mm256_mask_storeu_epi16(boffset + x*m_load, mmask, v);
#define MASK_STORE_256_HI(x) \
v = _mm256_permute2x128_si256(t0##x, t1##x, 0x31); \
_mm256_mask_storeu_epi16(boffset + (x + 8)*m_load, mmask, v);
#define STORE_256(x, y) {\
__m256i v; \
if (x == 0) { STORE_256_LO(y); } \
else { STORE_256_HI(y); } \
}
#define MASK_STORE_256(x, y) {\
__m256i v; \
if (x == 0) { MASK_STORE_256_LO(y); } \
else { MASK_STORE_256_HI(y); } \
}
#define SWITCH_STORE_16x(cond, func) \
switch((cond)) {\
case 15: func(1, 6); \
case 14: func(1, 5); \
case 13: func(1, 4); \
case 12: func(1, 3); \
case 11: func(1, 2); \
case 10: func(1, 1); \
case 9: func(1, 0); \
case 8: func(0, 7); \
case 7: func(0, 6); \
case 6: func(0, 5); \
case 5: func(0, 4); \
case 4: func(0, 3); \
case 3: func(0, 2); \
case 2: func(0, 1); \
case 1: func(0, 0); \
}
int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) {
IFLOAT *aoffset, *boffset;
IFLOAT *aoffset00, *aoffset01, *aoffset10, *aoffset11;
IFLOAT *boffset0;
__m256i r0, r1, r2, r3, r4, r5, r6, r7;
__m256i t00, t01, t02, t03, t04, t05, t06, t07;
__m256i t10, t11, t12, t13, t14, t15, t16, t17;
aoffset = a;
boffset = b;
BLASLONG n_count = n;
BLASLONG m_count = m;
for (; n_count > 15; n_count -= 16) {
aoffset00 = aoffset;
aoffset01 = aoffset00 + 8 * lda;
aoffset10 = aoffset01 + 8 * lda;
aoffset11 = aoffset10 + 8 * lda;
aoffset += 16;
m_count = m;
for (; m_count > 31; m_count -= 32) {
// first 16 rows
LOAD_A_8VEC(aoffset00);
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
LOAD_A_8VEC(aoffset01);
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
STORE_256(0, 0); STORE_256(0, 1); STORE_256(0, 2); STORE_256(0, 3);
STORE_256(0, 4); STORE_256(0, 5); STORE_256(0, 6); STORE_256(0, 7);
STORE_256(1, 0); STORE_256(1, 1); STORE_256(1, 2); STORE_256(1, 3);
STORE_256(1, 4); STORE_256(1, 5); STORE_256(1, 6); STORE_256(1, 7);
// last 16 rows
boffset += 16;
LOAD_A_8VEC(aoffset10);
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
LOAD_A_8VEC(aoffset11);
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
STORE_256(0, 0); STORE_256(0, 1); STORE_256(0, 2); STORE_256(0, 3);
STORE_256(0, 4); STORE_256(0, 5); STORE_256(0, 6); STORE_256(0, 7);
STORE_256(1, 0); STORE_256(1, 1); STORE_256(1, 2); STORE_256(1, 3);
STORE_256(1, 4); STORE_256(1, 5); STORE_256(1, 6); STORE_256(1, 7);
aoffset00 += 32 * lda;
aoffset01 += 32 * lda;
aoffset10 += 32 * lda;
aoffset11 += 32 * lda;
boffset += 31 * 16;
}
if (m_count > 1) {
int m_load = m_count & ~1;
m_count -= m_load;
__mmask16 mmask;
SWITCH_LOAD_A_8VEC(aoffset00, m_load > 8 ? 8: m_load);
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
if (m_load > 8) {
SWITCH_LOAD_A_8VEC(aoffset01, m_load > 16 ? 8: m_load - 8);
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
}
int this_load = m_load > 16 ? 16 : m_load;
mmask = (1UL << this_load) - 1;
MASK_STORE_256(0, 0); MASK_STORE_256(0, 1); MASK_STORE_256(0, 2); MASK_STORE_256(0, 3);
MASK_STORE_256(0, 4); MASK_STORE_256(0, 5); MASK_STORE_256(0, 6); MASK_STORE_256(0, 7);
MASK_STORE_256(1, 0); MASK_STORE_256(1, 1); MASK_STORE_256(1, 2); MASK_STORE_256(1, 3);
MASK_STORE_256(1, 4); MASK_STORE_256(1, 5); MASK_STORE_256(1, 6); MASK_STORE_256(1, 7);
boffset0 = boffset;
if (m_load > 16) {
boffset += this_load;
SWITCH_LOAD_A_8VEC(aoffset10, m_load > 24 ? 8: m_load - 16);
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
if (m_load > 24) {
SWITCH_LOAD_A_8VEC(aoffset11, m_load - 24);
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
}
this_load = m_load - 16;
mmask = (1UL << this_load) - 1;
MASK_STORE_256(0, 0); MASK_STORE_256(0, 1); MASK_STORE_256(0, 2); MASK_STORE_256(0, 3);
MASK_STORE_256(0, 4); MASK_STORE_256(0, 5); MASK_STORE_256(0, 6); MASK_STORE_256(0, 7);
MASK_STORE_256(1, 0); MASK_STORE_256(1, 1); MASK_STORE_256(1, 2); MASK_STORE_256(1, 3);
MASK_STORE_256(1, 4); MASK_STORE_256(1, 5); MASK_STORE_256(1, 6); MASK_STORE_256(1, 7);
}
boffset = boffset0 + 16 * m_load;
aoffset00 += m_load * lda;
}
if (m_count > 0) {
// just copy lask K to B directly
r0 = _mm256_loadu_si256((__m256i *)(aoffset00));
_mm256_storeu_si256((__m256i *)(boffset), r0);
boffset += 16;
}
}
if (n_count > 0) {
__mmask16 nmask = (1UL << n_count) - 1;
aoffset00 = aoffset;
aoffset01 = aoffset00 + 8 * lda;
aoffset10 = aoffset01 + 8 * lda;
aoffset11 = aoffset10 + 8 * lda;
m_count = m;
for (; m_count > 31; m_count -= 32) {
// first 16 rows
MASK_LOAD_A_8VEC(aoffset00);
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
MASK_LOAD_A_8VEC(aoffset01);
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
SWITCH_STORE_16x(n_count, STORE_256);
// last 16 rows
boffset0 = boffset;
boffset += 16;
MASK_LOAD_A_8VEC(aoffset10);
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
MASK_LOAD_A_8VEC(aoffset11);
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
SWITCH_STORE_16x(n_count, STORE_256);
aoffset00 += 32 * lda;
aoffset01 += 32 * lda;
aoffset10 += 32 * lda;
aoffset11 += 32 * lda;
boffset = 32 * n_count + boffset0;
}
if (m_count > 1) {
int m_load = m_count & ~1;
m_count -= m_load;
__mmask16 mmask;
SWITCH_MASK_LOAD_A_8VEC(aoffset00, m_load > 8 ? 8: m_load);
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
if (m_load > 8) {
SWITCH_MASK_LOAD_A_8VEC(aoffset01, m_load > 16 ? 8: m_load - 8);
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
}
int this_load = m_load > 16 ? 16 : m_load;
mmask = (1UL << this_load) - 1;
SWITCH_STORE_16x(n_count, MASK_STORE_256);
boffset0 = boffset;
if (m_load > 16) {
boffset += this_load;
SWITCH_MASK_LOAD_A_8VEC(aoffset10, m_load > 24 ? 8: m_load - 16);
REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07);
if (m_load > 24) {
SWITCH_MASK_LOAD_A_8VEC(aoffset11, m_load - 24);
REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17);
}
this_load = m_load - 16;
mmask = (1UL << this_load) - 1;
SWITCH_STORE_16x(n_count, MASK_STORE_256);
}
boffset = boffset0 + n_count * m_load;
aoffset00 += m_load * lda;
}
if (m_count > 0) {
// just copy lask K to B directly
r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aoffset00));
_mm256_mask_storeu_epi16((__m256i *)(boffset), nmask, r0);
boffset += 16;
}
}
return 0;
}

View File

@ -0,0 +1,42 @@
/***************************************************************************
Copyright (c) 2021, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#include "common.h"
#include "sbgemm_block_microk_cooperlake.c"
// Define micro kernels for ALPHA not ONE scenarios
#undef ONE_ALPHA
#include "sbgemm_microk_cooperlake_template.c"
// Define micro kernels for ALPHA as ONE scenarios
#define ONE_ALPHA 1
#include "sbgemm_microk_cooperlake_template.c"
int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta)
{
return 0;
}

14
param.h
View File

@ -1771,6 +1771,20 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#endif
#define USE_SGEMM_KERNEL_DIRECT 1
#undef SBGEMM_DEFAULT_UNROLL_N
#undef SBGEMM_DEFAULT_UNROLL_M
#undef SBGEMM_DEFAULT_P
#undef SBGEMM_DEFAULT_R
#undef SBGEMM_DEFAULT_Q
// FIXME: actually UNROLL_M = UNROLL_N = 16
// If M and N is equal, OpenBLAS will reuse OCOPY as ICOPY.
// But for AMX, they are not the same, set UNROLL_M = 32 to workaround
#define SBGEMM_DEFAULT_UNROLL_N 16
#define SBGEMM_DEFAULT_UNROLL_M 32
#define SBGEMM_DEFAULT_P 256
#define SBGEMM_DEFAULT_Q 1024
#define SBGEMM_DEFAULT_R sbgemm_r
#ifdef ARCH_X86
#define SGEMM_DEFAULT_UNROLL_M 4