209 lines
7.6 KiB
C
209 lines
7.6 KiB
C
/***************************************************************************
|
|
Copyright (c) 2021, The OpenBLAS Project
|
|
All rights reserved.
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions are
|
|
met:
|
|
1. Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
2. Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in
|
|
the documentation and/or other materials provided with the
|
|
distribution.
|
|
3. Neither the name of the OpenBLAS project nor the names of
|
|
its contributors may be used to endorse or promote products
|
|
derived from this software without specific prior written permission.
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
|
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
|
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*****************************************************************************/
|
|
|
|
#include <stdio.h>
|
|
#include <immintrin.h>
|
|
#include "common.h"
|
|
|
|
#define REORDER_4x32(r0, r1, r2, r3) {\
|
|
__m512i t0, t1, t2, t3; \
|
|
t0 = _mm512_unpacklo_epi32(r0, r1); \
|
|
t1 = _mm512_unpackhi_epi32(r0, r1); \
|
|
t2 = _mm512_unpacklo_epi32(r2, r3); \
|
|
t3 = _mm512_unpackhi_epi32(r2, r3); \
|
|
r0 = _mm512_unpacklo_epi64(t0, t2); \
|
|
r1 = _mm512_unpackhi_epi64(t0, t2); \
|
|
r2 = _mm512_unpacklo_epi64(t1, t3); \
|
|
r3 = _mm512_unpackhi_epi64(t1, t3); \
|
|
t0 = _mm512_permutex2var_epi32(r0, idx_lo_128, r1); \
|
|
t1 = _mm512_permutex2var_epi32(r0, idx_hi_128, r1); \
|
|
t2 = _mm512_permutex2var_epi32(r2, idx_lo_128, r3); \
|
|
t3 = _mm512_permutex2var_epi32(r2, idx_hi_128, r3); \
|
|
r0 = _mm512_permutex2var_epi32(t0, idx_lo_256, t2); \
|
|
r1 = _mm512_permutex2var_epi32(t1, idx_lo_256, t3); \
|
|
r2 = _mm512_permutex2var_epi32(t0, idx_hi_256, t2); \
|
|
r3 = _mm512_permutex2var_epi32(t1, idx_hi_256, t3); \
|
|
}
|
|
|
|
#define REORDER_4x8(r0, r1, r2, r3) {\
|
|
__m128i t0, t1, t2, t3; \
|
|
t0 = _mm_unpacklo_epi32(r0, r1); \
|
|
t1 = _mm_unpackhi_epi32(r0, r1); \
|
|
t2 = _mm_unpacklo_epi32(r2, r3); \
|
|
t3 = _mm_unpackhi_epi32(r2, r3); \
|
|
r0 = _mm_unpacklo_epi64(t0, t2); \
|
|
r1 = _mm_unpackhi_epi64(t0, t2); \
|
|
r2 = _mm_unpacklo_epi64(t1, t3); \
|
|
r3 = _mm_unpackhi_epi64(t1, t3); \
|
|
}
|
|
|
|
#define GET_TAIL(tail, remain_m) \
|
|
switch((remain_m + 1)/2) { \
|
|
case 1: tail = r0; break; \
|
|
case 2: tail = r1; break; \
|
|
case 3: tail = r2; break; \
|
|
case 4: tail = r3; break; \
|
|
}
|
|
|
|
int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){
|
|
BLASLONG i, j;
|
|
IFLOAT *aoffset;
|
|
IFLOAT *aoffset0, *aoffset1, *aoffset2, *aoffset3;
|
|
|
|
IFLOAT *boffset;
|
|
|
|
aoffset = a;
|
|
boffset = b;
|
|
|
|
BLASLONG m32 = m & ~31;
|
|
BLASLONG m8 = m & ~7;
|
|
BLASLONG n4 = n & ~3;
|
|
|
|
int permute_table[] = {
|
|
0x0, 0x1, 0x2, 0x3, 0x10, 0x11, 0x12, 0x13, 0x8, 0x9, 0xa, 0xb, 0x18, 0x19, 0x1a, 0x1b,
|
|
0x4, 0x5, 0x6, 0x7, 0x14, 0x15, 0x16, 0x17, 0xc, 0xd, 0xe, 0xf, 0x1c, 0x1d, 0x1e, 0x1f,
|
|
0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
|
0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
|
|
};
|
|
__m512i idx_lo_128 = _mm512_loadu_si512(permute_table);
|
|
__m512i idx_hi_128 = _mm512_loadu_si512(permute_table + 16);
|
|
__m512i idx_lo_256 = _mm512_loadu_si512(permute_table + 32);
|
|
__m512i idx_hi_256 = _mm512_loadu_si512(permute_table + 48);
|
|
|
|
for (j = 0; j < n4; j += 4) {
|
|
aoffset0 = aoffset;
|
|
aoffset1 = aoffset0 + lda;
|
|
aoffset2 = aoffset1 + lda;
|
|
aoffset3 = aoffset2 + lda;
|
|
aoffset += 4 * lda;
|
|
|
|
for (i = 0; i < m32; i += 32) {
|
|
__m512i r0, r1, r2, r3;
|
|
r0 = _mm512_loadu_si512(aoffset0 + i);
|
|
r1 = _mm512_loadu_si512(aoffset1 + i);
|
|
r2 = _mm512_loadu_si512(aoffset2 + i);
|
|
r3 = _mm512_loadu_si512(aoffset3 + i);
|
|
REORDER_4x32(r0, r1, r2, r3);
|
|
_mm512_storeu_si512(boffset + 32*0, r0);
|
|
_mm512_storeu_si512(boffset + 32*1, r1);
|
|
_mm512_storeu_si512(boffset + 32*2, r2);
|
|
_mm512_storeu_si512(boffset + 32*3, r3);
|
|
boffset += 32 * 4;
|
|
}
|
|
for (; i < m8; i += 8) {
|
|
__m128i r0 = _mm_loadu_si128((void *)(aoffset0 + i));
|
|
__m128i r1 = _mm_loadu_si128((void *)(aoffset1 + i));
|
|
__m128i r2 = _mm_loadu_si128((void *)(aoffset2 + i));
|
|
__m128i r3 = _mm_loadu_si128((void *)(aoffset3 + i));
|
|
REORDER_4x8(r0, r1, r2, r3);
|
|
_mm_storeu_si128((void *)(boffset + 8*0), r0);
|
|
_mm_storeu_si128((void *)(boffset + 8*1), r1);
|
|
_mm_storeu_si128((void *)(boffset + 8*2), r2);
|
|
_mm_storeu_si128((void *)(boffset + 8*3), r3);
|
|
boffset += 8 * 4;
|
|
}
|
|
if (i < m) {
|
|
int remain_m = m - i;
|
|
__mmask8 r_mask = (1UL << remain_m) - 1;
|
|
__m128i r0 = _mm_maskz_loadu_epi16(r_mask, aoffset0 + i);
|
|
__m128i r1 = _mm_maskz_loadu_epi16(r_mask, aoffset1 + i);
|
|
__m128i r2 = _mm_maskz_loadu_epi16(r_mask, aoffset2 + i);
|
|
__m128i r3 = _mm_maskz_loadu_epi16(r_mask, aoffset3 + i);
|
|
REORDER_4x8(r0, r1, r2, r3);
|
|
|
|
// store should skip the tail odd line
|
|
int num_store = remain_m/2;
|
|
switch(num_store) {
|
|
case 3: _mm_storeu_si128((void *)(boffset + 8*2), r2);
|
|
case 2: _mm_storeu_si128((void *)(boffset + 8*1), r1);
|
|
case 1: _mm_storeu_si128((void *)(boffset + 8*0), r0);
|
|
}
|
|
boffset += 8 * num_store;
|
|
|
|
if (m & 0x1) { // handling the tail
|
|
__m128i tail;
|
|
GET_TAIL(tail, remain_m);
|
|
/* tail vector is fill with zero like:
|
|
* a, 0, b, 0, c, 0, d, 0
|
|
* need to extract lo words of data and store
|
|
*/
|
|
tail = _mm_cvtepi32_epi16(tail);
|
|
_mm_store_sd((double *)boffset, (__m128d) tail); // only lower 4 bfloat valid
|
|
boffset += 4;
|
|
}
|
|
}
|
|
}
|
|
if (j < n) {
|
|
int remain_n = n - j;
|
|
__mmask8 nmask = (1UL << remain_n) - 1;
|
|
aoffset0 = aoffset;
|
|
aoffset1 = aoffset0 + lda;
|
|
aoffset2 = aoffset1 + lda;
|
|
aoffset3 = aoffset2 + lda;
|
|
__m128i r0, r1, r2, r3;
|
|
for (i = 0; i < m8; i += 8) {
|
|
switch (remain_n) {
|
|
case 3: r2 = _mm_loadu_si128((void *)(aoffset2 + i));
|
|
case 2: r1 = _mm_loadu_si128((void *)(aoffset1 + i));
|
|
case 1: r0 = _mm_loadu_si128((void *)(aoffset0 + i));
|
|
}
|
|
REORDER_4x8(r0, r1, r2, r3);
|
|
_mm_mask_storeu_epi32(boffset + remain_n * 0, nmask, r0);
|
|
_mm_mask_storeu_epi32(boffset + remain_n * 2, nmask, r1);
|
|
_mm_mask_storeu_epi32(boffset + remain_n * 4, nmask, r2);
|
|
_mm_mask_storeu_epi32(boffset + remain_n * 6, nmask, r3);
|
|
boffset += 8 * remain_n;
|
|
}
|
|
if (i < m) {
|
|
int remain_m = m - i;
|
|
__mmask8 mmask = (1UL << remain_m) - 1;
|
|
switch (remain_n) {
|
|
case 3: r2 = _mm_maskz_loadu_epi16(mmask, aoffset2 + i);
|
|
case 2: r1 = _mm_maskz_loadu_epi16(mmask, aoffset1 + i);
|
|
case 1: r0 = _mm_maskz_loadu_epi16(mmask, aoffset0 + i);
|
|
}
|
|
REORDER_4x8(r0, r1, r2, r3);
|
|
|
|
int num_store = remain_m/2;
|
|
switch (num_store) {
|
|
case 3: _mm_mask_storeu_epi32(boffset + remain_n * 4, nmask, r2);
|
|
case 2: _mm_mask_storeu_epi32(boffset + remain_n * 2, nmask, r1);
|
|
case 1: _mm_mask_storeu_epi32(boffset + remain_n * 0, nmask, r0);
|
|
}
|
|
boffset += 2 * num_store * remain_n;
|
|
|
|
if (m & 0x1) {
|
|
__m128i tail;
|
|
GET_TAIL(tail, remain_m);
|
|
tail = _mm_cvtepi32_epi16(tail);
|
|
_mm_mask_storeu_epi16(boffset, nmask, tail);
|
|
}
|
|
}
|
|
}
|
|
return 0;
|
|
}
|