Vectorize BF16 GEMV (VSX & MMA). Use GEMM_GEMV_FORWARD_BF16 (for Power).
This commit is contained in:
parent
3184b7f209
commit
36bd3eeddf
|
@ -282,15 +282,19 @@ GEMM_GEMV_FORWARD = 1
|
|||
endif
|
||||
ifeq ($(ARCH), power)
|
||||
GEMM_GEMV_FORWARD = 1
|
||||
GEMM_GEMV_FORWARD_BF16 = 1
|
||||
endif
|
||||
|
||||
ifeq ($(SMALL_MATRIX_OPT), 1)
|
||||
CCOMMON_OPT += -DSMALL_MATRIX_OPT
|
||||
endif
|
||||
ifeq ($(GEMM_GEMV_FORWARD), 1)
|
||||
ifneq ($(ONLY_CBLAS), 1)
|
||||
ifeq ($(GEMM_GEMV_FORWARD), 1)
|
||||
CCOMMON_OPT += -DGEMM_GEMV_FORWARD
|
||||
endif
|
||||
ifeq ($(GEMM_GEMV_FORWARD_BF16), 1)
|
||||
CCOMMON_OPT += -DGEMM_GEMV_FORWARD_BF16
|
||||
endif
|
||||
endif
|
||||
|
||||
# This operation is expensive, so execution should be once.
|
||||
|
|
|
@ -398,6 +398,9 @@ endif ()
|
|||
if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS)
|
||||
set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD")
|
||||
endif ()
|
||||
if (GEMM_GEMV_FORWARD_BF16 AND NOT ONLY_CBLAS)
|
||||
set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD_BF16")
|
||||
endif ()
|
||||
if (SMALL_MATRIX_OPT)
|
||||
set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT")
|
||||
endif ()
|
||||
|
|
|
@ -498,7 +498,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
|
|||
args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
|
||||
#endif
|
||||
|
||||
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && !defined(BFLOAT16)
|
||||
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16))
|
||||
// Check if we can convert GEMM -> GEMV
|
||||
if (args.k != 0) {
|
||||
if (args.n == 1) {
|
||||
|
|
|
@ -236,11 +236,13 @@ ZSWAPKERNEL = zswap.c
|
|||
#
|
||||
|
||||
SGEMVNKERNEL = sgemv_n.c
|
||||
SBGEMVNKERNEL = sbgemv_n_power10.c
|
||||
DGEMVNKERNEL = dgemv_n_power10.c
|
||||
CGEMVNKERNEL = cgemv_n.c
|
||||
ZGEMVNKERNEL = zgemv_n_power10.c
|
||||
#
|
||||
SGEMVTKERNEL = sgemv_t.c
|
||||
SBGEMVTKERNEL = sbgemv_t_power10.c
|
||||
DGEMVTKERNEL = dgemv_t_power10.c
|
||||
CGEMVTKERNEL = cgemv_t.c
|
||||
ZGEMVTKERNEL = zgemv_t_4.c
|
||||
|
|
|
@ -257,11 +257,13 @@ ZSWAPKERNEL = zswap.c
|
|||
#
|
||||
|
||||
SGEMVNKERNEL = sgemv_n.c
|
||||
SBGEMVNKERNEL = sbgemv_n_vsx.c
|
||||
DGEMVNKERNEL = dgemv_n.c
|
||||
CGEMVNKERNEL = cgemv_n.c
|
||||
ZGEMVNKERNEL = zgemv_n_4.c
|
||||
#
|
||||
SGEMVTKERNEL = sgemv_t.c
|
||||
SBGEMVTKERNEL = sbgemv_t_vsx.c
|
||||
DGEMVTKERNEL = dgemv_t.c
|
||||
CGEMVTKERNEL = cgemv_t.c
|
||||
ZGEMVTKERNEL = zgemv_t_4.c
|
||||
|
|
|
@ -181,11 +181,13 @@ ZSWAPKERNEL = zswap.c
|
|||
#
|
||||
|
||||
SGEMVNKERNEL = sgemv_n.c
|
||||
SBGEMVNKERNEL = sbgemv_n_vsx.c
|
||||
DGEMVNKERNEL = dgemv_n.c
|
||||
CGEMVNKERNEL = cgemv_n.c
|
||||
ZGEMVNKERNEL = zgemv_n_4.c
|
||||
#
|
||||
SGEMVTKERNEL = sgemv_t.c
|
||||
SBGEMVTKERNEL = sbgemv_t_vsx.c
|
||||
DGEMVTKERNEL = dgemv_t.c
|
||||
CGEMVTKERNEL = cgemv_t.c
|
||||
ZGEMVTKERNEL = zgemv_t_4.c
|
||||
|
|
|
@ -0,0 +1,153 @@
|
|||
#ifndef GEMM_COMMON_C
|
||||
#define GEMM_COMMON_C
|
||||
#include "common.h"
|
||||
|
||||
#include <altivec.h>
|
||||
#include <inttypes.h>
|
||||
|
||||
#define NBMAX 4096
|
||||
|
||||
#define FORCEINLINE inline __attribute__((always_inline))
|
||||
|
||||
#ifdef _ARCH_PWR10
|
||||
#ifdef __has_builtin
|
||||
#if !__has_builtin(__builtin_vsx_assemble_pair)
|
||||
#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
|
||||
#endif
|
||||
#if !__has_builtin(__builtin_vsx_disassemble_pair)
|
||||
#define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v1, v0)
|
||||
#else
|
||||
#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v0, v1)
|
||||
#endif
|
||||
|
||||
#define USE_VECTOR_PAIRS
|
||||
#endif
|
||||
|
||||
typedef __vector IFLOAT vec_bf16;
|
||||
typedef __vector FLOAT vec_f32;
|
||||
typedef __vector unsigned char vec_uc8;
|
||||
|
||||
FORCEINLINE vec_uc8 vec_load_vec(void *src)
|
||||
{
|
||||
return vec_xl(0, (unsigned char *)(src));
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_pair(vec_f32 *dst, vec_f32 *src)
|
||||
{
|
||||
#ifdef USE_VECTOR_PAIRS
|
||||
__vector_pair vy0p;
|
||||
#ifdef __clang__
|
||||
vy0p = __builtin_vsx_lxvp(0L, (const __vector_pair *)(src));
|
||||
#else
|
||||
vy0p = *(__vector_pair *)(src);
|
||||
#endif
|
||||
__builtin_vsx_disassemble_pair((void *)(dst), &vy0p);
|
||||
#else
|
||||
dst[0] = src[0];
|
||||
dst[1] = src[1];
|
||||
#endif
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_store_pair(vec_f32 *dst, vec_f32 *src)
|
||||
{
|
||||
#ifdef USE_VECTOR_PAIRS
|
||||
__vector_pair vy0p;
|
||||
__builtin_vsx_assemble_pair2(&vy0p, (vec_uc8)src[1], (vec_uc8)src[0]);
|
||||
#ifdef __clang__
|
||||
__builtin_vsx_stxvp(vy0p, 0L, (__vector_pair *)(dst));
|
||||
#else
|
||||
*(__vector_pair *)(dst) = vy0p;
|
||||
#endif
|
||||
#else
|
||||
dst[0] = src[0];
|
||||
dst[1] = src[1];
|
||||
#endif
|
||||
}
|
||||
|
||||
FORCEINLINE vec_bf16 vec_loadN(void *src, BLASLONG n)
|
||||
{
|
||||
IFLOAT *src2 = (IFLOAT *)(src);
|
||||
#ifdef _ARCH_PWR9
|
||||
return vec_xl_len(src2, n * sizeof(IFLOAT));
|
||||
#else
|
||||
__attribute__((aligned(16))) IFLOAT data[sizeof(vec_bf16) / sizeof(IFLOAT)];
|
||||
memset(data, 0, sizeof(vec_bf16));
|
||||
if (n & 4) {
|
||||
memcpy(data, src2, sizeof(uint64_t));
|
||||
}
|
||||
if (n & 2) {
|
||||
BLASLONG n4 = n & 4;
|
||||
memcpy(data + n4, src2 + n4, sizeof(uint32_t));
|
||||
}
|
||||
if (n & 1) {
|
||||
BLASLONG n6 = n & 6;
|
||||
data[n6] = src2[n6];
|
||||
}
|
||||
return (vec_bf16)vec_load_vec(data);
|
||||
#endif
|
||||
}
|
||||
|
||||
FORCEINLINE vec_f32 vec_loadN_f32(void *src, BLASLONG n)
|
||||
{
|
||||
#ifndef _ARCH_PWR9
|
||||
if (n & 4) {
|
||||
return (vec_f32)vec_load_vec(src);
|
||||
}
|
||||
#endif
|
||||
return (vec_f32)vec_loadN(src, n * (sizeof(FLOAT) / sizeof(IFLOAT)));
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_loadN2_f32(vec_f32 *data, vec_f32 *src, BLASLONG n)
|
||||
{
|
||||
data[0] = src[0];
|
||||
data[1] = vec_loadN_f32(&src[1], n);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_storeN(vec_bf16 data, void *dst, BLASLONG n)
|
||||
{
|
||||
IFLOAT *dst2 = (IFLOAT *)(dst);
|
||||
#ifdef _ARCH_PWR9
|
||||
vec_xst_len(data, dst2, n * sizeof(IFLOAT));
|
||||
#else
|
||||
if (n & 8) {
|
||||
vec_xst(data, 0, dst2);
|
||||
return;
|
||||
}
|
||||
__attribute__((aligned(16))) IFLOAT data2[sizeof(vec_f32) / sizeof(IFLOAT)];
|
||||
vec_xst(data, 0, data2);
|
||||
if (n & 4) {
|
||||
memcpy(dst2, data2, sizeof(uint64_t));
|
||||
}
|
||||
if (n & 2) {
|
||||
BLASLONG n4 = n & 4;
|
||||
memcpy(dst2 + n4, data2 + n4, sizeof(uint32_t));
|
||||
}
|
||||
if (n & 1) {
|
||||
BLASLONG n6 = n & 6;
|
||||
dst2[n6] = data2[n6];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n)
|
||||
{
|
||||
#ifndef _ARCH_PWR9
|
||||
if (n & 4) {
|
||||
vec_xst(data, 0, (FLOAT *)dst);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
return vec_storeN((vec_bf16)data, dst, n * (sizeof(FLOAT) / sizeof(IFLOAT)));
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_storeN2_f32(vec_f32 *data, vec_f32 *dst, BLASLONG n)
|
||||
{
|
||||
dst[0] = data[0];
|
||||
vec_storeN_f32(data[1], &dst[1], n);
|
||||
}
|
||||
#endif
|
|
@ -0,0 +1,223 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2024, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef SBGEMV_COMMON_C
|
||||
#define SBGEMV_COMMON_C
|
||||
#include "gemm_common.c"
|
||||
|
||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
#define BF16_HI(data, zero) (vec_f32)vec_mergeh(data, zero)
|
||||
#define BF16_LO(data, zero) (vec_f32)vec_mergel(data, zero)
|
||||
#else
|
||||
#define BF16_HI(data, zero) (vec_f32)vec_mergeh(zero, data)
|
||||
#define BF16_LO(data, zero) (vec_f32)vec_mergel(zero, data)
|
||||
#endif
|
||||
|
||||
FORCEINLINE vec_f32 vec_loadNHi(void *src, BLASLONG n, vec_bf16 zero)
|
||||
{
|
||||
vec_bf16 data = vec_loadN(src, n);
|
||||
return BF16_HI(data, zero);
|
||||
}
|
||||
|
||||
FORCEINLINE vec_f32 vec_mult(vec_f32 *inp, vec_bf16 in0, vec_bf16 zero)
|
||||
{
|
||||
vec_f32 v_in00 = BF16_HI(in0, zero);
|
||||
vec_f32 v_in01 = BF16_LO(in0, zero);
|
||||
|
||||
return (inp[0] * v_in00) + (inp[1] * v_in01);
|
||||
}
|
||||
|
||||
FORCEINLINE vec_f32 vec_load_mult(vec_bf16 *in, vec_f32 *inp, vec_bf16 zero)
|
||||
{
|
||||
vec_bf16 in0 = (vec_bf16)vec_load_vec(in);
|
||||
|
||||
return vec_mult(inp, in0, zero);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_vec2(vec_bf16 *in, vec_f32 *v_x0, vec_bf16 zero)
|
||||
{
|
||||
vec_bf16 inp = (vec_bf16)vec_load_vec(in);
|
||||
|
||||
v_x0[0] = BF16_HI(inp, zero);
|
||||
v_x0[1] = BF16_LO(inp, zero);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult2(vec_f32 v_x0, vec_bf16 in0, vec_bf16 zero, vec_f32 *vy0)
|
||||
{
|
||||
vec_f32 v_in00 = BF16_HI(in0, zero);
|
||||
vec_f32 v_in01 = BF16_LO(in0, zero);
|
||||
|
||||
vy0[0] += (v_x0 * v_in00);
|
||||
vy0[1] += (v_x0 * v_in01);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult2(vec_f32 v_x0, vec_bf16 *in, vec_bf16 zero, vec_f32 *vy0)
|
||||
{
|
||||
vec_bf16 in0 = (vec_bf16)vec_load_vec(in);
|
||||
|
||||
vec_mult2(v_x0, in0, zero, vy0);
|
||||
}
|
||||
|
||||
FORCEINLINE vec_f32 vec_loadN_mult(vec_bf16 *in, vec_f32 *inp, BLASLONG n, vec_bf16 zero)
|
||||
{
|
||||
vec_bf16 in0 = vec_loadN(in, n);
|
||||
|
||||
return vec_mult(inp, in0, zero);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_loadN_vec2(vec_bf16 *in, vec_f32 *v_x0, BLASLONG n, vec_bf16 zero)
|
||||
{
|
||||
vec_bf16 inp = vec_loadN(in, n);
|
||||
|
||||
v_x0[0] = BF16_HI(inp, zero);
|
||||
v_x0[1] = BF16_LO(inp, zero);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_loadN_mult2(vec_f32 v_x0, vec_bf16 *in, BLASLONG n, vec_bf16 zero, vec_f32 *vy0)
|
||||
{
|
||||
vec_bf16 in0 = vec_loadN(in, n);
|
||||
|
||||
vec_mult2(v_x0, in0, zero, vy0);
|
||||
}
|
||||
|
||||
FORCEINLINE vec_f32 vec_loadNHi_mult(vec_bf16 *in, vec_f32 v_inp0, BLASLONG n, vec_bf16 zero)
|
||||
{
|
||||
vec_f32 v_in00 = vec_loadNHi(in, n, zero);
|
||||
|
||||
return (v_inp0 * v_in00);
|
||||
}
|
||||
|
||||
FORCEINLINE void copy_x(BLASLONG n, IFLOAT *src, IFLOAT *dest, BLASLONG inc_src)
|
||||
{
|
||||
for (BLASLONG i = 0; i < n; i++) {
|
||||
*dest++ = *src;
|
||||
src += inc_src;
|
||||
}
|
||||
}
|
||||
|
||||
FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta)
|
||||
{
|
||||
if (beta == (FLOAT)0) {
|
||||
memset(dest, 0, n * sizeof(FLOAT));
|
||||
} else if (beta == (FLOAT)1) {
|
||||
for (BLASLONG i = 0; i < n; i++) {
|
||||
*dest++ = *src;
|
||||
src += inc_src;
|
||||
}
|
||||
} else {
|
||||
for (BLASLONG i = 0; i < n; i++) {
|
||||
*dest++ = *src * beta;
|
||||
src += inc_src;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FORCEINLINE void move_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest)
|
||||
{
|
||||
for (BLASLONG i = 0; i < n; i++) {
|
||||
*dest = *src++;
|
||||
dest += inc_dest;
|
||||
}
|
||||
}
|
||||
|
||||
FORCEINLINE void copy_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta)
|
||||
{
|
||||
if (beta == (FLOAT)0) {
|
||||
move_y(n, src, dest, inc_src);
|
||||
} else if (beta == (FLOAT)1) {
|
||||
for (BLASLONG i = 0; i < n; i++) {
|
||||
*dest += *src++;
|
||||
dest += inc_src;
|
||||
}
|
||||
} else {
|
||||
for (BLASLONG i = 0; i < n; i++) {
|
||||
*dest = *src++ + (beta * *dest);
|
||||
dest += inc_src;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta)
|
||||
{
|
||||
if (beta == (FLOAT)0) {
|
||||
memset(output_vector, 0, sizeof(FLOAT) * n);
|
||||
} else if (beta == (FLOAT)1) {
|
||||
if (output_vector != input_vector) {
|
||||
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
|
||||
}
|
||||
} else {
|
||||
vec_f32 b = { beta, beta, beta, beta };
|
||||
|
||||
vec_f32 *in = (vec_f32 *)input_vector;
|
||||
vec_f32 *out = (vec_f32 *)output_vector;
|
||||
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
vec_f32 v_inp0[2];
|
||||
|
||||
for (; i + 4 <= n8; i += 4) {
|
||||
vec_f32 v_inp1[2], v_inp2[2], v_inp3[2];
|
||||
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
|
||||
vec_load_pair(v_inp1, &in[(i * 2) + 2]);
|
||||
vec_load_pair(v_inp2, &in[(i * 2) + 4]);
|
||||
vec_load_pair(v_inp3, &in[(i * 2) + 6]);
|
||||
v_inp0[0] *= b;
|
||||
v_inp0[1] *= b;
|
||||
v_inp1[0] *= b;
|
||||
v_inp1[1] *= b;
|
||||
v_inp2[0] *= b;
|
||||
v_inp2[1] *= b;
|
||||
v_inp3[0] *= b;
|
||||
v_inp3[1] *= b;
|
||||
vec_store_pair(&out[(i * 2) + 0], v_inp0);
|
||||
vec_store_pair(&out[(i * 2) + 2], v_inp1);
|
||||
vec_store_pair(&out[(i * 2) + 4], v_inp2);
|
||||
vec_store_pair(&out[(i * 2) + 6], v_inp3);
|
||||
}
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
|
||||
v_inp0[0] *= b;
|
||||
v_inp0[1] *= b;
|
||||
vec_store_pair(&out[(i * 2) + 0], v_inp0);
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n > 4) {
|
||||
BLASLONG n3 = n & 3;
|
||||
vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3);
|
||||
v_inp0[0] *= b;
|
||||
v_inp0[1] *= b;
|
||||
vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3);
|
||||
} else if (n) {
|
||||
v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n);
|
||||
v_inp0[0] *= b;
|
||||
vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
|
@ -0,0 +1,629 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2024, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef SBGEMV_COMMON_MMA_C
|
||||
#define SBGEMV_COMMON_MMA_C
|
||||
#include "sbgemv_common.c"
|
||||
|
||||
#if defined(_AIX) || defined(__clang__)
|
||||
#define USE_MERGE_MMA
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_load_pair2(vec_bf16 *in0, vec_bf16 *in)
|
||||
{
|
||||
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0));
|
||||
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2));
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in0 = (vec_bf16)vec_load_vec(in);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in11 = (vec_bf16)vec_load_vec(in1);
|
||||
|
||||
vec_load_mult_mma(out, in0, inp);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in21 = (vec_bf16)vec_load_vec(in2);
|
||||
vec_bf16 in31 = (vec_bf16)vec_load_vec(in3);
|
||||
|
||||
vec_load_mult12a_mma(out, in0, in1, inp);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp);
|
||||
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in0[2];
|
||||
|
||||
vec_load_pair((vec_f32 *)in0, (vec_f32 *)in);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult2d_mma(__vector_quad *out, vec_bf16 *in01, vec_bf16 *in11, vec_bf16 *inp)
|
||||
{
|
||||
__builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]);
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in01[2], in11[2];
|
||||
|
||||
vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0);
|
||||
vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1);
|
||||
|
||||
vec_mult2d_mma(out, in01 + 0, in11 + 0, inp + 0);
|
||||
vec_mult2d_mma(out, in01 + 1, in11 + 1, inp + 1);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in01[2], in11[2], in21[2], in31[2];
|
||||
|
||||
vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0);
|
||||
vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1);
|
||||
vec_load_pair((vec_f32 *)in21, (vec_f32 *)in2);
|
||||
vec_load_pair((vec_f32 *)in31, (vec_f32 *)in3);
|
||||
|
||||
vec_mult2d_mma(out + 0, in01 + 0, in11 + 0, inp + 0);
|
||||
vec_mult2d_mma(out + 2, in21 + 0, in31 + 0, inp + 0);
|
||||
vec_mult2d_mma(out + 0, in01 + 1, in11 + 1, inp + 1);
|
||||
vec_mult2d_mma(out + 2, in21 + 1, in31 + 1, inp + 1);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult4_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in0[2];
|
||||
|
||||
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 2));
|
||||
|
||||
vec_load_mult2_mma(out, in + 0, inp + 0);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[2]);
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[3]);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in01[4], in11[4];
|
||||
|
||||
vec_load_pair2(in01, in0);
|
||||
vec_load_pair2(in11, in1);
|
||||
|
||||
vec_mult2d_mma(out, in01 + 0, in11 + 0, inp + 0);
|
||||
vec_mult2d_mma(out, in01 + 1, in11 + 1, inp + 1);
|
||||
vec_mult2d_mma(out, in01 + 2, in11 + 2, inp + 2);
|
||||
vec_mult2d_mma(out, in01 + 3, in11 + 3, inp + 3);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult4d_mma(__vector_quad *out, vec_bf16 *in01, vec_bf16 *in11, vec_bf16 *in21, vec_bf16 *in31, vec_bf16 *inp)
|
||||
{
|
||||
vec_mult2d_mma(out + 0, in01, in11, inp);
|
||||
vec_mult2d_mma(out + 2, in21, in31, inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in01[4], in11[4], in21[4], in31[4];
|
||||
|
||||
vec_load_pair2(in01, in0);
|
||||
vec_load_pair2(in11, in1);
|
||||
vec_load_pair2(in21, in2);
|
||||
vec_load_pair2(in31, in3);
|
||||
|
||||
vec_mult4d_mma(out, in01 + 0, in11 + 0, in21 + 0, in31 + 0, inp + 0);
|
||||
vec_mult4d_mma(out, in01 + 1, in11 + 1, in21 + 1, in31 + 1, inp + 1);
|
||||
vec_mult4d_mma(out, in01 + 2, in11 + 2, in21 + 2, in31 + 2, inp + 2);
|
||||
vec_mult4d_mma(out, in01 + 3, in11 + 3, in21 + 3, in31 + 3, inp + 3);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
vec_bf16 in0 = vec_loadN(in, n);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_loadN_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n);
|
||||
|
||||
vec_loadN_mult_mma(out, in0, inp, n);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_loadN_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
vec_bf16 in21 = (vec_bf16)vec_loadN(in2, n);
|
||||
vec_bf16 in31 = (vec_bf16)vec_loadN(in3, n);
|
||||
|
||||
vec_loadN_mult12a_mma(out, in0, in1, inp, n);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp);
|
||||
__builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult1_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in00 = vec_mergeh(in0, in0);
|
||||
|
||||
__builtin_mma_xvbf16ger2(out, (vec_uc8)inp, (vec_uc8)in00);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult2_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in01 = vec_mergel(in0, in0);
|
||||
|
||||
vec_mult1_mma(&out[0], in0, inp);
|
||||
|
||||
__builtin_mma_xvbf16ger2(&out[1], (vec_uc8)inp, (vec_uc8)in01);
|
||||
}
|
||||
|
||||
#ifndef USE_MERGE_MMA
|
||||
FORCEINLINE void vec_mult4_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 inp)
|
||||
{
|
||||
vec_mult2_mma(out + 0, in0[0], inp);
|
||||
vec_mult2_mma(out + 2, in0[1], inp);
|
||||
}
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_loadN_mult11_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
vec_bf16 in0 = vec_loadN(in, n);
|
||||
|
||||
vec_mult1_mma(out, in0, inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_loadN_mult12_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
vec_bf16 in0 = vec_loadN(in, n);
|
||||
|
||||
vec_mult2_mma(out, in0, inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult12_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in0 = (vec_bf16)vec_load_vec(in);
|
||||
|
||||
vec_mult2_mma(out, in0, inp);
|
||||
}
|
||||
|
||||
#ifndef USE_MERGE_MMA
|
||||
FORCEINLINE void vec_load_mult18_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in0[4];
|
||||
|
||||
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0));
|
||||
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2));
|
||||
|
||||
vec_mult4_mma(&out[0], in0 + 0, inp);
|
||||
vec_mult4_mma(&out[4], in0 + 2, inp);
|
||||
}
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_reduce1_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
|
||||
{
|
||||
__builtin_mma_disassemble_acc((void*)temp, &out[0]);
|
||||
|
||||
vy0[0] += (temp[0] * v_alpha);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_reduce2_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
|
||||
{
|
||||
vec_reduce1_mma(&out[0], &temp[0], v_alpha, &vy0[0]);
|
||||
vec_reduce1_mma(&out[1], &temp[4], v_alpha, &vy0[1]);
|
||||
}
|
||||
|
||||
#ifndef USE_MERGE_MMA
|
||||
FORCEINLINE void vec_reduce8_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
|
||||
{
|
||||
vec_reduce2_mma(&out[0], &temp[0], v_alpha, vy0 + 0);
|
||||
vec_reduce2_mma(&out[2], &temp[8], v_alpha, vy0 + 2);
|
||||
vec_reduce2_mma(&out[4], &temp[16], v_alpha, vy0 + 4);
|
||||
vec_reduce2_mma(&out[6], &temp[24], v_alpha, vy0 + 6);
|
||||
}
|
||||
#else
|
||||
FORCEINLINE void vec_reduce44_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
|
||||
{
|
||||
__builtin_mma_disassemble_acc((void*)temp, &out[0]);
|
||||
|
||||
vy0[0] += (temp[0] * v_alpha);
|
||||
vy0[2] += (temp[1] * v_alpha);
|
||||
vy0[4] += (temp[2] * v_alpha);
|
||||
vy0[6] += (temp[3] * v_alpha);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_reduce84_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
|
||||
{
|
||||
vec_reduce44_mma(&out[0], &temp[0], v_alpha, vy0 + 0);
|
||||
vec_reduce44_mma(&out[1], &temp[4], v_alpha, vy0 + 1);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_reduce88_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
|
||||
{
|
||||
vec_reduce44_mma(&out[0], &temp[ 0], v_alpha, vy0 + 0);
|
||||
vec_reduce44_mma(&out[1], &temp[ 4], v_alpha, vy0 + 1);
|
||||
vec_reduce44_mma(&out[2], &temp[ 8], v_alpha, vy0 + 8);
|
||||
vec_reduce44_mma(&out[3], &temp[12], v_alpha, vy0 + 9);
|
||||
}
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_mult11a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in00 = vec_mergeh(in0, in1);
|
||||
|
||||
__builtin_mma_xvbf16ger2(out, (vec_uc8)inp, (vec_uc8)in00);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult2a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in01 = vec_mergel(in0, in1);
|
||||
|
||||
vec_mult11a_mma(&out[0], in0, in1, inp);
|
||||
|
||||
__builtin_mma_xvbf16ger2(&out[1], (vec_uc8)inp, (vec_uc8)in01);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult4a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp)
|
||||
{
|
||||
vec_mult2a_mma(out + 0, in0[0], in1[0], inp);
|
||||
vec_mult2a_mma(out + 2, in0[1], in1[1], inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_loadN_mult11a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
vec_bf16 in0 = vec_loadN(ina, n);
|
||||
vec_bf16 in1 = vec_loadN(inb, n);
|
||||
|
||||
vec_mult11a_mma(out, in0, in1, inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in0 = (vec_bf16)vec_load_vec(ina);
|
||||
vec_bf16 in1 = (vec_bf16)vec_load_vec(inb);
|
||||
|
||||
vec_mult2a_mma(out, in0, in1, inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load4_mma(vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *ina, vec_bf16 *inb)
|
||||
{
|
||||
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(ina + 0));
|
||||
vec_load_pair((vec_f32 *)(in1 + 0), (vec_f32 *)(inb + 0));
|
||||
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(ina + 2));
|
||||
vec_load_pair((vec_f32 *)(in1 + 2), (vec_f32 *)(inb + 2));
|
||||
}
|
||||
|
||||
#ifndef USE_MERGE_MMA
|
||||
FORCEINLINE void vec_load_mult28a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in0[4], in1[4];
|
||||
|
||||
vec_load4_mma(in0, in1, ina, inb);
|
||||
|
||||
vec_mult4a_mma(&out[0], in0 + 0, in1 + 0, inp);
|
||||
vec_mult4a_mma(&out[4], in0 + 2, in1 + 2, inp);
|
||||
}
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_loadN_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
vec_bf16 in0 = vec_loadN(ina, n);
|
||||
vec_bf16 in1 = vec_loadN(inb, n);
|
||||
|
||||
vec_mult2a_mma(out, in0, in1, inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult11b_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in00 = vec_mergeh(in0, in1);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)inp, (vec_uc8)in00);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult2b_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in01 = vec_mergel(in0, in1);
|
||||
|
||||
vec_mult11b_mma(&out[0], in0, in1, inp);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(&out[1], (vec_uc8)inp, (vec_uc8)in01);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult4b_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp)
|
||||
{
|
||||
vec_mult2b_mma(out + 0, in0[0], in1[0], inp);
|
||||
vec_mult2b_mma(out + 2, in0[1], in1[1], inp);
|
||||
}
|
||||
|
||||
#ifdef USE_MERGE_MMA
|
||||
FORCEINLINE void vec_mult1c_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in00 = vec_mergeh(in0, in0);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)inp, (vec_uc8)in00);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult2c_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in01 = vec_mergel(in0, in0);
|
||||
|
||||
vec_mult1c_mma(&out[0], in0, inp);
|
||||
|
||||
__builtin_mma_xvbf16ger2pp(&out[1], (vec_uc8)inp, (vec_uc8)in01);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult44_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
|
||||
{
|
||||
vec_mult2_mma(out, in[0], inp[0]);
|
||||
vec_mult2c_mma(out, in[1], inp[1]);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult44c_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
|
||||
{
|
||||
vec_mult2c_mma(out, in[0], inp[0]);
|
||||
vec_mult2c_mma(out, in[1], inp[1]);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult44a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
|
||||
{
|
||||
vec_mult2a_mma(out, in0[0], in1[0], inp[0]);
|
||||
vec_mult2b_mma(out, in0[1], in1[1], inp[1]);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_mult44b_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
|
||||
{
|
||||
vec_mult2b_mma(out, in0[0], in1[0], inp[0]);
|
||||
vec_mult2b_mma(out, in0[1], in1[1], inp[1]);
|
||||
}
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_loadN_mult11b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
vec_bf16 in0 = vec_loadN(ina, n);
|
||||
vec_bf16 in1 = vec_loadN(inb, n);
|
||||
|
||||
vec_mult11b_mma(out, in0, in1, inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in0 = (vec_bf16)vec_load_vec(ina);
|
||||
vec_bf16 in1 = (vec_bf16)vec_load_vec(inb);
|
||||
|
||||
vec_mult2b_mma(out, in0, in1, inp);
|
||||
}
|
||||
|
||||
#ifndef USE_MERGE_MMA
|
||||
FORCEINLINE void vec_load_mult28b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
|
||||
{
|
||||
vec_bf16 in0[4], in1[4];
|
||||
|
||||
vec_load4_mma(in0, in1, ina, inb);
|
||||
|
||||
vec_mult4b_mma(&out[0], in0 + 0, in1 + 0, inp);
|
||||
vec_mult4b_mma(&out[4], in0 + 2, in1 + 2, inp);
|
||||
}
|
||||
#else
|
||||
FORCEINLINE void vec_load_mult184_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in0[4];
|
||||
|
||||
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0));
|
||||
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2));
|
||||
|
||||
vec_mult44_mma(out, in0 + 0, inp + 0);
|
||||
vec_mult44c_mma(out, in0 + 2, inp + 2);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult284a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in0[4], in1[4];
|
||||
|
||||
vec_load4_mma(in0, in1, ina, inb);
|
||||
|
||||
vec_mult44a_mma(out, in0 + 0, in1 + 0, inp + 0);
|
||||
vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult284b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in0[4], in1[4];
|
||||
|
||||
vec_load4_mma(in0, in1, ina, inb);
|
||||
|
||||
vec_mult44b_mma(out, in0 + 0, in1 + 0, inp + 0);
|
||||
vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult288a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in0[8], in1[8];
|
||||
|
||||
vec_load4_mma(in0 + 0, in1 + 0, ina + 0, inb + 0);
|
||||
vec_load4_mma(in0 + 4, in1 + 4, ina + 4, inb + 4);
|
||||
|
||||
vec_mult44a_mma(out + 0, in0 + 0, in1 + 0, inp + 0);
|
||||
vec_mult44a_mma(out + 2, in0 + 4, in1 + 4, inp + 0);
|
||||
vec_mult44b_mma(out + 0, in0 + 2, in1 + 2, inp + 2);
|
||||
vec_mult44b_mma(out + 2, in0 + 6, in1 + 6, inp + 2);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load_mult288b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp)
|
||||
{
|
||||
vec_bf16 in0[8], in1[8];
|
||||
|
||||
vec_load4_mma(in0 + 0, in1 + 0, ina + 0, inb + 0);
|
||||
vec_load4_mma(in0 + 4, in1 + 4, ina + 4, inb + 4);
|
||||
|
||||
vec_mult44b_mma(out + 0, in0 + 0, in1 + 0, inp + 0);
|
||||
vec_mult44b_mma(out + 2, in0 + 4, in1 + 4, inp + 0);
|
||||
vec_mult44b_mma(out + 0, in0 + 2, in1 + 2, inp + 2);
|
||||
vec_mult44b_mma(out + 2, in0 + 6, in1 + 6, inp + 2);
|
||||
}
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_loadN_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
|
||||
{
|
||||
vec_bf16 in0 = vec_loadN(ina, n);
|
||||
vec_bf16 in1 = vec_loadN(inb, n);
|
||||
|
||||
vec_mult2b_mma(out, in0, in1, inp);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_load4_pair(vec_f32 *vy0, vec_f32 *v_y)
|
||||
{
|
||||
vec_load_pair(vy0 + 0, v_y + 0);
|
||||
vec_load_pair(vy0 + 2, v_y + 2);
|
||||
vec_load_pair(vy0 + 4, v_y + 4);
|
||||
vec_load_pair(vy0 + 6, v_y + 6);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_store4_pair(vec_f32 *v_y, vec_f32 *vy0)
|
||||
{
|
||||
vec_store_pair(v_y + 0, vy0 + 0);
|
||||
vec_store_pair(v_y + 2, vy0 + 2);
|
||||
vec_store_pair(v_y + 4, vy0 + 4);
|
||||
vec_store_pair(v_y + 6, vy0 + 6);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_setzero_2(__vector_quad *temp0)
|
||||
{
|
||||
__builtin_mma_xxsetaccz(&temp0[0]);
|
||||
__builtin_mma_xxsetaccz(&temp0[1]);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_setzero_4(__vector_quad *temp0)
|
||||
{
|
||||
vec_setzero_2(temp0 + 0);
|
||||
vec_setzero_2(temp0 + 2);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_setzero_8(__vector_quad *temp0)
|
||||
{
|
||||
vec_setzero_4(temp0 + 0);
|
||||
vec_setzero_4(temp0 + 4);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_reduce_2(vec_f32 *temp00, __vector_quad *temp0)
|
||||
{
|
||||
__builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]);
|
||||
__builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_reduce_4(vec_f32 *temp00, __vector_quad *temp0)
|
||||
{
|
||||
vec_reduce_2(temp00 + 0, temp0 + 0);
|
||||
vec_reduce_2(temp00 + 8, temp0 + 2);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_reduce_8(vec_f32 *temp00, __vector_quad *temp0)
|
||||
{
|
||||
vec_reduce_4(temp00 + 0, temp0 + 0);
|
||||
vec_reduce_4(temp00 + 16, temp0 + 4);
|
||||
}
|
||||
|
||||
#ifdef USE_MERGE_MMA
|
||||
FORCEINLINE void vec_load8_pair(vec_f32 *vy0, vec_f32 *v_y)
|
||||
{
|
||||
vec_load4_pair(vy0 + 0, v_y + 0);
|
||||
vec_load4_pair(vy0 + 8, v_y + 8);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_store8_pair(vec_f32 *v_y, vec_f32 *vy0)
|
||||
{
|
||||
vec_store4_pair(v_y + 0, vy0 + 0);
|
||||
vec_store4_pair(v_y + 8, vy0 + 8);
|
||||
}
|
||||
|
||||
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
#define VEC_SHIFT(data, shift) vec_sldw(data, data, 4 - shift)
|
||||
|
||||
#define MASK_0 0xf000
|
||||
#define MASK_1 0x0f00
|
||||
#define MASK_2 0x00f0
|
||||
#define MASK_3 0x000f
|
||||
#else
|
||||
#define VEC_SHIFT(data, shift) vec_sldw(data, data, shift)
|
||||
|
||||
#define MASK_0 0x000f
|
||||
#define MASK_1 0x00f0
|
||||
#define MASK_2 0x0f00
|
||||
#define MASK_3 0xf000
|
||||
#endif
|
||||
|
||||
FORCEINLINE void vec_make_mult1(vec_bf16 *v_x0, const bool mask)
|
||||
{
|
||||
if (mask) {
|
||||
v_x0[ 0] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_0));
|
||||
}
|
||||
|
||||
v_x0[ 1] = VEC_SHIFT(v_x0[ 0], 1);
|
||||
v_x0[ 2] = VEC_SHIFT(v_x0[ 0], 2);
|
||||
v_x0[ 3] = VEC_SHIFT(v_x0[ 0], 3);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_make_mult2(vec_bf16 *v_x0)
|
||||
{
|
||||
v_x0[ 5] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_1));
|
||||
vec_make_mult1(v_x0, true);
|
||||
|
||||
v_x0[ 4] = VEC_SHIFT(v_x0[ 5], 3);
|
||||
v_x0[ 6] = VEC_SHIFT(v_x0[ 5], 1);
|
||||
v_x0[ 7] = VEC_SHIFT(v_x0[ 5], 2);
|
||||
}
|
||||
|
||||
FORCEINLINE void vec_make_mult4(vec_bf16 *v_x0)
|
||||
{
|
||||
v_x0[10] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_2));
|
||||
v_x0[15] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_3));
|
||||
vec_make_mult2(v_x0);
|
||||
|
||||
v_x0[ 8] = VEC_SHIFT(v_x0[10], 2);
|
||||
v_x0[ 9] = VEC_SHIFT(v_x0[10], 3);
|
||||
v_x0[11] = VEC_SHIFT(v_x0[10], 1);
|
||||
v_x0[12] = VEC_SHIFT(v_x0[15], 1);
|
||||
v_x0[13] = VEC_SHIFT(v_x0[15], 2);
|
||||
v_x0[14] = VEC_SHIFT(v_x0[15], 3);
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
|
@ -0,0 +1,152 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2024, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef SBGEMV_N_COMMON_C
|
||||
#define SBGEMV_N_COMMON_C
|
||||
|
||||
#if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_N_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_N_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_N_VSX))
|
||||
#define USE_N_8
|
||||
#endif
|
||||
|
||||
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y)
|
||||
{
|
||||
IFLOAT *x_ptr, *ap[4];
|
||||
IFLOAT xbuffer[8] __attribute__((aligned(16)));
|
||||
FLOAT *y_ptr, *ybuffer;
|
||||
FLOAT buffer[NBMAX] __attribute__((aligned(16)));
|
||||
|
||||
if ((m < 1) || (n < 1)) return 0;
|
||||
|
||||
ybuffer = buffer;
|
||||
y_ptr = y;
|
||||
|
||||
BLASLONG lda4 = lda << 2;
|
||||
#ifdef USE_N_8
|
||||
BLASLONG lda8 = lda << 3;
|
||||
#endif
|
||||
BLASLONG NB = NBMAX;
|
||||
BLASLONG m2 = (m & (NBMAX - 1));
|
||||
|
||||
while (NB == NBMAX) {
|
||||
m -= NB;
|
||||
if (m < 0) {
|
||||
if (m2 == 0) break;
|
||||
NB = m2;
|
||||
}
|
||||
|
||||
if (inc_y != 1) {
|
||||
copy_y_beta(NB, y_ptr, ybuffer, inc_y, beta);
|
||||
} else {
|
||||
ybuffer = y_ptr;
|
||||
BF16GEMV_N_beta(NB, ybuffer, ybuffer, beta);
|
||||
}
|
||||
|
||||
x_ptr = x;
|
||||
|
||||
ap[0] = a;
|
||||
ap[1] = a + lda;
|
||||
ap[2] = ap[1] + lda;
|
||||
ap[3] = ap[2] + lda;
|
||||
|
||||
if (inc_x == 1) {
|
||||
#ifdef USE_N_8
|
||||
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
||||
BF16GEMV_N_8(NB, ap, x_ptr, ybuffer, lda4, alpha);
|
||||
ap[0] += lda8;
|
||||
ap[1] += lda8;
|
||||
ap[2] += lda8;
|
||||
ap[3] += lda8;
|
||||
x_ptr += 8;
|
||||
}
|
||||
if (n & 4) {
|
||||
#else
|
||||
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
|
||||
#endif
|
||||
BF16GEMV_N_4(NB, ap, x_ptr, ybuffer, alpha);
|
||||
ap[0] += lda4;
|
||||
ap[1] += lda4;
|
||||
#ifndef USE_N_8
|
||||
ap[2] += lda4;
|
||||
ap[3] += lda4;
|
||||
#endif
|
||||
x_ptr += 4;
|
||||
}
|
||||
if (n & 2) {
|
||||
BF16GEMV_N_2(NB, ap, x_ptr, ybuffer, alpha);
|
||||
ap[0] += (lda * 2);
|
||||
x_ptr += 2;
|
||||
}
|
||||
if (n & 1) {
|
||||
BF16GEMV_N_1(NB, ap, x_ptr, ybuffer, alpha);
|
||||
}
|
||||
} else {
|
||||
#ifdef USE_N_8
|
||||
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
||||
copy_x(8, x_ptr, xbuffer, inc_x);
|
||||
BF16GEMV_N_8(NB, ap, xbuffer, ybuffer, lda4, alpha);
|
||||
ap[0] += lda8;
|
||||
ap[1] += lda8;
|
||||
ap[2] += lda8;
|
||||
ap[3] += lda8;
|
||||
x_ptr += 8 * inc_x;
|
||||
}
|
||||
if (n & 4) {
|
||||
#else
|
||||
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
|
||||
#endif
|
||||
copy_x(4, x_ptr, xbuffer, inc_x);
|
||||
BF16GEMV_N_4(NB, ap, xbuffer, ybuffer, alpha);
|
||||
ap[0] += lda4;
|
||||
ap[1] += lda4;
|
||||
#ifndef USE_N_8
|
||||
ap[2] += lda4;
|
||||
ap[3] += lda4;
|
||||
#endif
|
||||
x_ptr += 4 * inc_x;
|
||||
}
|
||||
if (n & 2) {
|
||||
copy_x(2, x_ptr, xbuffer, inc_x);
|
||||
BF16GEMV_N_2(NB, ap, xbuffer, ybuffer, alpha);
|
||||
ap[0] += (lda * 2);
|
||||
x_ptr += 2 * inc_x;
|
||||
}
|
||||
if (n & 1) {
|
||||
copy_x(1, x_ptr, xbuffer, inc_x);
|
||||
BF16GEMV_N_1(NB, ap, xbuffer, ybuffer, alpha);
|
||||
}
|
||||
}
|
||||
|
||||
a += NB;
|
||||
if (inc_y != 1) {
|
||||
move_y(NB, ybuffer, y_ptr, inc_y);
|
||||
}
|
||||
y_ptr += (NB * inc_y);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
#endif
|
|
@ -0,0 +1,474 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2024, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef SBGEMV_N_MMA_C
|
||||
#define SBGEMV_N_MMA_C
|
||||
|
||||
#define USE_BFGEMV_N_MMA
|
||||
|
||||
#ifdef USE_BFGEMV_N_MMA
|
||||
#include "sbgemv_common_power10.c"
|
||||
|
||||
#ifndef BF16GEMV_N_X
|
||||
#define BF16GEMV_N_X
|
||||
#define BF16GEMV_N_8 BF16GEMV_N_MMA_8
|
||||
#define BF16GEMV_N_4 BF16GEMV_N_MMA_4
|
||||
#define BF16GEMV_N_2 BF16GEMV_N_MMA_2
|
||||
#define BF16GEMV_N_1 BF16GEMV_N_MMA_1
|
||||
#endif
|
||||
|
||||
#define USE_BFGEMV_8_N_MMA
|
||||
|
||||
static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0;
|
||||
__vector_quad temp[2*4];
|
||||
vec_f32 temp0[8*4];
|
||||
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||
|
||||
a0 = ap[0];
|
||||
|
||||
vec_bf16 *va0 = (vec_bf16 *)a0;
|
||||
|
||||
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||
|
||||
vec_f32 *v_y = (vec_f32 *)y;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
#ifdef USE_MERGE_MMA
|
||||
vec_bf16 v_x0[4];
|
||||
v_x0[0] = vec_loadN(x_bf, 1);
|
||||
vec_f32 vy0[2*4*2];
|
||||
|
||||
vec_make_mult1(v_x0, false);
|
||||
|
||||
for (; i + 8 <= n8; i += 8) {
|
||||
vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]);
|
||||
vec_load_mult184_mma(&temp[2], &va0[i + 4], &v_x0[ 0]);
|
||||
|
||||
vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
|
||||
|
||||
vec_store8_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
if (n8 & 4) {
|
||||
vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]);
|
||||
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
|
||||
i += 4;
|
||||
}
|
||||
#else
|
||||
vec_bf16 v_x0[1];
|
||||
v_x0[0] = vec_loadN(x_bf, 1);
|
||||
vec_f32 vy0[2*4];
|
||||
|
||||
for (; i + 4 <= n8; i += 4) {
|
||||
vec_load_mult18_mma(&temp[0], &va0[i + 0], v_x0[ 0]);
|
||||
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_mult12_mma(&temp[0], &va0[i], v_x0[ 0]);
|
||||
|
||||
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n > 4) {
|
||||
vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0[ 0], n);
|
||||
|
||||
n &= 3;
|
||||
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n);
|
||||
} else if (n) {
|
||||
vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0[ 0], n);
|
||||
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||
}
|
||||
}
|
||||
|
||||
static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0, *a1;
|
||||
__vector_quad temp[2*4];
|
||||
vec_f32 temp0[8*4];
|
||||
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||
|
||||
a0 = ap[0];
|
||||
a1 = ap[1];
|
||||
|
||||
vec_bf16 *va0 = (vec_bf16 *)a0;
|
||||
vec_bf16 *va1 = (vec_bf16 *)a1;
|
||||
|
||||
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||
|
||||
vec_f32 *v_y = (vec_f32 *)y;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
#ifdef USE_MERGE_MMA
|
||||
vec_bf16 v_x0[4];
|
||||
vec_f32 vy0[2*4*2];
|
||||
v_x0[0] = vec_loadN(x_bf, 2);
|
||||
|
||||
vec_make_mult1(v_x0, false);
|
||||
|
||||
for (; i + 8 <= n8; i += 8) {
|
||||
vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
|
||||
|
||||
vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
|
||||
|
||||
vec_store8_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
if (n8 & 4) {
|
||||
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
|
||||
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
|
||||
i += 4;
|
||||
}
|
||||
#else
|
||||
vec_bf16 v_x0[1];
|
||||
vec_f32 vy0[2*4];
|
||||
v_x0[0] = vec_loadN(x_bf, 2);
|
||||
|
||||
for (; i + 4 <= n8; i += 4) {
|
||||
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]);
|
||||
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]);
|
||||
|
||||
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n > 4) {
|
||||
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
|
||||
|
||||
n &= 3;
|
||||
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n);
|
||||
} else if (n) {
|
||||
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
|
||||
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||
}
|
||||
}
|
||||
|
||||
static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0, *a1, *a2, *a3;
|
||||
__vector_quad temp[2*4];
|
||||
vec_f32 temp0[8*4];
|
||||
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||
|
||||
a0 = ap[0];
|
||||
a1 = ap[1];
|
||||
a2 = ap[2];
|
||||
a3 = ap[3];
|
||||
|
||||
vec_bf16 *va0 = (vec_bf16 *)a0;
|
||||
vec_bf16 *va1 = (vec_bf16 *)a1;
|
||||
vec_bf16 *va2 = (vec_bf16 *)a2;
|
||||
vec_bf16 *va3 = (vec_bf16 *)a3;
|
||||
|
||||
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||
|
||||
vec_f32 *v_y = (vec_f32 *)y;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
#ifdef USE_MERGE_MMA
|
||||
vec_bf16 v_x0[8];
|
||||
vec_f32 vy0[2*4*2];
|
||||
v_x0[0] = vec_loadN(x_bf, 4);
|
||||
|
||||
vec_make_mult2(v_x0);
|
||||
|
||||
for (; i + 8 <= n8; i += 8) {
|
||||
vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
|
||||
vec_load_mult288b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]);
|
||||
|
||||
vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
|
||||
|
||||
vec_store8_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
if (n8 & 4) {
|
||||
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
|
||||
vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]);
|
||||
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
|
||||
i += 4;
|
||||
}
|
||||
#else
|
||||
vec_bf16 v_x0[5];
|
||||
vec_f32 vy0[2*4];
|
||||
v_x0[0] = vec_loadN(x_bf, 4);
|
||||
|
||||
v_x0[ 4] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 1);
|
||||
|
||||
for (; i + 4 <= n8; i += 4) {
|
||||
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]);
|
||||
vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x0[ 4]);
|
||||
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]);
|
||||
vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4]);
|
||||
|
||||
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n > 4) {
|
||||
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
|
||||
vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n);
|
||||
|
||||
n &= 3;
|
||||
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n);
|
||||
} else if (n) {
|
||||
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
|
||||
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n);
|
||||
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_BFGEMV_8_N_MMA
|
||||
static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLASLONG lda4, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3;
|
||||
__vector_quad temp[2*4];
|
||||
vec_f32 temp0[8*4];
|
||||
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||
|
||||
a0 = ap[0];
|
||||
a1 = ap[1];
|
||||
a2 = ap[2];
|
||||
a3 = ap[3];
|
||||
b0 = a0 + lda4;
|
||||
b1 = a1 + lda4;
|
||||
b2 = a2 + lda4;
|
||||
b3 = a3 + lda4;
|
||||
|
||||
vec_bf16 *va0 = (vec_bf16 *)a0;
|
||||
vec_bf16 *va1 = (vec_bf16 *)a1;
|
||||
vec_bf16 *va2 = (vec_bf16 *)a2;
|
||||
vec_bf16 *va3 = (vec_bf16 *)a3;
|
||||
vec_bf16 *vb0 = (vec_bf16 *)b0;
|
||||
vec_bf16 *vb1 = (vec_bf16 *)b1;
|
||||
vec_bf16 *vb2 = (vec_bf16 *)b2;
|
||||
vec_bf16 *vb3 = (vec_bf16 *)b3;
|
||||
|
||||
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||
|
||||
vec_f32 *v_y = (vec_f32 *)y;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
#ifdef USE_MERGE_MMA
|
||||
vec_bf16 v_x0[16];
|
||||
vec_f32 vy0[2*4*2];
|
||||
v_x0[0] = (vec_bf16)vec_load_vec(x_bf);
|
||||
|
||||
vec_make_mult4(v_x0);
|
||||
|
||||
for (; i + 8 <= n8; i += 8) {
|
||||
vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
|
||||
vec_load_mult288b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]);
|
||||
vec_load_mult288b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]);
|
||||
vec_load_mult288b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]);
|
||||
|
||||
vec_load8_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0);
|
||||
|
||||
vec_store8_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
if (n8 & 4) {
|
||||
vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]);
|
||||
vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]);
|
||||
vec_load_mult284b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]);
|
||||
vec_load_mult284b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]);
|
||||
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
|
||||
i += 4;
|
||||
}
|
||||
#else
|
||||
vec_bf16 v_x0[13];
|
||||
vec_f32 vy0[2*4];
|
||||
v_x0[0] = (vec_bf16)vec_load_vec(x_bf);
|
||||
|
||||
v_x0[ 4] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 1);
|
||||
v_x0[ 8] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 2);
|
||||
v_x0[12] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 3);
|
||||
|
||||
for (; i + 4 <= n8; i += 4) {
|
||||
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]);
|
||||
vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x0[ 4]);
|
||||
vec_load_mult28b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], v_x0[ 8]);
|
||||
vec_load_mult28b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], v_x0[12]);
|
||||
|
||||
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]);
|
||||
vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4]);
|
||||
vec_load_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8]);
|
||||
vec_load_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12]);
|
||||
|
||||
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_store_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n > 4) {
|
||||
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
|
||||
vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n);
|
||||
vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n);
|
||||
vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n);
|
||||
|
||||
n &= 3;
|
||||
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n);
|
||||
|
||||
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n);
|
||||
} else if (n) {
|
||||
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n);
|
||||
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n);
|
||||
vec_loadN_mult11b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n);
|
||||
vec_loadN_mult11b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n);
|
||||
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0);
|
||||
|
||||
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#include "sbgemv_n.c"
|
||||
#else
|
||||
#include "sbgemv_n_vsx.c"
|
||||
#endif
|
||||
#endif
|
||||
|
|
@ -0,0 +1,299 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2024, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef SBGEMV_N_VSX_C
|
||||
#define SBGEMV_N_VSX_C
|
||||
|
||||
#include "sbgemv_common.c"
|
||||
|
||||
#ifndef BF16GEMV_N_X
|
||||
#define BF16GEMV_N_X
|
||||
#define BF16GEMV_N_8 BF16GEMV_N_VSX_8
|
||||
#define BF16GEMV_N_4 BF16GEMV_N_VSX_4
|
||||
#define BF16GEMV_N_2 BF16GEMV_N_VSX_2
|
||||
#define BF16GEMV_N_1 BF16GEMV_N_VSX_1
|
||||
#endif
|
||||
|
||||
#define USE_BFGEMV_8_N_VSX
|
||||
|
||||
static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0;
|
||||
vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 };
|
||||
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||
|
||||
a0 = ap[0];
|
||||
|
||||
vec_bf16 *va0 = (vec_bf16 *)a0;
|
||||
|
||||
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||
vec_f32 x_0 = vec_loadNHi(x_bf, 1, zero);
|
||||
x_0 *= v_alpha;
|
||||
|
||||
vec_f32 v_x0 = vec_splat(x_0, 0);
|
||||
|
||||
vec_f32 *v_y = (vec_f32 *)y;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
vec_f32 vy0[2];
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult2(v_x0, &va0[i], zero, vy0);
|
||||
|
||||
vec_store_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n > 4) {
|
||||
BLASLONG n3 = n & 3;
|
||||
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
|
||||
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
|
||||
|
||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
} else if (n) {
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vy0[0] += vec_loadNHi_mult(&va0[i], v_x0, n, zero);
|
||||
|
||||
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||
}
|
||||
}
|
||||
|
||||
static void BF16GEMV_N_VSX_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0, *a1;
|
||||
vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 };
|
||||
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||
|
||||
a0 = ap[0];
|
||||
a1 = ap[1];
|
||||
|
||||
vec_bf16 *va0 = (vec_bf16 *)a0;
|
||||
vec_bf16 *va1 = (vec_bf16 *)a1;
|
||||
|
||||
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||
vec_f32 x_0 = vec_loadNHi(x_bf, 2, zero);
|
||||
x_0 *= v_alpha;
|
||||
|
||||
vec_f32 v_x0 = vec_splat(x_0, 0);
|
||||
vec_f32 v_x1 = vec_splat(x_0, 1);
|
||||
|
||||
vec_f32 *v_y = (vec_f32 *)y;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
vec_f32 vy0[2];
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult2(v_x0, &va0[i], zero, vy0);
|
||||
vec_load_mult2(v_x1, &va1[i], zero, vy0);
|
||||
|
||||
vec_store_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n > 4) {
|
||||
BLASLONG n3 = n & 3;
|
||||
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
|
||||
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
|
||||
vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0);
|
||||
|
||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
} else if (n) {
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vy0[0] += vec_loadNHi_mult(&va0[i], v_x0, n, zero);
|
||||
vy0[0] += vec_loadNHi_mult(&va1[i], v_x1, n, zero);
|
||||
|
||||
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||
}
|
||||
}
|
||||
|
||||
static void BF16GEMV_N_VSX_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0, *a1, *a2, *a3;
|
||||
vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 };
|
||||
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||
|
||||
a0 = ap[0];
|
||||
a1 = ap[1];
|
||||
a2 = ap[2];
|
||||
a3 = ap[3];
|
||||
|
||||
vec_bf16 *va0 = (vec_bf16 *)a0;
|
||||
vec_bf16 *va1 = (vec_bf16 *)a1;
|
||||
vec_bf16 *va2 = (vec_bf16 *)a2;
|
||||
vec_bf16 *va3 = (vec_bf16 *)a3;
|
||||
|
||||
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||
vec_f32 x_0 = vec_loadNHi(x_bf, 4, zero);
|
||||
x_0 *= v_alpha;
|
||||
|
||||
vec_f32 v_x0 = vec_splat(x_0, 0);
|
||||
vec_f32 v_x1 = vec_splat(x_0, 1);
|
||||
vec_f32 v_x2 = vec_splat(x_0, 2);
|
||||
vec_f32 v_x3 = vec_splat(x_0, 3);
|
||||
|
||||
vec_f32 *v_y = (vec_f32 *)y;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
vec_f32 vy0[2];
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult2(v_x0, &va0[i], zero, vy0);
|
||||
vec_load_mult2(v_x1, &va1[i], zero, vy0);
|
||||
vec_load_mult2(v_x2, &va2[i], zero, vy0);
|
||||
vec_load_mult2(v_x3, &va3[i], zero, vy0);
|
||||
|
||||
vec_store_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n > 4) {
|
||||
BLASLONG n3 = n & 3;
|
||||
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
|
||||
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
|
||||
vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0);
|
||||
vec_loadN_mult2(v_x2, &va2[i], n, zero, vy0);
|
||||
vec_loadN_mult2(v_x3, &va3[i], n, zero, vy0);
|
||||
|
||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
} else if (n) {
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vy0[0] += vec_loadNHi_mult(&va0[i], v_x0, n, zero);
|
||||
vy0[0] += vec_loadNHi_mult(&va1[i], v_x1, n, zero);
|
||||
vy0[0] += vec_loadNHi_mult(&va2[i], v_x2, n, zero);
|
||||
vy0[0] += vec_loadNHi_mult(&va3[i], v_x3, n, zero);
|
||||
|
||||
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_BFGEMV_8_N_VSX
|
||||
static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLASLONG lda4, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3;
|
||||
vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 };
|
||||
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||
|
||||
a0 = ap[0];
|
||||
a1 = ap[1];
|
||||
a2 = ap[2];
|
||||
a3 = ap[3];
|
||||
b0 = a0 + lda4;
|
||||
b1 = a1 + lda4;
|
||||
b2 = a2 + lda4;
|
||||
b3 = a3 + lda4;
|
||||
|
||||
vec_bf16 *va0 = (vec_bf16 *)a0;
|
||||
vec_bf16 *va1 = (vec_bf16 *)a1;
|
||||
vec_bf16 *va2 = (vec_bf16 *)a2;
|
||||
vec_bf16 *va3 = (vec_bf16 *)a3;
|
||||
vec_bf16 *vb0 = (vec_bf16 *)b0;
|
||||
vec_bf16 *vb1 = (vec_bf16 *)b1;
|
||||
vec_bf16 *vb2 = (vec_bf16 *)b2;
|
||||
vec_bf16 *vb3 = (vec_bf16 *)b3;
|
||||
|
||||
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||
vec_bf16 x_in = (vec_bf16)vec_load_vec(x_bf);
|
||||
vec_f32 x_0 = BF16_HI(x_in, zero);
|
||||
vec_f32 x_1 = BF16_LO(x_in, zero);
|
||||
x_0 *= v_alpha;
|
||||
x_1 *= v_alpha;
|
||||
|
||||
vec_f32 v_x0 = vec_splat(x_0, 0);
|
||||
vec_f32 v_x1 = vec_splat(x_0, 1);
|
||||
vec_f32 v_x2 = vec_splat(x_0, 2);
|
||||
vec_f32 v_x3 = vec_splat(x_0, 3);
|
||||
vec_f32 v_x4 = vec_splat(x_1, 0);
|
||||
vec_f32 v_x5 = vec_splat(x_1, 1);
|
||||
vec_f32 v_x6 = vec_splat(x_1, 2);
|
||||
vec_f32 v_x7 = vec_splat(x_1, 3);
|
||||
|
||||
vec_f32 *v_y = (vec_f32 *)y;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
vec_f32 vy0[2];
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||
|
||||
vec_load_mult2(v_x0, &va0[i], zero, vy0);
|
||||
vec_load_mult2(v_x1, &va1[i], zero, vy0);
|
||||
vec_load_mult2(v_x2, &va2[i], zero, vy0);
|
||||
vec_load_mult2(v_x3, &va3[i], zero, vy0);
|
||||
vec_load_mult2(v_x4, &vb0[i], zero, vy0);
|
||||
vec_load_mult2(v_x5, &vb1[i], zero, vy0);
|
||||
vec_load_mult2(v_x6, &vb2[i], zero, vy0);
|
||||
vec_load_mult2(v_x7, &vb3[i], zero, vy0);
|
||||
|
||||
vec_store_pair(&v_y[(i * 2) + 0], vy0);
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n > 4) {
|
||||
BLASLONG n3 = n & 3;
|
||||
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
|
||||
vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0);
|
||||
vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0);
|
||||
vec_loadN_mult2(v_x2, &va2[i], n, zero, vy0);
|
||||
vec_loadN_mult2(v_x3, &va3[i], n, zero, vy0);
|
||||
vec_loadN_mult2(v_x4, &vb0[i], n, zero, vy0);
|
||||
vec_loadN_mult2(v_x5, &vb1[i], n, zero, vy0);
|
||||
vec_loadN_mult2(v_x6, &vb2[i], n, zero, vy0);
|
||||
vec_loadN_mult2(v_x7, &vb3[i], n, zero, vy0);
|
||||
|
||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
} else if (n) {
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vy0[0] += vec_loadNHi_mult(&va0[i], v_x0, n, zero);
|
||||
vy0[0] += vec_loadNHi_mult(&va1[i], v_x1, n, zero);
|
||||
vy0[0] += vec_loadNHi_mult(&va2[i], v_x2, n, zero);
|
||||
vy0[0] += vec_loadNHi_mult(&va3[i], v_x3, n, zero);
|
||||
vy0[0] += vec_loadNHi_mult(&vb0[i], v_x4, n, zero);
|
||||
vy0[0] += vec_loadNHi_mult(&vb1[i], v_x5, n, zero);
|
||||
vy0[0] += vec_loadNHi_mult(&vb2[i], v_x6, n, zero);
|
||||
vy0[0] += vec_loadNHi_mult(&vb3[i], v_x7, n, zero);
|
||||
|
||||
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#include "sbgemv_n.c"
|
||||
#endif
|
|
@ -0,0 +1,137 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2024, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef SBGEMV_T_COMMON_C
|
||||
#define SBGEMV_T_COMMON_C
|
||||
|
||||
#if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_T_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_T_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_T_VSX))
|
||||
#define USE_T_8
|
||||
#endif
|
||||
|
||||
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y)
|
||||
{
|
||||
IFLOAT *xbuffer, *a_ptr;
|
||||
IFLOAT buffer[NBMAX] __attribute__((aligned(16)));
|
||||
FLOAT ybuffer[8] __attribute__((aligned(16)));
|
||||
FLOAT *y_ptr;
|
||||
|
||||
if ((m < 1) || (n < 1)) return 0;
|
||||
|
||||
if (inc_y == 1) {
|
||||
BF16GEMV_N_beta(n, y, y, beta);
|
||||
}
|
||||
|
||||
xbuffer = buffer;
|
||||
|
||||
BLASLONG lda4 = lda << 2;
|
||||
#ifdef USE_T_8
|
||||
BLASLONG lda8 = lda << 3;
|
||||
#endif
|
||||
BLASLONG NB = NBMAX;
|
||||
BLASLONG m2 = (m & (NBMAX - 1));
|
||||
|
||||
while (NB == NBMAX) {
|
||||
m -= NB;
|
||||
if (m < 0) {
|
||||
if (m2 == 0) break;
|
||||
NB = m2;
|
||||
}
|
||||
|
||||
a_ptr = a;
|
||||
a += NB;
|
||||
y_ptr = y;
|
||||
|
||||
if (inc_x != 1) {
|
||||
copy_x(NB, x, xbuffer, inc_x);
|
||||
x += NB * inc_x;
|
||||
} else {
|
||||
xbuffer = x;
|
||||
x += NB;
|
||||
}
|
||||
|
||||
if (inc_y == 1) {
|
||||
#ifdef USE_T_8
|
||||
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
||||
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
|
||||
y_ptr += 8;
|
||||
a_ptr += lda8;
|
||||
}
|
||||
if (n & 4) {
|
||||
#else
|
||||
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
|
||||
#endif
|
||||
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
|
||||
y_ptr += 4;
|
||||
a_ptr += lda4;
|
||||
}
|
||||
if (n & 2) {
|
||||
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
|
||||
y_ptr += 2;
|
||||
a_ptr += (lda * 2);
|
||||
}
|
||||
if (n & 1) {
|
||||
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
|
||||
}
|
||||
} else {
|
||||
#ifdef USE_T_8
|
||||
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
||||
memset(ybuffer, 0, sizeof(FLOAT) * 8);
|
||||
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
|
||||
copy_y(8, ybuffer, y_ptr, inc_y, beta);
|
||||
y_ptr += 8 * inc_y;
|
||||
a_ptr += lda8;
|
||||
}
|
||||
if (n & 4) {
|
||||
#else
|
||||
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
|
||||
#endif
|
||||
memset(ybuffer, 0, sizeof(FLOAT) * 4);
|
||||
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
|
||||
copy_y(4, ybuffer, y_ptr, inc_y, beta);
|
||||
y_ptr += 4 * inc_y;
|
||||
a_ptr += lda4;
|
||||
}
|
||||
if (n & 2) {
|
||||
memset(ybuffer, 0, sizeof(FLOAT) * 4);
|
||||
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
|
||||
copy_y(2, ybuffer, y_ptr, inc_y, beta);
|
||||
y_ptr += 2 * inc_y;
|
||||
a_ptr += (lda * 2);
|
||||
}
|
||||
if (n & 1) {
|
||||
memset(ybuffer, 0, sizeof(FLOAT) * 4);
|
||||
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
|
||||
copy_y(1, ybuffer, y_ptr, inc_y, beta);
|
||||
}
|
||||
beta = (FLOAT)1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
|
@ -0,0 +1,338 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2024, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef SBGEMV_T_MMA_C
|
||||
#define SBGEMV_T_MMA_C
|
||||
|
||||
#define USE_BFGEMV_T_MMA
|
||||
|
||||
#ifdef USE_BFGEMV_T_MMA
|
||||
#include "sbgemv_common_power10.c"
|
||||
|
||||
#ifndef BF16GEMV_T_X
|
||||
#define BF16GEMV_T_X
|
||||
#define BF16GEMV_T_8 BF16GEMV_T_MMA_8
|
||||
#define BF16GEMV_T_4 BF16GEMV_T_MMA_4
|
||||
#define BF16GEMV_T_2 BF16GEMV_T_MMA_2
|
||||
#define BF16GEMV_T_1 BF16GEMV_T_MMA_1
|
||||
#endif
|
||||
|
||||
#define USE_BFGEMV_8_T_MMA
|
||||
|
||||
static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0;
|
||||
vec_bf16 *va0, *v_x;
|
||||
__vector_quad temp0;
|
||||
vec_f32 temp00[4];
|
||||
vec_bf16 inp[4];
|
||||
|
||||
__builtin_mma_xxsetaccz(&temp0);
|
||||
|
||||
a0 = ap;
|
||||
va0 = (vec_bf16 *)a0;
|
||||
v_x = (vec_bf16 *)x;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
for (; i + 4 <= n8; i += 4) {
|
||||
vec_load_pair2(inp, &v_x[i]);
|
||||
|
||||
vec_load_mult4_mma(&temp0, &va0[i + 0], inp);
|
||||
}
|
||||
|
||||
if (n8 & 2) {
|
||||
vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]);
|
||||
|
||||
vec_load_mult2_mma(&temp0, &va0[i + 0], inp);
|
||||
|
||||
i += 2;
|
||||
}
|
||||
|
||||
if (n8 & 1) {
|
||||
inp[0] = (vec_bf16)vec_load_vec(&v_x[i]);
|
||||
|
||||
vec_load_mult_mma(&temp0, &va0[i], inp[0]);
|
||||
|
||||
i++;
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n) {
|
||||
inp[0] = vec_loadN(&v_x[i], n);
|
||||
|
||||
vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n);
|
||||
}
|
||||
|
||||
__builtin_mma_disassemble_acc((void*)temp00, &temp0);
|
||||
|
||||
y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3]));
|
||||
}
|
||||
|
||||
static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0, *a1;
|
||||
vec_bf16 *va0, *va1, *v_x;
|
||||
__vector_quad temp0[2];
|
||||
vec_f32 temp00[4*2];
|
||||
vec_bf16 inp[4];
|
||||
|
||||
vec_setzero_2(&temp0[0]);
|
||||
|
||||
a0 = ap;
|
||||
a1 = ap + lda;
|
||||
va0 = (vec_bf16 *)a0;
|
||||
va1 = (vec_bf16 *)a1;
|
||||
v_x = (vec_bf16 *)x;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
for (; i + 4 <= n8; i += 4) {
|
||||
vec_load_pair2(inp, &v_x[i]);
|
||||
|
||||
vec_load_mult42_mma(&temp0[0], &va0[i + 0], &va1[i + 0], inp);
|
||||
}
|
||||
|
||||
if (n8 & 2) {
|
||||
vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]);
|
||||
|
||||
vec_load_mult22_mma(&temp0[0], &va0[i + 0], &va1[i + 0], inp);
|
||||
|
||||
i += 2;
|
||||
}
|
||||
|
||||
if (n8 & 1) {
|
||||
inp[0] = (vec_bf16)vec_load_vec(&v_x[i]);
|
||||
|
||||
vec_load_mult12a_mma(&temp0[0], &va0[i], &va1[i], inp[0]);
|
||||
|
||||
i++;
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n) {
|
||||
inp[0] = vec_loadN(&v_x[i], n);
|
||||
|
||||
vec_loadN_mult12a_mma(&temp0[0], &va0[i], &va1[i], inp[0], n);
|
||||
}
|
||||
|
||||
vec_reduce_2(temp00, &temp0[0]);
|
||||
|
||||
y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3]));
|
||||
y[1] += (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3]));
|
||||
}
|
||||
|
||||
static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0, *a1, *a2, *a3;
|
||||
vec_bf16 *va0, *va1, *va2, *va3, *v_x;
|
||||
__vector_quad temp0[4];
|
||||
vec_f32 temp00[4*4];
|
||||
vec_bf16 inp[4];
|
||||
|
||||
vec_setzero_4(&temp0[0]);
|
||||
|
||||
a0 = ap;
|
||||
a1 = ap + lda;
|
||||
a2 = a1 + lda;
|
||||
a3 = a2 + lda;
|
||||
va0 = (vec_bf16 *)a0;
|
||||
va1 = (vec_bf16 *)a1;
|
||||
va2 = (vec_bf16 *)a2;
|
||||
va3 = (vec_bf16 *)a3;
|
||||
v_x = (vec_bf16 *)x;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
for (; i + 4 <= n8; i += 4) {
|
||||
vec_load_pair2(inp, &v_x[i]);
|
||||
|
||||
vec_load_mult44_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp);
|
||||
}
|
||||
|
||||
if (n8 & 2) {
|
||||
vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]);
|
||||
|
||||
vec_load_mult24_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp);
|
||||
|
||||
i += 2;
|
||||
}
|
||||
|
||||
if (n8 & 1) {
|
||||
inp[0] = (vec_bf16)vec_load_vec(&v_x[i]);
|
||||
|
||||
vec_load_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0]);
|
||||
|
||||
i++;
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n) {
|
||||
inp[0] = vec_loadN(&v_x[i], n);
|
||||
|
||||
vec_loadN_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0], n);
|
||||
}
|
||||
|
||||
vec_reduce_4(temp00, &temp0[0]);
|
||||
|
||||
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7;
|
||||
vec_f32 a = { alpha, alpha, alpha, alpha };
|
||||
vec_f32 *v_y = (vec_f32 *) y;
|
||||
|
||||
t0 = vec_mergeh(temp00[ 0], temp00[ 4]);
|
||||
t1 = vec_mergeh(temp00[ 8], temp00[12]);
|
||||
t2 = vec_mergeo(temp00[ 1], temp00[ 5]);
|
||||
t3 = vec_mergeo(temp00[ 9], temp00[13]);
|
||||
t4 = vec_mergel(temp00[ 2], temp00[ 6]);
|
||||
t5 = vec_mergel(temp00[10], temp00[14]);
|
||||
t6 = vec_mergeo(temp00[ 3], temp00[ 7]);
|
||||
t7 = vec_mergeo(temp00[11], temp00[15]);
|
||||
t0 = vec_xxpermdi(t0, t1, 0);
|
||||
t2 = vec_xxpermdi(t2, t3, 0);
|
||||
t4 = vec_xxpermdi(t4, t5, 0);
|
||||
t6 = vec_xxpermdi(t6, t7, 3);
|
||||
|
||||
t0 += t2 + t4 + t6;
|
||||
|
||||
v_y[0] += (a * t0);
|
||||
}
|
||||
|
||||
#ifdef USE_BFGEMV_8_T_MMA
|
||||
static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
|
||||
vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x;
|
||||
__vector_quad temp0[8];
|
||||
vec_f32 temp00[4*8];
|
||||
vec_bf16 inp[4];
|
||||
|
||||
vec_setzero_8(&temp0[0]);
|
||||
|
||||
BLASLONG lda4 = lda << 2;
|
||||
a0 = ap;
|
||||
a1 = ap + lda;
|
||||
a2 = a1 + lda;
|
||||
a3 = a2 + lda;
|
||||
a4 = a0 + lda4;
|
||||
a5 = a1 + lda4;
|
||||
a6 = a2 + lda4;
|
||||
a7 = a3 + lda4;
|
||||
va0 = (vec_bf16 *)a0;
|
||||
va1 = (vec_bf16 *)a1;
|
||||
va2 = (vec_bf16 *)a2;
|
||||
va3 = (vec_bf16 *)a3;
|
||||
va4 = (vec_bf16 *)a4;
|
||||
va5 = (vec_bf16 *)a5;
|
||||
va6 = (vec_bf16 *)a6;
|
||||
va7 = (vec_bf16 *)a7;
|
||||
v_x = (vec_bf16 *)x;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
for (; i + 4 <= n8; i += 4) {
|
||||
vec_load_pair2(inp, &v_x[i]);
|
||||
|
||||
vec_load_mult44_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp);
|
||||
vec_load_mult44_mma(&temp0[4], &va4[i + 0], &va5[i + 0], &va6[i + 0], &va7[i + 0], inp);
|
||||
}
|
||||
|
||||
if (n8 & 2) {
|
||||
vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]);
|
||||
|
||||
vec_load_mult24_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp);
|
||||
vec_load_mult24_mma(&temp0[4], &va4[i + 0], &va5[i + 0], &va6[i + 0], &va7[i + 0], inp);
|
||||
|
||||
i += 2;
|
||||
}
|
||||
|
||||
if (n8 & 1) {
|
||||
inp[0] = (vec_bf16)vec_load_vec(&v_x[i]);
|
||||
|
||||
vec_load_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0]);
|
||||
vec_load_mult14_mma(&temp0[4], &va4[i], &va5[i], &va6[i], &va7[i], inp[0]);
|
||||
|
||||
i++;
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n) {
|
||||
inp[0] = vec_loadN(&v_x[i], n);
|
||||
|
||||
vec_loadN_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0], n);
|
||||
vec_loadN_mult14_mma(&temp0[4], &va4[i], &va5[i], &va6[i], &va7[i], inp[0], n);
|
||||
}
|
||||
|
||||
vec_reduce_8(temp00, &temp0[0]);
|
||||
|
||||
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17;
|
||||
vec_f32 a = { alpha, alpha, alpha, alpha };
|
||||
vec_f32 *v_y = (vec_f32 *) y;
|
||||
|
||||
t0 = vec_mergeh(temp00[ 0], temp00[ 4]);
|
||||
t1 = vec_mergeh(temp00[ 8], temp00[12]);
|
||||
t2 = vec_mergeo(temp00[ 1], temp00[ 5]);
|
||||
t3 = vec_mergeo(temp00[ 9], temp00[13]);
|
||||
t4 = vec_mergel(temp00[ 2], temp00[ 6]);
|
||||
t5 = vec_mergel(temp00[10], temp00[14]);
|
||||
t6 = vec_mergeo(temp00[ 3], temp00[ 7]);
|
||||
t7 = vec_mergeo(temp00[11], temp00[15]);
|
||||
t0 = vec_xxpermdi(t0, t1, 0);
|
||||
t2 = vec_xxpermdi(t2, t3, 0);
|
||||
t4 = vec_xxpermdi(t4, t5, 0);
|
||||
t6 = vec_xxpermdi(t6, t7, 3);
|
||||
|
||||
t0 += t2 + t4 + t6;
|
||||
|
||||
t10 = vec_mergeh(temp00[16], temp00[20]);
|
||||
t11 = vec_mergeh(temp00[24], temp00[28]);
|
||||
t12 = vec_mergeo(temp00[17], temp00[21]);
|
||||
t13 = vec_mergeo(temp00[25], temp00[29]);
|
||||
t14 = vec_mergel(temp00[18], temp00[22]);
|
||||
t15 = vec_mergel(temp00[26], temp00[30]);
|
||||
t16 = vec_mergeo(temp00[19], temp00[23]);
|
||||
t17 = vec_mergeo(temp00[27], temp00[31]);
|
||||
t10 = vec_xxpermdi(t10, t11, 0);
|
||||
t12 = vec_xxpermdi(t12, t13, 0);
|
||||
t14 = vec_xxpermdi(t14, t15, 0);
|
||||
t16 = vec_xxpermdi(t16, t17, 3);
|
||||
|
||||
t10 += t12 + t14 + t16;
|
||||
|
||||
vec_f32 inp2[2];
|
||||
vec_load_pair(inp2, v_y);
|
||||
inp2[0] += (a * t0);
|
||||
inp2[1] += (a * t10);
|
||||
vec_store_pair(v_y, inp2);
|
||||
}
|
||||
#endif
|
||||
|
||||
#include "sbgemv_t.c"
|
||||
#else
|
||||
#include "sbgemv_t_vsx.c"
|
||||
#endif
|
||||
#endif
|
||||
|
|
@ -0,0 +1,292 @@
|
|||
/***************************************************************************
|
||||
Copyright (c) 2024, The OpenBLAS Project
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
3. Neither the name of the OpenBLAS project nor the names of
|
||||
its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef SBGEMV_T_VSX_C
|
||||
#define SBGEMV_T_VSX_C
|
||||
|
||||
#include "sbgemv_common.c"
|
||||
|
||||
#ifndef BF16GEMV_T_X
|
||||
#define BF16GEMV_T_X
|
||||
#define BF16GEMV_T_8 BF16GEMV_T_VSX_8
|
||||
#define BF16GEMV_T_4 BF16GEMV_T_VSX_4
|
||||
#define BF16GEMV_T_2 BF16GEMV_T_VSX_2
|
||||
#define BF16GEMV_T_1 BF16GEMV_T_VSX_1
|
||||
#endif
|
||||
|
||||
#define USE_BFGEMV_8_T_VSX
|
||||
|
||||
static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0;
|
||||
vec_bf16 *va0, *v_x;
|
||||
vec_f32 temp0 = { 0, 0, 0, 0 };
|
||||
vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 };
|
||||
vec_f32 inp[2];
|
||||
|
||||
a0 = ap;
|
||||
va0 = (vec_bf16 *)a0;
|
||||
v_x = (vec_bf16 *)x;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_vec2(&v_x[i], inp, zero);
|
||||
|
||||
temp0 += vec_load_mult(&va0[i], inp, zero);
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n > 4) {
|
||||
vec_loadN_vec2(&v_x[i], inp, n, zero);
|
||||
|
||||
temp0 += vec_loadN_mult(&va0[i], inp, n, zero);
|
||||
} else if (n) {
|
||||
inp[0] = vec_loadNHi(&v_x[i], n, zero);
|
||||
|
||||
temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero);
|
||||
}
|
||||
|
||||
y[0] += (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3]));
|
||||
}
|
||||
|
||||
static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0, *a1;
|
||||
vec_bf16 *va0, *va1, *v_x;
|
||||
vec_f32 temp0 = { 0, 0, 0, 0 };
|
||||
vec_f32 temp1 = { 0, 0, 0, 0 };
|
||||
vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 };
|
||||
vec_f32 inp[2];
|
||||
|
||||
a0 = ap;
|
||||
a1 = ap + lda;
|
||||
va0 = (vec_bf16 *)a0;
|
||||
va1 = (vec_bf16 *)a1;
|
||||
v_x = (vec_bf16 *)x;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_vec2(&v_x[i], inp, zero);
|
||||
|
||||
temp0 += vec_load_mult(&va0[i], inp, zero);
|
||||
temp1 += vec_load_mult(&va1[i], inp, zero);
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n > 4) {
|
||||
vec_loadN_vec2(&v_x[i], inp, n, zero);
|
||||
|
||||
temp0 += vec_loadN_mult(&va0[i], inp, n, zero);
|
||||
temp1 += vec_loadN_mult(&va1[i], inp, n, zero);
|
||||
} else if (n) {
|
||||
inp[0] = vec_loadNHi(&v_x[i], n, zero);
|
||||
|
||||
temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero);
|
||||
temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero);
|
||||
}
|
||||
|
||||
y[0] += (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3]));
|
||||
y[1] += (alpha * (temp1[0] + temp1[1] + temp1[2] + temp1[3]));
|
||||
}
|
||||
|
||||
static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0, *a1, *a2, *a3;
|
||||
vec_bf16 *va0, *va1, *va2, *va3, *v_x;
|
||||
vec_f32 temp0 = { 0, 0, 0, 0 };
|
||||
vec_f32 temp1 = { 0, 0, 0, 0 };
|
||||
vec_f32 temp2 = { 0, 0, 0, 0 };
|
||||
vec_f32 temp3 = { 0, 0, 0, 0 };
|
||||
vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 };
|
||||
vec_f32 inp[2];
|
||||
|
||||
a0 = ap;
|
||||
a1 = ap + lda;
|
||||
a2 = a1 + lda;
|
||||
a3 = a2 + lda;
|
||||
va0 = (vec_bf16 *)a0;
|
||||
va1 = (vec_bf16 *)a1;
|
||||
va2 = (vec_bf16 *)a2;
|
||||
va3 = (vec_bf16 *)a3;
|
||||
v_x = (vec_bf16 *)x;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_vec2(&v_x[i], inp, zero);
|
||||
|
||||
temp0 += vec_load_mult(&va0[i], inp, zero);
|
||||
temp1 += vec_load_mult(&va1[i], inp, zero);
|
||||
temp2 += vec_load_mult(&va2[i], inp, zero);
|
||||
temp3 += vec_load_mult(&va3[i], inp, zero);
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n > 4) {
|
||||
vec_loadN_vec2(&v_x[i], inp, n, zero);
|
||||
|
||||
temp0 += vec_loadN_mult(&va0[i], inp, n, zero);
|
||||
temp1 += vec_loadN_mult(&va1[i], inp, n, zero);
|
||||
temp2 += vec_loadN_mult(&va2[i], inp, n, zero);
|
||||
temp3 += vec_loadN_mult(&va3[i], inp, n, zero);
|
||||
} else if (n) {
|
||||
inp[0] = vec_loadNHi(&v_x[i], n, zero);
|
||||
|
||||
temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero);
|
||||
temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero);
|
||||
temp2 += vec_loadNHi_mult(&va2[i], inp[0], n, zero);
|
||||
temp3 += vec_loadNHi_mult(&va3[i], inp[0], n, zero);
|
||||
}
|
||||
|
||||
vec_f32 t0, t1, t2, t3;
|
||||
vec_f32 a = { alpha, alpha, alpha, alpha };
|
||||
vec_f32 *v_y = (vec_f32 *) y;
|
||||
|
||||
t0 = vec_mergeh(temp0, temp2);
|
||||
t1 = vec_mergel(temp0, temp2);
|
||||
t2 = vec_mergeh(temp1, temp3);
|
||||
t3 = vec_mergel(temp1, temp3);
|
||||
temp0 = vec_mergeh(t0, t2);
|
||||
temp1 = vec_mergel(t0, t2);
|
||||
temp2 = vec_mergeh(t1, t3);
|
||||
temp3 = vec_mergel(t1, t3);
|
||||
temp0 += temp1 + temp2 + temp3;
|
||||
|
||||
v_y[0] += (a * temp0);
|
||||
}
|
||||
|
||||
#ifdef USE_BFGEMV_8_T_VSX
|
||||
static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
|
||||
vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x;
|
||||
vec_f32 temp0 = { 0, 0, 0, 0 };
|
||||
vec_f32 temp1 = { 0, 0, 0, 0 };
|
||||
vec_f32 temp2 = { 0, 0, 0, 0 };
|
||||
vec_f32 temp3 = { 0, 0, 0, 0 };
|
||||
vec_f32 temp4 = { 0, 0, 0, 0 };
|
||||
vec_f32 temp5 = { 0, 0, 0, 0 };
|
||||
vec_f32 temp6 = { 0, 0, 0, 0 };
|
||||
vec_f32 temp7 = { 0, 0, 0, 0 };
|
||||
vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 };
|
||||
vec_f32 inp[2];
|
||||
|
||||
BLASLONG lda4 = lda << 2;
|
||||
a0 = ap;
|
||||
a1 = ap + lda;
|
||||
a2 = a1 + lda;
|
||||
a3 = a2 + lda;
|
||||
a4 = a0 + lda4;
|
||||
a5 = a1 + lda4;
|
||||
a6 = a2 + lda4;
|
||||
a7 = a3 + lda4;
|
||||
va0 = (vec_bf16 *)a0;
|
||||
va1 = (vec_bf16 *)a1;
|
||||
va2 = (vec_bf16 *)a2;
|
||||
va3 = (vec_bf16 *)a3;
|
||||
va4 = (vec_bf16 *)a4;
|
||||
va5 = (vec_bf16 *)a5;
|
||||
va6 = (vec_bf16 *)a6;
|
||||
va7 = (vec_bf16 *)a7;
|
||||
v_x = (vec_bf16 *)x;
|
||||
BLASLONG n8 = n / 8;
|
||||
BLASLONG i = 0;
|
||||
|
||||
for (; i < n8; i++) {
|
||||
vec_load_vec2(&v_x[i], inp, zero);
|
||||
|
||||
temp0 += vec_load_mult(&va0[i], inp, zero);
|
||||
temp1 += vec_load_mult(&va1[i], inp, zero);
|
||||
temp2 += vec_load_mult(&va2[i], inp, zero);
|
||||
temp3 += vec_load_mult(&va3[i], inp, zero);
|
||||
temp4 += vec_load_mult(&va4[i], inp, zero);
|
||||
temp5 += vec_load_mult(&va5[i], inp, zero);
|
||||
temp6 += vec_load_mult(&va6[i], inp, zero);
|
||||
temp7 += vec_load_mult(&va7[i], inp, zero);
|
||||
}
|
||||
|
||||
n &= 7;
|
||||
if (n > 4) {
|
||||
vec_loadN_vec2(&v_x[i], inp, n, zero);
|
||||
|
||||
temp0 += vec_loadN_mult(&va0[i], inp, n, zero);
|
||||
temp1 += vec_loadN_mult(&va1[i], inp, n, zero);
|
||||
temp2 += vec_loadN_mult(&va2[i], inp, n, zero);
|
||||
temp3 += vec_loadN_mult(&va3[i], inp, n, zero);
|
||||
temp4 += vec_loadN_mult(&va4[i], inp, n, zero);
|
||||
temp5 += vec_loadN_mult(&va5[i], inp, n, zero);
|
||||
temp6 += vec_loadN_mult(&va6[i], inp, n, zero);
|
||||
temp7 += vec_loadN_mult(&va7[i], inp, n, zero);
|
||||
} else if (n) {
|
||||
inp[0] = vec_loadNHi(&v_x[i], n, zero);
|
||||
|
||||
temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero);
|
||||
temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero);
|
||||
temp2 += vec_loadNHi_mult(&va2[i], inp[0], n, zero);
|
||||
temp3 += vec_loadNHi_mult(&va3[i], inp[0], n, zero);
|
||||
temp4 += vec_loadNHi_mult(&va4[i], inp[0], n, zero);
|
||||
temp5 += vec_loadNHi_mult(&va5[i], inp[0], n, zero);
|
||||
temp6 += vec_loadNHi_mult(&va6[i], inp[0], n, zero);
|
||||
temp7 += vec_loadNHi_mult(&va7[i], inp[0], n, zero);
|
||||
}
|
||||
|
||||
vec_f32 t0, t1, t2, t3, t10, t11, t12, t13;
|
||||
vec_f32 a = { alpha, alpha, alpha, alpha };
|
||||
vec_f32 *v_y = (vec_f32 *) y;
|
||||
|
||||
t0 = vec_mergeh(temp0, temp2);
|
||||
t1 = vec_mergel(temp0, temp2);
|
||||
t2 = vec_mergeh(temp1, temp3);
|
||||
t3 = vec_mergel(temp1, temp3);
|
||||
temp0 = vec_mergeh(t0, t2);
|
||||
temp1 = vec_mergel(t0, t2);
|
||||
temp2 = vec_mergeh(t1, t3);
|
||||
temp3 = vec_mergel(t1, t3);
|
||||
temp0 += temp1 + temp2 + temp3;
|
||||
|
||||
t10 = vec_mergeh(temp4, temp6);
|
||||
t11 = vec_mergel(temp4, temp6);
|
||||
t12 = vec_mergeh(temp5, temp7);
|
||||
t13 = vec_mergel(temp5, temp7);
|
||||
temp4 = vec_mergeh(t10, t12);
|
||||
temp5 = vec_mergel(t10, t12);
|
||||
temp6 = vec_mergeh(t11, t13);
|
||||
temp7 = vec_mergel(t11, t13);
|
||||
temp4 += temp5 + temp6 + temp7;
|
||||
|
||||
vec_load_pair(inp, v_y);
|
||||
inp[0] += (a * temp0);
|
||||
inp[1] += (a * temp4);
|
||||
vec_store_pair(v_y, inp);
|
||||
}
|
||||
#endif
|
||||
|
||||
#include "sbgemv_t.c"
|
||||
#endif
|
||||
|
|
@ -202,16 +202,17 @@ main (int argc, char *argv[])
|
|||
return ret;
|
||||
}
|
||||
|
||||
for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one.
|
||||
for (x = 1; x <= loop; x++)
|
||||
{
|
||||
k = (x == 0) ? 0 : 1;
|
||||
k = (x == 0) ? 0 : l + 1;
|
||||
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT));
|
||||
float *B = (float *)malloc_safe(x * sizeof(FLOAT));
|
||||
float *C = (float *)malloc_safe(x * sizeof(FLOAT));
|
||||
float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l);
|
||||
float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l);
|
||||
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits));
|
||||
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits));
|
||||
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) << l);
|
||||
float *DD = (float *)malloc_safe(x * sizeof(FLOAT));
|
||||
float *CC = (float *)malloc_safe(x * sizeof(FLOAT));
|
||||
float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l);
|
||||
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
|
||||
(DD == NULL) || (CC == NULL))
|
||||
return 1;
|
||||
|
@ -226,9 +227,9 @@ main (int argc, char *argv[])
|
|||
sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one);
|
||||
AA[j * x + i].v = atmp;
|
||||
}
|
||||
B[j] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
sbstobf16_(&one, &B[j], &one, &btmp, &one);
|
||||
BB[j].v = btmp;
|
||||
B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
||||
sbstobf16_(&one, &B[j << l], &one, &btmp, &one);
|
||||
BB[j << l].v = btmp;
|
||||
}
|
||||
for (y = 0; y < 2; y++)
|
||||
{
|
||||
|
@ -238,9 +239,9 @@ main (int argc, char *argv[])
|
|||
transA = 'T';
|
||||
}
|
||||
|
||||
memset(CC, 0, x * sizeof(FLOAT));
|
||||
memset(CC, 0, x * sizeof(FLOAT) << l);
|
||||
memset(DD, 0, x * sizeof(FLOAT));
|
||||
memset(C, 0, x * sizeof(FLOAT));
|
||||
memset(C, 0, x * sizeof(FLOAT) << l);
|
||||
|
||||
SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k);
|
||||
SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k);
|
||||
|
@ -248,15 +249,15 @@ main (int argc, char *argv[])
|
|||
for (j = 0; j < x; j++)
|
||||
for (i = 0; i < x; i++)
|
||||
if (transA == 'N') {
|
||||
DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j]);
|
||||
DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]);
|
||||
} else if (transA == 'T') {
|
||||
DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i]);
|
||||
DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]);
|
||||
}
|
||||
|
||||
for (j = 0; j < x; j++) {
|
||||
if (fabs (CC[j] - C[j]) > 1.0)
|
||||
if (fabs (CC[j << l] - C[j << l]) > 1.0)
|
||||
ret++;
|
||||
if (fabs (CC[j] - DD[j]) > 1.0)
|
||||
if (fabs (CC[j << l] - DD[j]) > 1.0)
|
||||
ret++;
|
||||
}
|
||||
}
|
||||
|
@ -268,6 +269,7 @@ main (int argc, char *argv[])
|
|||
free(DD);
|
||||
free(CC);
|
||||
}
|
||||
}
|
||||
|
||||
if (ret != 0)
|
||||
fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret);
|
||||
|
|
Loading…
Reference in New Issue