MMA BF16 GEMV code.
This commit is contained in:
parent
7947970f9d
commit
89a12fa083
|
@ -4,6 +4,8 @@
|
||||||
|
|
||||||
#include <altivec.h>
|
#include <altivec.h>
|
||||||
|
|
||||||
|
#define NBMAX 4096
|
||||||
|
|
||||||
#define FORCEINLINE inline __attribute__((always_inline))
|
#define FORCEINLINE inline __attribute__((always_inline))
|
||||||
|
|
||||||
#ifdef __clang__
|
#ifdef __clang__
|
||||||
|
|
|
@ -111,7 +111,7 @@ FORCEINLINE vec_f32 vec_loadNHi_mult(vec_bf16 *in, vec_f32 v_inp0, BLASLONG n, v
|
||||||
return (v_inp0 * v_in00);
|
return (v_inp0 * v_in00);
|
||||||
}
|
}
|
||||||
|
|
||||||
FORCEINLINE vec_f32 vec_loadNHi_multi2(vec_f32 v_x0, vec_bf16 *in, BLASLONG n, vec_bf16 zero)
|
FORCEINLINE vec_f32 vec_loadNHi_mult2(vec_f32 v_x0, vec_bf16 *in, BLASLONG n, vec_bf16 zero)
|
||||||
{
|
{
|
||||||
vec_f32 v_in00 = vec_loadNHi(in, n, zero);
|
vec_f32 v_in00 = vec_loadNHi(in, n, zero);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,265 @@
|
||||||
|
/***************************************************************************
|
||||||
|
Copyright (c) 2024, The OpenBLAS Project
|
||||||
|
All rights reserved.
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
1. Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in
|
||||||
|
the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
3. Neither the name of the OpenBLAS project nor the names of
|
||||||
|
its contributors may be used to endorse or promote products
|
||||||
|
derived from this software without specific prior written permission.
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
||||||
|
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||||
|
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*****************************************************************************/
|
||||||
|
|
||||||
|
#ifndef SBGEMV_COMMON_MMA_C
|
||||||
|
#define SBGEMV_COMMON_MMA_C
|
||||||
|
#include "sbgemv_common.c"
|
||||||
|
|
||||||
|
FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_bf16 in0 = (vec_bf16)vec_load_vec(in);
|
||||||
|
|
||||||
|
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
|
||||||
|
{
|
||||||
|
vec_bf16 in0[2];
|
||||||
|
|
||||||
|
vec_load_pair((vec_f32 *)in0, (vec_f32 *)in);
|
||||||
|
|
||||||
|
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[0]);
|
||||||
|
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
|
||||||
|
{
|
||||||
|
vec_bf16 in0 = vec_loadN(in, n);
|
||||||
|
|
||||||
|
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_mult1_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_bf16 in00 = vec_mergeh(in0, in0);
|
||||||
|
|
||||||
|
__builtin_mma_xvbf16ger2(out, (vec_uc8)inp, (vec_uc8)in00);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_mult2_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_bf16 in01 = vec_mergel(in0, in0);
|
||||||
|
|
||||||
|
vec_mult1_mma(&out[0], in0, inp);
|
||||||
|
|
||||||
|
__builtin_mma_xvbf16ger2(&out[1], (vec_uc8)inp, (vec_uc8)in01);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_mult4_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_mult2_mma(out + 0, in0[0], inp);
|
||||||
|
vec_mult2_mma(out + 2, in0[1], inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_loadN_mult11_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
|
||||||
|
{
|
||||||
|
vec_bf16 in0 = vec_loadN(in, n);
|
||||||
|
|
||||||
|
vec_mult1_mma(out, in0, inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_loadN_mult12_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
|
||||||
|
{
|
||||||
|
vec_bf16 in0 = vec_loadN(in, n);
|
||||||
|
|
||||||
|
vec_mult2_mma(out, in0, inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_load_mult12_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_bf16 in0 = (vec_bf16)vec_load_vec(in);
|
||||||
|
|
||||||
|
vec_mult2_mma(out, in0, inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_load_mult18_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_bf16 in0[4];
|
||||||
|
|
||||||
|
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0));
|
||||||
|
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2));
|
||||||
|
|
||||||
|
vec_mult4_mma(&out[0], in0 + 0, inp);
|
||||||
|
vec_mult4_mma(&out[4], in0 + 2, inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_reduce1_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
|
||||||
|
{
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp, &out[0]);
|
||||||
|
|
||||||
|
vy0[0] += (temp[0] * v_alpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_reduce2_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
|
||||||
|
{
|
||||||
|
vec_reduce1_mma(&out[0], &temp[0], v_alpha, &vy0[0]);
|
||||||
|
vec_reduce1_mma(&out[1], &temp[4], v_alpha, &vy0[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_reduce8_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
|
||||||
|
{
|
||||||
|
vec_reduce2_mma(&out[0], &temp[0], v_alpha, vy0 + 0);
|
||||||
|
vec_reduce2_mma(&out[2], &temp[8], v_alpha, vy0 + 2);
|
||||||
|
vec_reduce2_mma(&out[4], &temp[16], v_alpha, vy0 + 4);
|
||||||
|
vec_reduce2_mma(&out[6], &temp[24], v_alpha, vy0 + 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_mult11a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_bf16 in00 = vec_mergeh(in0, in1);
|
||||||
|
|
||||||
|
__builtin_mma_xvbf16ger2(out, (vec_uc8)inp, (vec_uc8)in00);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_mult2a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_bf16 in01 = vec_mergel(in0, in1);
|
||||||
|
|
||||||
|
vec_mult11a_mma(&out[0], in0, in1, inp);
|
||||||
|
|
||||||
|
__builtin_mma_xvbf16ger2(&out[1], (vec_uc8)inp, (vec_uc8)in01);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_mult4a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_mult2a_mma(out + 0, in0[0], in1[0], inp);
|
||||||
|
vec_mult2a_mma(out + 2, in0[1], in1[1], inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_loadN_mult11a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
|
||||||
|
{
|
||||||
|
vec_bf16 in0 = vec_loadN(ina, n);
|
||||||
|
vec_bf16 in1 = vec_loadN(inb, n);
|
||||||
|
|
||||||
|
vec_mult11a_mma(out, in0, in1, inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_load_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_bf16 in0 = (vec_bf16)vec_load_vec(ina);
|
||||||
|
vec_bf16 in1 = (vec_bf16)vec_load_vec(inb);
|
||||||
|
|
||||||
|
vec_mult2a_mma(out, in0, in1, inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_load_mult28a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_bf16 in0[4], in1[4];
|
||||||
|
|
||||||
|
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(ina + 0));
|
||||||
|
vec_load_pair((vec_f32 *)(in1 + 0), (vec_f32 *)(inb + 0));
|
||||||
|
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(ina + 2));
|
||||||
|
vec_load_pair((vec_f32 *)(in1 + 2), (vec_f32 *)(inb + 2));
|
||||||
|
|
||||||
|
vec_mult4a_mma(&out[0], in0 + 0, in1 + 0, inp);
|
||||||
|
vec_mult4a_mma(&out[4], in0 + 2, in1 + 2, inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_loadN_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
|
||||||
|
{
|
||||||
|
vec_bf16 in0 = vec_loadN(ina, n);
|
||||||
|
vec_bf16 in1 = vec_loadN(inb, n);
|
||||||
|
|
||||||
|
vec_mult2a_mma(out, in0, in1, inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_mult11b_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_bf16 in00 = vec_mergeh(in0, in1);
|
||||||
|
|
||||||
|
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)inp, (vec_uc8)in00);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_mult2b_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_bf16 in01 = vec_mergel(in0, in1);
|
||||||
|
|
||||||
|
vec_mult11b_mma(&out[0], in0, in1, inp);
|
||||||
|
|
||||||
|
__builtin_mma_xvbf16ger2pp(&out[1], (vec_uc8)inp, (vec_uc8)in01);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_mult4b_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_mult2b_mma(out + 0, in0[0], in1[0], inp);
|
||||||
|
vec_mult2b_mma(out + 2, in0[1], in1[1], inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_loadN_mult11b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
|
||||||
|
{
|
||||||
|
vec_bf16 in0 = vec_loadN(ina, n);
|
||||||
|
vec_bf16 in1 = vec_loadN(inb, n);
|
||||||
|
|
||||||
|
vec_mult11b_mma(out, in0, in1, inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_load_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_bf16 in0 = (vec_bf16)vec_load_vec(ina);
|
||||||
|
vec_bf16 in1 = (vec_bf16)vec_load_vec(inb);
|
||||||
|
|
||||||
|
vec_mult2b_mma(out, in0, in1, inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_load_mult28b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
|
||||||
|
{
|
||||||
|
vec_bf16 in0[4], in1[4];
|
||||||
|
|
||||||
|
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(ina + 0));
|
||||||
|
vec_load_pair((vec_f32 *)(in1 + 0), (vec_f32 *)(inb + 0));
|
||||||
|
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(ina + 2));
|
||||||
|
vec_load_pair((vec_f32 *)(in1 + 2), (vec_f32 *)(inb + 2));
|
||||||
|
|
||||||
|
vec_mult4b_mma(&out[0], in0 + 0, in1 + 0, inp);
|
||||||
|
vec_mult4b_mma(&out[4], in0 + 2, in1 + 2, inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_loadN_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
|
||||||
|
{
|
||||||
|
vec_bf16 in0 = vec_loadN(ina, n);
|
||||||
|
vec_bf16 in1 = vec_loadN(inb, n);
|
||||||
|
|
||||||
|
vec_mult2b_mma(out, in0, in1, inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_load4_pair(vec_f32 *vy0, vec_f32 *v_y)
|
||||||
|
{
|
||||||
|
vec_load_pair(vy0 + 0, v_y + 0);
|
||||||
|
vec_load_pair(vy0 + 2, v_y + 2);
|
||||||
|
vec_load_pair(vy0 + 4, v_y + 4);
|
||||||
|
vec_load_pair(vy0 + 6, v_y + 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE void vec_store4_pair(vec_f32 *v_y, vec_f32 *vy0)
|
||||||
|
{
|
||||||
|
vec_store_pair(v_y + 0, vy0 + 0);
|
||||||
|
vec_store_pair(v_y + 2, vy0 + 2);
|
||||||
|
vec_store_pair(v_y + 4, vy0 + 4);
|
||||||
|
vec_store_pair(v_y + 6, vy0 + 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -87,6 +87,10 @@ static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vecto
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_N_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_N_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_N_VSX))
|
||||||
|
#define USE_N_8
|
||||||
|
#endif
|
||||||
|
|
||||||
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y)
|
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y)
|
||||||
{
|
{
|
||||||
IFLOAT *x_ptr, *ap[4];
|
IFLOAT *x_ptr, *ap[4];
|
||||||
|
@ -100,7 +104,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||||
y_ptr = y;
|
y_ptr = y;
|
||||||
|
|
||||||
BLASLONG lda4 = lda << 2;
|
BLASLONG lda4 = lda << 2;
|
||||||
|
#ifdef USE_N_8
|
||||||
BLASLONG lda8 = lda << 3;
|
BLASLONG lda8 = lda << 3;
|
||||||
|
#endif
|
||||||
BLASLONG NB = NBMAX;
|
BLASLONG NB = NBMAX;
|
||||||
BLASLONG m2 = (m & (NBMAX - 1));
|
BLASLONG m2 = (m & (NBMAX - 1));
|
||||||
|
|
||||||
|
@ -126,6 +132,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||||
ap[3] = ap[2] + lda;
|
ap[3] = ap[2] + lda;
|
||||||
|
|
||||||
if (inc_x == 1) {
|
if (inc_x == 1) {
|
||||||
|
#ifdef USE_N_8
|
||||||
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
||||||
BF16GEMV_N_8(NB, ap, x_ptr, ybuffer, lda4, alpha);
|
BF16GEMV_N_8(NB, ap, x_ptr, ybuffer, lda4, alpha);
|
||||||
ap[0] += lda8;
|
ap[0] += lda8;
|
||||||
|
@ -135,9 +142,16 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||||
x_ptr += 8;
|
x_ptr += 8;
|
||||||
}
|
}
|
||||||
if (n & 4) {
|
if (n & 4) {
|
||||||
|
#else
|
||||||
|
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
|
||||||
|
#endif
|
||||||
BF16GEMV_N_4(NB, ap, x_ptr, ybuffer, alpha);
|
BF16GEMV_N_4(NB, ap, x_ptr, ybuffer, alpha);
|
||||||
ap[0] += lda4;
|
ap[0] += lda4;
|
||||||
ap[1] += lda4;
|
ap[1] += lda4;
|
||||||
|
#ifndef USE_N_8
|
||||||
|
ap[2] += lda4;
|
||||||
|
ap[3] += lda4;
|
||||||
|
#endif
|
||||||
x_ptr += 4;
|
x_ptr += 4;
|
||||||
}
|
}
|
||||||
if (n & 2) {
|
if (n & 2) {
|
||||||
|
@ -149,6 +163,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||||
BF16GEMV_N_1(NB, ap, x_ptr, ybuffer, alpha);
|
BF16GEMV_N_1(NB, ap, x_ptr, ybuffer, alpha);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
#ifdef USE_N_8
|
||||||
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
||||||
copy_x(8, x_ptr, xbuffer, inc_x);
|
copy_x(8, x_ptr, xbuffer, inc_x);
|
||||||
BF16GEMV_N_8(NB, ap, xbuffer, ybuffer, lda4, alpha);
|
BF16GEMV_N_8(NB, ap, xbuffer, ybuffer, lda4, alpha);
|
||||||
|
@ -159,10 +174,17 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||||
x_ptr += 8 * inc_x;
|
x_ptr += 8 * inc_x;
|
||||||
}
|
}
|
||||||
if (n & 4) {
|
if (n & 4) {
|
||||||
|
#else
|
||||||
|
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
|
||||||
|
#endif
|
||||||
copy_x(4, x_ptr, xbuffer, inc_x);
|
copy_x(4, x_ptr, xbuffer, inc_x);
|
||||||
BF16GEMV_N_4(NB, ap, xbuffer, ybuffer, alpha);
|
BF16GEMV_N_4(NB, ap, xbuffer, ybuffer, alpha);
|
||||||
ap[0] += lda4;
|
ap[0] += lda4;
|
||||||
ap[1] += lda4;
|
ap[1] += lda4;
|
||||||
|
#ifndef USE_N_8
|
||||||
|
ap[2] += lda4;
|
||||||
|
ap[3] += lda4;
|
||||||
|
#endif
|
||||||
x_ptr += 4 * inc_x;
|
x_ptr += 4 * inc_x;
|
||||||
}
|
}
|
||||||
if (n & 2) {
|
if (n & 2) {
|
||||||
|
|
|
@ -25,9 +25,309 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
*****************************************************************************/
|
*****************************************************************************/
|
||||||
|
|
||||||
//#include "sbgemv_common.c"
|
#ifndef SBGEMV_N_MMA_C
|
||||||
|
#define SBGEMV_N_MMA_C
|
||||||
|
|
||||||
|
#define USE_BFGEMV_N_MMA
|
||||||
|
|
||||||
|
#ifdef USE_BFGEMV_N_MMA
|
||||||
|
#include "sbgemv_common_power10.c"
|
||||||
|
|
||||||
|
#ifndef BF16GEMV_N_X
|
||||||
|
#define BF16GEMV_N_X
|
||||||
|
#define BF16GEMV_N_8 BF16GEMV_N_MMA_8
|
||||||
|
#define BF16GEMV_N_4 BF16GEMV_N_MMA_4
|
||||||
|
#define BF16GEMV_N_2 BF16GEMV_N_MMA_2
|
||||||
|
#define BF16GEMV_N_1 BF16GEMV_N_MMA_1
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define USE_BFGEMV_8_N_MMA
|
||||||
|
|
||||||
|
static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha)
|
||||||
|
{
|
||||||
|
IFLOAT *a0;
|
||||||
|
__vector_quad temp[2*4];
|
||||||
|
vec_f32 temp0[8*4], vy0[2*4];
|
||||||
|
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||||
|
|
||||||
|
a0 = ap[0];
|
||||||
|
|
||||||
|
vec_bf16 *va0 = (vec_bf16 *)a0;
|
||||||
|
|
||||||
|
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||||
|
vec_bf16 v_x0 = vec_loadN(x_bf, 1);
|
||||||
|
|
||||||
|
vec_f32 *v_y = (vec_f32 *)y;
|
||||||
|
BLASLONG n8 = n / 8;
|
||||||
|
BLASLONG i = 0;
|
||||||
|
|
||||||
|
for (; i + 4 <= n8; i += 4) {
|
||||||
|
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||||
|
|
||||||
|
vec_load_mult18_mma(&temp[0], &va0[i + 0], v_x0);
|
||||||
|
|
||||||
|
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; i < n8; i++) {
|
||||||
|
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||||
|
|
||||||
|
vec_load_mult12_mma(&temp[0], &va0[i], v_x0);
|
||||||
|
|
||||||
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_store_pair(&v_y[(i * 2) + 0], vy0);
|
||||||
|
}
|
||||||
|
|
||||||
|
n &= 7;
|
||||||
|
if (n > 4) {
|
||||||
|
BLASLONG n3 = n & 3;
|
||||||
|
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
|
|
||||||
|
vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0, n);
|
||||||
|
|
||||||
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
|
} else if (n) {
|
||||||
|
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||||
|
|
||||||
|
vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0, n);
|
||||||
|
|
||||||
|
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha)
|
||||||
|
{
|
||||||
|
IFLOAT *a0, *a1;
|
||||||
|
__vector_quad temp[2*4];
|
||||||
|
vec_f32 temp0[8*4], vy0[2*4];
|
||||||
|
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||||
|
|
||||||
|
a0 = ap[0];
|
||||||
|
a1 = ap[1];
|
||||||
|
|
||||||
|
vec_bf16 *va0 = (vec_bf16 *)a0;
|
||||||
|
vec_bf16 *va1 = (vec_bf16 *)a1;
|
||||||
|
|
||||||
|
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||||
|
vec_bf16 v_x0 = vec_loadN(x_bf, 2);
|
||||||
|
|
||||||
|
vec_f32 *v_y = (vec_f32 *)y;
|
||||||
|
BLASLONG n8 = n / 8;
|
||||||
|
BLASLONG i = 0;
|
||||||
|
|
||||||
|
for (; i + 4 <= n8; i += 4) {
|
||||||
|
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||||
|
|
||||||
|
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0);
|
||||||
|
|
||||||
|
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; i < n8; i++) {
|
||||||
|
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||||
|
|
||||||
|
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0);
|
||||||
|
|
||||||
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_store_pair(&v_y[(i * 2) + 0], vy0);
|
||||||
|
}
|
||||||
|
|
||||||
|
n &= 7;
|
||||||
|
if (n > 4) {
|
||||||
|
BLASLONG n3 = n & 3;
|
||||||
|
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
|
|
||||||
|
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0, n);
|
||||||
|
|
||||||
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
|
} else if (n) {
|
||||||
|
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||||
|
|
||||||
|
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0, n);
|
||||||
|
|
||||||
|
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha)
|
||||||
|
{
|
||||||
|
IFLOAT *a0, *a1, *a2, *a3;
|
||||||
|
__vector_quad temp[2*4];
|
||||||
|
vec_f32 temp0[8*4], vy0[2*4];
|
||||||
|
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||||
|
|
||||||
|
a0 = ap[0];
|
||||||
|
a1 = ap[1];
|
||||||
|
a2 = ap[2];
|
||||||
|
a3 = ap[3];
|
||||||
|
|
||||||
|
vec_bf16 *va0 = (vec_bf16 *)a0;
|
||||||
|
vec_bf16 *va1 = (vec_bf16 *)a1;
|
||||||
|
vec_bf16 *va2 = (vec_bf16 *)a2;
|
||||||
|
vec_bf16 *va3 = (vec_bf16 *)a3;
|
||||||
|
|
||||||
|
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||||
|
vec_bf16 v_x00 = vec_loadN(x_bf, 4);
|
||||||
|
|
||||||
|
vec_bf16 v_x01 = (vec_bf16)vec_splat((vec_f32)v_x00, 1);
|
||||||
|
|
||||||
|
vec_f32 *v_y = (vec_f32 *)y;
|
||||||
|
BLASLONG n8 = n / 8;
|
||||||
|
BLASLONG i = 0;
|
||||||
|
|
||||||
|
for (; i + 4 <= n8; i += 4) {
|
||||||
|
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||||
|
|
||||||
|
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x00);
|
||||||
|
vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x01);
|
||||||
|
|
||||||
|
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; i < n8; i++) {
|
||||||
|
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||||
|
|
||||||
|
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00);
|
||||||
|
vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01);
|
||||||
|
|
||||||
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_store_pair(&v_y[(i * 2) + 0], vy0);
|
||||||
|
}
|
||||||
|
|
||||||
|
n &= 7;
|
||||||
|
if (n > 4) {
|
||||||
|
BLASLONG n3 = n & 3;
|
||||||
|
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
|
|
||||||
|
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00, n);
|
||||||
|
vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01, n);
|
||||||
|
|
||||||
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
|
} else if (n) {
|
||||||
|
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||||
|
|
||||||
|
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x00, n);
|
||||||
|
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x01, n);
|
||||||
|
|
||||||
|
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef USE_BFGEMV_8_N_MMA
|
||||||
|
static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLASLONG lda4, FLOAT alpha)
|
||||||
|
{
|
||||||
|
IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3;
|
||||||
|
__vector_quad temp[2*4];
|
||||||
|
vec_f32 temp0[8*4], vy0[2*4];
|
||||||
|
vec_f32 v_alpha = { alpha, alpha, alpha, alpha };
|
||||||
|
|
||||||
|
a0 = ap[0];
|
||||||
|
a1 = ap[1];
|
||||||
|
a2 = ap[2];
|
||||||
|
a3 = ap[3];
|
||||||
|
b0 = a0 + lda4;
|
||||||
|
b1 = a1 + lda4;
|
||||||
|
b2 = a2 + lda4;
|
||||||
|
b3 = a3 + lda4;
|
||||||
|
|
||||||
|
vec_bf16 *va0 = (vec_bf16 *)a0;
|
||||||
|
vec_bf16 *va1 = (vec_bf16 *)a1;
|
||||||
|
vec_bf16 *va2 = (vec_bf16 *)a2;
|
||||||
|
vec_bf16 *va3 = (vec_bf16 *)a3;
|
||||||
|
vec_bf16 *vb0 = (vec_bf16 *)b0;
|
||||||
|
vec_bf16 *vb1 = (vec_bf16 *)b1;
|
||||||
|
vec_bf16 *vb2 = (vec_bf16 *)b2;
|
||||||
|
vec_bf16 *vb3 = (vec_bf16 *)b3;
|
||||||
|
|
||||||
|
vec_bf16 *x_bf = (vec_bf16 *)(xo);
|
||||||
|
vec_bf16 v_x00 = (vec_bf16)vec_load_vec(x_bf);
|
||||||
|
|
||||||
|
vec_bf16 v_x01 = (vec_bf16)vec_splat((vec_f32)v_x00, 1);
|
||||||
|
vec_bf16 v_x02 = (vec_bf16)vec_splat((vec_f32)v_x00, 2);
|
||||||
|
vec_bf16 v_x03 = (vec_bf16)vec_splat((vec_f32)v_x00, 3);
|
||||||
|
|
||||||
|
vec_f32 *v_y = (vec_f32 *)y;
|
||||||
|
BLASLONG n8 = n / 8;
|
||||||
|
BLASLONG i = 0;
|
||||||
|
|
||||||
|
for (; i + 4 <= n8; i += 4) {
|
||||||
|
vec_load4_pair(vy0, &v_y[(i * 2) + 0]);
|
||||||
|
|
||||||
|
vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x00);
|
||||||
|
vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x01);
|
||||||
|
vec_load_mult28b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], v_x02);
|
||||||
|
vec_load_mult28b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], v_x03);
|
||||||
|
|
||||||
|
vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_store4_pair(&v_y[(i * 2) + 0], vy0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; i < n8; i++) {
|
||||||
|
vec_load_pair(vy0, &v_y[(i * 2) + 0]);
|
||||||
|
|
||||||
|
vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00);
|
||||||
|
vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01);
|
||||||
|
vec_load_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x02);
|
||||||
|
vec_load_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x03);
|
||||||
|
|
||||||
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_store_pair(&v_y[(i * 2) + 0], vy0);
|
||||||
|
}
|
||||||
|
|
||||||
|
n &= 7;
|
||||||
|
if (n > 4) {
|
||||||
|
BLASLONG n3 = n & 3;
|
||||||
|
vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
|
|
||||||
|
vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00, n);
|
||||||
|
vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01, n);
|
||||||
|
vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x02, n);
|
||||||
|
vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x03, n);
|
||||||
|
|
||||||
|
vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
|
} else if (n) {
|
||||||
|
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||||
|
|
||||||
|
vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x00, n);
|
||||||
|
vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x01, n);
|
||||||
|
vec_loadN_mult11b_mma(&temp[0], &vb0[i], &vb1[i], v_x02, n);
|
||||||
|
vec_loadN_mult11b_mma(&temp[0], &vb2[i], &vb3[i], v_x03, n);
|
||||||
|
|
||||||
|
vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0);
|
||||||
|
|
||||||
|
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "sbgemv_n.c"
|
||||||
|
#else
|
||||||
#include "sbgemv_n_vsx.c"
|
#include "sbgemv_n_vsx.c"
|
||||||
|
#endif
|
||||||
//#include "sbgemv_n.c"
|
#endif
|
||||||
|
|
||||||
|
|
|
@ -25,12 +25,20 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
*****************************************************************************/
|
*****************************************************************************/
|
||||||
|
|
||||||
#ifndef SBGEMV_N_VSX
|
#ifndef SBGEMV_N_VSX_C
|
||||||
#define SBGEMV_N_VSX
|
#define SBGEMV_N_VSX_C
|
||||||
|
|
||||||
#include "sbgemv_common.c"
|
#include "sbgemv_common.c"
|
||||||
|
|
||||||
#define NBMAX 4096
|
#ifndef BF16GEMV_N_X
|
||||||
|
#define BF16GEMV_N_X
|
||||||
|
#define BF16GEMV_N_8 BF16GEMV_N_VSX_8
|
||||||
|
#define BF16GEMV_N_4 BF16GEMV_N_VSX_4
|
||||||
|
#define BF16GEMV_N_2 BF16GEMV_N_VSX_2
|
||||||
|
#define BF16GEMV_N_1 BF16GEMV_N_VSX_1
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define USE_BFGEMV_8_N_VSX
|
||||||
|
|
||||||
static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha)
|
static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha)
|
||||||
{
|
{
|
||||||
|
@ -70,11 +78,11 @@ static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
||||||
|
|
||||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
} else if (n) {
|
} else if (n) {
|
||||||
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||||
|
|
||||||
vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero);
|
||||||
|
|
||||||
vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n);
|
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,12 +129,12 @@ static void BF16GEMV_N_VSX_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
||||||
|
|
||||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
} else if (n) {
|
} else if (n) {
|
||||||
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||||
|
|
||||||
vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero);
|
||||||
vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x1, &va1[i], n, zero);
|
||||||
|
|
||||||
vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n);
|
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,17 +191,18 @@ static void BF16GEMV_N_VSX_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
||||||
|
|
||||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
} else if (n) {
|
} else if (n) {
|
||||||
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||||
|
|
||||||
vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero);
|
||||||
vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x1, &va1[i], n, zero);
|
||||||
vy0 += vec_loadNHi_multi2(v_x2, &va2[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x2, &va2[i], n, zero);
|
||||||
vy0 += vec_loadNHi_multi2(v_x3, &va3[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x3, &va3[i], n, zero);
|
||||||
|
|
||||||
vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n);
|
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_BFGEMV_8_N_VSX
|
||||||
static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLASLONG lda4, FLOAT alpha)
|
static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLASLONG lda4, FLOAT alpha)
|
||||||
{
|
{
|
||||||
IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3;
|
IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3;
|
||||||
|
@ -270,25 +279,21 @@ static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
|
||||||
|
|
||||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||||
} else if (n) {
|
} else if (n) {
|
||||||
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||||
|
|
||||||
vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero);
|
||||||
vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x1, &va1[i], n, zero);
|
||||||
vy0 += vec_loadNHi_multi2(v_x2, &va2[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x2, &va2[i], n, zero);
|
||||||
vy0 += vec_loadNHi_multi2(v_x3, &va3[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x3, &va3[i], n, zero);
|
||||||
vy0 += vec_loadNHi_multi2(v_x4, &vb0[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x4, &vb0[i], n, zero);
|
||||||
vy0 += vec_loadNHi_multi2(v_x5, &vb1[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x5, &vb1[i], n, zero);
|
||||||
vy0 += vec_loadNHi_multi2(v_x6, &vb2[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x6, &vb2[i], n, zero);
|
||||||
vy0 += vec_loadNHi_multi2(v_x7, &vb3[i], n, zero);
|
vy0[0] += vec_loadNHi_mult2(v_x7, &vb3[i], n, zero);
|
||||||
|
|
||||||
vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n);
|
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
#define BF16GEMV_N_8 BF16GEMV_N_VSX_8
|
|
||||||
#define BF16GEMV_N_4 BF16GEMV_N_VSX_4
|
|
||||||
#define BF16GEMV_N_2 BF16GEMV_N_VSX_2
|
|
||||||
#define BF16GEMV_N_1 BF16GEMV_N_VSX_1
|
|
||||||
|
|
||||||
#include "sbgemv_n.c"
|
#include "sbgemv_n.c"
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -27,6 +27,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
#ifndef SBGEMV_T_COMMON_C
|
#ifndef SBGEMV_T_COMMON_C
|
||||||
#define SBGEMV_T_COMMON_C
|
#define SBGEMV_T_COMMON_C
|
||||||
|
|
||||||
|
#if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_T_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_T_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_T_VSX))
|
||||||
|
#define USE_T_8
|
||||||
|
#endif
|
||||||
|
|
||||||
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y)
|
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y)
|
||||||
{
|
{
|
||||||
IFLOAT *xbuffer, *a_ptr;
|
IFLOAT *xbuffer, *a_ptr;
|
||||||
|
@ -39,7 +44,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||||
xbuffer = buffer;
|
xbuffer = buffer;
|
||||||
|
|
||||||
BLASLONG lda4 = lda << 2;
|
BLASLONG lda4 = lda << 2;
|
||||||
|
#ifdef USE_T_8
|
||||||
BLASLONG lda8 = lda << 3;
|
BLASLONG lda8 = lda << 3;
|
||||||
|
#endif
|
||||||
BLASLONG NB = NBMAX;
|
BLASLONG NB = NBMAX;
|
||||||
BLASLONG m2 = (m & (NBMAX - 1));
|
BLASLONG m2 = (m & (NBMAX - 1));
|
||||||
|
|
||||||
|
@ -60,12 +67,16 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||||
}
|
}
|
||||||
|
|
||||||
if (inc_y == 1) {
|
if (inc_y == 1) {
|
||||||
|
#ifdef USE_T_8
|
||||||
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
||||||
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
|
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
|
||||||
y_ptr += 8;
|
y_ptr += 8;
|
||||||
a_ptr += lda8;
|
a_ptr += lda8;
|
||||||
}
|
}
|
||||||
if (n & 4) {
|
if (n & 4) {
|
||||||
|
#else
|
||||||
|
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
|
||||||
|
#endif
|
||||||
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
|
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
|
||||||
y_ptr += 4;
|
y_ptr += 4;
|
||||||
a_ptr += lda4;
|
a_ptr += lda4;
|
||||||
|
@ -79,6 +90,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||||
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
|
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
#ifdef USE_T_8
|
||||||
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
||||||
memset(ybuffer, 0, sizeof(FLOAT) * 8);
|
memset(ybuffer, 0, sizeof(FLOAT) * 8);
|
||||||
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
|
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
|
||||||
|
@ -87,6 +99,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||||
a_ptr += lda8;
|
a_ptr += lda8;
|
||||||
}
|
}
|
||||||
if (n & 4) {
|
if (n & 4) {
|
||||||
|
#else
|
||||||
|
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
|
||||||
|
#endif
|
||||||
memset(ybuffer, 0, sizeof(FLOAT) * 4);
|
memset(ybuffer, 0, sizeof(FLOAT) * 4);
|
||||||
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
|
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
|
||||||
copy_y(4, ybuffer, y_ptr, inc_y, beta);
|
copy_y(4, ybuffer, y_ptr, inc_y, beta);
|
||||||
|
|
|
@ -25,8 +25,334 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
*****************************************************************************/
|
*****************************************************************************/
|
||||||
|
|
||||||
//#include "sbgemv_common.c"
|
#ifndef SBGEMV_T_MMA_C
|
||||||
|
#define SBGEMV_T_MMA_C
|
||||||
|
|
||||||
|
#define USE_BFGEMV_T_MMA
|
||||||
|
|
||||||
|
#ifdef USE_BFGEMV_T_MMA
|
||||||
|
#include "sbgemv_common_power10.c"
|
||||||
|
|
||||||
|
#ifndef BF16GEMV_T_X
|
||||||
|
#define BF16GEMV_T_X
|
||||||
|
#define BF16GEMV_T_8 BF16GEMV_T_MMA_8
|
||||||
|
#define BF16GEMV_T_4 BF16GEMV_T_MMA_4
|
||||||
|
#define BF16GEMV_T_2 BF16GEMV_T_MMA_2
|
||||||
|
#define BF16GEMV_T_1 BF16GEMV_T_MMA_1
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define USE_BFGEMV_8_T_MMA
|
||||||
|
|
||||||
|
static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
||||||
|
{
|
||||||
|
IFLOAT *a0;
|
||||||
|
vec_bf16 *va0, *v_x;
|
||||||
|
__vector_quad temp0;
|
||||||
|
vec_f32 temp00[4];
|
||||||
|
vec_bf16 inp[2];
|
||||||
|
|
||||||
|
__builtin_mma_xxsetaccz(&temp0);
|
||||||
|
|
||||||
|
a0 = ap;
|
||||||
|
va0 = (vec_bf16 *)a0;
|
||||||
|
v_x = (vec_bf16 *)x;
|
||||||
|
BLASLONG n8 = n / 8;
|
||||||
|
BLASLONG i = 0;
|
||||||
|
|
||||||
|
for (; i + 2 <= n8; i += 2) {
|
||||||
|
vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]);
|
||||||
|
|
||||||
|
vec_load_mult2_mma(&temp0, &va0[i + 0], inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n8 & 1) {
|
||||||
|
inp[0] = (vec_bf16)vec_load_vec(&v_x[i]);
|
||||||
|
|
||||||
|
vec_load_mult_mma(&temp0, &va0[i], inp[0]);
|
||||||
|
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
n &= 7;
|
||||||
|
if (n) {
|
||||||
|
inp[0] = vec_loadN(&v_x[i], n);
|
||||||
|
|
||||||
|
vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n);
|
||||||
|
}
|
||||||
|
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp00, &temp0);
|
||||||
|
|
||||||
|
y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
||||||
|
{
|
||||||
|
IFLOAT *a0, *a1;
|
||||||
|
vec_bf16 *va0, *va1, *v_x;
|
||||||
|
__vector_quad temp0, temp1;
|
||||||
|
vec_f32 temp00[4], temp01[4];
|
||||||
|
vec_bf16 inp[2];
|
||||||
|
|
||||||
|
__builtin_mma_xxsetaccz(&temp0);
|
||||||
|
__builtin_mma_xxsetaccz(&temp1);
|
||||||
|
|
||||||
|
a0 = ap;
|
||||||
|
a1 = ap + lda;
|
||||||
|
va0 = (vec_bf16 *)a0;
|
||||||
|
va1 = (vec_bf16 *)a1;
|
||||||
|
v_x = (vec_bf16 *)x;
|
||||||
|
BLASLONG n8 = n / 8;
|
||||||
|
BLASLONG i = 0;
|
||||||
|
|
||||||
|
for (; i + 2 <= n8; i += 2) {
|
||||||
|
vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]);
|
||||||
|
|
||||||
|
vec_load_mult2_mma(&temp0, &va0[i + 0], inp);
|
||||||
|
vec_load_mult2_mma(&temp1, &va1[i + 0], inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n8 & 1) {
|
||||||
|
inp[0] = (vec_bf16)vec_load_vec(&v_x[i]);
|
||||||
|
|
||||||
|
vec_load_mult_mma(&temp0, &va0[i], inp[0]);
|
||||||
|
vec_load_mult_mma(&temp1, &va1[i], inp[0]);
|
||||||
|
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
n &= 7;
|
||||||
|
if (n) {
|
||||||
|
inp[0] = vec_loadN(&v_x[i], n);
|
||||||
|
|
||||||
|
vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n);
|
||||||
|
vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n);
|
||||||
|
}
|
||||||
|
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp00, &temp0);
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp01, &temp1);
|
||||||
|
|
||||||
|
y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]);
|
||||||
|
y[1] = (alpha * (temp01[0][0] + temp01[1][1] + temp01[2][2] + temp01[3][3])) + (beta * y[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
||||||
|
{
|
||||||
|
IFLOAT *a0, *a1, *a2, *a3;
|
||||||
|
vec_bf16 *va0, *va1, *va2, *va3, *v_x;
|
||||||
|
__vector_quad temp0, temp1, temp2, temp3;
|
||||||
|
vec_f32 temp00[4], temp01[4], temp02[4], temp03[4];
|
||||||
|
vec_bf16 inp[2];
|
||||||
|
|
||||||
|
__builtin_mma_xxsetaccz(&temp0);
|
||||||
|
__builtin_mma_xxsetaccz(&temp1);
|
||||||
|
__builtin_mma_xxsetaccz(&temp2);
|
||||||
|
__builtin_mma_xxsetaccz(&temp3);
|
||||||
|
|
||||||
|
a0 = ap;
|
||||||
|
a1 = ap + lda;
|
||||||
|
a2 = a1 + lda;
|
||||||
|
a3 = a2 + lda;
|
||||||
|
va0 = (vec_bf16 *)a0;
|
||||||
|
va1 = (vec_bf16 *)a1;
|
||||||
|
va2 = (vec_bf16 *)a2;
|
||||||
|
va3 = (vec_bf16 *)a3;
|
||||||
|
v_x = (vec_bf16 *)x;
|
||||||
|
BLASLONG n8 = n / 8;
|
||||||
|
BLASLONG i = 0;
|
||||||
|
|
||||||
|
for (; i + 2 <= n8; i += 2) {
|
||||||
|
vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]);
|
||||||
|
|
||||||
|
vec_load_mult2_mma(&temp0, &va0[i + 0], inp);
|
||||||
|
vec_load_mult2_mma(&temp1, &va1[i + 0], inp);
|
||||||
|
vec_load_mult2_mma(&temp2, &va2[i + 0], inp);
|
||||||
|
vec_load_mult2_mma(&temp3, &va3[i + 0], inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n8 & 1) {
|
||||||
|
inp[0] = (vec_bf16)vec_load_vec(&v_x[i]);
|
||||||
|
|
||||||
|
vec_load_mult_mma(&temp0, &va0[i], inp[0]);
|
||||||
|
vec_load_mult_mma(&temp1, &va1[i], inp[0]);
|
||||||
|
vec_load_mult_mma(&temp2, &va2[i], inp[0]);
|
||||||
|
vec_load_mult_mma(&temp3, &va3[i], inp[0]);
|
||||||
|
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
n &= 7;
|
||||||
|
if (n) {
|
||||||
|
inp[0] = vec_loadN(&v_x[i], n);
|
||||||
|
|
||||||
|
vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n);
|
||||||
|
vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n);
|
||||||
|
vec_loadN_mult_mma(&temp2, &va2[i], inp[0], n);
|
||||||
|
vec_loadN_mult_mma(&temp3, &va3[i], inp[0], n);
|
||||||
|
}
|
||||||
|
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp00, &temp0);
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp01, &temp1);
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp02, &temp2);
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp03, &temp3);
|
||||||
|
|
||||||
|
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7;
|
||||||
|
vec_f32 a = { alpha, alpha, alpha, alpha };
|
||||||
|
vec_f32 b = { beta, beta, beta, beta };
|
||||||
|
vec_f32 *v_y = (vec_f32 *) y;
|
||||||
|
|
||||||
|
t0 = vec_mergeh(temp00[0], temp01[0]);
|
||||||
|
t1 = vec_mergeh(temp02[0], temp03[0]);
|
||||||
|
t2 = vec_mergeo(temp00[1], temp01[1]);
|
||||||
|
t3 = vec_mergeo(temp02[1], temp03[1]);
|
||||||
|
t4 = vec_mergel(temp00[2], temp01[2]);
|
||||||
|
t5 = vec_mergel(temp02[2], temp03[2]);
|
||||||
|
t6 = vec_mergeo(temp00[3], temp01[3]);
|
||||||
|
t7 = vec_mergeo(temp02[3], temp03[3]);
|
||||||
|
t0 = vec_xxpermdi(t0, t1, 0);
|
||||||
|
t2 = vec_xxpermdi(t2, t3, 0);
|
||||||
|
t4 = vec_xxpermdi(t4, t5, 0);
|
||||||
|
t6 = vec_xxpermdi(t6, t7, 3);
|
||||||
|
|
||||||
|
t0 += t2 + t4 + t6;
|
||||||
|
|
||||||
|
v_y[0] = (a * t0) + (b * v_y[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef USE_BFGEMV_8_T_MMA
|
||||||
|
static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
||||||
|
{
|
||||||
|
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
|
||||||
|
vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x;
|
||||||
|
__vector_quad temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
|
||||||
|
vec_f32 temp00[4], temp01[4], temp02[4], temp03[4], temp04[4], temp05[4], temp06[4], temp07[4];
|
||||||
|
vec_bf16 inp[2];
|
||||||
|
|
||||||
|
__builtin_mma_xxsetaccz(&temp0);
|
||||||
|
__builtin_mma_xxsetaccz(&temp1);
|
||||||
|
__builtin_mma_xxsetaccz(&temp2);
|
||||||
|
__builtin_mma_xxsetaccz(&temp3);
|
||||||
|
__builtin_mma_xxsetaccz(&temp4);
|
||||||
|
__builtin_mma_xxsetaccz(&temp5);
|
||||||
|
__builtin_mma_xxsetaccz(&temp6);
|
||||||
|
__builtin_mma_xxsetaccz(&temp7);
|
||||||
|
|
||||||
|
a0 = ap;
|
||||||
|
a1 = ap + lda;
|
||||||
|
a2 = a1 + lda;
|
||||||
|
a3 = a2 + lda;
|
||||||
|
a4 = a3 + lda;
|
||||||
|
a5 = a4 + lda;
|
||||||
|
a6 = a5 + lda;
|
||||||
|
a7 = a6 + lda;
|
||||||
|
va0 = (vec_bf16 *)a0;
|
||||||
|
va1 = (vec_bf16 *)a1;
|
||||||
|
va2 = (vec_bf16 *)a2;
|
||||||
|
va3 = (vec_bf16 *)a3;
|
||||||
|
va4 = (vec_bf16 *)a4;
|
||||||
|
va5 = (vec_bf16 *)a5;
|
||||||
|
va6 = (vec_bf16 *)a6;
|
||||||
|
va7 = (vec_bf16 *)a7;
|
||||||
|
v_x = (vec_bf16 *)x;
|
||||||
|
BLASLONG n8 = n / 8;
|
||||||
|
BLASLONG i = 0;
|
||||||
|
|
||||||
|
for (; i + 2 <= n8; i += 2) {
|
||||||
|
vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]);
|
||||||
|
|
||||||
|
vec_load_mult2_mma(&temp0, &va0[i + 0], inp);
|
||||||
|
vec_load_mult2_mma(&temp1, &va1[i + 0], inp);
|
||||||
|
vec_load_mult2_mma(&temp2, &va2[i + 0], inp);
|
||||||
|
vec_load_mult2_mma(&temp3, &va3[i + 0], inp);
|
||||||
|
vec_load_mult2_mma(&temp4, &va4[i + 0], inp);
|
||||||
|
vec_load_mult2_mma(&temp5, &va5[i + 0], inp);
|
||||||
|
vec_load_mult2_mma(&temp6, &va6[i + 0], inp);
|
||||||
|
vec_load_mult2_mma(&temp7, &va7[i + 0], inp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n8 & 1) {
|
||||||
|
inp[0] = (vec_bf16)vec_load_vec(&v_x[i]);
|
||||||
|
|
||||||
|
vec_load_mult_mma(&temp0, &va0[i], inp[0]);
|
||||||
|
vec_load_mult_mma(&temp1, &va1[i], inp[0]);
|
||||||
|
vec_load_mult_mma(&temp2, &va2[i], inp[0]);
|
||||||
|
vec_load_mult_mma(&temp3, &va3[i], inp[0]);
|
||||||
|
vec_load_mult_mma(&temp4, &va4[i], inp[0]);
|
||||||
|
vec_load_mult_mma(&temp5, &va5[i], inp[0]);
|
||||||
|
vec_load_mult_mma(&temp6, &va6[i], inp[0]);
|
||||||
|
vec_load_mult_mma(&temp7, &va7[i], inp[0]);
|
||||||
|
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
n &= 7;
|
||||||
|
if (n) {
|
||||||
|
inp[0] = vec_loadN(&v_x[i], n);
|
||||||
|
|
||||||
|
vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n);
|
||||||
|
vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n);
|
||||||
|
vec_loadN_mult_mma(&temp2, &va2[i], inp[0], n);
|
||||||
|
vec_loadN_mult_mma(&temp3, &va3[i], inp[0], n);
|
||||||
|
vec_loadN_mult_mma(&temp4, &va4[i], inp[0], n);
|
||||||
|
vec_loadN_mult_mma(&temp5, &va5[i], inp[0], n);
|
||||||
|
vec_loadN_mult_mma(&temp6, &va6[i], inp[0], n);
|
||||||
|
vec_loadN_mult_mma(&temp7, &va7[i], inp[0], n);
|
||||||
|
}
|
||||||
|
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp00, &temp0);
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp01, &temp1);
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp02, &temp2);
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp03, &temp3);
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp04, &temp4);
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp05, &temp5);
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp06, &temp6);
|
||||||
|
__builtin_mma_disassemble_acc((void*)temp07, &temp7);
|
||||||
|
|
||||||
|
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17;
|
||||||
|
vec_f32 a = { alpha, alpha, alpha, alpha };
|
||||||
|
vec_f32 b = { beta, beta, beta, beta };
|
||||||
|
vec_f32 *v_y = (vec_f32 *) y;
|
||||||
|
|
||||||
|
t0 = vec_mergeh(temp00[0], temp01[0]);
|
||||||
|
t1 = vec_mergeh(temp02[0], temp03[0]);
|
||||||
|
t2 = vec_mergeo(temp00[1], temp01[1]);
|
||||||
|
t3 = vec_mergeo(temp02[1], temp03[1]);
|
||||||
|
t4 = vec_mergel(temp00[2], temp01[2]);
|
||||||
|
t5 = vec_mergel(temp02[2], temp03[2]);
|
||||||
|
t6 = vec_mergeo(temp00[3], temp01[3]);
|
||||||
|
t7 = vec_mergeo(temp02[3], temp03[3]);
|
||||||
|
t0 = vec_xxpermdi(t0, t1, 0);
|
||||||
|
t2 = vec_xxpermdi(t2, t3, 0);
|
||||||
|
t4 = vec_xxpermdi(t4, t5, 0);
|
||||||
|
t6 = vec_xxpermdi(t6, t7, 3);
|
||||||
|
|
||||||
|
t0 += t2 + t4 + t6;
|
||||||
|
|
||||||
|
t10 = vec_mergeh(temp04[0], temp05[0]);
|
||||||
|
t11 = vec_mergeh(temp06[0], temp07[0]);
|
||||||
|
t12 = vec_mergeo(temp04[1], temp05[1]);
|
||||||
|
t13 = vec_mergeo(temp06[1], temp07[1]);
|
||||||
|
t14 = vec_mergel(temp04[2], temp05[2]);
|
||||||
|
t15 = vec_mergel(temp06[2], temp07[2]);
|
||||||
|
t16 = vec_mergeo(temp04[3], temp05[3]);
|
||||||
|
t17 = vec_mergeo(temp06[3], temp07[3]);
|
||||||
|
t10 = vec_xxpermdi(t10, t11, 0);
|
||||||
|
t12 = vec_xxpermdi(t12, t13, 0);
|
||||||
|
t14 = vec_xxpermdi(t14, t15, 0);
|
||||||
|
t16 = vec_xxpermdi(t16, t17, 3);
|
||||||
|
|
||||||
|
t10 += t12 + t14 + t16;
|
||||||
|
|
||||||
|
vec_f32 inp2[2];
|
||||||
|
vec_load_pair(inp2, v_y);
|
||||||
|
inp2[0] = (a * t0) + (b * inp2[0]);
|
||||||
|
inp2[1] = (a * t10) + (b * inp2[1]);
|
||||||
|
vec_store_pair(v_y, inp2);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "sbgemv_t.c"
|
||||||
|
#else
|
||||||
#include "sbgemv_t_vsx.c"
|
#include "sbgemv_t_vsx.c"
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
//#include "sbgemv_t.c"
|
|
||||||
|
|
|
@ -25,12 +25,20 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
*****************************************************************************/
|
*****************************************************************************/
|
||||||
|
|
||||||
#ifndef SBGEMV_T_VSX
|
#ifndef SBGEMV_T_VSX_C
|
||||||
#define SBGEMV_T_VSX
|
#define SBGEMV_T_VSX_C
|
||||||
|
|
||||||
#include "sbgemv_common.c"
|
#include "sbgemv_common.c"
|
||||||
|
|
||||||
#define NBMAX 4096
|
#ifndef BF16GEMV_T_X
|
||||||
|
#define BF16GEMV_T_X
|
||||||
|
#define BF16GEMV_T_8 BF16GEMV_T_VSX_8
|
||||||
|
#define BF16GEMV_T_4 BF16GEMV_T_VSX_4
|
||||||
|
#define BF16GEMV_T_2 BF16GEMV_T_VSX_2
|
||||||
|
#define BF16GEMV_T_1 BF16GEMV_T_VSX_1
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define USE_BFGEMV_8_T_VSX
|
||||||
|
|
||||||
static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
||||||
{
|
{
|
||||||
|
@ -58,9 +66,9 @@ static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
|
|
||||||
temp0 += vec_loadN_mult(&va0[i], inp, n, zero);
|
temp0 += vec_loadN_mult(&va0[i], inp, n, zero);
|
||||||
} else if (n) {
|
} else if (n) {
|
||||||
vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero);
|
inp[0] = vec_loadNHi_vec(v_x, i, n, zero);
|
||||||
|
|
||||||
temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero);
|
temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero);
|
||||||
}
|
}
|
||||||
|
|
||||||
y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]);
|
y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]);
|
||||||
|
@ -97,10 +105,10 @@ static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
temp0 += vec_loadN_mult(&va0[i], inp, n, zero);
|
temp0 += vec_loadN_mult(&va0[i], inp, n, zero);
|
||||||
temp1 += vec_loadN_mult(&va1[i], inp, n, zero);
|
temp1 += vec_loadN_mult(&va1[i], inp, n, zero);
|
||||||
} else if (n) {
|
} else if (n) {
|
||||||
vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero);
|
inp[0] = vec_loadNHi_vec(v_x, i, n, zero);
|
||||||
|
|
||||||
temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero);
|
temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero);
|
||||||
temp1 += vec_loadNHi_mult(&va1[i], v_inp0, n, zero);
|
temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero);
|
||||||
}
|
}
|
||||||
|
|
||||||
y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]);
|
y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]);
|
||||||
|
@ -148,12 +156,12 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
temp2 += vec_loadN_mult(&va2[i], inp, n, zero);
|
temp2 += vec_loadN_mult(&va2[i], inp, n, zero);
|
||||||
temp3 += vec_loadN_mult(&va3[i], inp, n, zero);
|
temp3 += vec_loadN_mult(&va3[i], inp, n, zero);
|
||||||
} else if (n) {
|
} else if (n) {
|
||||||
vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero);
|
inp[0] = vec_loadNHi_vec(v_x, i, n, zero);
|
||||||
|
|
||||||
temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero);
|
temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero);
|
||||||
temp1 += vec_loadNHi_mult(&va1[i], v_inp0, n, zero);
|
temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero);
|
||||||
temp2 += vec_loadNHi_mult(&va2[i], v_inp0, n, zero);
|
temp2 += vec_loadNHi_mult(&va2[i], inp[0], n, zero);
|
||||||
temp3 += vec_loadNHi_mult(&va3[i], v_inp0, n, zero);
|
temp3 += vec_loadNHi_mult(&va3[i], inp[0], n, zero);
|
||||||
}
|
}
|
||||||
|
|
||||||
vec_f32 t0, t1, t2, t3;
|
vec_f32 t0, t1, t2, t3;
|
||||||
|
@ -174,6 +182,7 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
v_y[0] = (a * temp0) + (b * v_y[0]);
|
v_y[0] = (a * temp0) + (b * v_y[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_BFGEMV_8_T_VSX
|
||||||
static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
||||||
{
|
{
|
||||||
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
|
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
|
||||||
|
@ -235,16 +244,16 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
temp6 += vec_loadN_mult(&va6[i], inp, n, zero);
|
temp6 += vec_loadN_mult(&va6[i], inp, n, zero);
|
||||||
temp7 += vec_loadN_mult(&va7[i], inp, n, zero);
|
temp7 += vec_loadN_mult(&va7[i], inp, n, zero);
|
||||||
} else if (n) {
|
} else if (n) {
|
||||||
vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero);
|
inp[0] = vec_loadNHi_vec(v_x, i, n, zero);
|
||||||
|
|
||||||
temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero);
|
temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero);
|
||||||
temp1 += vec_loadNHi_mult(&va1[i], v_inp0, n, zero);
|
temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero);
|
||||||
temp2 += vec_loadNHi_mult(&va2[i], v_inp0, n, zero);
|
temp2 += vec_loadNHi_mult(&va2[i], inp[0], n, zero);
|
||||||
temp3 += vec_loadNHi_mult(&va3[i], v_inp0, n, zero);
|
temp3 += vec_loadNHi_mult(&va3[i], inp[0], n, zero);
|
||||||
temp4 += vec_loadNHi_mult(&va4[i], v_inp0, n, zero);
|
temp4 += vec_loadNHi_mult(&va4[i], inp[0], n, zero);
|
||||||
temp5 += vec_loadNHi_mult(&va5[i], v_inp0, n, zero);
|
temp5 += vec_loadNHi_mult(&va5[i], inp[0], n, zero);
|
||||||
temp6 += vec_loadNHi_mult(&va6[i], v_inp0, n, zero);
|
temp6 += vec_loadNHi_mult(&va6[i], inp[0], n, zero);
|
||||||
temp7 += vec_loadNHi_mult(&va7[i], v_inp0, n, zero);
|
temp7 += vec_loadNHi_mult(&va7[i], inp[0], n, zero);
|
||||||
}
|
}
|
||||||
|
|
||||||
vec_f32 t0, t1, t2, t3;
|
vec_f32 t0, t1, t2, t3;
|
||||||
|
@ -272,14 +281,12 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
temp7 = vec_mergel(t1, t3);
|
temp7 = vec_mergel(t1, t3);
|
||||||
temp4 += temp5 + temp6 + temp7;
|
temp4 += temp5 + temp6 + temp7;
|
||||||
|
|
||||||
v_y[0] = (a * temp0) + (b * v_y[0]);
|
vec_load_pair(inp, v_y);
|
||||||
v_y[1] = (a * temp4) + (b * v_y[1]);
|
inp[0] = (a * temp0) + (b * inp[0]);
|
||||||
|
inp[1] = (a * temp4) + (b * inp[1]);
|
||||||
|
vec_store_pair(v_y, inp);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
#define BF16GEMV_T_8 BF16GEMV_T_VSX_8
|
|
||||||
#define BF16GEMV_T_4 BF16GEMV_T_VSX_4
|
|
||||||
#define BF16GEMV_T_2 BF16GEMV_T_VSX_2
|
|
||||||
#define BF16GEMV_T_1 BF16GEMV_T_VSX_1
|
|
||||||
|
|
||||||
#include "sbgemv_t.c"
|
#include "sbgemv_t.c"
|
||||||
#endif
|
#endif
|
||||||
|
|
Loading…
Reference in New Issue