Implementation of BF16 based gemv

1. Add a new API -- sbgemv to support bfloat16 based gemv
2. Implement a generic kernel for sbgemv
3. Implement an avx512-bf16 based kernel for sbgemv

Signed-off-by: Chen, Guobing <guobing.chen@intel.com>
This commit is contained in:
Chen, Guobing 2020-10-28 08:49:12 +08:00
parent 67f39ad813
commit a7b1f9b1bb
24 changed files with 5111 additions and 16 deletions

View File

@ -393,6 +393,7 @@ void cblas_sbf16tos(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPE
void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout);
/* dot production of BFLOAT16 input arrays, and output as float */
float cblas_sbdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy);
void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum CBLAS_TRANSPOSE trans, OPENBLAS_CONST blasint m, OPENBLAS_CONST blasint n, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *a, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST float beta, float *y, OPENBLAS_CONST blasint incy);
#ifdef __cplusplus
}

View File

@ -184,8 +184,8 @@ macro(SetDefaultL2)
set(XHEMV_V_KERNEL ../generic/zhemv_k.c)
set(XHEMV_M_KERNEL ../generic/zhemv_k.c)
if (BUILD_BFLOAT16)
set(SBGEMVNKERNEL ../arm/gemv_n.c)
set(SBGEMVTKERNEL ../arm/gemv_t.c)
set(SBGEMVNKERNEL ../x86_64/sbgemv_n.c)
set(SBGEMVTKERNEL ../x86_64/sbgemv_t.c)
set(SHGERKERNEL ../generic/ger.c)
endif ()
endmacro ()

View File

@ -250,6 +250,8 @@ void BLASFUNC(xgeru)(blasint *, blasint *, xdouble *, xdouble *, blasint *,
void BLASFUNC(xgerc)(blasint *, blasint *, xdouble *, xdouble *, blasint *,
xdouble *, blasint *, xdouble *, blasint *);
void BLASFUNC(sbgemv)(char *, blasint *, blasint *, float *, bfloat16 *, blasint *,
bfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(sgemv)(char *, blasint *, blasint *, float *, float *, blasint *,
float *, blasint *, float *, float *, blasint *);
void BLASFUNC(dgemv)(char *, blasint *, blasint *, double *, double *, blasint *,

View File

@ -44,6 +44,10 @@
extern "C" {
#endif
int sbgemv_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG);
int sbgemv_t(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG);
int sbgemv_thread_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG, int);
int sbgemv_thread_t(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG, int);
int sger_k (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int dger_k (BLASLONG, BLASLONG, BLASLONG, double, double *, BLASLONG, double *, BLASLONG, double *, BLASLONG, double *);
int qger_k (BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *);

View File

@ -646,10 +646,12 @@
#elif defined(BFLOAT16)
#define D_TO_BF16_K SBDTOBF16_K
#define D_BF16_TO_K DBF16TOD_K
#define S_TO_BF16_K SBSTOBF16_K
#define S_BF16_TO_K SBF16TOS_K
#define D_TO_BF16_K SBDTOBF16_K
#define D_BF16_TO_K DBF16TOD_K
#define S_TO_BF16_K SBSTOBF16_K
#define S_BF16_TO_K SBF16TOS_K
#define SBGEMV_N SBGEMV_N_K
#define SBGEMV_T SBGEMV_T_K
#define AMAX_K SAMAX_K
#define AMIN_K SAMIN_K

View File

@ -78,8 +78,8 @@ BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);
int (*sbscal_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
int (*sbswap_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
int (*sbgemv_n) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*sbgemv_t) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*sbgemv_n) (BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG);
int (*sbgemv_t) (BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG);
int (*sbger_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*sbsymv_L) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);

View File

@ -8,6 +8,8 @@
#define SBDTOBF16_K sbdtobf16_k
#define SBF16TOS_K sbf16tos_k
#define DBF16TOD_K dbf16tod_k
#define SBGEMV_N_K sbgemv_n
#define SBGEMV_T_K sbgemv_t
#define SBGEMM_ONCOPY sbgemm_oncopy
#define SBGEMM_OTCOPY sbgemm_otcopy
@ -29,6 +31,8 @@
#define SBDTOBF16_K gotoblas -> sbdtobf16_k
#define SBF16TOS_K gotoblas -> sbf16tos_k
#define DBF16TOD_K gotoblas -> dbf16tod_k
#define SBGEMV_N_K gotoblas -> sbgemv_n
#define SBGEMV_T_K gotoblas -> sbgemv_t
#define SBGEMM_ONCOPY gotoblas -> sbgemm_oncopy
#define SBGEMM_OTCOPY gotoblas -> sbgemm_otcopy

View File

@ -413,7 +413,13 @@ XBLASOBJS += \
xtbmv_thread_RUU.$(SUFFIX) xtbmv_thread_RUN.$(SUFFIX) \
xtbmv_thread_RLU.$(SUFFIX) xtbmv_thread_RLN.$(SUFFIX) \
xtbmv_thread_CUU.$(SUFFIX) xtbmv_thread_CUN.$(SUFFIX) \
xtbmv_thread_CLU.$(SUFFIX) xtbmv_thread_CLN.$(SUFFIX) \
xtbmv_thread_CLU.$(SUFFIX) xtbmv_thread_CLN.$(SUFFIX)
ifeq ($(BUILD_BFLOAT16),1)
SBBLASOBJS += \
sbgemv_thread_n$(TSUFFIX).$(SUFFIX) \
sbgemv_thread_t$(TSUFFIX).$(SUFFIX)
endif
endif
@ -3693,4 +3699,12 @@ xtrsv_CUU.$(SUFFIX) xtrsv_CUU.$(PSUFFIX) : ztrsv_L.c ../../param.h
xtrsv_CUN.$(SUFFIX) xtrsv_CUN.$(PSUFFIX) : ztrsv_L.c ../../param.h
$(CC) -c $(CFLAGS) -DXDOUBLE -DCOMPLEX -DTRANSA=4 -UUNIT $< -o $(@F)
ifeq ($(BUILD_BFLOAT16),1)
sbgemv_thread_n.$(SUFFIX) sbgemv_thread_n.$(PSUFFIX) : sbgemv_thread.c ../../common.h
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE -UTRANSA -UCONJ -UXCONJ $< -o $(@F)
sbgemv_thread_t.$(SUFFIX) sbgemv_thread_t.$(PSUFFIX) : sbgemv_thread.c ../../common.h
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE -DTRANSA -UCONJ -UXCONJ $< -o $(@F)
endif
include ../../Makefile.tail

View File

@ -0,0 +1,149 @@
/*********************************************************************/
/* Copyright 2009, 2010 The University of Texas at Austin. */
/* All rights reserved. */
/* */
/* Redistribution and use in source and binary forms, with or */
/* without modification, are permitted provided that the following */
/* conditions are met: */
/* */
/* 1. Redistributions of source code must retain the above */
/* copyright notice, this list of conditions and the following */
/* disclaimer. */
/* */
/* 2. Redistributions in binary form must reproduce the above */
/* copyright notice, this list of conditions and the following */
/* disclaimer in the documentation and/or other materials */
/* provided with the distribution. */
/* */
/* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */
/* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */
/* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */
/* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */
/* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */
/* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */
/* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */
/* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */
/* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */
/* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */
/* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */
/* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */
/* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */
/* POSSIBILITY OF SUCH DAMAGE. */
/* */
/* The views and conclusions contained in the software and */
/* documentation are those of the authors and should not be */
/* interpreted as representing official policies, either expressed */
/* or implied, of The University of Texas at Austin. */
/*********************************************************************/
#include <stdio.h>
#include <stdlib.h>
#include "common.h"
#ifndef TRANSA
#define SBGEMV SBGEMV_N
#else
#define SBGEMV SBGEMV_T
#endif
static int sbgemv_kernel(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *dummy1, FLOAT *dummy2, BLASLONG dummy3){
bfloat16 *a, *x;
float *y;
BLASLONG lda, incx, incy;
BLASLONG m_from, m_to, n_from, n_to;
a = (bfloat16 *)args->a;
x = (bfloat16 *)args->b;
y = (float *)args->c;
lda = args->lda;
incx = args->ldb;
incy = args->ldc;
#ifndef TRANSA // N
m_from = *(range_m + 0);
m_to = *(range_m + 1);
n_from = 0;
n_to = args -> n;
a += m_from;
y += m_from * incy;
#else // T
m_from = 0;
m_to = args->m;
n_from = *(range_n + 0);
n_to = *(range_n + 1);
a += n_from * lda;
y += n_from * incy;
#endif
SBGEMV(m_to - m_from, n_to - n_from, *((FLOAT *)(args->alpha)), a, lda, x, incx, *((FLOAT *)(args->beta)), y, incy);
return 0;
}
int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy, int threads)
{
blas_arg_t args;
blas_queue_t queue[MAX_CPU_NUMBER];
BLASLONG range[MAX_CPU_NUMBER + 1];
#ifndef TRANSA
BLASLONG width_for_split = m;
#else
BLASLONG width_for_split = n;
#endif
BLASLONG BLOCK_WIDTH = width_for_split/threads;
int mode = BLAS_BFLOAT16 | BLAS_REAL;
args.m = m;
args.n = n;
args.a = (void *)a;
args.b = (void *)x;
args.c = (void *)y;
args.lda = lda;
args.ldb = incx;
args.ldc = incy;
args.alpha = (void *)&alpha;
args.beta = (void *)&beta;
range[0] = 0;
int thread_idx;
for (thread_idx=0; thread_idx<threads; thread_idx++) {
if (thread_idx != threads-1) {
range[thread_idx + 1] = range[thread_idx] + BLOCK_WIDTH;
} else {
range[thread_idx + 1] = range[thread_idx] + width_for_split;
}
queue[thread_idx].mode = mode;
queue[thread_idx].routine = sbgemv_kernel;
queue[thread_idx].args = &args;
#ifndef TRANSA
queue[thread_idx].range_m = &range[thread_idx];
queue[thread_idx].range_n = NULL;
#else
queue[thread_idx].range_m = NULL;
queue[thread_idx].range_n = &range[thread_idx];
#endif
queue[thread_idx].sa = NULL;
queue[thread_idx].sb = NULL;
queue[thread_idx].next = &queue[thread_idx + 1];
width_for_split -= BLOCK_WIDTH;
}
if (thread_idx) {
queue[0].sa = NULL;
queue[0].sb = NULL;
queue[thread_idx - 1].next = NULL;
exec_blas(thread_idx, queue);
}
return 0;
}

View File

@ -352,7 +352,6 @@ fprintf(stderr,"UNHANDLED COMPLEX\n");
/* Other types in future */
}
}
if (!sb) fprintf(stderr,"SB not declared!!!\n");
queue->sb=sb;
}
}

View File

@ -51,7 +51,7 @@
zgeadd, dzsum);
@blasobjs = (lsame, xerbla);
@bfblasobjs = (sbgemm, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod);
@bfblasobjs = (sbgemm, sbgemv, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod);
@cblasobjsc = (
cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv,
cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k,
@ -94,7 +94,7 @@
@cblasobjs = ( cblas_xerbla );
@bfcblasobjs = (cblas_sbgemm, cblas_sbdot, cblas_sbstobf16, cblas_sbdtobf16, cblas_sbf16tos, cblas_dbf16tod);
@bfcblasobjs = (cblas_sbgemm, cblas_sbgemv, cblas_sbdot, cblas_sbstobf16, cblas_sbdtobf16, cblas_sbf16tos, cblas_dbf16tod);
@exblasobjs = (
qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm,

View File

@ -48,6 +48,7 @@ SBLAS3OBJS = \
ifeq ($(BUILD_BFLOAT16),1)
SBBLAS1OBJS = sbdot.$(SUFFIX)
SBBLAS2OBJS = sbgemv.$(SUFFIX)
SBBLAS3OBJS = sbgemm.$(SUFFIX)
SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX)
endif
@ -284,6 +285,7 @@ CSBLAS3OBJS = \
ifeq ($(BUILD_BFLOAT16),1)
CSBBLAS1OBJS = cblas_sbdot.$(SUFFIX)
CSBBLAS2OBJS = cblas_sbgemv.$(SUFFIX)
CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX)
CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX)
endif
@ -382,6 +384,7 @@ SBLAS1OBJS += $(CSBLAS1OBJS)
SBLAS2OBJS += $(CSBLAS2OBJS)
SBLAS3OBJS += $(CSBLAS3OBJS)
SBBLAS1OBJS += $(CSBBLAS1OBJS)
SBBLAS2OBJS += $(CSBBLAS2OBJS)
SBBLAS3OBJS += $(CSBBLAS3OBJS)
DBLAS1OBJS += $(CDBLAS1OBJS)
DBLAS2OBJS += $(CDBLAS2OBJS)
@ -399,7 +402,7 @@ CBAUXOBJS += $(CXERBLAOBJ)
endif
SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS)
SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS3OBJS)
SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS)
DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS)
QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS)
CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS)
@ -538,7 +541,7 @@ clean ::
level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS)
level2 : $(SBBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS)
@ -929,6 +932,11 @@ xgeru.$(SUFFIX) xgeru.$(PSUFFIX) : zger.c
xgerc.$(SUFFIX) xgerc.$(PSUFFIX) : zger.c
$(CC) -c $(CFLAGS) -DCONJ $< -o $(@F)
ifeq ($(BUILD_BFLOAT16),1)
sbgemv.$(SUFFIX) sbgemv.$(PSUFFIX) : sbgemv.c
$(CC) $(CFLAGS) -c $< -o $(@F)
endif
ifndef USE_NETLIB_GEMV
sgemv.$(SUFFIX) sgemv.$(PSUFFIX): gemv.c
$(CC) -c $(CFLAGS) -o $(@F) $<
@ -1656,6 +1664,11 @@ cblas_csscal.$(SUFFIX) cblas_csscal.$(PSUFFIX) : zscal.c
cblas_zdscal.$(SUFFIX) cblas_zdscal.$(PSUFFIX) : zscal.c
$(CC) $(CFLAGS) -DCBLAS -c -DSSCAL $< -o $(@F)
ifeq ($(BUILD_BFLOAT16),1)
cblas_sbgemv.$(SUFFIX) cblas_sbgemv.$(PSUFFIX) : sbgemv.c
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
endif
cblas_sgemv.$(SUFFIX) cblas_sgemv.$(PSUFFIX): gemv.c
$(CC) -DCBLAS -c $(CFLAGS) -o $(@F) $<

View File

@ -191,7 +191,6 @@ void CNAME(enum CBLAS_ORDER order,
}
#endif
//printf("m=%d, n=%d, trans=%d, incx=%d, incy=%d, alpha=%f, beta=%f\n", m, n, trans, incx, incy, alpha, beta);
if ((m==0) || (n==0)) return;
lenx = n;

210
interface/sbgemv.c Normal file
View File

@ -0,0 +1,210 @@
/*********************************************************************/
/* Copyright 2009, 2010 The University of Texas at Austin. */
/* All rights reserved. */
/* */
/* Redistribution and use in source and binary forms, with or */
/* without modification, are permitted provided that the following */
/* conditions are met: */
/* */
/* 1. Redistributions of source code must retain the above */
/* copyright notice, this list of conditions and the following */
/* disclaimer. */
/* */
/* 2. Redistributions in binary form must reproduce the above */
/* copyright notice, this list of conditions and the following */
/* disclaimer in the documentation and/or other materials */
/* provided with the distribution. */
/* */
/* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */
/* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */
/* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */
/* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */
/* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */
/* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */
/* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */
/* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */
/* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */
/* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */
/* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */
/* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */
/* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */
/* POSSIBILITY OF SUCH DAMAGE. */
/* */
/* The views and conclusions contained in the software and */
/* documentation are those of the authors and should not be */
/* interpreted as representing official policies, either expressed */
/* or implied, of The University of Texas at Austin. */
/*********************************************************************/
#include <stdio.h>
#include "common.h"
#include "l1param.h"
#ifdef FUNCTION_PROFILE
#include "functable.h"
#endif
#define ERROR_NAME "SBGEMV "
#ifdef SMP
static int (*sbgemv_thread[])(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 * , BLASLONG, float, float *, BLASLONG, int) = {
sbgemv_thread_n, sbgemv_thread_t,
};
#endif
#ifndef CBLAS
void NAME(char *TRANS, blasint *M, blasint *N, float *ALPHA, bfloat16 *a, blasint *LDA, bfloat16 *x, blasint *INCX, float *BETA, float *y, blasint *INCY)
{
char trans = *TRANS;
blasint m = *M;
blasint n = *N;
blasint lda = *LDA;
blasint incx = *INCX;
blasint incy = *INCY;
float alpha = *ALPHA;
float beta = *BETA;
#ifdef SMP
int nthreads;
#endif
int (*sbgemv[])(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 * , BLASLONG, float, float *, BLASLONG) = {
SBGEMV_N, SBGEMV_T,
};
blasint info;
blasint lenx, leny;
blasint i;
PRINT_DEBUG_NAME;
TOUPPER(trans);
info = 0;
i = -1;
if (trans == 'N') {i = 0;}
if (trans == 'T') {i = 1;}
if (trans == 'R') {i = 0;}
if (trans == 'C') {i = 1;}
if (incy == 0) {info = 11;}
if (incx == 0) {info = 8;}
if (lda < MAX(1, m)) {info = 6;}
if (n < 0) {info = 3;}
if (m < 0) {info = 2;}
if (i < 0) {info = 1;}
trans = i;
if (info != 0) {
BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
return;
}
#else
void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, blasint m, blasint n, float alpha, bfloat16 *a, blasint lda, bfloat16 *x, blasint incx, float beta, float *y, blasint incy)
{
blasint lenx, leny;
int trans;
blasint info, t;
#ifdef SMP
int nthreads;
#endif
int (*sbgemv[])(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 * , BLASLONG, float, float *, BLASLONG) = {
SBGEMV_N, SBGEMV_T,
};
PRINT_DEBUG_CNAME;
trans = -1;
info = 0;
if (order == CblasColMajor) { // Column Major
if (TransA == CblasNoTrans || TransA == CblasConjNoTrans) {
trans = 0;
} else if (TransA == CblasTrans || TransA == CblasConjTrans) {
trans = 1;
}
} else { // Row Major
if (TransA == CblasNoTrans || TransA == CblasConjNoTrans) {
trans = 1;
} else if (TransA == CblasTrans || TransA == CblasConjTrans) {
trans = 0;
}
t = n;
n = m;
m = t;
}
info = -1;
if (incy == 0) {info = 11;}
if (incx == 0) {info = 8;}
if (lda < MAX(1, m)) {info = 6;}
if (n < 0) {info = 3;}
if (m < 0) {info = 2;}
if (trans < 0) {info = 1;}
if (info >= 0) {
BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
return;
}
#endif
if ((m==0) || (n==0)) return;
if (trans) {
lenx = m;
leny = n;
} else {
lenx = n;
leny = m;
}
if (alpha == ZERO) {
if (beta != ONE) SCAL_K(leny, 0, 0, beta, y, blasabs(incy), NULL, 0, NULL, 0);
return;
}
IDEBUG_START;
FUNCTION_PROFILE_START();
if (incx < 0) {x -= (lenx - 1) * incx;}
if (incy < 0) {y -= (leny - 1) * incy;}
#ifdef SMP
int thread_thres_row = 20480;
if (trans) {
if (n <= thread_thres_row) {
nthreads = 1;
} else {
nthreads = num_cpu_avail(1);
}
} else {
if (m <= thread_thres_row) {
nthreads = 1;
} else {
nthreads = num_cpu_avail(1);
}
}
if (nthreads == 1) {
#endif
(sbgemv[(int)trans])(m, n, alpha, a, lda, x, incx, beta, y, incy);
#ifdef SMP
} else {
(sbgemv_thread[(int)trans])(m, n, alpha, a, lda, x, incx, beta, y, incy, nthreads);
}
#endif
FUNCTION_PROFILE_END(1, m * n + m + n, 2 * m * n);
IDEBUG_END;
return;
}

View File

@ -48,6 +48,16 @@ ifndef XGEMVTKERNEL
XGEMVTKERNEL = zgemv_t.S
endif
ifeq ($(BUILD_BFLOAT16),1)
ifndef SBGEMVNKERNEL
SBGEMVNKERNEL = ../x86_64/sbgemv_n.c
endif
ifndef SBGEMVTKERNEL
SBGEMVTKERNEL = ../x86_64/sbgemv_t.c
endif
endif
### GER ###
ifndef SGERKERNEL
@ -234,6 +244,12 @@ XBLASOBJS += \
xhemv_U$(TSUFFIX).$(SUFFIX) xhemv_L$(TSUFFIX).$(SUFFIX) xhemv_V$(TSUFFIX).$(SUFFIX) xhemv_M$(TSUFFIX).$(SUFFIX) \
xgeru_k$(TSUFFIX).$(SUFFIX) xgerc_k$(TSUFFIX).$(SUFFIX) xgerv_k$(TSUFFIX).$(SUFFIX) xgerd_k$(TSUFFIX).$(SUFFIX)
ifeq ($(BUILD_BFLOAT16),1)
SBBLASOBJS += \
sbgemv_n$(TSUFFIX).$(SUFFIX) \
sbgemv_t$(TSUFFIX).$(SUFFIX)
endif
ifneq "$(or $(BUILD_SINGLE), $(BUILD_DOUBLE), $(BUILD_COMPLEX))" ""
$(KDIR)sgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)sgemv_n$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMVNKERNEL) $(TOPDIR)/common.h $(GEMVDEP)
$(CC) -c $(CFLAGS) -UDOUBLE -UCOMPLEX -UTRANS $< -o $@
@ -483,4 +499,10 @@ $(KDIR)xhemv_V$(TSUFFIX).$(SUFFIX) $(KDIR)xhemv_V$(TSUFFIX).$(PSUFFIX) : $(KER
$(KDIR)xhemv_M$(TSUFFIX).$(SUFFIX) $(KDIR)xhemv_M$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(XHEMV_M_KERNEL) ../symcopy.h
$(CC) -c $(CFLAGS) -DCOMPLEX -DXDOUBLE -DLOWER -DHEMV -DHEMVREV $< -o $@
ifeq ($(BUILD_BFLOAT16),1)
$(KDIR)sbgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)sbgemv_n$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMVNKERNEL)
$(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@
$(KDIR)sbgemv_t$(TSUFFIX).$(SUFFIX) $(KDIR)sbgemv_t$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMVTKERNEL)
$(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@
endif

View File

@ -69,7 +69,7 @@ gotoblas_t TABLE_NAME = {
snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, sbdot_kTS,
dsdot_kTS,
srot_kTS, saxpy_kTS, sscal_kTS, sswap_kTS,
sgemv_nTS, sgemv_tTS, sger_kTS,
sbgemv_nTS, sbgemv_tTS, sger_kTS,
ssymv_LTS, ssymv_UTS,
sbgemm_kernelTS, sbgemm_betaTS,

View File

@ -384,6 +384,14 @@ endif
GEMVDEP = ../l2param.h
ifndef SBGEMVNKERNEL
SBGEMVNKERNEL = sbgemv_n.c
endif
ifndef SBGEMVTKERNEL
SBGEMVTKERNEL = sbgemv_t.c
endif
ifndef SGEMVNKERNEL
SGEMVNKERNEL = sgemv_n.c
endif

View File

@ -0,0 +1,795 @@
/***************************************************************************
Copyright (c) 2014, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#ifndef __BF16_COMMON_MACROS
#define __BF16_COMMON_MACROS
#include <immintrin.h>
#define EXTRACT_LOW_256_FROM_512_2X(reg256, reg512) \
reg256##_0 = _mm512_castps512_ps256(reg512##_0); \
reg256##_1 = _mm512_castps512_ps256(reg512##_1);
#define BF16_MATRIX_LOAD_8x32(regArray, a, lda, idx_m, idx_n) \
regArray##_0 = _mm512_loadu_si512(&a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm512_loadu_si512(&a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm512_loadu_si512(&a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm512_loadu_si512(&a[(idx_m+3)*lda + idx_n]); \
regArray##_4 = _mm512_loadu_si512(&a[(idx_m+4)*lda + idx_n]); \
regArray##_5 = _mm512_loadu_si512(&a[(idx_m+5)*lda + idx_n]); \
regArray##_6 = _mm512_loadu_si512(&a[(idx_m+6)*lda + idx_n]); \
regArray##_7 = _mm512_loadu_si512(&a[(idx_m+7)*lda + idx_n]);
#define BF16_MATRIX_LOAD_8x16(regArray, a, lda, idx_m, idx_n) \
regArray##_0 = _mm256_loadu_si256(&a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm256_loadu_si256(&a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm256_loadu_si256(&a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm256_loadu_si256(&a[(idx_m+3)*lda + idx_n]); \
regArray##_4 = _mm256_loadu_si256(&a[(idx_m+4)*lda + idx_n]); \
regArray##_5 = _mm256_loadu_si256(&a[(idx_m+5)*lda + idx_n]); \
regArray##_6 = _mm256_loadu_si256(&a[(idx_m+6)*lda + idx_n]); \
regArray##_7 = _mm256_loadu_si256(&a[(idx_m+7)*lda + idx_n]);
#define BF16_MATRIX_LOAD_8x8(regArray, a, lda, idx_m, idx_n) \
regArray##_0 = _mm_loadu_si128(&a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm_loadu_si128(&a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm_loadu_si128(&a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm_loadu_si128(&a[(idx_m+3)*lda + idx_n]); \
regArray##_4 = _mm_loadu_si128(&a[(idx_m+4)*lda + idx_n]); \
regArray##_5 = _mm_loadu_si128(&a[(idx_m+5)*lda + idx_n]); \
regArray##_6 = _mm_loadu_si128(&a[(idx_m+6)*lda + idx_n]); \
regArray##_7 = _mm_loadu_si128(&a[(idx_m+7)*lda + idx_n]);
#define BF16_MATRIX_LOAD_1x32(regArray, a, lda, idx_m, idx_n) \
regArray = _mm512_loadu_si512(&a[idx_m*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_8x32(regArray, a, lda, idx_m, idx_n, mask) \
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_8x16(regArray, a, lda, idx_m, idx_n, mask) \
regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
regArray##_4 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
regArray##_5 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
regArray##_6 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
regArray##_7 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_8x8(regArray, a, lda, idx_m, idx_n, mask) \
regArray##_0 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
regArray##_4 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
regArray##_5 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
regArray##_6 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
regArray##_7 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_4x32(regArray, a, lda, idx_m, idx_n, mask) \
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_4x16(regArray, a, lda, idx_m, idx_n, mask) \
regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_8x32_2(regArray, a, lda, idx_m, idx_n, mask) \
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+8)*lda + idx_n]); \
regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+10)*lda + idx_n]); \
regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+12)*lda + idx_n]); \
regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+14)*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_4x32_2(regArray, a, lda, idx_m, idx_n, mask) \
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]);
#define BF16_MATRIX_MASKZ_LOAD_1x32(regArray, a, lda, idx_m, idx_n, mask) \
regArray = _mm512_maskz_loadu_epi16(mask, &a[idx_m*lda + idx_n]);
#define BF16_VECTOR_LOAD_1x32(reg, x, idx_n) \
reg = _mm512_loadu_si512(x + idx_n);
#define BF16_VECTOR_LOAD_1x16(reg, x, idx_n) \
reg = _mm256_loadu_si256(x + idx_n);
#define BF16_VECTOR_LOAD_1x8(reg, x, idx_n) \
reg = _mm_loadu_si128(x + idx_n);
#define BF16_VECTOR_MASKZ_LOAD_1x32(reg, x, idx_n, mask) \
reg = _mm512_maskz_loadu_epi16(mask, x + idx_n);
#define BF16_VECTOR_MASKZ_LOAD_1x16(reg, x, idx_n, mask) \
reg = _mm256_maskz_loadu_epi16(mask, x + idx_n);
#define BF16_VECTOR_MASKZ_LOAD_1x8(reg, x, idx_n, mask) \
reg = _mm_maskz_loadu_epi16(mask, x + idx_n);
/* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row
Input - register array of 8 rows of raw-major matrix
Output - the output of Step 2
Step 1: 2-element interleave for matrix
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27
|e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11|e16|e17|f16|f17|e18|e19|f18|f19|e24|e25|f24|f25|e26|e27|f26|f27
|g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11|g16|g17|h16|h17|g18|g19|h18|h19|g24|g25|h24|h25|g26|g27|h26|h27
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31
|e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15|e20|e21|f20|f21|e22|e23|f22|f23|e28|e29|f28|f29|e30|e31|f30|f31
|g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15|g20|g21|h20|h21|g22|g23|h22|h23|g28|g29|h28|h29|g30|g31|h30|h31
Step 2: 4-element interleave for matrix
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27
|e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9 |e16|e17|f16|f17|g16|g17|h16|h17|e24|e25|f24|f25|g24|g25|h24|h25
|e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11|e18|e19|f18|f19|g18|g19|h18|h19|e26|e27|f26|f27|g26|g27|h26|h27
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31
|e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13|e20|e21|f20|f21|g20|g21|h20|h21|e28|e29|f28|f29|g28|g29|h28|h29
|e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15|e22|e23|f22|f23|g22|g23|h22|h23|e30|e31|f30|f31|g30|g31|h30|h31
*/
#define BF16_INTERLEAVE_8x32(regArray) \
regArray##_8 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \
regArray##_9 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \
regArray##_10 = _mm512_unpacklo_epi32(regArray##_4, regArray##_5); \
regArray##_11 = _mm512_unpacklo_epi32(regArray##_6, regArray##_7); \
regArray##_12 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \
regArray##_13 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \
regArray##_14 = _mm512_unpackhi_epi32(regArray##_4, regArray##_5); \
regArray##_15 = _mm512_unpackhi_epi32(regArray##_6, regArray##_7); \
\
regArray##_0 = _mm512_unpacklo_epi64(regArray##_8, regArray##_9); \
regArray##_1 = _mm512_unpackhi_epi64(regArray##_8, regArray##_9); \
regArray##_2 = _mm512_unpacklo_epi64(regArray##_10, regArray##_11); \
regArray##_3 = _mm512_unpackhi_epi64(regArray##_10, regArray##_11); \
regArray##_4 = _mm512_unpacklo_epi64(regArray##_12, regArray##_13); \
regArray##_5 = _mm512_unpackhi_epi64(regArray##_12, regArray##_13); \
regArray##_6 = _mm512_unpacklo_epi64(regArray##_14, regArray##_15); \
regArray##_7 = _mm512_unpackhi_epi64(regArray##_14, regArray##_15);
/* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row
Input - register array of 8 rows of raw-major matrix
Output - the output of Step 2
Step 1: 2-element interleave for matrix
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11
|e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11
|g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15
|e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15
|g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15
Step 2: 4-element interleave for matrix
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11
|e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9
|e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15
|e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13
|e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15
*/
#define BF16_INTERLEAVE_8x16(regArray) \
regArray##_8 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \
regArray##_9 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \
regArray##_10 = _mm256_unpacklo_epi32(regArray##_4, regArray##_5); \
regArray##_11 = _mm256_unpacklo_epi32(regArray##_6, regArray##_7); \
regArray##_12 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \
regArray##_13 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \
regArray##_14 = _mm256_unpackhi_epi32(regArray##_4, regArray##_5); \
regArray##_15 = _mm256_unpackhi_epi32(regArray##_6, regArray##_7); \
\
regArray##_0 = _mm256_unpacklo_epi64(regArray##_8, regArray##_9); \
regArray##_1 = _mm256_unpackhi_epi64(regArray##_8, regArray##_9); \
regArray##_2 = _mm256_unpacklo_epi64(regArray##_10, regArray##_11); \
regArray##_3 = _mm256_unpackhi_epi64(regArray##_10, regArray##_11); \
regArray##_4 = _mm256_unpacklo_epi64(regArray##_12, regArray##_13); \
regArray##_5 = _mm256_unpackhi_epi64(regArray##_12, regArray##_13); \
regArray##_6 = _mm256_unpacklo_epi64(regArray##_14, regArray##_15); \
regArray##_7 = _mm256_unpackhi_epi64(regArray##_14, regArray##_15);
/* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row
Input - register array of 8 rows of raw-major matrix
Output - the output of Step 2
Step 1: 2-element interleave for matrix
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31
Step 2: 4-element interleave for matrix
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31
*/
#define BF16_INTERLEAVE_4x32(regArray) \
regArray##_4 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \
regArray##_5 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \
regArray##_6 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \
regArray##_7 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \
\
regArray##_0 = _mm512_unpacklo_epi64(regArray##_4, regArray##_5); \
regArray##_1 = _mm512_unpackhi_epi64(regArray##_4, regArray##_5); \
regArray##_2 = _mm512_unpacklo_epi64(regArray##_6, regArray##_7); \
regArray##_3 = _mm512_unpackhi_epi64(regArray##_6, regArray##_7);
/* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row
Input - register array of 8 rows of raw-major matrix
Output - the output of Step 2
Step 1: 2-element interleave for matrix
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15
Step 2: 4-element interleave for matrix
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15
*/
#define BF16_INTERLEAVE_4x16(regArray) \
regArray##_4 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \
regArray##_5 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \
regArray##_6 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \
regArray##_7 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \
\
regArray##_0 = _mm256_unpacklo_epi64(regArray##_4, regArray##_5); \
regArray##_1 = _mm256_unpackhi_epi64(regArray##_4, regArray##_5); \
regArray##_2 = _mm256_unpacklo_epi64(regArray##_6, regArray##_7); \
regArray##_3 = _mm256_unpackhi_epi64(regArray##_6, regArray##_7);
/* 2-step interleave for x with 32 BF16 elements
Input - original vector
Output - the output of Step 2
Step 1: 2-element interleave for x:
|x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11|x16|x17|x16|x17|x18|x19|x18|x19|x24|x25|x24|x25|x26|x27|x26|x27
|x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15|x20|x21|x20|x21|x22|x23|x22|x23|x28|x29|x28|x29|x30|x31|x30|x31
Step 2: 4-element interleave for x:
|x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9 |x16|x17|x16|x17|x16|x17|x16|x17|x24|x25|x24|x25|x24|x25|x24|x25
|x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11|x18|x19|x18|x19|x18|x19|x18|x19|x26|x27|x26|x27|x26|x27|x26|x27
|x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13|x20|x21|x20|x21|x20|x21|x20|x21|x28|x29|x28|x29|x28|x29|x28|x29
|x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15|x22|x23|x22|x23|x22|x23|x22|x23|x30|x31|x30|x31|x30|x31|x30|x31
*/
#define BF16_INTERLEAVE_1x32(regArray) \
regArray##_1 = _mm512_unpacklo_epi32(regArray##_0, regArray##_0); \
regArray##_3 = _mm512_unpackhi_epi32(regArray##_0, regArray##_0); \
\
regArray##_0 = _mm512_unpacklo_epi64(regArray##_1, regArray##_1); \
regArray##_1 = _mm512_unpackhi_epi64(regArray##_1, regArray##_1); \
regArray##_2 = _mm512_unpacklo_epi64(regArray##_3, regArray##_3); \
regArray##_3 = _mm512_unpackhi_epi64(regArray##_3, regArray##_3);
/* 2-step interleave for x with 16 BF16 elements
Input - original vector
Output - the output of Step 2
Step 1: 2-element interleave for x:
|x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11
|x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15
Step 2: 4-element interleave for x:
|x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9
|x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11
|x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13
|x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15
*/
#define BF16_INTERLEAVE_1x16(regArray) \
regArray##_1 = _mm256_unpacklo_epi32(regArray##_0, regArray##_0); \
regArray##_3 = _mm256_unpackhi_epi32(regArray##_0, regArray##_0); \
\
regArray##_0 = _mm256_unpacklo_epi64(regArray##_1, regArray##_1); \
regArray##_1 = _mm256_unpackhi_epi64(regArray##_1, regArray##_1); \
regArray##_2 = _mm256_unpacklo_epi64(regArray##_3, regArray##_3); \
regArray##_3 = _mm256_unpackhi_epi64(regArray##_3, regArray##_3);
/* 1-step interleave to exchange the high-256s bit and low-256 bits of 4 pair of registers
|a0|a1|...|a14|a15|i0|i1|...|i14|i15|
|b0|b1|...|b14|b15|j0|j1|...|j14|j15|
|c0|c1|...|c14|c15|k0|k1|...|k14|k15|
|d0|d1|...|d14|d15|l0|l1|...|l14|l15|
|e0|e1|...|e14|e15|m0|m1|...|m14|m15|
|f0|f1|...|f14|f15|n0|n1|...|n14|n15|
|g0|g1|...|g14|g15|o0|o1|...|o14|o15|
|h0|h1|...|h14|h15|p0|p1|...|p14|p15|
*/
#define BF16_INTERLEAVE256_8x32(regArray) \
regArray##_0 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0x44); \
regArray##_1 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0xee); \
regArray##_2 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0x44); \
regArray##_3 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0xee); \
regArray##_4 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0x44); \
regArray##_5 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0xee); \
regArray##_6 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0x44); \
regArray##_7 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0xee);
/* 1-step interleave to exchange the high-256s bit and low-256 bits of 2 pair of registers
|a0|a1|...|a14|a15|e0|e1|...|e14|e15|
|b0|b1|...|b14|b15|f0|f1|...|f14|f15|
|c0|c1|...|c14|c15|g0|g1|...|g14|g15|
|d0|d1|...|d14|d15|h0|h1|...|h14|h15|
*/
#define BF16_INTERLEAVE256_4x32(regArray) \
regArray##_0 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0x44); \
regArray##_1 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0xee); \
regArray##_2 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0x44); \
regArray##_3 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0xee);
#define BF16_PERMUTE_8x32(idx, regArray) \
regArray##_8 = _mm512_permutexvar_epi16(idx, regArray##_0); \
regArray##_9 = _mm512_permutexvar_epi16(idx, regArray##_1); \
regArray##_10 = _mm512_permutexvar_epi16(idx, regArray##_2); \
regArray##_11 = _mm512_permutexvar_epi16(idx, regArray##_3); \
regArray##_12 = _mm512_permutexvar_epi16(idx, regArray##_4); \
regArray##_13 = _mm512_permutexvar_epi16(idx, regArray##_5); \
regArray##_14 = _mm512_permutexvar_epi16(idx, regArray##_6); \
regArray##_15 = _mm512_permutexvar_epi16(idx, regArray##_7);
#define BF16_PERMUTE_8x32_2(idx, regArray) \
regArray##_8 = _mm512_permutexvar_epi32(idx, regArray##_0); \
regArray##_9 = _mm512_permutexvar_epi32(idx, regArray##_1); \
regArray##_10 = _mm512_permutexvar_epi32(idx, regArray##_2); \
regArray##_11 = _mm512_permutexvar_epi32(idx, regArray##_3); \
regArray##_12 = _mm512_permutexvar_epi32(idx, regArray##_4); \
regArray##_13 = _mm512_permutexvar_epi32(idx, regArray##_5); \
regArray##_14 = _mm512_permutexvar_epi32(idx, regArray##_6); \
regArray##_15 = _mm512_permutexvar_epi32(idx, regArray##_7);
#define BF16_PERMUTE_4x32(idx, regArray) \
regArray##_4 = _mm512_permutexvar_epi16(idx, regArray##_0); \
regArray##_5 = _mm512_permutexvar_epi16(idx, regArray##_1); \
regArray##_6 = _mm512_permutexvar_epi16(idx, regArray##_2); \
regArray##_7 = _mm512_permutexvar_epi16(idx, regArray##_3);
#define BF16_PERMUTE_4x32_2(idx, regArray) \
regArray##_4 = _mm512_permutexvar_epi32(idx, regArray##_0); \
regArray##_5 = _mm512_permutexvar_epi32(idx, regArray##_1); \
regArray##_6 = _mm512_permutexvar_epi32(idx, regArray##_2); \
regArray##_7 = _mm512_permutexvar_epi32(idx, regArray##_3);
/* Calculate the dot result for 2-step interleaved matrix and vector
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
*/
#define BF16_2STEP_INTERLEAVED_DOT_8x32(accumArray, matArray, xArray) \
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_2, (__m512bh) xArray##_0); \
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_1); \
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_4, (__m512bh) xArray##_2); \
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_6, (__m512bh) xArray##_2); \
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_5, (__m512bh) xArray##_3); \
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_7, (__m512bh) xArray##_3);
/* Calculate the dot result for 2-step interleaved matrix and vector
(Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
*/
#define BF16_2STEP_INTERLEAVED_DOT_8x16(accumArray, matArray, xArray) \
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_2, (__m256bh) xArray##_0); \
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_1); \
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_4, (__m256bh) xArray##_2); \
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_6, (__m256bh) xArray##_2); \
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_5, (__m256bh) xArray##_3); \
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_7, (__m256bh) xArray##_3);
/* Calculate the dot result for 2-step interleaved matrix and vector
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
*/
#define BF16_2STEP_INTERLEAVED_DOT_4x32(accumArray, matArray, xArray) \
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_2, (__m512bh) xArray##_2); \
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_3);
/* Calculate the dot result for 2-step interleaved matrix and vector
(Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
*/
#define BF16_2STEP_INTERLEAVED_DOT_4x16(accumArray, matArray, xArray) \
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_2, (__m256bh) xArray##_2); \
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_3);
/* Calculate the dot result for matrix and vector at 32 elements per row
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
*/
#define BF16_DOT_8x32(accumArray, matArray, xArray) \
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray); \
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray); \
accumArray##_2 = _mm512_dpbf16_ps(accumArray##_2, (__m512bh) matArray##_2, (__m512bh) xArray); \
accumArray##_3 = _mm512_dpbf16_ps(accumArray##_3, (__m512bh) matArray##_3, (__m512bh) xArray); \
accumArray##_4 = _mm512_dpbf16_ps(accumArray##_4, (__m512bh) matArray##_4, (__m512bh) xArray); \
accumArray##_5 = _mm512_dpbf16_ps(accumArray##_5, (__m512bh) matArray##_5, (__m512bh) xArray); \
accumArray##_6 = _mm512_dpbf16_ps(accumArray##_6, (__m512bh) matArray##_6, (__m512bh) xArray); \
accumArray##_7 = _mm512_dpbf16_ps(accumArray##_7, (__m512bh) matArray##_7, (__m512bh) xArray);
/* Calculate the dot result for matrix and vector at 32 elements per row
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
*/
#define BF16_DOT_1x32(accumArray, matArray, xArray) \
accumArray = _mm512_dpbf16_ps(accumArray, (__m512bh) matArray, (__m512bh) xArray);
/* Calculate the dot result for matrix and vector at 16 elements per row
(Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
*/
#define BF16_DOT_8x16(accumArray, matArray, xArray) \
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray); \
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray); \
accumArray##_2 = _mm256_dpbf16_ps(accumArray##_2, (__m256bh) matArray##_2, (__m256bh) xArray); \
accumArray##_3 = _mm256_dpbf16_ps(accumArray##_3, (__m256bh) matArray##_3, (__m256bh) xArray); \
accumArray##_4 = _mm256_dpbf16_ps(accumArray##_4, (__m256bh) matArray##_4, (__m256bh) xArray); \
accumArray##_5 = _mm256_dpbf16_ps(accumArray##_5, (__m256bh) matArray##_5, (__m256bh) xArray); \
accumArray##_6 = _mm256_dpbf16_ps(accumArray##_6, (__m256bh) matArray##_6, (__m256bh) xArray); \
accumArray##_7 = _mm256_dpbf16_ps(accumArray##_7, (__m256bh) matArray##_7, (__m256bh) xArray);
/* 2-step interleave for matrix against 8 rows with 16 fp32 elements per row
Input - register array of 8 rows of raw-major matrix
Output - the output of Step 2
Step 1: 2-element interleave for matrix
|a0|b0|a1|b1|a4|b4|a5|b5|a8 |b8 |a9 |b9 |a12|b12|a13|b13|
|c0|d0|c1|d1|c4|d4|c5|d5|c8 |d8 |c9 |d9 |c12|d12|c13|d13|
|e0|f0|e1|f1|e4|f4|e5|f5|e8 |f8 |e9 |f9 |e12|f12|e13|f13|
|g0|h0|g1|h1|g4|h4|g5|h5|g8 |h8 |g9 |h9 |g12|h12|g13|h13|
|a2|b2|a3|b3|a6|b6|a7|b7|a10|b10|a11|b11|a14|b14|a15|b15|
|c2|d2|c3|d3|c6|d6|c7|d7|c10|d10|c11|d11|c14|d14|c15|d15|
|e2|f2|e3|f3|e6|f6|e7|f7|e10|f10|e11|f11|e14|f14|e15|f15|
|g2|h2|g3|h3|g6|h6|g7|h7|g10|h10|g11|h11|g14|h14|g15|h15|
Step 2: 4-element interleave for matrix
|a0|b0|c0|d0|a4|b4|c4|d4|a8 |b8 |c8 |d8 |a12|b12|c12|d12|
|a1|b1|c1|d1|a5|b5|c5|d5|a9 |b9 |c9 |d9 |a13|b13|c13|d13|
|e0|f0|g0|h0|e4|f4|g4|h4|e8 |f8 |g8 |h8 |e12|f12|g12|h12|
|e1|f1|g1|h1|e5|f5|g5|h5|e9 |f9 |g9 |h9 |e13|f13|g13|h13|
|a2|b2|c2|d2|a6|b6|c6|d6|a10|b10|c10|d10|a14|b14|c14|d14|
|a3|b3|c3|d3|a7|b7|c7|d7|a11|b11|c11|d11|a15|b15|c15|d15|
|e2|f2|g2|h2|e6|f6|g6|h6|e10|f10|g10|h10|e14|f14|g14|h14|
|e3|f3|g3|h3|e7|f7|g7|h7|e11|f11|g11|h11|e15|f15|g15|h15|
*/
#define FP32_INTERLEAVE_8x16(regArray) \
regArray##_8 = _mm512_unpacklo_ps(regArray##_0, regArray##_1); \
regArray##_9 = _mm512_unpacklo_ps(regArray##_2, regArray##_3); \
regArray##_10 = _mm512_unpacklo_ps(regArray##_4, regArray##_5); \
regArray##_11 = _mm512_unpacklo_ps(regArray##_6, regArray##_7); \
regArray##_12 = _mm512_unpackhi_ps(regArray##_0, regArray##_1); \
regArray##_13 = _mm512_unpackhi_ps(regArray##_2, regArray##_3); \
regArray##_14 = _mm512_unpackhi_ps(regArray##_4, regArray##_5); \
regArray##_15 = _mm512_unpackhi_ps(regArray##_6, regArray##_7); \
\
regArray##_0 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \
regArray##_1 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \
regArray##_4 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \
regArray##_5 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \
regArray##_2 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \
regArray##_3 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \
regArray##_6 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_14, (__m512d) regArray##_15); \
regArray##_7 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_14, (__m512d) regArray##_15);
#define FP32_INTERLEAVE_8x16_ARRAY(regArray) \
regArray[8] = _mm512_unpacklo_ps(regArray[0], regArray[1]); \
regArray[9] = _mm512_unpacklo_ps(regArray[2], regArray[3]); \
regArray[10] = _mm512_unpacklo_ps(regArray[4], regArray[5]); \
regArray[11] = _mm512_unpacklo_ps(regArray[6], regArray[7]); \
regArray[12] = _mm512_unpackhi_ps(regArray[0], regArray[1]); \
regArray[13] = _mm512_unpackhi_ps(regArray[2], regArray[3]); \
regArray[14] = _mm512_unpackhi_ps(regArray[4], regArray[5]); \
regArray[15] = _mm512_unpackhi_ps(regArray[6], regArray[7]); \
\
regArray[0] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[8], (__m512d) regArray[9]); \
regArray[1] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[8], (__m512d) regArray[9]); \
regArray[4] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[10], (__m512d) regArray[11]); \
regArray[5] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[10], (__m512d) regArray[11]); \
regArray[2] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[12], (__m512d) regArray[13]); \
regArray[3] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[12], (__m512d) regArray[13]); \
regArray[6] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[14], (__m512d) regArray[15]); \
regArray[7] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[14], (__m512d) regArray[15]);
/* 2-step interleave for matrix against 8 rows with 8 fp32 elements per row
Input - register array of 8 rows of raw-major matrix
Output - the output of Step 2
Step 1: 2-element interleave for matrix
|a0|b0|a1|b1|a4|b4|a5|b5|
|c0|d0|c1|d1|c4|d4|c5|d5|
|e0|f0|e1|f1|e4|f4|e5|f5|
|g0|h0|g1|h1|g4|h4|g5|h5|
|a2|b2|a3|b3|a6|b6|a7|b7|
|c2|d2|c3|d3|c6|d6|c7|d7|
|e2|f2|e3|f3|e6|f6|e7|f7|
|g2|h2|g3|h3|g6|h6|g7|h7|
Step 2: 4-element interleave for matrix
|a0|b0|c0|d0|a4|b4|c4|d4|
|a1|b1|c1|d1|a5|b5|c5|d5|
|e0|f0|g0|h0|e4|f4|g4|h4|
|e1|f1|g1|h1|e5|f5|g5|h5|
|a2|b2|c2|d2|a6|b6|c6|d6|
|a3|b3|c3|d3|a7|b7|c7|d7|
|e2|f2|g2|h2|e6|f6|g6|h6|
|e3|f3|g3|h3|e7|f7|g7|h7|
*/
#define FP32_INTERLEAVE_8x8(regArray) \
regArray##_8 = _mm256_unpacklo_ps(regArray##_0, regArray##_1); \
regArray##_9 = _mm256_unpacklo_ps(regArray##_2, regArray##_3); \
regArray##_10 = _mm256_unpacklo_ps(regArray##_4, regArray##_5); \
regArray##_11 = _mm256_unpacklo_ps(regArray##_6, regArray##_7); \
regArray##_12 = _mm256_unpackhi_ps(regArray##_0, regArray##_1); \
regArray##_13 = _mm256_unpackhi_ps(regArray##_2, regArray##_3); \
regArray##_14 = _mm256_unpackhi_ps(regArray##_4, regArray##_5); \
regArray##_15 = _mm256_unpackhi_ps(regArray##_6, regArray##_7); \
\
regArray##_0 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \
regArray##_1 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \
regArray##_4 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \
regArray##_5 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \
regArray##_2 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \
regArray##_3 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \
regArray##_6 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_14, (__m256d) regArray##_15); \
regArray##_7 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_14, (__m256d) regArray##_15);
/* Accumulate the result for 2 batch of 4-registers
*/
#define FP32_ACCUM2_8x16(regArray) \
regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_1); \
regArray##_2 = _mm512_add_ps(regArray##_2, regArray##_3); \
regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_5); \
regArray##_6 = _mm512_add_ps(regArray##_6, regArray##_7); \
regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_2); \
regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_6);
#define FP32_ACCUM2_8x16_ARRAY(regArray) \
regArray[0] = _mm512_add_ps(regArray[0], regArray[1]); \
regArray[2] = _mm512_add_ps(regArray[2], regArray[3]); \
regArray[4] = _mm512_add_ps(regArray[4], regArray[5]); \
regArray[6] = _mm512_add_ps(regArray[6], regArray[7]); \
regArray[0] = _mm512_add_ps(regArray[0], regArray[2]); \
regArray[4] = _mm512_add_ps(regArray[4], regArray[6]);
/* Accumulate the result for 2 batch of 4-registers
*/
#define FP32_ACCUM2_8x8(regArray) \
regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_1); \
regArray##_2 = _mm256_add_ps(regArray##_2, regArray##_3); \
regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_5); \
regArray##_6 = _mm256_add_ps(regArray##_6, regArray##_7); \
regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_2); \
regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_6);
/* Store 16 (alpha * result + beta * y) to y
*/
#define STORE16_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_loadu_ps(targetAddr))); \
_mm512_storeu_ps(targetAddr, regResult);
/* Masked store 16 (alpha * result + beta * y) to y
*/
#define STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_maskz_loadu_ps(mask, targetAddr))); \
_mm512_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 8 (alpha * result + beta * y) to y
*/
#define STORE8_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_loadu_ps(targetAddr))); \
_mm256_storeu_ps(targetAddr, regResult);
/* Masked store 8 (alpha * result + beta * y) to y
*/
#define STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_maskz_loadu_ps(mask, targetAddr))); \
_mm256_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 4 (alpha * result + beta * y) to y
*/
#define STORE4_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_loadu_ps(targetAddr))); \
_mm_storeu_ps(targetAddr, regResult);
/* Masked store 4 (alpha * result + beta * y) to y
*/
#define STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_maskz_loadu_ps(mask, targetAddr))); \
_mm_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 16 (alpha * result + y) to y
*/
#define STORE16_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_loadu_ps(targetAddr)); \
_mm512_storeu_ps(targetAddr, regResult);
/* Masked store 16 (alpha * result + y) to y
*/
#define STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \
_mm512_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 8 (alpha * result + y) to y
*/
#define STORE8_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_loadu_ps(targetAddr)); \
_mm256_storeu_ps(targetAddr, regResult);
/* Masked store 8 (alpha * result + y) to y
*/
#define STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \
_mm256_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 4 (alpha * result + y) to y
*/
#define STORE4_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_loadu_ps(targetAddr)); \
_mm_storeu_ps(targetAddr, regResult);
/* Masked store 4 (alpha * result + y) to y
*/
#define STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \
_mm_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 16 (alpha * result) to y
*/
#define STORE16_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
_mm512_storeu_ps(targetAddr, _mm512_mul_ps(ALPHAVECTOR, regResult));
/* Masked store 16 (alpha * result) to y
*/
#define STORE16_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
_mm512_mask_storeu_ps(targetAddr, mask, _mm512_mul_ps(ALPHAVECTOR, regResult));
/* Store 8 (alpha * result) to y
*/
#define STORE8_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
_mm256_storeu_ps(targetAddr, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult));
/* Masked store 8 (alpha * result) to y
*/
#define STORE8_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
_mm256_mask_storeu_ps(targetAddr, mask, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult));
/* Store 4 (alpha * result) to y
*/
#define STORE4_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
_mm_storeu_ps(targetAddr, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult));
/* Masked store 4 (alpha * result) to y
*/
#define STORE4_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
_mm_mask_storeu_ps(targetAddr, mask, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult));
/* Store 16 result to y
*/
#define STORE16_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
_mm512_storeu_ps(targetAddr, regResult);
/* Masked store 16 result to y
*/
#define STORE16_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
_mm512_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 8 result to y
*/
#define STORE8_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
_mm256_storeu_ps(targetAddr, regResult);
/* Masked store 8 result to y
*/
#define STORE8_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
_mm256_mask_storeu_ps(targetAddr, mask, regResult);
/* Store 4 result to y
*/
#define STORE4_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
_mm_storeu_ps(targetAddr, regResult);
/* Masked store 4 result to y
*/
#define STORE4_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
_mm_mask_storeu_ps(targetAddr, mask, regResult);
#endif

137
kernel/x86_64/sbgemv_n.c Normal file
View File

@ -0,0 +1,137 @@
/***************************************************************************
Copyright (c) 2014, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#include "common.h"
#if defined (COOPERLAKE)
#include "sbgemv_n_microk_cooperlake.c"
#endif
#define ALIGN64_ALLOC(alloc_size, TYPE, ptr_align, ptr) \
ptr = (TYPE *) malloc(sizeof(TYPE)*alloc_size + 63); \
ptr_align = ((int)(((uintptr_t)ptr & (uintptr_t)0x3F))!=0) ? (TYPE *)((char *)ptr + (64 - (int)((uintptr_t)ptr & (uintptr_t)0x3F))) : ptr
#define ALIGN64_FREE(ptr) \
free(ptr)
#ifndef HAVE_SBGEMV_N_ACCL_KERNEL
static void sbgemv_kernel_n(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
{
BLASLONG offset_lda, offset_m;
float accum = 0.0;
float tmp_x = 0.0;
bfloat16 * a_bf16 = malloc(sizeof(bfloat16)*m*n);
float * a_fp32 = malloc(sizeof(float)*m*n);
float * x_fp32 = malloc(sizeof(float)*n);
for (BLASLONG j=0; j<n; j++) {
offset_lda = lda * j;
offset_m = m * j;
for (BLASLONG i=0; i<m; i++) {
a_bf16[offset_m + i] = a[offset_lda + i];
}
}
SBF16TOS_K(n, x, 1, x_fp32, 1);
SBF16TOS_K(m*n, a_bf16, 1, a_fp32, 1);
for (BLASLONG i=0; i<m; i++) {
accum = 0.0;
for (BLASLONG j=0; j<n; j++) {
accum += a_fp32[j*m + i] * x_fp32[j];
}
if (beta == ZERO) {
y[i] = alpha * accum;
} else {
y[i] = alpha * accum + beta * y[i];
}
}
free(a_bf16);
free(a_fp32);
free(x_fp32);
}
#endif
static void bf16_compress_vector(BLASLONG n, bfloat16 * src, bfloat16 * target, BLASLONG inc)
{
for(BLASLONG i=0; i<n; i++) {
target[i] = src[i*inc];
}
}
static void fp32_compress_vector(BLASLONG n, float * src, float * target, BLASLONG inc)
{
for(BLASLONG i=0; i<n; i++) {
target[i] = src[i*inc];
}
}
static void fp32_expand_vector(BLASLONG n, float * src, float * target, BLASLONG inc)
{
for(BLASLONG i=0; i<n; i++) {
target[i*inc] = src[i];
}
}
int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float * y, BLASLONG incy)
{
if ( m < 1 || n < 1) return(0);
bfloat16 * xbuffer_align = x;
float * ybuffer_align = y;
bfloat16 * xbuffer = NULL;
float * ybuffer = NULL;
if (incx != 1) {
ALIGN64_ALLOC(n, bfloat16, xbuffer_align, xbuffer);
bf16_compress_vector(n, x, xbuffer_align, incx);
}
if (incy != 1) {
ALIGN64_ALLOC(m, float, ybuffer_align, ybuffer);
if (beta != ZERO) {
fp32_compress_vector(m, y, ybuffer_align, incy);
}
}
sbgemv_kernel_n(m, n, alpha, a, lda, xbuffer_align, beta, ybuffer_align);
if (incy != 1) {
fp32_expand_vector(m, ybuffer_align, y, incy);
ALIGN64_FREE(ybuffer);
}
if (incx != 1) {
ALIGN64_FREE(xbuffer);
}
return(0);
}

View File

@ -0,0 +1,76 @@
/***************************************************************************
Copyright (c) 2014, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
/* need a new enough GCC for avx512 support */
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
#define HAVE_SBGEMV_N_ACCL_KERNEL 1
#include "common.h"
#include <immintrin.h>
// Define micro kernels for ALPHA not ONE && BETA effective && BETA not ONE scenarios
#undef ZERO_BETA
#undef ONE_BETA
#undef ONE_ALPHA
#include "sbgemv_n_microk_cooperlake_template.c"
// Define micro kernels for ALPHA not ONE && BETA as ONE scenarios
#undef ZERO_BETA
#define ONE_BETA 1
#undef ONE_ALPHA
#include "sbgemv_n_microk_cooperlake_template.c"
// Define micro kernels for ALPHA not ONE && BETA in-effective (BETA == 0) scenarios
#define ZERO_BETA 1
#undef ONE_ALPHA
#include "sbgemv_n_microk_cooperlake_template.c"
// Define micro kernels for ALPHA as ONE && BETA in-effective (BETA == 0) scenarios
#define ZERO_BETA 1
#define ONE_ALPHA 1
#include "sbgemv_n_microk_cooperlake_template.c"
static int sbgemv_kernel_n(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
{
if (beta == ZERO) { // BETA == 0.0, no need to accumulate the original Y data
if (alpha == ONE) { // ALPHA == 1.0, no need to multipy ALPHA
sbgemv_kernel_32xN_lda_direct(m, n, alpha, a, lda, x, y);
} else { // ALPHA != 1.0, need to multipy ALPHA
sbgemv_kernel_32xN_lda_direct_alpha(m, n, alpha, a, lda, x, y);
}
} else { // BETA != 0.0, need to accumulate the original Y data no matter what ALPHA is
if (beta == ONE) {
sbgemv_kernel_32xN_lda_direct_alpha_one(m, n, alpha, a, lda, x, beta, y);
} else {
sbgemv_kernel_32xN_lda_direct_alpha_beta(m, n, alpha, a, lda, x, beta, y);
}
}
return 0;
}
#endif

View File

@ -0,0 +1,234 @@
/***************************************************************************
Copyright (c) 2014, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#include <immintrin.h>
#include "common.h"
// Include common macros for BF16 based operations with IA intrinsics
#include "bf16_common_macros.h"
#ifndef ZERO_BETA // Beta is non-zero
#ifndef ONE_BETA // BETA is not ONE
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_BETA
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA_BETA
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA_BETA
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA
#else // BETA is ONE
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA_ONE
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA_ONE
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE
#endif
#else // BETA is zero
#ifndef ONE_ALPHA // ALPHA is not ONE
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA
#else // ALPHA is ONE
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_DIRECT
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_DIRECT
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_DIRECT
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_DIRECT
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_DIRECT
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_DIRECT
#endif
#endif
// 8 rows parallel processing BF16 GEMV kernel for big N && lda effective scenario (process before interleave)
#ifndef ZERO_BETA
#ifndef ONE_BETA
static int sbgemv_kernel_32xN_lda_direct_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
#else
static int sbgemv_kernel_32xN_lda_direct_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
#endif
#else
#ifndef ONE_ALPHA
static int sbgemv_kernel_32xN_lda_direct_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
#else
static int sbgemv_kernel_32xN_lda_direct(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
#endif
#endif
{
BLASLONG tag_m_32x = m & (~31);
BLASLONG tag_m_128x = m & (~127);
__m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7, \
accum512_8, accum512_9, accum512_10, accum512_11, accum512_12, accum512_13, accum512_14, accum512_15;
#ifndef ONE_ALPHA
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
#endif
#ifndef ZERO_BETA
__m512 BETAVECTOR = _mm512_set1_ps(beta);
#endif
__m512i matrixArray_seed_0, matrixArray_seed_1, matrixArray_seed_2, matrixArray_seed_3;
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7;
__m512i xArray_0;
__m512i ZERO512 = _mm512_setzero_si512();
unsigned int blend_hi_mask_value = ((unsigned int)0xaaaaaaaa);
__mmask32 blend_hi_mask = *((__mmask32*) &blend_hi_mask_value);
unsigned int blend_lo_mask_value = ((unsigned int)0x55555555);
__mmask32 blend_lo_mask = *((__mmask32*) &blend_lo_mask_value);
__m512i M512_EPI32_8 = _mm512_set1_epi32(8);
__m512i idx_base_0 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0);
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_8);
for (BLASLONG idx_m = 0; idx_m < tag_m_128x; idx_m+=128) {
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
accum512_2 = _mm512_setzero_ps();
accum512_3 = _mm512_setzero_ps();
accum512_4 = _mm512_setzero_ps();
accum512_5 = _mm512_setzero_ps();
accum512_6 = _mm512_setzero_ps();
accum512_7 = _mm512_setzero_ps();
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) {
xArray_0 = _mm512_set1_epi16(x[idx_n]);
BF16_MATRIX_LOAD_1x32(matrixArray_seed_0, a, lda, idx_n, idx_m + 0)
BF16_MATRIX_LOAD_1x32(matrixArray_seed_1, a, lda, idx_n, idx_m + 32)
BF16_MATRIX_LOAD_1x32(matrixArray_seed_2, a, lda, idx_n, idx_m + 64)
BF16_MATRIX_LOAD_1x32(matrixArray_seed_3, a, lda, idx_n, idx_m + 96)
matrixArray_0 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_0);
matrixArray_1 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_0);
matrixArray_2 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_1);
matrixArray_3 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_1);
matrixArray_4 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_2);
matrixArray_5 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_2);
matrixArray_6 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_3);
matrixArray_7 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_3);
BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0)
BF16_DOT_1x32(accum512_1, matrixArray_1, xArray_0)
BF16_DOT_1x32(accum512_2, matrixArray_2, xArray_0)
BF16_DOT_1x32(accum512_3, matrixArray_3, xArray_0)
BF16_DOT_1x32(accum512_4, matrixArray_4, xArray_0)
BF16_DOT_1x32(accum512_5, matrixArray_5, xArray_0)
BF16_DOT_1x32(accum512_6, matrixArray_6, xArray_0)
BF16_DOT_1x32(accum512_7, matrixArray_7, xArray_0)
}
accum512_8 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
accum512_9 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
accum512_10 = _mm512_permutex2var_ps(accum512_2, idx_base_0, accum512_3);
accum512_11 = _mm512_permutex2var_ps(accum512_2, idx_base_1, accum512_3);
accum512_12 = _mm512_permutex2var_ps(accum512_4, idx_base_0, accum512_5);
accum512_13 = _mm512_permutex2var_ps(accum512_4, idx_base_1, accum512_5);
accum512_14 = _mm512_permutex2var_ps(accum512_6, idx_base_0, accum512_7);
accum512_15 = _mm512_permutex2var_ps(accum512_6, idx_base_1, accum512_7);
STORE16_COMPLETE_RESULT(accum512_8, y+idx_m+0)
STORE16_COMPLETE_RESULT(accum512_9, y+idx_m+16)
STORE16_COMPLETE_RESULT(accum512_10, y+idx_m+32)
STORE16_COMPLETE_RESULT(accum512_11, y+idx_m+48)
STORE16_COMPLETE_RESULT(accum512_12, y+idx_m+64)
STORE16_COMPLETE_RESULT(accum512_13, y+idx_m+80)
STORE16_COMPLETE_RESULT(accum512_14, y+idx_m+96)
STORE16_COMPLETE_RESULT(accum512_15, y+idx_m+112)
}
for (BLASLONG idx_m = tag_m_128x; idx_m < tag_m_32x; idx_m+=32) {
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) {
xArray_0 = _mm512_set1_epi16(x[idx_n]);
BF16_MATRIX_LOAD_1x32(matrixArray_seed_0, a, lda, idx_n, idx_m)
matrixArray_0 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_0);
matrixArray_1 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_0);
BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0)
BF16_DOT_1x32(accum512_1, matrixArray_1, xArray_0)
}
accum512_8 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
accum512_9 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
STORE16_COMPLETE_RESULT(accum512_8, y+idx_m+0)
STORE16_COMPLETE_RESULT(accum512_9, y+idx_m+16)
}
if (tag_m_32x != m) {
unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(m&31)));
__mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
unsigned short store_tail_mask_value = (((unsigned int)0xffff) >> (16-(m&15)));
__mmask32 store_tail_mask = *((__mmask32*) &store_tail_mask_value);
accum512_0 = _mm512_setzero_ps();
accum512_1 = _mm512_setzero_ps();
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) {
xArray_0 = _mm512_set1_epi16(x[idx_n]);
BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_seed_0, a, lda, idx_n, tag_m_32x, tail_mask)
matrixArray_0 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_0);
matrixArray_1 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_0);
BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0)
BF16_DOT_1x32(accum512_1, matrixArray_1, xArray_0)
}
accum512_8 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
accum512_9 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
if ((m-tag_m_32x) > 16) {
STORE16_COMPLETE_RESULT(accum512_8, y+tag_m_32x+0)
STORE16_MASK_COMPLETE_RESULT(accum512_9, y+tag_m_32x+16, store_tail_mask)
} else {
STORE16_MASK_COMPLETE_RESULT(accum512_8, y+tag_m_32x+0, store_tail_mask)
}
}
return 0;
}

142
kernel/x86_64/sbgemv_t.c Normal file
View File

@ -0,0 +1,142 @@
/***************************************************************************
Copyright (c) 2014, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
#include "common.h"
#if defined (COOPERLAKE)
#include "sbgemv_t_microk_cooperlake.c"
#endif
#define ALIGN64_ALLOC(alloc_size, TYPE, ptr_align, ptr) \
ptr = (TYPE *) malloc(sizeof(TYPE)*alloc_size + 63); \
ptr_align = ((int)(((uintptr_t)ptr & (uintptr_t)0x3F))!=0) ? (TYPE *)((char *)ptr + (64 - (int)((uintptr_t)ptr & (uintptr_t)0x3F))) : ptr
#define ALIGN64_FREE(ptr) \
free(ptr)
#ifndef HAVE_SBGEMV_T_ACCL_KERNEL
static void sbgemv_kernel_t(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
{
BLASLONG offset_lda, offset_n;
float accum = 0.0;
bfloat16 * a_bf16 = malloc(sizeof(bfloat16)*m*n);
float * a_fp32 = malloc(sizeof(float)*m*n);
float * x_fp32 = malloc(sizeof(float)*n);
for (BLASLONG i=0; i<m; i++) {
offset_lda = lda * i;
offset_n = n * i;
for (BLASLONG j=0; j<n; j++) {
a_bf16[offset_n + j] = a[offset_lda + j];
}
}
SBF16TOS_K(n, x, 1, x_fp32, 1);
SBF16TOS_K(m*n, a_bf16, 1, a_fp32, 1);
for (BLASLONG i=0; i<m; i++) {
offset_n = n * i;
accum = 0.0;
for (BLASLONG j=0; j<n; j++) {
accum += a_fp32[offset_n + j] * x_fp32[j];
}
if (beta == ZERO) {
y[i] = alpha * accum;
} else {
y[i] = alpha * accum + beta * y[i];
}
}
free(a_bf16);
free(a_fp32);
free(x_fp32);
}
#endif
static void bf16_compress_vector(BLASLONG n, bfloat16 * src, bfloat16 * target, BLASLONG inc)
{
for(BLASLONG i=0; i<n; i++) {
target[i] = src[i*inc];
}
}
static void fp32_compress_vector(BLASLONG n, float * src, float * target, BLASLONG inc)
{
for(BLASLONG i=0; i<n; i++) {
target[i] = src[i*inc];
}
}
static void fp32_expand_vector(BLASLONG n, float * src, float * target, BLASLONG inc)
{
for(BLASLONG i=0; i<n; i++) {
target[i*inc] = src[i];
}
}
int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float * y, BLASLONG incy)
{
if ( m < 1 || n < 1) return(0);
bfloat16 * xbuffer_align = x;
float * ybuffer_align = y;
bfloat16 * xbuffer = NULL;
float * ybuffer = NULL;
// Switch m and n
BLASLONG t = m;
m = n;
n = t;
if (incx != 1) {
ALIGN64_ALLOC(n, bfloat16, xbuffer_align, xbuffer);
bf16_compress_vector(n, x, xbuffer_align, incx);
}
if (incy != 1) {
ALIGN64_ALLOC(m, float, ybuffer_align, ybuffer);
if (beta != ZERO) {
fp32_compress_vector(m, y, ybuffer_align, incy);
}
}
sbgemv_kernel_t(m, n, alpha, a, lda, xbuffer_align, beta, ybuffer_align);
if (incy != 1) {
fp32_expand_vector(m, ybuffer_align, y, incy);
ALIGN64_FREE(ybuffer);
}
if (incx != 1) {
ALIGN64_FREE(xbuffer);
}
return(0);
}

View File

@ -0,0 +1,202 @@
/***************************************************************************
Copyright (c) 2014, The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/
/* need a new enough GCC for avx512 support */
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
#define HAVE_SBGEMV_T_ACCL_KERNEL 1
// Define micro kernels for ALPHA not ONE && BETA effective && BETA not ONE scenarios
#undef ZERO_BETA
#undef ONE_BETA
#undef ONE_ALPHA
#include "sbgemv_t_microk_cooperlake_template.c"
// Define micro kernels for ALPHA not ONE && BETA as ONE scenarios
#undef ZERO_BETA
#define ONE_BETA 1
#undef ONE_ALPHA
#include "sbgemv_t_microk_cooperlake_template.c"
// Define micro kernels for ALPHA not ONE && BETA in-effective (BETA == 0) scenarios
#define ZERO_BETA 1
#undef ONE_ALPHA
#include "sbgemv_t_microk_cooperlake_template.c"
// Define micro kernels for ALPHA as ONE && BETA in-effective (BETA == 0) scenarios
#define ZERO_BETA 1
#define ONE_ALPHA 1
#include "sbgemv_t_microk_cooperlake_template.c"
static int sbgemv_kernel_t(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
{
if (beta == ZERO) { // BETA == 0.0, no need to accumulate the original Y data
if (alpha == ONE) { // ALPHA == 1.0, no need to multipy ALPHA
if (n > 127) {
sbgemv_kernel_1x128_lda_direct(m, n, alpha, a, lda, x, y);
} else if (n > 32) {
sbgemv_kernel_8x32_lda_direct(m, n, alpha, a, lda, x, y);
} else {
if (n > 16) {
sbgemv_kernel_8x16p_lda(m, n, alpha, a, lda, x, y);
} else {
if (lda == n) {
switch(n) {
case 1: sbgemv_kernel_32x1 (m, alpha, a, x, y); break;
case 2: sbgemv_kernel_32x2 (m, alpha, a, x, y); break;
case 3: sbgemv_kernel_32x3 (m, alpha, a, x, y); break;
case 4: sbgemv_kernel_16x4 (m, alpha, a, x, y); break;
case 5: sbgemv_kernel_30x5 (m, alpha, a, x, y); break;
case 6: sbgemv_kernel_16x6 (m, alpha, a, x, y); break;
case 7: sbgemv_kernel_16x7 (m, alpha, a, x, y); break;
case 8: sbgemv_kernel_16x8 (m, alpha, a, x, y); break;
case 9: sbgemv_kernel_14x9 (m, alpha, a, x, y); break;
case 10: sbgemv_kernel_12x10(m, alpha, a, x, y); break;
case 11: sbgemv_kernel_15x11(m, alpha, a, x, y); break;
case 12: sbgemv_kernel_15x12(m, alpha, a, x, y); break;
case 13: sbgemv_kernel_16x13(m, alpha, a, x, y); break;
case 14: sbgemv_kernel_16x14(m, alpha, a, x, y); break;
case 15: sbgemv_kernel_16x15(m, alpha, a, x, y); break;
case 16: sbgemv_kernel_16x16(m, alpha, a, x, y); break;
default: break;
}
} else {
sbgemv_kernel_8x16m_lda(m, n, alpha, a, lda, x, y);
}
}
}
} else { // ALPHA != 1.0, need to multipy ALPHA
if (n > 127) {
sbgemv_kernel_1x128_lda_direct_alpha(m, n, alpha, a, lda, x, y);
} else if (n > 32) {
sbgemv_kernel_8x32_lda_direct_alpha(m, n, alpha, a, lda, x, y);
} else {
if (n > 16) {
sbgemv_kernel_8x16p_lda_alpha(m, n, alpha, a, lda, x, y);
} else {
if (lda == n) {
switch(n) {
case 1: sbgemv_kernel_32x1_alpha (m, alpha, a, x, y); break;
case 2: sbgemv_kernel_32x2_alpha (m, alpha, a, x, y); break;
case 3: sbgemv_kernel_32x3_alpha (m, alpha, a, x, y); break;
case 4: sbgemv_kernel_16x4_alpha (m, alpha, a, x, y); break;
case 5: sbgemv_kernel_30x5_alpha (m, alpha, a, x, y); break;
case 6: sbgemv_kernel_16x6_alpha (m, alpha, a, x, y); break;
case 7: sbgemv_kernel_16x7_alpha (m, alpha, a, x, y); break;
case 8: sbgemv_kernel_16x8_alpha (m, alpha, a, x, y); break;
case 9: sbgemv_kernel_14x9_alpha (m, alpha, a, x, y); break;
case 10: sbgemv_kernel_12x10_alpha(m, alpha, a, x, y); break;
case 11: sbgemv_kernel_15x11_alpha(m, alpha, a, x, y); break;
case 12: sbgemv_kernel_15x12_alpha(m, alpha, a, x, y); break;
case 13: sbgemv_kernel_16x13_alpha(m, alpha, a, x, y); break;
case 14: sbgemv_kernel_16x14_alpha(m, alpha, a, x, y); break;
case 15: sbgemv_kernel_16x15_alpha(m, alpha, a, x, y); break;
case 16: sbgemv_kernel_16x16_alpha(m, alpha, a, x, y); break;
default: break;
}
} else {
sbgemv_kernel_8x16m_lda_alpha(m, n, alpha, a, lda, x, y);
}
}
}
}
} else { // BETA != 0.0, need to accumulate the original Y data no matter what ALPHA is
if (beta == ONE) {
if (n > 127) {
sbgemv_kernel_1x128_lda_direct_alpha_one(m, n, alpha, a, lda, x, beta, y);
} else if (n > 32) {
sbgemv_kernel_8x32_lda_direct_alpha_one(m, n, alpha, a, lda, x, beta, y);
} else {
if (n > 16) {
sbgemv_kernel_8x16p_lda_alpha_one(m, n, alpha, a, lda, x, beta, y);
} else {
if (lda == n) {
switch(n) {
case 1: sbgemv_kernel_32x1_alpha_one (m, alpha, a, x, beta, y); break;
case 2: sbgemv_kernel_32x2_alpha_one (m, alpha, a, x, beta, y); break;
case 3: sbgemv_kernel_32x3_alpha_one (m, alpha, a, x, beta, y); break;
case 4: sbgemv_kernel_16x4_alpha_one (m, alpha, a, x, beta, y); break;
case 5: sbgemv_kernel_30x5_alpha_one (m, alpha, a, x, beta, y); break;
case 6: sbgemv_kernel_16x6_alpha_one (m, alpha, a, x, beta, y); break;
case 7: sbgemv_kernel_16x7_alpha_one (m, alpha, a, x, beta, y); break;
case 8: sbgemv_kernel_16x8_alpha_one (m, alpha, a, x, beta, y); break;
case 9: sbgemv_kernel_14x9_alpha_one (m, alpha, a, x, beta, y); break;
case 10: sbgemv_kernel_12x10_alpha_one(m, alpha, a, x, beta, y); break;
case 11: sbgemv_kernel_15x11_alpha_one(m, alpha, a, x, beta, y); break;
case 12: sbgemv_kernel_15x12_alpha_one(m, alpha, a, x, beta, y); break;
case 13: sbgemv_kernel_16x13_alpha_one(m, alpha, a, x, beta, y); break;
case 14: sbgemv_kernel_16x14_alpha_one(m, alpha, a, x, beta, y); break;
case 15: sbgemv_kernel_16x15_alpha_one(m, alpha, a, x, beta, y); break;
case 16: sbgemv_kernel_16x16_alpha_one(m, alpha, a, x, beta, y); break;
default: break;
}
} else {
sbgemv_kernel_8x16m_lda_alpha_one(m, n, alpha, a, lda, x, beta, y);
}
}
}
} else {
if (n > 127) {
sbgemv_kernel_1x128_lda_direct_alpha_beta(m, n, alpha, a, lda, x, beta, y);
} else if (n > 32) {
sbgemv_kernel_8x32_lda_direct_alpha_beta(m, n, alpha, a, lda, x, beta, y);
} else {
if (n > 16) {
sbgemv_kernel_8x16p_lda_alpha_beta(m, n, alpha, a, lda, x, beta, y);
} else {
if (lda == n) {
switch(n) {
case 1: sbgemv_kernel_32x1_alpha_beta (m, alpha, a, x, beta, y); break;
case 2: sbgemv_kernel_32x2_alpha_beta (m, alpha, a, x, beta, y); break;
case 3: sbgemv_kernel_32x3_alpha_beta (m, alpha, a, x, beta, y); break;
case 4: sbgemv_kernel_16x4_alpha_beta (m, alpha, a, x, beta, y); break;
case 5: sbgemv_kernel_30x5_alpha_beta (m, alpha, a, x, beta, y); break;
case 6: sbgemv_kernel_16x6_alpha_beta (m, alpha, a, x, beta, y); break;
case 7: sbgemv_kernel_16x7_alpha_beta (m, alpha, a, x, beta, y); break;
case 8: sbgemv_kernel_16x8_alpha_beta (m, alpha, a, x, beta, y); break;
case 9: sbgemv_kernel_14x9_alpha_beta (m, alpha, a, x, beta, y); break;
case 10: sbgemv_kernel_12x10_alpha_beta(m, alpha, a, x, beta, y); break;
case 11: sbgemv_kernel_15x11_alpha_beta(m, alpha, a, x, beta, y); break;
case 12: sbgemv_kernel_15x12_alpha_beta(m, alpha, a, x, beta, y); break;
case 13: sbgemv_kernel_16x13_alpha_beta(m, alpha, a, x, beta, y); break;
case 14: sbgemv_kernel_16x14_alpha_beta(m, alpha, a, x, beta, y); break;
case 15: sbgemv_kernel_16x15_alpha_beta(m, alpha, a, x, beta, y); break;
case 16: sbgemv_kernel_16x16_alpha_beta(m, alpha, a, x, beta, y); break;
default: break;
}
} else {
sbgemv_kernel_8x16m_lda_alpha_beta(m, n, alpha, a, lda, x, beta, y);
}
}
}
}
}
return 0;
}
#endif

File diff suppressed because it is too large Load Diff