Add a 24x8 kernel to the skylakex dgemm implementation
Minor gains for small matrixes, but at 512x512 and above the gain gets more significant.
This commit is contained in:
parent
1938819c25
commit
66b43affbc
|
@ -849,6 +849,207 @@ CNAME(BLASLONG m, BLASLONG n, BLASLONG k, double alpha, double * __restrict__ A,
|
||||||
|
|
||||||
i = m;
|
i = m;
|
||||||
|
|
||||||
|
while (i >= 24) {
|
||||||
|
double *BO;
|
||||||
|
double *A1, *A2;
|
||||||
|
int kloop = K;
|
||||||
|
|
||||||
|
BO = B + 12;
|
||||||
|
A1 = AO + 8 * K;
|
||||||
|
A2 = AO + 16 * K;
|
||||||
|
/*
|
||||||
|
* This is the inner loop for the hot hot path
|
||||||
|
* Written in inline asm because compilers like GCC 8 and earlier
|
||||||
|
* struggle with register allocation and are not good at using
|
||||||
|
* the AVX512 built in broadcast ability (1to8)
|
||||||
|
*/
|
||||||
|
asm(
|
||||||
|
"vxorpd %%zmm1, %%zmm1, %%zmm1\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm2\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm3\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm4\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm5\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm6\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm7\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm8\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm11\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm12\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm13\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm14\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm15\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm16\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm17\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm18\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm21\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm22\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm23\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm24\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm25\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm26\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm27\n"
|
||||||
|
"vmovapd %%zmm1, %%zmm28\n"
|
||||||
|
"jmp .label24\n"
|
||||||
|
".align 32\n"
|
||||||
|
/* Inner math loop */
|
||||||
|
".label24:\n"
|
||||||
|
"vmovupd -128(%[AO]),%%zmm0\n"
|
||||||
|
"vmovupd -128(%[A1]),%%zmm10\n"
|
||||||
|
"vmovupd -128(%[A2]),%%zmm20\n"
|
||||||
|
|
||||||
|
"vbroadcastsd -96(%[BO]), %%zmm9\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm0, %%zmm1\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm10, %%zmm11\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm20, %%zmm21\n"
|
||||||
|
|
||||||
|
"vbroadcastsd -88(%[BO]), %%zmm9\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm0, %%zmm2\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm10, %%zmm12\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm20, %%zmm22\n"
|
||||||
|
|
||||||
|
"vbroadcastsd -80(%[BO]), %%zmm9\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm0, %%zmm3\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm10, %%zmm13\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm20, %%zmm23\n"
|
||||||
|
|
||||||
|
"vbroadcastsd -72(%[BO]), %%zmm9\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm0, %%zmm4\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm10, %%zmm14\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm20, %%zmm24\n"
|
||||||
|
|
||||||
|
"vbroadcastsd -64(%[BO]), %%zmm9\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm0, %%zmm5\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm10, %%zmm15\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm20, %%zmm25\n"
|
||||||
|
|
||||||
|
"vbroadcastsd -56(%[BO]), %%zmm9\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm0, %%zmm6\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm10, %%zmm16\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm20, %%zmm26\n"
|
||||||
|
|
||||||
|
"vbroadcastsd -48(%[BO]), %%zmm9\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm0, %%zmm7\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm10, %%zmm17\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm20, %%zmm27\n"
|
||||||
|
|
||||||
|
"vbroadcastsd -40(%[BO]), %%zmm9\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm0, %%zmm8\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm10, %%zmm18\n"
|
||||||
|
"vfmadd231pd %%zmm9, %%zmm20, %%zmm28\n"
|
||||||
|
"add $64, %[AO]\n"
|
||||||
|
"add $64, %[A1]\n"
|
||||||
|
"add $64, %[A2]\n"
|
||||||
|
"add $64, %[BO]\n"
|
||||||
|
"prefetch 512(%[AO])\n"
|
||||||
|
"prefetch 512(%[A1])\n"
|
||||||
|
"prefetch 512(%[A2])\n"
|
||||||
|
"prefetch 512(%[BO])\n"
|
||||||
|
"subl $1, %[kloop]\n"
|
||||||
|
"jg .label24\n"
|
||||||
|
/* multiply the result by alpha */
|
||||||
|
"vbroadcastsd (%[alpha]), %%zmm9\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm1, %%zmm1\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm2, %%zmm2\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm3, %%zmm3\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm4, %%zmm4\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm5, %%zmm5\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm6, %%zmm6\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm7, %%zmm7\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm8, %%zmm8\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm11, %%zmm11\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm12, %%zmm12\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm13, %%zmm13\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm14, %%zmm14\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm15, %%zmm15\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm16, %%zmm16\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm17, %%zmm17\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm18, %%zmm18\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm21, %%zmm21\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm22, %%zmm22\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm23, %%zmm23\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm24, %%zmm24\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm25, %%zmm25\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm26, %%zmm26\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm27, %%zmm27\n"
|
||||||
|
"vmulpd %%zmm9, %%zmm28, %%zmm28\n"
|
||||||
|
/* And store additively in C */
|
||||||
|
"vaddpd (%[C0]), %%zmm1, %%zmm1\n"
|
||||||
|
"vaddpd (%[C1]), %%zmm2, %%zmm2\n"
|
||||||
|
"vaddpd (%[C2]), %%zmm3, %%zmm3\n"
|
||||||
|
"vaddpd (%[C3]), %%zmm4, %%zmm4\n"
|
||||||
|
"vaddpd (%[C4]), %%zmm5, %%zmm5\n"
|
||||||
|
"vaddpd (%[C5]), %%zmm6, %%zmm6\n"
|
||||||
|
"vaddpd (%[C6]), %%zmm7, %%zmm7\n"
|
||||||
|
"vaddpd (%[C7]), %%zmm8, %%zmm8\n"
|
||||||
|
"vmovupd %%zmm1, (%[C0])\n"
|
||||||
|
"vmovupd %%zmm2, (%[C1])\n"
|
||||||
|
"vmovupd %%zmm3, (%[C2])\n"
|
||||||
|
"vmovupd %%zmm4, (%[C3])\n"
|
||||||
|
"vmovupd %%zmm5, (%[C4])\n"
|
||||||
|
"vmovupd %%zmm6, (%[C5])\n"
|
||||||
|
"vmovupd %%zmm7, (%[C6])\n"
|
||||||
|
"vmovupd %%zmm8, (%[C7])\n"
|
||||||
|
|
||||||
|
"vaddpd 64(%[C0]), %%zmm11, %%zmm11\n"
|
||||||
|
"vaddpd 64(%[C1]), %%zmm12, %%zmm12\n"
|
||||||
|
"vaddpd 64(%[C2]), %%zmm13, %%zmm13\n"
|
||||||
|
"vaddpd 64(%[C3]), %%zmm14, %%zmm14\n"
|
||||||
|
"vaddpd 64(%[C4]), %%zmm15, %%zmm15\n"
|
||||||
|
"vaddpd 64(%[C5]), %%zmm16, %%zmm16\n"
|
||||||
|
"vaddpd 64(%[C6]), %%zmm17, %%zmm17\n"
|
||||||
|
"vaddpd 64(%[C7]), %%zmm18, %%zmm18\n"
|
||||||
|
"vmovupd %%zmm11, 64(%[C0])\n"
|
||||||
|
"vmovupd %%zmm12, 64(%[C1])\n"
|
||||||
|
"vmovupd %%zmm13, 64(%[C2])\n"
|
||||||
|
"vmovupd %%zmm14, 64(%[C3])\n"
|
||||||
|
"vmovupd %%zmm15, 64(%[C4])\n"
|
||||||
|
"vmovupd %%zmm16, 64(%[C5])\n"
|
||||||
|
"vmovupd %%zmm17, 64(%[C6])\n"
|
||||||
|
"vmovupd %%zmm18, 64(%[C7])\n"
|
||||||
|
|
||||||
|
"vaddpd 128(%[C0]), %%zmm21, %%zmm21\n"
|
||||||
|
"vaddpd 128(%[C1]), %%zmm22, %%zmm22\n"
|
||||||
|
"vaddpd 128(%[C2]), %%zmm23, %%zmm23\n"
|
||||||
|
"vaddpd 128(%[C3]), %%zmm24, %%zmm24\n"
|
||||||
|
"vaddpd 128(%[C4]), %%zmm25, %%zmm25\n"
|
||||||
|
"vaddpd 128(%[C5]), %%zmm26, %%zmm26\n"
|
||||||
|
"vaddpd 128(%[C6]), %%zmm27, %%zmm27\n"
|
||||||
|
"vaddpd 128(%[C7]), %%zmm28, %%zmm28\n"
|
||||||
|
"vmovupd %%zmm21, 128(%[C0])\n"
|
||||||
|
"vmovupd %%zmm22, 128(%[C1])\n"
|
||||||
|
"vmovupd %%zmm23, 128(%[C2])\n"
|
||||||
|
"vmovupd %%zmm24, 128(%[C3])\n"
|
||||||
|
"vmovupd %%zmm25, 128(%[C4])\n"
|
||||||
|
"vmovupd %%zmm26, 128(%[C5])\n"
|
||||||
|
"vmovupd %%zmm27, 128(%[C6])\n"
|
||||||
|
"vmovupd %%zmm28, 128(%[C7])\n"
|
||||||
|
|
||||||
|
:
|
||||||
|
[AO] "+r" (AO),
|
||||||
|
[A1] "+r" (A1),
|
||||||
|
[A2] "+r" (A2),
|
||||||
|
[BO] "+r" (BO),
|
||||||
|
[C0] "+r" (CO1),
|
||||||
|
[kloop] "+r" (kloop)
|
||||||
|
:
|
||||||
|
[alpha] "r" (&alpha),
|
||||||
|
[C1] "r" (CO1 + 1 * ldc),
|
||||||
|
[C2] "r" (CO1 + 2 * ldc),
|
||||||
|
[C3] "r" (CO1 + 3 * ldc),
|
||||||
|
[C4] "r" (CO1 + 4 * ldc),
|
||||||
|
[C5] "r" (CO1 + 5 * ldc),
|
||||||
|
[C6] "r" (CO1 + 6 * ldc),
|
||||||
|
[C7] "r" (CO1 + 7 * ldc)
|
||||||
|
|
||||||
|
: "memory", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9",
|
||||||
|
"zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", "zmm16", "zmm17", "zmm18",
|
||||||
|
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28"
|
||||||
|
);
|
||||||
|
CO1 += 24;
|
||||||
|
AO += 16 * K;
|
||||||
|
i-= 24;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
while (i >= 16) {
|
while (i >= 16) {
|
||||||
double *BO;
|
double *BO;
|
||||||
double *A1;
|
double *A1;
|
||||||
|
|
Loading…
Reference in New Issue