Merge pull request #3550 from guowangy/smatrix-mask-fix
Small Matrix: use proper inline asm input constraint for AVX512 mask
This commit is contained in:
commit
1ef97c470c
|
@ -48,7 +48,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
_mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N)
|
_mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N)
|
||||||
#define MASK_STORE_512(M, N) \
|
#define MASK_STORE_512(M, N) \
|
||||||
result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
|
result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
|
||||||
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "k"(mask)); \
|
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "Yk"(mask)); \
|
||||||
_mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N)
|
_mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -266,7 +266,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
|
||||||
int mm = M - i;
|
int mm = M - i;
|
||||||
if (!mm) return 0;
|
if (!mm) return 0;
|
||||||
if (mm > 4 || K < 16) {
|
if (mm > 4 || K < 16) {
|
||||||
register __mmask8 mask asm("k1") = (1UL << mm) - 1;
|
register __mmask8 mask = (1UL << mm) - 1;
|
||||||
for (j = 0; j < n6; j += 6) {
|
for (j = 0; j < n6; j += 6) {
|
||||||
DECLARE_RESULT_512(0, 0);
|
DECLARE_RESULT_512(0, 0);
|
||||||
DECLARE_RESULT_512(0, 1);
|
DECLARE_RESULT_512(0, 1);
|
||||||
|
|
|
@ -55,7 +55,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
_mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N)
|
_mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N)
|
||||||
#define MASK_STORE_512(M, N) \
|
#define MASK_STORE_512(M, N) \
|
||||||
result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
|
result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
|
||||||
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "k"(mask)); \
|
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "Yk"(mask)); \
|
||||||
_mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N)
|
_mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N)
|
||||||
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
|
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
|
||||||
__m512d tmp##M##N = _mm512_i64gather_pd(vindex_n, &C[(j + N*8)*ldc + i + M], 8); \
|
__m512d tmp##M##N = _mm512_i64gather_pd(vindex_n, &C[(j + N*8)*ldc + i + M], 8); \
|
||||||
|
@ -303,7 +303,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
|
||||||
}
|
}
|
||||||
int mm = M - i;
|
int mm = M - i;
|
||||||
if (mm >= 6) {
|
if (mm >= 6) {
|
||||||
register __mmask16 mask asm("k1") = (1UL << mm) - 1;
|
register __mmask16 mask = (1UL << mm) - 1;
|
||||||
for (j = 0; j < n8; j += 8) {
|
for (j = 0; j < n8; j += 8) {
|
||||||
DECLARE_RESULT_512(0, 0);
|
DECLARE_RESULT_512(0, 0);
|
||||||
DECLARE_RESULT_512(0, 1);
|
DECLARE_RESULT_512(0, 1);
|
||||||
|
|
|
@ -48,7 +48,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N)
|
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N)
|
||||||
#define MASK_STORE_512(M, N) \
|
#define MASK_STORE_512(M, N) \
|
||||||
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
||||||
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \
|
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "Yk"(mask)); \
|
||||||
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N)
|
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -267,7 +267,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
|
||||||
int mm = M - i;
|
int mm = M - i;
|
||||||
if (!mm) return 0;
|
if (!mm) return 0;
|
||||||
if (mm > 8 || K < 32) {
|
if (mm > 8 || K < 32) {
|
||||||
register __mmask16 mask asm("k1") = (1UL << mm) - 1;
|
register __mmask16 mask = (1UL << mm) - 1;
|
||||||
for (j = 0; j < n6; j += 6) {
|
for (j = 0; j < n6; j += 6) {
|
||||||
DECLARE_RESULT_512(0, 0);
|
DECLARE_RESULT_512(0, 0);
|
||||||
DECLARE_RESULT_512(0, 1);
|
DECLARE_RESULT_512(0, 1);
|
||||||
|
|
|
@ -55,7 +55,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N)
|
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N)
|
||||||
#define MASK_STORE_512(M, N) \
|
#define MASK_STORE_512(M, N) \
|
||||||
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
||||||
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \
|
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "Yk"(mask)); \
|
||||||
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N)
|
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N)
|
||||||
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
|
||||||
__m512 tmp##M##N = _mm512_i32gather_ps(vindex_n, &C[(j + N*16)*ldc + i + M], 4); \
|
__m512 tmp##M##N = _mm512_i32gather_ps(vindex_n, &C[(j + N*16)*ldc + i + M], 4); \
|
||||||
|
@ -303,7 +303,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
|
||||||
}
|
}
|
||||||
int mm = M - i;
|
int mm = M - i;
|
||||||
if (mm >= 12) {
|
if (mm >= 12) {
|
||||||
register __mmask16 mask asm("k1") = (1UL << mm) - 1;
|
register __mmask16 mask = (1UL << mm) - 1;
|
||||||
for (j = 0; j < n8; j += 8) {
|
for (j = 0; j < n8; j += 8) {
|
||||||
DECLARE_RESULT_512(0, 0);
|
DECLARE_RESULT_512(0, 0);
|
||||||
DECLARE_RESULT_512(0, 1);
|
DECLARE_RESULT_512(0, 1);
|
||||||
|
|
Loading…
Reference in New Issue