From f018aa342a5d97603a5f59fa1feb36e1c77e0571 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Fri, 17 Sep 2021 00:48:52 -0700 Subject: [PATCH] sbgemm: spr: kernel handle alpha != 1.0 --- kernel/x86_64/sbgemm_kernel_16x16_spr.c | 392 +------------- kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c | 521 +++++++++++++++++++ 2 files changed, 529 insertions(+), 384 deletions(-) create mode 100644 kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr.c b/kernel/x86_64/sbgemm_kernel_16x16_spr.c index b34035896..955db3163 100644 --- a/kernel/x86_64/sbgemm_kernel_16x16_spr.c +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr.c @@ -25,109 +25,12 @@ * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * *****************************************************************************/ -#include -#include #include "common.h" -typedef struct { - char palette_id; - char start_row; - char dummy0[14]; // bytes 2-15 reserved, must be zero - short tile_colsb[8]; - char dummy1[16]; // bytes 32-47 reserved, must be zero - char tile_rows[8]; - char dummy2[16]; // bytes 56-63 reserved, must be zero -} tilecfg; - -/* tile0/tile1 -- A (m x 2k) - * tile2/tile3 -- B (2k x n) - * tile4-7 -- C (m x n) - */ -#define TCONF(cfg, m, n, k2) \ - memset(&cfg, 0, sizeof(tilecfg)); \ - cfg.palette_id = 1; \ - cfg.tile_rows[0] = m; \ - cfg.tile_rows[1] = m; \ - cfg.tile_rows[2] = k2>>1; \ - cfg.tile_rows[3] = k2>>1; \ - cfg.tile_rows[4] = m; \ - cfg.tile_rows[5] = m; \ - cfg.tile_rows[6] = m; \ - cfg.tile_rows[7] = m; \ - cfg.tile_colsb[0] = k2<<1; \ - cfg.tile_colsb[1] = k2<<1; \ - cfg.tile_colsb[2] = n * 4; \ - cfg.tile_colsb[3] = n * 4; \ - cfg.tile_colsb[4] = n * 4; \ - cfg.tile_colsb[5] = n * 4; \ - cfg.tile_colsb[6] = n * 4; \ - cfg.tile_colsb[7] = n * 4; \ - _tile_loadconfig(&cfg); - -/* CONFIG for handling k2 and odd tail at the same time - * tile0 -- A (m x 2k) - * tile1 -- A (m x 1) - * tile2 -- B (2k x n) - * tile3 -- B (1 x n) - * tile4 -- C (m x n) - */ -#define TCONF_TAIL(cfg, m, n, k2) \ - memset(&cfg, 0, sizeof(tilecfg)); \ - cfg.palette_id = 1; \ - cfg.tile_rows[0] = m; \ - cfg.tile_rows[1] = m; \ - cfg.tile_rows[2] = k2>>1; \ - cfg.tile_rows[3] = 1; \ - cfg.tile_rows[4] = m; \ - cfg.tile_colsb[0] = k2<<1; \ - cfg.tile_colsb[1] = 4; \ - cfg.tile_colsb[2] = n * 4; \ - cfg.tile_colsb[3] = n * 4; \ - cfg.tile_colsb[4] = n * 4; \ - _tile_loadconfig(&cfg); - -#define T_A0 0 -#define T_A1 1 -#define T_B0 2 -#define T_B1 3 -#define T_C00 4 -#define T_C01 5 -#define T_C10 6 -#define T_C11 7 - -// FIXME: gcc11 seem have problem in tile load/store address calc, -// need to multiply with element size (2 or 4) here. -#define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2) -#define LOAD_A_TAIL(M, N) {\ - __m256i ymm = _mm256_loadu_epi16(ptr_a##M); \ - __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ - _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ - _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ -} -#define MASK_LOAD_A_TAIL(M, N) {\ - __m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \ - __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ - _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ - _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ -} -#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2) -#define LOAD_B_TAIL(M, N) {\ - __m256i ymm = _mm256_loadu_epi16(ptr_b##N); \ - __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ - _mm512_storeu_epi16(tail_b + 16 * N, zmm); \ - _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ -} -#define MASK_LOAD_B_TAIL(M, N) {\ - __m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \ - __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ - _mm512_storeu_epi16(tail_b + 16 * N, zmm); \ - _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ -} - -#define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4) -#define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N) -#define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N) -#define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4) +#define ALPHA_ONE +#include "sbgemm_kernel_16x16_spr_tmpl.c" +#undef ALPHA_ONE +#include "sbgemm_kernel_16x16_spr_tmpl.c" int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOAT * iB, FLOAT * C, BLASLONG ldc) @@ -140,287 +43,8 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA A = iB; B = iA; - IFLOAT *ptr_a = A, *ptr_b = B; - IFLOAT *ptr_b0, *ptr_b1; - IFLOAT *ptr_a0, *ptr_a1; - FLOAT *ptr_c = C; - FLOAT *ptr_c00, *ptr_c01, *ptr_c10, *ptr_c11; - - BLASLONG lda, ldb; - BLASLONG m_count = m; - BLASLONG n_count, k_count; - - - IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64))); - IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64))); - tilecfg cfg; - - if (k < 32) - goto tail_k; - - for (; m_count > 31; m_count -= 32) { - ptr_b = B; - - ptr_c00 = ptr_c; - ptr_c01 = ptr_c00 + 16; - ptr_c10 = ptr_c + 16 * ldc; - ptr_c11 = ptr_c10 + 16; - ptr_c += 32 * ldc; - n_count = n; - TCONF(cfg, 16, 16, 32); - for (; n_count > 31; n_count -= 32) { - ptr_a0 = ptr_a; - ptr_a1 = ptr_a + 16 * k; - - ptr_b0 = ptr_b; - ptr_b1 = ptr_b + 16 * k; - ptr_b += 32 * k; - - lda = 32; - ldb = 32; - LOAD_C(0, 0); LOAD_C(0, 1); - LOAD_C(1, 0); LOAD_C(1, 1); - k_count = k; - for (; k_count > 31; k_count -= 32) { - LOAD_A(0, x); LOAD_A(1, x); - ptr_a0 += 16 * 32; - ptr_a1 += 16 * 32; - LOAD_B(x, 0); LOAD_B(x, 1); - ptr_b0 += 16 * 32; - ptr_b1 += 16 * 32; - - MATMUL(0, 0); MATMUL(0, 1); - MATMUL(1, 0); MATMUL(1, 1); - } - STORE_C(0, 0); STORE_C(0, 1); - STORE_C(1, 0); STORE_C(1, 1); - ptr_c00 += 32; - ptr_c01 += 32; - ptr_c10 += 32; - ptr_c11 += 32; - } - for (; n_count > 0; n_count -= 16) { - int tail_n = (n_count > 16) ? 16: n_count; - __mmask16 bmask = (1UL << tail_n) - 1; - ptr_a0 = ptr_a; - ptr_a1 = ptr_a + 16 * k; - - ptr_b0 = ptr_b; - ptr_b += tail_n * k; - - lda = 32; - ldb = 2 * tail_n; - TCONF(cfg, 16, tail_n, 32); - LOAD_C(0, 0); - LOAD_C(1, 0); - k_count = k; - for (; k_count > 31; k_count -= 32) { - LOAD_A(0, x); LOAD_A(1, x); - ptr_a0 += 16 * 32; - ptr_a1 += 16 * 32; - LOAD_B(x, 0); - ptr_b0 += tail_n * 32; - - MATMUL(0, 0); - MATMUL(1, 0); - } - STORE_C(0, 0); - STORE_C(1, 0); - ptr_c00 += tail_n; - ptr_c10 += tail_n; - } - ptr_a += 32 * k; - } - for (; m_count > 0; m_count -= 16) { - // process at most 16 m at a time - int tail_m = (m_count > 16) ? 16: m_count; - __mmask16 amask = (1UL << tail_m) - 1; - - ptr_b = B; - - ptr_c00 = ptr_c; - ptr_c01 = ptr_c00 + 16; - ptr_c += tail_m * ldc; - n_count = n; - TCONF(cfg, tail_m, 16, 32); - for (; n_count > 31; n_count -= 32) { - ptr_a0 = ptr_a; - - ptr_b0 = ptr_b; - ptr_b1 = ptr_b + 16 * k; - ptr_b += 32 * k; - - lda = 32; - ldb = 32; - LOAD_C(0, 0); LOAD_C(0, 1); - k_count = k; - for (; k_count > 31; k_count -= 32) { - LOAD_A(0, x); - ptr_a0 += tail_m * 32; - LOAD_B(x, 0); LOAD_B(x, 1); - ptr_b0 += 16 * 32; - ptr_b1 += 16 * 32; - - MATMUL(0, 0); MATMUL(0, 1); - } - STORE_C(0, 0); STORE_C(0, 1); - ptr_c00 += 32; - ptr_c01 += 32; - } - for (; n_count > 0; n_count -= 16) { - int tail_n = (n_count > 16) ? 16: n_count; - __mmask16 bmask = (1UL << tail_n) - 1; - ptr_a0 = ptr_a; - - ptr_b0 = ptr_b; - ptr_b += tail_n * k; - - lda = 32; - ldb = 2 * tail_n; - TCONF(cfg, tail_m, tail_n, 32); - LOAD_C(0, 0); - k_count = k; - for (; k_count > 31; k_count -= 32) { - LOAD_A(0, x); - ptr_a0 += tail_m * 32; - LOAD_B(x, 0); - ptr_b0 += tail_n * 32; - - MATMUL(0, 0); - } - STORE_C(0, 0); - ptr_c00 += tail_n; - } - ptr_a += tail_m * k; - } - -tail_k: - // process for k < 32 - BLASLONG k32 = k & ~31; - BLASLONG k2 = k & ~1; - if (k32 != k) { - int remain_k2 = k2 - k32; - m_count = m; - ptr_a = A; - ptr_c = C; - if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0) - for (; m_count > 0; m_count -= 16) { - int tail_m = (m_count > 16) ? 16: m_count; - __mmask16 amask = (1UL << tail_m) - 1; - - ptr_a0 = ptr_a + tail_m * k32; - ptr_a1 = ptr_a + tail_m * k2; - ptr_a += tail_m * k; - ptr_b = B; - ptr_c00 = ptr_c; - ptr_c += tail_m * ldc; - n_count = n; - lda = remain_k2; - ldb = 32; - if (n_count > 15) { - TCONF_TAIL(cfg, tail_m, 16, remain_k2); - LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k32; - ptr_b1 = ptr_b + 16 * k2; - LOAD_C(0, 0); - LOAD_B(x, 0); LOAD_B_TAIL(x, 1); - MATMUL(0, 0); MATMUL_TAIL(1, 1); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; - } - } - if (n_count > 0) { - int tail_n = (n_count > 16) ? 16: n_count; - __mmask16 bmask = (1UL << tail_n) - 1; - ptr_b0 = ptr_b + tail_n * k32; - ptr_b1 = ptr_b + tail_n * k2; - ldb = 2 * tail_n; - TCONF_TAIL(cfg, tail_m, tail_n, remain_k2); - LOAD_C(0, 0); - LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); - LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1); - MATMUL(0, 0); MATMUL_TAIL(1, 1); - STORE_C(0, 0); - } - } - - } else if (remain_k2 > 0) { // k%32 = 2x - for (; m_count > 0; m_count -= 16) { - int tail_m = (m_count > 16) ? 16: m_count; - - ptr_a0 = ptr_a + tail_m * k32; - ptr_a += tail_m * k; - ptr_b = B; - ptr_c00 = ptr_c; - ptr_c += tail_m * ldc; - n_count = n; - lda = remain_k2; - ldb = 32; - if (n_count > 15) { - TCONF(cfg, tail_m, 16, remain_k2); - LOAD_A(0, x); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k32; - LOAD_C(0, 0); - LOAD_B(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; - } - } - if (n_count > 0) { - int tail_n = (n_count > 16) ? 16: n_count; - ptr_b0 = ptr_b + tail_n * k32; - ldb = 2 * tail_n; - TCONF(cfg, tail_m, tail_n, remain_k2); - LOAD_C(0, 0); - LOAD_A(0, x); - LOAD_B(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - } - } - } else { // k%32 = 1 - for (; m_count > 0; m_count -= 16) { - int tail_m = (m_count > 16) ? 16: m_count; - __mmask16 amask = (1UL << tail_m) - 1; - - ptr_a0 = ptr_a + tail_m * k2; - ptr_a += tail_m * k; - ptr_b = B; - ptr_c00 = ptr_c; - ptr_c += tail_m * ldc; - n_count = n; - if (n_count > 15) { - TCONF(cfg, tail_m, 16, 2); - MASK_LOAD_A_TAIL(0, x); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k2; - LOAD_C(0, 0); - LOAD_B_TAIL(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; - } - } - if (n_count > 0) { - int tail_n = (n_count > 16) ? 16: n_count; - __mmask16 bmask = (1UL << tail_n) - 1; - ptr_b0 = ptr_b + tail_n * k2; - TCONF(cfg, tail_m, tail_n, 2); - LOAD_C(0, 0); - MASK_LOAD_A_TAIL(0, x); - MASK_LOAD_B_TAIL(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - } - } - - } - } - return 0; + if (alpha == 1.0f) + return sbgemm_kernel_spr_alpha_one(m, n, k, alpha, A, B, C, ldc); + else + return sbgemm_kernel_spr_alpha(m, n, k, alpha, A, B, C, ldc); } diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c b/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c new file mode 100644 index 000000000..465b9eb75 --- /dev/null +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c @@ -0,0 +1,521 @@ +/*************************************************************************** + * Copyright (c) 2021, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include +#include +#include "common.h" + +#ifndef SBGEMM_KERNEL_SPR +#define SBGEMM_KERNEL_SPR +typedef struct { + char palette_id; + char start_row; + char dummy0[14]; // bytes 2-15 reserved, must be zero + short tile_colsb[8]; + char dummy1[16]; // bytes 32-47 reserved, must be zero + char tile_rows[8]; + char dummy2[16]; // bytes 56-63 reserved, must be zero +} tilecfg; + +/* tile0/tile1 -- A (m x 2k) + * tile2/tile3 -- B (2k x n) + * tile4-7 -- C (m x n) + */ +#define TCONF(cfg, m, n, k2) \ + memset(&cfg, 0, sizeof(tilecfg)); \ + cfg.palette_id = 1; \ + cfg.tile_rows[0] = m; \ + cfg.tile_rows[1] = m; \ + cfg.tile_rows[2] = k2>>1; \ + cfg.tile_rows[3] = k2>>1; \ + cfg.tile_rows[4] = m; \ + cfg.tile_rows[5] = m; \ + cfg.tile_rows[6] = m; \ + cfg.tile_rows[7] = m; \ + cfg.tile_colsb[0] = k2<<1; \ + cfg.tile_colsb[1] = k2<<1; \ + cfg.tile_colsb[2] = n * 4; \ + cfg.tile_colsb[3] = n * 4; \ + cfg.tile_colsb[4] = n * 4; \ + cfg.tile_colsb[5] = n * 4; \ + cfg.tile_colsb[6] = n * 4; \ + cfg.tile_colsb[7] = n * 4; \ + _tile_loadconfig(&cfg); + +/* CONFIG for handling k2 and odd tail at the same time + * tile0 -- A (m x 2k) + * tile1 -- A (m x 1) + * tile2 -- B (2k x n) + * tile3 -- B (1 x n) + * tile4 -- C (m x n) + */ +#define TCONF_TAIL(cfg, m, n, k2) \ + memset(&cfg, 0, sizeof(tilecfg)); \ + cfg.palette_id = 1; \ + cfg.tile_rows[0] = m; \ + cfg.tile_rows[1] = m; \ + cfg.tile_rows[2] = k2>>1; \ + cfg.tile_rows[3] = 1; \ + cfg.tile_rows[4] = m; \ + cfg.tile_colsb[0] = k2<<1; \ + cfg.tile_colsb[1] = 4; \ + cfg.tile_colsb[2] = n * 4; \ + cfg.tile_colsb[3] = n * 4; \ + cfg.tile_colsb[4] = n * 4; \ + _tile_loadconfig(&cfg); + +#define T_A0 0 +#define T_A1 1 +#define T_B0 2 +#define T_B1 3 +#define T_C00 4 +#define T_C01 5 +#define T_C10 6 +#define T_C11 7 + +// FIXME: gcc11 seem have problem in tile load/store address calc, +// need to multiply with element size (2 or 4) here. +#define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2) +#define LOAD_A_TAIL(M, N) {\ + __m256i ymm = _mm256_loadu_epi16(ptr_a##M); \ + __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ + _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ + _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ +} +#define MASK_LOAD_A_TAIL(M, N) {\ + __m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \ + __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ + _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ + _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ +} +#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2) +#define LOAD_B_TAIL(M, N) {\ + __m256i ymm = _mm256_loadu_epi16(ptr_b##N); \ + __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ + _mm512_storeu_epi16(tail_b + 16 * N, zmm); \ + _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ +} +#define MASK_LOAD_B_TAIL(M, N) {\ + __m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \ + __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ + _mm512_storeu_epi16(tail_b + 16 * N, zmm); \ + _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ +} + +#define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N) +#define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N) +#define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4) +#define LOAD_C_F(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4) + +#endif // end of SBGEMM_KERNEL_SPR + +#ifdef ALPHA_ONE +#undef LOAD_C +#define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4) +#else +#undef LOAD_C +#define LOAD_C(M, N) _tile_zero(T_C##M##N) +#define ALPHA_STORE(N) \ + __m512 zmm_d##N = _mm512_loadu_ps(dst##N + noffset); \ + __m512 zmm_s##N = _mm512_loadu_ps(src##N + noffset); \ + zmm_d##N = _mm512_fmadd_ps(alpha_512, zmm_s##N, zmm_d##N); \ + _mm512_storeu_ps(dst##N + noffset, zmm_d##N); +#define MASK_APLPHA_STORE(N) \ + __m512 zmm_d##N = _mm512_maskz_loadu_ps(mask, dst##N + noffset); \ + __m512 zmm_s##N = _mm512_maskz_loadu_ps(mask, src##N + noffset); \ + zmm_d##N = _mm512_fmadd_ps(alpha_512, zmm_s##N, zmm_d##N); \ + _mm512_mask_storeu_ps(dst##N + noffset, mask, zmm_d##N); +#endif // end of ALPHA_ONE + + +#ifdef ALPHA_ONE +int sbgemm_kernel_spr_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc) +#else +int sbgemm_kernel_spr_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc) +#endif +{ + /* Row Major matrix for AMX requirement */ + IFLOAT *ptr_a = A, *ptr_b = B; + IFLOAT *ptr_b0, *ptr_b1; + IFLOAT *ptr_a0, *ptr_a1; + FLOAT *ptr_c = C; + FLOAT *ptr_c00, *ptr_c01, *ptr_c10, *ptr_c11; + + BLASLONG lda, ldb; + BLASLONG m_count = m; + BLASLONG n_count, k_count; + +#ifndef ALPHA_ONE + FLOAT *tmp_c = malloc(sizeof(FLOAT) * m * n); + memset(tmp_c, 0, sizeof(FLOAT) * m * n); + ptr_c = tmp_c; + BLASLONG ldc_o = ldc; + ldc = n; +#endif + IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64))); + IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64))); + tilecfg cfg; + + if (k > 31) { + for (; m_count > 31; m_count -= 32) { + ptr_b = B; + + ptr_c00 = ptr_c; + ptr_c01 = ptr_c00 + 16; + ptr_c10 = ptr_c + 16 * ldc; + ptr_c11 = ptr_c10 + 16; + ptr_c += 32 * ldc; + n_count = n; + TCONF(cfg, 16, 16, 32); + for (; n_count > 31; n_count -= 32) { + ptr_a0 = ptr_a; + ptr_a1 = ptr_a + 16 * k; + + ptr_b0 = ptr_b; + ptr_b1 = ptr_b + 16 * k; + ptr_b += 32 * k; + + lda = 32; + ldb = 32; + LOAD_C(0, 0); LOAD_C(0, 1); + LOAD_C(1, 0); LOAD_C(1, 1); + k_count = k; + for (; k_count > 31; k_count -= 32) { + LOAD_A(0, x); LOAD_A(1, x); + ptr_a0 += 16 * 32; + ptr_a1 += 16 * 32; + LOAD_B(x, 0); LOAD_B(x, 1); + ptr_b0 += 16 * 32; + ptr_b1 += 16 * 32; + + MATMUL(0, 0); MATMUL(0, 1); + MATMUL(1, 0); MATMUL(1, 1); + } + STORE_C(0, 0); STORE_C(0, 1); + STORE_C(1, 0); STORE_C(1, 1); + ptr_c00 += 32; + ptr_c01 += 32; + ptr_c10 += 32; + ptr_c11 += 32; + } + for (; n_count > 0; n_count -= 16) { + int tail_n = (n_count > 16) ? 16: n_count; + ptr_a0 = ptr_a; + ptr_a1 = ptr_a + 16 * k; + + ptr_b0 = ptr_b; + ptr_b += tail_n * k; + + lda = 32; + ldb = 2 * tail_n; + TCONF(cfg, 16, tail_n, 32); + LOAD_C(0, 0); + LOAD_C(1, 0); + k_count = k; + for (; k_count > 31; k_count -= 32) { + LOAD_A(0, x); LOAD_A(1, x); + ptr_a0 += 16 * 32; + ptr_a1 += 16 * 32; + LOAD_B(x, 0); + ptr_b0 += tail_n * 32; + + MATMUL(0, 0); + MATMUL(1, 0); + } + STORE_C(0, 0); + STORE_C(1, 0); + ptr_c00 += tail_n; + ptr_c10 += tail_n; + } + ptr_a += 32 * k; + } + for (; m_count > 0; m_count -= 16) { + // process at most 16 m at a time + int tail_m = (m_count > 16) ? 16: m_count; + + ptr_b = B; + + ptr_c00 = ptr_c; + ptr_c01 = ptr_c00 + 16; + ptr_c += tail_m * ldc; + n_count = n; + TCONF(cfg, tail_m, 16, 32); + for (; n_count > 31; n_count -= 32) { + ptr_a0 = ptr_a; + + ptr_b0 = ptr_b; + ptr_b1 = ptr_b + 16 * k; + ptr_b += 32 * k; + + lda = 32; + ldb = 32; + LOAD_C(0, 0); LOAD_C(0, 1); + k_count = k; + for (; k_count > 31; k_count -= 32) { + LOAD_A(0, x); + ptr_a0 += tail_m * 32; + LOAD_B(x, 0); LOAD_B(x, 1); + ptr_b0 += 16 * 32; + ptr_b1 += 16 * 32; + + MATMUL(0, 0); MATMUL(0, 1); + } + STORE_C(0, 0); STORE_C(0, 1); + ptr_c00 += 32; + ptr_c01 += 32; + } + for (; n_count > 0; n_count -= 16) { + int tail_n = (n_count > 16) ? 16: n_count; + ptr_a0 = ptr_a; + + ptr_b0 = ptr_b; + ptr_b += tail_n * k; + + lda = 32; + ldb = 2 * tail_n; + TCONF(cfg, tail_m, tail_n, 32); + LOAD_C(0, 0); + k_count = k; + for (; k_count > 31; k_count -= 32) { + LOAD_A(0, x); + ptr_a0 += tail_m * 32; + LOAD_B(x, 0); + ptr_b0 += tail_n * 32; + + MATMUL(0, 0); + } + STORE_C(0, 0); + ptr_c00 += tail_n; + } + ptr_a += tail_m * k; + } + } + + // process for k < 32 + BLASLONG k32 = k & ~31; + BLASLONG k2 = k & ~1; + if (k32 != k) { + int remain_k2 = k2 - k32; + m_count = m; + ptr_a = A; +#ifndef ALPHA_ONE + ptr_c = tmp_c; +#else + ptr_c = C; +#endif + if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0) + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + __mmask16 amask = (1UL << tail_m) - 1; + + ptr_a0 = ptr_a + tail_m * k32; + ptr_a1 = ptr_a + tail_m * k2; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + lda = remain_k2; + ldb = 32; + if (n_count > 15) { + TCONF_TAIL(cfg, tail_m, 16, remain_k2); + LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k32; + ptr_b1 = ptr_b + 16 * k2; + LOAD_C_F(0, 0); + LOAD_B(x, 0); LOAD_B_TAIL(x, 1); + MATMUL(0, 0); MATMUL_TAIL(1, 1); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_b0 = ptr_b + tail_n * k32; + ptr_b1 = ptr_b + tail_n * k2; + ldb = 2 * tail_n; + TCONF_TAIL(cfg, tail_m, tail_n, remain_k2); + LOAD_C_F(0, 0); + LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); + LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1); + MATMUL(0, 0); MATMUL_TAIL(1, 1); + STORE_C(0, 0); + } + } + + } else if (remain_k2 > 0) { // k%32 = 2x + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + + ptr_a0 = ptr_a + tail_m * k32; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + lda = remain_k2; + ldb = 32; + if (n_count > 15) { + TCONF(cfg, tail_m, 16, remain_k2); + LOAD_A(0, x); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k32; + LOAD_C_F(0, 0); + LOAD_B(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + ptr_b0 = ptr_b + tail_n * k32; + ldb = 2 * tail_n; + TCONF(cfg, tail_m, tail_n, remain_k2); + LOAD_C_F(0, 0); + LOAD_A(0, x); + LOAD_B(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + } + } + } else { // k%32 = 1 + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + __mmask16 amask = (1UL << tail_m) - 1; + + ptr_a0 = ptr_a + tail_m * k2; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + if (n_count > 15) { + TCONF(cfg, tail_m, 16, 2); + MASK_LOAD_A_TAIL(0, x); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k2; + LOAD_C_F(0, 0); + LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_b0 = ptr_b + tail_n * k2; + TCONF(cfg, tail_m, tail_n, 2); + LOAD_C_F(0, 0); + MASK_LOAD_A_TAIL(0, x); + MASK_LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + } + } + + } + } +#ifndef ALPHA_ONE + __m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha)); + BLASLONG n16 = n & ~15; + BLASLONG noffset; + FLOAT *src0, *src1, *src2, *src3; + FLOAT *dst0, *dst1, *dst2, *dst3; + FLOAT *src = tmp_c; + FLOAT *dst = C; + m_count = m; + for (; m_count > 3; m_count -= 4) { + src0 = src; + src1 = src0 + ldc; + src2 = src1 + ldc; + src3 = src2 + ldc; + src += 4 * ldc; + + dst0 = dst; + dst1 = dst0 + ldc_o; + dst2 = dst1 + ldc_o; + dst3 = dst2 + ldc_o; + dst += 4 * ldc_o; + + noffset = 0; + for (; noffset < n16; noffset += 16) { + ALPHA_STORE(0); + ALPHA_STORE(1); + ALPHA_STORE(2); + ALPHA_STORE(3); + } + if (noffset < n) { + __mmask16 mask = (1UL << (n - noffset)) - 1; + MASK_APLPHA_STORE(0); + MASK_APLPHA_STORE(1); + MASK_APLPHA_STORE(2); + MASK_APLPHA_STORE(3); + } + } + for (; m_count > 1; m_count -= 2) { + src0 = src; + src1 = src0 + ldc; + src += 2 * ldc; + + dst0 = dst; + dst1 = dst0 + ldc_o; + dst += 2 * ldc_o; + + noffset = 0; + for (; noffset < n16; noffset += 16) { + ALPHA_STORE(0); + ALPHA_STORE(1); + } + if (noffset < n) { + __mmask16 mask = (1UL << (n - noffset)) - 1; + MASK_APLPHA_STORE(0); + MASK_APLPHA_STORE(1); + } + } + for (; m_count > 0; m_count -= 1) { + src0 = src; + dst0 = dst; + noffset = 0; + for (; noffset < n16; noffset += 16) { + ALPHA_STORE(0); + } + if (noffset < n) { + __mmask16 mask = (1UL << (n - noffset)) - 1; + MASK_APLPHA_STORE(0); + } + } + free(tmp_c); +#endif + return 0; +}