Remove beta from optimized functions.
This commit is contained in:
parent
7cc00f68c9
commit
7ec3c16d82
|
@ -282,6 +282,7 @@ GEMM_GEMV_FORWARD = 1
|
||||||
endif
|
endif
|
||||||
ifeq ($(ARCH), power)
|
ifeq ($(ARCH), power)
|
||||||
GEMM_GEMV_FORWARD = 1
|
GEMM_GEMV_FORWARD = 1
|
||||||
|
GEMM_GEMV_FORWARD_BF16 = 1
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ifeq ($(SMALL_MATRIX_OPT), 1)
|
ifeq ($(SMALL_MATRIX_OPT), 1)
|
||||||
|
|
|
@ -122,7 +122,10 @@ FORCEINLINE void copy_x(BLASLONG n, IFLOAT *src, IFLOAT *dest, BLASLONG inc_src)
|
||||||
FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta)
|
FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta)
|
||||||
{
|
{
|
||||||
if (beta == 0) {
|
if (beta == 0) {
|
||||||
memset(dest, 0, sizeof(FLOAT) * n);
|
for (BLASLONG i = 0; i < n; i++) {
|
||||||
|
*dest++ = (FLOAT)0;
|
||||||
|
src += inc_src;
|
||||||
|
}
|
||||||
} else if (beta == 1) {
|
} else if (beta == 1) {
|
||||||
for (BLASLONG i = 0; i < n; i++) {
|
for (BLASLONG i = 0; i < n; i++) {
|
||||||
*dest++ = *src;
|
*dest++ = *src;
|
||||||
|
@ -163,4 +166,64 @@ FORCEINLINE void move_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest)
|
||||||
dest += inc_dest;
|
dest += inc_dest;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta)
|
||||||
|
{
|
||||||
|
if (beta == 0) {
|
||||||
|
memset(output_vector, 0, sizeof(FLOAT) * n);
|
||||||
|
} else if (beta == 1) {
|
||||||
|
if (output_vector != input_vector) {
|
||||||
|
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
vec_f32 b = { beta, beta, beta, beta };
|
||||||
|
|
||||||
|
vec_f32 *in = (vec_f32 *)input_vector;
|
||||||
|
vec_f32 *out = (vec_f32 *)output_vector;
|
||||||
|
|
||||||
|
BLASLONG n8 = n / 8;
|
||||||
|
BLASLONG i = 0;
|
||||||
|
vec_f32 v_inp0[2];
|
||||||
|
|
||||||
|
for (; i + 4 <= n8; i += 4) {
|
||||||
|
vec_f32 v_inp1[2], v_inp2[2], v_inp3[2];
|
||||||
|
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
|
||||||
|
vec_load_pair(v_inp1, &in[(i * 2) + 2]);
|
||||||
|
vec_load_pair(v_inp2, &in[(i * 2) + 4]);
|
||||||
|
vec_load_pair(v_inp3, &in[(i * 2) + 6]);
|
||||||
|
v_inp0[0] *= b;
|
||||||
|
v_inp0[1] *= b;
|
||||||
|
v_inp1[0] *= b;
|
||||||
|
v_inp1[1] *= b;
|
||||||
|
v_inp2[0] *= b;
|
||||||
|
v_inp2[1] *= b;
|
||||||
|
v_inp3[0] *= b;
|
||||||
|
v_inp3[1] *= b;
|
||||||
|
vec_store_pair(&out[(i * 2) + 0], v_inp0);
|
||||||
|
vec_store_pair(&out[(i * 2) + 2], v_inp1);
|
||||||
|
vec_store_pair(&out[(i * 2) + 4], v_inp2);
|
||||||
|
vec_store_pair(&out[(i * 2) + 6], v_inp3);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; i < n8; i++) {
|
||||||
|
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
|
||||||
|
v_inp0[0] *= b;
|
||||||
|
v_inp0[1] *= b;
|
||||||
|
vec_store_pair(&out[(i * 2) + 0], v_inp0);
|
||||||
|
}
|
||||||
|
|
||||||
|
n &= 7;
|
||||||
|
if (n > 4) {
|
||||||
|
BLASLONG n3 = n & 3;
|
||||||
|
vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3);
|
||||||
|
v_inp0[0] *= b;
|
||||||
|
v_inp0[1] *= b;
|
||||||
|
vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3);
|
||||||
|
} else if (n) {
|
||||||
|
v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n);
|
||||||
|
v_inp0[0] *= b;
|
||||||
|
vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -27,65 +27,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
#ifndef SBGEMV_N_COMMON_C
|
#ifndef SBGEMV_N_COMMON_C
|
||||||
#define SBGEMV_N_COMMON_C
|
#define SBGEMV_N_COMMON_C
|
||||||
static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta)
|
|
||||||
{
|
|
||||||
if (beta == 0) {
|
|
||||||
memset(output_vector, 0, sizeof(FLOAT) * n);
|
|
||||||
} else if (beta == 1) {
|
|
||||||
if (output_vector != input_vector) {
|
|
||||||
memcpy(output_vector, input_vector, sizeof(FLOAT) * n);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
vec_f32 b = { beta, beta, beta, beta };
|
|
||||||
|
|
||||||
vec_f32 *in = (vec_f32 *)input_vector;
|
|
||||||
vec_f32 *out = (vec_f32 *)output_vector;
|
|
||||||
|
|
||||||
BLASLONG n8 = n / 8;
|
|
||||||
BLASLONG i = 0;
|
|
||||||
vec_f32 v_inp0[2];
|
|
||||||
|
|
||||||
for (; i + 4 <= n8; i += 4) {
|
|
||||||
vec_f32 v_inp1[2], v_inp2[2], v_inp3[2];
|
|
||||||
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
|
|
||||||
vec_load_pair(v_inp1, &in[(i * 2) + 2]);
|
|
||||||
vec_load_pair(v_inp2, &in[(i * 2) + 4]);
|
|
||||||
vec_load_pair(v_inp3, &in[(i * 2) + 6]);
|
|
||||||
v_inp0[0] *= b;
|
|
||||||
v_inp0[1] *= b;
|
|
||||||
v_inp1[0] *= b;
|
|
||||||
v_inp1[1] *= b;
|
|
||||||
v_inp2[0] *= b;
|
|
||||||
v_inp2[1] *= b;
|
|
||||||
v_inp3[0] *= b;
|
|
||||||
v_inp3[1] *= b;
|
|
||||||
vec_store_pair(&out[(i * 2) + 0], v_inp0);
|
|
||||||
vec_store_pair(&out[(i * 2) + 2], v_inp1);
|
|
||||||
vec_store_pair(&out[(i * 2) + 4], v_inp2);
|
|
||||||
vec_store_pair(&out[(i * 2) + 6], v_inp3);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (; i < n8; i++) {
|
|
||||||
vec_load_pair(v_inp0, &in[(i * 2) + 0]);
|
|
||||||
v_inp0[0] *= b;
|
|
||||||
v_inp0[1] *= b;
|
|
||||||
vec_store_pair(&out[(i * 2) + 0], v_inp0);
|
|
||||||
}
|
|
||||||
|
|
||||||
n &= 7;
|
|
||||||
if (n > 4) {
|
|
||||||
BLASLONG n3 = n & 3;
|
|
||||||
vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3);
|
|
||||||
v_inp0[0] *= b;
|
|
||||||
v_inp0[1] *= b;
|
|
||||||
vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3);
|
|
||||||
} else if (n) {
|
|
||||||
v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n);
|
|
||||||
v_inp0[0] *= b;
|
|
||||||
vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_N_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_N_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_N_VSX))
|
#if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_N_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_N_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_N_VSX))
|
||||||
#define USE_N_8
|
#define USE_N_8
|
||||||
|
|
|
@ -41,6 +41,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||||
|
|
||||||
if ((m < 1) || (n < 1)) return 0;
|
if ((m < 1) || (n < 1)) return 0;
|
||||||
|
|
||||||
|
if (inc_y == 1) {
|
||||||
|
BF16GEMV_N_beta(n, y, y, beta);
|
||||||
|
}
|
||||||
|
|
||||||
xbuffer = buffer;
|
xbuffer = buffer;
|
||||||
|
|
||||||
BLASLONG lda4 = lda << 2;
|
BLASLONG lda4 = lda << 2;
|
||||||
|
@ -58,18 +62,21 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||||
}
|
}
|
||||||
|
|
||||||
a_ptr = a;
|
a_ptr = a;
|
||||||
|
a += NB;
|
||||||
y_ptr = y;
|
y_ptr = y;
|
||||||
|
|
||||||
if (inc_x != 1) {
|
if (inc_x != 1) {
|
||||||
copy_x(NB, x, xbuffer, inc_x);
|
copy_x(NB, x, xbuffer, inc_x);
|
||||||
|
x += NB * inc_x;
|
||||||
} else {
|
} else {
|
||||||
xbuffer = x;
|
xbuffer = x;
|
||||||
|
x += NB;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (inc_y == 1) {
|
if (inc_y == 1) {
|
||||||
#ifdef USE_T_8
|
#ifdef USE_T_8
|
||||||
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
||||||
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
|
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
|
||||||
y_ptr += 8;
|
y_ptr += 8;
|
||||||
a_ptr += lda8;
|
a_ptr += lda8;
|
||||||
}
|
}
|
||||||
|
@ -77,23 +84,23 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||||
#else
|
#else
|
||||||
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
|
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
|
||||||
#endif
|
#endif
|
||||||
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
|
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
|
||||||
y_ptr += 4;
|
y_ptr += 4;
|
||||||
a_ptr += lda4;
|
a_ptr += lda4;
|
||||||
}
|
}
|
||||||
if (n & 2) {
|
if (n & 2) {
|
||||||
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
|
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
|
||||||
y_ptr += 2;
|
y_ptr += 2;
|
||||||
a_ptr += (lda * 2);
|
a_ptr += (lda * 2);
|
||||||
}
|
}
|
||||||
if (n & 1) {
|
if (n & 1) {
|
||||||
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta);
|
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#ifdef USE_T_8
|
#ifdef USE_T_8
|
||||||
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
|
||||||
memset(ybuffer, 0, sizeof(FLOAT) * 8);
|
memset(ybuffer, 0, sizeof(FLOAT) * 8);
|
||||||
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
|
BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
|
||||||
copy_y(8, ybuffer, y_ptr, inc_y, beta);
|
copy_y(8, ybuffer, y_ptr, inc_y, beta);
|
||||||
y_ptr += 8 * inc_y;
|
y_ptr += 8 * inc_y;
|
||||||
a_ptr += lda8;
|
a_ptr += lda8;
|
||||||
|
@ -103,28 +110,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
|
||||||
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
|
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
|
||||||
#endif
|
#endif
|
||||||
memset(ybuffer, 0, sizeof(FLOAT) * 4);
|
memset(ybuffer, 0, sizeof(FLOAT) * 4);
|
||||||
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
|
BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
|
||||||
copy_y(4, ybuffer, y_ptr, inc_y, beta);
|
copy_y(4, ybuffer, y_ptr, inc_y, beta);
|
||||||
y_ptr += 4 * inc_y;
|
y_ptr += 4 * inc_y;
|
||||||
a_ptr += lda4;
|
a_ptr += lda4;
|
||||||
}
|
}
|
||||||
if (n & 2) {
|
if (n & 2) {
|
||||||
memset(ybuffer, 0, sizeof(FLOAT) * 4);
|
memset(ybuffer, 0, sizeof(FLOAT) * 4);
|
||||||
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
|
BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
|
||||||
copy_y(2, ybuffer, y_ptr, inc_y, beta);
|
copy_y(2, ybuffer, y_ptr, inc_y, beta);
|
||||||
y_ptr += 2 * inc_y;
|
y_ptr += 2 * inc_y;
|
||||||
a_ptr += (lda * 2);
|
a_ptr += (lda * 2);
|
||||||
}
|
}
|
||||||
if (n & 1) {
|
if (n & 1) {
|
||||||
memset(ybuffer, 0, sizeof(FLOAT) * 4);
|
memset(ybuffer, 0, sizeof(FLOAT) * 4);
|
||||||
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta);
|
BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha);
|
||||||
copy_y(1, ybuffer, y_ptr, inc_y, beta);
|
copy_y(1, ybuffer, y_ptr, inc_y, beta);
|
||||||
}
|
}
|
||||||
|
beta = (FLOAT)1;
|
||||||
}
|
}
|
||||||
|
|
||||||
a += NB;
|
|
||||||
x += NB * inc_x;
|
|
||||||
beta = (FLOAT)1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
@ -43,7 +43,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
#define USE_BFGEMV_8_T_MMA
|
#define USE_BFGEMV_8_T_MMA
|
||||||
|
|
||||||
static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||||
{
|
{
|
||||||
IFLOAT *a0;
|
IFLOAT *a0;
|
||||||
vec_bf16 *va0, *v_x;
|
vec_bf16 *va0, *v_x;
|
||||||
|
@ -90,10 +90,10 @@ static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
|
|
||||||
__builtin_mma_disassemble_acc((void*)temp00, &temp0);
|
__builtin_mma_disassemble_acc((void*)temp00, &temp0);
|
||||||
|
|
||||||
y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]);
|
y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||||
{
|
{
|
||||||
IFLOAT *a0, *a1;
|
IFLOAT *a0, *a1;
|
||||||
vec_bf16 *va0, *va1, *v_x;
|
vec_bf16 *va0, *va1, *v_x;
|
||||||
|
@ -142,11 +142,11 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
|
|
||||||
vec_reduce_2(temp00, &temp0[0]);
|
vec_reduce_2(temp00, &temp0[0]);
|
||||||
|
|
||||||
y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]);
|
y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3]));
|
||||||
y[1] = (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3])) + (beta * y[1]);
|
y[1] += (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||||
{
|
{
|
||||||
IFLOAT *a0, *a1, *a2, *a3;
|
IFLOAT *a0, *a1, *a2, *a3;
|
||||||
vec_bf16 *va0, *va1, *va2, *va3, *v_x;
|
vec_bf16 *va0, *va1, *va2, *va3, *v_x;
|
||||||
|
@ -201,7 +201,6 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
|
|
||||||
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7;
|
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7;
|
||||||
vec_f32 a = { alpha, alpha, alpha, alpha };
|
vec_f32 a = { alpha, alpha, alpha, alpha };
|
||||||
vec_f32 b = { beta, beta, beta, beta };
|
|
||||||
vec_f32 *v_y = (vec_f32 *) y;
|
vec_f32 *v_y = (vec_f32 *) y;
|
||||||
|
|
||||||
t0 = vec_mergeh(temp00[ 0], temp00[ 4]);
|
t0 = vec_mergeh(temp00[ 0], temp00[ 4]);
|
||||||
|
@ -219,11 +218,11 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
|
|
||||||
t0 += t2 + t4 + t6;
|
t0 += t2 + t4 + t6;
|
||||||
|
|
||||||
v_y[0] = (a * t0) + (b * v_y[0]);
|
v_y[0] += (a * t0);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_BFGEMV_8_T_MMA
|
#ifdef USE_BFGEMV_8_T_MMA
|
||||||
static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||||
{
|
{
|
||||||
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
|
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
|
||||||
vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x;
|
vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x;
|
||||||
|
@ -291,7 +290,6 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
|
|
||||||
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17;
|
vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17;
|
||||||
vec_f32 a = { alpha, alpha, alpha, alpha };
|
vec_f32 a = { alpha, alpha, alpha, alpha };
|
||||||
vec_f32 b = { beta, beta, beta, beta };
|
|
||||||
vec_f32 *v_y = (vec_f32 *) y;
|
vec_f32 *v_y = (vec_f32 *) y;
|
||||||
|
|
||||||
t0 = vec_mergeh(temp00[ 0], temp00[ 4]);
|
t0 = vec_mergeh(temp00[ 0], temp00[ 4]);
|
||||||
|
@ -326,8 +324,8 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
|
|
||||||
vec_f32 inp2[2];
|
vec_f32 inp2[2];
|
||||||
vec_load_pair(inp2, v_y);
|
vec_load_pair(inp2, v_y);
|
||||||
inp2[0] = (a * t0) + (b * inp2[0]);
|
inp2[0] += (a * t0);
|
||||||
inp2[1] = (a * t10) + (b * inp2[1]);
|
inp2[1] += (a * t10);
|
||||||
vec_store_pair(v_y, inp2);
|
vec_store_pair(v_y, inp2);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -40,7 +40,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
#define USE_BFGEMV_8_T_VSX
|
#define USE_BFGEMV_8_T_VSX
|
||||||
|
|
||||||
static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||||
{
|
{
|
||||||
IFLOAT *a0;
|
IFLOAT *a0;
|
||||||
vec_bf16 *va0, *v_x;
|
vec_bf16 *va0, *v_x;
|
||||||
|
@ -71,10 +71,10 @@ static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero);
|
temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero);
|
||||||
}
|
}
|
||||||
|
|
||||||
y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]);
|
y[0] += (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||||
{
|
{
|
||||||
IFLOAT *a0, *a1;
|
IFLOAT *a0, *a1;
|
||||||
vec_bf16 *va0, *va1, *v_x;
|
vec_bf16 *va0, *va1, *v_x;
|
||||||
|
@ -111,11 +111,11 @@ static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero);
|
temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero);
|
||||||
}
|
}
|
||||||
|
|
||||||
y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]);
|
y[0] += (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3]));
|
||||||
y[1] = (alpha * (temp1[0] + temp1[1] + temp1[2] + temp1[3])) + (beta * y[1]);
|
y[1] += (alpha * (temp1[0] + temp1[1] + temp1[2] + temp1[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||||
{
|
{
|
||||||
IFLOAT *a0, *a1, *a2, *a3;
|
IFLOAT *a0, *a1, *a2, *a3;
|
||||||
vec_bf16 *va0, *va1, *va2, *va3, *v_x;
|
vec_bf16 *va0, *va1, *va2, *va3, *v_x;
|
||||||
|
@ -166,7 +166,6 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
|
|
||||||
vec_f32 t0, t1, t2, t3;
|
vec_f32 t0, t1, t2, t3;
|
||||||
vec_f32 a = { alpha, alpha, alpha, alpha };
|
vec_f32 a = { alpha, alpha, alpha, alpha };
|
||||||
vec_f32 b = { beta, beta, beta, beta };
|
|
||||||
vec_f32 *v_y = (vec_f32 *) y;
|
vec_f32 *v_y = (vec_f32 *) y;
|
||||||
|
|
||||||
t0 = vec_mergeh(temp0, temp2);
|
t0 = vec_mergeh(temp0, temp2);
|
||||||
|
@ -179,11 +178,11 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
temp3 = vec_mergel(t1, t3);
|
temp3 = vec_mergel(t1, t3);
|
||||||
temp0 += temp1 + temp2 + temp3;
|
temp0 += temp1 + temp2 + temp3;
|
||||||
|
|
||||||
v_y[0] = (a * temp0) + (b * v_y[0]);
|
v_y[0] += (a * temp0);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_BFGEMV_8_T_VSX
|
#ifdef USE_BFGEMV_8_T_VSX
|
||||||
static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta)
|
static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha)
|
||||||
{
|
{
|
||||||
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
|
IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7;
|
||||||
vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x;
|
vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x;
|
||||||
|
@ -259,7 +258,6 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
|
|
||||||
vec_f32 t0, t1, t2, t3, t10, t11, t12, t13;
|
vec_f32 t0, t1, t2, t3, t10, t11, t12, t13;
|
||||||
vec_f32 a = { alpha, alpha, alpha, alpha };
|
vec_f32 a = { alpha, alpha, alpha, alpha };
|
||||||
vec_f32 b = { beta, beta, beta, beta };
|
|
||||||
vec_f32 *v_y = (vec_f32 *) y;
|
vec_f32 *v_y = (vec_f32 *) y;
|
||||||
|
|
||||||
t0 = vec_mergeh(temp0, temp2);
|
t0 = vec_mergeh(temp0, temp2);
|
||||||
|
@ -283,8 +281,8 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL
|
||||||
temp4 += temp5 + temp6 + temp7;
|
temp4 += temp5 + temp6 + temp7;
|
||||||
|
|
||||||
vec_load_pair(inp, v_y);
|
vec_load_pair(inp, v_y);
|
||||||
inp[0] = (a * temp0) + (b * inp[0]);
|
inp[0] += (a * temp0);
|
||||||
inp[1] = (a * temp4) + (b * inp[1]);
|
inp[1] += (a * temp4);
|
||||||
vec_store_pair(v_y, inp);
|
vec_store_pair(v_y, inp);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
Loading…
Reference in New Issue