Merge pull request #9 from xianyi/develop

rebase
This commit is contained in:
Martin Kroeker 2021-01-24 23:14:45 +01:00 committed by GitHub
commit e3ff4cdd23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1737 additions and 8 deletions

View File

@ -202,7 +202,7 @@ static gotoblas_t *get_coretype(void) {
return &gotoblas_POWER10;
#endif
/* Fall back to the POWER9 implementation if the toolchain is too old or the MMA feature is not set */
#if (!defined __GNUC__) || ( __GNUC__ >= 6)
#if (!defined __GNUC__) || ( __GNUC__ >= 11) || (__GNUC__ == 10 && __GNUC_MINOR__ >= 2)
if (__builtin_cpu_is("power10"))
return &gotoblas_POWER9;
#endif

View File

@ -246,6 +246,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
#ifdef SMP
double MNK;
#if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY)
#ifndef COMPLEX
#ifdef XDOUBLE
int mode = BLAS_XDOUBLE | BLAS_REAL;
@ -264,6 +265,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
#endif
#endif
#endif
#endif
#if defined(SMP) && !defined(NO_AFFINITY) && !defined(USE_SIMPLE_THREADED_LEVEL3)
int nodes;
@ -417,8 +419,10 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
sb = (XFLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
#ifdef SMP
#if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY)
mode |= (transa << BLAS_TRANSA_SHIFT);
mode |= (transb << BLAS_TRANSB_SHIFT);
#endif
MNK = (double) args.m * (double) args.n * (double) args.k;
if ( MNK <= (SMP_THRESHOLD_MIN * (double) GEMM_MULTITHREAD_THRESHOLD) )

View File

@ -107,7 +107,6 @@ void CNAME(FLOAT *dd1, FLOAT *dd2, FLOAT *dx1, FLOAT dy1, FLOAT *dparam){
dq1 = dp1 * *dx1;
if(ABS(dq1) > ABS(dq2))
{
dflag = ZERO;
dh11 = ONE;
dh22 = ONE;
dh21 = - dy1 / *dx1;

View File

@ -39,9 +39,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma GCC optimize "O1"
#if defined(POWER8) || defined(POWER9) || defined(POWER10)
#if defined(__VEC__) || defined(__ALTIVEC__)
#if defined(POWER8) || defined(POWER9)
#include "drot_microk_power8.c"
#elif defined(POWER10)
#include "drot_microk_power10.c"
#endif
#endif
@ -115,12 +117,30 @@ int CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT
if ( (inc_x == 1) && (inc_y == 1) )
{
#if defined(POWER10)
if ( n >= 16 )
{
BLASLONG align = ((32 - ((uintptr_t)y & (uintptr_t)0x1F)) >> 3) & 0x3;
for (i = 0; i < align; i++) {
temp = c*x[i] + s*y[i] ;
y[i] = c*y[i] - s*x[i] ;
x[i] = temp ;
}
}
BLASLONG n1 = (n-i) & -16;
if ( n1 > 0 )
{
drot_kernel_16(n1,&x[i], &y[i], c, s);
i+=n1;
}
#else
BLASLONG n1 = n & -16;
if ( n1 > 0 )
{
drot_kernel_16(n1, x1, y1, c, s);
i=n1;
}
#endif
while(i < n)
{

View File

@ -0,0 +1,148 @@
/***************************************************************************
Copyright (c) 2021, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#define HAVE_KERNEL_16 1
static void drot_kernel_16 (long n, double *x, double *y, double c, double s)
{
__asm__
(
XXSPLTD_S(36,%x5,0) // load c to both dwords
XXSPLTD_S(37,%x6,0) // load s to both dwords
"lxvp 32, 0(%3) \n\t" // load x
"lxvp 34, 32(%3) \n\t"
"lxvp 48, 0(%4) \n\t" // load y
"lxvp 50, 32(%4) \n\t"
"addic. %2, %2, -8 \n\t"
"ble two%= \n\t"
".align 5 \n"
"one%=: \n\t"
"xvmuldp 40, 32, 36 \n\t" // c * x
"xvmuldp 41, 33, 36 \n\t"
"xvmuldp 42, 34, 36 \n\t"
"xvmuldp 43, 35, 36 \n\t"
"xvmuldp 44, 32, 37 \n\t" // s * x
"xvmuldp 45, 33, 37 \n\t"
"xvmuldp 46, 34, 37 \n\t"
"xvmuldp 47, 35, 37 \n\t"
"lxvp 32, 64(%3) \n\t" // load x
"lxvp 34, 96(%3) \n\t"
"xvmuldp 52, 48, 36 \n\t" // c * y
"xvmuldp 53, 49, 36 \n\t"
"xvmuldp 54, 50, 36 \n\t"
"xvmuldp 55, 51, 36 \n\t"
"xvmuldp 38, 48, 37 \n\t" // s * y
"xvmuldp 39, 49, 37 \n\t"
"xvmuldp 56, 50, 37 \n\t"
"xvmuldp 57, 51, 37 \n\t"
"lxvp 48, 64(%4) \n\t" // load y
"lxvp 50, 96(%4) \n\t"
"xvadddp 40, 40, 38 \n\t" // c * x + s * y
"xvadddp 41, 41, 39 \n\t" // c * x + s * y
"xvadddp 42, 42, 56 \n\t" // c * x + s * y
"xvadddp 43, 43, 57 \n\t" // c * x + s * y
"stxvp 40, 0(%3) \n\t" // store x
"stxvp 42, 32(%3) \n\t"
"xvsubdp 52, 52, 44 \n\t" // c * y - s * x
"xvsubdp 53, 53, 45 \n\t" // c * y - s * x
"xvsubdp 54, 54, 46 \n\t" // c * y - s * x
"xvsubdp 55, 55, 47 \n\t" // c * y - s * x
"stxvp 52, 0(%4) \n\t" // store y
"stxvp 54, 32(%4) \n\t"
"addi %3, %3, 64 \n\t"
"addi %4, %4, 64 \n\t"
"addic. %2, %2, -8 \n\t"
"bgt one%= \n"
"two%=: \n\t"
"xvmuldp 40, 32, 36 \n\t" // c * x
"xvmuldp 41, 33, 36 \n\t"
"xvmuldp 42, 34, 36 \n\t"
"xvmuldp 43, 35, 36 \n\t"
"xvmuldp 52, 48, 36 \n\t" // c * y
"xvmuldp 53, 49, 36 \n\t"
"xvmuldp 54, 50, 36 \n\t"
"xvmuldp 55, 51, 36 \n\t"
"xvmuldp 44, 32, 37 \n\t" // s * x
"xvmuldp 45, 33, 37 \n\t"
"xvmuldp 46, 34, 37 \n\t"
"xvmuldp 47, 35, 37 \n\t"
"xvmuldp 38, 48, 37 \n\t" // s * y
"xvmuldp 39, 49, 37 \n\t"
"xvmuldp 56, 50, 37 \n\t"
"xvmuldp 57, 51, 37 \n\t"
"xvadddp 40, 40, 38 \n\t" // c * x + s * y
"xvadddp 41, 41, 39 \n\t" // c * x + s * y
"xvadddp 42, 42, 56 \n\t" // c * x + s * y
"xvadddp 43, 43, 57 \n\t" // c * x + s * y
"stxvp 40, 0(%3) \n\t" // store x
"stxvp 42, 32(%3) \n\t"
"xvsubdp 52, 52, 44 \n\t" // c * y - s * x
"xvsubdp 53, 53, 45 \n\t" // c * y - s * x
"xvsubdp 54, 54, 46 \n\t" // c * y - s * x
"xvsubdp 55, 55, 47 \n\t" // c * y - s * x
"stxvp 52, 0(%4) \n\t" // store y
"stxvp 54, 32(%4) \n\t"
"#n=%2 x=%0=%3 y=%1=%4 c=%5 s=%6\n"
:
"+m" (*x),
"+m" (*y),
"+r" (n), // 2
"+b" (x), // 3
"+b" (y) // 4
:
"d" (c), // 5
"d" (s) // 6
:
"cr0",
"vs32","vs33","vs34","vs35","vs36","vs37","vs38","vs39",
"vs40","vs41","vs42","vs43","vs44","vs45","vs46","vs47",
"vs48","vs49","vs50","vs51","vs52","vs53","vs54","vs55",
"vs56","vs57"
);
}

View File

@ -35,9 +35,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "common.h"
#if defined(POWER8) || defined(POWER9) || defined(POWER10)
#if defined(__VEC__) || defined(__ALTIVEC__)
#if defined(POWER8) || defined(POWER9)
#include "dscal_microk_power8.c"
#elif defined(POWER10)
#include "dscal_microk_power10.c"
#endif
#endif
@ -100,12 +102,28 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
if ( da == 0.0 )
{
#if defined(POWER10)
if ( n >= 16 )
{
BLASLONG align = ((32 - ((uintptr_t)x & (uintptr_t)0x1F)) >> 3) & 0x3;
for (j = 0; j < align; j++) {
x[j] = 0.0;
}
}
BLASLONG n1 = (n-j) & -16;
if ( n1 > 0 )
{
dscal_kernel_8_zero(n1, &x[j]);
j+=n1;
}
#else
BLASLONG n1 = n & -16;
if ( n1 > 0 )
{
dscal_kernel_8_zero(n1, x);
j=n1;
}
#endif
while(j < n)
{
@ -118,12 +136,28 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
else
{
#if defined(POWER10)
if ( n >= 16 )
{
BLASLONG align = ((32 - ((uintptr_t)x & (uintptr_t)0x1F)) >> 3) & 0x3;
for (j = 0; j < align; j++) {
x[j] = da * x[j];
}
}
BLASLONG n1 = (n-j) & -16;
if ( n1 > 0 )
{
dscal_kernel_8(n1, &x[j], da);
j+=n1;
}
#else
BLASLONG n1 = n & -16;
if ( n1 > 0 )
{
dscal_kernel_8(n1, x, da);
j=n1;
}
#endif
while(j < n)
{

View File

@ -0,0 +1,134 @@
/***************************************************************************
Copyright (c) 2021, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#define HAVE_KERNEL_8 1
static void dscal_kernel_8 (long n, double *x, double alpha)
{
__asm__
(
"dcbt 0, %2 \n\t"
XXSPLTD_S(48,%x3,0)
"lxvp 32, 0(%2) \n\t"
"lxvp 34, 32(%2) \n\t"
"lxvp 36, 64(%2) \n\t"
"lxvp 38, 96(%2) \n\t"
"addic. %1, %1, -16 \n\t"
"ble two%= \n\t"
".align 5 \n"
"one%=: \n\t"
"xvmuldp 40, 32, 48 \n\t"
"xvmuldp 41, 33, 48 \n\t"
"xvmuldp 42, 34, 48 \n\t"
"xvmuldp 43, 35, 48 \n\t"
"lxvp 32, 128(%2) \n\t"
"lxvp 34, 160(%2) \n\t"
"xvmuldp 44, 36, 48 \n\t"
"xvmuldp 45, 37, 48 \n\t"
"xvmuldp 46, 38, 48 \n\t"
"xvmuldp 47, 39, 48 \n\t"
"lxvp 36, 192(%2) \n\t"
"lxvp 38, 224(%2) \n\t"
"stxvp 40, 0(%2) \n\t"
"stxvp 42, 32(%2) \n\t"
"stxvp 44, 64(%2) \n\t"
"stxvp 46, 96(%2) \n\t"
"addi %2, %2, 128 \n\t"
"addic. %1, %1, -16 \n\t"
"bgt one%= \n"
"two%=: \n\t"
"xvmuldp 40, 32, 48 \n\t"
"xvmuldp 41, 33, 48 \n\t"
"xvmuldp 42, 34, 48 \n\t"
"xvmuldp 43, 35, 48 \n\t"
"xvmuldp 44, 36, 48 \n\t"
"xvmuldp 45, 37, 48 \n\t"
"xvmuldp 46, 38, 48 \n\t"
"xvmuldp 47, 39, 48 \n\t"
"stxvp 40, 0(%2) \n\t"
"stxvp 42, 32(%2) \n\t"
"stxvp 44, 64(%2) \n\t"
"stxvp 46, 96(%2) \n\t"
"#n=%1 alpha=%3 x=%0=%2"
:
"+m" (*x),
"+r" (n), // 1
"+b" (x) // 2
:
"d" (alpha) // 3
:
"cr0",
"vs32","vs33","vs34","vs35","vs36","vs37","vs38","vs39",
"vs40","vs41","vs42","vs43","vs44","vs45","vs46","vs47","vs48"
);
}
static void dscal_kernel_8_zero (long n, double *x)
{
__asm__
(
"xxlxor 32, 32, 32 \n\t"
"xxlxor 33, 33, 33 \n\t"
".align 5 \n"
"one%=: \n\t"
"stxvp 32, 0(%2) \n\t"
"stxvp 32, 32(%2) \n\t"
"stxvp 32, 64(%2) \n\t"
"stxvp 32, 96(%2) \n\t"
"addi %2, %2, 128 \n\t"
"addic. %1, %1, -16 \n\t"
"bgt one%= \n"
"#n=%1 x=%0=%2 "
:
"=m" (*x),
"+r" (n), // 1
"+b" (x) // 2
:
:
"cr0","vs32","vs33"
);
}

View File

@ -39,9 +39,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma GCC optimize "O1"
#if defined(POWER8) || defined(POWER9) || defined(POWER10)
#if defined(__VEC__) || defined(__ALTIVEC__)
#if defined(POWER8) || defined(POWER9)
#include "srot_microk_power8.c"
#elif defined(POWER10)
#include "srot_microk_power10.c"
#endif
#endif
@ -115,6 +117,23 @@ int CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT
if ( (inc_x == 1) && (inc_y == 1) )
{
#if defined(POWER10)
if ( n >= 16 )
{
BLASLONG align = ((32 - ((uintptr_t)y & (uintptr_t)0x1F)) >> 2) & 0x7;
for (i = 0; i < align; i++) {
temp = c*x[i] + s*y[i] ;
y[i] = c*y[i] - s*x[i] ;
x[i] = temp ;
}
}
BLASLONG n1 = (n-i) & -16;
if ( n1 > 0 )
{
srot_kernel_16(n1, &x1[i], &y1[i], c, s);
i+=n1;
}
#else
BLASLONG n1 = n & -16;
if ( n1 > 0 )
{
@ -122,6 +141,7 @@ int CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT
i=n1;
}
#endif
while(i < n)
{
temp = c*x[i] + s*y[i] ;

View File

@ -0,0 +1,151 @@
/***************************************************************************
Copyright (c) 2021, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#define HAVE_KERNEL_16 1
static void srot_kernel_16 (long n, float *x, float *y, float c, float s)
{
__asm__
(
"xscvdpspn 36, %x5 \n\t" // load c to all words
"xxspltw 36, 36, 0 \n\t"
"xscvdpspn 37, %x6 \n\t" // load s to all words
"xxspltw 37, 37, 0 \n\t"
"lxvp 32, 0(%3) \n\t" // load x
"lxvp 34, 32(%3) \n\t"
"lxvp 48, 0(%4) \n\t" // load y
"lxvp 50, 32(%4) \n\t"
"addic. %2, %2, -16 \n\t"
"ble two%= \n\t"
".align 5 \n"
"one%=: \n\t"
"xvmulsp 40, 32, 36 \n\t" // c * x
"xvmulsp 41, 33, 36 \n\t"
"xvmulsp 42, 34, 36 \n\t"
"xvmulsp 43, 35, 36 \n\t"
"xvmulsp 44, 32, 37 \n\t" // s * x
"xvmulsp 45, 33, 37 \n\t"
"xvmulsp 46, 34, 37 \n\t"
"xvmulsp 47, 35, 37 \n\t"
"lxvp 32, 64(%3) \n\t" // load x
"lxvp 34, 96(%3) \n\t"
"xvmulsp 52, 48, 36 \n\t" // c * y
"xvmulsp 53, 49, 36 \n\t"
"xvmulsp 54, 50, 36 \n\t"
"xvmulsp 55, 51, 36 \n\t"
"xvmulsp 38, 48, 37 \n\t" // s * y
"xvmulsp 39, 49, 37 \n\t"
"xvmulsp 56, 50, 37 \n\t"
"xvmulsp 57, 51, 37 \n\t"
"lxvp 48, 64(%4) \n\t" // load y
"lxvp 50, 96(%4) \n\t"
"xvaddsp 40, 40, 38 \n\t" // c * x + s * y
"xvaddsp 41, 41, 39 \n\t" // c * x + s * y
"xvaddsp 42, 42, 56 \n\t" // c * x + s * y
"xvaddsp 43, 43, 57 \n\t" // c * x + s * y
"stxvp 40, 0(%3) \n\t" // store x
"stxvp 42, 32(%3) \n\t"
"xvsubsp 52, 52, 44 \n\t" // c * y - s * x
"xvsubsp 53, 53, 45 \n\t" // c * y - s * x
"xvsubsp 54, 54, 46 \n\t" // c * y - s * x
"xvsubsp 55, 55, 47 \n\t" // c * y - s * x
"stxvp 52, 0(%4) \n\t" // store y
"stxvp 54, 32(%4) \n\t"
"addi %3, %3, 64 \n\t"
"addi %4, %4, 64 \n\t"
"addic. %2, %2, -16 \n\t"
"bgt one%= \n"
"two%=: \n\t"
"xvmulsp 40, 32, 36 \n\t" // c * x
"xvmulsp 41, 33, 36 \n\t"
"xvmulsp 42, 34, 36 \n\t"
"xvmulsp 43, 35, 36 \n\t"
"xvmulsp 52, 48, 36 \n\t" // c * y
"xvmulsp 53, 49, 36 \n\t"
"xvmulsp 54, 50, 36 \n\t"
"xvmulsp 55, 51, 36 \n\t"
"xvmulsp 44, 32, 37 \n\t" // s * x
"xvmulsp 45, 33, 37 \n\t"
"xvmulsp 46, 34, 37 \n\t"
"xvmulsp 47, 35, 37 \n\t"
"xvmulsp 38, 48, 37 \n\t" // s * y
"xvmulsp 39, 49, 37 \n\t"
"xvmulsp 56, 50, 37 \n\t"
"xvmulsp 57, 51, 37 \n\t"
"xvaddsp 40, 40, 38 \n\t" // c * x + s * y
"xvaddsp 41, 41, 39 \n\t" // c * x + s * y
"xvaddsp 42, 42, 56 \n\t" // c * x + s * y
"xvaddsp 43, 43, 57 \n\t" // c * x + s * y
"stxvp 40, 0(%3) \n\t" // store x
"stxvp 42, 32(%3) \n\t"
"xvsubsp 52, 52, 44 \n\t" // c * y - s * x
"xvsubsp 53, 53, 45 \n\t" // c * y - s * x
"xvsubsp 54, 54, 46 \n\t" // c * y - s * x
"xvsubsp 55, 55, 47 \n\t" // c * y - s * x
"stxvp 52, 0(%4) \n\t" // store y
"stxvp 54, 32(%4) \n\t"
"#n=%2 x=%0=%3 y=%1=%4 c=%5 s=%6\n"
:
"+m" (*x),
"+m" (*y),
"+r" (n), // 2
"+b" (x), // 3
"+b" (y) // 4
:
"f" (c), // 5
"f" (s) // 6
:
"cr0",
"vs32","vs33","vs34","vs35","vs36","vs37","vs38","vs39",
"vs40","vs41","vs42","vs43","vs44","vs45","vs46","vs47",
"vs48","vs49","vs50","vs51","vs52","vs53","vs54","vs55",
"vs56","vs57"
);
}

View File

@ -35,9 +35,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "common.h"
#if defined(POWER8) || defined(POWER9) || defined(POWER10)
#if defined(__VEC__) || defined(__ALTIVEC__)
#if defined(POWER8) || defined(POWER9)
#include "sscal_microk_power8.c"
#elif defined(POWER10)
#include "sscal_microk_power10.c"
#endif
#endif
@ -102,12 +104,28 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
if ( da == 0.0 )
{
#if defined(POWER10)
if ( n >= 32 )
{
BLASLONG align = ((32 - ((uintptr_t)x & (uintptr_t)0x1F)) >> 2) & 0x7;
for (j = 0; j < align; j++) {
x[j] = 0.0;
}
}
BLASLONG n1 = (n-j) & -32;
if ( n1 > 0 )
{
sscal_kernel_16_zero(n1, &x[j]);
j+=n1;
}
#else
BLASLONG n1 = n & -32;
if ( n1 > 0 )
{
sscal_kernel_16_zero(n1, x);
j=n1;
}
#endif
while(j < n)
{
@ -120,12 +138,28 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS
else
{
#if defined(POWER10)
if ( n >= 32 )
{
BLASLONG align = ((32 - ((uintptr_t)x & (uintptr_t)0x1F)) >> 2) & 0x7;
for (j = 0; j < align; j++) {
x[j] = da * x[j];
}
}
BLASLONG n1 = (n-j) & -32;
if ( n1 > 0 )
{
sscal_kernel_16(n1, &x[j], da);
j+=n1;
}
#else
BLASLONG n1 = n & -32;
if ( n1 > 0 )
{
sscal_kernel_16(n1, x, da);
j=n1;
}
#endif
while(j < n)
{

View File

@ -0,0 +1,135 @@
/***************************************************************************
Copyright (c) 2021, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#define HAVE_KERNEL_16 1
static void sscal_kernel_16 (long n, float *x, float alpha)
{
__asm__
(
"dcbt 0, %2 \n\t"
"xscvdpspn 48, %x3 \n\t"
"xxspltw 48, 48, 0 \n\t"
"lxvp 32, 0(%2) \n\t"
"lxvp 34, 32(%2) \n\t"
"lxvp 36, 64(%2) \n\t"
"lxvp 38, 96(%2) \n\t"
"addic. %1, %1, -32 \n\t"
"ble two%= \n\t"
".align 5 \n"
"one%=: \n\t"
"xvmulsp 40, 32, 48 \n\t"
"xvmulsp 41, 33, 48 \n\t"
"xvmulsp 42, 34, 48 \n\t"
"xvmulsp 43, 35, 48 \n\t"
"lxvp 32, 128(%2) \n\t"
"lxvp 34, 160(%2) \n\t"
"xvmulsp 44, 36, 48 \n\t"
"xvmulsp 45, 37, 48 \n\t"
"xvmulsp 46, 38, 48 \n\t"
"xvmulsp 47, 39, 48 \n\t"
"lxvp 36, 192(%2) \n\t"
"lxvp 38, 224(%2) \n\t"
"stxvp 40, 0(%2) \n\t"
"stxvp 42, 32(%2) \n\t"
"stxvp 44, 64(%2) \n\t"
"stxvp 46, 96(%2) \n\t"
"addi %2, %2, 128 \n\t"
"addic. %1, %1, -32 \n\t"
"bgt one%= \n"
"two%=: \n\t"
"xvmulsp 40, 32, 48 \n\t"
"xvmulsp 41, 33, 48 \n\t"
"xvmulsp 42, 34, 48 \n\t"
"xvmulsp 43, 35, 48 \n\t"
"xvmulsp 44, 36, 48 \n\t"
"xvmulsp 45, 37, 48 \n\t"
"xvmulsp 46, 38, 48 \n\t"
"xvmulsp 47, 39, 48 \n\t"
"stxvp 40, 0(%2) \n\t"
"stxvp 42, 32(%2) \n\t"
"stxvp 44, 64(%2) \n\t"
"stxvp 46, 96(%2) \n\t"
"#n=%1 alpha=%3 x=%0=%2"
:
"+m" (*x),
"+r" (n), // 1
"+b" (x) // 2
:
"f" (alpha) // 3
:
"cr0",
"vs32","vs33","vs34","vs35","vs36","vs37","vs38","vs39",
"vs40","vs41","vs42","vs43","vs44","vs45","vs46","vs47","vs48"
);
}
static void sscal_kernel_16_zero (long n, float *x)
{
__asm__
(
"xxlxor 32, 32, 32 \n\t"
"xxlxor 33, 33, 33 \n\t"
".align 5 \n"
"one%=: \n\t"
"stxvp 32, 0(%2) \n\t"
"stxvp 32, 32(%2) \n\t"
"stxvp 32, 64(%2) \n\t"
"stxvp 32, 96(%2) \n\t"
"addi %2, %2, 128 \n\t"
"addic. %1, %1, -32 \n\t"
"bgt one%= \n"
"#n=%1 x=%0=%2 "
:
"=m" (*x),
"+r" (n), // 1
"+b" (x) // 2
:
:
"cr0","vs32","vs33"
);
}

View File

@ -93,7 +93,6 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
#if defined(SMP)
int nthreads;
FLOAT dummy_alpha;
FLOAT * dummy_b;
#endif
FLOAT sumf = 0.0;
@ -115,7 +114,7 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
#else
mode = BLAS_DOUBLE | BLAS_REAL;
#endif
blas_level1_thread_with_return_value(mode, n, 0, 0, &dummy_alpha, x, inc_x, dummy_b, 0, result, 0, (void *)asum_thread_function, nthreads);
blas_level1_thread_with_return_value(mode, n, 0, 0, &dummy_alpha, x, inc_x, NULL, 0, result, 0, (void *)asum_thread_function, nthreads);
ptr = (FLOAT *)result;
for (i = 0; i < nthreads; i++) {
sumf += (*ptr);

View File

@ -0,0 +1,426 @@
#include "sbgemm.h"
#include <immintrin.h>
// Walk around those intrinsics that missed by compiler
#define MM256_LOADU_EPI16(addr) \
_mm256_maskz_loadu_epi16(~0, (addr))
#define MM256_STOREU_EPI16(addr, reg) \
_mm256_mask_storeu_epi16((addr), ~0, (reg))
#include <stdio.h>
void print_block(BLASLONG m, BLASLONG n, bfloat16 * mat)
{
printf("---- BLOCK %ld x %ld ----\n", m, n);
for (BLASLONG i=0; i<m; i++) {
for (BLASLONG j=0; j<n; j++) {
printf("%-4X ", *(mat + i*n +j));
}
printf("\n");
}
printf("---- End of BLOCK ----\n");
}
void COL_MAJOR_INCOPY_KERNEL_Kx32(BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
{
BLASLONG tag_k_2x = k & (~1);
__m512i array512_0, array512_1, array512_2, array512_3;
BLASLONG idx_src_base0, idx_src_base1;
BLASLONG idx_target_base0, idx_target_base1;
BLASLONG LDA_2x = 2*lda;
BLASLONG BF16_BLOCK_T_M_2x = 2*32;
idx_src_base0 = 0;
idx_src_base1 = lda;
idx_target_base0 = 0;
idx_target_base1 = 32;
for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
array512_0 = _mm512_loadu_si512(&A[idx_src_base0]);
array512_1 = _mm512_loadu_si512(&A[idx_src_base1]);
array512_2 = _mm512_unpacklo_epi16(array512_0, array512_1);
array512_3 = _mm512_unpackhi_epi16(array512_0, array512_1);
_mm512_storeu_si512(&block_A[idx_target_base0], array512_2);
_mm512_storeu_si512(&block_A[idx_target_base1], array512_3);
idx_src_base0 += LDA_2x;
idx_src_base1 += LDA_2x;
idx_target_base0 += BF16_BLOCK_T_M_2x;
idx_target_base1 += BF16_BLOCK_T_M_2x;
}
if (tag_k_2x != k) {
__m512i ZERO512 = _mm512_setzero_si512();
array512_0 = _mm512_loadu_si512(&A[idx_src_base0]);
array512_2 = _mm512_unpacklo_epi16(array512_0, ZERO512);
array512_3 = _mm512_unpackhi_epi16(array512_0, ZERO512);
_mm512_storeu_si512(&block_A[idx_target_base0], array512_2);
_mm512_storeu_si512(&block_A[idx_target_base1], array512_3);
}
#ifdef DEBUG_PROFILE
print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A);
#endif
}
void COL_MAJOR_INCOPY_KERNEL_Kx32m(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
{
BLASLONG tag_k_2x = k & (~1);
unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-m));
__mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
__m512i array512_0, array512_1, array512_2, array512_3;
BLASLONG idx_src_base0, idx_src_base1;
BLASLONG idx_target_base0, idx_target_base1;
BLASLONG LDA_2x = 2*lda;
BLASLONG BF16_BLOCK_T_M_2x = 2*32;
idx_src_base0 = 0;
idx_src_base1 = lda;
idx_target_base0 = 0;
idx_target_base1 = 32;
for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
array512_0 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]);
array512_1 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base1]);
array512_2 = _mm512_unpacklo_epi16(array512_0, array512_1);
array512_3 = _mm512_unpackhi_epi16(array512_0, array512_1);
_mm512_storeu_si512(&block_A[idx_target_base0], array512_2);
_mm512_storeu_si512(&block_A[idx_target_base1], array512_3);
idx_src_base0 += LDA_2x;
idx_src_base1 += LDA_2x;
idx_target_base0 += BF16_BLOCK_T_M_2x;
idx_target_base1 += BF16_BLOCK_T_M_2x;
}
if (tag_k_2x != k) {
__m512i ZERO512 = _mm512_setzero_si512();
array512_0 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]);
array512_2 = _mm512_unpacklo_epi16(array512_0, ZERO512);
array512_3 = _mm512_unpackhi_epi16(array512_0, ZERO512);
_mm512_storeu_si512(&block_A[idx_target_base0], array512_2);
_mm512_storeu_si512(&block_A[idx_target_base1], array512_3);
}
#ifdef DEBUG_PROFILE
print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A);
#endif
}
void COL_MAJOR_INCOPY_KERNEL_Kx16(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
{
BLASLONG tag_k_2x = k & (~1);
__m256i array256_0, array256_1, array256_2, array256_3;
BLASLONG idx_src_base0, idx_src_base1;
BLASLONG idx_target_base0;
BLASLONG LDA_2x = 2*lda;
idx_src_base0 = 0;
idx_src_base1 = lda;
idx_target_base0 = 0;
for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
array256_0 = MM256_LOADU_EPI16(&A[idx_src_base0]);
array256_1 = MM256_LOADU_EPI16(&A[idx_src_base1]);
array256_2 = _mm256_unpacklo_epi16(array256_0, array256_1);
array256_3 = _mm256_unpackhi_epi16(array256_0, array256_1);
// Store in one row of block_B
MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2);
MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3);
idx_src_base0 += LDA_2x;
idx_src_base1 += LDA_2x;
idx_target_base0 += 32;
}
if (tag_k_2x != k) {
__m256i ZERO256 = _mm256_setzero_si256();
array256_0 = MM256_LOADU_EPI16(&A[idx_src_base0]);
array256_2 = _mm256_unpacklo_epi16(array256_0, ZERO256);
array256_3 = _mm256_unpackhi_epi16(array256_0, ZERO256);
// Store in one row of block_B
MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2);
MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3);
}
#ifdef DEBUG_PROFILE
print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A);
#endif
}
void COL_MAJOR_INCOPY_KERNEL_Kx16m(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
{
BLASLONG tag_k_2x = k & (~1);
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m));
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
__m256i array256_0, array256_1, array256_2, array256_3;
BLASLONG idx_src_base0, idx_src_base1;
BLASLONG idx_target_base0;
BLASLONG LDA_2x = 2*lda;
idx_src_base0 = 0;
idx_src_base1 = lda;
idx_target_base0 = 0;
for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) {
array256_0 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]);
array256_1 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base1]);
array256_2 = _mm256_unpacklo_epi16(array256_0, array256_1);
array256_3 = _mm256_unpackhi_epi16(array256_0, array256_1);
// Store in one row of block_B
MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2);
MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3);
idx_src_base0 += LDA_2x;
idx_src_base1 += LDA_2x;
idx_target_base0 += 32;
}
if (tag_k_2x != k) {
__m256i ZERO256 = _mm256_setzero_si256();
array256_0 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]);
array256_2 = _mm256_unpacklo_epi16(array256_0, ZERO256);
array256_3 = _mm256_unpackhi_epi16(array256_0, ZERO256);
// Store in one row of block_B
MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2);
MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3);
}
#ifdef DEBUG_PROFILE
print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A);
#endif
}
void COL_MAJOR_ONCOPY_KERNEL_8x32(BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B)
{
BLASLONG tag_k_32x = k & (~31);
BLASLONG idx_src_base0, idx_src_base1, idx_src_base2, idx_src_base3, idx_src_base4, idx_src_base5, idx_src_base6, idx_src_base7;
BLASLONG idx_target_base0;
idx_src_base0 = 0;
idx_src_base1 = 1*ldb;
idx_src_base2 = 2*ldb;
idx_src_base3 = 3*ldb;
idx_src_base4 = 4*ldb;
idx_src_base5 = 5*ldb;
idx_src_base6 = 6*ldb;
idx_src_base7 = 7*ldb;
idx_target_base0 = 0;
for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_loadu_si512(&B[idx_src_base0+idx_k]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_loadu_si512(&B[idx_src_base1+idx_k]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*2], _mm512_loadu_si512(&B[idx_src_base2+idx_k]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*3], _mm512_loadu_si512(&B[idx_src_base3+idx_k]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*4], _mm512_loadu_si512(&B[idx_src_base4+idx_k]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*5], _mm512_loadu_si512(&B[idx_src_base5+idx_k]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*6], _mm512_loadu_si512(&B[idx_src_base6+idx_k]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*7], _mm512_loadu_si512(&B[idx_src_base7+idx_k]));
idx_target_base0 += 32*8;
}
if (tag_k_32x != k) {
unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(k-tag_k_32x)));
__mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0+tag_k_32x]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base1+tag_k_32x]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*2], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base2+tag_k_32x]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*3], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base3+tag_k_32x]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*4], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base4+tag_k_32x]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*5], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base5+tag_k_32x]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*6], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base6+tag_k_32x]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*7], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base7+tag_k_32x]));
}
#ifdef DEBUG_PROFILE
print_block(BF16_BLOCK_THRES_N, BF16_BLOCK_THRES_K, block_B);
#endif
}
void COL_MAJOR_ONCOPY_KERNEL_Nx32(BLASLONG n, BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B)
{
BLASLONG tag_k_32x = k & (~31);
BLASLONG tag_n_2x = n & (~1);
BLASLONG idx_src_base0;
BLASLONG idx_target_base0;
BLASLONG LDB_2x = 2*ldb;
idx_target_base0 = 0;
for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
idx_src_base0 = 0;
for (BLASLONG idx_n = 0; idx_n < tag_n_2x; idx_n += 2) {
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_loadu_si512(&B[idx_src_base0 + idx_k]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_loadu_si512(&B[idx_src_base0 + ldb + idx_k]));
idx_src_base0 += LDB_2x;
idx_target_base0 += 64;
}
if (tag_n_2x != n) {
_mm512_storeu_si512(&block_B[idx_target_base0], _mm512_loadu_si512(&B[idx_src_base0 + idx_k]));
idx_target_base0 += 32;
}
}
if (tag_k_32x != k) {
unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(k-tag_k_32x)));
__mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
idx_src_base0 = 0;
for (BLASLONG idx_n = 0; idx_n < tag_n_2x; idx_n += 2) {
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0 + tag_k_32x]));
_mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0 + ldb + tag_k_32x]));
idx_src_base0 += LDB_2x;
idx_target_base0 += 64;
}
if (tag_n_2x != n) {
_mm512_storeu_si512(&block_B[idx_target_base0], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0 + tag_k_32x]));
}
}
#ifdef DEBUG_PROFILE
print_block(BF16_BLOCK_THRES_N, BF16_BLOCK_THRES_K, block_B);
#endif
}
// Scale matrix C while beta is not ZERO or ONE
void sbgemm_scal_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc)
{
BLASLONG tag_n_Nx = N & (~3);
BLASLONG tag_n_Mx = M & (~15);
BLASLONG LDC4x = ldc*4;
BLASLONG idx_base_0 = 0;
BLASLONG idx_base_1 = ldc;
BLASLONG idx_base_2 = ldc*2;
BLASLONG idx_base_3 = ldc*3;
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-M+tag_n_Mx));
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
__m512 array_512_0, array_512_1, array_512_2, array_512_3;
__m512 BETAVECTOR = _mm512_set1_ps(beta);
if (Order == CblasColMajor) {
for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) {
for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
array_512_0 = _mm512_loadu_ps(&C[idx_base_0+idx_m]);
array_512_1 = _mm512_loadu_ps(&C[idx_base_1+idx_m]);
array_512_2 = _mm512_loadu_ps(&C[idx_base_2+idx_m]);
array_512_3 = _mm512_loadu_ps(&C[idx_base_3+idx_m]);
array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1);
array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2);
array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3);
_mm512_storeu_ps(&C[idx_base_0+idx_m], array_512_0);
_mm512_storeu_ps(&C[idx_base_1+idx_m], array_512_1);
_mm512_storeu_ps(&C[idx_base_2+idx_m], array_512_2);
_mm512_storeu_ps(&C[idx_base_3+idx_m], array_512_3);
}
if (tag_n_Mx != M) {
array_512_0 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_0+tag_n_Mx]);
array_512_1 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_1+tag_n_Mx]);
array_512_2 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_2+tag_n_Mx]);
array_512_3 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_3+tag_n_Mx]);
array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1);
array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2);
array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3);
_mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, array_512_0);
_mm512_mask_storeu_ps(&C[idx_base_1+tag_n_Mx], tail_mask, array_512_1);
_mm512_mask_storeu_ps(&C[idx_base_2+tag_n_Mx], tail_mask, array_512_2);
_mm512_mask_storeu_ps(&C[idx_base_3+tag_n_Mx], tail_mask, array_512_3);
}
idx_base_0 += LDC4x;
idx_base_1 += LDC4x;
idx_base_2 += LDC4x;
idx_base_3 += LDC4x;
}
if (tag_n_Nx != N) {
for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) {
for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
array_512_0 = _mm512_loadu_ps(&C[idx_base_0+idx_m]);
array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
_mm512_storeu_ps(&C[idx_base_0+idx_m], array_512_0);
}
if (tag_n_Mx != M) {
array_512_0 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_0+tag_n_Mx]);
array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0);
_mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, array_512_0);
}
idx_base_0 += ldc;
}
}
} else {
}
}
// Scale matrix C while beta is not ZERO or ONE
void sbgemm_zero_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, float *C, OPENBLAS_CONST blasint ldc)
{
BLASLONG tag_n_Nx = N & (~3);
BLASLONG tag_n_Mx = M & (~15);
BLASLONG LDC4x = ldc*4;
BLASLONG idx_base_0 = 0;
BLASLONG idx_base_1 = ldc;
BLASLONG idx_base_2 = ldc*2;
BLASLONG idx_base_3 = ldc*3;
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-M+tag_n_Mx));
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
__m512 ZEROVECTOR = _mm512_setzero_ps();
if (Order == CblasColMajor) {
for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) {
for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
_mm512_storeu_ps(&C[idx_base_0+idx_m], ZEROVECTOR);
_mm512_storeu_ps(&C[idx_base_1+idx_m], ZEROVECTOR);
_mm512_storeu_ps(&C[idx_base_2+idx_m], ZEROVECTOR);
_mm512_storeu_ps(&C[idx_base_3+idx_m], ZEROVECTOR);
}
if (tag_n_Mx != M) {
_mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, ZEROVECTOR);
_mm512_mask_storeu_ps(&C[idx_base_1+tag_n_Mx], tail_mask, ZEROVECTOR);
_mm512_mask_storeu_ps(&C[idx_base_2+tag_n_Mx], tail_mask, ZEROVECTOR);
_mm512_mask_storeu_ps(&C[idx_base_3+tag_n_Mx], tail_mask, ZEROVECTOR);
}
idx_base_0 += LDC4x;
idx_base_1 += LDC4x;
idx_base_2 += LDC4x;
idx_base_3 += LDC4x;
}
if (tag_n_Nx != N) {
for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) {
for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) {
_mm512_storeu_ps(&C[idx_base_0+idx_m], ZEROVECTOR);
}
if (tag_n_Mx != M) {
_mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, ZEROVECTOR);
}
idx_base_0 += ldc;
}
}
} else {
}
}

View File

@ -0,0 +1,625 @@
#include "sbgemm.h"
#include "bf16_common_macros.h"
#include <immintrin.h>
#undef STORE16_COMPLETE_RESULT
#undef STORE16_MASK_COMPLETE_RESULT
#undef SBGEMM_BLOCK_KERNEL_32x8x32
#undef SBGEMM_BLOCK_KERNEL_16x8x32
#undef SBGEMM_BLOCK_KERNEL_32xNx32
#undef SBGEMM_BLOCK_KERNEL_16xNx32
#undef SBGEMM_BLOCKING_KERNEL_2
#ifndef ONE_ALPHA // ALPHA is not ONE
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE
#define SBGEMM_BLOCK_KERNEL_32x8x32 sbgemm_block_kernel_32x8x32_alpha
#define SBGEMM_BLOCK_KERNEL_16x8x32 sbgemm_block_kernel_16x8x32_alpha
#define SBGEMM_BLOCK_KERNEL_32xNx32 sbgemm_block_kernel_32xNx32_alpha
#define SBGEMM_BLOCK_KERNEL_16xNx32 sbgemm_block_kernel_16xNx32_alpha
#define SBGEMM_BLOCKING_KERNEL_2 sbgemm_blocking_kernel_2_alpha
#else // ALPHA is ONE
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ONE_ONE
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ONE_ONE
#define SBGEMM_BLOCK_KERNEL_32x8x32 sbgemm_block_kernel_32x8x32_one
#define SBGEMM_BLOCK_KERNEL_16x8x32 sbgemm_block_kernel_16x8x32_one
#define SBGEMM_BLOCK_KERNEL_32xNx32 sbgemm_block_kernel_32xNx32_one
#define SBGEMM_BLOCK_KERNEL_16xNx32 sbgemm_block_kernel_16xNx32_one
#define SBGEMM_BLOCKING_KERNEL_2 sbgemm_blocking_kernel_2_one
#endif
// SBGEMM Kernel for 16<M<=32, N=8, K can be any number, but the processing will take 32 as a base
#ifndef ONE_ALPHA // ALPHA is not ONE
void sbgemm_block_kernel_32x8x32_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#else // ALPHA is ONE
void sbgemm_block_kernel_32x8x32_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#endif
{
int SHUFFLE_MAGIC_NO = 0x39;
BLASLONG tag_k_32x = k & (~31);
BLASLONG idxA_base = 0;
BLASLONG idxB_base = 0;
BLASLONG width = 32;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
__m512i arrayA_512_0, arrayA_512_1;
__m512i arrayB_512_0, arrayB_512_1, arrayB_512_2, arrayB_512_3, arrayB_512_4, arrayB_512_5, arrayB_512_6, arrayB_512_7;
__m512 result_512_0, result_512_1, result_512_2, result_512_3, result_512_4, result_512_5, result_512_6, result_512_7,
result_512_8, result_512_9, result_512_10, result_512_11, result_512_12, result_512_13, result_512_14, result_512_15;
__m512 result_512_tmp_0, result_512_tmp_1, result_512_tmp_2, result_512_tmp_3;
__m512i M512_EPI32_8 = _mm512_set1_epi32(8);
__m512i shuffle_idx_base0 = _mm512_set_epi32(23, 22, 21, 20, 7, 6, 5, 4, 19, 18, 17, 16, 3, 2, 1, 0);
__m512i shuffle_idx_base1 = _mm512_add_epi32(shuffle_idx_base0, M512_EPI32_8);
result_512_0 = _mm512_setzero_ps();
result_512_1 = _mm512_setzero_ps();
result_512_2 = _mm512_setzero_ps();
result_512_3 = _mm512_setzero_ps();
result_512_4 = _mm512_setzero_ps();
result_512_5 = _mm512_setzero_ps();
result_512_6 = _mm512_setzero_ps();
result_512_7 = _mm512_setzero_ps();
result_512_8 = _mm512_setzero_ps();
result_512_9 = _mm512_setzero_ps();
result_512_10 = _mm512_setzero_ps();
result_512_11 = _mm512_setzero_ps();
result_512_12 = _mm512_setzero_ps();
result_512_13 = _mm512_setzero_ps();
result_512_14 = _mm512_setzero_ps();
result_512_15 = _mm512_setzero_ps();
for (BLASLONG idx_k = 0; idx_k < k; idx_k += 32) {
// Load B with unroll 8
idxB_base = idx_k << 3;
arrayB_512_0 = _mm512_loadu_si512(&B[idxB_base + 32*0]);
arrayB_512_1 = _mm512_loadu_si512(&B[idxB_base + 32*1]);
arrayB_512_2 = _mm512_loadu_si512(&B[idxB_base + 32*2]);
arrayB_512_3 = _mm512_loadu_si512(&B[idxB_base + 32*3]);
arrayB_512_4 = _mm512_loadu_si512(&B[idxB_base + 32*4]);
arrayB_512_5 = _mm512_loadu_si512(&B[idxB_base + 32*5]);
arrayB_512_6 = _mm512_loadu_si512(&B[idxB_base + 32*6]);
arrayB_512_7 = _mm512_loadu_si512(&B[idxB_base + 32*7]);
if (idx_k == tag_k_32x) {width = k - tag_k_32x;}
for (BLASLONG idx = 0; idx < width;) {
// Each two rows are a group for 32-pair bf16 elements
idxA_base = idx << 5;
arrayA_512_0 = _mm512_loadu_si512(&A[idxA_base]);
arrayA_512_1 = _mm512_loadu_si512(&A[idxA_base + 32]);
result_512_0 = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_0)));
result_512_1 = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_1)));
result_512_2 = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_2)));
result_512_3 = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_3)));
result_512_4 = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_4)));
result_512_5 = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_5)));
result_512_6 = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_6)));
result_512_7 = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_7)));
result_512_8 = _mm512_dpbf16_ps(result_512_8, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_0)));
result_512_9 = _mm512_dpbf16_ps(result_512_9, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_1)));
result_512_10 = _mm512_dpbf16_ps(result_512_10, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_2)));
result_512_11 = _mm512_dpbf16_ps(result_512_11, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_3)));
result_512_12 = _mm512_dpbf16_ps(result_512_12, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_4)));
result_512_13 = _mm512_dpbf16_ps(result_512_13, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_5)));
result_512_14 = _mm512_dpbf16_ps(result_512_14, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_6)));
result_512_15 = _mm512_dpbf16_ps(result_512_15, (__m512bh) arrayA_512_1, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_7)));
arrayB_512_0 = _mm512_shuffle_epi32(arrayB_512_0, SHUFFLE_MAGIC_NO);
arrayB_512_1 = _mm512_shuffle_epi32(arrayB_512_1, SHUFFLE_MAGIC_NO);
arrayB_512_2 = _mm512_shuffle_epi32(arrayB_512_2, SHUFFLE_MAGIC_NO);
arrayB_512_3 = _mm512_shuffle_epi32(arrayB_512_3, SHUFFLE_MAGIC_NO);
arrayB_512_4 = _mm512_shuffle_epi32(arrayB_512_4, SHUFFLE_MAGIC_NO);
arrayB_512_5 = _mm512_shuffle_epi32(arrayB_512_5, SHUFFLE_MAGIC_NO);
arrayB_512_6 = _mm512_shuffle_epi32(arrayB_512_6, SHUFFLE_MAGIC_NO);
arrayB_512_7 = _mm512_shuffle_epi32(arrayB_512_7, SHUFFLE_MAGIC_NO);
idx += 2;
// Every 4 loops we need to switch to next 128 bits of arrayB registers
if ((idx & (~7)) == idx) {
arrayB_512_0 = _mm512_shuffle_i32x4(arrayB_512_0, arrayB_512_0, SHUFFLE_MAGIC_NO);
arrayB_512_1 = _mm512_shuffle_i32x4(arrayB_512_1, arrayB_512_1, SHUFFLE_MAGIC_NO);
arrayB_512_2 = _mm512_shuffle_i32x4(arrayB_512_2, arrayB_512_2, SHUFFLE_MAGIC_NO);
arrayB_512_3 = _mm512_shuffle_i32x4(arrayB_512_3, arrayB_512_3, SHUFFLE_MAGIC_NO);
arrayB_512_4 = _mm512_shuffle_i32x4(arrayB_512_4, arrayB_512_4, SHUFFLE_MAGIC_NO);
arrayB_512_5 = _mm512_shuffle_i32x4(arrayB_512_5, arrayB_512_5, SHUFFLE_MAGIC_NO);
arrayB_512_6 = _mm512_shuffle_i32x4(arrayB_512_6, arrayB_512_6, SHUFFLE_MAGIC_NO);
arrayB_512_7 = _mm512_shuffle_i32x4(arrayB_512_7, arrayB_512_7, SHUFFLE_MAGIC_NO);
}
}
}
if (m != 32) {
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (32-m));
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_0, shuffle_idx_base0, result_512_8);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_0, shuffle_idx_base1, result_512_8);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base0, result_512_9);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base1, result_512_9);
STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*0]))
STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*0+16]), tail_mask)
STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*1]))
STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*1+16]), tail_mask)
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base0, result_512_10);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base1, result_512_10);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base0, result_512_11);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base1, result_512_11);
STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*2]))
STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*2+16]), tail_mask)
STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*3]))
STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*3+16]), tail_mask)
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base0, result_512_12);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base1, result_512_12);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base0, result_512_13);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base1, result_512_13);
STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*4]))
STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*4+16]), tail_mask)
STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*5]))
STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*5+16]), tail_mask)
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base0, result_512_14);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base1, result_512_14);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base0, result_512_15);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base1, result_512_15);
STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*6]))
STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*6+16]), tail_mask)
STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*7]))
STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*7+16]), tail_mask)
} else {
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_0, shuffle_idx_base0, result_512_8);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_0, shuffle_idx_base1, result_512_8);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base0, result_512_9);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base1, result_512_9);
STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*0]))
STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*0+16]))
STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*1]))
STORE16_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*1+16]))
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base0, result_512_10);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base1, result_512_10);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base0, result_512_11);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base1, result_512_11);
STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*2]))
STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*2+16]))
STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*3]))
STORE16_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*3+16]))
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base0, result_512_12);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base1, result_512_12);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base0, result_512_13);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base1, result_512_13);
STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*4]))
STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*4+16]))
STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*5]))
STORE16_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*5+16]))
result_512_tmp_0 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base0, result_512_14);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base1, result_512_14);
result_512_tmp_2 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base0, result_512_15);
result_512_tmp_3 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base1, result_512_15);
STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*6]))
STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*6+16]))
STORE16_COMPLETE_RESULT(result_512_tmp_2, (&C[ldc*7]))
STORE16_COMPLETE_RESULT(result_512_tmp_3, (&C[ldc*7+16]))
}
}
// SBGEMM Kernel for M<=16, N=8, K can be any number, but the processing will take 32 as a base
#ifndef ONE_ALPHA // ALPHA is not ONE
void sbgemm_block_kernel_16x8x32_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#else // ALPHA is ONE
void sbgemm_block_kernel_16x8x32_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#endif
{
int SHUFFLE_MAGIC_NO = 0x39;
BLASLONG tag_k_32x = k & (~31);
BLASLONG idxB_base = 0;
BLASLONG width = 32;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
__m512i arrayA_512_0;
__m512i arrayB_512_0, arrayB_512_1, arrayB_512_2, arrayB_512_3, arrayB_512_4, arrayB_512_5, arrayB_512_6, arrayB_512_7;
__m512 result_512_0, result_512_1, result_512_2, result_512_3, result_512_4, result_512_5, result_512_6, result_512_7;
result_512_0 = _mm512_setzero_ps();
result_512_1 = _mm512_setzero_ps();
result_512_2 = _mm512_setzero_ps();
result_512_3 = _mm512_setzero_ps();
result_512_4 = _mm512_setzero_ps();
result_512_5 = _mm512_setzero_ps();
result_512_6 = _mm512_setzero_ps();
result_512_7 = _mm512_setzero_ps();
for (BLASLONG idx_k = 0; idx_k < k; idx_k += 32) {
// Load B with unroll 8
idxB_base = idx_k << 3;
arrayB_512_0 = _mm512_loadu_si512(&B[idxB_base + 32*0]);
arrayB_512_1 = _mm512_loadu_si512(&B[idxB_base + 32*1]);
arrayB_512_2 = _mm512_loadu_si512(&B[idxB_base + 32*2]);
arrayB_512_3 = _mm512_loadu_si512(&B[idxB_base + 32*3]);
arrayB_512_4 = _mm512_loadu_si512(&B[idxB_base + 32*4]);
arrayB_512_5 = _mm512_loadu_si512(&B[idxB_base + 32*5]);
arrayB_512_6 = _mm512_loadu_si512(&B[idxB_base + 32*6]);
arrayB_512_7 = _mm512_loadu_si512(&B[idxB_base + 32*7]);
if (idx_k == tag_k_32x) {width = k - tag_k_32x;}
for (BLASLONG idx = 0; idx < width;) {
// Each two rows are a group for 32-pair bf16 elements
// Load two rows into a 512 register
arrayA_512_0 = _mm512_loadu_si512(&A[idx<<4]);
result_512_0 = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_0)));
result_512_1 = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_1)));
result_512_2 = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_2)));
result_512_3 = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_3)));
result_512_4 = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_4)));
result_512_5 = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_5)));
result_512_6 = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_6)));
result_512_7 = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512_7)));
arrayB_512_0 = _mm512_shuffle_epi32(arrayB_512_0, SHUFFLE_MAGIC_NO);
arrayB_512_1 = _mm512_shuffle_epi32(arrayB_512_1, SHUFFLE_MAGIC_NO);
arrayB_512_2 = _mm512_shuffle_epi32(arrayB_512_2, SHUFFLE_MAGIC_NO);
arrayB_512_3 = _mm512_shuffle_epi32(arrayB_512_3, SHUFFLE_MAGIC_NO);
arrayB_512_4 = _mm512_shuffle_epi32(arrayB_512_4, SHUFFLE_MAGIC_NO);
arrayB_512_5 = _mm512_shuffle_epi32(arrayB_512_5, SHUFFLE_MAGIC_NO);
arrayB_512_6 = _mm512_shuffle_epi32(arrayB_512_6, SHUFFLE_MAGIC_NO);
arrayB_512_7 = _mm512_shuffle_epi32(arrayB_512_7, SHUFFLE_MAGIC_NO);
idx += 2;
// Every 4 loops we need to switch to next 128 bits of arrayB registers
if ((idx & (~7)) == idx) {
arrayB_512_0 = _mm512_shuffle_i32x4(arrayB_512_0, arrayB_512_0, SHUFFLE_MAGIC_NO);
arrayB_512_1 = _mm512_shuffle_i32x4(arrayB_512_1, arrayB_512_1, SHUFFLE_MAGIC_NO);
arrayB_512_2 = _mm512_shuffle_i32x4(arrayB_512_2, arrayB_512_2, SHUFFLE_MAGIC_NO);
arrayB_512_3 = _mm512_shuffle_i32x4(arrayB_512_3, arrayB_512_3, SHUFFLE_MAGIC_NO);
arrayB_512_4 = _mm512_shuffle_i32x4(arrayB_512_4, arrayB_512_4, SHUFFLE_MAGIC_NO);
arrayB_512_5 = _mm512_shuffle_i32x4(arrayB_512_5, arrayB_512_5, SHUFFLE_MAGIC_NO);
arrayB_512_6 = _mm512_shuffle_i32x4(arrayB_512_6, arrayB_512_6, SHUFFLE_MAGIC_NO);
arrayB_512_7 = _mm512_shuffle_i32x4(arrayB_512_7, arrayB_512_7, SHUFFLE_MAGIC_NO);
}
}
}
if (m != 16) {
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m));
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
result_512_0 = _mm512_shuffle_f32x4(result_512_0, result_512_0, 0xd8);
result_512_1 = _mm512_shuffle_f32x4(result_512_1, result_512_1, 0xd8);
result_512_2 = _mm512_shuffle_f32x4(result_512_2, result_512_2, 0xd8);
result_512_3 = _mm512_shuffle_f32x4(result_512_3, result_512_3, 0xd8);
STORE16_MASK_COMPLETE_RESULT(result_512_0, (&C[ldc*0]), tail_mask)
STORE16_MASK_COMPLETE_RESULT(result_512_1, (&C[ldc*1]), tail_mask)
STORE16_MASK_COMPLETE_RESULT(result_512_2, (&C[ldc*2]), tail_mask)
STORE16_MASK_COMPLETE_RESULT(result_512_3, (&C[ldc*3]), tail_mask)
result_512_4 = _mm512_shuffle_f32x4(result_512_4, result_512_4, 0xd8);
result_512_5 = _mm512_shuffle_f32x4(result_512_5, result_512_5, 0xd8);
result_512_6 = _mm512_shuffle_f32x4(result_512_6, result_512_6, 0xd8);
result_512_7 = _mm512_shuffle_f32x4(result_512_7, result_512_7, 0xd8);
STORE16_MASK_COMPLETE_RESULT(result_512_4, (&C[ldc*4]), tail_mask)
STORE16_MASK_COMPLETE_RESULT(result_512_5, (&C[ldc*5]), tail_mask)
STORE16_MASK_COMPLETE_RESULT(result_512_6, (&C[ldc*6]), tail_mask)
STORE16_MASK_COMPLETE_RESULT(result_512_7, (&C[ldc*7]), tail_mask)
} else {
result_512_0 = _mm512_shuffle_f32x4(result_512_0, result_512_0, 0xd8);
result_512_1 = _mm512_shuffle_f32x4(result_512_1, result_512_1, 0xd8);
result_512_2 = _mm512_shuffle_f32x4(result_512_2, result_512_2, 0xd8);
result_512_3 = _mm512_shuffle_f32x4(result_512_3, result_512_3, 0xd8);
STORE16_COMPLETE_RESULT(result_512_0, (&C[ldc*0]))
STORE16_COMPLETE_RESULT(result_512_1, (&C[ldc*1]))
STORE16_COMPLETE_RESULT(result_512_2, (&C[ldc*2]))
STORE16_COMPLETE_RESULT(result_512_3, (&C[ldc*3]))
result_512_4 = _mm512_shuffle_f32x4(result_512_4, result_512_4, 0xd8);
result_512_5 = _mm512_shuffle_f32x4(result_512_5, result_512_5, 0xd8);
result_512_6 = _mm512_shuffle_f32x4(result_512_6, result_512_6, 0xd8);
result_512_7 = _mm512_shuffle_f32x4(result_512_7, result_512_7, 0xd8);
STORE16_COMPLETE_RESULT(result_512_4, (&C[ldc*4]))
STORE16_COMPLETE_RESULT(result_512_5, (&C[ldc*5]))
STORE16_COMPLETE_RESULT(result_512_6, (&C[ldc*6]))
STORE16_COMPLETE_RESULT(result_512_7, (&C[ldc*7]))
}
}
// SBGEMM Kernel for 16<M<=32, N<8, K can be any number, but the processing will take 32 as a base
#ifndef ONE_ALPHA // ALPHA is not ONE
void sbgemm_block_kernel_32xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#else // ALPHA is ONE
void sbgemm_block_kernel_32xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#endif
{
int SHUFFLE_MAGIC_NO = 0x39;
BLASLONG tag_k_32x = k & (~31);
BLASLONG idxA_base = 0;
BLASLONG idxB_base = 0;
BLASLONG width = 32;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
__m512i arrayA_512[2];
__m512i arrayB_512[8];
__m512 result_512[16];
__m512 result_512_tmp_0, result_512_tmp_1;
__m512i M512_EPI32_8 = _mm512_set1_epi32(8);
__m512i shuffle_idx_base0 = _mm512_set_epi32(23, 22, 21, 20, 7, 6, 5, 4, 19, 18, 17, 16, 3, 2, 1, 0);
__m512i shuffle_idx_base1 = _mm512_add_epi32(shuffle_idx_base0, M512_EPI32_8);
for (int i = 0; i < 15; i += 2) {
result_512[i] = _mm512_setzero_ps();
result_512[i+1] = _mm512_setzero_ps();
}
for (BLASLONG idx_k = 0; idx_k < k; idx_k += 32) {
// Load B with unroll n
for (int i = 0; i < n; i ++) {
arrayB_512[i] = _mm512_loadu_si512(&B[idxB_base]);
idxB_base += 32;
}
if (idx_k == tag_k_32x) {width = k - tag_k_32x;}
for (BLASLONG idx = 0; idx < width;) {
// Each two rows are a group for 32-pair bf16 elements
idxA_base = idx << 5;
arrayA_512[0] = _mm512_loadu_si512(&A[idxA_base]);
arrayA_512[1] = _mm512_loadu_si512(&A[idxA_base + 32]);
for (int i = 0; i < n; i++) {
result_512[i] = _mm512_dpbf16_ps(result_512[i] , (__m512bh) arrayA_512[0], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512[1], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
}
idx += 2;
// Every 4 loops we need to switch to next 128 bits of arrayB registers
if ((idx & (~7)) == idx) {
for (int i = 0; i < n; i++) {
arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
}
}
}
}
if (m != 32) {
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (32-m));
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
for (int i = 0; i < n; i++) {
result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]);
STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*i]))
STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*i+16]), tail_mask)
}
} else {
for (int i = 0; i < n; i++) {
result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]);
result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]);
STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*i]))
STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*i+16]))
}
}
}
// SBGEMM Kernel for 16<=M, N<8, K can be any number, but the processing will take 32 as a base
#ifndef ONE_ALPHA // ALPHA is not ONE
void sbgemm_block_kernel_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#else // ALPHA is ONE
void sbgemm_block_kernel_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
#endif
{
int SHUFFLE_MAGIC_NO = 0x39;
BLASLONG tag_k_32x = k & (~31);
BLASLONG idxB_base = 0;
BLASLONG width = 32;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
__m512i arrayA_512;
__m512i arrayB_512[8];
__m512 result_512[8];
for (int i = 0; i < 8; i += 2) {
result_512[i] = _mm512_setzero_ps();
result_512[i+1] = _mm512_setzero_ps();
}
for (BLASLONG idx_k = 0; idx_k < k; idx_k += 32) {
// Load B with unroll n
for (int i = 0; i < n; i ++) {
arrayB_512[i] = _mm512_loadu_si512(&B[idxB_base]);
idxB_base += 32;
}
if (idx_k == tag_k_32x) {width = k - tag_k_32x;}
for (BLASLONG idx = 0; idx < width;) {
// Each two rows are a group for 32-pair bf16 elements
// Load two rows into a 512 register
arrayA_512 = _mm512_loadu_si512(&A[idx<<4]);
for (int i = 0; i < n; i ++) {
result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
}
idx += 2;
// Every 4 loops we need to switch to next 128 bits of arrayB registers
if ((idx & (~7)) == idx) {
for (int i = 0; i < n; i++) {
arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
}
}
}
}
if (m != 16) {
unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m));
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
for (int i = 0; i < n; i++) {
result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8);
STORE16_MASK_COMPLETE_RESULT(result_512[i], (&C[ldc*i]), tail_mask)
}
} else {
for (int i = 0; i < n; i++) {
result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8);
STORE16_COMPLETE_RESULT(result_512[i], (&C[ldc*i]))
}
}
}
#ifndef ONE_ALPHA // ALPHA is not ONE
void sbgemm_blocking_kernel_2_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
#else // ALPHA is ONE
void sbgemm_blocking_kernel_2_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
#endif
{
BLASLONG m_step, n_step, k_step, k_step_round32;
BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1));
BLASLONG n_from, n_to;
BLASLONG tag_n_Nx;
n_from = 0;
n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N;
tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K;
k_step_round32 = k_step & (~31);
k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
if (M >= BF16_BLOCK_THRES_M) {
while (n_from < N) {
for (BLASLONG idx_k = 0; idx_k < K;) {
// Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, &A(idx_k, 0), lda, block_A);
// TODO: MT
for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
// Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
SBGEMM_BLOCK_KERNEL_32x8x32(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
}
if (tag_n_Nx != n_to) {
n_step = n_to - tag_n_Nx;
COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
SBGEMM_BLOCK_KERNEL_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
}
for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) {
COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, &A(idx_k, idx_m), lda, block_A);
for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
SBGEMM_BLOCK_KERNEL_32x8x32(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
}
if (tag_n_Nx != n_to) {
n_step = n_to - tag_n_Nx;
SBGEMM_BLOCK_KERNEL_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
}
}
if (tag_m_Nx != M) {
m_step = M - tag_m_Nx;
if (m_step > 16) {
COL_MAJOR_INCOPY_KERNEL_Kx32m(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
SBGEMM_BLOCK_KERNEL_32x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
}
if (tag_n_Nx != n_to) {
n_step = n_to - tag_n_Nx;
SBGEMM_BLOCK_KERNEL_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
}
} else if (m_step == 16) {
COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
SBGEMM_BLOCK_KERNEL_16x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
}
if (tag_n_Nx != n_to) {
n_step = n_to - tag_n_Nx;
SBGEMM_BLOCK_KERNEL_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
}
} else {
COL_MAJOR_INCOPY_KERNEL_Kx16m(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
SBGEMM_BLOCK_KERNEL_16x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
}
if (tag_n_Nx != n_to) {
n_step = n_to - tag_n_Nx;
SBGEMM_BLOCK_KERNEL_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
}
}
}
idx_k += k_step;
k_step = K - idx_k;
k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
k_step_round32 = k_step & (~31);
k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
}
n_from = n_to;
n_to += BF16_BLOCK_THRES_N;
n_to = (n_to > N) ? N : n_to;
tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
}
} else {
m_step = M - tag_m_Nx;
while (n_from < N) {
for (BLASLONG idx_k = 0; idx_k < K;) {
// Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
COL_MAJOR_INCOPY_KERNEL_Kx32m(k_step, m_step, &A(idx_k, 0), lda, block_A);
// TODO: MT
for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
// Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
SBGEMM_BLOCK_KERNEL_32x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
}
if (tag_n_Nx != n_to) {
n_step = n_to - tag_n_Nx;
COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
SBGEMM_BLOCK_KERNEL_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
}
idx_k += k_step;
k_step = K - idx_k;
k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
k_step_round32 = k_step & (~31);
k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
}
n_from = n_to;
n_to += BF16_BLOCK_THRES_N;
n_to = (n_to > N) ? N : n_to;
tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
}
}
}
#ifndef ONE_ALPHA // ALPHA is not ONE
void sbgemm_internal_kernel_alpha(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, float *C, OPENBLAS_CONST blasint ldc)
#else // ALPHA is ONE
void sbgemm_internal_kernel_one(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, float *C, OPENBLAS_CONST blasint ldc)
#endif
{
bfloat16 block_A[BF16_BLOCK_THRES_K * BF16_BLOCK_THRES_M];
bfloat16 block_B[BF16_BLOCK_THRES_N * BF16_BLOCK_THRES_K];
// TODO: assume no trans for both A and B, to complement these scenarios later
if (Order == CblasColMajor) {
SBGEMM_BLOCKING_KERNEL_2(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
} else {
}
}