Merge pull request #3550 from guowangy/smatrix-mask-fix

Small Matrix: use proper inline asm input constraint for AVX512 mask
This commit is contained in:
Martin Kroeker 2022-02-28 08:28:02 +01:00 committed by GitHub
commit 1ef97c470c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 8 additions and 8 deletions

View File

@ -48,7 +48,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
_mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N)
#define MASK_STORE_512(M, N) \
result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "k"(mask)); \
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "Yk"(mask)); \
_mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N)
#endif
@ -266,7 +266,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
int mm = M - i;
if (!mm) return 0;
if (mm > 4 || K < 16) {
register __mmask8 mask asm("k1") = (1UL << mm) - 1;
register __mmask8 mask = (1UL << mm) - 1;
for (j = 0; j < n6; j += 6) {
DECLARE_RESULT_512(0, 0);
DECLARE_RESULT_512(0, 1);

View File

@ -55,7 +55,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
_mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N)
#define MASK_STORE_512(M, N) \
result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "k"(mask)); \
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "Yk"(mask)); \
_mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N)
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
__m512d tmp##M##N = _mm512_i64gather_pd(vindex_n, &C[(j + N*8)*ldc + i + M], 8); \
@ -303,7 +303,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
}
int mm = M - i;
if (mm >= 6) {
register __mmask16 mask asm("k1") = (1UL << mm) - 1;
register __mmask16 mask = (1UL << mm) - 1;
for (j = 0; j < n8; j += 8) {
DECLARE_RESULT_512(0, 0);
DECLARE_RESULT_512(0, 1);

View File

@ -48,7 +48,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N)
#define MASK_STORE_512(M, N) \
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "Yk"(mask)); \
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N)
#endif
@ -267,7 +267,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
int mm = M - i;
if (!mm) return 0;
if (mm > 8 || K < 32) {
register __mmask16 mask asm("k1") = (1UL << mm) - 1;
register __mmask16 mask = (1UL << mm) - 1;
for (j = 0; j < n6; j += 6) {
DECLARE_RESULT_512(0, 0);
DECLARE_RESULT_512(0, 1);

View File

@ -55,7 +55,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N)
#define MASK_STORE_512(M, N) \
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "Yk"(mask)); \
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N)
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
__m512 tmp##M##N = _mm512_i32gather_ps(vindex_n, &C[(j + N*16)*ldc + i + M], 4); \
@ -303,7 +303,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
}
int mm = M - i;
if (mm >= 12) {
register __mmask16 mask asm("k1") = (1UL << mm) - 1;
register __mmask16 mask = (1UL << mm) - 1;
for (j = 0; j < n8; j += 8) {
DECLARE_RESULT_512(0, 0);
DECLARE_RESULT_512(0, 1);