Merge pull request #3550 from guowangy/smatrix-mask-fix
Small Matrix: use proper inline asm input constraint for AVX512 mask
This commit is contained in:
commit
1ef97c470c
|
@ -48,7 +48,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
_mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N)
|
||||
#define MASK_STORE_512(M, N) \
|
||||
result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
|
||||
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "k"(mask)); \
|
||||
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "Yk"(mask)); \
|
||||
_mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N)
|
||||
#endif
|
||||
|
||||
|
@ -266,7 +266,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
|
|||
int mm = M - i;
|
||||
if (!mm) return 0;
|
||||
if (mm > 4 || K < 16) {
|
||||
register __mmask8 mask asm("k1") = (1UL << mm) - 1;
|
||||
register __mmask8 mask = (1UL << mm) - 1;
|
||||
for (j = 0; j < n6; j += 6) {
|
||||
DECLARE_RESULT_512(0, 0);
|
||||
DECLARE_RESULT_512(0, 1);
|
||||
|
|
|
@ -55,7 +55,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
_mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N)
|
||||
#define MASK_STORE_512(M, N) \
|
||||
result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
|
||||
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "k"(mask)); \
|
||||
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "Yk"(mask)); \
|
||||
_mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N)
|
||||
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
|
||||
__m512d tmp##M##N = _mm512_i64gather_pd(vindex_n, &C[(j + N*8)*ldc + i + M], 8); \
|
||||
|
@ -303,7 +303,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
|
|||
}
|
||||
int mm = M - i;
|
||||
if (mm >= 6) {
|
||||
register __mmask16 mask asm("k1") = (1UL << mm) - 1;
|
||||
register __mmask16 mask = (1UL << mm) - 1;
|
||||
for (j = 0; j < n8; j += 8) {
|
||||
DECLARE_RESULT_512(0, 0);
|
||||
DECLARE_RESULT_512(0, 1);
|
||||
|
|
|
@ -48,7 +48,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N)
|
||||
#define MASK_STORE_512(M, N) \
|
||||
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
||||
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \
|
||||
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "Yk"(mask)); \
|
||||
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N)
|
||||
#endif
|
||||
|
||||
|
@ -267,7 +267,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
|
|||
int mm = M - i;
|
||||
if (!mm) return 0;
|
||||
if (mm > 8 || K < 32) {
|
||||
register __mmask16 mask asm("k1") = (1UL << mm) - 1;
|
||||
register __mmask16 mask = (1UL << mm) - 1;
|
||||
for (j = 0; j < n6; j += 6) {
|
||||
DECLARE_RESULT_512(0, 0);
|
||||
DECLARE_RESULT_512(0, 1);
|
||||
|
|
|
@ -55,7 +55,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N)
|
||||
#define MASK_STORE_512(M, N) \
|
||||
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
||||
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \
|
||||
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "Yk"(mask)); \
|
||||
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N)
|
||||
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
||||
__m512 tmp##M##N = _mm512_i32gather_ps(vindex_n, &C[(j + N*16)*ldc + i + M], 4); \
|
||||
|
@ -303,7 +303,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
|
|||
}
|
||||
int mm = M - i;
|
||||
if (mm >= 12) {
|
||||
register __mmask16 mask asm("k1") = (1UL << mm) - 1;
|
||||
register __mmask16 mask = (1UL << mm) - 1;
|
||||
for (j = 0; j < n8; j += 8) {
|
||||
DECLARE_RESULT_512(0, 0);
|
||||
DECLARE_RESULT_512(0, 1);
|
||||
|
|
Loading…
Reference in New Issue