Optimize M < 16 using AVX512 mask
This commit is contained in:
parent
9186456a12
commit
f88470323b
|
@ -31,17 +31,25 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
#define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps()
|
#define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps()
|
||||||
#define LOAD_A_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[lda * k + i + (M*16)])
|
#define LOAD_A_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[lda * k + i + (M*16)])
|
||||||
|
#define MASK_LOAD_A_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &A[lda * k + i + (M*16)])
|
||||||
#define BROADCAST_LOAD_B_512(M, N) __m512 Bval##N = _mm512_broadcastss_ps(_mm_load_ss(&B[k + ldb * (j+N)]))
|
#define BROADCAST_LOAD_B_512(M, N) __m512 Bval##N = _mm512_broadcastss_ps(_mm_load_ss(&B[k + ldb * (j+N)]))
|
||||||
#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N)
|
#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N)
|
||||||
#if defined(B0)
|
#if defined(B0)
|
||||||
#define STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
#define STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
||||||
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N)
|
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N)
|
||||||
|
#define MASK_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
||||||
|
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N)
|
||||||
#else
|
#else
|
||||||
#define STORE_512(M, N) \
|
#define STORE_512(M, N) \
|
||||||
BLASLONG offset##M##N = (j+N)*ldc + i + (M*16); \
|
BLASLONG offset##M##N = (j+N)*ldc + i + (M*16); \
|
||||||
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
||||||
asm("vfmadd231ps (%1, %2, 4), %3, %0": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_512)); \
|
asm("vfmadd231ps (%1, %2, 4), %3, %0": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_512)); \
|
||||||
_mm512_storeu_ps(&C[offset##M##N], result##M##N)
|
_mm512_storeu_ps(&C[offset##M##N], result##M##N)
|
||||||
|
#define MASK_STORE_512(M, N) \
|
||||||
|
BLASLONG offset##M##N = (j+N)*ldc + i + (M*16); \
|
||||||
|
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
||||||
|
asm("vfmadd231ps (%1, %2, 4), %3, %0 %{%4%}": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_512), "k"(mask)); \
|
||||||
|
_mm512_mask_storeu_ps(&C[offset##M##N], mask, result##M##N)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define DECLARE_RESULT_256(M, N) __m256 result##M##N = _mm256_setzero_ps()
|
#define DECLARE_RESULT_256(M, N) __m256 result##M##N = _mm256_setzero_ps()
|
||||||
|
@ -241,6 +249,51 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
|
||||||
STORE_512(0, 0);
|
STORE_512(0, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (M - i > 0) {
|
||||||
|
register __mmask16 mask asm("k1") = (1UL << (M - i)) - 1;
|
||||||
|
for (j = 0; j < n4; j += 4) {
|
||||||
|
DECLARE_RESULT_512(0, 0);
|
||||||
|
DECLARE_RESULT_512(0, 1);
|
||||||
|
DECLARE_RESULT_512(0, 2);
|
||||||
|
DECLARE_RESULT_512(0, 3);
|
||||||
|
for (k = 0; k < K; k++) {
|
||||||
|
MASK_LOAD_A_512(0, x);
|
||||||
|
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1);
|
||||||
|
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3);
|
||||||
|
|
||||||
|
MATMUL_512(0, 0);
|
||||||
|
MATMUL_512(0, 1);
|
||||||
|
MATMUL_512(0, 2);
|
||||||
|
MATMUL_512(0, 3);
|
||||||
|
}
|
||||||
|
MASK_STORE_512(0, 0);
|
||||||
|
MASK_STORE_512(0, 1);
|
||||||
|
MASK_STORE_512(0, 2);
|
||||||
|
MASK_STORE_512(0, 3);
|
||||||
|
}
|
||||||
|
for (; j < n2; j += 2) {
|
||||||
|
DECLARE_RESULT_512(0, 0);
|
||||||
|
DECLARE_RESULT_512(0, 1);
|
||||||
|
for (k = 0; k < K; k++) {
|
||||||
|
MASK_LOAD_A_512(0, x);
|
||||||
|
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1);
|
||||||
|
MATMUL_512(0, 0);
|
||||||
|
MATMUL_512(0, 1);
|
||||||
|
}
|
||||||
|
MASK_STORE_512(0, 0);
|
||||||
|
MASK_STORE_512(0, 1);
|
||||||
|
}
|
||||||
|
for (; j < N; j++) {
|
||||||
|
DECLARE_RESULT_512(0, 0);
|
||||||
|
for (k = 0; k < K; k++) {
|
||||||
|
MASK_LOAD_A_512(0, x);
|
||||||
|
BROADCAST_LOAD_B_512(x, 0);
|
||||||
|
MATMUL_512(0, 0);
|
||||||
|
}
|
||||||
|
MASK_STORE_512(0, 0);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
__m256 alpha_256 = _mm256_broadcastss_ps(_mm_load_ss(&alpha));
|
__m256 alpha_256 = _mm256_broadcastss_ps(_mm_load_ss(&alpha));
|
||||||
#if !defined(B0)
|
#if !defined(B0)
|
||||||
__m256 beta_256 = _mm256_broadcastss_ps(_mm_load_ss(&beta));
|
__m256 beta_256 = _mm256_broadcastss_ps(_mm_load_ss(&beta));
|
||||||
|
|
Loading…
Reference in New Issue