From 135718eafce6b42c62ae1080a55e80c819dd4788 Mon Sep 17 00:00:00 2001 From: lilianhuang Date: Mon, 28 Nov 2022 04:17:54 -0500 Subject: [PATCH] Improve the performance of sbgemm_tcopy on neoversen2 --- kernel/arm64/sbgemm_tcopy_8_neoversen2.c | 58 +++++++++++++++++++----- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/kernel/arm64/sbgemm_tcopy_8_neoversen2.c b/kernel/arm64/sbgemm_tcopy_8_neoversen2.c index a058b5a8e..459dfa16a 100644 --- a/kernel/arm64/sbgemm_tcopy_8_neoversen2.c +++ b/kernel/arm64/sbgemm_tcopy_8_neoversen2.c @@ -25,6 +25,7 @@ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. * *****************************************************************************/ +#include #include "common.h" @@ -34,6 +35,9 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { a_offset = a; b_offset = b; + uint16x8_t v0, v1, v2, v3, v4, v5, v6, v7; + uint16x4_t v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; + for (BLASLONG j = 0; j < n / 8; j++) { a_offset0 = a_offset; a_offset1 = a_offset0 + lda; @@ -42,12 +46,29 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { a_offset += 8; for (BLASLONG i = 0; i < m / 4; i++) { - for (BLASLONG line = 0; line < 8; line++) { - b_offset[line * 4] = a_offset0[line]; - b_offset[line * 4 + 1] = a_offset1[line]; - b_offset[line * 4 + 2] = a_offset2[line]; - b_offset[line * 4 + 3] = a_offset3[line]; - } + v0 = vld1q_u16(a_offset0); + v1 = vld1q_u16(a_offset1); + v2 = vld1q_u16(a_offset2); + v3 = vld1q_u16(a_offset3); + + v4 = vtrn1q_u16(v0, v1); + v5 = vtrn2q_u16(v0, v1); + v6 = vtrn1q_u16(v2, v3); + v7 = vtrn2q_u16(v2, v3); + + v0 = (uint16x8_t)vtrn1q_u32((uint32x4_t)v4, (uint32x4_t)v6); + v1 = (uint16x8_t)vtrn1q_u32((uint32x4_t)v5, (uint32x4_t)v7); + v2 = (uint16x8_t)vtrn2q_u32((uint32x4_t)v4, (uint32x4_t)v6); + v3 = (uint16x8_t)vtrn2q_u32((uint32x4_t)v5, (uint32x4_t)v7); + + vst1_u16(b_offset, vget_low_u16(v0)); + vst1_u16(b_offset + 4, vget_low_u16(v1)); + vst1_u16(b_offset + 8, vget_low_u16(v2)); + vst1_u16(b_offset + 12, vget_low_u16(v3)); + vst1_u16(b_offset + 16, vget_high_u16(v0)); + vst1_u16(b_offset + 20, vget_high_u16(v1)); + vst1_u16(b_offset + 24, vget_high_u16(v2)); + vst1_u16(b_offset + 28, vget_high_u16(v3)); b_offset += 32; a_offset0 += 4 * lda; @@ -76,12 +97,25 @@ int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { a_offset += 4; for (BLASLONG i = 0; i < m / 4; i++) { - for (BLASLONG line = 0; line < 4; line++) { - b_offset[line * 4] = a_offset0[line]; - b_offset[line * 4 + 1] = a_offset1[line]; - b_offset[line * 4 + 2] = a_offset2[line]; - b_offset[line * 4 + 3] = a_offset3[line]; - } + v0_h = vld1_u16(a_offset0); + v1_h = vld1_u16(a_offset1); + v2_h = vld1_u16(a_offset2); + v3_h = vld1_u16(a_offset3); + + v4_h = vtrn1_u16(v0_h, v1_h); + v5_h = vtrn2_u16(v0_h, v1_h); + v6_h = vtrn1_u16(v2_h, v3_h); + v7_h = vtrn2_u16(v2_h, v3_h); + + v0_h = (uint16x4_t)vtrn1_u32((uint32x2_t)v4_h, (uint32x2_t)v6_h); + v1_h = (uint16x4_t)vtrn1_u32((uint32x2_t)v5_h, (uint32x2_t)v7_h); + v2_h = (uint16x4_t)vtrn2_u32((uint32x2_t)v4_h, (uint32x2_t)v6_h); + v3_h = (uint16x4_t)vtrn2_u32((uint32x2_t)v5_h, (uint32x2_t)v7_h); + + vst1_u16(b_offset, v0_h); + vst1_u16(b_offset + 4, v1_h); + vst1_u16(b_offset + 8, v2_h); + vst1_u16(b_offset + 12, v3_h); b_offset += 16; a_offset0 += 4 * lda;