From 1d48b7cb168c57285e8c3763b0e0acc204ddbb17 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Mon, 6 Sep 2021 19:48:23 -0700 Subject: [PATCH 01/16] sbgemm: spr: add dummy source files --- kernel/x86_64/KERNEL.SAPPHIRERAPIDS | 11 +++++++++ kernel/x86_64/sbgemm_incopy_16_spr.c | 32 ++++++++++++++++++++++++ kernel/x86_64/sbgemm_itcopy_16_spr.c | 32 ++++++++++++++++++++++++ kernel/x86_64/sbgemm_kernel_16x16_spr.c | 33 +++++++++++++++++++++++++ kernel/x86_64/sbgemm_oncopy_16_spr.c | 32 ++++++++++++++++++++++++ kernel/x86_64/sbgemm_otcopy_16_spr.c | 32 ++++++++++++++++++++++++ 6 files changed, 172 insertions(+) create mode 100644 kernel/x86_64/sbgemm_incopy_16_spr.c create mode 100644 kernel/x86_64/sbgemm_itcopy_16_spr.c create mode 100644 kernel/x86_64/sbgemm_kernel_16x16_spr.c create mode 100644 kernel/x86_64/sbgemm_oncopy_16_spr.c create mode 100644 kernel/x86_64/sbgemm_otcopy_16_spr.c diff --git a/kernel/x86_64/KERNEL.SAPPHIRERAPIDS b/kernel/x86_64/KERNEL.SAPPHIRERAPIDS index 61965c745..bee624b04 100644 --- a/kernel/x86_64/KERNEL.SAPPHIRERAPIDS +++ b/kernel/x86_64/KERNEL.SAPPHIRERAPIDS @@ -1 +1,12 @@ include $(KERNELDIR)/KERNEL.COOPERLAKE + +SBGEMM_BETA = sgemm_beta_skylakex.c +SBGEMMKERNEL = sbgemm_kernel_16x16_spr.c +SBGEMMINCOPY = sbgemm_incopy_16_spr.c +SBGEMMITCOPY = sbgemm_itcopy_16_spr.c +SBGEMMONCOPY = sbgemm_oncopy_16_spr.c +SBGEMMOTCOPY = sbgemm_otcopy_16_spr.c +SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX) +SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX) +SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX) +SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX) diff --git a/kernel/x86_64/sbgemm_incopy_16_spr.c b/kernel/x86_64/sbgemm_incopy_16_spr.c new file mode 100644 index 000000000..2f57ae7b6 --- /dev/null +++ b/kernel/x86_64/sbgemm_incopy_16_spr.c @@ -0,0 +1,32 @@ +/*************************************************************************** + * Copyright (c) 2021, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { + return 0; +} diff --git a/kernel/x86_64/sbgemm_itcopy_16_spr.c b/kernel/x86_64/sbgemm_itcopy_16_spr.c new file mode 100644 index 000000000..2f57ae7b6 --- /dev/null +++ b/kernel/x86_64/sbgemm_itcopy_16_spr.c @@ -0,0 +1,32 @@ +/*************************************************************************** + * Copyright (c) 2021, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { + return 0; +} diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr.c b/kernel/x86_64/sbgemm_kernel_16x16_spr.c new file mode 100644 index 000000000..51f44ba4a --- /dev/null +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr.c @@ -0,0 +1,33 @@ +/*************************************************************************** + * Copyright (c) 2021, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include "common.h" + +int CNAME (BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc) +{ + return 0; +} diff --git a/kernel/x86_64/sbgemm_oncopy_16_spr.c b/kernel/x86_64/sbgemm_oncopy_16_spr.c new file mode 100644 index 000000000..2f57ae7b6 --- /dev/null +++ b/kernel/x86_64/sbgemm_oncopy_16_spr.c @@ -0,0 +1,32 @@ +/*************************************************************************** + * Copyright (c) 2021, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { + return 0; +} diff --git a/kernel/x86_64/sbgemm_otcopy_16_spr.c b/kernel/x86_64/sbgemm_otcopy_16_spr.c new file mode 100644 index 000000000..2f57ae7b6 --- /dev/null +++ b/kernel/x86_64/sbgemm_otcopy_16_spr.c @@ -0,0 +1,32 @@ +/*************************************************************************** + * Copyright (c) 2021, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { + return 0; +} From d0b253ac6eb639cdcce1c70d28d64b23e1fbaa92 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 8 Sep 2021 19:41:12 -0700 Subject: [PATCH 02/16] sbgemm: spr: implement oncopy_16 --- kernel/x86_64/sbgemm_oncopy_16_spr.c | 145 +++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) diff --git a/kernel/x86_64/sbgemm_oncopy_16_spr.c b/kernel/x86_64/sbgemm_oncopy_16_spr.c index 2f57ae7b6..da353d2c7 100644 --- a/kernel/x86_64/sbgemm_oncopy_16_spr.c +++ b/kernel/x86_64/sbgemm_oncopy_16_spr.c @@ -25,8 +25,153 @@ * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * *****************************************************************************/ +#include #include "common.h" +#define COPY_32(N) _mm512_storeu_si512(boffset + 32 * N, _mm512_loadu_si512(aoffset##N + i)) +#define MASK_COPY_32(N) _mm512_mask_storeu_epi16(boffset + tail_m * N, mmask, _mm512_maskz_loadu_epi16(mmask, aoffset##N + i)) +#define COPY_ODD_TAIL(N) *(boffset + N) = *(aoffset##N + i); + int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { + printf("ONCOPY: m %d, n %d, lda %d\n", m, n, lda); + BLASLONG i, j; + IFLOAT *aoffset, *boffset; + IFLOAT *aoffset0, *aoffset1, *aoffset2, *aoffset3; + IFLOAT *aoffset4, *aoffset5, *aoffset6, *aoffset7; + IFLOAT *aoffset8, *aoffset9, *aoffset10, *aoffset11; + IFLOAT *aoffset12, *aoffset13, *aoffset14, *aoffset15; + + aoffset = a; + boffset = b; + + BLASLONG n16 = n & ~15; + BLASLONG m32 = m & ~31; + BLASLONG m2 = m & ~1; + + for (j = 0; j < n16; j += 16) { + aoffset0 = aoffset; + aoffset1 = aoffset0 + lda; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset5 = aoffset4 + lda; + aoffset6 = aoffset5 + lda; + aoffset7 = aoffset6 + lda; + aoffset8 = aoffset7 + lda; + aoffset9 = aoffset8 + lda; + aoffset10 = aoffset9 + lda; + aoffset11 = aoffset10 + lda; + aoffset12 = aoffset11 + lda; + aoffset13 = aoffset12 + lda; + aoffset14 = aoffset13 + lda; + aoffset15 = aoffset14 + lda; + aoffset += 16 * lda; + for (i = 0; i < m32; i += 32) { + COPY_32(0); COPY_32(1); COPY_32(2); COPY_32(3); + COPY_32(4); COPY_32(5); COPY_32(6); COPY_32(7); + COPY_32(8); COPY_32(9); COPY_32(10); COPY_32(11); + COPY_32(12); COPY_32(13); COPY_32(14); COPY_32(15); + boffset += 32 * 16; + } + if (i < m2) { + int tail_m = m2 - i; + __mmask32 mmask = (1UL << tail_m) - 1; + MASK_COPY_32(0); MASK_COPY_32(1); MASK_COPY_32(2); MASK_COPY_32(3); + MASK_COPY_32(4); MASK_COPY_32(5); MASK_COPY_32(6); MASK_COPY_32(7); + MASK_COPY_32(8); MASK_COPY_32(9); MASK_COPY_32(10); MASK_COPY_32(11); + MASK_COPY_32(12); MASK_COPY_32(13); MASK_COPY_32(14); MASK_COPY_32(15); + i = m2; + boffset += tail_m * 16; + } + if (i < m) { + /* the tail odd k should put alone */ + COPY_ODD_TAIL(0); COPY_ODD_TAIL(1); COPY_ODD_TAIL(2); COPY_ODD_TAIL(3); + COPY_ODD_TAIL(4); COPY_ODD_TAIL(5); COPY_ODD_TAIL(6); COPY_ODD_TAIL(7); + COPY_ODD_TAIL(8); COPY_ODD_TAIL(9); COPY_ODD_TAIL(10); COPY_ODD_TAIL(11); + COPY_ODD_TAIL(12); COPY_ODD_TAIL(13); COPY_ODD_TAIL(14); COPY_ODD_TAIL(15); + boffset += 16; + } + } + if (j < n) { + int remain_n = n - j; + aoffset0 = aoffset; + aoffset1 = aoffset0 + lda; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset5 = aoffset4 + lda; + aoffset6 = aoffset5 + lda; + aoffset7 = aoffset6 + lda; + aoffset8 = aoffset7 + lda; + aoffset9 = aoffset8 + lda; + aoffset10 = aoffset9 + lda; + aoffset11 = aoffset10 + lda; + aoffset12 = aoffset11 + lda; + aoffset13 = aoffset12 + lda; + aoffset14 = aoffset13 + lda; + aoffset15 = aoffset14 + lda; + for (i = 0; i < m32; i += 32) { + switch(remain_n) { + case 15: COPY_32(14); + case 14: COPY_32(13); + case 13: COPY_32(12); + case 12: COPY_32(11); + case 11: COPY_32(10); + case 10: COPY_32(9); + case 9: COPY_32(8); + case 8: COPY_32(7); + case 7: COPY_32(6); + case 6: COPY_32(5); + case 5: COPY_32(4); + case 4: COPY_32(3); + case 3: COPY_32(2); + case 2: COPY_32(1); + case 1: COPY_32(0); + } + boffset += 32 * remain_n; + } + if (i < m2) { + int tail_m = m2 - i; + __mmask32 mmask = (1UL << tail_m) - 1; + switch(remain_n) { + case 15: MASK_COPY_32(14); + case 14: MASK_COPY_32(13); + case 13: MASK_COPY_32(12); + case 12: MASK_COPY_32(11); + case 11: MASK_COPY_32(10); + case 10: MASK_COPY_32(9); + case 9: MASK_COPY_32(8); + case 8: MASK_COPY_32(7); + case 7: MASK_COPY_32(6); + case 6: MASK_COPY_32(5); + case 5: MASK_COPY_32(4); + case 4: MASK_COPY_32(3); + case 3: MASK_COPY_32(2); + case 2: MASK_COPY_32(1); + case 1: MASK_COPY_32(0); + } + i = m2; + boffset += tail_m * remain_n; + } + if (i < m) { + switch(remain_n) { + case 15: COPY_ODD_TAIL(14); + case 14: COPY_ODD_TAIL(13); + case 13: COPY_ODD_TAIL(12); + case 12: COPY_ODD_TAIL(11); + case 11: COPY_ODD_TAIL(10); + case 10: COPY_ODD_TAIL(9); + case 9: COPY_ODD_TAIL(8); + case 8: COPY_ODD_TAIL(7); + case 7: COPY_ODD_TAIL(6); + case 6: COPY_ODD_TAIL(5); + case 5: COPY_ODD_TAIL(4); + case 4: COPY_ODD_TAIL(3); + case 3: COPY_ODD_TAIL(2); + case 2: COPY_ODD_TAIL(1); + case 1: COPY_ODD_TAIL(0); + } + } + } return 0; } From 6051c867418bb0ddc762647155637365c5a92a24 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Fri, 10 Sep 2021 01:14:05 -0700 Subject: [PATCH 03/16] sbgemm: spr: kernel works for m32 in NN case --- kernel/x86_64/KERNEL.SAPPHIRERAPIDS | 2 +- kernel/x86_64/sbgemm_itcopy_16_spr.c | 32 --- kernel/x86_64/sbgemm_kernel_16x16_spr.c | 252 +++++++++++++++++++++++- 3 files changed, 252 insertions(+), 34 deletions(-) delete mode 100644 kernel/x86_64/sbgemm_itcopy_16_spr.c diff --git a/kernel/x86_64/KERNEL.SAPPHIRERAPIDS b/kernel/x86_64/KERNEL.SAPPHIRERAPIDS index bee624b04..3f67640cb 100644 --- a/kernel/x86_64/KERNEL.SAPPHIRERAPIDS +++ b/kernel/x86_64/KERNEL.SAPPHIRERAPIDS @@ -3,7 +3,7 @@ include $(KERNELDIR)/KERNEL.COOPERLAKE SBGEMM_BETA = sgemm_beta_skylakex.c SBGEMMKERNEL = sbgemm_kernel_16x16_spr.c SBGEMMINCOPY = sbgemm_incopy_16_spr.c -SBGEMMITCOPY = sbgemm_itcopy_16_spr.c +SBGEMMITCOPY = sbgemm_tcopy_16_cooperlake.c SBGEMMONCOPY = sbgemm_oncopy_16_spr.c SBGEMMOTCOPY = sbgemm_otcopy_16_spr.c SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX) diff --git a/kernel/x86_64/sbgemm_itcopy_16_spr.c b/kernel/x86_64/sbgemm_itcopy_16_spr.c deleted file mode 100644 index 2f57ae7b6..000000000 --- a/kernel/x86_64/sbgemm_itcopy_16_spr.c +++ /dev/null @@ -1,32 +0,0 @@ -/*************************************************************************** - * Copyright (c) 2021, The OpenBLAS Project - * All rights reserved. - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in - * the documentation and/or other materials provided with the - * distribution. - * 3. Neither the name of the OpenBLAS project nor the names of - * its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE - * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * *****************************************************************************/ - -#include "common.h" - -int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { - return 0; -} diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr.c b/kernel/x86_64/sbgemm_kernel_16x16_spr.c index 51f44ba4a..41d2634d5 100644 --- a/kernel/x86_64/sbgemm_kernel_16x16_spr.c +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr.c @@ -25,9 +25,259 @@ * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * *****************************************************************************/ +#include +#include #include "common.h" -int CNAME (BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc) +typedef struct { + char palette_id; + char start_row; + char dummy0[14]; // bytes 2-15 reserved, must be zero + short tile_colsb[8]; + char dummy1[16]; // bytes 32-47 reserved, must be zero + char tile_rows[8]; + char dummy2[16]; // bytes 56-63 reserved, must be zero +} tilecfg; + +/* tile0/tile1 -- A (m x 2k) + * tile2/tile3 -- B (2k x n) + * tile4-7 -- C (m x n) + */ +#define TCONF(cfg, m, n, k2) \ + memset(&cfg, 0, sizeof(tilecfg)); \ + cfg.palette_id = 1; \ + cfg.tile_rows[0] = m; \ + cfg.tile_rows[1] = m; \ + cfg.tile_rows[2] = k2>>1; \ + cfg.tile_rows[3] = k2>>1; \ + cfg.tile_rows[4] = m; \ + cfg.tile_rows[5] = m; \ + cfg.tile_rows[6] = m; \ + cfg.tile_rows[7] = m; \ + cfg.tile_colsb[0] = k2<<1; \ + cfg.tile_colsb[1] = k2<<1; \ + cfg.tile_colsb[2] = n * 4; \ + cfg.tile_colsb[3] = n * 4; \ + cfg.tile_colsb[4] = n * 4; \ + cfg.tile_colsb[5] = n * 4; \ + cfg.tile_colsb[6] = n * 4; \ + cfg.tile_colsb[7] = n * 4; \ + _tile_loadconfig(&cfg); + +#define T_A0 0 +#define T_A1 1 +#define T_B0 2 +#define T_B1 3 +#define T_C00 4 +#define T_C01 5 +#define T_C10 6 +#define T_C11 7 + +// FIXME: gcc11 seem have problem in tile load/store address calc, +// need to multiply with element size (2 or 4) here. +#define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2) +#define LOAD_A_TAIL(M, N) {\ + __m256i ymm = _mm256_loadu_epi16(ptr_a##M); \ + __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ + _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ + _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ +} +#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2) +#define LOAD_B_TAIL(M, N) {\ + __m256i ymm = _mm256_loadu_epi16(ptr_b##N); \ + __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ + _mm512_storeu_epi16(tail_b + 16 * N, zmm); \ + _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ +} +#define MASK_LOAD_B_TAIL(M, N) {\ + __m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \ + __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ + _mm512_storeu_epi16(tail_b + 16 * N, zmm); \ + _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ +} + +#define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4) +#define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N) +#define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4) + + +int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOAT * iB, FLOAT * C, BLASLONG ldc) { + /* transport to Row Major matrix for AMX requirement */ + BLASLONG m, n; + IFLOAT *A, *B; + m = in; + n = im; + A = iB; + B = iA; + + printf("kernel: m %d, n %d, k %d, ldc: %d\n", m, n, k, ldc); + IFLOAT *ptr_a = A, *ptr_b = B; + IFLOAT *ptr_b0, *ptr_b1; + IFLOAT *ptr_a0, *ptr_a1; + FLOAT *ptr_c = C; + FLOAT *ptr_c00, *ptr_c01, *ptr_c10, *ptr_c11; + + BLASLONG lda, ldb; + BLASLONG m_count = m; + BLASLONG n_count, k_count; + + IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64))); + IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64))); + tilecfg cfg; + + for (; m_count > 31; m_count -= 32) { + ptr_b = B; + + ptr_c00 = ptr_c; + ptr_c01 = ptr_c00 + 16; + ptr_c10 = ptr_c + 16 * ldc; + ptr_c11 = ptr_c10 + 16; + ptr_c += 32 * ldc; + n_count = n; + for (; n_count > 31; n_count -= 32) { + ptr_a0 = ptr_a; + ptr_a1 = ptr_a + 16 * k; + + ptr_b0 = ptr_b; + ptr_b1 = ptr_b + 16 * k; + ptr_b += 32 * k; + + lda = 32; + ldb = 32; + TCONF(cfg, 16, 16, 32); + LOAD_C(0, 0); LOAD_C(0, 1); + LOAD_C(1, 0); LOAD_C(1, 1); + k_count = k; + for (; k_count > 31; k_count -= 32) { + LOAD_A(0, x); LOAD_A(1, x); + ptr_a0 += 16 * 32; + ptr_a1 += 16 * 32; + LOAD_B(x, 0); LOAD_B(x, 1); + ptr_b0 += 16 * 32; + ptr_b1 += 16 * 32; + + MATMUL(0, 0); MATMUL(0, 1); + MATMUL(1, 0); MATMUL(1, 1); + } + STORE_C(0, 0); STORE_C(0, 1); + STORE_C(1, 0); STORE_C(1, 1); + if (k_count > 1) { + /* still have more than 2*k */ + int remain_k2 = k_count & ~1; + k_count -= remain_k2; + lda = remain_k2; + TCONF(cfg, 16, 16, remain_k2); + /* reconfig will clear all tiles, + * need to store/load again + */ + LOAD_C(0, 0); LOAD_C(0, 1); + LOAD_C(1, 0); LOAD_C(1, 1); + + LOAD_A(0, x); LOAD_A(1, x); + ptr_a0 += 16 * remain_k2; + ptr_a1 += 16 * remain_k2; + LOAD_B(x, 0); LOAD_B(x, 1); + ptr_b0 += 16 * remain_k2; + ptr_b1 += 16 * remain_k2; + + MATMUL(0, 0); MATMUL(0, 1); + MATMUL(1, 0); MATMUL(1, 1); + + STORE_C(0, 0); STORE_C(0, 1); + STORE_C(1, 0); STORE_C(1, 1); + } + if (k_count > 0) { + /* still have odd tail k, need to transform into 2*k */ + TCONF(cfg, 16, 16, 2); + + LOAD_C(0, 0); LOAD_C(0, 1); + LOAD_C(1, 0); LOAD_C(1, 1); + + LOAD_A_TAIL(0, x); LOAD_A_TAIL(1, x); + LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1); + + MATMUL(0, 0); MATMUL(0, 1); + MATMUL(1, 0); MATMUL(1, 1); + + STORE_C(0, 0); STORE_C(0, 1); + STORE_C(1, 0); STORE_C(1, 1); + } + ptr_c00 += 32; + ptr_c01 += 32; + ptr_c10 += 32; + ptr_c11 += 32; + } + for (; n_count > 0; n_count -= 16) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_a0 = ptr_a; + ptr_a1 = ptr_a + 16 * k; + + ptr_b0 = ptr_b; + ptr_b += tail_n * k; + + lda = 32; + ldb = 2 * tail_n; + TCONF(cfg, 16, tail_n, 32); + LOAD_C(0, 0); + LOAD_C(1, 0); + k_count = k; + for (; k_count > 31; k_count -= 32) { + LOAD_A(0, x); LOAD_A(1, x); + ptr_a0 += 16 * 32; + ptr_a1 += 16 * 32; + LOAD_B(x, 0); + ptr_b0 += tail_n * 32; + + MATMUL(0, 0); + MATMUL(1, 0); + } + STORE_C(0, 0); + STORE_C(1, 0); + if (k_count > 1) { + /* still have more than 2*k */ + int remain_k2 = k_count & ~1; + k_count -= remain_k2; + lda = remain_k2; + TCONF(cfg, 16, tail_n, remain_k2); + /* reconfig will clear all tiles, + * need to store/load again + */ + LOAD_C(0, 0); + LOAD_C(1, 0); + + LOAD_A(0, x); LOAD_A(1, x); + ptr_a0 += 16 * remain_k2; + ptr_a1 += 16 * remain_k2; + LOAD_B(x, 0); + ptr_b0 += tail_n * remain_k2; + + MATMUL(0, 0); + MATMUL(1, 0); + + STORE_C(0, 0); + STORE_C(1, 0); + } + if (k_count > 0) { + /* still have odd tail k, need to transform into 2*k */ + TCONF(cfg, 16, tail_n, 2); + + LOAD_C(0, 0); + LOAD_C(1, 0); + + LOAD_A_TAIL(0, x); LOAD_A_TAIL(1, x); + MASK_LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + MATMUL(1, 0); + + STORE_C(0, 0); + STORE_C(1, 0); + } + ptr_c00 += tail_n; + ptr_c10 += tail_n; + } + ptr_a += 32 * k; + } return 0; } From a70bfb52d519ced4e6b0d633bfb95d3f4d332fe5 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Sun, 12 Sep 2021 19:22:58 -0700 Subject: [PATCH 04/16] sbgemm: spr: kernel works for NN case when alpha is 1.0 --- kernel/x86_64/sbgemm_kernel_16x16_spr.c | 135 +++++++++++++++++++++++- kernel/x86_64/sbgemm_oncopy_16_spr.c | 2 +- 2 files changed, 135 insertions(+), 2 deletions(-) diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr.c b/kernel/x86_64/sbgemm_kernel_16x16_spr.c index 41d2634d5..b7b4e36a3 100644 --- a/kernel/x86_64/sbgemm_kernel_16x16_spr.c +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr.c @@ -82,6 +82,12 @@ typedef struct { _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ } +#define MASK_LOAD_A_TAIL(M, N) {\ + __m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \ + __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ + _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ + _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ +} #define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2) #define LOAD_B_TAIL(M, N) {\ __m256i ymm = _mm256_loadu_epi16(ptr_b##N); \ @@ -111,7 +117,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA A = iB; B = iA; - printf("kernel: m %d, n %d, k %d, ldc: %d\n", m, n, k, ldc); IFLOAT *ptr_a = A, *ptr_b = B; IFLOAT *ptr_b0, *ptr_b1; IFLOAT *ptr_a0, *ptr_a1; @@ -279,5 +284,133 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA } ptr_a += 32 * k; } + for (; m_count > 0; m_count -= 16) { + // process at most 16 m at a time + int tail_m = (m_count > 16) ? 16: m_count; + __mmask16 amask = (1UL << tail_m) - 1; + + ptr_b = B; + + ptr_c00 = ptr_c; + ptr_c01 = ptr_c00 + 16; + ptr_c += tail_m * ldc; + n_count = n; + for (; n_count > 31; n_count -= 32) { + ptr_a0 = ptr_a; + + ptr_b0 = ptr_b; + ptr_b1 = ptr_b + 16 * k; + ptr_b += 32 * k; + + lda = 32; + ldb = 32; + TCONF(cfg, tail_m, 16, 32); + LOAD_C(0, 0); LOAD_C(0, 1); + k_count = k; + for (; k_count > 31; k_count -= 32) { + LOAD_A(0, x); + ptr_a0 += tail_m * 32; + LOAD_B(x, 0); LOAD_B(x, 1); + ptr_b0 += 16 * 32; + ptr_b1 += 16 * 32; + + MATMUL(0, 0); MATMUL(0, 1); + } + STORE_C(0, 0); STORE_C(0, 1); + if (k_count > 1) { + /* still have more than 2*k */ + int remain_k2 = k_count & ~1; + k_count -= remain_k2; + lda = remain_k2; + TCONF(cfg, tail_m, 16, remain_k2); + /* reconfig will clear all tiles, + * need to store/load again + */ + LOAD_C(0, 0); LOAD_C(0, 1); + + LOAD_A(0, x); + ptr_a0 += tail_m * remain_k2; + LOAD_B(x, 0); LOAD_B(x, 1); + ptr_b0 += 16 * remain_k2; + ptr_b1 += 16 * remain_k2; + + MATMUL(0, 0); MATMUL(0, 1); + + STORE_C(0, 0); STORE_C(0, 1); + } + if (k_count > 0) { + /* still have odd tail k, need to transform into 2*k */ + TCONF(cfg, tail_m, 16, 2); + + LOAD_C(0, 0); LOAD_C(0, 1); + + MASK_LOAD_A_TAIL(0, x); + LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1); + + MATMUL(0, 0); MATMUL(0, 1); + + STORE_C(0, 0); STORE_C(0, 1); + } + ptr_c00 += 32; + ptr_c01 += 32; + } + for (; n_count > 0; n_count -= 16) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_a0 = ptr_a; + + ptr_b0 = ptr_b; + ptr_b += tail_n * k; + + lda = 32; + ldb = 2 * tail_n; + TCONF(cfg, tail_m, tail_n, 32); + LOAD_C(0, 0); + k_count = k; + for (; k_count > 31; k_count -= 32) { + LOAD_A(0, x); + ptr_a0 += tail_m * 32; + LOAD_B(x, 0); + ptr_b0 += tail_n * 32; + + MATMUL(0, 0); + } + STORE_C(0, 0); + if (k_count > 1) { + /* still have more than 2*k */ + int remain_k2 = k_count & ~1; + k_count -= remain_k2; + lda = remain_k2; + TCONF(cfg, tail_m, tail_n, remain_k2); + /* reconfig will clear all tiles, + * need to store/load again + */ + LOAD_C(0, 0); + + LOAD_A(0, x); + ptr_a0 += tail_m * remain_k2; + LOAD_B(x, 0); + ptr_b0 += tail_n * remain_k2; + + MATMUL(0, 0); + + STORE_C(0, 0); + } + if (k_count > 0) { + /* still have odd tail k, need to transform into 2*k */ + TCONF(cfg, tail_m, tail_n, 2); + + LOAD_C(0, 0); + + MASK_LOAD_A_TAIL(0, x); + MASK_LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + + STORE_C(0, 0); + } + ptr_c00 += tail_n; + } + ptr_a += tail_m * k; + } return 0; } diff --git a/kernel/x86_64/sbgemm_oncopy_16_spr.c b/kernel/x86_64/sbgemm_oncopy_16_spr.c index da353d2c7..f5668e26e 100644 --- a/kernel/x86_64/sbgemm_oncopy_16_spr.c +++ b/kernel/x86_64/sbgemm_oncopy_16_spr.c @@ -32,8 +32,8 @@ #define MASK_COPY_32(N) _mm512_mask_storeu_epi16(boffset + tail_m * N, mmask, _mm512_maskz_loadu_epi16(mmask, aoffset##N + i)) #define COPY_ODD_TAIL(N) *(boffset + N) = *(aoffset##N + i); + int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { - printf("ONCOPY: m %d, n %d, lda %d\n", m, n, lda); BLASLONG i, j; IFLOAT *aoffset, *boffset; IFLOAT *aoffset0, *aoffset1, *aoffset2, *aoffset3; From 0abbcd19c1589bac3e5d1eae5d87a40535b26510 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Mon, 13 Sep 2021 01:44:53 -0700 Subject: [PATCH 05/16] sbgemm: spr: tuning for blocking params --- param.h | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/param.h b/param.h index 4e83714d1..c2c6916bc 100644 --- a/param.h +++ b/param.h @@ -1771,6 +1771,20 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #endif #define USE_SGEMM_KERNEL_DIRECT 1 +#undef SBGEMM_DEFAULT_UNROLL_N +#undef SBGEMM_DEFAULT_UNROLL_M +#undef SBGEMM_DEFAULT_P +#undef SBGEMM_DEFAULT_R +#undef SBGEMM_DEFAULT_Q +// FIXME: actually UNROLL_M = UNROLL_N = 16 +// If M and N is equal, OpenBLAS will reuse OCOPY as ICOPY. +// But for AMX, they are not the same, set UNROLL_M = 32 to workaround +#define SBGEMM_DEFAULT_UNROLL_N 16 +#define SBGEMM_DEFAULT_UNROLL_M 32 +#define SBGEMM_DEFAULT_P 192 +#define SBGEMM_DEFAULT_Q 1024 +#define SBGEMM_DEFAULT_R sbgemm_r + #ifdef ARCH_X86 #define SGEMM_DEFAULT_UNROLL_M 4 From 88154ed02d14aa6b9048dbcbe29a854d5a691713 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 15 Sep 2021 01:11:15 -0700 Subject: [PATCH 06/16] sbgemm: spr: reduce tile conf loading by seperate tail k handling --- kernel/x86_64/sbgemm_kernel_16x16_spr.c | 240 +++++++++--------------- 1 file changed, 92 insertions(+), 148 deletions(-) diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr.c b/kernel/x86_64/sbgemm_kernel_16x16_spr.c index b7b4e36a3..112adbe1c 100644 --- a/kernel/x86_64/sbgemm_kernel_16x16_spr.c +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr.c @@ -127,10 +127,14 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA BLASLONG m_count = m; BLASLONG n_count, k_count; + IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64))); IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64))); tilecfg cfg; + if (k < 32) + goto tail_k; + for (; m_count > 31; m_count -= 32) { ptr_b = B; @@ -140,6 +144,7 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA ptr_c11 = ptr_c10 + 16; ptr_c += 32 * ldc; n_count = n; + TCONF(cfg, 16, 16, 32); for (; n_count > 31; n_count -= 32) { ptr_a0 = ptr_a; ptr_a1 = ptr_a + 16 * k; @@ -150,7 +155,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA lda = 32; ldb = 32; - TCONF(cfg, 16, 16, 32); LOAD_C(0, 0); LOAD_C(0, 1); LOAD_C(1, 0); LOAD_C(1, 1); k_count = k; @@ -167,47 +171,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA } STORE_C(0, 0); STORE_C(0, 1); STORE_C(1, 0); STORE_C(1, 1); - if (k_count > 1) { - /* still have more than 2*k */ - int remain_k2 = k_count & ~1; - k_count -= remain_k2; - lda = remain_k2; - TCONF(cfg, 16, 16, remain_k2); - /* reconfig will clear all tiles, - * need to store/load again - */ - LOAD_C(0, 0); LOAD_C(0, 1); - LOAD_C(1, 0); LOAD_C(1, 1); - - LOAD_A(0, x); LOAD_A(1, x); - ptr_a0 += 16 * remain_k2; - ptr_a1 += 16 * remain_k2; - LOAD_B(x, 0); LOAD_B(x, 1); - ptr_b0 += 16 * remain_k2; - ptr_b1 += 16 * remain_k2; - - MATMUL(0, 0); MATMUL(0, 1); - MATMUL(1, 0); MATMUL(1, 1); - - STORE_C(0, 0); STORE_C(0, 1); - STORE_C(1, 0); STORE_C(1, 1); - } - if (k_count > 0) { - /* still have odd tail k, need to transform into 2*k */ - TCONF(cfg, 16, 16, 2); - - LOAD_C(0, 0); LOAD_C(0, 1); - LOAD_C(1, 0); LOAD_C(1, 1); - - LOAD_A_TAIL(0, x); LOAD_A_TAIL(1, x); - LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1); - - MATMUL(0, 0); MATMUL(0, 1); - MATMUL(1, 0); MATMUL(1, 1); - - STORE_C(0, 0); STORE_C(0, 1); - STORE_C(1, 0); STORE_C(1, 1); - } ptr_c00 += 32; ptr_c01 += 32; ptr_c10 += 32; @@ -240,45 +203,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA } STORE_C(0, 0); STORE_C(1, 0); - if (k_count > 1) { - /* still have more than 2*k */ - int remain_k2 = k_count & ~1; - k_count -= remain_k2; - lda = remain_k2; - TCONF(cfg, 16, tail_n, remain_k2); - /* reconfig will clear all tiles, - * need to store/load again - */ - LOAD_C(0, 0); - LOAD_C(1, 0); - - LOAD_A(0, x); LOAD_A(1, x); - ptr_a0 += 16 * remain_k2; - ptr_a1 += 16 * remain_k2; - LOAD_B(x, 0); - ptr_b0 += tail_n * remain_k2; - - MATMUL(0, 0); - MATMUL(1, 0); - - STORE_C(0, 0); - STORE_C(1, 0); - } - if (k_count > 0) { - /* still have odd tail k, need to transform into 2*k */ - TCONF(cfg, 16, tail_n, 2); - - LOAD_C(0, 0); - LOAD_C(1, 0); - - LOAD_A_TAIL(0, x); LOAD_A_TAIL(1, x); - MASK_LOAD_B_TAIL(x, 0); - MATMUL(0, 0); - MATMUL(1, 0); - - STORE_C(0, 0); - STORE_C(1, 0); - } ptr_c00 += tail_n; ptr_c10 += tail_n; } @@ -295,6 +219,7 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA ptr_c01 = ptr_c00 + 16; ptr_c += tail_m * ldc; n_count = n; + TCONF(cfg, tail_m, 16, 32); for (; n_count > 31; n_count -= 32) { ptr_a0 = ptr_a; @@ -304,7 +229,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA lda = 32; ldb = 32; - TCONF(cfg, tail_m, 16, 32); LOAD_C(0, 0); LOAD_C(0, 1); k_count = k; for (; k_count > 31; k_count -= 32) { @@ -317,40 +241,6 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA MATMUL(0, 0); MATMUL(0, 1); } STORE_C(0, 0); STORE_C(0, 1); - if (k_count > 1) { - /* still have more than 2*k */ - int remain_k2 = k_count & ~1; - k_count -= remain_k2; - lda = remain_k2; - TCONF(cfg, tail_m, 16, remain_k2); - /* reconfig will clear all tiles, - * need to store/load again - */ - LOAD_C(0, 0); LOAD_C(0, 1); - - LOAD_A(0, x); - ptr_a0 += tail_m * remain_k2; - LOAD_B(x, 0); LOAD_B(x, 1); - ptr_b0 += 16 * remain_k2; - ptr_b1 += 16 * remain_k2; - - MATMUL(0, 0); MATMUL(0, 1); - - STORE_C(0, 0); STORE_C(0, 1); - } - if (k_count > 0) { - /* still have odd tail k, need to transform into 2*k */ - TCONF(cfg, tail_m, 16, 2); - - LOAD_C(0, 0); LOAD_C(0, 1); - - MASK_LOAD_A_TAIL(0, x); - LOAD_B_TAIL(x, 0); LOAD_B_TAIL(x, 1); - - MATMUL(0, 0); MATMUL(0, 1); - - STORE_C(0, 0); STORE_C(0, 1); - } ptr_c00 += 32; ptr_c01 += 32; } @@ -376,41 +266,95 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA MATMUL(0, 0); } STORE_C(0, 0); - if (k_count > 1) { - /* still have more than 2*k */ - int remain_k2 = k_count & ~1; - k_count -= remain_k2; - lda = remain_k2; - TCONF(cfg, tail_m, tail_n, remain_k2); - /* reconfig will clear all tiles, - * need to store/load again - */ - LOAD_C(0, 0); - - LOAD_A(0, x); - ptr_a0 += tail_m * remain_k2; - LOAD_B(x, 0); - ptr_b0 += tail_n * remain_k2; - - MATMUL(0, 0); - - STORE_C(0, 0); - } - if (k_count > 0) { - /* still have odd tail k, need to transform into 2*k */ - TCONF(cfg, tail_m, tail_n, 2); - - LOAD_C(0, 0); - - MASK_LOAD_A_TAIL(0, x); - MASK_LOAD_B_TAIL(x, 0); - MATMUL(0, 0); - - STORE_C(0, 0); - } ptr_c00 += tail_n; } ptr_a += tail_m * k; } + +tail_k: + // process for k < 32 + BLASLONG k32 = k & ~31; + BLASLONG k2 = k & ~1; + int remain_k2 = k2 - k32; + if (remain_k2 > 0) { + m_count = m; + ptr_a = A; + ptr_c = C; + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + __mmask16 amask = (1UL << tail_m) - 1; + + ptr_a0 = ptr_a + tail_m * k32; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + lda = remain_k2; + ldb = 32; + TCONF(cfg, tail_m, 16, remain_k2); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k32; + LOAD_C(0, 0); + LOAD_A(0, x); + LOAD_B(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_b0 = ptr_b + tail_n * k32; + ldb = 2 * tail_n; + TCONF(cfg, tail_m, tail_n, remain_k2); + LOAD_C(0, 0); + LOAD_A(0, x); + LOAD_B(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + } + } + } + if (k2 != k) { + m_count = m; + ptr_a = A; + ptr_c = C; + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + __mmask16 amask = (1UL << tail_m) - 1; + + ptr_a0 = ptr_a + tail_m * k2; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + TCONF(cfg, tail_m, 16, 2); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k2; + LOAD_C(0, 0); + MASK_LOAD_A_TAIL(0, x); + LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_b0 = ptr_b + tail_n * k2; + TCONF(cfg, tail_m, tail_n, 2); + LOAD_C(0, 0); + MASK_LOAD_A_TAIL(0, x); + MASK_LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + } + } + + } return 0; } From 10d52646e2f04e741ccdcf15f4e68f9501ef6c40 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 15 Sep 2021 19:36:02 -0700 Subject: [PATCH 07/16] sbgemm: spr: oncopy: avoid handling too much pointer at a time --- kernel/x86_64/sbgemm_oncopy_16_spr.c | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/kernel/x86_64/sbgemm_oncopy_16_spr.c b/kernel/x86_64/sbgemm_oncopy_16_spr.c index f5668e26e..593f2433d 100644 --- a/kernel/x86_64/sbgemm_oncopy_16_spr.c +++ b/kernel/x86_64/sbgemm_oncopy_16_spr.c @@ -49,27 +49,39 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { BLASLONG m2 = m & ~1; for (j = 0; j < n16; j += 16) { + IFLOAT *boffset0 = boffset; aoffset0 = aoffset; aoffset1 = aoffset0 + lda; aoffset2 = aoffset1 + lda; aoffset3 = aoffset2 + lda; + for (i = 0; i < m32; i += 32) { + COPY_32(0); COPY_32(1); COPY_32(2); COPY_32(3); + boffset += 32 * 16; + } aoffset4 = aoffset3 + lda; aoffset5 = aoffset4 + lda; aoffset6 = aoffset5 + lda; aoffset7 = aoffset6 + lda; + boffset = boffset0; + for (i = 0; i < m32; i += 32) { + COPY_32(4); COPY_32(5); COPY_32(6); COPY_32(7); + boffset += 32 * 16; + } aoffset8 = aoffset7 + lda; aoffset9 = aoffset8 + lda; aoffset10 = aoffset9 + lda; aoffset11 = aoffset10 + lda; + boffset = boffset0; + for (i = 0; i < m32; i += 32) { + COPY_32(8); COPY_32(9); COPY_32(10); COPY_32(11); + boffset += 32 * 16; + } aoffset12 = aoffset11 + lda; aoffset13 = aoffset12 + lda; aoffset14 = aoffset13 + lda; aoffset15 = aoffset14 + lda; - aoffset += 16 * lda; + boffset = boffset0; for (i = 0; i < m32; i += 32) { - COPY_32(0); COPY_32(1); COPY_32(2); COPY_32(3); - COPY_32(4); COPY_32(5); COPY_32(6); COPY_32(7); - COPY_32(8); COPY_32(9); COPY_32(10); COPY_32(11); COPY_32(12); COPY_32(13); COPY_32(14); COPY_32(15); boffset += 32 * 16; } @@ -91,6 +103,7 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { COPY_ODD_TAIL(12); COPY_ODD_TAIL(13); COPY_ODD_TAIL(14); COPY_ODD_TAIL(15); boffset += 16; } + aoffset += 16 * lda; } if (j < n) { int remain_n = n - j; From 7b2f5cb3b7378b3111010678bd1433ebdb13d9a6 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 15 Sep 2021 20:29:49 -0700 Subject: [PATCH 08/16] sbgemm: spr: enlarge P to 256 for performance --- param.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/param.h b/param.h index c2c6916bc..23f406d74 100644 --- a/param.h +++ b/param.h @@ -1781,7 +1781,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // But for AMX, they are not the same, set UNROLL_M = 32 to workaround #define SBGEMM_DEFAULT_UNROLL_N 16 #define SBGEMM_DEFAULT_UNROLL_M 32 -#define SBGEMM_DEFAULT_P 192 +#define SBGEMM_DEFAULT_P 256 #define SBGEMM_DEFAULT_Q 1024 #define SBGEMM_DEFAULT_R sbgemm_r From 9ab33228bbf9301c34b3dedfa4047cbfee9bb847 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Wed, 15 Sep 2021 23:59:38 -0700 Subject: [PATCH 09/16] sbgemm: spr: process k2 and odd k at the same time --- kernel/x86_64/sbgemm_kernel_16x16_spr.c | 206 +++++++++++++++--------- 1 file changed, 133 insertions(+), 73 deletions(-) diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr.c b/kernel/x86_64/sbgemm_kernel_16x16_spr.c index 112adbe1c..dbfacd6ab 100644 --- a/kernel/x86_64/sbgemm_kernel_16x16_spr.c +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr.c @@ -64,6 +64,28 @@ typedef struct { cfg.tile_colsb[7] = n * 4; \ _tile_loadconfig(&cfg); +/* CONFIG for handling k2 and odd tail at the same time + * tile0 -- A (m x 2k) + * tile1 -- A (m x 1) + * tile2 -- B (2k x n) + * tile3 -- B (1 x n) + * tile4 -- C (m x n) + */ +#define TCONF_TAIL(cfg, m, n, k2) \ + memset(&cfg, 0, sizeof(tilecfg)); \ + cfg.palette_id = 1; \ + cfg.tile_rows[0] = m; \ + cfg.tile_rows[1] = m; \ + cfg.tile_rows[2] = k2>>1; \ + cfg.tile_rows[3] = 1; \ + cfg.tile_rows[4] = m; \ + cfg.tile_colsb[0] = k2<<1; \ + cfg.tile_colsb[1] = 4; \ + cfg.tile_colsb[2] = n * 4; \ + cfg.tile_colsb[3] = n * 4; \ + cfg.tile_colsb[4] = n * 4; \ + _tile_loadconfig(&cfg); + #define T_A0 0 #define T_A1 1 #define T_B0 2 @@ -104,6 +126,7 @@ typedef struct { #define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4) #define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N) +#define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N) #define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4) @@ -275,86 +298,123 @@ tail_k: // process for k < 32 BLASLONG k32 = k & ~31; BLASLONG k2 = k & ~1; - int remain_k2 = k2 - k32; - if (remain_k2 > 0) { + if (k32 != k) { + int remain_k2 = k2 - k32; m_count = m; ptr_a = A; ptr_c = C; - for (; m_count > 0; m_count -= 16) { - int tail_m = (m_count > 16) ? 16: m_count; - __mmask16 amask = (1UL << tail_m) - 1; + if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0) + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + __mmask16 amask = (1UL << tail_m) - 1; - ptr_a0 = ptr_a + tail_m * k32; - ptr_a += tail_m * k; - ptr_b = B; - ptr_c00 = ptr_c; - ptr_c += tail_m * ldc; - n_count = n; - lda = remain_k2; - ldb = 32; - TCONF(cfg, tail_m, 16, remain_k2); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k32; - LOAD_C(0, 0); - LOAD_A(0, x); - LOAD_B(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; + ptr_a0 = ptr_a + tail_m * k32; + ptr_a1 = ptr_a + tail_m * k2; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + lda = remain_k2; + ldb = 32; + TCONF_TAIL(cfg, tail_m, 16, remain_k2); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k32; + ptr_b1 = ptr_b + 16 * k2; + LOAD_C(0, 0); + LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); + LOAD_B(x, 0); LOAD_B_TAIL(x, 1); + MATMUL(0, 0); MATMUL_TAIL(1, 1); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_b0 = ptr_b + tail_n * k32; + ptr_b1 = ptr_b + tail_n * k2; + ldb = 2 * tail_n; + TCONF_TAIL(cfg, tail_m, tail_n, remain_k2); + LOAD_C(0, 0); + LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); + LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1); + MATMUL(0, 0); MATMUL_TAIL(1, 1); + STORE_C(0, 0); + } } - if (n_count > 0) { - int tail_n = (n_count > 16) ? 16: n_count; - __mmask16 bmask = (1UL << tail_n) - 1; - ptr_b0 = ptr_b + tail_n * k32; - ldb = 2 * tail_n; - TCONF(cfg, tail_m, tail_n, remain_k2); - LOAD_C(0, 0); - LOAD_A(0, x); - LOAD_B(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); + + } else if (remain_k2 > 0) { // k%32 = 2x + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + + ptr_a0 = ptr_a + tail_m * k32; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + lda = remain_k2; + ldb = 32; + TCONF(cfg, tail_m, 16, remain_k2); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k32; + LOAD_C(0, 0); + LOAD_A(0, x); + LOAD_B(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + ptr_b0 = ptr_b + tail_n * k32; + ldb = 2 * tail_n; + TCONF(cfg, tail_m, tail_n, remain_k2); + LOAD_C(0, 0); + LOAD_A(0, x); + LOAD_B(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + } } + } else { // k%32 = 1 + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + __mmask16 amask = (1UL << tail_m) - 1; + + ptr_a0 = ptr_a + tail_m * k2; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + TCONF(cfg, tail_m, 16, 2); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k2; + LOAD_C(0, 0); + MASK_LOAD_A_TAIL(0, x); + LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_b0 = ptr_b + tail_n * k2; + TCONF(cfg, tail_m, tail_n, 2); + LOAD_C(0, 0); + MASK_LOAD_A_TAIL(0, x); + MASK_LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + } + } + } } - if (k2 != k) { - m_count = m; - ptr_a = A; - ptr_c = C; - for (; m_count > 0; m_count -= 16) { - int tail_m = (m_count > 16) ? 16: m_count; - __mmask16 amask = (1UL << tail_m) - 1; - - ptr_a0 = ptr_a + tail_m * k2; - ptr_a += tail_m * k; - ptr_b = B; - ptr_c00 = ptr_c; - ptr_c += tail_m * ldc; - n_count = n; - TCONF(cfg, tail_m, 16, 2); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k2; - LOAD_C(0, 0); - MASK_LOAD_A_TAIL(0, x); - LOAD_B_TAIL(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; - } - if (n_count > 0) { - int tail_n = (n_count > 16) ? 16: n_count; - __mmask16 bmask = (1UL << tail_n) - 1; - ptr_b0 = ptr_b + tail_n * k2; - TCONF(cfg, tail_m, tail_n, 2); - LOAD_C(0, 0); - MASK_LOAD_A_TAIL(0, x); - MASK_LOAD_B_TAIL(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - } - } - - } return 0; } From f2485352a603f11bd3c7b45788c1c80449ba76eb Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 16 Sep 2021 01:04:01 -0700 Subject: [PATCH 10/16] sbgemm: spr: only load A once in tail_k handling --- kernel/x86_64/sbgemm_kernel_16x16_spr.c | 62 ++++++++++++++----------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr.c b/kernel/x86_64/sbgemm_kernel_16x16_spr.c index dbfacd6ab..b34035896 100644 --- a/kernel/x86_64/sbgemm_kernel_16x16_spr.c +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr.c @@ -317,17 +317,19 @@ tail_k: n_count = n; lda = remain_k2; ldb = 32; - TCONF_TAIL(cfg, tail_m, 16, remain_k2); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k32; - ptr_b1 = ptr_b + 16 * k2; - LOAD_C(0, 0); + if (n_count > 15) { + TCONF_TAIL(cfg, tail_m, 16, remain_k2); LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); - LOAD_B(x, 0); LOAD_B_TAIL(x, 1); - MATMUL(0, 0); MATMUL_TAIL(1, 1); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k32; + ptr_b1 = ptr_b + 16 * k2; + LOAD_C(0, 0); + LOAD_B(x, 0); LOAD_B_TAIL(x, 1); + MATMUL(0, 0); MATMUL_TAIL(1, 1); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } } if (n_count > 0) { int tail_n = (n_count > 16) ? 16: n_count; @@ -356,16 +358,18 @@ tail_k: n_count = n; lda = remain_k2; ldb = 32; - TCONF(cfg, tail_m, 16, remain_k2); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k32; - LOAD_C(0, 0); + if (n_count > 15) { + TCONF(cfg, tail_m, 16, remain_k2); LOAD_A(0, x); - LOAD_B(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k32; + LOAD_C(0, 0); + LOAD_B(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } } if (n_count > 0) { int tail_n = (n_count > 16) ? 16: n_count; @@ -390,16 +394,18 @@ tail_k: ptr_c00 = ptr_c; ptr_c += tail_m * ldc; n_count = n; - TCONF(cfg, tail_m, 16, 2); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k2; - LOAD_C(0, 0); + if (n_count > 15) { + TCONF(cfg, tail_m, 16, 2); MASK_LOAD_A_TAIL(0, x); - LOAD_B_TAIL(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k2; + LOAD_C(0, 0); + LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } } if (n_count > 0) { int tail_n = (n_count > 16) ? 16: n_count; From a52456b168897ab252d1be968ff2040d8d909296 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 16 Sep 2021 20:08:42 -0700 Subject: [PATCH 11/16] sbgemm: spr: oncopy: use tile load/store instead --- kernel/x86_64/sbgemm_oncopy_16_spr.c | 180 +++++++++------------------ 1 file changed, 59 insertions(+), 121 deletions(-) diff --git a/kernel/x86_64/sbgemm_oncopy_16_spr.c b/kernel/x86_64/sbgemm_oncopy_16_spr.c index 593f2433d..ccb00ada1 100644 --- a/kernel/x86_64/sbgemm_oncopy_16_spr.c +++ b/kernel/x86_64/sbgemm_oncopy_16_spr.c @@ -28,18 +28,45 @@ #include #include "common.h" -#define COPY_32(N) _mm512_storeu_si512(boffset + 32 * N, _mm512_loadu_si512(aoffset##N + i)) -#define MASK_COPY_32(N) _mm512_mask_storeu_epi16(boffset + tail_m * N, mmask, _mm512_maskz_loadu_epi16(mmask, aoffset##N + i)) -#define COPY_ODD_TAIL(N) *(boffset + N) = *(aoffset##N + i); +typedef struct { + char palette_id; + char start_row; + char dummy0[14]; // bytes 2-15 reserved, must be zero + short tile_colsb[8]; + char dummy1[16]; // bytes 32-47 reserved, must be zero + char tile_rows[8]; + char dummy2[16]; // bytes 56-63 reserved, must be zero +} tilecfg; + +#define T_16x32 0 +#define T_16xm 1 +#define T_nx32 2 +#define T_nxm 3 + +#define TCONF(cfg, m, n) \ + memset(&cfg, 0, sizeof(tilecfg)); \ + cfg.palette_id = 1; \ + cfg.tile_rows[T_16x32] = 16; \ + cfg.tile_colsb[T_16x32] = 64; \ + if (m) { \ + cfg.tile_rows[T_16xm] = 16; \ + cfg.tile_colsb[T_16xm] = m * 2; \ + } \ + if (n) { \ + cfg.tile_rows[T_nx32] = n; \ + cfg.tile_colsb[T_nx32] = 64; \ + } \ + if (m && n) { \ + cfg.tile_rows[T_nxm] = n; \ + cfg.tile_colsb[T_nxm] = m * 2; \ + } \ + _tile_loadconfig(&cfg); int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { BLASLONG i, j; IFLOAT *aoffset, *boffset; - IFLOAT *aoffset0, *aoffset1, *aoffset2, *aoffset3; - IFLOAT *aoffset4, *aoffset5, *aoffset6, *aoffset7; - IFLOAT *aoffset8, *aoffset9, *aoffset10, *aoffset11; - IFLOAT *aoffset12, *aoffset13, *aoffset14, *aoffset15; + IFLOAT *aoffset0; aoffset = a; boffset = b; @@ -48,141 +75,52 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { BLASLONG m32 = m & ~31; BLASLONG m2 = m & ~1; + BLASLONG tail_m = m2 - m32; + BLASLONG tail_n = n - n16; + tilecfg cfg; + TCONF(cfg, tail_m, tail_n); + for (j = 0; j < n16; j += 16) { - IFLOAT *boffset0 = boffset; aoffset0 = aoffset; - aoffset1 = aoffset0 + lda; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; for (i = 0; i < m32; i += 32) { - COPY_32(0); COPY_32(1); COPY_32(2); COPY_32(3); - boffset += 32 * 16; - } - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - boffset = boffset0; - for (i = 0; i < m32; i += 32) { - COPY_32(4); COPY_32(5); COPY_32(6); COPY_32(7); - boffset += 32 * 16; - } - aoffset8 = aoffset7 + lda; - aoffset9 = aoffset8 + lda; - aoffset10 = aoffset9 + lda; - aoffset11 = aoffset10 + lda; - boffset = boffset0; - for (i = 0; i < m32; i += 32) { - COPY_32(8); COPY_32(9); COPY_32(10); COPY_32(11); - boffset += 32 * 16; - } - aoffset12 = aoffset11 + lda; - aoffset13 = aoffset12 + lda; - aoffset14 = aoffset13 + lda; - aoffset15 = aoffset14 + lda; - boffset = boffset0; - for (i = 0; i < m32; i += 32) { - COPY_32(12); COPY_32(13); COPY_32(14); COPY_32(15); + _tile_loadd(T_16x32, aoffset0, lda * 2); + _tile_stored(T_16x32, boffset, 32 * 2); + aoffset0 += 32; boffset += 32 * 16; } if (i < m2) { - int tail_m = m2 - i; - __mmask32 mmask = (1UL << tail_m) - 1; - MASK_COPY_32(0); MASK_COPY_32(1); MASK_COPY_32(2); MASK_COPY_32(3); - MASK_COPY_32(4); MASK_COPY_32(5); MASK_COPY_32(6); MASK_COPY_32(7); - MASK_COPY_32(8); MASK_COPY_32(9); MASK_COPY_32(10); MASK_COPY_32(11); - MASK_COPY_32(12); MASK_COPY_32(13); MASK_COPY_32(14); MASK_COPY_32(15); - i = m2; + _tile_loadd(T_16xm, aoffset0, lda * 2); + _tile_stored(T_16xm, boffset, tail_m * 2); + aoffset0 += tail_m; boffset += tail_m * 16; + i = m2; } if (i < m) { /* the tail odd k should put alone */ - COPY_ODD_TAIL(0); COPY_ODD_TAIL(1); COPY_ODD_TAIL(2); COPY_ODD_TAIL(3); - COPY_ODD_TAIL(4); COPY_ODD_TAIL(5); COPY_ODD_TAIL(6); COPY_ODD_TAIL(7); - COPY_ODD_TAIL(8); COPY_ODD_TAIL(9); COPY_ODD_TAIL(10); COPY_ODD_TAIL(11); - COPY_ODD_TAIL(12); COPY_ODD_TAIL(13); COPY_ODD_TAIL(14); COPY_ODD_TAIL(15); + for (int ii = 0; ii < 16; ii++) { + *(boffset + ii) = *(aoffset0 + lda * ii); + } boffset += 16; } aoffset += 16 * lda; } if (j < n) { - int remain_n = n - j; aoffset0 = aoffset; - aoffset1 = aoffset0 + lda; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; - aoffset9 = aoffset8 + lda; - aoffset10 = aoffset9 + lda; - aoffset11 = aoffset10 + lda; - aoffset12 = aoffset11 + lda; - aoffset13 = aoffset12 + lda; - aoffset14 = aoffset13 + lda; - aoffset15 = aoffset14 + lda; for (i = 0; i < m32; i += 32) { - switch(remain_n) { - case 15: COPY_32(14); - case 14: COPY_32(13); - case 13: COPY_32(12); - case 12: COPY_32(11); - case 11: COPY_32(10); - case 10: COPY_32(9); - case 9: COPY_32(8); - case 8: COPY_32(7); - case 7: COPY_32(6); - case 6: COPY_32(5); - case 5: COPY_32(4); - case 4: COPY_32(3); - case 3: COPY_32(2); - case 2: COPY_32(1); - case 1: COPY_32(0); - } - boffset += 32 * remain_n; + _tile_loadd(T_nx32, aoffset0, lda * 2); + _tile_stored(T_nx32, boffset, 32 * 2); + aoffset0 += 32; + boffset += 32 * tail_n; } if (i < m2) { - int tail_m = m2 - i; - __mmask32 mmask = (1UL << tail_m) - 1; - switch(remain_n) { - case 15: MASK_COPY_32(14); - case 14: MASK_COPY_32(13); - case 13: MASK_COPY_32(12); - case 12: MASK_COPY_32(11); - case 11: MASK_COPY_32(10); - case 10: MASK_COPY_32(9); - case 9: MASK_COPY_32(8); - case 8: MASK_COPY_32(7); - case 7: MASK_COPY_32(6); - case 6: MASK_COPY_32(5); - case 5: MASK_COPY_32(4); - case 4: MASK_COPY_32(3); - case 3: MASK_COPY_32(2); - case 2: MASK_COPY_32(1); - case 1: MASK_COPY_32(0); - } - i = m2; - boffset += tail_m * remain_n; + _tile_loadd(T_nxm, aoffset0, lda * 2); + _tile_stored(T_nxm, boffset, tail_m * 2); + aoffset0 += tail_m; + boffset += tail_m * tail_n; } if (i < m) { - switch(remain_n) { - case 15: COPY_ODD_TAIL(14); - case 14: COPY_ODD_TAIL(13); - case 13: COPY_ODD_TAIL(12); - case 12: COPY_ODD_TAIL(11); - case 11: COPY_ODD_TAIL(10); - case 10: COPY_ODD_TAIL(9); - case 9: COPY_ODD_TAIL(8); - case 8: COPY_ODD_TAIL(7); - case 7: COPY_ODD_TAIL(6); - case 6: COPY_ODD_TAIL(5); - case 5: COPY_ODD_TAIL(4); - case 4: COPY_ODD_TAIL(3); - case 3: COPY_ODD_TAIL(2); - case 2: COPY_ODD_TAIL(1); - case 1: COPY_ODD_TAIL(0); + for (int ii = 0; ii < tail_n; ii++) { + *(boffset + ii) = *(aoffset0 + lda * ii); } } } From f018aa342a5d97603a5f59fa1feb36e1c77e0571 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Fri, 17 Sep 2021 00:48:52 -0700 Subject: [PATCH 12/16] sbgemm: spr: kernel handle alpha != 1.0 --- kernel/x86_64/sbgemm_kernel_16x16_spr.c | 392 +------------- kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c | 521 +++++++++++++++++++ 2 files changed, 529 insertions(+), 384 deletions(-) create mode 100644 kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr.c b/kernel/x86_64/sbgemm_kernel_16x16_spr.c index b34035896..955db3163 100644 --- a/kernel/x86_64/sbgemm_kernel_16x16_spr.c +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr.c @@ -25,109 +25,12 @@ * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * *****************************************************************************/ -#include -#include #include "common.h" -typedef struct { - char palette_id; - char start_row; - char dummy0[14]; // bytes 2-15 reserved, must be zero - short tile_colsb[8]; - char dummy1[16]; // bytes 32-47 reserved, must be zero - char tile_rows[8]; - char dummy2[16]; // bytes 56-63 reserved, must be zero -} tilecfg; - -/* tile0/tile1 -- A (m x 2k) - * tile2/tile3 -- B (2k x n) - * tile4-7 -- C (m x n) - */ -#define TCONF(cfg, m, n, k2) \ - memset(&cfg, 0, sizeof(tilecfg)); \ - cfg.palette_id = 1; \ - cfg.tile_rows[0] = m; \ - cfg.tile_rows[1] = m; \ - cfg.tile_rows[2] = k2>>1; \ - cfg.tile_rows[3] = k2>>1; \ - cfg.tile_rows[4] = m; \ - cfg.tile_rows[5] = m; \ - cfg.tile_rows[6] = m; \ - cfg.tile_rows[7] = m; \ - cfg.tile_colsb[0] = k2<<1; \ - cfg.tile_colsb[1] = k2<<1; \ - cfg.tile_colsb[2] = n * 4; \ - cfg.tile_colsb[3] = n * 4; \ - cfg.tile_colsb[4] = n * 4; \ - cfg.tile_colsb[5] = n * 4; \ - cfg.tile_colsb[6] = n * 4; \ - cfg.tile_colsb[7] = n * 4; \ - _tile_loadconfig(&cfg); - -/* CONFIG for handling k2 and odd tail at the same time - * tile0 -- A (m x 2k) - * tile1 -- A (m x 1) - * tile2 -- B (2k x n) - * tile3 -- B (1 x n) - * tile4 -- C (m x n) - */ -#define TCONF_TAIL(cfg, m, n, k2) \ - memset(&cfg, 0, sizeof(tilecfg)); \ - cfg.palette_id = 1; \ - cfg.tile_rows[0] = m; \ - cfg.tile_rows[1] = m; \ - cfg.tile_rows[2] = k2>>1; \ - cfg.tile_rows[3] = 1; \ - cfg.tile_rows[4] = m; \ - cfg.tile_colsb[0] = k2<<1; \ - cfg.tile_colsb[1] = 4; \ - cfg.tile_colsb[2] = n * 4; \ - cfg.tile_colsb[3] = n * 4; \ - cfg.tile_colsb[4] = n * 4; \ - _tile_loadconfig(&cfg); - -#define T_A0 0 -#define T_A1 1 -#define T_B0 2 -#define T_B1 3 -#define T_C00 4 -#define T_C01 5 -#define T_C10 6 -#define T_C11 7 - -// FIXME: gcc11 seem have problem in tile load/store address calc, -// need to multiply with element size (2 or 4) here. -#define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2) -#define LOAD_A_TAIL(M, N) {\ - __m256i ymm = _mm256_loadu_epi16(ptr_a##M); \ - __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ - _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ - _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ -} -#define MASK_LOAD_A_TAIL(M, N) {\ - __m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \ - __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ - _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ - _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ -} -#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2) -#define LOAD_B_TAIL(M, N) {\ - __m256i ymm = _mm256_loadu_epi16(ptr_b##N); \ - __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ - _mm512_storeu_epi16(tail_b + 16 * N, zmm); \ - _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ -} -#define MASK_LOAD_B_TAIL(M, N) {\ - __m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \ - __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ - _mm512_storeu_epi16(tail_b + 16 * N, zmm); \ - _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ -} - -#define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4) -#define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N) -#define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N) -#define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4) +#define ALPHA_ONE +#include "sbgemm_kernel_16x16_spr_tmpl.c" +#undef ALPHA_ONE +#include "sbgemm_kernel_16x16_spr_tmpl.c" int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOAT * iB, FLOAT * C, BLASLONG ldc) @@ -140,287 +43,8 @@ int CNAME (BLASLONG im, BLASLONG in, BLASLONG k, FLOAT alpha, IFLOAT * iA, IFLOA A = iB; B = iA; - IFLOAT *ptr_a = A, *ptr_b = B; - IFLOAT *ptr_b0, *ptr_b1; - IFLOAT *ptr_a0, *ptr_a1; - FLOAT *ptr_c = C; - FLOAT *ptr_c00, *ptr_c01, *ptr_c10, *ptr_c11; - - BLASLONG lda, ldb; - BLASLONG m_count = m; - BLASLONG n_count, k_count; - - - IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64))); - IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64))); - tilecfg cfg; - - if (k < 32) - goto tail_k; - - for (; m_count > 31; m_count -= 32) { - ptr_b = B; - - ptr_c00 = ptr_c; - ptr_c01 = ptr_c00 + 16; - ptr_c10 = ptr_c + 16 * ldc; - ptr_c11 = ptr_c10 + 16; - ptr_c += 32 * ldc; - n_count = n; - TCONF(cfg, 16, 16, 32); - for (; n_count > 31; n_count -= 32) { - ptr_a0 = ptr_a; - ptr_a1 = ptr_a + 16 * k; - - ptr_b0 = ptr_b; - ptr_b1 = ptr_b + 16 * k; - ptr_b += 32 * k; - - lda = 32; - ldb = 32; - LOAD_C(0, 0); LOAD_C(0, 1); - LOAD_C(1, 0); LOAD_C(1, 1); - k_count = k; - for (; k_count > 31; k_count -= 32) { - LOAD_A(0, x); LOAD_A(1, x); - ptr_a0 += 16 * 32; - ptr_a1 += 16 * 32; - LOAD_B(x, 0); LOAD_B(x, 1); - ptr_b0 += 16 * 32; - ptr_b1 += 16 * 32; - - MATMUL(0, 0); MATMUL(0, 1); - MATMUL(1, 0); MATMUL(1, 1); - } - STORE_C(0, 0); STORE_C(0, 1); - STORE_C(1, 0); STORE_C(1, 1); - ptr_c00 += 32; - ptr_c01 += 32; - ptr_c10 += 32; - ptr_c11 += 32; - } - for (; n_count > 0; n_count -= 16) { - int tail_n = (n_count > 16) ? 16: n_count; - __mmask16 bmask = (1UL << tail_n) - 1; - ptr_a0 = ptr_a; - ptr_a1 = ptr_a + 16 * k; - - ptr_b0 = ptr_b; - ptr_b += tail_n * k; - - lda = 32; - ldb = 2 * tail_n; - TCONF(cfg, 16, tail_n, 32); - LOAD_C(0, 0); - LOAD_C(1, 0); - k_count = k; - for (; k_count > 31; k_count -= 32) { - LOAD_A(0, x); LOAD_A(1, x); - ptr_a0 += 16 * 32; - ptr_a1 += 16 * 32; - LOAD_B(x, 0); - ptr_b0 += tail_n * 32; - - MATMUL(0, 0); - MATMUL(1, 0); - } - STORE_C(0, 0); - STORE_C(1, 0); - ptr_c00 += tail_n; - ptr_c10 += tail_n; - } - ptr_a += 32 * k; - } - for (; m_count > 0; m_count -= 16) { - // process at most 16 m at a time - int tail_m = (m_count > 16) ? 16: m_count; - __mmask16 amask = (1UL << tail_m) - 1; - - ptr_b = B; - - ptr_c00 = ptr_c; - ptr_c01 = ptr_c00 + 16; - ptr_c += tail_m * ldc; - n_count = n; - TCONF(cfg, tail_m, 16, 32); - for (; n_count > 31; n_count -= 32) { - ptr_a0 = ptr_a; - - ptr_b0 = ptr_b; - ptr_b1 = ptr_b + 16 * k; - ptr_b += 32 * k; - - lda = 32; - ldb = 32; - LOAD_C(0, 0); LOAD_C(0, 1); - k_count = k; - for (; k_count > 31; k_count -= 32) { - LOAD_A(0, x); - ptr_a0 += tail_m * 32; - LOAD_B(x, 0); LOAD_B(x, 1); - ptr_b0 += 16 * 32; - ptr_b1 += 16 * 32; - - MATMUL(0, 0); MATMUL(0, 1); - } - STORE_C(0, 0); STORE_C(0, 1); - ptr_c00 += 32; - ptr_c01 += 32; - } - for (; n_count > 0; n_count -= 16) { - int tail_n = (n_count > 16) ? 16: n_count; - __mmask16 bmask = (1UL << tail_n) - 1; - ptr_a0 = ptr_a; - - ptr_b0 = ptr_b; - ptr_b += tail_n * k; - - lda = 32; - ldb = 2 * tail_n; - TCONF(cfg, tail_m, tail_n, 32); - LOAD_C(0, 0); - k_count = k; - for (; k_count > 31; k_count -= 32) { - LOAD_A(0, x); - ptr_a0 += tail_m * 32; - LOAD_B(x, 0); - ptr_b0 += tail_n * 32; - - MATMUL(0, 0); - } - STORE_C(0, 0); - ptr_c00 += tail_n; - } - ptr_a += tail_m * k; - } - -tail_k: - // process for k < 32 - BLASLONG k32 = k & ~31; - BLASLONG k2 = k & ~1; - if (k32 != k) { - int remain_k2 = k2 - k32; - m_count = m; - ptr_a = A; - ptr_c = C; - if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0) - for (; m_count > 0; m_count -= 16) { - int tail_m = (m_count > 16) ? 16: m_count; - __mmask16 amask = (1UL << tail_m) - 1; - - ptr_a0 = ptr_a + tail_m * k32; - ptr_a1 = ptr_a + tail_m * k2; - ptr_a += tail_m * k; - ptr_b = B; - ptr_c00 = ptr_c; - ptr_c += tail_m * ldc; - n_count = n; - lda = remain_k2; - ldb = 32; - if (n_count > 15) { - TCONF_TAIL(cfg, tail_m, 16, remain_k2); - LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k32; - ptr_b1 = ptr_b + 16 * k2; - LOAD_C(0, 0); - LOAD_B(x, 0); LOAD_B_TAIL(x, 1); - MATMUL(0, 0); MATMUL_TAIL(1, 1); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; - } - } - if (n_count > 0) { - int tail_n = (n_count > 16) ? 16: n_count; - __mmask16 bmask = (1UL << tail_n) - 1; - ptr_b0 = ptr_b + tail_n * k32; - ptr_b1 = ptr_b + tail_n * k2; - ldb = 2 * tail_n; - TCONF_TAIL(cfg, tail_m, tail_n, remain_k2); - LOAD_C(0, 0); - LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); - LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1); - MATMUL(0, 0); MATMUL_TAIL(1, 1); - STORE_C(0, 0); - } - } - - } else if (remain_k2 > 0) { // k%32 = 2x - for (; m_count > 0; m_count -= 16) { - int tail_m = (m_count > 16) ? 16: m_count; - - ptr_a0 = ptr_a + tail_m * k32; - ptr_a += tail_m * k; - ptr_b = B; - ptr_c00 = ptr_c; - ptr_c += tail_m * ldc; - n_count = n; - lda = remain_k2; - ldb = 32; - if (n_count > 15) { - TCONF(cfg, tail_m, 16, remain_k2); - LOAD_A(0, x); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k32; - LOAD_C(0, 0); - LOAD_B(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; - } - } - if (n_count > 0) { - int tail_n = (n_count > 16) ? 16: n_count; - ptr_b0 = ptr_b + tail_n * k32; - ldb = 2 * tail_n; - TCONF(cfg, tail_m, tail_n, remain_k2); - LOAD_C(0, 0); - LOAD_A(0, x); - LOAD_B(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - } - } - } else { // k%32 = 1 - for (; m_count > 0; m_count -= 16) { - int tail_m = (m_count > 16) ? 16: m_count; - __mmask16 amask = (1UL << tail_m) - 1; - - ptr_a0 = ptr_a + tail_m * k2; - ptr_a += tail_m * k; - ptr_b = B; - ptr_c00 = ptr_c; - ptr_c += tail_m * ldc; - n_count = n; - if (n_count > 15) { - TCONF(cfg, tail_m, 16, 2); - MASK_LOAD_A_TAIL(0, x); - for (; n_count > 15; n_count -= 16) { - ptr_b0 = ptr_b + 16 * k2; - LOAD_C(0, 0); - LOAD_B_TAIL(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - ptr_b += 16 * k; - ptr_c00 += 16; - } - } - if (n_count > 0) { - int tail_n = (n_count > 16) ? 16: n_count; - __mmask16 bmask = (1UL << tail_n) - 1; - ptr_b0 = ptr_b + tail_n * k2; - TCONF(cfg, tail_m, tail_n, 2); - LOAD_C(0, 0); - MASK_LOAD_A_TAIL(0, x); - MASK_LOAD_B_TAIL(x, 0); - MATMUL(0, 0); - STORE_C(0, 0); - } - } - - } - } - return 0; + if (alpha == 1.0f) + return sbgemm_kernel_spr_alpha_one(m, n, k, alpha, A, B, C, ldc); + else + return sbgemm_kernel_spr_alpha(m, n, k, alpha, A, B, C, ldc); } diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c b/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c new file mode 100644 index 000000000..465b9eb75 --- /dev/null +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c @@ -0,0 +1,521 @@ +/*************************************************************************** + * Copyright (c) 2021, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include +#include +#include "common.h" + +#ifndef SBGEMM_KERNEL_SPR +#define SBGEMM_KERNEL_SPR +typedef struct { + char palette_id; + char start_row; + char dummy0[14]; // bytes 2-15 reserved, must be zero + short tile_colsb[8]; + char dummy1[16]; // bytes 32-47 reserved, must be zero + char tile_rows[8]; + char dummy2[16]; // bytes 56-63 reserved, must be zero +} tilecfg; + +/* tile0/tile1 -- A (m x 2k) + * tile2/tile3 -- B (2k x n) + * tile4-7 -- C (m x n) + */ +#define TCONF(cfg, m, n, k2) \ + memset(&cfg, 0, sizeof(tilecfg)); \ + cfg.palette_id = 1; \ + cfg.tile_rows[0] = m; \ + cfg.tile_rows[1] = m; \ + cfg.tile_rows[2] = k2>>1; \ + cfg.tile_rows[3] = k2>>1; \ + cfg.tile_rows[4] = m; \ + cfg.tile_rows[5] = m; \ + cfg.tile_rows[6] = m; \ + cfg.tile_rows[7] = m; \ + cfg.tile_colsb[0] = k2<<1; \ + cfg.tile_colsb[1] = k2<<1; \ + cfg.tile_colsb[2] = n * 4; \ + cfg.tile_colsb[3] = n * 4; \ + cfg.tile_colsb[4] = n * 4; \ + cfg.tile_colsb[5] = n * 4; \ + cfg.tile_colsb[6] = n * 4; \ + cfg.tile_colsb[7] = n * 4; \ + _tile_loadconfig(&cfg); + +/* CONFIG for handling k2 and odd tail at the same time + * tile0 -- A (m x 2k) + * tile1 -- A (m x 1) + * tile2 -- B (2k x n) + * tile3 -- B (1 x n) + * tile4 -- C (m x n) + */ +#define TCONF_TAIL(cfg, m, n, k2) \ + memset(&cfg, 0, sizeof(tilecfg)); \ + cfg.palette_id = 1; \ + cfg.tile_rows[0] = m; \ + cfg.tile_rows[1] = m; \ + cfg.tile_rows[2] = k2>>1; \ + cfg.tile_rows[3] = 1; \ + cfg.tile_rows[4] = m; \ + cfg.tile_colsb[0] = k2<<1; \ + cfg.tile_colsb[1] = 4; \ + cfg.tile_colsb[2] = n * 4; \ + cfg.tile_colsb[3] = n * 4; \ + cfg.tile_colsb[4] = n * 4; \ + _tile_loadconfig(&cfg); + +#define T_A0 0 +#define T_A1 1 +#define T_B0 2 +#define T_B1 3 +#define T_C00 4 +#define T_C01 5 +#define T_C10 6 +#define T_C11 7 + +// FIXME: gcc11 seem have problem in tile load/store address calc, +// need to multiply with element size (2 or 4) here. +#define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2) +#define LOAD_A_TAIL(M, N) {\ + __m256i ymm = _mm256_loadu_epi16(ptr_a##M); \ + __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ + _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ + _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ +} +#define MASK_LOAD_A_TAIL(M, N) {\ + __m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \ + __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ + _mm512_storeu_epi16(tail_a + 16 * M, zmm); \ + _tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ +} +#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2) +#define LOAD_B_TAIL(M, N) {\ + __m256i ymm = _mm256_loadu_epi16(ptr_b##N); \ + __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ + _mm512_storeu_epi16(tail_b + 16 * N, zmm); \ + _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ +} +#define MASK_LOAD_B_TAIL(M, N) {\ + __m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \ + __m512i zmm = _mm512_cvtepu16_epi32(ymm); \ + _mm512_storeu_epi16(tail_b + 16 * N, zmm); \ + _tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ +} + +#define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N) +#define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N) +#define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4) +#define LOAD_C_F(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4) + +#endif // end of SBGEMM_KERNEL_SPR + +#ifdef ALPHA_ONE +#undef LOAD_C +#define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4) +#else +#undef LOAD_C +#define LOAD_C(M, N) _tile_zero(T_C##M##N) +#define ALPHA_STORE(N) \ + __m512 zmm_d##N = _mm512_loadu_ps(dst##N + noffset); \ + __m512 zmm_s##N = _mm512_loadu_ps(src##N + noffset); \ + zmm_d##N = _mm512_fmadd_ps(alpha_512, zmm_s##N, zmm_d##N); \ + _mm512_storeu_ps(dst##N + noffset, zmm_d##N); +#define MASK_APLPHA_STORE(N) \ + __m512 zmm_d##N = _mm512_maskz_loadu_ps(mask, dst##N + noffset); \ + __m512 zmm_s##N = _mm512_maskz_loadu_ps(mask, src##N + noffset); \ + zmm_d##N = _mm512_fmadd_ps(alpha_512, zmm_s##N, zmm_d##N); \ + _mm512_mask_storeu_ps(dst##N + noffset, mask, zmm_d##N); +#endif // end of ALPHA_ONE + + +#ifdef ALPHA_ONE +int sbgemm_kernel_spr_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc) +#else +int sbgemm_kernel_spr_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc) +#endif +{ + /* Row Major matrix for AMX requirement */ + IFLOAT *ptr_a = A, *ptr_b = B; + IFLOAT *ptr_b0, *ptr_b1; + IFLOAT *ptr_a0, *ptr_a1; + FLOAT *ptr_c = C; + FLOAT *ptr_c00, *ptr_c01, *ptr_c10, *ptr_c11; + + BLASLONG lda, ldb; + BLASLONG m_count = m; + BLASLONG n_count, k_count; + +#ifndef ALPHA_ONE + FLOAT *tmp_c = malloc(sizeof(FLOAT) * m * n); + memset(tmp_c, 0, sizeof(FLOAT) * m * n); + ptr_c = tmp_c; + BLASLONG ldc_o = ldc; + ldc = n; +#endif + IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64))); + IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64))); + tilecfg cfg; + + if (k > 31) { + for (; m_count > 31; m_count -= 32) { + ptr_b = B; + + ptr_c00 = ptr_c; + ptr_c01 = ptr_c00 + 16; + ptr_c10 = ptr_c + 16 * ldc; + ptr_c11 = ptr_c10 + 16; + ptr_c += 32 * ldc; + n_count = n; + TCONF(cfg, 16, 16, 32); + for (; n_count > 31; n_count -= 32) { + ptr_a0 = ptr_a; + ptr_a1 = ptr_a + 16 * k; + + ptr_b0 = ptr_b; + ptr_b1 = ptr_b + 16 * k; + ptr_b += 32 * k; + + lda = 32; + ldb = 32; + LOAD_C(0, 0); LOAD_C(0, 1); + LOAD_C(1, 0); LOAD_C(1, 1); + k_count = k; + for (; k_count > 31; k_count -= 32) { + LOAD_A(0, x); LOAD_A(1, x); + ptr_a0 += 16 * 32; + ptr_a1 += 16 * 32; + LOAD_B(x, 0); LOAD_B(x, 1); + ptr_b0 += 16 * 32; + ptr_b1 += 16 * 32; + + MATMUL(0, 0); MATMUL(0, 1); + MATMUL(1, 0); MATMUL(1, 1); + } + STORE_C(0, 0); STORE_C(0, 1); + STORE_C(1, 0); STORE_C(1, 1); + ptr_c00 += 32; + ptr_c01 += 32; + ptr_c10 += 32; + ptr_c11 += 32; + } + for (; n_count > 0; n_count -= 16) { + int tail_n = (n_count > 16) ? 16: n_count; + ptr_a0 = ptr_a; + ptr_a1 = ptr_a + 16 * k; + + ptr_b0 = ptr_b; + ptr_b += tail_n * k; + + lda = 32; + ldb = 2 * tail_n; + TCONF(cfg, 16, tail_n, 32); + LOAD_C(0, 0); + LOAD_C(1, 0); + k_count = k; + for (; k_count > 31; k_count -= 32) { + LOAD_A(0, x); LOAD_A(1, x); + ptr_a0 += 16 * 32; + ptr_a1 += 16 * 32; + LOAD_B(x, 0); + ptr_b0 += tail_n * 32; + + MATMUL(0, 0); + MATMUL(1, 0); + } + STORE_C(0, 0); + STORE_C(1, 0); + ptr_c00 += tail_n; + ptr_c10 += tail_n; + } + ptr_a += 32 * k; + } + for (; m_count > 0; m_count -= 16) { + // process at most 16 m at a time + int tail_m = (m_count > 16) ? 16: m_count; + + ptr_b = B; + + ptr_c00 = ptr_c; + ptr_c01 = ptr_c00 + 16; + ptr_c += tail_m * ldc; + n_count = n; + TCONF(cfg, tail_m, 16, 32); + for (; n_count > 31; n_count -= 32) { + ptr_a0 = ptr_a; + + ptr_b0 = ptr_b; + ptr_b1 = ptr_b + 16 * k; + ptr_b += 32 * k; + + lda = 32; + ldb = 32; + LOAD_C(0, 0); LOAD_C(0, 1); + k_count = k; + for (; k_count > 31; k_count -= 32) { + LOAD_A(0, x); + ptr_a0 += tail_m * 32; + LOAD_B(x, 0); LOAD_B(x, 1); + ptr_b0 += 16 * 32; + ptr_b1 += 16 * 32; + + MATMUL(0, 0); MATMUL(0, 1); + } + STORE_C(0, 0); STORE_C(0, 1); + ptr_c00 += 32; + ptr_c01 += 32; + } + for (; n_count > 0; n_count -= 16) { + int tail_n = (n_count > 16) ? 16: n_count; + ptr_a0 = ptr_a; + + ptr_b0 = ptr_b; + ptr_b += tail_n * k; + + lda = 32; + ldb = 2 * tail_n; + TCONF(cfg, tail_m, tail_n, 32); + LOAD_C(0, 0); + k_count = k; + for (; k_count > 31; k_count -= 32) { + LOAD_A(0, x); + ptr_a0 += tail_m * 32; + LOAD_B(x, 0); + ptr_b0 += tail_n * 32; + + MATMUL(0, 0); + } + STORE_C(0, 0); + ptr_c00 += tail_n; + } + ptr_a += tail_m * k; + } + } + + // process for k < 32 + BLASLONG k32 = k & ~31; + BLASLONG k2 = k & ~1; + if (k32 != k) { + int remain_k2 = k2 - k32; + m_count = m; + ptr_a = A; +#ifndef ALPHA_ONE + ptr_c = tmp_c; +#else + ptr_c = C; +#endif + if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0) + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + __mmask16 amask = (1UL << tail_m) - 1; + + ptr_a0 = ptr_a + tail_m * k32; + ptr_a1 = ptr_a + tail_m * k2; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + lda = remain_k2; + ldb = 32; + if (n_count > 15) { + TCONF_TAIL(cfg, tail_m, 16, remain_k2); + LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k32; + ptr_b1 = ptr_b + 16 * k2; + LOAD_C_F(0, 0); + LOAD_B(x, 0); LOAD_B_TAIL(x, 1); + MATMUL(0, 0); MATMUL_TAIL(1, 1); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_b0 = ptr_b + tail_n * k32; + ptr_b1 = ptr_b + tail_n * k2; + ldb = 2 * tail_n; + TCONF_TAIL(cfg, tail_m, tail_n, remain_k2); + LOAD_C_F(0, 0); + LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x); + LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1); + MATMUL(0, 0); MATMUL_TAIL(1, 1); + STORE_C(0, 0); + } + } + + } else if (remain_k2 > 0) { // k%32 = 2x + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + + ptr_a0 = ptr_a + tail_m * k32; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + lda = remain_k2; + ldb = 32; + if (n_count > 15) { + TCONF(cfg, tail_m, 16, remain_k2); + LOAD_A(0, x); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k32; + LOAD_C_F(0, 0); + LOAD_B(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + ptr_b0 = ptr_b + tail_n * k32; + ldb = 2 * tail_n; + TCONF(cfg, tail_m, tail_n, remain_k2); + LOAD_C_F(0, 0); + LOAD_A(0, x); + LOAD_B(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + } + } + } else { // k%32 = 1 + for (; m_count > 0; m_count -= 16) { + int tail_m = (m_count > 16) ? 16: m_count; + __mmask16 amask = (1UL << tail_m) - 1; + + ptr_a0 = ptr_a + tail_m * k2; + ptr_a += tail_m * k; + ptr_b = B; + ptr_c00 = ptr_c; + ptr_c += tail_m * ldc; + n_count = n; + if (n_count > 15) { + TCONF(cfg, tail_m, 16, 2); + MASK_LOAD_A_TAIL(0, x); + for (; n_count > 15; n_count -= 16) { + ptr_b0 = ptr_b + 16 * k2; + LOAD_C_F(0, 0); + LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + ptr_b += 16 * k; + ptr_c00 += 16; + } + } + if (n_count > 0) { + int tail_n = (n_count > 16) ? 16: n_count; + __mmask16 bmask = (1UL << tail_n) - 1; + ptr_b0 = ptr_b + tail_n * k2; + TCONF(cfg, tail_m, tail_n, 2); + LOAD_C_F(0, 0); + MASK_LOAD_A_TAIL(0, x); + MASK_LOAD_B_TAIL(x, 0); + MATMUL(0, 0); + STORE_C(0, 0); + } + } + + } + } +#ifndef ALPHA_ONE + __m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha)); + BLASLONG n16 = n & ~15; + BLASLONG noffset; + FLOAT *src0, *src1, *src2, *src3; + FLOAT *dst0, *dst1, *dst2, *dst3; + FLOAT *src = tmp_c; + FLOAT *dst = C; + m_count = m; + for (; m_count > 3; m_count -= 4) { + src0 = src; + src1 = src0 + ldc; + src2 = src1 + ldc; + src3 = src2 + ldc; + src += 4 * ldc; + + dst0 = dst; + dst1 = dst0 + ldc_o; + dst2 = dst1 + ldc_o; + dst3 = dst2 + ldc_o; + dst += 4 * ldc_o; + + noffset = 0; + for (; noffset < n16; noffset += 16) { + ALPHA_STORE(0); + ALPHA_STORE(1); + ALPHA_STORE(2); + ALPHA_STORE(3); + } + if (noffset < n) { + __mmask16 mask = (1UL << (n - noffset)) - 1; + MASK_APLPHA_STORE(0); + MASK_APLPHA_STORE(1); + MASK_APLPHA_STORE(2); + MASK_APLPHA_STORE(3); + } + } + for (; m_count > 1; m_count -= 2) { + src0 = src; + src1 = src0 + ldc; + src += 2 * ldc; + + dst0 = dst; + dst1 = dst0 + ldc_o; + dst += 2 * ldc_o; + + noffset = 0; + for (; noffset < n16; noffset += 16) { + ALPHA_STORE(0); + ALPHA_STORE(1); + } + if (noffset < n) { + __mmask16 mask = (1UL << (n - noffset)) - 1; + MASK_APLPHA_STORE(0); + MASK_APLPHA_STORE(1); + } + } + for (; m_count > 0; m_count -= 1) { + src0 = src; + dst0 = dst; + noffset = 0; + for (; noffset < n16; noffset += 16) { + ALPHA_STORE(0); + } + if (noffset < n) { + __mmask16 mask = (1UL << (n - noffset)) - 1; + MASK_APLPHA_STORE(0); + } + } + free(tmp_c); +#endif + return 0; +} From 6bc8204ce5c137cc18c2580eabc30264d4b8b2fe Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Fri, 17 Sep 2021 23:59:32 -0700 Subject: [PATCH 13/16] sbgemm: spr: optimization for tmp_c buffer --- kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c b/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c index 465b9eb75..90e0a32c7 100644 --- a/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c +++ b/kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c @@ -170,11 +170,20 @@ int sbgemm_kernel_spr_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFL BLASLONG n_count, k_count; #ifndef ALPHA_ONE - FLOAT *tmp_c = malloc(sizeof(FLOAT) * m * n); - memset(tmp_c, 0, sizeof(FLOAT) * m * n); + // make sure each row is 64 bytes aligned + BLASLONG cn = (n & 31) ? (n & ~31) + 32 : n; + FLOAT *raw_tmp_c; + if (k < 32) { + // only need to zero buff in this situation + raw_tmp_c = (FLOAT *)calloc(1, sizeof(FLOAT) * m * cn + 64); + } else { + raw_tmp_c = (FLOAT *)malloc(sizeof(FLOAT) * m * cn + 64); + } + // align buf to 64 byte boundary + FLOAT *tmp_c = (FLOAT *)(((uintptr_t) raw_tmp_c + 63) & ~(uintptr_t)63); ptr_c = tmp_c; BLASLONG ldc_o = ldc; - ldc = n; + ldc = cn; #endif IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64))); IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64))); @@ -515,7 +524,7 @@ int sbgemm_kernel_spr_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFL MASK_APLPHA_STORE(0); } } - free(tmp_c); + free(raw_tmp_c); #endif return 0; } From 8632380a96f9172a6bba4610d9014faf9fb0cd74 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Sat, 18 Sep 2021 01:11:31 -0700 Subject: [PATCH 14/16] sbgemm: spr: reuse ncopy_16 from cooperlake as incopy --- kernel/x86_64/KERNEL.SAPPHIRERAPIDS | 2 +- kernel/x86_64/sbgemm_incopy_16_spr.c | 32 ---------------------------- 2 files changed, 1 insertion(+), 33 deletions(-) delete mode 100644 kernel/x86_64/sbgemm_incopy_16_spr.c diff --git a/kernel/x86_64/KERNEL.SAPPHIRERAPIDS b/kernel/x86_64/KERNEL.SAPPHIRERAPIDS index 3f67640cb..e061b913d 100644 --- a/kernel/x86_64/KERNEL.SAPPHIRERAPIDS +++ b/kernel/x86_64/KERNEL.SAPPHIRERAPIDS @@ -2,7 +2,7 @@ include $(KERNELDIR)/KERNEL.COOPERLAKE SBGEMM_BETA = sgemm_beta_skylakex.c SBGEMMKERNEL = sbgemm_kernel_16x16_spr.c -SBGEMMINCOPY = sbgemm_incopy_16_spr.c +SBGEMMINCOPY = sbgemm_ncopy_16_cooperlake.c SBGEMMITCOPY = sbgemm_tcopy_16_cooperlake.c SBGEMMONCOPY = sbgemm_oncopy_16_spr.c SBGEMMOTCOPY = sbgemm_otcopy_16_spr.c diff --git a/kernel/x86_64/sbgemm_incopy_16_spr.c b/kernel/x86_64/sbgemm_incopy_16_spr.c deleted file mode 100644 index 2f57ae7b6..000000000 --- a/kernel/x86_64/sbgemm_incopy_16_spr.c +++ /dev/null @@ -1,32 +0,0 @@ -/*************************************************************************** - * Copyright (c) 2021, The OpenBLAS Project - * All rights reserved. - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in - * the documentation and/or other materials provided with the - * distribution. - * 3. Neither the name of the OpenBLAS project nor the names of - * its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE - * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * *****************************************************************************/ - -#include "common.h" - -int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { - return 0; -} From 82194ea9d2c8bc2f3e8521421904eeb3419c3ab3 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 23 Sep 2021 01:08:40 -0700 Subject: [PATCH 15/16] sbgemm: spr: implement otcopy_16 --- kernel/x86_64/sbgemm_otcopy_16_spr.c | 270 +++++++++++++++++++++++++++ 1 file changed, 270 insertions(+) diff --git a/kernel/x86_64/sbgemm_otcopy_16_spr.c b/kernel/x86_64/sbgemm_otcopy_16_spr.c index 2f57ae7b6..b5d5d38fb 100644 --- a/kernel/x86_64/sbgemm_otcopy_16_spr.c +++ b/kernel/x86_64/sbgemm_otcopy_16_spr.c @@ -25,8 +25,278 @@ * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * *****************************************************************************/ +#include #include "common.h" +#define LOAD_A_8VEC(aptr) \ + r0 = _mm256_loadu_si256((__m256i *)(aptr + lda*0)); \ + r1 = _mm256_loadu_si256((__m256i *)(aptr + lda*1)); \ + r2 = _mm256_loadu_si256((__m256i *)(aptr + lda*2)); \ + r3 = _mm256_loadu_si256((__m256i *)(aptr + lda*3)); \ + r4 = _mm256_loadu_si256((__m256i *)(aptr + lda*4)); \ + r5 = _mm256_loadu_si256((__m256i *)(aptr + lda*5)); \ + r6 = _mm256_loadu_si256((__m256i *)(aptr + lda*6)); \ + r7 = _mm256_loadu_si256((__m256i *)(aptr + lda*7)); + +#define MASK_LOAD_A_8VEC(aptr) \ + r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*0)); \ + r1 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*1)); \ + r2 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*2)); \ + r3 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*3)); \ + r4 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*4)); \ + r5 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*5)); \ + r6 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*6)); \ + r7 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*7)); + +#define SWITCH_LOAD_A_8VEC(aptr, cond) \ + switch((cond)) { \ + case 8: r7 = _mm256_loadu_si256((__m256i *)(aptr + lda*7)); \ + case 7: r6 = _mm256_loadu_si256((__m256i *)(aptr + lda*6)); \ + case 6: r5 = _mm256_loadu_si256((__m256i *)(aptr + lda*5)); \ + case 5: r4 = _mm256_loadu_si256((__m256i *)(aptr + lda*4)); \ + case 4: r3 = _mm256_loadu_si256((__m256i *)(aptr + lda*3)); \ + case 3: r2 = _mm256_loadu_si256((__m256i *)(aptr + lda*2)); \ + case 2: r1 = _mm256_loadu_si256((__m256i *)(aptr + lda*1)); \ + case 1: r0 = _mm256_loadu_si256((__m256i *)(aptr + lda*0)); \ + } + +#define SWITCH_MASK_LOAD_A_8VEC(aptr, cond) \ + switch((cond)) { \ + case 8: r7 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*7)); \ + case 7: r6 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*6)); \ + case 6: r5 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*5)); \ + case 5: r4 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*4)); \ + case 4: r3 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*3)); \ + case 3: r2 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*2)); \ + case 2: r1 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*1)); \ + case 1: r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aptr + lda*0)); \ + } + +#define REORDER_8x16(t0, t1, t2, t3, t4, t5, t6, t7) \ + t0 = _mm256_unpacklo_epi16(r0, r1); \ + t1 = _mm256_unpackhi_epi16(r0, r1); \ + t2 = _mm256_unpacklo_epi16(r2, r3); \ + t3 = _mm256_unpackhi_epi16(r2, r3); \ + t4 = _mm256_unpacklo_epi16(r4, r5); \ + t5 = _mm256_unpackhi_epi16(r4, r5); \ + t6 = _mm256_unpacklo_epi16(r6, r7); \ + t7 = _mm256_unpackhi_epi16(r6, r7); \ + r0 = _mm256_unpacklo_epi32(t0, t2); \ + r1 = _mm256_unpacklo_epi32(t1, t3); \ + r2 = _mm256_unpacklo_epi32(t4, t6); \ + r3 = _mm256_unpacklo_epi32(t5, t7); \ + r4 = _mm256_unpackhi_epi32(t0, t2); \ + r5 = _mm256_unpackhi_epi32(t1, t3); \ + r6 = _mm256_unpackhi_epi32(t4, t6); \ + r7 = _mm256_unpackhi_epi32(t5, t7); \ + t0 = _mm256_unpacklo_epi64(r0, r2); \ + t1 = _mm256_unpackhi_epi64(r0, r2); \ + t2 = _mm256_unpacklo_epi64(r4, r6); \ + t3 = _mm256_unpackhi_epi64(r4, r6); \ + t4 = _mm256_unpacklo_epi64(r1, r3); \ + t5 = _mm256_unpackhi_epi64(r1, r3); \ + t6 = _mm256_unpacklo_epi64(r5, r7); \ + t7 = _mm256_unpackhi_epi64(r5, r7); + +#define STORE_256_LO(x) \ + v = _mm256_permute2x128_si256(t0##x, t1##x, 0x20); \ + _mm256_storeu_si256((__m256i *)(boffset + x*32), v); + +#define STORE_256_HI(x) \ + v = _mm256_permute2x128_si256(t0##x, t1##x, 0x31); \ + _mm256_storeu_si256((__m256i *)(boffset + (x + 8)*32), v); + +#define MASK_STORE_256_LO(x) \ + v = _mm256_permute2x128_si256(t0##x, t1##x, 0x20); \ + _mm256_mask_storeu_epi16(boffset + x*m_load, mmask, v); + +#define MASK_STORE_256_HI(x) \ + v = _mm256_permute2x128_si256(t0##x, t1##x, 0x31); \ + _mm256_mask_storeu_epi16(boffset + (x + 8)*m_load, mmask, v); + +#define STORE_256(x, y) {\ + __m256i v; \ + if (x == 0) { STORE_256_LO(y); } \ + else { STORE_256_HI(y); } \ +} + +#define MASK_STORE_256(x, y) {\ + __m256i v; \ + if (x == 0) { MASK_STORE_256_LO(y); } \ + else { MASK_STORE_256_HI(y); } \ +} + +#define SWITCH_STORE_16x(cond, func) \ + switch((cond)) {\ + case 15: func(1, 6); \ + case 14: func(1, 5); \ + case 13: func(1, 4); \ + case 12: func(1, 3); \ + case 11: func(1, 2); \ + case 10: func(1, 1); \ + case 9: func(1, 0); \ + case 8: func(0, 7); \ + case 7: func(0, 6); \ + case 6: func(0, 5); \ + case 5: func(0, 4); \ + case 4: func(0, 3); \ + case 3: func(0, 2); \ + case 2: func(0, 1); \ + case 1: func(0, 0); \ + } + + int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { + IFLOAT *aoffset, *boffset; + IFLOAT *aoffset00, *aoffset01, *aoffset10, *aoffset11; + IFLOAT *boffset0; + + __m256i r0, r1, r2, r3, r4, r5, r6, r7; + __m256i t00, t01, t02, t03, t04, t05, t06, t07; + __m256i t10, t11, t12, t13, t14, t15, t16, t17; + + aoffset = a; + boffset = b; + BLASLONG n_count = n; + BLASLONG m_count = m; + for (; n_count > 15; n_count -= 16) { + aoffset00 = aoffset; + aoffset01 = aoffset00 + 8 * lda; + aoffset10 = aoffset01 + 8 * lda; + aoffset11 = aoffset10 + 8 * lda; + aoffset += 16; + m_count = m; + for (; m_count > 31; m_count -= 32) { + // first 16 rows + LOAD_A_8VEC(aoffset00); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + LOAD_A_8VEC(aoffset01); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + STORE_256(0, 0); STORE_256(0, 1); STORE_256(0, 2); STORE_256(0, 3); + STORE_256(0, 4); STORE_256(0, 5); STORE_256(0, 6); STORE_256(0, 7); + STORE_256(1, 0); STORE_256(1, 1); STORE_256(1, 2); STORE_256(1, 3); + STORE_256(1, 4); STORE_256(1, 5); STORE_256(1, 6); STORE_256(1, 7); + // last 16 rows + boffset += 16; + LOAD_A_8VEC(aoffset10); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + LOAD_A_8VEC(aoffset11); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + STORE_256(0, 0); STORE_256(0, 1); STORE_256(0, 2); STORE_256(0, 3); + STORE_256(0, 4); STORE_256(0, 5); STORE_256(0, 6); STORE_256(0, 7); + STORE_256(1, 0); STORE_256(1, 1); STORE_256(1, 2); STORE_256(1, 3); + STORE_256(1, 4); STORE_256(1, 5); STORE_256(1, 6); STORE_256(1, 7); + aoffset00 += 32 * lda; + aoffset01 += 32 * lda; + aoffset10 += 32 * lda; + aoffset11 += 32 * lda; + boffset += 31 * 16; + } + if (m_count > 1) { + int m_load = m_count & ~1; + m_count -= m_load; + __mmask16 mmask; + SWITCH_LOAD_A_8VEC(aoffset00, m_load > 8 ? 8: m_load); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + if (m_load > 8) { + SWITCH_LOAD_A_8VEC(aoffset01, m_load > 16 ? 8: m_load - 8); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + } + int this_load = m_load > 16 ? 16 : m_load; + mmask = (1UL << this_load) - 1; + MASK_STORE_256(0, 0); MASK_STORE_256(0, 1); MASK_STORE_256(0, 2); MASK_STORE_256(0, 3); + MASK_STORE_256(0, 4); MASK_STORE_256(0, 5); MASK_STORE_256(0, 6); MASK_STORE_256(0, 7); + MASK_STORE_256(1, 0); MASK_STORE_256(1, 1); MASK_STORE_256(1, 2); MASK_STORE_256(1, 3); + MASK_STORE_256(1, 4); MASK_STORE_256(1, 5); MASK_STORE_256(1, 6); MASK_STORE_256(1, 7); + boffset0 = boffset; + if (m_load > 16) { + boffset += this_load; + SWITCH_LOAD_A_8VEC(aoffset10, m_load > 24 ? 8: m_load - 16); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + if (m_load > 24) { + SWITCH_LOAD_A_8VEC(aoffset11, m_load - 24); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + } + this_load = m_load - 16; + mmask = (1UL << this_load) - 1; + MASK_STORE_256(0, 0); MASK_STORE_256(0, 1); MASK_STORE_256(0, 2); MASK_STORE_256(0, 3); + MASK_STORE_256(0, 4); MASK_STORE_256(0, 5); MASK_STORE_256(0, 6); MASK_STORE_256(0, 7); + MASK_STORE_256(1, 0); MASK_STORE_256(1, 1); MASK_STORE_256(1, 2); MASK_STORE_256(1, 3); + MASK_STORE_256(1, 4); MASK_STORE_256(1, 5); MASK_STORE_256(1, 6); MASK_STORE_256(1, 7); + } + boffset = boffset0 + 16 * m_load; + aoffset00 += m_load * lda; + } + if (m_count > 0) { + // just copy lask K to B directly + r0 = _mm256_loadu_si256((__m256i *)(aoffset00)); + _mm256_storeu_si256((__m256i *)(boffset), r0); + boffset += 16; + } + } + if (n_count > 0) { + __mmask16 nmask = (1UL << n_count) - 1; + aoffset00 = aoffset; + aoffset01 = aoffset00 + 8 * lda; + aoffset10 = aoffset01 + 8 * lda; + aoffset11 = aoffset10 + 8 * lda; + m_count = m; + for (; m_count > 31; m_count -= 32) { + // first 16 rows + MASK_LOAD_A_8VEC(aoffset00); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + MASK_LOAD_A_8VEC(aoffset01); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + SWITCH_STORE_16x(n_count, STORE_256); + // last 16 rows + boffset0 = boffset; + boffset += 16; + MASK_LOAD_A_8VEC(aoffset10); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + MASK_LOAD_A_8VEC(aoffset11); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + SWITCH_STORE_16x(n_count, STORE_256); + aoffset00 += 32 * lda; + aoffset01 += 32 * lda; + aoffset10 += 32 * lda; + aoffset11 += 32 * lda; + boffset = 32 * n_count + boffset0; + } + if (m_count > 1) { + int m_load = m_count & ~1; + m_count -= m_load; + __mmask16 mmask; + SWITCH_MASK_LOAD_A_8VEC(aoffset00, m_load > 8 ? 8: m_load); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + if (m_load > 8) { + SWITCH_MASK_LOAD_A_8VEC(aoffset01, m_load > 16 ? 8: m_load - 8); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + } + int this_load = m_load > 16 ? 16 : m_load; + mmask = (1UL << this_load) - 1; + SWITCH_STORE_16x(n_count, MASK_STORE_256); + boffset0 = boffset; + if (m_load > 16) { + boffset += this_load; + SWITCH_MASK_LOAD_A_8VEC(aoffset10, m_load > 24 ? 8: m_load - 16); + REORDER_8x16(t00, t01, t02, t03, t04, t05, t06, t07); + if (m_load > 24) { + SWITCH_MASK_LOAD_A_8VEC(aoffset11, m_load - 24); + REORDER_8x16(t10, t11, t12, t13, t14, t15, t16, t17); + } + this_load = m_load - 16; + mmask = (1UL << this_load) - 1; + SWITCH_STORE_16x(n_count, MASK_STORE_256); + } + boffset = boffset0 + n_count * m_load; + aoffset00 += m_load * lda; + } + if (m_count > 0) { + // just copy lask K to B directly + r0 = _mm256_maskz_loadu_epi16(nmask, (__m256i *)(aoffset00)); + _mm256_mask_storeu_epi16((__m256i *)(boffset), nmask, r0); + boffset += 16; + } + } return 0; } From 63a103ba6e8a55c4f117f99716d71ef341a03fa1 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Tue, 12 Oct 2021 01:18:37 -0700 Subject: [PATCH 16/16] sbgemm: spr: disable small matrix path by default --- kernel/x86_64/KERNEL.SAPPHIRERAPIDS | 2 + .../x86_64/sbgemm_small_kernel_permit_spr.c | 42 +++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 kernel/x86_64/sbgemm_small_kernel_permit_spr.c diff --git a/kernel/x86_64/KERNEL.SAPPHIRERAPIDS b/kernel/x86_64/KERNEL.SAPPHIRERAPIDS index e061b913d..88f574668 100644 --- a/kernel/x86_64/KERNEL.SAPPHIRERAPIDS +++ b/kernel/x86_64/KERNEL.SAPPHIRERAPIDS @@ -1,5 +1,7 @@ include $(KERNELDIR)/KERNEL.COOPERLAKE +SBGEMM_SMALL_M_PERMIT = sbgemm_small_kernel_permit_spr.c + SBGEMM_BETA = sgemm_beta_skylakex.c SBGEMMKERNEL = sbgemm_kernel_16x16_spr.c SBGEMMINCOPY = sbgemm_ncopy_16_cooperlake.c diff --git a/kernel/x86_64/sbgemm_small_kernel_permit_spr.c b/kernel/x86_64/sbgemm_small_kernel_permit_spr.c new file mode 100644 index 000000000..98d8ca06a --- /dev/null +++ b/kernel/x86_64/sbgemm_small_kernel_permit_spr.c @@ -0,0 +1,42 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#include "sbgemm_block_microk_cooperlake.c" +// Define micro kernels for ALPHA not ONE scenarios +#undef ONE_ALPHA +#include "sbgemm_microk_cooperlake_template.c" + +// Define micro kernels for ALPHA as ONE scenarios +#define ONE_ALPHA 1 +#include "sbgemm_microk_cooperlake_template.c" + +int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta) +{ + return 0; +}