MMA BF16 GEMV code.
This commit is contained in:
@@ -25,12 +25,20 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef SBGEMV_N_VSX
|
||||
#define SBGEMV_N_VSX
|
||||
#ifndef SBGEMV_N_VSX_C
|
||||
#define SBGEMV_N_VSX_C
|
||||
|
||||
#include "sbgemv_common.c"
|
||||
|
||||
#define NBMAX 4096
|
||||
#ifndef BF16GEMV_N_X
|
||||
#define BF16GEMV_N_X
|
||||
#define BF16GEMV_N_8 BF16GEMV_N_VSX_8
|
||||
#define BF16GEMV_N_4 BF16GEMV_N_VSX_4
|
||||
#define BF16GEMV_N_2 BF16GEMV_N_VSX_2
|
||||
#define BF16GEMV_N_1 BF16GEMV_N_VSX_1
|
||||
#endif
|
||||
|
||||
#define USE_BFGEMV_8_N_VSX
|
||||
|
||||
static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha)
|
||||
{
|
||||
@@ -70,11 +78,11 @@ static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
||||
|
||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
} else if (n) {
|
||||
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero);
|
||||
|
||||
vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n);
|
||||
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,12 +129,12 @@ static void BF16GEMV_N_VSX_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
||||
|
||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
} else if (n) {
|
||||
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero);
|
||||
vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x1, &va1[i], n, zero);
|
||||
|
||||
vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n);
|
||||
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,17 +191,18 @@ static void BF16GEMV_N_VSX_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA
|
||||
|
||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
} else if (n) {
|
||||
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero);
|
||||
vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero);
|
||||
vy0 += vec_loadNHi_multi2(v_x2, &va2[i], n, zero);
|
||||
vy0 += vec_loadNHi_multi2(v_x3, &va3[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x1, &va1[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x2, &va2[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x3, &va3[i], n, zero);
|
||||
|
||||
vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n);
|
||||
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_BFGEMV_8_N_VSX
|
||||
static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLASLONG lda4, FLOAT alpha)
|
||||
{
|
||||
IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3;
|
||||
@@ -270,25 +279,21 @@ static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS
|
||||
|
||||
vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3);
|
||||
} else if (n) {
|
||||
vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n);
|
||||
|
||||
vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero);
|
||||
vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero);
|
||||
vy0 += vec_loadNHi_multi2(v_x2, &va2[i], n, zero);
|
||||
vy0 += vec_loadNHi_multi2(v_x3, &va3[i], n, zero);
|
||||
vy0 += vec_loadNHi_multi2(v_x4, &vb0[i], n, zero);
|
||||
vy0 += vec_loadNHi_multi2(v_x5, &vb1[i], n, zero);
|
||||
vy0 += vec_loadNHi_multi2(v_x6, &vb2[i], n, zero);
|
||||
vy0 += vec_loadNHi_multi2(v_x7, &vb3[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x1, &va1[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x2, &va2[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x3, &va3[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x4, &vb0[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x5, &vb1[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x6, &vb2[i], n, zero);
|
||||
vy0[0] += vec_loadNHi_mult2(v_x7, &vb3[i], n, zero);
|
||||
|
||||
vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n);
|
||||
vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n);
|
||||
}
|
||||
}
|
||||
|
||||
#define BF16GEMV_N_8 BF16GEMV_N_VSX_8
|
||||
#define BF16GEMV_N_4 BF16GEMV_N_VSX_4
|
||||
#define BF16GEMV_N_2 BF16GEMV_N_VSX_2
|
||||
#define BF16GEMV_N_1 BF16GEMV_N_VSX_1
|
||||
#endif
|
||||
|
||||
#include "sbgemv_n.c"
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user