1180 lines
29 KiB
C
1180 lines
29 KiB
C
/*********************************************************************************
|
|
Copyright (c) 2013, The OpenBLAS Project
|
|
All rights reserved.
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions are
|
|
met:
|
|
1. Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
2. Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in
|
|
the documentation and/or other materials provided with the
|
|
distribution.
|
|
3. Neither the name of the OpenBLAS project nor the names of
|
|
its contributors may be used to endorse or promote products
|
|
derived from this software without specific prior written permission.
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
|
|
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
|
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
**********************************************************************************/
|
|
|
|
|
|
/* comment below left for history, data does not represent the implementation in this file */
|
|
|
|
/*********************************************************************
|
|
* 2014/07/28 Saar
|
|
* BLASTEST : OK
|
|
* CTEST : OK
|
|
* TEST : OK
|
|
*
|
|
* 2013/10/28 Saar
|
|
* Parameter:
|
|
* SGEMM_DEFAULT_UNROLL_N 4
|
|
* SGEMM_DEFAULT_UNROLL_M 16
|
|
* SGEMM_DEFAULT_P 768
|
|
* SGEMM_DEFAULT_Q 384
|
|
* A_PR1 512
|
|
* B_PR1 512
|
|
*
|
|
*
|
|
* 2014/07/28 Saar
|
|
* Performance at 9216x9216x9216:
|
|
* 1 thread: 102 GFLOPS (SANDYBRIDGE: 59) (MKL: 83)
|
|
* 2 threads: 195 GFLOPS (SANDYBRIDGE: 116) (MKL: 155)
|
|
* 3 threads: 281 GFLOPS (SANDYBRIDGE: 165) (MKL: 230)
|
|
* 4 threads: 366 GFLOPS (SANDYBRIDGE: 223) (MKL: 267)
|
|
*
|
|
*********************************************************************/
|
|
|
|
#include "common.h"
|
|
#include <immintrin.h>
|
|
|
|
|
|
|
|
/*******************************************************************************************
|
|
* 8 lines of N
|
|
*******************************************************************************************/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/*******************************************************************************************
|
|
* 4 lines of N
|
|
*******************************************************************************************/
|
|
|
|
#define INIT64x4() \
|
|
row0 = _mm512_setzero_ps(); \
|
|
row1 = _mm512_setzero_ps(); \
|
|
row2 = _mm512_setzero_ps(); \
|
|
row3 = _mm512_setzero_ps(); \
|
|
row0b = _mm512_setzero_ps(); \
|
|
row1b = _mm512_setzero_ps(); \
|
|
row2b = _mm512_setzero_ps(); \
|
|
row3b = _mm512_setzero_ps(); \
|
|
row0c = _mm512_setzero_ps(); \
|
|
row1c = _mm512_setzero_ps(); \
|
|
row2c = _mm512_setzero_ps(); \
|
|
row3c = _mm512_setzero_ps(); \
|
|
row0d = _mm512_setzero_ps(); \
|
|
row1d = _mm512_setzero_ps(); \
|
|
row2d = _mm512_setzero_ps(); \
|
|
row3d = _mm512_setzero_ps(); \
|
|
|
|
#define KERNEL64x4_SUB() \
|
|
zmm0 = _mm512_loadu_ps(AO); \
|
|
zmm1 = _mm512_loadu_ps(A1); \
|
|
zmm5 = _mm512_loadu_ps(A2); \
|
|
zmm7 = _mm512_loadu_ps(A3); \
|
|
zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO)); \
|
|
zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+1)); \
|
|
row0 += zmm0 * zmm2; \
|
|
row1 += zmm0 * zmm3; \
|
|
row0b += zmm1 * zmm2; \
|
|
row1b += zmm1 * zmm3; \
|
|
row0c += zmm5 * zmm2; \
|
|
row1c += zmm5 * zmm3; \
|
|
row0d += zmm7 * zmm2; \
|
|
row1d += zmm7 * zmm3; \
|
|
zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO+2)); \
|
|
zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+3)); \
|
|
row2 += zmm0 * zmm2; \
|
|
row3 += zmm0 * zmm3; \
|
|
row2b += zmm1 * zmm2; \
|
|
row3b += zmm1 * zmm3; \
|
|
row2c += zmm5 * zmm2; \
|
|
row3c += zmm5 * zmm3; \
|
|
row2d += zmm7 * zmm2; \
|
|
row3d += zmm7 * zmm3; \
|
|
BO += 4; \
|
|
AO += 16; \
|
|
A1 += 16; \
|
|
A2 += 16; \
|
|
A3 += 16; \
|
|
|
|
|
|
#define SAVE64x4(ALPHA) \
|
|
zmm0 = _mm512_set1_ps(ALPHA); \
|
|
row0 *= zmm0; \
|
|
row1 *= zmm0; \
|
|
row2 *= zmm0; \
|
|
row3 *= zmm0; \
|
|
row0b *= zmm0; \
|
|
row1b *= zmm0; \
|
|
row2b *= zmm0; \
|
|
row3b *= zmm0; \
|
|
row0c *= zmm0; \
|
|
row1c *= zmm0; \
|
|
row2c *= zmm0; \
|
|
row3c *= zmm0; \
|
|
row0d *= zmm0; \
|
|
row1d *= zmm0; \
|
|
row2d *= zmm0; \
|
|
row3d *= zmm0; \
|
|
row0 += _mm512_loadu_ps(CO1 + 0*ldc); \
|
|
row1 += _mm512_loadu_ps(CO1 + 1*ldc); \
|
|
row2 += _mm512_loadu_ps(CO1 + 2*ldc); \
|
|
row3 += _mm512_loadu_ps(CO1 + 3*ldc); \
|
|
_mm512_storeu_ps(CO1 + 0*ldc, row0); \
|
|
_mm512_storeu_ps(CO1 + 1*ldc, row1); \
|
|
_mm512_storeu_ps(CO1 + 2*ldc, row2); \
|
|
_mm512_storeu_ps(CO1 + 3*ldc, row3); \
|
|
row0b += _mm512_loadu_ps(CO1 + 0*ldc + 16); \
|
|
row1b += _mm512_loadu_ps(CO1 + 1*ldc + 16); \
|
|
row2b += _mm512_loadu_ps(CO1 + 2*ldc + 16); \
|
|
row3b += _mm512_loadu_ps(CO1 + 3*ldc + 16); \
|
|
_mm512_storeu_ps(CO1 + 0*ldc + 16, row0b); \
|
|
_mm512_storeu_ps(CO1 + 1*ldc + 16, row1b); \
|
|
_mm512_storeu_ps(CO1 + 2*ldc + 16, row2b); \
|
|
_mm512_storeu_ps(CO1 + 3*ldc + 16, row3b); \
|
|
row0c += _mm512_loadu_ps(CO1 + 0*ldc + 32); \
|
|
row1c += _mm512_loadu_ps(CO1 + 1*ldc + 32); \
|
|
row2c += _mm512_loadu_ps(CO1 + 2*ldc + 32); \
|
|
row3c += _mm512_loadu_ps(CO1 + 3*ldc + 32); \
|
|
_mm512_storeu_ps(CO1 + 0*ldc + 32, row0c); \
|
|
_mm512_storeu_ps(CO1 + 1*ldc + 32, row1c); \
|
|
_mm512_storeu_ps(CO1 + 2*ldc + 32, row2c); \
|
|
_mm512_storeu_ps(CO1 + 3*ldc + 32, row3c); \
|
|
row0d += _mm512_loadu_ps(CO1 + 0*ldc + 48); \
|
|
row1d += _mm512_loadu_ps(CO1 + 1*ldc + 48); \
|
|
row2d += _mm512_loadu_ps(CO1 + 2*ldc + 48); \
|
|
row3d += _mm512_loadu_ps(CO1 + 3*ldc + 48); \
|
|
_mm512_storeu_ps(CO1 + 0*ldc + 48, row0d); \
|
|
_mm512_storeu_ps(CO1 + 1*ldc + 48, row1d); \
|
|
_mm512_storeu_ps(CO1 + 2*ldc + 48, row2d); \
|
|
_mm512_storeu_ps(CO1 + 3*ldc + 48, row3d);
|
|
|
|
|
|
#define INIT48x4() \
|
|
row0 = _mm512_setzero_ps(); \
|
|
row1 = _mm512_setzero_ps(); \
|
|
row2 = _mm512_setzero_ps(); \
|
|
row3 = _mm512_setzero_ps(); \
|
|
row0b = _mm512_setzero_ps(); \
|
|
row1b = _mm512_setzero_ps(); \
|
|
row2b = _mm512_setzero_ps(); \
|
|
row3b = _mm512_setzero_ps(); \
|
|
row0c = _mm512_setzero_ps(); \
|
|
row1c = _mm512_setzero_ps(); \
|
|
row2c = _mm512_setzero_ps(); \
|
|
row3c = _mm512_setzero_ps(); \
|
|
|
|
#define KERNEL48x4_SUB() \
|
|
zmm0 = _mm512_loadu_ps(AO); \
|
|
zmm1 = _mm512_loadu_ps(A1); \
|
|
zmm5 = _mm512_loadu_ps(A2); \
|
|
zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO)); \
|
|
zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+1)); \
|
|
row0 += zmm0 * zmm2; \
|
|
row1 += zmm0 * zmm3; \
|
|
row0b += zmm1 * zmm2; \
|
|
row1b += zmm1 * zmm3; \
|
|
row0c += zmm5 * zmm2; \
|
|
row1c += zmm5 * zmm3; \
|
|
zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO+2)); \
|
|
zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+3)); \
|
|
row2 += zmm0 * zmm2; \
|
|
row3 += zmm0 * zmm3; \
|
|
row2b += zmm1 * zmm2; \
|
|
row3b += zmm1 * zmm3; \
|
|
row2c += zmm5 * zmm2; \
|
|
row3c += zmm5 * zmm3; \
|
|
BO += 4; \
|
|
AO += 16; \
|
|
A1 += 16; \
|
|
A2 += 16;
|
|
|
|
|
|
#define SAVE48x4(ALPHA) \
|
|
zmm0 = _mm512_set1_ps(ALPHA); \
|
|
row0 *= zmm0; \
|
|
row1 *= zmm0; \
|
|
row2 *= zmm0; \
|
|
row3 *= zmm0; \
|
|
row0b *= zmm0; \
|
|
row1b *= zmm0; \
|
|
row2b *= zmm0; \
|
|
row3b *= zmm0; \
|
|
row0c *= zmm0; \
|
|
row1c *= zmm0; \
|
|
row2c *= zmm0; \
|
|
row3c *= zmm0; \
|
|
row0 += _mm512_loadu_ps(CO1 + 0*ldc); \
|
|
row1 += _mm512_loadu_ps(CO1 + 1*ldc); \
|
|
row2 += _mm512_loadu_ps(CO1 + 2*ldc); \
|
|
row3 += _mm512_loadu_ps(CO1 + 3*ldc); \
|
|
_mm512_storeu_ps(CO1 + 0*ldc, row0); \
|
|
_mm512_storeu_ps(CO1 + 1*ldc, row1); \
|
|
_mm512_storeu_ps(CO1 + 2*ldc, row2); \
|
|
_mm512_storeu_ps(CO1 + 3*ldc, row3); \
|
|
row0b += _mm512_loadu_ps(CO1 + 0*ldc + 16); \
|
|
row1b += _mm512_loadu_ps(CO1 + 1*ldc + 16); \
|
|
row2b += _mm512_loadu_ps(CO1 + 2*ldc + 16); \
|
|
row3b += _mm512_loadu_ps(CO1 + 3*ldc + 16); \
|
|
_mm512_storeu_ps(CO1 + 0*ldc + 16, row0b); \
|
|
_mm512_storeu_ps(CO1 + 1*ldc + 16, row1b); \
|
|
_mm512_storeu_ps(CO1 + 2*ldc + 16, row2b); \
|
|
_mm512_storeu_ps(CO1 + 3*ldc + 16, row3b); \
|
|
row0c += _mm512_loadu_ps(CO1 + 0*ldc + 32); \
|
|
row1c += _mm512_loadu_ps(CO1 + 1*ldc + 32); \
|
|
row2c += _mm512_loadu_ps(CO1 + 2*ldc + 32); \
|
|
row3c += _mm512_loadu_ps(CO1 + 3*ldc + 32); \
|
|
_mm512_storeu_ps(CO1 + 0*ldc + 32, row0c); \
|
|
_mm512_storeu_ps(CO1 + 1*ldc + 32, row1c); \
|
|
_mm512_storeu_ps(CO1 + 2*ldc + 32, row2c); \
|
|
_mm512_storeu_ps(CO1 + 3*ldc + 32, row3c);
|
|
|
|
|
|
#define INIT32x4() \
|
|
row0 = _mm512_setzero_ps(); \
|
|
row1 = _mm512_setzero_ps(); \
|
|
row2 = _mm512_setzero_ps(); \
|
|
row3 = _mm512_setzero_ps(); \
|
|
row0b = _mm512_setzero_ps(); \
|
|
row1b = _mm512_setzero_ps(); \
|
|
row2b = _mm512_setzero_ps(); \
|
|
row3b = _mm512_setzero_ps(); \
|
|
|
|
#define KERNEL32x4_SUB() \
|
|
zmm0 = _mm512_loadu_ps(AO); \
|
|
zmm1 = _mm512_loadu_ps(A1); \
|
|
zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO)); \
|
|
zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+1)); \
|
|
row0 += zmm0 * zmm2; \
|
|
row1 += zmm0 * zmm3; \
|
|
row0b += zmm1 * zmm2; \
|
|
row1b += zmm1 * zmm3; \
|
|
zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO+2)); \
|
|
zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+3)); \
|
|
row2 += zmm0 * zmm2; \
|
|
row3 += zmm0 * zmm3; \
|
|
row2b += zmm1 * zmm2; \
|
|
row3b += zmm1 * zmm3; \
|
|
BO += 4; \
|
|
AO += 16; \
|
|
A1 += 16;
|
|
|
|
|
|
#define SAVE32x4(ALPHA) \
|
|
zmm0 = _mm512_set1_ps(ALPHA); \
|
|
row0 *= zmm0; \
|
|
row1 *= zmm0; \
|
|
row2 *= zmm0; \
|
|
row3 *= zmm0; \
|
|
row0b *= zmm0; \
|
|
row1b *= zmm0; \
|
|
row2b *= zmm0; \
|
|
row3b *= zmm0; \
|
|
row0 += _mm512_loadu_ps(CO1 + 0*ldc); \
|
|
row1 += _mm512_loadu_ps(CO1 + 1*ldc); \
|
|
row2 += _mm512_loadu_ps(CO1 + 2*ldc); \
|
|
row3 += _mm512_loadu_ps(CO1 + 3*ldc); \
|
|
_mm512_storeu_ps(CO1 + 0*ldc, row0); \
|
|
_mm512_storeu_ps(CO1 + 1*ldc, row1); \
|
|
_mm512_storeu_ps(CO1 + 2*ldc, row2); \
|
|
_mm512_storeu_ps(CO1 + 3*ldc, row3); \
|
|
row0b += _mm512_loadu_ps(CO1 + 0*ldc + 16); \
|
|
row1b += _mm512_loadu_ps(CO1 + 1*ldc + 16); \
|
|
row2b += _mm512_loadu_ps(CO1 + 2*ldc + 16); \
|
|
row3b += _mm512_loadu_ps(CO1 + 3*ldc + 16); \
|
|
_mm512_storeu_ps(CO1 + 0*ldc + 16, row0b); \
|
|
_mm512_storeu_ps(CO1 + 1*ldc + 16, row1b); \
|
|
_mm512_storeu_ps(CO1 + 2*ldc + 16, row2b); \
|
|
_mm512_storeu_ps(CO1 + 3*ldc + 16, row3b);
|
|
|
|
|
|
|
|
#define INIT16x4() \
|
|
row0 = _mm512_setzero_ps(); \
|
|
row1 = _mm512_setzero_ps(); \
|
|
row2 = _mm512_setzero_ps(); \
|
|
row3 = _mm512_setzero_ps(); \
|
|
|
|
#define KERNEL16x4_SUB() \
|
|
zmm0 = _mm512_loadu_ps(AO); \
|
|
zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO)); \
|
|
zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+1)); \
|
|
row0 += zmm0 * zmm2; \
|
|
row1 += zmm0 * zmm3; \
|
|
zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO+2)); \
|
|
zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+3)); \
|
|
row2 += zmm0 * zmm2; \
|
|
row3 += zmm0 * zmm3; \
|
|
BO += 4; \
|
|
AO += 16;
|
|
|
|
|
|
#define SAVE16x4(ALPHA) \
|
|
zmm0 = _mm512_set1_ps(ALPHA); \
|
|
row0 *= zmm0; \
|
|
row1 *= zmm0; \
|
|
row2 *= zmm0; \
|
|
row3 *= zmm0; \
|
|
row0 += _mm512_loadu_ps(CO1 + 0 * ldc); \
|
|
row1 += _mm512_loadu_ps(CO1 + 1 * ldc); \
|
|
row2 += _mm512_loadu_ps(CO1 + 2 * ldc); \
|
|
row3 += _mm512_loadu_ps(CO1 + 3 * ldc); \
|
|
_mm512_storeu_ps(CO1 + 0 * ldc, row0); \
|
|
_mm512_storeu_ps(CO1 + 1 * ldc, row1); \
|
|
_mm512_storeu_ps(CO1 + 2 * ldc, row2); \
|
|
_mm512_storeu_ps(CO1 + 3 * ldc, row3);
|
|
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
#define INIT8x4() \
|
|
ymm4 = _mm256_setzero_ps(); \
|
|
ymm6 = _mm256_setzero_ps(); \
|
|
ymm8 = _mm256_setzero_ps(); \
|
|
ymm10 = _mm256_setzero_ps(); \
|
|
|
|
#define KERNEL8x4_SUB() \
|
|
ymm0 = _mm256_loadu_ps(AO); \
|
|
ymm2 = _mm256_broadcastss_ps(_mm_load_ss(BO + 0)); \
|
|
ymm3 = _mm256_broadcastss_ps(_mm_load_ss(BO + 1)); \
|
|
ymm4 += ymm0 * ymm2; \
|
|
ymm6 += ymm0 * ymm3; \
|
|
ymm2 = _mm256_broadcastss_ps(_mm_load_ss(BO + 2)); \
|
|
ymm3 = _mm256_broadcastss_ps(_mm_load_ss(BO + 3)); \
|
|
ymm8 += ymm0 * ymm2; \
|
|
ymm10 += ymm0 * ymm3; \
|
|
BO += 4; \
|
|
AO += 8;
|
|
|
|
|
|
#define SAVE8x4(ALPHA) \
|
|
ymm0 = _mm256_set1_ps(ALPHA); \
|
|
ymm4 *= ymm0; \
|
|
ymm6 *= ymm0; \
|
|
ymm8 *= ymm0; \
|
|
ymm10 *= ymm0; \
|
|
ymm4 += _mm256_loadu_ps(CO1 + 0 * ldc); \
|
|
ymm6 += _mm256_loadu_ps(CO1 + 1 * ldc); \
|
|
ymm8 += _mm256_loadu_ps(CO1 + 2 * ldc); \
|
|
ymm10 += _mm256_loadu_ps(CO1 + 3 * ldc); \
|
|
_mm256_storeu_ps(CO1 + 0 * ldc, ymm4); \
|
|
_mm256_storeu_ps(CO1 + 1 * ldc, ymm6); \
|
|
_mm256_storeu_ps(CO1 + 2 * ldc, ymm8); \
|
|
_mm256_storeu_ps(CO1 + 3 * ldc, ymm10); \
|
|
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
#define INIT4x4() \
|
|
row0 = _mm_setzero_ps(); \
|
|
row1 = _mm_setzero_ps(); \
|
|
row2 = _mm_setzero_ps(); \
|
|
row3 = _mm_setzero_ps(); \
|
|
|
|
|
|
#define KERNEL4x4_SUB() \
|
|
xmm0 = _mm_loadu_ps(AO); \
|
|
xmm2 = _mm_broadcastss_ps(_mm_load_ss(BO + 0)); \
|
|
xmm3 = _mm_broadcastss_ps(_mm_load_ss(BO + 1)); \
|
|
row0 += xmm0 * xmm2; \
|
|
row1 += xmm0 * xmm3; \
|
|
xmm2 = _mm_broadcastss_ps(_mm_load_ss(BO + 2)); \
|
|
xmm3 = _mm_broadcastss_ps(_mm_load_ss(BO + 3)); \
|
|
row2 += xmm0 * xmm2; \
|
|
row3 += xmm0 * xmm3; \
|
|
BO += 4; \
|
|
AO += 4;
|
|
|
|
|
|
#define SAVE4x4(ALPHA) \
|
|
xmm0 = _mm_set1_ps(ALPHA); \
|
|
row0 *= xmm0; \
|
|
row1 *= xmm0; \
|
|
row2 *= xmm0; \
|
|
row3 *= xmm0; \
|
|
row0 += _mm_loadu_ps(CO1 + 0 * ldc); \
|
|
row1 += _mm_loadu_ps(CO1 + 1 * ldc); \
|
|
row2 += _mm_loadu_ps(CO1 + 2 * ldc); \
|
|
row3 += _mm_loadu_ps(CO1 + 3 * ldc); \
|
|
_mm_storeu_ps(CO1 + 0 * ldc, row0); \
|
|
_mm_storeu_ps(CO1 + 1 * ldc, row1); \
|
|
_mm_storeu_ps(CO1 + 2 * ldc, row2); \
|
|
_mm_storeu_ps(CO1 + 3 * ldc, row3); \
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
#define INIT2x4() \
|
|
row0 = 0; row0b = 0; row1 = 0; row1b = 0; \
|
|
row2 = 0; row2b = 0; row3 = 0; row3b = 0;
|
|
|
|
#define KERNEL2x4_SUB() \
|
|
xmm0 = *(AO); \
|
|
xmm1 = *(AO + 1); \
|
|
xmm2 = *(BO + 0); \
|
|
xmm3 = *(BO + 1); \
|
|
row0 += xmm0 * xmm2; \
|
|
row0b += xmm1 * xmm2; \
|
|
row1 += xmm0 * xmm3; \
|
|
row1b += xmm1 * xmm3; \
|
|
xmm2 = *(BO + 2); \
|
|
xmm3 = *(BO + 3); \
|
|
row2 += xmm0 * xmm2; \
|
|
row2b += xmm1 * xmm2; \
|
|
row3 += xmm0 * xmm3; \
|
|
row3b += xmm1 * xmm3; \
|
|
BO += 4; \
|
|
AO += 2;
|
|
|
|
|
|
#define SAVE2x4(ALPHA) \
|
|
xmm0 = ALPHA; \
|
|
row0 *= xmm0; \
|
|
row0b *= xmm0; \
|
|
row1 *= xmm0; \
|
|
row1b *= xmm0; \
|
|
row2 *= xmm0; \
|
|
row2b *= xmm0; \
|
|
row3 *= xmm0; \
|
|
row3b *= xmm0; \
|
|
*(CO1 + 0 * ldc + 0) += row0; \
|
|
*(CO1 + 0 * ldc + 1) += row0b; \
|
|
*(CO1 + 1 * ldc + 0) += row1; \
|
|
*(CO1 + 1 * ldc + 1) += row1b; \
|
|
*(CO1 + 2 * ldc + 0) += row2; \
|
|
*(CO1 + 2 * ldc + 1) += row2b; \
|
|
*(CO1 + 3 * ldc + 0) += row3; \
|
|
*(CO1 + 3 * ldc + 1) += row3b; \
|
|
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
#define INIT1x4() \
|
|
row0 = 0; row1 = 0; row2 = 0; row3 = 0;
|
|
#define KERNEL1x4_SUB() \
|
|
xmm0 = *(AO ); \
|
|
xmm2 = *(BO + 0); \
|
|
xmm3 = *(BO + 1); \
|
|
row0 += xmm0 * xmm2; \
|
|
row1 += xmm0 * xmm3; \
|
|
xmm2 = *(BO + 2); \
|
|
xmm3 = *(BO + 3); \
|
|
row2 += xmm0 * xmm2; \
|
|
row3 += xmm0 * xmm3; \
|
|
BO += 4; \
|
|
AO += 1;
|
|
|
|
|
|
#define SAVE1x4(ALPHA) \
|
|
xmm0 = ALPHA; \
|
|
row0 *= xmm0; \
|
|
row1 *= xmm0; \
|
|
row2 *= xmm0; \
|
|
row3 *= xmm0; \
|
|
*(CO1 + 0 * ldc) += row0; \
|
|
*(CO1 + 1 * ldc) += row1; \
|
|
*(CO1 + 2 * ldc) += row2; \
|
|
*(CO1 + 3 * ldc) += row3; \
|
|
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
/*******************************************************************************************
|
|
* 2 lines of N
|
|
*******************************************************************************************/
|
|
|
|
#define INIT16x2() \
|
|
row0 = _mm512_setzero_ps(); \
|
|
row1 = _mm512_setzero_ps(); \
|
|
|
|
|
|
#define KERNEL16x2_SUB() \
|
|
zmm0 = _mm512_loadu_ps(AO); \
|
|
zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO)); \
|
|
zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO + 1)); \
|
|
row0 += zmm0 * zmm2; \
|
|
row1 += zmm0 * zmm3; \
|
|
BO += 2; \
|
|
AO += 16;
|
|
|
|
|
|
#define SAVE16x2(ALPHA) \
|
|
zmm0 = _mm512_set1_ps(ALPHA); \
|
|
row0 *= zmm0; \
|
|
row1 *= zmm0; \
|
|
row0 += _mm512_loadu_ps(CO1); \
|
|
row1 += _mm512_loadu_ps(CO1 + ldc); \
|
|
_mm512_storeu_ps(CO1 , row0); \
|
|
_mm512_storeu_ps(CO1 + ldc, row1); \
|
|
|
|
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
#define INIT8x2() \
|
|
ymm4 = _mm256_setzero_ps(); \
|
|
ymm6 = _mm256_setzero_ps(); \
|
|
|
|
#define KERNEL8x2_SUB() \
|
|
ymm0 = _mm256_loadu_ps(AO); \
|
|
ymm2 = _mm256_broadcastss_ps(_mm_load_ss(BO)); \
|
|
ymm3 = _mm256_broadcastss_ps(_mm_load_ss(BO + 1)); \
|
|
ymm4 += ymm0 * ymm2; \
|
|
ymm6 += ymm0 * ymm3; \
|
|
BO += 2; \
|
|
AO += 8;
|
|
|
|
|
|
#define SAVE8x2(ALPHA) \
|
|
ymm0 = _mm256_set1_ps(ALPHA); \
|
|
ymm4 *= ymm0; \
|
|
ymm6 *= ymm0; \
|
|
ymm4 += _mm256_loadu_ps(CO1); \
|
|
ymm6 += _mm256_loadu_ps(CO1 + ldc); \
|
|
_mm256_storeu_ps(CO1 , ymm4); \
|
|
_mm256_storeu_ps(CO1 + ldc, ymm6); \
|
|
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
#define INIT4x2() \
|
|
row0 = _mm_setzero_ps(); \
|
|
row1 = _mm_setzero_ps(); \
|
|
|
|
#define KERNEL4x2_SUB() \
|
|
xmm0 = _mm_loadu_ps(AO); \
|
|
xmm2 = _mm_broadcastss_ps(_mm_load_ss(BO)); \
|
|
xmm3 = _mm_broadcastss_ps(_mm_load_ss(BO + 1)); \
|
|
row0 += xmm0 * xmm2; \
|
|
row1 += xmm0 * xmm3; \
|
|
BO += 2; \
|
|
AO += 4;
|
|
|
|
|
|
#define SAVE4x2(ALPHA) \
|
|
xmm0 = _mm_set1_ps(ALPHA); \
|
|
row0 *= xmm0; \
|
|
row1 *= xmm0; \
|
|
row0 += _mm_loadu_ps(CO1); \
|
|
row1 += _mm_loadu_ps(CO1 + ldc); \
|
|
_mm_storeu_ps(CO1 , row0); \
|
|
_mm_storeu_ps(CO1 + ldc, row1); \
|
|
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
|
|
#define INIT2x2() \
|
|
row0 = 0; row0b = 0; row1 = 0; row1b = 0; \
|
|
|
|
#define KERNEL2x2_SUB() \
|
|
xmm0 = *(AO + 0); \
|
|
xmm1 = *(AO + 1); \
|
|
xmm2 = *(BO + 0); \
|
|
xmm3 = *(BO + 1); \
|
|
row0 += xmm0 * xmm2; \
|
|
row0b += xmm1 * xmm2; \
|
|
row1 += xmm0 * xmm3; \
|
|
row1b += xmm1 * xmm3; \
|
|
BO += 2; \
|
|
AO += 2; \
|
|
|
|
|
|
#define SAVE2x2(ALPHA) \
|
|
xmm0 = ALPHA; \
|
|
row0 *= xmm0; \
|
|
row0b *= xmm0; \
|
|
row1 *= xmm0; \
|
|
row1b *= xmm0; \
|
|
*(CO1 ) += row0; \
|
|
*(CO1 +1 ) += row0b; \
|
|
*(CO1 + ldc ) += row1; \
|
|
*(CO1 + ldc +1) += row1b; \
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
#define INIT1x2() \
|
|
row0 = 0; row1 = 0;
|
|
|
|
#define KERNEL1x2_SUB() \
|
|
xmm0 = *(AO); \
|
|
xmm2 = *(BO + 0); \
|
|
xmm3 = *(BO + 1); \
|
|
row0 += xmm0 * xmm2; \
|
|
row1 += xmm0 * xmm3; \
|
|
BO += 2; \
|
|
AO += 1;
|
|
|
|
|
|
#define SAVE1x2(ALPHA) \
|
|
xmm0 = ALPHA; \
|
|
row0 *= xmm0; \
|
|
row1 *= xmm0; \
|
|
*(CO1 ) += row0; \
|
|
*(CO1 + ldc ) += row1; \
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
/*******************************************************************************************
|
|
* 1 line of N
|
|
*******************************************************************************************/
|
|
|
|
#define INIT16x1() \
|
|
row0 = _mm512_setzero_ps(); \
|
|
|
|
#define KERNEL16x1_SUB() \
|
|
zmm0 = _mm512_loadu_ps(AO); \
|
|
zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO)); \
|
|
row0 += zmm0 * zmm2; \
|
|
BO += 1; \
|
|
AO += 16;
|
|
|
|
|
|
#define SAVE16x1(ALPHA) \
|
|
zmm0 = _mm512_set1_ps(ALPHA); \
|
|
row0 *= zmm0; \
|
|
row0 += _mm512_loadu_ps(CO1); \
|
|
_mm512_storeu_ps(CO1 , row0); \
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
#define INIT8x1() \
|
|
ymm4 = _mm256_setzero_ps();
|
|
|
|
#define KERNEL8x1_SUB() \
|
|
ymm0 = _mm256_loadu_ps(AO); \
|
|
ymm2 = _mm256_broadcastss_ps(_mm_load_ss(BO)); \
|
|
ymm4 += ymm0 * ymm2; \
|
|
BO += 1; \
|
|
AO += 8;
|
|
|
|
|
|
#define SAVE8x1(ALPHA) \
|
|
ymm0 = _mm256_set1_ps(ALPHA); \
|
|
ymm4 *= ymm0; \
|
|
ymm4 += _mm256_loadu_ps(CO1); \
|
|
_mm256_storeu_ps(CO1 , ymm4); \
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
#define INIT4x1() \
|
|
row0 = _mm_setzero_ps(); \
|
|
|
|
#define KERNEL4x1_SUB() \
|
|
xmm0 = _mm_loadu_ps(AO); \
|
|
xmm2 = _mm_broadcastss_ps(_mm_load_ss(BO)); \
|
|
row0 += xmm0 * xmm2; \
|
|
BO += 1; \
|
|
AO += 4;
|
|
|
|
|
|
#define SAVE4x1(ALPHA) \
|
|
xmm0 = _mm_set1_ps(ALPHA); \
|
|
row0 *= xmm0; \
|
|
row0 += _mm_loadu_ps(CO1); \
|
|
_mm_storeu_ps(CO1 , row0); \
|
|
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
#define INIT2x1() \
|
|
row0 = 0; row0b = 0;
|
|
|
|
#define KERNEL2x1_SUB() \
|
|
xmm0 = *(AO + 0); \
|
|
xmm1 = *(AO + 1); \
|
|
xmm2 = *(BO); \
|
|
row0 += xmm0 * xmm2; \
|
|
row0b += xmm1 * xmm2; \
|
|
BO += 1; \
|
|
AO += 2;
|
|
|
|
|
|
#define SAVE2x1(ALPHA) \
|
|
xmm0 = ALPHA; \
|
|
row0 *= xmm0; \
|
|
row0b *= xmm0; \
|
|
*(CO1 ) += row0; \
|
|
*(CO1 +1 ) += row0b; \
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
#define INIT1x1() \
|
|
row0 = 0;
|
|
|
|
#define KERNEL1x1_SUB() \
|
|
xmm0 = *(AO); \
|
|
xmm2 = *(BO); \
|
|
row0 += xmm0 * xmm2; \
|
|
BO += 1; \
|
|
AO += 1;
|
|
|
|
|
|
#define SAVE1x1(ALPHA) \
|
|
xmm0 = ALPHA; \
|
|
row0 *= xmm0; \
|
|
*(CO1 ) += row0; \
|
|
|
|
|
|
/*******************************************************************************************/
|
|
|
|
|
|
/*************************************************************************************
|
|
* GEMM Kernel
|
|
*************************************************************************************/
|
|
|
|
int __attribute__ ((noinline))
|
|
CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float * __restrict A, float * __restrict B, float * __restrict C, BLASLONG ldc)
|
|
{
|
|
unsigned long long M = m, N = n, K = k;
|
|
if (M == 0)
|
|
return 0;
|
|
if (N == 0)
|
|
return 0;
|
|
if (K == 0)
|
|
return 0;
|
|
|
|
|
|
while (N >= 4) {
|
|
float *CO1;
|
|
float *AO;
|
|
int i;
|
|
// L8_10
|
|
CO1 = C;
|
|
C += 4 * ldc;
|
|
|
|
AO = A;
|
|
|
|
i = m;
|
|
while (i >= 64) {
|
|
float *BO;
|
|
float *A1, *A2, *A3;
|
|
// L8_11
|
|
__m512 zmm0, zmm1, zmm2, zmm3, row0, zmm5, row1, zmm7, row2, row3, row0b, row1b, row2b, row3b, row0c, row1c, row2c, row3c, row0d, row1d, row2d, row3d;
|
|
BO = B;
|
|
int kloop = K;
|
|
|
|
A1 = AO + 16 * K;
|
|
A2 = A1 + 16 * K;
|
|
A3 = A2 + 16 * K;
|
|
|
|
INIT64x4()
|
|
|
|
while (kloop > 0) {
|
|
// L12_17
|
|
KERNEL64x4_SUB()
|
|
kloop--;
|
|
}
|
|
// L8_19
|
|
SAVE64x4(alpha)
|
|
CO1 += 64;
|
|
AO += 48 * K;
|
|
|
|
i -= 64;
|
|
}
|
|
while (i >= 32) {
|
|
float *BO;
|
|
float *A1;
|
|
// L8_11
|
|
__m512 zmm0, zmm1, zmm2, zmm3, row0, row1, row2, row3, row0b, row1b, row2b, row3b;
|
|
BO = B;
|
|
int kloop = K;
|
|
|
|
A1 = AO + 16 * K;
|
|
|
|
INIT32x4()
|
|
|
|
while (kloop > 0) {
|
|
// L12_17
|
|
KERNEL32x4_SUB()
|
|
kloop--;
|
|
}
|
|
// L8_19
|
|
SAVE32x4(alpha)
|
|
CO1 += 32;
|
|
AO += 16 * K;
|
|
|
|
i -= 32;
|
|
}
|
|
while (i >= 16) {
|
|
float *BO;
|
|
// L8_11
|
|
__m512 zmm0, zmm2, zmm3, row0, row1, row2, row3;
|
|
BO = B;
|
|
int kloop = K;
|
|
|
|
INIT16x4()
|
|
|
|
while (kloop > 0) {
|
|
// L12_17
|
|
KERNEL16x4_SUB()
|
|
kloop--;
|
|
}
|
|
// L8_19
|
|
SAVE16x4(alpha)
|
|
CO1 += 16;
|
|
|
|
i -= 16;
|
|
}
|
|
while (i >= 8) {
|
|
float *BO;
|
|
// L8_11
|
|
__m256 ymm0, ymm2, ymm3, ymm4, ymm6,ymm8,ymm10;
|
|
BO = B;
|
|
int kloop = K;
|
|
|
|
INIT8x4()
|
|
|
|
while (kloop > 0) {
|
|
// L12_17
|
|
KERNEL8x4_SUB()
|
|
kloop--;
|
|
}
|
|
// L8_19
|
|
SAVE8x4(alpha)
|
|
CO1 += 8;
|
|
|
|
i -= 8;
|
|
}
|
|
while (i >= 4) {
|
|
// L8_11
|
|
float *BO;
|
|
__m128 xmm0, xmm2, xmm3, row0, row1, row2, row3;
|
|
BO = B;
|
|
int kloop = K;
|
|
|
|
INIT4x4()
|
|
// L8_16
|
|
while (kloop > 0) {
|
|
// L12_17
|
|
KERNEL4x4_SUB()
|
|
kloop--;
|
|
}
|
|
// L8_19
|
|
SAVE4x4(alpha)
|
|
CO1 += 4;
|
|
|
|
i -= 4;
|
|
}
|
|
|
|
/**************************************************************************
|
|
* Rest of M
|
|
***************************************************************************/
|
|
|
|
while (i >= 2) {
|
|
float *BO;
|
|
float xmm0, xmm1, xmm2, xmm3, row0, row0b, row1, row1b, row2, row2b, row3, row3b;
|
|
BO = B;
|
|
|
|
INIT2x4()
|
|
int kloop = K;
|
|
|
|
while (kloop > 0) {
|
|
KERNEL2x4_SUB()
|
|
kloop--;
|
|
}
|
|
SAVE2x4(alpha)
|
|
CO1 += 2;
|
|
i -= 2;
|
|
}
|
|
// L13_40
|
|
while (i >= 1) {
|
|
float *BO;
|
|
float xmm0, xmm2, xmm3, row0, row1, row2, row3;
|
|
int kloop = K;
|
|
BO = B;
|
|
INIT1x4()
|
|
|
|
while (kloop > 0) {
|
|
KERNEL1x4_SUB()
|
|
kloop--;
|
|
}
|
|
SAVE1x4(alpha)
|
|
CO1 += 1;
|
|
i -= 1;
|
|
}
|
|
|
|
B += K * 4;
|
|
N -= 4;
|
|
}
|
|
|
|
/**************************************************************************************************/
|
|
|
|
// L8_0
|
|
while (N >= 2) {
|
|
float *CO1;
|
|
float *AO;
|
|
int i;
|
|
// L8_10
|
|
CO1 = C;
|
|
C += 2 * ldc;
|
|
|
|
AO = A;
|
|
|
|
i = m;
|
|
while (i >= 16) {
|
|
float *BO;
|
|
|
|
// L8_11
|
|
__m512 zmm0, zmm2, zmm3, row0, row1;
|
|
BO = B;
|
|
int kloop = K;
|
|
|
|
INIT16x2()
|
|
|
|
while (kloop > 0) {
|
|
// L12_17
|
|
KERNEL16x2_SUB()
|
|
kloop--;
|
|
}
|
|
// L8_19
|
|
SAVE16x2(alpha)
|
|
CO1 += 16;
|
|
|
|
i -= 16;
|
|
}
|
|
while (i >= 8) {
|
|
float *BO;
|
|
__m256 ymm0, ymm2, ymm3, ymm4, ymm6;
|
|
// L8_11
|
|
BO = B;
|
|
int kloop = K;
|
|
|
|
INIT8x2()
|
|
|
|
// L8_16
|
|
while (kloop > 0) {
|
|
// L12_17
|
|
KERNEL8x2_SUB()
|
|
kloop--;
|
|
}
|
|
// L8_19
|
|
SAVE8x2(alpha)
|
|
CO1 += 8;
|
|
|
|
i-=8;
|
|
}
|
|
|
|
while (i >= 4) {
|
|
float *BO;
|
|
__m128 xmm0, xmm2, xmm3, row0, row1;
|
|
// L8_11
|
|
BO = B;
|
|
int kloop = K;
|
|
|
|
INIT4x2()
|
|
|
|
// L8_16
|
|
while (kloop > 0) {
|
|
// L12_17
|
|
KERNEL4x2_SUB()
|
|
kloop--;
|
|
}
|
|
// L8_19
|
|
SAVE4x2(alpha)
|
|
CO1 += 4;
|
|
|
|
i-=4;
|
|
}
|
|
|
|
/**************************************************************************
|
|
* Rest of M
|
|
***************************************************************************/
|
|
|
|
while (i >= 2) {
|
|
float *BO;
|
|
float xmm0, xmm1, xmm2, xmm3, row0, row0b, row1, row1b;
|
|
int kloop = K;
|
|
BO = B;
|
|
|
|
INIT2x2()
|
|
|
|
while (kloop > 0) {
|
|
KERNEL2x2_SUB()
|
|
kloop--;
|
|
}
|
|
SAVE2x2(alpha)
|
|
CO1 += 2;
|
|
i -= 2;
|
|
}
|
|
// L13_40
|
|
while (i >= 1) {
|
|
float *BO;
|
|
float xmm0, xmm2, xmm3, row0, row1;
|
|
int kloop = K;
|
|
BO = B;
|
|
|
|
INIT1x2()
|
|
|
|
while (kloop > 0) {
|
|
KERNEL1x2_SUB()
|
|
kloop--;
|
|
}
|
|
SAVE1x2(alpha)
|
|
CO1 += 1;
|
|
i -= 1;
|
|
}
|
|
|
|
B += K * 2;
|
|
N -= 2;
|
|
}
|
|
|
|
// L8_0
|
|
while (N >= 1) {
|
|
// L8_10
|
|
float *CO1;
|
|
float *AO;
|
|
int i;
|
|
|
|
CO1 = C;
|
|
C += ldc;
|
|
|
|
AO = A;
|
|
|
|
i = m;
|
|
while (i >= 16) {
|
|
float *BO;
|
|
__m512 zmm0, zmm2, row0;
|
|
// L8_11
|
|
BO = B;
|
|
int kloop = K;
|
|
|
|
INIT16x1()
|
|
// L8_16
|
|
while (kloop > 0) {
|
|
// L12_17
|
|
KERNEL16x1_SUB()
|
|
kloop--;
|
|
}
|
|
// L8_19
|
|
SAVE16x1(alpha)
|
|
CO1 += 16;
|
|
|
|
i-= 16;
|
|
}
|
|
while (i >= 8) {
|
|
float *BO;
|
|
__m256 ymm0, ymm2, ymm4;
|
|
// L8_11
|
|
BO = B;
|
|
int kloop = K;
|
|
|
|
INIT8x1()
|
|
// L8_16
|
|
while (kloop > 0) {
|
|
// L12_17
|
|
KERNEL8x1_SUB()
|
|
kloop--;
|
|
}
|
|
// L8_19
|
|
SAVE8x1(alpha)
|
|
CO1 += 8;
|
|
|
|
i-= 8;
|
|
}
|
|
while (i >= 4) {
|
|
float *BO;
|
|
__m128 xmm0, xmm2, row0;
|
|
// L8_11
|
|
BO = B;
|
|
int kloop = K;
|
|
|
|
INIT4x1()
|
|
// L8_16
|
|
while (kloop > 0) {
|
|
// L12_17
|
|
KERNEL4x1_SUB()
|
|
kloop--;
|
|
}
|
|
// L8_19
|
|
SAVE4x1(alpha)
|
|
CO1 += 4;
|
|
|
|
i-= 4;
|
|
}
|
|
|
|
/**************************************************************************
|
|
* Rest of M
|
|
***************************************************************************/
|
|
|
|
while (i >= 2) {
|
|
float *BO;
|
|
float xmm0, xmm1, xmm2, row0, row0b;
|
|
int kloop = K;
|
|
BO = B;
|
|
|
|
INIT2x1()
|
|
|
|
while (kloop > 0) {
|
|
KERNEL2x1_SUB()
|
|
kloop--;
|
|
}
|
|
SAVE2x1(alpha)
|
|
CO1 += 2;
|
|
i -= 2;
|
|
}
|
|
// L13_40
|
|
while (i >= 1) {
|
|
float *BO;
|
|
float xmm0, xmm2, row0;
|
|
int kloop = K;
|
|
|
|
BO = B;
|
|
INIT1x1()
|
|
|
|
|
|
while (kloop > 0) {
|
|
KERNEL1x1_SUB()
|
|
kloop--;
|
|
}
|
|
SAVE1x1(alpha)
|
|
CO1 += 1;
|
|
i -= 1;
|
|
}
|
|
|
|
B += K * 1;
|
|
N -= 1;
|
|
}
|
|
|
|
|
|
return 0;
|
|
}
|
|
|
|
#include "sgemm_direct_skylakex.c"
|