diff --git a/common_param.h b/common_param.h index 31fba9059..091840343 100644 --- a/common_param.h +++ b/common_param.h @@ -1193,6 +1193,7 @@ BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG); #ifdef BUILD_COMPLEX16 int (*zgeadd_k) (BLASLONG, BLASLONG, double, double, double *, BLASLONG, double, double, double *, BLASLONG); #endif + int align_k; // must be 2^n } gotoblas_t; extern gotoblas_t *gotoblas; diff --git a/driver/level3/level3.c b/driver/level3/level3.c index 4a8e193be..d3281345d 100644 --- a/driver/level3/level3.c +++ b/driver/level3/level3.c @@ -304,6 +304,16 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, while (gemm_p * min_l > l2size) gemm_p -= GEMM_UNROLL_M; } + BLASLONG pad_min_l = min_l; + +#if defined(HALF) && defined(DYNAMIC_ARCH) + pad_min_l = (min_l + gotoblas->align_k - 1) & ~(gotoblas->align_k-1); +#endif + +#if defined(HALF) && !defined(DYNAMIC_ARCH) && defined(NEOVERSEN2) + pad_min_l = (min_l + 3) & ~3; +#endif + /* First, we have to move data A to L2 cache */ min_i = m_to - m_from; l1stride = 1; @@ -350,7 +360,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, START_RPCC(); OCOPY_OPERATION(min_l, min_jj, b, ldb, ls, jjs, - sb + min_l * (jjs - js) * COMPSIZE * l1stride); + sb + pad_min_l * (jjs - js) * COMPSIZE * l1stride); STOP_RPCC(outercost); @@ -358,10 +368,10 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, #if !defined(XDOUBLE) || !defined(QUAD_PRECISION) KERNEL_OPERATION(min_i, min_jj, min_l, alpha, - sa, sb + min_l * (jjs - js) * COMPSIZE * l1stride, c, ldc, m_from, jjs); + sa, sb + pad_min_l * (jjs - js) * COMPSIZE * l1stride, c, ldc, m_from, jjs); #else KERNEL_OPERATION(min_i, min_jj, min_l, (void *)&xalpha, - sa, sb + min_l * (jjs - js) * COMPSIZE * l1stride, c, ldc, m_from, jjs); + sa, sb + pad_min_l * (jjs - js) * COMPSIZE * l1stride, c, ldc, m_from, jjs); #endif STOP_RPCC(kernelcost); diff --git a/driver/level3/level3_thread.c b/driver/level3/level3_thread.c index dfc7107b8..95c8e6d19 100644 --- a/driver/level3/level3_thread.c +++ b/driver/level3/level3_thread.c @@ -324,6 +324,16 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, } else { if (min_l > GEMM_Q) min_l = (min_l + 1) / 2; } + + BLASLONG pad_min_l = min_l; + +#if defined(HALF) && defined(DYNAMIC_ARCH) + pad_min_l = (min_l + gotoblas->align_k - 1) & ~(gotoblas->align_k-1); +#endif + +#if defined(HALF) && !defined(DYNAMIC_ARCH) && defined(NEOVERSEN2) + pad_min_l = (min_l + 3) & ~3; +#endif /* Determine step size in m * Note: We are currently on the first step in m @@ -382,13 +392,13 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, /* Copy part of local region of B into workspace */ START_RPCC(); OCOPY_OPERATION(min_l, min_jj, b, ldb, ls, jjs, - buffer[bufferside] + min_l * (jjs - js) * COMPSIZE * l1stride); + buffer[bufferside] + pad_min_l * (jjs - js) * COMPSIZE * l1stride); STOP_RPCC(copy_B); /* Apply kernel with local region of A and part of local region of B */ START_RPCC(); KERNEL_OPERATION(min_i, min_jj, min_l, alpha, - sa, buffer[bufferside] + min_l * (jjs - js) * COMPSIZE * l1stride, + sa, buffer[bufferside] + pad_min_l * (jjs - js) * COMPSIZE * l1stride, c, ldc, m_from, jjs); STOP_RPCC(kernel); diff --git a/kernel/arm64/KERNEL.NEOVERSEN2 b/kernel/arm64/KERNEL.NEOVERSEN2 index 07a94a043..7fe9acd5c 100644 --- a/kernel/arm64/KERNEL.NEOVERSEN2 +++ b/kernel/arm64/KERNEL.NEOVERSEN2 @@ -189,11 +189,12 @@ ZGEMMONCOPYOBJ = zgemm_oncopy$(TSUFFIX).$(SUFFIX) ZGEMMOTCOPYOBJ = zgemm_otcopy$(TSUFFIX).$(SUFFIX) SBGEMM_BETA = sbgemm_beta_neoversen2.c -SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversen2.c -SBGEMMINCOPY = sbgemm_ncopy_neoversen2.c -SBGEMMITCOPY = sbgemm_tcopy_neoversen2.c -SBGEMMONCOPY = sbgemm_ncopy_neoversen2.c -SBGEMMOTCOPY = sbgemm_tcopy_neoversen2.c +# SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversen2.c +SBGEMMKERNEL = sbgemm_kernel_neoversen2_newbf16.c +SBGEMMINCOPY = sbgemm_ncopy_4_neoversen2.c +SBGEMMITCOPY = sbgemm_tcopy_8_neoversen2.c +SBGEMMONCOPY = sbgemm_ncopy_4_neoversen2.c +SBGEMMOTCOPY = sbgemm_tcopy_8_neoversen2.c SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX) SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX) SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX) diff --git a/kernel/arm64/sbgemm_kernel_neoversen2_newbf16.c b/kernel/arm64/sbgemm_kernel_neoversen2_newbf16.c new file mode 100644 index 000000000..1bf743c7f --- /dev/null +++ b/kernel/arm64/sbgemm_kernel_neoversen2_newbf16.c @@ -0,0 +1,467 @@ +/*************************************************************************** + * Copyright (c) 2022, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include + +#include "common.h" + +#define LOAD_C(M, N) mc##M##N = svdup_f32(0); + +#define MATMUL(M, N) mc##M##N = svbfmmla(mc##M##N, ma##M, mb##N); + +#define LOAD_C_8x4 \ + do { \ + LOAD_C(0, 0); \ + LOAD_C(0, 1); \ + LOAD_C(1, 0); \ + LOAD_C(1, 1); \ + LOAD_C(2, 0); \ + LOAD_C(2, 1); \ + LOAD_C(3, 0); \ + LOAD_C(3, 1); \ + } while (0); + +#define STORE_C(PG, PTR, SRC, DST) \ + do { \ + SRC = svld1_f32((PG), (PTR)); \ + DST = svmad_z((PG), svalpha, DST, SRC); \ + svst1_f32((PG), (PTR), DST); \ + } while (0); + +int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, + FLOAT *C, BLASLONG ldc) { + BLASLONG pad_k = (k + 3) & ~3; + + svbfloat16_t ma0, ma1, ma2, ma3, mb0, mb1; + svfloat32_t mc00, mc01, mc10, mc11, mc20, mc21, mc30, mc31, + vc0, vc1, vc2, vc3, vc4, vc5, vc6, vc7, + oc0, oc1, oc2, oc3, oc4, oc5, oc6, oc7; + svfloat32_t svalpha = svdup_f32(alpha); + + svbool_t pg16 = svptrue_b16(); + svbool_t pg16_low = svdupq_b16(1, 1, 1, 1, 0, 0, 0, 0); + svbool_t pg32 = svptrue_b32(); + svbool_t pg32_low = svdupq_b32(1, 1, 0, 0); + svbool_t pg32_first = svdupq_b32(1, 0, 0, 0); + + bfloat16_t *ptr_a = (bfloat16_t *)A; + bfloat16_t *ptr_b = (bfloat16_t *)B; + FLOAT *ptr_c = C; + + bfloat16_t *ptr_a0, *ptr_a1, *ptr_a2, *ptr_a3; + bfloat16_t *ptr_b0, *ptr_b1; + FLOAT *ptr_c0, *ptr_c1, *ptr_c2, *ptr_c3; + + for (BLASLONG j = 0; j < n / 4; j++) { + ptr_c0 = ptr_c; + ptr_c1 = ptr_c0 + ldc; + ptr_c2 = ptr_c1 + ldc; + ptr_c3 = ptr_c2 + ldc; + ptr_c += 4 * ldc; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 8; i++) { + ptr_a0 = ptr_a; + ptr_a += 8 * pad_k; + + ptr_b0 = ptr_b; + + LOAD_C_8x4; + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + ma1 = svld1_bf16(pg16, ptr_a0 + 8); + ma2 = svld1_bf16(pg16, ptr_a0 + 16); + ma3 = svld1_bf16(pg16, ptr_a0 + 24); + + mb0 = svld1_bf16(pg16, ptr_b0); + mb1 = svld1_bf16(pg16, ptr_b0 + 8); + +#if 0 + for (int q = 0; q < 8; q++) { + float tmp = 0; + *((bfloat16_t *)(&tmp) + 1) = ptr_b0[8+q]; + printf("%.1f ", tmp); + } + printf("\n"); +#endif + + MATMUL(0, 0); MATMUL(0, 1); + MATMUL(1, 0); MATMUL(1, 1); + MATMUL(2, 0); MATMUL(2, 1); + MATMUL(3, 0); MATMUL(3, 1); + + ptr_a0 += 32; + ptr_b0 += 16; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp1(mc20, mc30); + vc2 = svuzp2(mc00, mc10); + vc3 = svuzp2(mc20, mc30); + vc4 = svuzp1(mc01, mc11); + vc5 = svuzp1(mc21, mc31); + vc6 = svuzp2(mc01, mc11); + vc7 = svuzp2(mc21, mc31); + + STORE_C(pg32, ptr_c0, oc0, vc0); + STORE_C(pg32, ptr_c0+4, oc1, vc1); + STORE_C(pg32, ptr_c1, oc2, vc2); + STORE_C(pg32, ptr_c1+4, oc3, vc3); + STORE_C(pg32, ptr_c2, oc4, vc4) + STORE_C(pg32, ptr_c2+4, oc5, vc5); + STORE_C(pg32, ptr_c3, oc6, vc6) + STORE_C(pg32, ptr_c3+4, oc7, vc7); + + ptr_c0 += 8; + ptr_c1 += 8; + ptr_c2 += 8; + ptr_c3 += 8; + } + + if (m & 4) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + ptr_b0 = ptr_b; + + LOAD_C(0, 0); LOAD_C(0, 1); + LOAD_C(1, 0); LOAD_C(1, 1); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + ma1 = svld1_bf16(pg16, ptr_a0 + 8); + mb0 = svld1_bf16(pg16, ptr_b0); + mb1 = svld1_bf16(pg16, ptr_b0 + 8); + + MATMUL(0, 0); MATMUL(0, 1); + MATMUL(1, 0); MATMUL(1, 1); + + ptr_a0 += 16; + ptr_b0 += 16; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp2(mc00, mc10); + vc2 = svuzp1(mc01, mc11); + vc3 = svuzp2(mc01, mc11); + + STORE_C(pg32, ptr_c0, oc0, vc0); + STORE_C(pg32, ptr_c1, oc1, vc1); + STORE_C(pg32, ptr_c2, oc2, vc2); + STORE_C(pg32, ptr_c3, oc3, vc3); + + ptr_c0 += 4; + ptr_c1 += 4; + ptr_c2 += 4; + ptr_c3 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + ptr_b0 = ptr_b; + + LOAD_C(0, 0); LOAD_C(0, 1); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + mb0 = svld1_bf16(pg16, ptr_b0); + mb1 = svld1_bf16(pg16, ptr_b0 + 8); + + MATMUL(0, 0); MATMUL(0, 1); + + ptr_a0 += 8; + ptr_b0 += 16; + } + + vc0 = svuzp1(mc00, mc00); + vc1 = svuzp2(mc00, mc00); + vc2 = svuzp1(mc01, mc01); + vc3 = svuzp2(mc01, mc01); + + STORE_C(pg32_low, ptr_c0, oc0, vc0); + STORE_C(pg32_low, ptr_c1, oc1, vc1); + STORE_C(pg32_low, ptr_c2, oc2, vc2); + STORE_C(pg32_low, ptr_c3, oc3, vc3); + + ptr_c0 += 2; + ptr_c1 += 2; + ptr_c2 += 2; + ptr_c3 += 2; + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + + LOAD_C(0, 0); LOAD_C(0, 1); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_low, ptr_a0); + mb0 = svld1_bf16(pg16, ptr_b0); + mb1 = svld1_bf16(pg16, ptr_b0 + 8); + + MATMUL(0, 0); MATMUL(0, 1); + + ptr_a0 += 4; + ptr_b0 += 16; + } + + vc1 = svuzp2(mc00, mc00); + vc3 = svuzp2(mc01, mc01); + + STORE_C(pg32_first, ptr_c0, oc0, mc00); + STORE_C(pg32_first, ptr_c1, oc1, vc1); + STORE_C(pg32_first, ptr_c2, oc2, mc01); + STORE_C(pg32_first, ptr_c3, oc3, vc3); + + } + + ptr_b += 4 * pad_k; + } + + if (n & 2) { + ptr_c0 = ptr_c; + ptr_c1 = ptr_c0 + ldc; + ptr_c += 2 * ldc; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 8; i++) { + ptr_a0 = ptr_a; + ptr_a += 8 * pad_k; + + ptr_b0 = ptr_b; + + LOAD_C(0, 0); + LOAD_C(1, 0); + LOAD_C(2, 0); + LOAD_C(3, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + ma1 = svld1_bf16(pg16, ptr_a0 + 8); + ma2 = svld1_bf16(pg16, ptr_a0 + 16); + ma3 = svld1_bf16(pg16, ptr_a0 + 24); + + mb0 = svld1_bf16(pg16, ptr_b0); + + MATMUL(0, 0); + MATMUL(1, 0); + MATMUL(2, 0); + MATMUL(3, 0); + + ptr_a0 += 32; + ptr_b0 += 8; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp1(mc20, mc30); + vc2 = svuzp2(mc00, mc10); + vc3 = svuzp2(mc20, mc30); + + STORE_C(pg32, ptr_c0, oc0, vc0); + STORE_C(pg32, ptr_c0 + 4, oc1, vc1); + STORE_C(pg32, ptr_c1, oc2, vc2); + STORE_C(pg32, ptr_c1 + 4, oc3, vc3); + + ptr_c0 += 8; + ptr_c1 += 8; + } + + if (m & 4) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + ptr_b0 = ptr_b; + + LOAD_C(0, 0); + LOAD_C(1, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + ma1 = svld1_bf16(pg16, ptr_a0 + 8); + mb0 = svld1_bf16(pg16, ptr_b0); + MATMUL(0, 0); + MATMUL(1, 0); + ptr_a0 += 16; + ptr_b0 += 8; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp2(mc00, mc10); + + STORE_C(pg32, ptr_c0, oc0, vc0); + STORE_C(pg32, ptr_c1, oc1, vc1); + + ptr_c0 += 4; + ptr_c1 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + ptr_b0 = ptr_b; + + LOAD_C(0, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + mb0 = svld1_bf16(pg16, ptr_b0); + + MATMUL(0, 0); + + ptr_a0 += 8; + ptr_b0 += 8; + } + + vc0 = svuzp1(mc00, mc00); + vc1 = svuzp2(mc00, mc00); + STORE_C(pg32_low, ptr_c0, oc0, vc0); + STORE_C(pg32_low, ptr_c1, oc1, vc1); + + ptr_c0 += 2; + ptr_c1 += 2; + + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + LOAD_C(0, 0); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_low, ptr_a0); + mb0 = svld1_bf16(pg16, ptr_b0); + MATMUL(0, 0); + ptr_a0 += 4; + ptr_b0 += 8; + } + vc1 = svuzp2(mc00, mc00); + + STORE_C(pg32_first, ptr_c0, oc0, mc00); + STORE_C(pg32_first, ptr_c1, oc1, vc1); + } + + ptr_b += 2 * pad_k; + } + + if (n & 1) { + ptr_c0 = ptr_c; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 8; i++) { + ptr_a0 = ptr_a; + ptr_a += 8 * pad_k; + + ptr_b0 = ptr_b; + + LOAD_C(0, 0); + LOAD_C(1, 0); + LOAD_C(2, 0); + LOAD_C(3, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + ma1 = svld1_bf16(pg16, ptr_a0 + 8); + ma2 = svld1_bf16(pg16, ptr_a0 + 16); + ma3 = svld1_bf16(pg16, ptr_a0 + 24); + + mb0 = svld1_bf16(pg16_low, ptr_b0); + + MATMUL(0, 0); + MATMUL(1, 0); + MATMUL(2, 0); + MATMUL(3, 0); + + ptr_a0 += 32; + ptr_b0 += 4; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp1(mc20, mc30); + + STORE_C(pg32, ptr_c0, oc0, vc0); + STORE_C(pg32, ptr_c0 + 4, oc1, vc1); + + ptr_c0 += 8; + } + + if (m & 4) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + ptr_b0 = ptr_b; + LOAD_C(0, 0); + LOAD_C(1, 0); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + ma1 = svld1_bf16(pg16, ptr_a0 + 8); + mb0 = svld1_bf16(pg16_low, ptr_b0); + MATMUL(0, 0); + MATMUL(1, 0); + ptr_a0 += 16; + ptr_b0 += 4; + } + vc0 = svuzp1(mc00, mc10); + STORE_C(pg32, ptr_c0, oc0, vc0); + ptr_c0 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + ptr_b0 = ptr_b; + + LOAD_C(0, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16, ptr_a0); + mb0 = svld1_bf16(pg16_low, ptr_b0); + + MATMUL(0, 0); + + ptr_a0 += 8; + ptr_b0 += 4; + } + vc0 = svuzp1(mc00, mc00); + STORE_C(pg32_low, ptr_c0, oc0, vc0); + ptr_c0 += 2; + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + LOAD_C(0, 0); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_low, ptr_a0); + mb0 = svld1_bf16(pg16_low, ptr_b0); + MATMUL(0, 0); + ptr_a0 += 4; + ptr_b0 += 4; + } + STORE_C(pg32_first, ptr_c0, oc0, mc00); + } + } + + return 0; +} diff --git a/kernel/arm64/sbgemm_ncopy_4_neoversen2.c b/kernel/arm64/sbgemm_ncopy_4_neoversen2.c new file mode 100644 index 000000000..0b0e7a427 --- /dev/null +++ b/kernel/arm64/sbgemm_ncopy_4_neoversen2.c @@ -0,0 +1,137 @@ +/*************************************************************************** + * Copyright (c) 2022, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include + +#include "common.h" + +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { + IFLOAT *a_offset; + IFLOAT *a_offsetx[4]; + IFLOAT *b_offset; + a_offset = a; + b_offset = b; + + svbool_t pg16 = svdupq_b16(1, 1, 1, 1, 0, 0, 0, 0); + svbfloat16_t v0, v1, v2, v3; + + for (BLASLONG j = 0; j < n / 4; j++) { + a_offsetx[0] = a_offset; + a_offsetx[1] = a_offsetx[0] + lda; + a_offsetx[2] = a_offsetx[1] + lda; + a_offsetx[3] = a_offsetx[2] + lda; + a_offset += 4 * lda; + + for (BLASLONG i = 0; i < m / 4; i++) { + v0 = svld1_bf16(pg16, (bfloat16_t *)a_offsetx[0]); + v1 = svld1_bf16(pg16, (bfloat16_t *)a_offsetx[1]); + v2 = svld1_bf16(pg16, (bfloat16_t *)a_offsetx[2]); + v3 = svld1_bf16(pg16, (bfloat16_t *)a_offsetx[3]); + + svst1_bf16(pg16, (bfloat16_t *)b_offset, v0); + svst1_bf16(pg16, (bfloat16_t *)b_offset + 4, v1); + svst1_bf16(pg16, (bfloat16_t *)b_offset + 8, v2); + svst1_bf16(pg16, (bfloat16_t *)b_offset + 12, v3); + +#if 0 + for (int line = 0; line < 4; line++) { + for (int p = 0; p < 4; p++) { + float tmp = 0; + *((bfloat16 *)(&tmp) + 1) = b_offset[line * 4 + p]; + printf("%f ", tmp); + } + printf("\n"); + } +#endif + + b_offset += 16; + a_offsetx[0] += 4; + a_offsetx[1] += 4; + a_offsetx[2] += 4; + a_offsetx[3] += 4; + } + + if (m & 3) { + BLASLONG rest = m & 3; + for (BLASLONG col = 0; col < 4; col++) { + b_offset[4 * col] = a_offsetx[col][0]; + b_offset[4 * col + 1] = rest == 1 ? 0 : a_offsetx[col][1]; + b_offset[4 * col + 2] = rest <= 2 ? 0 : a_offsetx[col][2]; + b_offset[4 * col + 3] = rest <= 3 ? 0 : a_offsetx[col][3]; + } + b_offset += 16; + } + } + + if (n & 2) { + a_offsetx[0] = a_offset; + a_offsetx[1] = a_offsetx[0] + lda; + a_offset += 2 * lda; + + for (BLASLONG i = 0; i < m / 4; i++) { + v0 = svld1_bf16(pg16, (bfloat16_t *)a_offsetx[0]); + v1 = svld1_bf16(pg16, (bfloat16_t *)a_offsetx[1]); + svst1_bf16(pg16, (bfloat16_t *)b_offset, v0); + svst1_bf16(pg16, (bfloat16_t *)b_offset + 4, v1); + + b_offset += 8; + a_offsetx[0] += 4; + a_offsetx[1] += 4; + } + + if (m & 3) { + BLASLONG rest = m & 3; + for (BLASLONG col = 0; col < 2; col++) { + b_offset[4 * col] = a_offsetx[col][0]; + b_offset[4 * col + 1] = rest == 1 ? 0 : a_offsetx[col][1]; + b_offset[4 * col + 2] = rest <= 2 ? 0 : a_offsetx[col][2]; + b_offset[4 * col + 3] = rest <= 3 ? 0 : a_offsetx[col][3]; + } + b_offset += 8; + } + } + + if (n & 1) { + a_offsetx[0] = a_offset; + for (BLASLONG i = 0; i < m / 4; i++) { + v0 = svld1_bf16(pg16, (bfloat16_t *)a_offsetx[0]); + svst1_bf16(pg16, (bfloat16_t *)b_offset, v0); + b_offset += 4; + a_offsetx[0] += 4; + } + if (m & 3) { + BLASLONG rest = m & 3; + b_offset[0] = a_offsetx[0][0]; + b_offset[1] = rest == 1 ? 0 : a_offsetx[0][1]; + b_offset[2] = rest <= 2 ? 0 : a_offsetx[0][2]; + b_offset[3] = rest <= 3 ? 0 : a_offsetx[0][3]; + } + } + + return 0; +} diff --git a/kernel/arm64/sbgemm_tcopy_8_neoversen2.c b/kernel/arm64/sbgemm_tcopy_8_neoversen2.c new file mode 100644 index 000000000..6c37e4bcf --- /dev/null +++ b/kernel/arm64/sbgemm_tcopy_8_neoversen2.c @@ -0,0 +1,174 @@ +/*************************************************************************** + * Copyright (c) 2022, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { + IFLOAT *a_offset, *a_offset0, *a_offset1, *a_offset2, *a_offset3; + IFLOAT *b_offset; + a_offset = a; + b_offset = b; + + for (BLASLONG j = 0; j < n / 8; j++) { + a_offset0 = a_offset; + a_offset1 = a_offset0 + lda; + a_offset2 = a_offset1 + lda; + a_offset3 = a_offset2 + lda; + a_offset += 8; + + for (BLASLONG i = 0; i < m / 4; i++) { + for (BLASLONG line = 0; line < 8; line++) { +#if 0 + float fv0 = 0, fv1 = 0, fv2 = 0, fv3 = 0; + *((bfloat16 *)(&fv0) + 1) = a_offset0[line]; + *((bfloat16 *)(&fv1) + 1) = a_offset1[line]; + *((bfloat16 *)(&fv2) + 1) = a_offset2[line]; + *((bfloat16 *)(&fv3) + 1) = a_offset3[line]; + printf("%f %f %f %f\n", fv0, fv1, fv2, fv3); +#endif + + b_offset[line * 4] = a_offset0[line]; + b_offset[line * 4 + 1] = a_offset1[line]; + b_offset[line * 4 + 2] = a_offset2[line]; + b_offset[line * 4 + 3] = a_offset3[line]; + } + + b_offset += 32; + a_offset0 += 4 * lda; + a_offset1 += 4 * lda; + a_offset2 += 4 * lda; + a_offset3 += 4 * lda; + } + + if (m & 3) { + BLASLONG rest = m & 3; + for (BLASLONG line = 0; line < 8; line++) { + b_offset[line * 4] = a_offset0[line]; + b_offset[line * 4 + 1] = rest == 1 ? 0 : a_offset1[line]; + b_offset[line * 4 + 2] = rest <= 2 ? 0 : a_offset2[line]; + b_offset[line * 4 + 3] = rest <= 3 ? 0 : a_offset3[line]; + } + b_offset += 32; + } + } + + if (n & 4) { + a_offset0 = a_offset; + a_offset1 = a_offset0 + lda; + a_offset2 = a_offset1 + lda; + a_offset3 = a_offset2 + lda; + a_offset += 4; + + for (BLASLONG i = 0; i < m / 4; i++) { + for (BLASLONG line = 0; line < 4; line++) { + b_offset[line * 4] = a_offset0[line]; + b_offset[line * 4 + 1] = a_offset1[line]; + b_offset[line * 4 + 2] = a_offset2[line]; + b_offset[line * 4 + 3] = a_offset3[line]; + } + + b_offset += 16; + a_offset0 += 4 * lda; + a_offset1 += 4 * lda; + a_offset2 += 4 * lda; + a_offset3 += 4 * lda; + } + + if (m & 3) { + BLASLONG rest = m & 3; + for (BLASLONG line = 0; line < 4; line++) { + b_offset[line * 4] = a_offset0[line]; + b_offset[line * 4 + 1] = rest == 1 ? 0 : a_offset1[line]; + b_offset[line * 4 + 2] = rest <= 2 ? 0 : a_offset2[line]; + b_offset[line * 4 + 3] = rest <= 3 ? 0 : a_offset3[line]; + } + b_offset += 16; + } + } + + if (n & 2) { + a_offset0 = a_offset; + a_offset1 = a_offset0 + lda; + a_offset2 = a_offset1 + lda; + a_offset3 = a_offset2 + lda; + a_offset += 2; + + for (BLASLONG i = 0; i < m / 4; i++) { + for (BLASLONG line = 0; line < 2; line++) { + b_offset[line * 4] = a_offset0[line]; + b_offset[line * 4 + 1] = a_offset1[line]; + b_offset[line * 4 + 2] = a_offset2[line]; + b_offset[line * 4 + 3] = a_offset3[line]; + } + b_offset += 8; + a_offset0 += 4 * lda; + a_offset1 += 4 * lda; + a_offset2 += 4 * lda; + a_offset3 += 4 * lda; + } + + if (m & 3) { + BLASLONG rest = m & 3; + for (BLASLONG line = 0; line < 2; line++) { + b_offset[line * 4] = a_offset0[line]; + b_offset[line * 4 + 1] = rest == 1 ? 0 : a_offset1[line]; + b_offset[line * 4 + 2] = rest <= 2 ? 0 : a_offset2[line]; + b_offset[line * 4 + 3] = rest <= 3 ? 0 : a_offset3[line]; + } + b_offset += 8; + } + } + + if (n & 1) { + a_offset0 = a_offset; + a_offset1 = a_offset0 + lda; + a_offset2 = a_offset1 + lda; + a_offset3 = a_offset2 + lda; + + for (BLASLONG i = 0; i < m / 4; i++) { + b_offset[0] = *a_offset0; + b_offset[1] = *a_offset1; + b_offset[2] = *a_offset2; + b_offset[3] = *a_offset3; + b_offset += 4; + a_offset0 += 4 * lda; + a_offset1 += 4 * lda; + a_offset2 += 4 * lda; + a_offset3 += 4 * lda; + } + + if (m & 3) { + BLASLONG rest = m & 3; + b_offset[0] = *a_offset0; + b_offset[1] = rest == 1 ? 0 : *a_offset1; + b_offset[2] = rest <= 2 ? 0 : *a_offset2; + b_offset[3] = rest <= 3 ? 0 : *a_offset3; + } + } + return 0; +} diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index 8bcd31ef2..010c39bd4 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -866,8 +866,9 @@ gotoblas_t TABLE_NAME = { cgeadd_kTS, #endif #if BUILD_COMPLEX16==1 - zgeadd_kTS + zgeadd_kTS, #endif + 0, // padding_k }; #if (ARCH_ARM64) @@ -972,6 +973,12 @@ static void init_parameter(void) { TABLE_NAME.xgemm3m_r = TABLE_NAME.qgemm_r; #endif #endif + +#if defined(NEOVERSEN2) && BUILD_BFLOAT16 == 1 + TABLE_NAME.align_k = 4; +#else + TABLE_NAME.align_k = 1; +#endif } #else // (ARCH_ARM64)