diff --git a/kernel/arm64/KERNEL.NEOVERSEN2 b/kernel/arm64/KERNEL.NEOVERSEN2 index 7fe9acd5c..ae386d6e1 100644 --- a/kernel/arm64/KERNEL.NEOVERSEN2 +++ b/kernel/arm64/KERNEL.NEOVERSEN2 @@ -189,12 +189,11 @@ ZGEMMONCOPYOBJ = zgemm_oncopy$(TSUFFIX).$(SUFFIX) ZGEMMOTCOPYOBJ = zgemm_otcopy$(TSUFFIX).$(SUFFIX) SBGEMM_BETA = sbgemm_beta_neoversen2.c -# SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversen2.c -SBGEMMKERNEL = sbgemm_kernel_neoversen2_newbf16.c -SBGEMMINCOPY = sbgemm_ncopy_4_neoversen2.c -SBGEMMITCOPY = sbgemm_tcopy_8_neoversen2.c -SBGEMMONCOPY = sbgemm_ncopy_4_neoversen2.c -SBGEMMOTCOPY = sbgemm_tcopy_8_neoversen2.c +SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversen2.c +SBGEMMINCOPY = sbgemm_ncopy_$(SBGEMM_UNROLL_N)_neoversen2.c +SBGEMMITCOPY = sbgemm_tcopy_$(SBGEMM_UNROLL_M)_neoversen2.c +SBGEMMONCOPY = sbgemm_ncopy_$(SBGEMM_UNROLL_N)_neoversen2.c +SBGEMMOTCOPY = sbgemm_tcopy_$(SBGEMM_UNROLL_M)_neoversen2.c SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX) SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX) SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX) diff --git a/kernel/arm64/sbgemm_kernel_8x4_neoversen2.c b/kernel/arm64/sbgemm_kernel_8x4_neoversen2.c index 66e7dd38a..4c1385fbe 100644 --- a/kernel/arm64/sbgemm_kernel_8x4_neoversen2.c +++ b/kernel/arm64/sbgemm_kernel_8x4_neoversen2.c @@ -37,9 +37,9 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc) { - if (alpha == 1.0f) - return sbgemm_kernel_neoversen2_alpha_one(m, n, k, alpha, A, B, C, ldc); - else - return sbgemm_kernel_neoversen2_alpha(m, n, k, alpha, A, B, C, ldc); - return 0; + if (alpha == 1.0f) + return sbgemm_kernel_neoversen2_alpha_one(m, n, k, alpha, A, B, C, ldc); + else + return sbgemm_kernel_neoversen2_alpha(m, n, k, alpha, A, B, C, ldc); + return 0; } diff --git a/kernel/arm64/sbgemm_kernel_8x4_neoversen2_impl.c b/kernel/arm64/sbgemm_kernel_8x4_neoversen2_impl.c index 7d53b1aa0..26ea7ee61 100644 --- a/kernel/arm64/sbgemm_kernel_8x4_neoversen2_impl.c +++ b/kernel/arm64/sbgemm_kernel_8x4_neoversen2_impl.c @@ -30,100 +30,37 @@ #include "common.h" -#ifdef ALPHA_ONE -#define LOAD_C(M, N) \ - mc##M##N = svld1_gather_index(pg32, ptr_c0##N + 2 * M , off_vc); - -#define LOAD_C_LOW(M, N) \ - mc##M##N = svld1_gather_index(pg32_low, ptr_c0##N + 2 * M, off_vc); - -#define LOAD_C_EVEN(M, N) \ - mc##M##N = svld1_gather_index(pg32_even, ptr_c0##N + 2 * M, off_vc); - -#define LOAD_C_FIRST(M, N) \ - mc##M##N = svld1_gather_index(pg32_first, ptr_c0##N + 2 * M, off_vc); - -#define STORE_C(M, N) \ - svst1_scatter_index(pg32, ptr_c0##N + 2 * M, off_vc, mc##M##N); - -#define STORE_C_LOW(M, N) \ - svst1_scatter_index(pg32_low, ptr_c0##N + 2 * M, off_vc, mc##M##N); - -#define STORE_C_EVEN(M, N) \ - svst1_scatter_index(pg32_even, ptr_c0##N + 2 * M, off_vc, mc##M##N); - -#define STORE_C_FIRST(M, N) \ - svst1_scatter_index(pg32_first, ptr_c0##N + 2 * M, off_vc, mc##M##N); - -#else -#define LOAD_C(M, N) \ - mc##M##N = svdup_f32(0); \ - oc##M##N = svld1_gather_index(pg32, ptr_c0##N + 2 * M , off_vc); - -#define LOAD_C_LOW(M, N) \ - mc##M##N = svdup_f32(0); \ - oc##M##N = svld1_gather_index(pg32_low, ptr_c0##N + 2 * M , off_vc); - -#define LOAD_C_EVEN(M, N) \ - mc##M##N = svdup_f32(0); \ - oc##M##N = svld1_gather_index(pg32_even, ptr_c0##N + 2 * M , off_vc); - -#define LOAD_C_FIRST(M, N) \ - mc##M##N = svdup_f32(0); \ - oc##M##N = svld1_gather_index(pg32_first, ptr_c0##N + 2 * M , off_vc); - -#define STORE_C(M, N) \ - mc##M##N = svmad_z(pg32, svalpha, mc##M##N, oc##M##N); \ - svst1_scatter_index(pg32, ptr_c0##N + 2 * M, off_vc, mc##M##N); - -#define STORE_C_LOW(M, N) \ - mc##M##N = svmad_z(pg32_low, svalpha, mc##M##N, oc##M##N); \ - svst1_scatter_index(pg32_low, ptr_c0##N + 2 * M, off_vc, mc##M##N); - -#define STORE_C_EVEN(M, N) \ - mc##M##N = svmad_z(pg32_even, svalpha, mc##M##N, oc##M##N); \ - svst1_scatter_index(pg32_even, ptr_c0##N + 2 * M, off_vc, mc##M##N); - -#define STORE_C_FIRST(M, N) \ - mc##M##N = svmad_z(pg32_first, svalpha, mc##M##N, oc##M##N); \ - svst1_scatter_index(pg32_first, ptr_c0##N + 2 * M, off_vc, mc##M##N); - -#endif - -#define LOAD_A(M) ma##M = svld1_bf16(pg16, ptr_a##M); - -#define LOAD_B(N) mb##N = svld1_bf16(pg16, ptr_b##N); +#define INIT_C(M, N) mc##M##N = svdup_f32(0); #define MATMUL(M, N) mc##M##N = svbfmmla(mc##M##N, ma##M, mb##N); -#define LOAD_KREST_1(NAME, M) \ - m##NAME##M = svdupq_bf16(*(ptr_##NAME##M), zero, zero, zero, \ - *(ptr_##NAME##M + 1), zero, zero, zero); - -#define LOAD_KREST_1_LOW(NAME, M) \ - m##NAME##M = svdupq_bf16(*(ptr_##NAME##M), zero, zero, zero, zero, zero, \ - zero, zero); - -#define LOAD_KREST_2(NAME, M) \ - m##NAME##M = \ - svdupq_bf16(*(ptr_##NAME##M), *(ptr_##NAME##M + 1), zero, zero, \ - *(ptr_##NAME##M + 2), *(ptr_##NAME##M + 3), zero, zero); - -#define LOAD_KREST_2_LOW(NAME, M) \ - m##NAME##M = svdupq_bf16(*(ptr_##NAME##M), *(ptr_##NAME##M + 1), zero, \ - zero, zero, zero, zero, zero); - -#define LOAD_KREST_3(NAME, M) \ - m##NAME##M = \ - svdupq_bf16(*(ptr_##NAME##M), *(ptr_##NAME##M + 1), \ - *(ptr_##NAME##M + 2), zero, *(ptr_##NAME##M + 3), \ - *(ptr_##NAME##M + 4), *(ptr_##NAME##M + 5), zero); - -#define LOAD_KREST_3_LOW(NAME, M) \ - m##NAME##M = \ - svdupq_bf16(*(ptr_##NAME##M), *(ptr_##NAME##M + 1), \ - *(ptr_##NAME##M + 2), zero, zero, zero, zero, zero); +#define INIT_C_8x4 \ + do { \ + INIT_C(0, 0); \ + INIT_C(0, 1); \ + INIT_C(1, 0); \ + INIT_C(1, 1); \ + INIT_C(2, 0); \ + INIT_C(2, 1); \ + INIT_C(3, 0); \ + INIT_C(3, 1); \ + } while (0); +#ifdef ALPHA_ONE +#define UPDATE_C(PG, PTR, DST, SRC) \ + do { \ + DST = svld1_f32((PG), (PTR)); \ + DST = svadd_z((PG), SRC, DST); \ + svst1_f32((PG), (PTR), DST); \ + } while (0); +#else +#define UPDATE_C(PG, PTR, DST, SRC) \ + do { \ + DST = svld1_f32((PG), (PTR)); \ + DST = svmad_z((PG), svalpha, SRC, DST); \ + svst1_f32((PG), (PTR), DST); \ + } while (0); +#endif #ifdef ALPHA_ONE int sbgemm_kernel_neoversen2_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc) @@ -131,535 +68,404 @@ int sbgemm_kernel_neoversen2_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT int sbgemm_kernel_neoversen2_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc) #endif { - bfloat16_t *ptr_a = (bfloat16_t *)A; - bfloat16_t *ptr_b = (bfloat16_t *)B; - FLOAT *ptr_c = C; + BLASLONG pad_k = (k + 3) & ~3; - bfloat16_t *ptr_a0, *ptr_a1, *ptr_a2, *ptr_a3; - bfloat16_t *ptr_b0, *ptr_b1; - FLOAT *ptr_c00, *ptr_c01; + svbfloat16_t ma0, ma1, ma2, ma3, mb0, mb1; + svfloat32_t mc00, mc01, mc10, mc11, mc20, mc21, mc30, mc31, + vc0, vc1, vc2, vc3, vc4, vc5, vc6, vc7, + oc0, oc1, oc2, oc3, oc4, oc5, oc6, oc7; + svfloat32_t svalpha = svdup_f32(alpha); - svbfloat16_t ma0, ma1, ma2, ma3, mb0, mb1; - svfloat32_t mc00, mc01, mc10, mc11, mc20, mc21, mc30, mc31; -#ifndef ALPHA_ONE - svfloat32_t oc00, oc01, oc10, oc11, oc20, oc21, oc30, oc31; -#endif - svbool_t pg16 = svptrue_b16(); - svbool_t pg16_low = svdupq_b16(1, 1, 1, 1, 0, 0, 0, 0); - svbool_t pg32 = svptrue_b32(); - svbool_t pg32_low = svdupq_b32(1, 1, 0, 0); - svbool_t pg32_even = svdupq_b32(1, 0, 1, 0); - svbool_t pg32_first = svdupq_b32(1, 0, 0, 0); - svfloat32_t svalpha = svdup_f32(alpha); - bfloat16 tmp = 0; - bfloat16_t zero = *((bfloat16_t *)&tmp); - BLASLONG krest = k & 3; + svbool_t pg16 = svptrue_b16(); + svbool_t pg16_low = svdupq_b16(1, 1, 1, 1, 0, 0, 0, 0); + svbool_t pg32 = svptrue_b32(); + svbool_t pg32_low = svdupq_b32(1, 1, 0, 0); + svbool_t pg32_first = svdupq_b32(1, 0, 0, 0); - // 00 01 10 11 - svuint32_t off_vc = svdupq_u32(0, (uint32_t)ldc, 1, (uint32_t)ldc + 1); + bfloat16_t *ptr_a = (bfloat16_t *)A; + bfloat16_t *ptr_b = (bfloat16_t *)B; + FLOAT *ptr_c = C; - for (BLASLONG j = 0; j < n / 4; j++) { - ptr_c00 = ptr_c; - ptr_c01 = ptr_c + 2 * ldc; - ptr_c += 4 * ldc; + bfloat16_t *ptr_a0, *ptr_a1, *ptr_a2, *ptr_a3; + bfloat16_t *ptr_b0, *ptr_b1; + FLOAT *ptr_c0, *ptr_c1, *ptr_c2, *ptr_c3; - ptr_a = (bfloat16_t *)A; + for (BLASLONG j = 0; j < n / 4; j++) { + ptr_c0 = ptr_c; + ptr_c1 = ptr_c0 + ldc; + ptr_c2 = ptr_c1 + ldc; + ptr_c3 = ptr_c2 + ldc; + ptr_c += 4 * ldc; + ptr_a = (bfloat16_t *)A; - for (BLASLONG i = 0; i < m / 8; i++) { - ptr_a0 = ptr_a; - ptr_a1 = ptr_a0 + 2 * k; - ptr_a2 = ptr_a1 + 2 * k; - ptr_a3 = ptr_a2 + 2 * k; - ptr_a += 8 * k; + for (BLASLONG i = 0; i < m / 8; i++) { + ptr_a0 = ptr_a; + ptr_a += 8 * pad_k; - ptr_b0 = ptr_b; - ptr_b1 = ptr_b0 + 2 * k; + ptr_b0 = ptr_b; - LOAD_C(0, 0); LOAD_C(0, 1); - LOAD_C(1, 0); LOAD_C(1, 1); - LOAD_C(2, 0); LOAD_C(2, 1); - LOAD_C(3, 0); LOAD_C(3, 1); + INIT_C_8x4; - for (BLASLONG p = 0; p < k / 4; p++) { - LOAD_A(0); LOAD_A(1); LOAD_A(2); LOAD_A(3); - LOAD_B(0); LOAD_B(1); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + ma1 = svld1_bf16(pg16, ptr_a0 + 8); + ma2 = svld1_bf16(pg16, ptr_a0 + 16); + ma3 = svld1_bf16(pg16, ptr_a0 + 24); - MATMUL(0, 0); MATMUL(0, 1); - MATMUL(1, 0); MATMUL(1, 1); - MATMUL(2, 0); MATMUL(2, 1); - MATMUL(3, 0); MATMUL(3, 1); + mb0 = svld1_bf16(pg16, ptr_b0); + mb1 = svld1_bf16(pg16, ptr_b0 + 8); - ptr_a0 += 8; ptr_a1 += 8; ptr_a2 += 8; ptr_a3 += 8; - ptr_b0 += 8; ptr_b1 += 8; - } + MATMUL(0, 0); MATMUL(0, 1); + MATMUL(1, 0); MATMUL(1, 1); + MATMUL(2, 0); MATMUL(2, 1); + MATMUL(3, 0); MATMUL(3, 1); - if (krest) { - if (krest == 1) { - LOAD_KREST_1(a, 0); LOAD_KREST_1(a, 1); - LOAD_KREST_1(a, 2); LOAD_KREST_1(a, 3); - LOAD_KREST_1(b, 0); LOAD_KREST_1(b, 1); - } else if (krest == 2) { - LOAD_KREST_2(a, 0); LOAD_KREST_2(a, 1); - LOAD_KREST_2(a, 2); LOAD_KREST_2(a, 3); - LOAD_KREST_2(b, 0); LOAD_KREST_2(b, 1); - } else if (krest == 3) { - LOAD_KREST_3(a, 0); LOAD_KREST_3(a, 1); - LOAD_KREST_3(a, 2); LOAD_KREST_3(a, 3); - LOAD_KREST_3(b, 0); LOAD_KREST_3(b, 1); - } - MATMUL(0, 0); MATMUL(0, 1); - MATMUL(1, 0); MATMUL(1, 1); - MATMUL(2, 0); MATMUL(2, 1); - MATMUL(3, 0); MATMUL(3, 1); - } + ptr_a0 += 32; + ptr_b0 += 16; + } - STORE_C(0, 0); STORE_C(0, 1); - STORE_C(1, 0); STORE_C(1, 1); - STORE_C(2, 0); STORE_C(2, 1); - STORE_C(3, 0); STORE_C(3, 1); + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp1(mc20, mc30); + vc2 = svuzp2(mc00, mc10); + vc3 = svuzp2(mc20, mc30); + vc4 = svuzp1(mc01, mc11); + vc5 = svuzp1(mc21, mc31); + vc6 = svuzp2(mc01, mc11); + vc7 = svuzp2(mc21, mc31); - ptr_c00 += 8; ptr_c01 += 8; - } + UPDATE_C(pg32, ptr_c0, oc0, vc0); + UPDATE_C(pg32, ptr_c0+4, oc1, vc1); + UPDATE_C(pg32, ptr_c1, oc2, vc2); + UPDATE_C(pg32, ptr_c1+4, oc3, vc3); + UPDATE_C(pg32, ptr_c2, oc4, vc4) + UPDATE_C(pg32, ptr_c2+4, oc5, vc5); + UPDATE_C(pg32, ptr_c3, oc6, vc6) + UPDATE_C(pg32, ptr_c3+4, oc7, vc7); - if (m & 4) { - ptr_a0 = ptr_a; - ptr_a1 = ptr_a0 + 2 * k; - ptr_a += 4 * k; - - ptr_b0 = ptr_b; - ptr_b1 = ptr_b0 + 2 * k; - - LOAD_C(0, 0); LOAD_C(0, 1); - LOAD_C(1, 0); LOAD_C(1, 1); - - for (BLASLONG p = 0; p < k / 4; p++) { - LOAD_A(0); LOAD_A(1); - LOAD_B(0); LOAD_B(1); - - MATMUL(0, 0); MATMUL(0, 1); - MATMUL(1, 0); MATMUL(1, 1); - - ptr_a0 += 8; ptr_a1 += 8; - ptr_b0 += 8; ptr_b1 += 8; - } - - if (krest) { - if (krest == 1) { - LOAD_KREST_1(a, 0); LOAD_KREST_1(a, 1); - LOAD_KREST_1(b, 0); LOAD_KREST_1(b, 1); - } else if (krest == 2) { - LOAD_KREST_2(a, 0); LOAD_KREST_2(a, 1); - LOAD_KREST_2(b, 0); LOAD_KREST_2(b, 1); - } else if (krest == 3) { - LOAD_KREST_3(a, 0); LOAD_KREST_3(a, 1); - LOAD_KREST_3(b, 0); LOAD_KREST_3(b, 1); - } - MATMUL(0, 0); MATMUL(0, 1); - MATMUL(1, 0); MATMUL(1, 1); - } - - STORE_C(0, 0); STORE_C(0, 1); - STORE_C(1, 0); STORE_C(1, 1); - - ptr_c00 += 4; ptr_c01 += 4; - } - - if (m & 2) { - ptr_a0 = ptr_a; - ptr_a += 2 * k; - - ptr_b0 = ptr_b; - ptr_b1 = ptr_b0 + 2 * k; - - LOAD_C(0, 0); LOAD_C(0, 1); - - for (BLASLONG p = 0; p < k / 4; p++) { - LOAD_A(0); - LOAD_B(0); LOAD_B(1); - - MATMUL(0, 0); MATMUL(0, 1); - - ptr_a0 += 8; - ptr_b0 += 8; ptr_b1 += 8; - } - - if (krest) { - if (krest == 1) { - LOAD_KREST_1(a, 0); - LOAD_KREST_1(b, 0); LOAD_KREST_1(b, 1); - } else if (krest == 2) { - LOAD_KREST_2(a, 0); - LOAD_KREST_2(b, 0); LOAD_KREST_2(b, 1); - } else if (krest == 3) { - LOAD_KREST_3(a, 0); - LOAD_KREST_3(b, 0); LOAD_KREST_3(b, 1); - } - MATMUL(0, 0); MATMUL(0, 1); - } - STORE_C(0, 0); STORE_C(0, 1); - ptr_c00 += 2; ptr_c01 += 2; - } - - if (m & 1) { - ptr_a0 = ptr_a; - - ptr_b0 = ptr_b; - ptr_b1 = ptr_b0 + 2 * k; - - LOAD_C_LOW(0, 0); LOAD_C_LOW(0, 1); - - for (BLASLONG p = 0; p < k / 4; p++) { - ma0 = svld1_bf16(pg16_low, ptr_a0); - LOAD_B(0); LOAD_B(1); - - MATMUL(0, 0); MATMUL(0, 1); - - ptr_a0 += 4; - ptr_b0 += 8; - ptr_b1 += 8; - } - - if (krest) { - if (krest == 1) { - LOAD_KREST_1_LOW(a, 0); - LOAD_KREST_1(b, 0); LOAD_KREST_1(b, 1); - } else if (krest == 2) { - LOAD_KREST_2_LOW(a, 0); - LOAD_KREST_2(b, 0); LOAD_KREST_2(b, 1); - } else if (krest == 3) { - LOAD_KREST_3_LOW(a, 0); - LOAD_KREST_3(b, 0); LOAD_KREST_3(b, 1); - } - MATMUL(0, 0); MATMUL(0, 1); - } - STORE_C_LOW(0, 0); STORE_C_LOW(0, 1); - } - - ptr_b += 4 * k; + ptr_c0 += 8; + ptr_c1 += 8; + ptr_c2 += 8; + ptr_c3 += 8; } - if (n & 2) { - ptr_c00 = ptr_c; - ptr_c += 2 * ldc; + if (m & 4) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + ptr_b0 = ptr_b; - ptr_a = (bfloat16_t *)A; + INIT_C(0, 0); INIT_C(0, 1); + INIT_C(1, 0); INIT_C(1, 1); - for (BLASLONG i = 0; i < m / 8; i++) { - ptr_a0 = ptr_a; - ptr_a1 = ptr_a0 + 2 * k; - ptr_a2 = ptr_a1 + 2 * k; - ptr_a3 = ptr_a2 + 2 * k; - ptr_a += 8 * k; + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + ma1 = svld1_bf16(pg16, ptr_a0 + 8); + mb0 = svld1_bf16(pg16, ptr_b0); + mb1 = svld1_bf16(pg16, ptr_b0 + 8); - ptr_b0 = ptr_b; + MATMUL(0, 0); MATMUL(0, 1); + MATMUL(1, 0); MATMUL(1, 1); - LOAD_C(0, 0); - LOAD_C(1, 0); - LOAD_C(2, 0); - LOAD_C(3, 0); + ptr_a0 += 16; + ptr_b0 += 16; + } - for (BLASLONG p = 0; p < k / 4; p++) { - LOAD_A(0); LOAD_A(1); LOAD_A(2); LOAD_A(3); - LOAD_B(0); + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp2(mc00, mc10); + vc2 = svuzp1(mc01, mc11); + vc3 = svuzp2(mc01, mc11); - MATMUL(0, 0); - MATMUL(1, 0); - MATMUL(2, 0); - MATMUL(3, 0); + UPDATE_C(pg32, ptr_c0, oc0, vc0); + UPDATE_C(pg32, ptr_c1, oc1, vc1); + UPDATE_C(pg32, ptr_c2, oc2, vc2); + UPDATE_C(pg32, ptr_c3, oc3, vc3); - ptr_a0 += 8; ptr_a1 += 8; ptr_a2 += 8; ptr_a3 += 8; - ptr_b0 += 8; - } - if (krest) { - if (krest == 1) { - LOAD_KREST_1(a, 0); LOAD_KREST_1(a, 1); - LOAD_KREST_1(a, 2); LOAD_KREST_1(a, 3); - LOAD_KREST_1(b, 0); - } else if (krest == 2) { - LOAD_KREST_2(a, 0); LOAD_KREST_2(a, 1); - LOAD_KREST_2(a, 2); LOAD_KREST_2(a, 3); - LOAD_KREST_2(b, 0); - } else if (krest == 3) { - LOAD_KREST_3(a, 0); LOAD_KREST_3(a, 1); - LOAD_KREST_3(a, 2); LOAD_KREST_3(a, 3); - LOAD_KREST_3(b, 0); - } - MATMUL(0, 0); - MATMUL(1, 0); - MATMUL(2, 0); - MATMUL(3, 0); - } - - STORE_C(0, 0); - STORE_C(1, 0); - STORE_C(2, 0); - STORE_C(3, 0); - - ptr_c00 += 8; - } - - if (m & 4) { - ptr_a0 = ptr_a; - ptr_a1 = ptr_a0 + 2 * k; - ptr_a += 4 * k; - - ptr_b0 = ptr_b; - - LOAD_C(0, 0); - LOAD_C(1, 0); - - for (BLASLONG p = 0; p < k / 4; p++) { - LOAD_A(0); LOAD_A(1); - LOAD_B(0); - - MATMUL(0, 0); - MATMUL(1, 0); - - ptr_a0 += 8; ptr_a1 += 8; - ptr_b0 += 8; - } - if (krest) { - if (krest == 1) { - LOAD_KREST_1(a, 0); LOAD_KREST_1(a, 1); - LOAD_KREST_1(b, 0); - } else if (krest == 2) { - LOAD_KREST_2(a, 0); LOAD_KREST_2(a, 1); - LOAD_KREST_2(b, 0); - } else if (krest == 3) { - LOAD_KREST_3(a, 0); LOAD_KREST_3(a, 1); - LOAD_KREST_3(b, 0); - } - MATMUL(0, 0); - MATMUL(1, 0); - } - STORE_C(0, 0) - STORE_C(1, 0) - - ptr_c00 += 4; - } - - if (m & 2) { - ptr_a0 = ptr_a; - ptr_a += 2 * k; - ptr_b0 = ptr_b; - - LOAD_C(0, 0); - for (BLASLONG p = 0; p < k / 4; p++) { - LOAD_A(0); - LOAD_B(0); - MATMUL(0, 0); - ptr_a0 += 8; - ptr_b0 += 8; - } - if (krest) { - if (krest == 1) { - LOAD_KREST_1(a, 0); - LOAD_KREST_1(b, 0); - } else if (krest == 2) { - LOAD_KREST_2(a, 0); - LOAD_KREST_2(b, 0); - } else if (krest == 3) { - LOAD_KREST_3(a, 0); - LOAD_KREST_3(b, 0); - } - MATMUL(0, 0); - } - STORE_C(0, 0); - ptr_c00 += 2; - } - - if (m & 1) { - ptr_a0 = ptr_a; - - ptr_b0 = ptr_b; - - LOAD_C(0, 0); - - for (BLASLONG p = 0; p < k / 4; p++) { - ma0 = svld1_bf16(pg16_low, ptr_a0); - LOAD_B(0); - MATMUL(0, 0); - ptr_a0 += 4; - ptr_b0 += 8; - } - if (krest) { - if (krest == 1) { - LOAD_KREST_1_LOW(a, 0); - LOAD_KREST_1(b, 0); - } else if (krest == 2) { - LOAD_KREST_2_LOW(a, 0); - LOAD_KREST_2(b, 0); - } else if (krest == 3) { - LOAD_KREST_3_LOW(a, 0); - LOAD_KREST_3(b, 0); - } - MATMUL(0, 0); - } - STORE_C_LOW(0, 0); - } - - ptr_b += 2 * k; + ptr_c0 += 4; + ptr_c1 += 4; + ptr_c2 += 4; + ptr_c3 += 4; } - if (n & 1) { - ptr_c00 = ptr_c; - ptr_a = (bfloat16_t *) A; + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + ptr_b0 = ptr_b; - for (BLASLONG i = 0; i < m / 8; i++) { - ptr_a0 = ptr_a; - ptr_a1 = ptr_a0 + 2 * k; - ptr_a2 = ptr_a1 + 2 * k; - ptr_a3 = ptr_a2 + 2 * k; - ptr_a += 8 * k; + INIT_C(0, 0); INIT_C(0, 1); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + mb0 = svld1_bf16(pg16, ptr_b0); + mb1 = svld1_bf16(pg16, ptr_b0 + 8); - ptr_b0 = ptr_b; + MATMUL(0, 0); MATMUL(0, 1); - LOAD_C_EVEN(0, 0); - LOAD_C_EVEN(1, 0); - LOAD_C_EVEN(2, 0); - LOAD_C_EVEN(3, 0); + ptr_a0 += 8; + ptr_b0 += 16; + } - for (BLASLONG p = 0; p < k / 4; p++) { - LOAD_A(0); LOAD_A(1); LOAD_A(2); LOAD_A(3); - mb0 = svld1_bf16(pg16_low, ptr_b0); + vc0 = svuzp1(mc00, mc00); + vc1 = svuzp2(mc00, mc00); + vc2 = svuzp1(mc01, mc01); + vc3 = svuzp2(mc01, mc01); - MATMUL(0, 0); - MATMUL(1, 0); - MATMUL(2, 0); - MATMUL(3, 0); + UPDATE_C(pg32_low, ptr_c0, oc0, vc0); + UPDATE_C(pg32_low, ptr_c1, oc1, vc1); + UPDATE_C(pg32_low, ptr_c2, oc2, vc2); + UPDATE_C(pg32_low, ptr_c3, oc3, vc3); - ptr_a0 += 8; ptr_a1 += 8; ptr_a2 += 8; ptr_a3 += 8; - ptr_b0 += 4; - } - if (krest) { - if (krest == 1) { - LOAD_KREST_1(a, 0); LOAD_KREST_1(a, 1); - LOAD_KREST_1(a, 2); LOAD_KREST_1(a, 3); - LOAD_KREST_1_LOW(b, 0); - } else if (krest == 2) { - LOAD_KREST_2(a, 0); LOAD_KREST_2(a, 1); - LOAD_KREST_2(a, 2); LOAD_KREST_2(a, 3); - LOAD_KREST_2_LOW(b, 0); - } else if (krest == 3) { - LOAD_KREST_3(a, 0); LOAD_KREST_3(a, 1); - LOAD_KREST_3(a, 2); LOAD_KREST_3(a, 3); - LOAD_KREST_3_LOW(b, 0); - } - MATMUL(0, 0); - MATMUL(1, 0); - MATMUL(2, 0); - MATMUL(3, 0); - } - STORE_C_EVEN(0, 0) - STORE_C_EVEN(1, 0); - STORE_C_EVEN(2, 0); - STORE_C_EVEN(3, 0); - - ptr_c00 += 8; - } - - if (m & 4) { - ptr_a0 = ptr_a; - ptr_a1 = ptr_a0 + 2 * k; - ptr_a += 4 * k; - - ptr_b0 = ptr_b; - - LOAD_C_EVEN(0, 0); - LOAD_C_EVEN(1, 0); - - for (BLASLONG p = 0; p < k / 4; p++) { - LOAD_A(0); LOAD_A(1); - mb0 = svld1_bf16(pg16_low, ptr_b0); - - MATMUL(0, 0); - MATMUL(1, 0); - - ptr_a0 += 8; ptr_a1 += 8; - ptr_b0 += 4; - } - if (krest) { - if (krest == 1) { - LOAD_KREST_1(a, 0); LOAD_KREST_1(a, 1); - LOAD_KREST_1_LOW(b, 0); - } else if (krest == 2) { - LOAD_KREST_2(a, 0); LOAD_KREST_2(a, 1); - LOAD_KREST_2_LOW(b, 0); - } else if (krest == 3) { - LOAD_KREST_3(a, 0); LOAD_KREST_3(a, 1); - LOAD_KREST_3_LOW(b, 0); - } - MATMUL(0, 0); - MATMUL(1, 0); - } - STORE_C_EVEN(0, 0) - STORE_C_EVEN(1, 0) - - ptr_c00 += 4; - } - - if (m & 2) { - ptr_a0 = ptr_a; - ptr_a += 2 * k; - - ptr_b0 = ptr_b; - - LOAD_C_EVEN(0, 0); - - for (BLASLONG p = 0; p < k / 4; p++) { - LOAD_A(0); - mb0 = svld1_bf16(pg16_low, ptr_b0); - - MATMUL(0, 0); - - ptr_a0 += 8; - ptr_b0 += 4; - } - if (krest) { - if (krest == 1) { - LOAD_KREST_1(a, 0); - LOAD_KREST_1_LOW(b, 0); - } else if (krest == 2) { - LOAD_KREST_2(a, 0); - LOAD_KREST_2_LOW(b, 0); - } else if (krest == 3) { - LOAD_KREST_3(a, 0); - LOAD_KREST_3_LOW(b, 0); - } - MATMUL(0, 0); - } - STORE_C_EVEN(0, 0); - ptr_c00 += 2; - } - if (m & 1) { - ptr_a0 = ptr_a; - ptr_b0 = ptr_b; - LOAD_C_FIRST(0, 0); - for (BLASLONG p = 0; p < k / 4; p++) { - ma0 = svld1_bf16(pg16_low, ptr_a0); - mb0 = svld1_bf16(pg16_low, ptr_b0); - - MATMUL(0, 0); - - ptr_a0 += 4; - ptr_b0 += 4; - } - if (krest) { - if (krest == 1) { - LOAD_KREST_1_LOW(a, 0); - LOAD_KREST_1_LOW(b, 0); - } else if (krest == 2) { - LOAD_KREST_2_LOW(a, 0); - LOAD_KREST_2_LOW(b, 0); - } else if (krest == 3) { - LOAD_KREST_3_LOW(a, 0); - LOAD_KREST_3_LOW(b, 0); - } - MATMUL(0, 0); - } - STORE_C_FIRST(0, 0); - } + ptr_c0 += 2; + ptr_c1 += 2; + ptr_c2 += 2; + ptr_c3 += 2; } - return 0; -} \ No newline at end of file + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + + INIT_C(0, 0); INIT_C(0, 1); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_low, ptr_a0); + mb0 = svld1_bf16(pg16, ptr_b0); + mb1 = svld1_bf16(pg16, ptr_b0 + 8); + + MATMUL(0, 0); MATMUL(0, 1); + + ptr_a0 += 4; + ptr_b0 += 16; + } + + vc1 = svuzp2(mc00, mc00); + vc3 = svuzp2(mc01, mc01); + + UPDATE_C(pg32_first, ptr_c0, oc0, mc00); + UPDATE_C(pg32_first, ptr_c1, oc1, vc1); + UPDATE_C(pg32_first, ptr_c2, oc2, mc01); + UPDATE_C(pg32_first, ptr_c3, oc3, vc3); + + } + + ptr_b += 4 * pad_k; + } + + if (n & 2) { + ptr_c0 = ptr_c; + ptr_c1 = ptr_c0 + ldc; + ptr_c += 2 * ldc; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 8; i++) { + ptr_a0 = ptr_a; + ptr_a += 8 * pad_k; + + ptr_b0 = ptr_b; + + INIT_C(0, 0); + INIT_C(1, 0); + INIT_C(2, 0); + INIT_C(3, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + ma1 = svld1_bf16(pg16, ptr_a0 + 8); + ma2 = svld1_bf16(pg16, ptr_a0 + 16); + ma3 = svld1_bf16(pg16, ptr_a0 + 24); + + mb0 = svld1_bf16(pg16, ptr_b0); + + MATMUL(0, 0); + MATMUL(1, 0); + MATMUL(2, 0); + MATMUL(3, 0); + + ptr_a0 += 32; + ptr_b0 += 8; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp1(mc20, mc30); + vc2 = svuzp2(mc00, mc10); + vc3 = svuzp2(mc20, mc30); + + UPDATE_C(pg32, ptr_c0, oc0, vc0); + UPDATE_C(pg32, ptr_c0 + 4, oc1, vc1); + UPDATE_C(pg32, ptr_c1, oc2, vc2); + UPDATE_C(pg32, ptr_c1 + 4, oc3, vc3); + + ptr_c0 += 8; + ptr_c1 += 8; + } + + if (m & 4) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + ptr_b0 = ptr_b; + + INIT_C(0, 0); + INIT_C(1, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + ma1 = svld1_bf16(pg16, ptr_a0 + 8); + mb0 = svld1_bf16(pg16, ptr_b0); + MATMUL(0, 0); + MATMUL(1, 0); + ptr_a0 += 16; + ptr_b0 += 8; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp2(mc00, mc10); + + UPDATE_C(pg32, ptr_c0, oc0, vc0); + UPDATE_C(pg32, ptr_c1, oc1, vc1); + + ptr_c0 += 4; + ptr_c1 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + ptr_b0 = ptr_b; + + INIT_C(0, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + mb0 = svld1_bf16(pg16, ptr_b0); + + MATMUL(0, 0); + + ptr_a0 += 8; + ptr_b0 += 8; + } + + vc0 = svuzp1(mc00, mc00); + vc1 = svuzp2(mc00, mc00); + UPDATE_C(pg32_low, ptr_c0, oc0, vc0); + UPDATE_C(pg32_low, ptr_c1, oc1, vc1); + + ptr_c0 += 2; + ptr_c1 += 2; + + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + INIT_C(0, 0); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_low, ptr_a0); + mb0 = svld1_bf16(pg16, ptr_b0); + MATMUL(0, 0); + ptr_a0 += 4; + ptr_b0 += 8; + } + vc1 = svuzp2(mc00, mc00); + + UPDATE_C(pg32_first, ptr_c0, oc0, mc00); + UPDATE_C(pg32_first, ptr_c1, oc1, vc1); + } + + ptr_b += 2 * pad_k; + } + + if (n & 1) { + ptr_c0 = ptr_c; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 8; i++) { + ptr_a0 = ptr_a; + ptr_a += 8 * pad_k; + + ptr_b0 = ptr_b; + + INIT_C(0, 0); + INIT_C(1, 0); + INIT_C(2, 0); + INIT_C(3, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + ma1 = svld1_bf16(pg16, ptr_a0 + 8); + ma2 = svld1_bf16(pg16, ptr_a0 + 16); + ma3 = svld1_bf16(pg16, ptr_a0 + 24); + + mb0 = svld1_bf16(pg16_low, ptr_b0); + + MATMUL(0, 0); + MATMUL(1, 0); + MATMUL(2, 0); + MATMUL(3, 0); + + ptr_a0 += 32; + ptr_b0 += 4; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp1(mc20, mc30); + + UPDATE_C(pg32, ptr_c0, oc0, vc0); + UPDATE_C(pg32, ptr_c0 + 4, oc1, vc1); + + ptr_c0 += 8; + } + + if (m & 4) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + ptr_b0 = ptr_b; + INIT_C(0, 0); + INIT_C(1, 0); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + ma1 = svld1_bf16(pg16, ptr_a0 + 8); + mb0 = svld1_bf16(pg16_low, ptr_b0); + MATMUL(0, 0); + MATMUL(1, 0); + ptr_a0 += 16; + ptr_b0 += 4; + } + vc0 = svuzp1(mc00, mc10); + UPDATE_C(pg32, ptr_c0, oc0, vc0); + ptr_c0 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + ptr_b0 = ptr_b; + + INIT_C(0, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + mb0 = svld1_bf16(pg16_low, ptr_b0); + + MATMUL(0, 0); + + ptr_a0 += 8; + ptr_b0 += 4; + } + vc0 = svuzp1(mc00, mc00); + UPDATE_C(pg32_low, ptr_c0, oc0, vc0); + ptr_c0 += 2; + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + INIT_C(0, 0); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_low, ptr_a0); + mb0 = svld1_bf16(pg16_low, ptr_b0); + MATMUL(0, 0); + ptr_a0 += 4; + ptr_b0 += 4; + } + UPDATE_C(pg32_first, ptr_c0, oc0, mc00); + } + } + + return 0; +} diff --git a/kernel/arm64/sbgemm_kernel_neoversen2_newbf16.c b/kernel/arm64/sbgemm_kernel_neoversen2_newbf16.c deleted file mode 100644 index 1bf743c7f..000000000 --- a/kernel/arm64/sbgemm_kernel_neoversen2_newbf16.c +++ /dev/null @@ -1,467 +0,0 @@ -/*************************************************************************** - * Copyright (c) 2022, The OpenBLAS Project - * All rights reserved. - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in - * the documentation and/or other materials provided with the - * distribution. - * 3. Neither the name of the OpenBLAS project nor the names of - * its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - * *****************************************************************************/ - -#include - -#include "common.h" - -#define LOAD_C(M, N) mc##M##N = svdup_f32(0); - -#define MATMUL(M, N) mc##M##N = svbfmmla(mc##M##N, ma##M, mb##N); - -#define LOAD_C_8x4 \ - do { \ - LOAD_C(0, 0); \ - LOAD_C(0, 1); \ - LOAD_C(1, 0); \ - LOAD_C(1, 1); \ - LOAD_C(2, 0); \ - LOAD_C(2, 1); \ - LOAD_C(3, 0); \ - LOAD_C(3, 1); \ - } while (0); - -#define STORE_C(PG, PTR, SRC, DST) \ - do { \ - SRC = svld1_f32((PG), (PTR)); \ - DST = svmad_z((PG), svalpha, DST, SRC); \ - svst1_f32((PG), (PTR), DST); \ - } while (0); - -int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, - FLOAT *C, BLASLONG ldc) { - BLASLONG pad_k = (k + 3) & ~3; - - svbfloat16_t ma0, ma1, ma2, ma3, mb0, mb1; - svfloat32_t mc00, mc01, mc10, mc11, mc20, mc21, mc30, mc31, - vc0, vc1, vc2, vc3, vc4, vc5, vc6, vc7, - oc0, oc1, oc2, oc3, oc4, oc5, oc6, oc7; - svfloat32_t svalpha = svdup_f32(alpha); - - svbool_t pg16 = svptrue_b16(); - svbool_t pg16_low = svdupq_b16(1, 1, 1, 1, 0, 0, 0, 0); - svbool_t pg32 = svptrue_b32(); - svbool_t pg32_low = svdupq_b32(1, 1, 0, 0); - svbool_t pg32_first = svdupq_b32(1, 0, 0, 0); - - bfloat16_t *ptr_a = (bfloat16_t *)A; - bfloat16_t *ptr_b = (bfloat16_t *)B; - FLOAT *ptr_c = C; - - bfloat16_t *ptr_a0, *ptr_a1, *ptr_a2, *ptr_a3; - bfloat16_t *ptr_b0, *ptr_b1; - FLOAT *ptr_c0, *ptr_c1, *ptr_c2, *ptr_c3; - - for (BLASLONG j = 0; j < n / 4; j++) { - ptr_c0 = ptr_c; - ptr_c1 = ptr_c0 + ldc; - ptr_c2 = ptr_c1 + ldc; - ptr_c3 = ptr_c2 + ldc; - ptr_c += 4 * ldc; - ptr_a = (bfloat16_t *)A; - - for (BLASLONG i = 0; i < m / 8; i++) { - ptr_a0 = ptr_a; - ptr_a += 8 * pad_k; - - ptr_b0 = ptr_b; - - LOAD_C_8x4; - - for (BLASLONG p = 0; p < pad_k; p += 4) { - ma0 = svld1_bf16(pg16, ptr_a0); - ma1 = svld1_bf16(pg16, ptr_a0 + 8); - ma2 = svld1_bf16(pg16, ptr_a0 + 16); - ma3 = svld1_bf16(pg16, ptr_a0 + 24); - - mb0 = svld1_bf16(pg16, ptr_b0); - mb1 = svld1_bf16(pg16, ptr_b0 + 8); - -#if 0 - for (int q = 0; q < 8; q++) { - float tmp = 0; - *((bfloat16_t *)(&tmp) + 1) = ptr_b0[8+q]; - printf("%.1f ", tmp); - } - printf("\n"); -#endif - - MATMUL(0, 0); MATMUL(0, 1); - MATMUL(1, 0); MATMUL(1, 1); - MATMUL(2, 0); MATMUL(2, 1); - MATMUL(3, 0); MATMUL(3, 1); - - ptr_a0 += 32; - ptr_b0 += 16; - } - - vc0 = svuzp1(mc00, mc10); - vc1 = svuzp1(mc20, mc30); - vc2 = svuzp2(mc00, mc10); - vc3 = svuzp2(mc20, mc30); - vc4 = svuzp1(mc01, mc11); - vc5 = svuzp1(mc21, mc31); - vc6 = svuzp2(mc01, mc11); - vc7 = svuzp2(mc21, mc31); - - STORE_C(pg32, ptr_c0, oc0, vc0); - STORE_C(pg32, ptr_c0+4, oc1, vc1); - STORE_C(pg32, ptr_c1, oc2, vc2); - STORE_C(pg32, ptr_c1+4, oc3, vc3); - STORE_C(pg32, ptr_c2, oc4, vc4) - STORE_C(pg32, ptr_c2+4, oc5, vc5); - STORE_C(pg32, ptr_c3, oc6, vc6) - STORE_C(pg32, ptr_c3+4, oc7, vc7); - - ptr_c0 += 8; - ptr_c1 += 8; - ptr_c2 += 8; - ptr_c3 += 8; - } - - if (m & 4) { - ptr_a0 = ptr_a; - ptr_a += 4 * pad_k; - ptr_b0 = ptr_b; - - LOAD_C(0, 0); LOAD_C(0, 1); - LOAD_C(1, 0); LOAD_C(1, 1); - - for (BLASLONG p = 0; p < pad_k; p += 4) { - ma0 = svld1_bf16(pg16, ptr_a0); - ma1 = svld1_bf16(pg16, ptr_a0 + 8); - mb0 = svld1_bf16(pg16, ptr_b0); - mb1 = svld1_bf16(pg16, ptr_b0 + 8); - - MATMUL(0, 0); MATMUL(0, 1); - MATMUL(1, 0); MATMUL(1, 1); - - ptr_a0 += 16; - ptr_b0 += 16; - } - - vc0 = svuzp1(mc00, mc10); - vc1 = svuzp2(mc00, mc10); - vc2 = svuzp1(mc01, mc11); - vc3 = svuzp2(mc01, mc11); - - STORE_C(pg32, ptr_c0, oc0, vc0); - STORE_C(pg32, ptr_c1, oc1, vc1); - STORE_C(pg32, ptr_c2, oc2, vc2); - STORE_C(pg32, ptr_c3, oc3, vc3); - - ptr_c0 += 4; - ptr_c1 += 4; - ptr_c2 += 4; - ptr_c3 += 4; - } - - if (m & 2) { - ptr_a0 = ptr_a; - ptr_a += 2 * pad_k; - ptr_b0 = ptr_b; - - LOAD_C(0, 0); LOAD_C(0, 1); - for (BLASLONG p = 0; p < pad_k; p += 4) { - ma0 = svld1_bf16(pg16, ptr_a0); - mb0 = svld1_bf16(pg16, ptr_b0); - mb1 = svld1_bf16(pg16, ptr_b0 + 8); - - MATMUL(0, 0); MATMUL(0, 1); - - ptr_a0 += 8; - ptr_b0 += 16; - } - - vc0 = svuzp1(mc00, mc00); - vc1 = svuzp2(mc00, mc00); - vc2 = svuzp1(mc01, mc01); - vc3 = svuzp2(mc01, mc01); - - STORE_C(pg32_low, ptr_c0, oc0, vc0); - STORE_C(pg32_low, ptr_c1, oc1, vc1); - STORE_C(pg32_low, ptr_c2, oc2, vc2); - STORE_C(pg32_low, ptr_c3, oc3, vc3); - - ptr_c0 += 2; - ptr_c1 += 2; - ptr_c2 += 2; - ptr_c3 += 2; - } - - if (m & 1) { - ptr_a0 = ptr_a; - ptr_b0 = ptr_b; - - LOAD_C(0, 0); LOAD_C(0, 1); - for (BLASLONG p = 0; p < pad_k; p += 4) { - ma0 = svld1_bf16(pg16_low, ptr_a0); - mb0 = svld1_bf16(pg16, ptr_b0); - mb1 = svld1_bf16(pg16, ptr_b0 + 8); - - MATMUL(0, 0); MATMUL(0, 1); - - ptr_a0 += 4; - ptr_b0 += 16; - } - - vc1 = svuzp2(mc00, mc00); - vc3 = svuzp2(mc01, mc01); - - STORE_C(pg32_first, ptr_c0, oc0, mc00); - STORE_C(pg32_first, ptr_c1, oc1, vc1); - STORE_C(pg32_first, ptr_c2, oc2, mc01); - STORE_C(pg32_first, ptr_c3, oc3, vc3); - - } - - ptr_b += 4 * pad_k; - } - - if (n & 2) { - ptr_c0 = ptr_c; - ptr_c1 = ptr_c0 + ldc; - ptr_c += 2 * ldc; - ptr_a = (bfloat16_t *)A; - - for (BLASLONG i = 0; i < m / 8; i++) { - ptr_a0 = ptr_a; - ptr_a += 8 * pad_k; - - ptr_b0 = ptr_b; - - LOAD_C(0, 0); - LOAD_C(1, 0); - LOAD_C(2, 0); - LOAD_C(3, 0); - - for (BLASLONG p = 0; p < pad_k; p += 4) { - ma0 = svld1_bf16(pg16, ptr_a0); - ma1 = svld1_bf16(pg16, ptr_a0 + 8); - ma2 = svld1_bf16(pg16, ptr_a0 + 16); - ma3 = svld1_bf16(pg16, ptr_a0 + 24); - - mb0 = svld1_bf16(pg16, ptr_b0); - - MATMUL(0, 0); - MATMUL(1, 0); - MATMUL(2, 0); - MATMUL(3, 0); - - ptr_a0 += 32; - ptr_b0 += 8; - } - - vc0 = svuzp1(mc00, mc10); - vc1 = svuzp1(mc20, mc30); - vc2 = svuzp2(mc00, mc10); - vc3 = svuzp2(mc20, mc30); - - STORE_C(pg32, ptr_c0, oc0, vc0); - STORE_C(pg32, ptr_c0 + 4, oc1, vc1); - STORE_C(pg32, ptr_c1, oc2, vc2); - STORE_C(pg32, ptr_c1 + 4, oc3, vc3); - - ptr_c0 += 8; - ptr_c1 += 8; - } - - if (m & 4) { - ptr_a0 = ptr_a; - ptr_a += 4 * pad_k; - ptr_b0 = ptr_b; - - LOAD_C(0, 0); - LOAD_C(1, 0); - - for (BLASLONG p = 0; p < pad_k; p += 4) { - ma0 = svld1_bf16(pg16, ptr_a0); - ma1 = svld1_bf16(pg16, ptr_a0 + 8); - mb0 = svld1_bf16(pg16, ptr_b0); - MATMUL(0, 0); - MATMUL(1, 0); - ptr_a0 += 16; - ptr_b0 += 8; - } - - vc0 = svuzp1(mc00, mc10); - vc1 = svuzp2(mc00, mc10); - - STORE_C(pg32, ptr_c0, oc0, vc0); - STORE_C(pg32, ptr_c1, oc1, vc1); - - ptr_c0 += 4; - ptr_c1 += 4; - } - - if (m & 2) { - ptr_a0 = ptr_a; - ptr_a += 2 * pad_k; - ptr_b0 = ptr_b; - - LOAD_C(0, 0); - - for (BLASLONG p = 0; p < pad_k; p += 4) { - ma0 = svld1_bf16(pg16, ptr_a0); - mb0 = svld1_bf16(pg16, ptr_b0); - - MATMUL(0, 0); - - ptr_a0 += 8; - ptr_b0 += 8; - } - - vc0 = svuzp1(mc00, mc00); - vc1 = svuzp2(mc00, mc00); - STORE_C(pg32_low, ptr_c0, oc0, vc0); - STORE_C(pg32_low, ptr_c1, oc1, vc1); - - ptr_c0 += 2; - ptr_c1 += 2; - - } - - if (m & 1) { - ptr_a0 = ptr_a; - ptr_b0 = ptr_b; - LOAD_C(0, 0); - for (BLASLONG p = 0; p < pad_k; p += 4) { - ma0 = svld1_bf16(pg16_low, ptr_a0); - mb0 = svld1_bf16(pg16, ptr_b0); - MATMUL(0, 0); - ptr_a0 += 4; - ptr_b0 += 8; - } - vc1 = svuzp2(mc00, mc00); - - STORE_C(pg32_first, ptr_c0, oc0, mc00); - STORE_C(pg32_first, ptr_c1, oc1, vc1); - } - - ptr_b += 2 * pad_k; - } - - if (n & 1) { - ptr_c0 = ptr_c; - ptr_a = (bfloat16_t *)A; - - for (BLASLONG i = 0; i < m / 8; i++) { - ptr_a0 = ptr_a; - ptr_a += 8 * pad_k; - - ptr_b0 = ptr_b; - - LOAD_C(0, 0); - LOAD_C(1, 0); - LOAD_C(2, 0); - LOAD_C(3, 0); - - for (BLASLONG p = 0; p < pad_k; p += 4) { - ma0 = svld1_bf16(pg16, ptr_a0); - ma1 = svld1_bf16(pg16, ptr_a0 + 8); - ma2 = svld1_bf16(pg16, ptr_a0 + 16); - ma3 = svld1_bf16(pg16, ptr_a0 + 24); - - mb0 = svld1_bf16(pg16_low, ptr_b0); - - MATMUL(0, 0); - MATMUL(1, 0); - MATMUL(2, 0); - MATMUL(3, 0); - - ptr_a0 += 32; - ptr_b0 += 4; - } - - vc0 = svuzp1(mc00, mc10); - vc1 = svuzp1(mc20, mc30); - - STORE_C(pg32, ptr_c0, oc0, vc0); - STORE_C(pg32, ptr_c0 + 4, oc1, vc1); - - ptr_c0 += 8; - } - - if (m & 4) { - ptr_a0 = ptr_a; - ptr_a += 4 * pad_k; - ptr_b0 = ptr_b; - LOAD_C(0, 0); - LOAD_C(1, 0); - for (BLASLONG p = 0; p < pad_k; p += 4) { - ma0 = svld1_bf16(pg16, ptr_a0); - ma1 = svld1_bf16(pg16, ptr_a0 + 8); - mb0 = svld1_bf16(pg16_low, ptr_b0); - MATMUL(0, 0); - MATMUL(1, 0); - ptr_a0 += 16; - ptr_b0 += 4; - } - vc0 = svuzp1(mc00, mc10); - STORE_C(pg32, ptr_c0, oc0, vc0); - ptr_c0 += 4; - } - - if (m & 2) { - ptr_a0 = ptr_a; - ptr_a += 2 * pad_k; - ptr_b0 = ptr_b; - - LOAD_C(0, 0); - - for (BLASLONG p = 0; p < pad_k; p += 4) { - ma0 = svld1_bf16(pg16, ptr_a0); - mb0 = svld1_bf16(pg16_low, ptr_b0); - - MATMUL(0, 0); - - ptr_a0 += 8; - ptr_b0 += 4; - } - vc0 = svuzp1(mc00, mc00); - STORE_C(pg32_low, ptr_c0, oc0, vc0); - ptr_c0 += 2; - } - - if (m & 1) { - ptr_a0 = ptr_a; - ptr_b0 = ptr_b; - LOAD_C(0, 0); - for (BLASLONG p = 0; p < pad_k; p += 4) { - ma0 = svld1_bf16(pg16_low, ptr_a0); - mb0 = svld1_bf16(pg16_low, ptr_b0); - MATMUL(0, 0); - ptr_a0 += 4; - ptr_b0 += 4; - } - STORE_C(pg32_first, ptr_c0, oc0, mc00); - } - } - - return 0; -} diff --git a/kernel/arm64/sbgemm_ncopy_4_neoversen2.c b/kernel/arm64/sbgemm_ncopy_4_neoversen2.c index 0b0e7a427..22978a388 100644 --- a/kernel/arm64/sbgemm_ncopy_4_neoversen2.c +++ b/kernel/arm64/sbgemm_ncopy_4_neoversen2.c @@ -58,17 +58,6 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { svst1_bf16(pg16, (bfloat16_t *)b_offset + 8, v2); svst1_bf16(pg16, (bfloat16_t *)b_offset + 12, v3); -#if 0 - for (int line = 0; line < 4; line++) { - for (int p = 0; p < 4; p++) { - float tmp = 0; - *((bfloat16 *)(&tmp) + 1) = b_offset[line * 4 + p]; - printf("%f ", tmp); - } - printf("\n"); - } -#endif - b_offset += 16; a_offsetx[0] += 4; a_offsetx[1] += 4; diff --git a/kernel/arm64/sbgemm_ncopy_neoversen2.c b/kernel/arm64/sbgemm_ncopy_neoversen2.c deleted file mode 100644 index 594067ebb..000000000 --- a/kernel/arm64/sbgemm_ncopy_neoversen2.c +++ /dev/null @@ -1,101 +0,0 @@ -/*************************************************************************** - * Copyright (c) 2022, The OpenBLAS Project - * All rights reserved. - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in - * the documentation and/or other materials provided with the - * distribution. - * 3. Neither the name of the OpenBLAS project nor the names of - * its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - * *****************************************************************************/ - -#include "common.h" - -int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { - IFLOAT *a_offset, *a_offset1, *a_offset2; - IFLOAT *b_offset; - - a_offset = a; - b_offset = b; - - for (BLASLONG j = 0; j < n / 2; j++) { - a_offset1 = a_offset; - a_offset2 = a_offset1 + lda; - a_offset += 2 * lda; - for (BLASLONG i = 0; i < m / 4; i++) { - *(b_offset + 0) = *(a_offset1 + 0); - *(b_offset + 1) = *(a_offset1 + 1); - *(b_offset + 2) = *(a_offset1 + 2); - *(b_offset + 3) = *(a_offset1 + 3); - *(b_offset + 4) = *(a_offset2 + 0); - *(b_offset + 5) = *(a_offset2 + 1); - *(b_offset + 6) = *(a_offset2 + 2); - *(b_offset + 7) = *(a_offset2 + 3); - - a_offset1 += 4; - a_offset2 += 4; - b_offset += 8; - } - BLASLONG rest = m & 3; - if (rest == 3) { - *(b_offset + 0) = *(a_offset1 + 0); - *(b_offset + 1) = *(a_offset1 + 1); - *(b_offset + 2) = *(a_offset1 + 2); - *(b_offset + 3) = *(a_offset2 + 0); - *(b_offset + 4) = *(a_offset2 + 1); - *(b_offset + 5) = *(a_offset2 + 2); - b_offset += 6; - } else if (rest == 2) { - *(b_offset + 0) = *(a_offset1 + 0); - *(b_offset + 1) = *(a_offset1 + 1); - *(b_offset + 2) = *(a_offset2 + 0); - *(b_offset + 3) = *(a_offset2 + 1); - b_offset += 4; - } else if (rest == 1) { - *(b_offset + 0) = *(a_offset1 + 0); - *(b_offset + 1) = *(a_offset2 + 0); - b_offset += 2; - } - } - if (n & 1) { - for (BLASLONG i = 0; i < m / 4; i++) { - *(b_offset + 0) = *(a_offset + 0); - *(b_offset + 1) = *(a_offset + 1); - *(b_offset + 2) = *(a_offset + 2); - *(b_offset + 3) = *(a_offset + 3); - - b_offset += 4; - a_offset += 4; - } - BLASLONG rest = m & 3; - if (rest == 3) { - *(b_offset + 0) = *(a_offset + 0); - *(b_offset + 1) = *(a_offset + 1); - *(b_offset + 2) = *(a_offset + 2); - } else if (rest == 2) { - *(b_offset + 0) = *(a_offset + 0); - *(b_offset + 1) = *(a_offset + 1); - } else if (rest == 1) { - *(b_offset + 0) = *(a_offset + 0); - } - } - - return 0; -} diff --git a/kernel/arm64/sbgemm_tcopy_8_neoversen2.c b/kernel/arm64/sbgemm_tcopy_8_neoversen2.c index 6c37e4bcf..a058b5a8e 100644 --- a/kernel/arm64/sbgemm_tcopy_8_neoversen2.c +++ b/kernel/arm64/sbgemm_tcopy_8_neoversen2.c @@ -43,15 +43,6 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { for (BLASLONG i = 0; i < m / 4; i++) { for (BLASLONG line = 0; line < 8; line++) { -#if 0 - float fv0 = 0, fv1 = 0, fv2 = 0, fv3 = 0; - *((bfloat16 *)(&fv0) + 1) = a_offset0[line]; - *((bfloat16 *)(&fv1) + 1) = a_offset1[line]; - *((bfloat16 *)(&fv2) + 1) = a_offset2[line]; - *((bfloat16 *)(&fv3) + 1) = a_offset3[line]; - printf("%f %f %f %f\n", fv0, fv1, fv2, fv3); -#endif - b_offset[line * 4] = a_offset0[line]; b_offset[line * 4 + 1] = a_offset1[line]; b_offset[line * 4 + 2] = a_offset2[line]; diff --git a/kernel/arm64/sbgemm_tcopy_neoversen2.c b/kernel/arm64/sbgemm_tcopy_neoversen2.c deleted file mode 100644 index 2f3313379..000000000 --- a/kernel/arm64/sbgemm_tcopy_neoversen2.c +++ /dev/null @@ -1,109 +0,0 @@ -/*************************************************************************** - * Copyright (c) 2022, The OpenBLAS Project - * All rights reserved. - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in - * the documentation and/or other materials provided with the - * distribution. - * 3. Neither the name of the OpenBLAS project nor the names of - * its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - * *****************************************************************************/ - -#include "common.h" - - -int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { - IFLOAT *a_offset, *a_offset1, *a_offset2, *a_offset3, *a_offset4; - IFLOAT *b_offset; - a_offset = a; - b_offset = b; - - for (BLASLONG j = 0; j < n / 2; j++) { - a_offset1 = a_offset; - a_offset2 = a_offset1 + lda; - a_offset3 = a_offset2 + lda; - a_offset4 = a_offset3 + lda; - a_offset += 2; - - for (BLASLONG i = 0; i < m / 4; i++) { - *(b_offset + 0) = *(a_offset1 + 0); - *(b_offset + 1) = *(a_offset2 + 0); - *(b_offset + 2) = *(a_offset3 + 0); - *(b_offset + 3) = *(a_offset4 + 0); - *(b_offset + 4) = *(a_offset1 + 1); - *(b_offset + 5) = *(a_offset2 + 1); - *(b_offset + 6) = *(a_offset3 + 1); - *(b_offset + 7) = *(a_offset4 + 1); - - b_offset += 8; - a_offset1 += 4 * lda; - a_offset2 += 4 * lda; - a_offset3 += 4 * lda; - a_offset4 += 4 * lda; - } - - if (m & 3) { - BLASLONG rest = m & 3; - if (rest == 3) { - *(b_offset + 0) = *(a_offset1 + 0); - *(b_offset + 1) = *(a_offset2 + 0); - *(b_offset + 2) = *(a_offset3 + 0); - *(b_offset + 3) = *(a_offset1 + 1); - *(b_offset + 4) = *(a_offset2 + 1); - *(b_offset + 5) = *(a_offset3 + 1); - b_offset += 6; - } else if (rest == 2) { - *(b_offset + 0) = *(a_offset1 + 0); - *(b_offset + 1) = *(a_offset2 + 0); - *(b_offset + 2) = *(a_offset1 + 1); - *(b_offset + 3) = *(a_offset2 + 1); - b_offset += 4; - } else if (rest == 1) { - *(b_offset + 0) = *(a_offset1 + 0); - *(b_offset + 1) = *(a_offset1 + 1); - b_offset += 2; - } - } - } - if (n & 1) { - for (BLASLONG i = 0; i < m / 4; i++) { - *(b_offset + 0) = *(a_offset); - *(b_offset + 1) = *(a_offset + lda); - *(b_offset + 2) = *(a_offset + lda * 2); - *(b_offset + 3) = *(a_offset + lda * 3); - - b_offset += 4; - a_offset += 4 * lda; - } - BLASLONG rest = m & 3; - if (rest == 3) { - *(b_offset + 0) = *(a_offset); - *(b_offset + 1) = *(a_offset + lda); - *(b_offset + 2) = *(a_offset + lda * 2); - } else if (rest == 2) { - *(b_offset + 0) = *(a_offset); - *(b_offset + 1) = *(a_offset + lda); - } else if (rest == 1) { - *(b_offset + 0) = *(a_offset); - } - } - - return 0; -}