various code cleanups and comments

This commit is contained in:
Arjan van de Ven 2018-08-05 02:44:40 +00:00
parent f2810beafb
commit 9c29524f50
2 changed files with 72 additions and 61 deletions

View File

@ -25,7 +25,7 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already*/ /* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already */
#ifndef __AVX512CD_ #ifndef __AVX512CD_
#pragma GCC target("avx2,fma") #pragma GCC target("avx2,fma")
#endif #endif
@ -68,17 +68,13 @@ static void dgemv_kernel_4x4( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT
__m512d tempY; __m512d tempY;
__m512d sum; __m512d sum;
sum = _mm512_add_pd( sum = _mm512_loadu_pd(&ap[0][i]) * x05 +
_mm512_add_pd( _mm512_loadu_pd(&ap[1][i]) * x15 +
_mm512_mul_pd(_mm512_loadu_pd(&ap[0][i]), x05), _mm512_loadu_pd(&ap[2][i]) * x25 +
_mm512_mul_pd(_mm512_loadu_pd(&ap[1][i]), x15)), _mm512_loadu_pd(&ap[3][i]) * x35;
_mm512_add_pd(
_mm512_mul_pd(_mm512_loadu_pd(&ap[2][i]), x25),
_mm512_mul_pd(_mm512_loadu_pd(&ap[3][i]), x35))
);
tempY = _mm512_loadu_pd(&y[i]); tempY = _mm512_loadu_pd(&y[i]);
tempY = _mm512_add_pd(tempY, _mm512_mul_pd(sum, __alpha5)); tempY += sum * __alpha5;
_mm512_storeu_pd(&y[i], tempY); _mm512_storeu_pd(&y[i], tempY);
} }
#endif #endif
@ -87,17 +83,13 @@ static void dgemv_kernel_4x4( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT
__m256d tempY; __m256d tempY;
__m256d sum; __m256d sum;
sum = _mm256_add_pd( sum = _mm256_loadu_pd(&ap[0][i]) * x0 +
_mm256_add_pd( _mm256_loadu_pd(&ap[1][i]) * x1 +
_mm256_mul_pd(_mm256_loadu_pd(&ap[0][i]), x0), _mm256_loadu_pd(&ap[2][i]) * x2 +
_mm256_mul_pd(_mm256_loadu_pd(&ap[1][i]), x1)), _mm256_loadu_pd(&ap[3][i]) * x3;
_mm256_add_pd(
_mm256_mul_pd(_mm256_loadu_pd(&ap[2][i]), x2),
_mm256_mul_pd(_mm256_loadu_pd(&ap[3][i]), x3))
);
tempY = _mm256_loadu_pd(&y[i]); tempY = _mm256_loadu_pd(&y[i]);
tempY = _mm256_add_pd(tempY, _mm256_mul_pd(sum, __alpha)); tempY += sum * __alpha;
_mm256_storeu_pd(&y[i], tempY); _mm256_storeu_pd(&y[i], tempY);
} }
@ -124,13 +116,10 @@ static void dgemv_kernel_4x2( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT
__m256d tempY; __m256d tempY;
__m256d sum; __m256d sum;
sum = _mm256_add_pd( sum = _mm256_loadu_pd(&ap[0][i]) * x0 + _mm256_loadu_pd(&ap[1][i]) * x1;
_mm256_mul_pd(_mm256_loadu_pd(&ap[0][i]), x0),
_mm256_mul_pd(_mm256_loadu_pd(&ap[1][i]), x1)
);
tempY = _mm256_loadu_pd(&y[i]); tempY = _mm256_loadu_pd(&y[i]);
tempY = _mm256_add_pd(tempY, _mm256_mul_pd(sum, __alpha)); tempY += sum * __alpha;
_mm256_storeu_pd(&y[i], tempY); _mm256_storeu_pd(&y[i], tempY);
} }

View File

@ -25,6 +25,14 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/ *****************************************************************************/
/* Ensure that the compiler knows how to generate AVX2 instructions if it doesn't already */
#ifndef __AVX512CD_
#pragma GCC target("avx2,fma")
#endif
#ifdef __AVX2__
#include <immintrin.h> #include <immintrin.h>
#define HAVE_KERNEL_4x4 1 #define HAVE_KERNEL_4x4 1
@ -33,13 +41,14 @@ static void dsymv_kernel_4x4(BLASLONG from, BLASLONG to, FLOAT **a, FLOAT *x, FL
{ {
__m256d temp2_0, temp2_1, temp2_2, temp2_3; // temp2_0 temp2_1 temp2_2 temp2_3 __m256d accum_0, accum_1, accum_2, accum_3;
__m256d temp1_0, temp1_1, temp1_2, temp1_3; __m256d temp1_0, temp1_1, temp1_2, temp1_3;
temp2_0 = _mm256_setzero_pd(); /* the 256 bit wide acculmulator vectors start out as zero */
temp2_1 = _mm256_setzero_pd(); accum_0 = _mm256_setzero_pd();
temp2_2 = _mm256_setzero_pd(); accum_1 = _mm256_setzero_pd();
temp2_3 = _mm256_setzero_pd(); accum_2 = _mm256_setzero_pd();
accum_3 = _mm256_setzero_pd();
temp1_0 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[0])); temp1_0 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[0]));
temp1_1 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[1])); temp1_1 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[1]));
@ -47,15 +56,16 @@ static void dsymv_kernel_4x4(BLASLONG from, BLASLONG to, FLOAT **a, FLOAT *x, FL
temp1_3 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[3])); temp1_3 = _mm256_broadcastsd_pd(_mm_load_sd(&temp1[3]));
#ifdef __AVX512CD__ #ifdef __AVX512CD__
__m512d temp2_05, temp2_15, temp2_25, temp2_35; // temp2_0 temp2_1 temp2_2 temp2_3 __m512d accum_05, accum_15, accum_25, accum_35;
__m512d temp1_05, temp1_15, temp1_25, temp1_35; __m512d temp1_05, temp1_15, temp1_25, temp1_35;
BLASLONG to2; BLASLONG to2;
int delta; int delta;
temp2_05 = _mm512_setzero_pd(); /* the 512 bit wide accumulator vectors start out as zero */
temp2_15 = _mm512_setzero_pd(); accum_05 = _mm512_setzero_pd();
temp2_25 = _mm512_setzero_pd(); accum_15 = _mm512_setzero_pd();
temp2_35 = _mm512_setzero_pd(); accum_25 = _mm512_setzero_pd();
accum_35 = _mm512_setzero_pd();
temp1_05 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[0])); temp1_05 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[0]));
temp1_15 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[1])); temp1_15 = _mm512_broadcastsd_pd(_mm_load_sd(&temp1[1]));
@ -80,19 +90,23 @@ static void dsymv_kernel_4x4(BLASLONG from, BLASLONG to, FLOAT **a, FLOAT *x, FL
_y += temp1_05 * a0 + temp1_15 * a1 + temp1_25 * a2 + temp1_35 * a3; _y += temp1_05 * a0 + temp1_15 * a1 + temp1_25 * a2 + temp1_35 * a3;
temp2_05 += _x * a0; accum_05 += _x * a0;
temp2_15 += _x * a1; accum_15 += _x * a1;
temp2_25 += _x * a2; accum_25 += _x * a2;
temp2_35 += _x * a3; accum_35 += _x * a3;
_mm512_storeu_pd(&y[from], _y); _mm512_storeu_pd(&y[from], _y);
}; };
temp2_0 = _mm256_add_pd(_mm512_extractf64x4_pd(temp2_05, 0), _mm512_extractf64x4_pd(temp2_05, 1)); /*
temp2_1 = _mm256_add_pd(_mm512_extractf64x4_pd(temp2_15, 0), _mm512_extractf64x4_pd(temp2_15, 1)); * we need to fold our 512 bit wide accumulator vectors into 256 bit wide vectors so that the AVX2 code
temp2_2 = _mm256_add_pd(_mm512_extractf64x4_pd(temp2_25, 0), _mm512_extractf64x4_pd(temp2_25, 1)); * below can continue using the intermediate results in its loop
temp2_3 = _mm256_add_pd(_mm512_extractf64x4_pd(temp2_35, 0), _mm512_extractf64x4_pd(temp2_35, 1)); */
accum_0 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_05, 0), _mm512_extractf64x4_pd(accum_05, 1));
accum_1 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_15, 0), _mm512_extractf64x4_pd(accum_15, 1));
accum_2 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_25, 0), _mm512_extractf64x4_pd(accum_25, 1));
accum_3 = _mm256_add_pd(_mm512_extractf64x4_pd(accum_35, 0), _mm512_extractf64x4_pd(accum_35, 1));
#endif #endif
@ -103,6 +117,7 @@ static void dsymv_kernel_4x4(BLASLONG from, BLASLONG to, FLOAT **a, FLOAT *x, FL
_y = _mm256_loadu_pd(&y[from]); _y = _mm256_loadu_pd(&y[from]);
_x = _mm256_loadu_pd(&x[from]); _x = _mm256_loadu_pd(&x[from]);
/* load 4 rows of matrix data */
a0 = _mm256_loadu_pd(&a[0][from]); a0 = _mm256_loadu_pd(&a[0][from]);
a1 = _mm256_loadu_pd(&a[1][from]); a1 = _mm256_loadu_pd(&a[1][from]);
a2 = _mm256_loadu_pd(&a[2][from]); a2 = _mm256_loadu_pd(&a[2][from]);
@ -110,33 +125,40 @@ static void dsymv_kernel_4x4(BLASLONG from, BLASLONG to, FLOAT **a, FLOAT *x, FL
_y += temp1_0 * a0 + temp1_1 * a1 + temp1_2 * a2 + temp1_3 * a3; _y += temp1_0 * a0 + temp1_1 * a1 + temp1_2 * a2 + temp1_3 * a3;
temp2_0 += _x * a0; accum_0 += _x * a0;
temp2_1 += _x * a1; accum_1 += _x * a1;
temp2_2 += _x * a2; accum_2 += _x * a2;
temp2_3 += _x * a3; accum_3 += _x * a3;
_mm256_storeu_pd(&y[from], _y); _mm256_storeu_pd(&y[from], _y);
}; };
__m128d xmm0, xmm1, xmm2, xmm3; /*
* we now have 4 accumulator vectors. Each vector needs to be summed up element wise and stored in the temp2
* output array. There is no direct instruction for this in 256 bit space, only in 128 space.
*/
__m128d half_accum0, half_accum1, half_accum2, half_accum3;
xmm0 = _mm_add_pd(_mm256_extractf128_pd(temp2_0, 0), _mm256_extractf128_pd(temp2_0, 1)); /* Add upper half to lower half of each of the four 256 bit vectors to get to four 128 bit vectors */
xmm1 = _mm_add_pd(_mm256_extractf128_pd(temp2_1, 0), _mm256_extractf128_pd(temp2_1, 1)); half_accum0 = _mm_add_pd(_mm256_extractf128_pd(accum_0, 0), _mm256_extractf128_pd(accum_0, 1));
xmm2 = _mm_add_pd(_mm256_extractf128_pd(temp2_2, 0), _mm256_extractf128_pd(temp2_2, 1)); half_accum1 = _mm_add_pd(_mm256_extractf128_pd(accum_1, 0), _mm256_extractf128_pd(accum_1, 1));
xmm3 = _mm_add_pd(_mm256_extractf128_pd(temp2_3, 0), _mm256_extractf128_pd(temp2_3, 1)); half_accum2 = _mm_add_pd(_mm256_extractf128_pd(accum_2, 0), _mm256_extractf128_pd(accum_2, 1));
half_accum3 = _mm_add_pd(_mm256_extractf128_pd(accum_3, 0), _mm256_extractf128_pd(accum_3, 1));
xmm0 = _mm_hadd_pd(xmm0, xmm0); /* in 128 bit land there is a hadd operation to do the rest of the element-wise sum in one go */
xmm1 = _mm_hadd_pd(xmm1, xmm1); half_accum0 = _mm_hadd_pd(half_accum0, half_accum0);
xmm2 = _mm_hadd_pd(xmm2, xmm2); half_accum1 = _mm_hadd_pd(half_accum1, half_accum1);
xmm3 = _mm_hadd_pd(xmm3, xmm3); half_accum2 = _mm_hadd_pd(half_accum2, half_accum2);
half_accum3 = _mm_hadd_pd(half_accum3, half_accum3);
/* and store the lowest double value from each of these vectors in the temp2 output */
temp2[0] += xmm0[0]; temp2[0] += half_accum0[0];
temp2[1] += xmm1[0]; temp2[1] += half_accum1[0];
temp2[2] += xmm2[0]; temp2[2] += half_accum2[0];
temp2[3] += xmm3[0]; temp2[3] += half_accum3[0];
} }
#endif