From 76227e2948b1f847abef003de5f6d49ea0dd3171 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Fri, 6 Sep 2024 14:03:31 -0500 Subject: [PATCH] Initial commit for vectorized BF16 GEMV. Added GEMM_GEMV_FORWARD_BF16 to enable using BF16 GEMV for one dimension matrices. Updated unit test to support inc_x != 1 or inc_y for GEMV. --- Makefile.system | 6 +- cmake/system.cmake | 3 + interface/gemm.c | 2 +- kernel/power/KERNEL.POWER10 | 2 + kernel/power/KERNEL.POWER8 | 2 + kernel/power/KERNEL.POWER9 | 2 + kernel/power/sbgemv_common.c | 285 ++++++++++++++++++++++++++++++ kernel/power/sbgemv_n.c | 189 ++++++++++++++++++++ kernel/power/sbgemv_n_power10.c | 33 ++++ kernel/power/sbgemv_n_vsx.c | 303 ++++++++++++++++++++++++++++++++ kernel/power/sbgemv_t.c | 117 ++++++++++++ kernel/power/sbgemv_t_power10.c | 32 ++++ kernel/power/sbgemv_t_vsx.c | 286 ++++++++++++++++++++++++++++++ test/compare_sgemm_sbgemm.c | 31 ++-- 14 files changed, 1277 insertions(+), 16 deletions(-) create mode 100644 kernel/power/sbgemv_common.c create mode 100644 kernel/power/sbgemv_n.c create mode 100644 kernel/power/sbgemv_n_power10.c create mode 100644 kernel/power/sbgemv_n_vsx.c create mode 100644 kernel/power/sbgemv_t.c create mode 100644 kernel/power/sbgemv_t_power10.c create mode 100644 kernel/power/sbgemv_t_vsx.c diff --git a/Makefile.system b/Makefile.system index b065f9a98..8c030842a 100644 --- a/Makefile.system +++ b/Makefile.system @@ -282,15 +282,19 @@ GEMM_GEMV_FORWARD = 1 endif ifeq ($(ARCH), power) GEMM_GEMV_FORWARD = 1 +GEMM_GEMV_FORWARD_BF16 = 1 endif ifeq ($(SMALL_MATRIX_OPT), 1) CCOMMON_OPT += -DSMALL_MATRIX_OPT endif -ifeq ($(GEMM_GEMV_FORWARD), 1) ifneq ($(ONLY_CBLAS), 1) +ifeq ($(GEMM_GEMV_FORWARD), 1) CCOMMON_OPT += -DGEMM_GEMV_FORWARD endif +ifeq ($(GEMM_GEMV_FORWARD_BF16), 1) +CCOMMON_OPT += -DGEMM_GEMV_FORWARD_BF16 +endif endif # This operation is expensive, so execution should be once. diff --git a/cmake/system.cmake b/cmake/system.cmake index a0b73ddae..fb2d350ab 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -398,6 +398,9 @@ endif () if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS) set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD") endif () +if (GEMM_GEMV_FORWARD_BF16 AND NOT ONLY_CBLAS) + set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD_BF16") +endif () if (SMALL_MATRIX_OPT) set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT") endif () diff --git a/interface/gemm.c b/interface/gemm.c index 64b8b620c..7cd0884fa 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -498,7 +498,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS args.m, args.n, args.k, args.lda, args.ldb, args.ldc); #endif -#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && !defined(BFLOAT16) +#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16)) // Check if we can convert GEMM -> GEMV if (args.k != 0) { if (args.n == 1) { diff --git a/kernel/power/KERNEL.POWER10 b/kernel/power/KERNEL.POWER10 index c84cd91d2..956b401fb 100644 --- a/kernel/power/KERNEL.POWER10 +++ b/kernel/power/KERNEL.POWER10 @@ -228,11 +228,13 @@ ZSWAPKERNEL = zswap.c # SGEMVNKERNEL = sgemv_n.c +SBGEMVNKERNEL = sbgemv_n_power10.c DGEMVNKERNEL = dgemv_n_power10.c CGEMVNKERNEL = cgemv_n.c ZGEMVNKERNEL = zgemv_n_power10.c # SGEMVTKERNEL = sgemv_t.c +SBGEMVTKERNEL = sbgemv_t_power10.c DGEMVTKERNEL = dgemv_t_power10.c CGEMVTKERNEL = cgemv_t.c ZGEMVTKERNEL = zgemv_t_4.c diff --git a/kernel/power/KERNEL.POWER8 b/kernel/power/KERNEL.POWER8 index 700a68e44..001401d53 100644 --- a/kernel/power/KERNEL.POWER8 +++ b/kernel/power/KERNEL.POWER8 @@ -257,11 +257,13 @@ ZSWAPKERNEL = zswap.c # SGEMVNKERNEL = sgemv_n.c +SBGEMVNKERNEL = sbgemv_n_vsx.c DGEMVNKERNEL = dgemv_n.c CGEMVNKERNEL = cgemv_n.c ZGEMVNKERNEL = zgemv_n_4.c # SGEMVTKERNEL = sgemv_t.c +SBGEMVTKERNEL = sbgemv_t_vsx.c DGEMVTKERNEL = dgemv_t.c CGEMVTKERNEL = cgemv_t.c ZGEMVTKERNEL = zgemv_t_4.c diff --git a/kernel/power/KERNEL.POWER9 b/kernel/power/KERNEL.POWER9 index 7d007d1a2..a18c31a2e 100644 --- a/kernel/power/KERNEL.POWER9 +++ b/kernel/power/KERNEL.POWER9 @@ -181,11 +181,13 @@ ZSWAPKERNEL = zswap.c # SGEMVNKERNEL = sgemv_n.c +SBGEMVNKERNEL = sbgemv_n_vsx.c DGEMVNKERNEL = dgemv_n.c CGEMVNKERNEL = cgemv_n.c ZGEMVNKERNEL = zgemv_n_4.c # SGEMVTKERNEL = sgemv_t.c +SBGEMVTKERNEL = sbgemv_t_vsx.c DGEMVTKERNEL = dgemv_t.c CGEMVTKERNEL = cgemv_t.c ZGEMVTKERNEL = zgemv_t_4.c diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c new file mode 100644 index 000000000..2aadcca6f --- /dev/null +++ b/kernel/power/sbgemv_common.c @@ -0,0 +1,285 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_COMMON_C +#define SBGEMV_COMMON_C +#include "common.h" + +#include + +#define FORCEINLINE inline __attribute__((always_inline)) + +#ifdef __clang__ +#define uint16_t unsigned short +#define uint32_t unsigned int +#define uint64_t unsigned long long +#endif + +#ifdef _ARCH_PWR10 +#ifdef __has_builtin +#if !__has_builtin(__builtin_vsx_assemble_pair) +#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair +#endif +#if !__has_builtin(__builtin_vsx_disassemble_pair) +#define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair +#endif +#endif + +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v1, v0) +#else +#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v0, v1) +#endif + +#define USE_VECTOR_PAIRS +#endif + +typedef __vector IFLOAT vec_bf16; +typedef __vector FLOAT vec_f32; +typedef __vector unsigned char vec_uc8; + +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define BF16_HI(data, zero) (vec_f32)vec_mergeh(data, zero) +#define BF16_LO(data, zero) (vec_f32)vec_mergel(data, zero) +#else +#define BF16_HI(data, zero) (vec_f32)vec_mergeh(zero, data) +#define BF16_LO(data, zero) (vec_f32)vec_mergel(zero, data) +#endif + +FORCEINLINE vec_uc8 vec_load_vec(void *src) +{ + return vec_xl(0, (unsigned char *)(src)); +} + +FORCEINLINE void vec_load_pair(vec_f32 *dst, vec_f32 *src) +{ +#ifdef USE_VECTOR_PAIRS + __vector_pair vy0p; + vy0p = *(__vector_pair *)(src); + __builtin_vsx_disassemble_pair((void *)(dst), &vy0p); +#else + dst[0] = src[0]; + dst[1] = src[1]; +#endif +} + +FORCEINLINE void vec_store_pair(vec_f32 *dst, vec_f32 *src) +{ +#ifdef USE_VECTOR_PAIRS + __vector_pair vy0p; + __builtin_vsx_assemble_pair2(&vy0p, (vec_uc8)src[1], (vec_uc8)src[0]); + *(__vector_pair *)(dst) = vy0p; +#else + dst[0] = src[0]; + dst[1] = src[1]; +#endif +} + +FORCEINLINE vec_bf16 vec_loadN(void *src, BLASLONG n) +{ + IFLOAT *src2 = (IFLOAT *)(src); +#ifdef _ARCH_PWR9 + return vec_xl_len(src2, n * sizeof(IFLOAT)); +#else + __attribute__((aligned(16))) IFLOAT data[sizeof(vec_bf16) / sizeof(IFLOAT)]; + memset(data, 0, sizeof(vec_bf16)); + if (n & 4) { + memcpy(data, src2, sizeof(uint64_t)); + } + if (n & 2) { + BLASLONG n4 = n & 4; + memcpy(data + n4, src2 + n4, sizeof(uint32_t)); + } + if (n & 1) { + BLASLONG n6 = n & 6; + data[n6] = src2[n6]; + } + return (vec_bf16)vec_load_vec(data); +#endif +} + +FORCEINLINE vec_f32 vec_loadNHi(void *src, BLASLONG n, vec_bf16 zero) +{ + vec_bf16 data = vec_loadN(src, n); + return BF16_HI(data, zero); +} + +FORCEINLINE vec_f32 vec_loadN_f32(void *src, BLASLONG n) +{ +#ifndef _ARCH_PWR9 + if (n & 4) { + return (vec_f32)vec_load_vec(src); + } +#endif + return (vec_f32)vec_loadN(src, n * (sizeof(FLOAT) / sizeof(IFLOAT))); +} + +FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n) +{ + FLOAT *dst2 = (FLOAT *)(dst); +#ifdef _ARCH_PWR9 + vec_xst_len(data, dst2, n * sizeof(FLOAT)); +#else + if (n & 4) { + vec_xst(data, 0, dst2); + return; + } + __attribute__((aligned(16))) FLOAT data2[sizeof(vec_f32) / sizeof(FLOAT)]; + vec_xst(data, 0, data2); + if (n & 2) { + memcpy(dst2, data2, sizeof(uint64_t)); + } + if (n & 1) { + BLASLONG n2 = n & 2; + dst2[n2] = data2[n2]; + } +#endif +} + +FORCEINLINE vec_f32 vec_mult(vec_f32 *inp, vec_bf16 in0, vec_bf16 zero) +{ + vec_f32 v_in00 = BF16_HI(in0, zero); + vec_f32 v_in01 = BF16_LO(in0, zero); + + return (inp[0] * v_in00) + (inp[1] * v_in01); +} + +FORCEINLINE vec_f32 vec_load_mult(vec_bf16 *in, vec_f32 *inp, vec_bf16 zero) +{ + vec_bf16 in0 = (vec_bf16)vec_load_vec(in); + + return vec_mult(inp, in0, zero); +} + +FORCEINLINE void vec_load_vec2(vec_bf16 *in, BLASLONG i, vec_f32 *v_x0, vec_bf16 zero) +{ + vec_bf16 inp = (vec_bf16)vec_load_vec(&in[i]); + + v_x0[0] = BF16_HI(inp, zero); + v_x0[1] = BF16_LO(inp, zero); +} + +FORCEINLINE void vec_mult2(vec_f32 v_x0, vec_bf16 in0, vec_bf16 zero, vec_f32 *vy0) +{ + vec_f32 v_in00 = BF16_HI(in0, zero); + vec_f32 v_in01 = BF16_LO(in0, zero); + + vy0[0] += (v_x0 * v_in00); + vy0[1] += (v_x0 * v_in01); +} + +FORCEINLINE void vec_load_mult2(vec_f32 v_x0, vec_bf16 *in, vec_bf16 zero, vec_f32 *vy0) +{ + vec_bf16 in0 = (vec_bf16)vec_load_vec(in); + + vec_mult2(v_x0, in0, zero, vy0); +} + +FORCEINLINE vec_f32 vec_loadN_mult(vec_bf16 *in, vec_f32 *inp, BLASLONG n, vec_bf16 zero) +{ + vec_bf16 in0 = vec_loadN(in, n); + + return vec_mult(inp, in0, zero); +} + +FORCEINLINE void vec_loadN_vec2(vec_bf16 *in, BLASLONG i, vec_f32 *v_x0, BLASLONG n, vec_bf16 zero) +{ + vec_bf16 inp = vec_loadN(&in[i], n); + + v_x0[0] = BF16_HI(inp, zero); + v_x0[1] = BF16_LO(inp, zero); +} + +FORCEINLINE void vec_loadN_mult2(vec_f32 v_x0, vec_bf16 *in, BLASLONG n, vec_bf16 zero, vec_f32 *vy0) +{ + vec_bf16 in0 = vec_loadN(in, n); + + vec_mult2(v_x0, in0, zero, vy0); +} + +FORCEINLINE vec_f32 vec_loadNHi_mult(vec_bf16 *in, vec_f32 v_inp0, BLASLONG n, vec_bf16 zero) +{ + vec_f32 v_in00 = vec_loadNHi(in, n, zero); + + return (v_inp0 * v_in00); +} + +FORCEINLINE vec_f32 vec_loadNHi_multi2(vec_f32 v_x0, vec_bf16 *in, BLASLONG n, vec_bf16 zero) +{ + vec_f32 v_in00 = vec_loadNHi(in, n, zero); + + return (v_x0 * v_in00); +} + +FORCEINLINE vec_f32 vec_loadNHi_vec(vec_bf16 *in, BLASLONG i, BLASLONG n, vec_bf16 zero) +{ + return vec_loadNHi(&in[i], n, zero); +} + +FORCEINLINE void copy_x(BLASLONG n, IFLOAT *src, IFLOAT *dest, BLASLONG inc_src) +{ + for (BLASLONG i = 0; i < n; i++) { + *dest++ = *src; + src += inc_src; + } +} + +FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta) +{ + if (beta == 0) { + memset(dest, 0, sizeof(FLOAT) * n); + } else { + for (BLASLONG i = 0; i < n; i++) { + *dest++ = *src * beta; + src += inc_src; + } + } +} + +FORCEINLINE void copy_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta) +{ + if (beta == 0) { + for (BLASLONG i = 0; i < n; i++) { + *dest = *src++; + dest += inc_src; + } + } else { + for (BLASLONG i = 0; i < n; i++) { + *dest = *src++ + (beta * *dest); + dest += inc_src; + } + } +} + +FORCEINLINE void add_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest) +{ + for (BLASLONG i = 0; i < n; i++) { + *dest += *src++; + dest += inc_dest; + } +} +#endif diff --git a/kernel/power/sbgemv_n.c b/kernel/power/sbgemv_n.c new file mode 100644 index 000000000..854ad93ee --- /dev/null +++ b/kernel/power/sbgemv_n.c @@ -0,0 +1,189 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_N_COMMON_C +#define SBGEMV_N_COMMON_C +static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta) +{ + if (beta == 0) { + memset(output_vector, 0, sizeof(FLOAT) * n); + } else { + vec_f32 b = { beta, beta, beta, beta }; + + vec_f32 *in = (vec_f32 *)input_vector; + vec_f32 *out = (vec_f32 *)output_vector; + + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 v_inp0[2]; + + for (; i + 4 <= n8; i += 4) { + vec_f32 v_inp1[2], v_inp2[2], v_inp3[2]; + vec_load_pair(v_inp0, &in[(i * 2) + 0]); + vec_load_pair(v_inp1, &in[(i * 2) + 2]); + vec_load_pair(v_inp2, &in[(i * 2) + 4]); + vec_load_pair(v_inp3, &in[(i * 2) + 6]); + v_inp0[0] *= b; + v_inp0[1] *= b; + v_inp1[0] *= b; + v_inp1[1] *= b; + v_inp2[0] *= b; + v_inp2[1] *= b; + v_inp3[0] *= b; + v_inp3[1] *= b; + vec_store_pair(&out[(i * 2) + 0], v_inp0); + vec_store_pair(&out[(i * 2) + 2], v_inp1); + vec_store_pair(&out[(i * 2) + 4], v_inp2); + vec_store_pair(&out[(i * 2) + 6], v_inp3); + } + + for (; i < n8; i++) { + vec_load_pair(v_inp0, &in[(i * 2) + 0]); + v_inp0[0] *= b; + v_inp0[1] *= b; + vec_store_pair(&out[(i * 2) + 0], v_inp0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + v_inp0[0] = in[(i * 2) + 0]; + v_inp0[1] = vec_loadN_f32(&in[(i * 2) + 1], n3); + v_inp0[0] *= b; + v_inp0[1] *= b; + out[(i * 2) + 0] = v_inp0[0]; + vec_storeN_f32(v_inp0[1], &out[(i * 2) + 1], n3); + } else if (n) { + v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n); + v_inp0[0] *= b; + vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n); + } + } +} + +int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y) +{ + IFLOAT *x_ptr, *ap[4]; + IFLOAT xbuffer[8] __attribute__((aligned(16))); + FLOAT *y_ptr, *ybuffer; + FLOAT buffer[NBMAX] __attribute__((aligned(16))); + + if ((m < 1) || (n < 1)) return 0; + + ybuffer = buffer; + y_ptr = y; + + BLASLONG lda4 = lda << 2; + BLASLONG lda8 = lda << 3; + BLASLONG NB = NBMAX; + BLASLONG m2 = (m & (NBMAX - 1)); + + while (NB == NBMAX) { + m -= NB; + if (m < 0) { + if (m2 == 0) break; + NB = m2; + } + + if (inc_y != 1) { + copy_y_beta(NB, y_ptr, ybuffer, inc_y, beta); + } else { + ybuffer = y_ptr; + BF16GEMV_N_beta(NB, ybuffer, ybuffer, beta); + } + + x_ptr = x; + + ap[0] = a; + ap[1] = a + lda; + ap[2] = ap[1] + lda; + ap[3] = ap[2] + lda; + + if (inc_x == 1) { + for (BLASLONG j = 0; j + 8 <= n; j += 8) { + BF16GEMV_N_8(NB, ap, x_ptr, ybuffer, lda4, alpha); + ap[0] += lda8; + ap[1] += lda8; + ap[2] += lda8; + ap[3] += lda8; + x_ptr += 8; + } + if (n & 4) { + BF16GEMV_N_4(NB, ap, x_ptr, ybuffer, alpha); + ap[0] += lda4; + ap[1] += lda4; + x_ptr += 4; + } + if (n & 2) { + BF16GEMV_N_2(NB, ap, x_ptr, ybuffer, alpha); + ap[0] += (lda * 2); + x_ptr += 2; + } + if (n & 1) { + BF16GEMV_N_1(NB, ap, x_ptr, ybuffer, alpha); + } + } else { + for (BLASLONG j = 0; j + 8 <= n; j += 8) { + copy_x(8, x_ptr, xbuffer, inc_x); + BF16GEMV_N_8(NB, ap, xbuffer, ybuffer, lda4, alpha); + ap[0] += lda8; + ap[1] += lda8; + ap[2] += lda8; + ap[3] += lda8; + x_ptr += 8 * inc_x; + } + if (n & 4) { + copy_x(4, x_ptr, xbuffer, inc_x); + BF16GEMV_N_4(NB, ap, xbuffer, ybuffer, alpha); + ap[0] += lda4; + ap[1] += lda4; + x_ptr += 4 * inc_x; + } + if (n & 2) { + copy_x(2, x_ptr, xbuffer, inc_x); + BF16GEMV_N_2(NB, ap, xbuffer, ybuffer, alpha); + ap[0] += (lda * 2); + x_ptr += 2 * inc_x; + } + if (n & 1) { + copy_x(1, x_ptr, xbuffer, inc_x); + BF16GEMV_N_1(NB, ap, xbuffer, ybuffer, alpha); + } + } + + a += NB; + if (inc_y != 1) { + add_y(NB, ybuffer, y_ptr, inc_y); + y_ptr += (NB * inc_y); + } else { + y_ptr += NB; + } + } + + return 0; +} +#endif diff --git a/kernel/power/sbgemv_n_power10.c b/kernel/power/sbgemv_n_power10.c new file mode 100644 index 000000000..fc83b38c3 --- /dev/null +++ b/kernel/power/sbgemv_n_power10.c @@ -0,0 +1,33 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +//#include "sbgemv_common.c" + +#include "sbgemv_n_vsx.c" + +//#include "sbgemv_n.c" + diff --git a/kernel/power/sbgemv_n_vsx.c b/kernel/power/sbgemv_n_vsx.c new file mode 100644 index 000000000..ddbf908b3 --- /dev/null +++ b/kernel/power/sbgemv_n_vsx.c @@ -0,0 +1,303 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_N_VSX +#define SBGEMV_N_VSX + +#include "sbgemv_common.c" + +#define NBMAX 4096 + +static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_f32 x_0 = vec_loadNHi(x_bf, 1, zero); + x_0 *= v_alpha; + + vec_f32 v_x0 = vec_splat(x_0, 0); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 vy0[2]; + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult2(v_x0, &va0[i], zero, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vy0[0] = v_y[(i * 2) + 0]; + vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + + vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); + + v_y[(i * 2) + 0] = vy0[0]; + vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + } else if (n) { + vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero); + + vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n); + } +} + +static void BF16GEMV_N_VSX_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_f32 x_0 = vec_loadNHi(x_bf, 2, zero); + x_0 *= v_alpha; + + vec_f32 v_x0 = vec_splat(x_0, 0); + vec_f32 v_x1 = vec_splat(x_0, 1); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 vy0[2]; + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult2(v_x0, &va0[i], zero, vy0); + vec_load_mult2(v_x1, &va1[i], zero, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vy0[0] = v_y[(i * 2) + 0]; + vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + + vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); + vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); + + v_y[(i * 2) + 0] = vy0[0]; + vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + } else if (n) { + vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero); + + vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n); + } +} + +static void BF16GEMV_N_VSX_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1, *a2, *a3; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + a2 = ap[2]; + a3 = ap[3]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + vec_bf16 *va2 = (vec_bf16 *)a2; + vec_bf16 *va3 = (vec_bf16 *)a3; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_f32 x_0 = vec_loadNHi(x_bf, 4, zero); + x_0 *= v_alpha; + + vec_f32 v_x0 = vec_splat(x_0, 0); + vec_f32 v_x1 = vec_splat(x_0, 1); + vec_f32 v_x2 = vec_splat(x_0, 2); + vec_f32 v_x3 = vec_splat(x_0, 3); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 vy0[2]; + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult2(v_x0, &va0[i], zero, vy0); + vec_load_mult2(v_x1, &va1[i], zero, vy0); + vec_load_mult2(v_x2, &va2[i], zero, vy0); + vec_load_mult2(v_x3, &va3[i], zero, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vy0[0] = v_y[(i * 2) + 0]; + vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + + vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); + vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); + vec_loadN_mult2(v_x2, &va2[i], n, zero, vy0); + vec_loadN_mult2(v_x3, &va3[i], n, zero, vy0); + + v_y[(i * 2) + 0] = vy0[0]; + vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + } else if (n) { + vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x2, &va2[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x3, &va3[i], n, zero); + + vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n); + } +} + +static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLASLONG lda4, FLOAT alpha) +{ + IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + a2 = ap[2]; + a3 = ap[3]; + b0 = a0 + lda4; + b1 = a1 + lda4; + b2 = a2 + lda4; + b3 = a3 + lda4; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + vec_bf16 *va2 = (vec_bf16 *)a2; + vec_bf16 *va3 = (vec_bf16 *)a3; + vec_bf16 *vb0 = (vec_bf16 *)b0; + vec_bf16 *vb1 = (vec_bf16 *)b1; + vec_bf16 *vb2 = (vec_bf16 *)b2; + vec_bf16 *vb3 = (vec_bf16 *)b3; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_bf16 x_in = (vec_bf16)vec_load_vec(x_bf); + vec_f32 x_0 = BF16_HI(x_in, zero); + vec_f32 x_1 = BF16_LO(x_in, zero); + x_0 *= v_alpha; + x_1 *= v_alpha; + + vec_f32 v_x0 = vec_splat(x_0, 0); + vec_f32 v_x1 = vec_splat(x_0, 1); + vec_f32 v_x2 = vec_splat(x_0, 2); + vec_f32 v_x3 = vec_splat(x_0, 3); + vec_f32 v_x4 = vec_splat(x_1, 0); + vec_f32 v_x5 = vec_splat(x_1, 1); + vec_f32 v_x6 = vec_splat(x_1, 2); + vec_f32 v_x7 = vec_splat(x_1, 3); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 vy0[2]; + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult2(v_x0, &va0[i], zero, vy0); + vec_load_mult2(v_x1, &va1[i], zero, vy0); + vec_load_mult2(v_x2, &va2[i], zero, vy0); + vec_load_mult2(v_x3, &va3[i], zero, vy0); + vec_load_mult2(v_x4, &vb0[i], zero, vy0); + vec_load_mult2(v_x5, &vb1[i], zero, vy0); + vec_load_mult2(v_x6, &vb2[i], zero, vy0); + vec_load_mult2(v_x7, &vb3[i], zero, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vy0[0] = v_y[(i * 2) + 0]; + vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + + vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); + vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); + vec_loadN_mult2(v_x2, &va2[i], n, zero, vy0); + vec_loadN_mult2(v_x3, &va3[i], n, zero, vy0); + vec_loadN_mult2(v_x4, &vb0[i], n, zero, vy0); + vec_loadN_mult2(v_x5, &vb1[i], n, zero, vy0); + vec_loadN_mult2(v_x6, &vb2[i], n, zero, vy0); + vec_loadN_mult2(v_x7, &vb3[i], n, zero, vy0); + + v_y[(i * 2) + 0] = vy0[0]; + vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + } else + if (n) { + vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x2, &va2[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x3, &va3[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x4, &vb0[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x5, &vb1[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x6, &vb2[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x7, &vb3[i], n, zero); + + vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n); + } +} + +#define BF16GEMV_N_8 BF16GEMV_N_VSX_8 +#define BF16GEMV_N_4 BF16GEMV_N_VSX_4 +#define BF16GEMV_N_2 BF16GEMV_N_VSX_2 +#define BF16GEMV_N_1 BF16GEMV_N_VSX_1 + +#include "sbgemv_n.c" +#endif diff --git a/kernel/power/sbgemv_t.c b/kernel/power/sbgemv_t.c new file mode 100644 index 000000000..f0c79fe77 --- /dev/null +++ b/kernel/power/sbgemv_t.c @@ -0,0 +1,117 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_T_COMMON_C +#define SBGEMV_T_COMMON_C +int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y) +{ + IFLOAT *xbuffer, *a_ptr; + IFLOAT buffer[NBMAX] __attribute__((aligned(16))); + FLOAT ybuffer[8] __attribute__((aligned(16))); + FLOAT *y_ptr; + + if ((m < 1) || (n < 1)) return 0; + + xbuffer = buffer; + + BLASLONG lda4 = lda << 2; + BLASLONG lda8 = lda << 3; + BLASLONG NB = NBMAX; + BLASLONG m2 = (m & (NBMAX - 1)); + + while (NB == NBMAX) { + m -= NB; + if (m < 0) { + if (m2 == 0) break; + NB = m2; + } + + a_ptr = a; + y_ptr = y; + + if (inc_x != 1) { + copy_x(NB, x, xbuffer, inc_x); + } else { + xbuffer = x; + } + + if (inc_y == 1) { + for (BLASLONG j = 0; j + 8 <= n; j += 8) { + BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + y_ptr += 8; + a_ptr += lda8; + } + if (n & 4) { + BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + y_ptr += 4; + a_ptr += lda4; + } + if (n & 2) { + BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + y_ptr += 2; + a_ptr += (lda * 2); + } + if (n & 1) { + BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + } + } else { + for (BLASLONG j = 0; j + 8 <= n; j += 8) { + memset(ybuffer, 0, sizeof(FLOAT) * 8); + BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + copy_y(8, ybuffer, y_ptr, inc_y, beta); + y_ptr += 8 * inc_y; + a_ptr += lda8; + } + if (n & 4) { + memset(ybuffer, 0, sizeof(FLOAT) * 4); + BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + copy_y(4, ybuffer, y_ptr, inc_y, beta); + y_ptr += 4 * inc_y; + a_ptr += lda4; + } + if (n & 2) { + memset(ybuffer, 0, sizeof(FLOAT) * 4); + BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + copy_y(2, ybuffer, y_ptr, inc_y, beta); + y_ptr += 2 * inc_y; + a_ptr += (lda * 2); + } + if (n & 1) { + memset(ybuffer, 0, sizeof(FLOAT) * 4); + BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + copy_y(1, ybuffer, y_ptr, inc_y, beta); + } + } + + a += NB; + x += NB * inc_x; + } + + return 0; +} +#endif + diff --git a/kernel/power/sbgemv_t_power10.c b/kernel/power/sbgemv_t_power10.c new file mode 100644 index 000000000..08bc4237c --- /dev/null +++ b/kernel/power/sbgemv_t_power10.c @@ -0,0 +1,32 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +//#include "sbgemv_common.c" + +#include "sbgemv_t_vsx.c" + +//#include "sbgemv_t.c" diff --git a/kernel/power/sbgemv_t_vsx.c b/kernel/power/sbgemv_t_vsx.c new file mode 100644 index 000000000..7da894109 --- /dev/null +++ b/kernel/power/sbgemv_t_vsx.c @@ -0,0 +1,286 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_T_VSX +#define SBGEMV_T_VSX + +#include "sbgemv_common.c" + +#define NBMAX 4096 + +static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +{ + IFLOAT *a0; + vec_bf16 *va0, *v_x; + vec_f32 temp0 = { 0, 0, 0, 0 }; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 inp[2]; + + a0 = ap; + va0 = (vec_bf16 *)a0; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i < n8; i++) { + vec_load_vec2(v_x, i, inp, zero); + + temp0 += vec_load_mult(&va0[i], inp, zero); + } + + n &= 7; + if (n > 4) { + vec_loadN_vec2(v_x, i, inp, n, zero); + + temp0 += vec_loadN_mult(&va0[i], inp, n, zero); + } else if (n) { + vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero); + + temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero); + } + + y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]); +} + +static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +{ + IFLOAT *a0, *a1; + vec_bf16 *va0, *va1, *v_x; + vec_f32 temp0 = { 0, 0, 0, 0 }; + vec_f32 temp1 = { 0, 0, 0, 0 }; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 inp[2]; + + a0 = ap; + a1 = ap + lda; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i < n8; i++) { + vec_load_vec2(v_x, i, inp, zero); + + temp0 += vec_load_mult(&va0[i], inp, zero); + temp1 += vec_load_mult(&va1[i], inp, zero); + } + + n &= 7; + if (n > 4) { + vec_loadN_vec2(v_x, i, inp, n, zero); + + temp0 += vec_loadN_mult(&va0[i], inp, n, zero); + temp1 += vec_loadN_mult(&va1[i], inp, n, zero); + } else if (n) { + vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero); + + temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero); + temp1 += vec_loadNHi_mult(&va1[i], v_inp0, n, zero); + } + + y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]); + y[1] = (alpha * (temp1[0] + temp1[1] + temp1[2] + temp1[3])) + (beta * y[1]); +} + +static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +{ + IFLOAT *a0, *a1, *a2, *a3; + vec_bf16 *va0, *va1, *va2, *va3, *v_x; + vec_f32 temp0 = { 0, 0, 0, 0 }; + vec_f32 temp1 = { 0, 0, 0, 0 }; + vec_f32 temp2 = { 0, 0, 0, 0 }; + vec_f32 temp3 = { 0, 0, 0, 0 }; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 inp[2]; + + a0 = ap; + a1 = ap + lda; + a2 = a1 + lda; + a3 = a2 + lda; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + va2 = (vec_bf16 *)a2; + va3 = (vec_bf16 *)a3; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i < n8; i++) { + vec_load_vec2(v_x, i, inp, zero); + + temp0 += vec_load_mult(&va0[i], inp, zero); + temp1 += vec_load_mult(&va1[i], inp, zero); + temp2 += vec_load_mult(&va2[i], inp, zero); + temp3 += vec_load_mult(&va3[i], inp, zero); + } + + n &= 7; + if (n > 4) { + vec_loadN_vec2(v_x, i, inp, n, zero); + + temp0 += vec_loadN_mult(&va0[i], inp, n, zero); + temp1 += vec_loadN_mult(&va1[i], inp, n, zero); + temp2 += vec_loadN_mult(&va2[i], inp, n, zero); + temp3 += vec_loadN_mult(&va3[i], inp, n, zero); + } else if (n) { + vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero); + + temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero); + temp1 += vec_loadNHi_mult(&va1[i], v_inp0, n, zero); + temp2 += vec_loadNHi_mult(&va2[i], v_inp0, n, zero); + temp3 += vec_loadNHi_mult(&va3[i], v_inp0, n, zero); + } + + vec_f32 t0, t1, t2, t3; + vec_f32 a = { alpha, alpha, alpha, alpha }; + vec_f32 b = { beta, beta, beta, beta }; + vec_f32 *v_y = (vec_f32 *) y; + + t0 = vec_mergeh(temp0, temp2); + t1 = vec_mergel(temp0, temp2); + t2 = vec_mergeh(temp1, temp3); + t3 = vec_mergel(temp1, temp3); + temp0 = vec_mergeh(t0, t2); + temp1 = vec_mergel(t0, t2); + temp2 = vec_mergeh(t1, t3); + temp3 = vec_mergel(t1, t3); + temp0 += temp1 + temp2 + temp3; + + v_y[0] = (a * temp0) + (b * v_y[0]); +} + +static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +{ + IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; + vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x; + vec_f32 temp0 = { 0, 0, 0, 0 }; + vec_f32 temp1 = { 0, 0, 0, 0 }; + vec_f32 temp2 = { 0, 0, 0, 0 }; + vec_f32 temp3 = { 0, 0, 0, 0 }; + vec_f32 temp4 = { 0, 0, 0, 0 }; + vec_f32 temp5 = { 0, 0, 0, 0 }; + vec_f32 temp6 = { 0, 0, 0, 0 }; + vec_f32 temp7 = { 0, 0, 0, 0 }; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 inp[2]; + + a0 = ap; + a1 = ap + lda; + a2 = a1 + lda; + a3 = a2 + lda; + a4 = a3 + lda; + a5 = a4 + lda; + a6 = a5 + lda; + a7 = a6 + lda; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + va2 = (vec_bf16 *)a2; + va3 = (vec_bf16 *)a3; + va4 = (vec_bf16 *)a4; + va5 = (vec_bf16 *)a5; + va6 = (vec_bf16 *)a6; + va7 = (vec_bf16 *)a7; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i < n8; i++) { + vec_load_vec2(v_x, i, inp, zero); + + temp0 += vec_load_mult(&va0[i], inp, zero); + temp1 += vec_load_mult(&va1[i], inp, zero); + temp2 += vec_load_mult(&va2[i], inp, zero); + temp3 += vec_load_mult(&va3[i], inp, zero); + temp4 += vec_load_mult(&va4[i], inp, zero); + temp5 += vec_load_mult(&va5[i], inp, zero); + temp6 += vec_load_mult(&va6[i], inp, zero); + temp7 += vec_load_mult(&va7[i], inp, zero); + } + + n &= 7; + if (n > 4) { + vec_loadN_vec2(v_x, i, inp, n, zero); + + temp0 += vec_loadN_mult(&va0[i], inp, n, zero); + temp1 += vec_loadN_mult(&va1[i], inp, n, zero); + temp2 += vec_loadN_mult(&va2[i], inp, n, zero); + temp3 += vec_loadN_mult(&va3[i], inp, n, zero); + temp4 += vec_loadN_mult(&va4[i], inp, n, zero); + temp5 += vec_loadN_mult(&va5[i], inp, n, zero); + temp6 += vec_loadN_mult(&va6[i], inp, n, zero); + temp7 += vec_loadN_mult(&va7[i], inp, n, zero); + } else if (n) { + vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero); + + temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero); + temp1 += vec_loadNHi_mult(&va1[i], v_inp0, n, zero); + temp2 += vec_loadNHi_mult(&va2[i], v_inp0, n, zero); + temp3 += vec_loadNHi_mult(&va3[i], v_inp0, n, zero); + temp4 += vec_loadNHi_mult(&va4[i], v_inp0, n, zero); + temp5 += vec_loadNHi_mult(&va5[i], v_inp0, n, zero); + temp6 += vec_loadNHi_mult(&va6[i], v_inp0, n, zero); + temp7 += vec_loadNHi_mult(&va7[i], v_inp0, n, zero); + } + + vec_f32 t0, t1, t2, t3; + vec_f32 a = { alpha, alpha, alpha, alpha }; + vec_f32 b = { beta, beta, beta, beta }; + vec_f32 *v_y = (vec_f32 *) y; + + t0 = vec_mergeh(temp0, temp2); + t1 = vec_mergel(temp0, temp2); + t2 = vec_mergeh(temp1, temp3); + t3 = vec_mergel(temp1, temp3); + temp0 = vec_mergeh(t0, t2); + temp1 = vec_mergel(t0, t2); + temp2 = vec_mergeh(t1, t3); + temp3 = vec_mergel(t1, t3); + temp0 += temp1 + temp2 + temp3; + + t0 = vec_mergeh(temp4, temp6); + t1 = vec_mergel(temp4, temp6); + t2 = vec_mergeh(temp5, temp7); + t3 = vec_mergel(temp5, temp7); + temp4 = vec_mergeh(t0, t2); + temp5 = vec_mergel(t0, t2); + temp6 = vec_mergeh(t1, t3); + temp7 = vec_mergel(t1, t3); + temp4 += temp5 + temp6 + temp7; + + v_y[0] = (a * temp0) + (b * v_y[0]); + v_y[1] = (a * temp4) + (b * v_y[1]); +} + +#define BF16GEMV_T_8 BF16GEMV_T_VSX_8 +#define BF16GEMV_T_4 BF16GEMV_T_VSX_4 +#define BF16GEMV_T_2 BF16GEMV_T_VSX_2 +#define BF16GEMV_T_1 BF16GEMV_T_VSX_1 + +#include "sbgemv_t.c" +#endif + diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c index b8aaee8be..a86c73d1c 100644 --- a/test/compare_sgemm_sbgemm.c +++ b/test/compare_sgemm_sbgemm.c @@ -202,16 +202,18 @@ main (int argc, char *argv[]) return ret; } + for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one. for (x = 1; x <= loop; x++) { - k = (x == 0) ? 0 : 1; + m = l + 1; + k = (x == 0) ? 0 : m; float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); - float *B = (float *)malloc_safe(x * sizeof(FLOAT)); - float *C = (float *)malloc_safe(x * sizeof(FLOAT)); + float *B = (float *)malloc_safe(x * sizeof(FLOAT) * m); + float *C = (float *)malloc_safe(x * sizeof(FLOAT) * m); bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits)); - bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits)); + bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) * m); float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); - float *CC = (float *)malloc_safe(x * sizeof(FLOAT)); + float *CC = (float *)malloc_safe(x * sizeof(FLOAT) * m); if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || (DD == NULL) || (CC == NULL)) return 1; @@ -226,9 +228,9 @@ main (int argc, char *argv[]) sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one); AA[j * x + i].v = atmp; } - B[j] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - sbstobf16_(&one, &B[j], &one, &btmp, &one); - BB[j].v = btmp; + B[j*m] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + sbstobf16_(&one, &B[j*m], &one, &btmp, &one); + BB[j*m].v = btmp; } for (y = 0; y < 2; y++) { @@ -238,9 +240,9 @@ main (int argc, char *argv[]) transA = 'T'; } - memset(CC, 0, x * sizeof(FLOAT)); + memset(CC, 0, x * m * sizeof(FLOAT)); memset(DD, 0, x * sizeof(FLOAT)); - memset(C, 0, x * sizeof(FLOAT)); + memset(C, 0, x * m * sizeof(FLOAT)); SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k); SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k); @@ -248,15 +250,15 @@ main (int argc, char *argv[]) for (j = 0; j < x; j++) for (i = 0; i < x; i++) if (transA == 'N') { - DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j]); + DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j*m]); } else if (transA == 'T') { - DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i]); + DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i*m]); } for (j = 0; j < x; j++) { - if (fabs (CC[j] - C[j]) > 1.0) + if (fabs (CC[j*m] - C[j*m]) > 1.0) ret++; - if (fabs (CC[j] - DD[j]) > 1.0) + if (fabs (CC[j*m] - DD[j]) > 1.0) ret++; } } @@ -268,6 +270,7 @@ main (int argc, char *argv[]) free(DD); free(CC); } + } if (ret != 0) fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret);