diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 155e0681aa44bf6931a1827be571529a5df671f1..b1373ee894adf0342bd83b65b9cf75f6f1692ed3 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -2143,6 +2143,18 @@ intel-double-precision-mpi-noopenmp-ftimings-redirect-real-avx512_block2-complex - export LD_LIBRARY_PATH=$MKL_HOME/lib/intel64:$LD_LIBRARY_PATH - make check TEST_FLAGS='1000 500 128' +intel-single-precision-mpi-noopenmp-ftimings-redirect-real-avx512_block2-complex-avx512_block1-kernel-jobs: + tags: + - KNL + script: + - ./autogen.sh + - ./configure FC=mpiifort CC=mpiicc CFLAGS="-O3 -mtune=knl -axMIC-AVX512" FCFLAGS="-O3 -mtune=knl -axMIC-AVX512" SCALAPACK_FCFLAGS="-L$MKLROOT/lib/intel64 -lmkl_scalapack_lp64 -lmkl_intel_lp64 -lmkl_sequential -lmkl_core -lmkl_blacs_intelmpi_lp64 -lpthread -lm -I$MKLROOT/include/intel64/lp64" SCALAPACK_LDFLAGS="-L$MKLROOT/lib/intel64 -lmkl_scalapack_lp64 -lmkl_intel_lp64 -lmkl_sequential -lmkl_core -lmkl_blacs_intelmpi_lp64 -lpthread -lm -Wl,-rpath,$MKLROOT/lib/intel64" --with-real-avx512_block2-kernel-only --with-complex-avx512_block1-kernel-only --enable-single-precision + - /home/elpa/wait_until_midnight.sh + - make -j 8 + - export OMP_NUM_THREADS=1 + - export LD_LIBRARY_PATH=$MKL_HOME/lib/intel64:$LD_LIBRARY_PATH + - make check TEST_FLAGS='1000 500 128' + intel-double-precision-mpi-noopenmp-ftimings-redirect-real-avx512_block4-complex-avx512_block2-kernel-jobs: tags: - KNL diff --git a/Makefile.am b/Makefile.am index 9c37cb5d6266fc61bc45a7371a4e4cc2feb0b49a..5832247bcd213058d775c74e0d435aa947aefeb6 100644 --- a/Makefile.am +++ b/Makefile.am @@ -145,9 +145,9 @@ endif if WITH_REAL_AVX512_BLOCK2_KERNEL libelpa@SUFFIX@_private_la_SOURCES += src/elpa2_kernels/elpa2_kernels_real_avx512_2hv_double_precision.c -#if WANT_SINGLE_PRECISION_REAL -# libelpa@SUFFIX@_private_la_SOURCES += src/elpa2_kernels/elpa2_kernels_real_avx512_2hv_single_precision.c -#endif +if WANT_SINGLE_PRECISION_REAL + libelpa@SUFFIX@_private_la_SOURCES += src/elpa2_kernels/elpa2_kernels_real_avx512_2hv_single_precision.c +endif endif diff --git a/src/elpa2_kernels/elpa2_kernels_real_avx512_2hv_single_precision.c b/src/elpa2_kernels/elpa2_kernels_real_avx512_2hv_single_precision.c index 50c19b62dc339d10dccccdda31c00b61c30ab2a1..2f6a894fa852ee472f8fef168ad0e3e4399e96db 100644 --- a/src/elpa2_kernels/elpa2_kernels_real_avx512_2hv_single_precision.c +++ b/src/elpa2_kernels/elpa2_kernels_real_avx512_2hv_single_precision.c @@ -42,39 +42,32 @@ // any derivatives of ELPA under the same license that we chose for // the original distribution, the GNU Lesser General Public License. // -// Author: Andreas Marek, MPCDF, based on the double precision case of A. Heinecke -// +// Author: Andreas Marek (andreas.marek@mpcdf.mpg.de) +// -------------------------------------------------------------------------------------------------- + #include "config-f90.h" + #include #define __forceinline __attribute__((always_inline)) static -#ifdef HAVE_AVX2 - -#ifdef __FMA4__ +#ifdef HAVE_AVX512 #define __ELPA_USE_FMA__ -#define _mm256_FMA_ps(a,b,c) _mm256_macc_ps(a,b,c) +#define _mm512_FMA_ps(a,b,c) _mm512_fmadd_ps(a,b,c) #endif -#ifdef __AVX2__ -#define __ELPA_USE_FMA__ -#define _mm256_FMA_ps(a,b,c) _mm256_fmadd_ps(a,b,c) -#endif - -#endif //Forward declaration -// 4 rows single presision does not work in AVX since it cannot be 32 aligned use sse instead -__forceinline void hh_trafo_kernel_4_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s); -//__forceinline void hh_trafo_kernel_4_sse_instead_of_avx512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s); -__forceinline void hh_trafo_kernel_8_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s); +//__forceinline void hh_trafo_kernel_8_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s); __forceinline void hh_trafo_kernel_16_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s); -__forceinline void hh_trafo_kernel_24_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s); +//__forceinline void hh_trafo_kernel_24_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s); +__forceinline void hh_trafo_kernel_32_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s); +__forceinline void hh_trafo_kernel_48_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s); +__forceinline void hh_trafo_kernel_64_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s); void double_hh_trafo_real_avx512_2hv_single(float* q, float* hh, int* pnb, int* pnq, int* pldq, int* pldh); - /* -!f>#ifdef HAVE_AVX512 +!f>#if defined(HAVE_AVX512) !f> interface !f> subroutine double_hh_trafo_real_avx512_2hv_single(q, hh, pnb, pnq, pldq, pldh) & !f> bind(C, name="double_hh_trafo_real_avx512_2hv_single") @@ -97,14 +90,8 @@ void double_hh_trafo_real_avx512_2hv_single(float* q, float* hh, int* pnb, int* // calculating scalar product to compute // 2 householder vectors simultaneously - // - // Fortran: - // s = hh(2,2)*1 float s = hh[(ldh)+1]*1.0; - // FORTRAN: - // do = 3, nb - // s =s + hh(i,2)*hh(i-1,1) #pragma ivdep for (i = 2; i < nb; i++) { @@ -112,51 +99,37 @@ void double_hh_trafo_real_avx512_2hv_single(float* q, float* hh, int* pnb, int* } // Production level kernel calls with padding - for (i = 0; i < nq-20; i+=24) + for (i = 0; i < nq-48; i+=64) { - hh_trafo_kernel_24_AVX512_2hv_single(&q[i], hh, nb, ldq, ldh, s); + hh_trafo_kernel_64_AVX512_2hv_single(&q[i], hh, nb, ldq, ldh, s); } + if (nq == i) { return; } - if (nq-i == 20) - { - hh_trafo_kernel_16_AVX512_2hv_single(&q[i], hh, nb, ldq, ldh, s); - hh_trafo_kernel_4_AVX512_2hv_single(&q[i+16], hh, nb, ldq, ldh, s); -// hh_trafo_kernel_4_sse_instead_of_avx512_2hv_single(&q[i+8], hh, nb, ldq, ldh, s); - } - else if (nq-i == 16) - { - hh_trafo_kernel_16_AVX512_2hv_single(&q[i], hh, nb, ldq, ldh, s); - } - else if (nq-i == 12) + if (nq-i == 48) { - hh_trafo_kernel_8_AVX512_2hv_single(&q[i], hh, nb, ldq, ldh, s); - hh_trafo_kernel_4_AVX512_2hv_single(&q[i+8], hh, nb, ldq, ldh, s); -// hh_trafo_kernel_4_sse_instead_of_avx512_2hv_single(&q[i+8], hh, nb, ldq, ldh, s); + hh_trafo_kernel_48_AVX512_2hv_single(&q[i], hh, nb, ldq, ldh, s); } - else if (nq-i == 8) + else if (nq-i == 32) { - hh_trafo_kernel_8_AVX512_2hv_single(&q[i], hh, nb, ldq, ldh, s); + hh_trafo_kernel_32_AVX512_2hv_single(&q[i], hh, nb, ldq, ldh, s); } + else { - hh_trafo_kernel_4_AVX512_2hv_single(&q[i], hh, nb, ldq, ldh, s); -// hh_trafo_kernel_4_sse_instead_of_avx512_2hv_single(&q[i], hh, nb, ldq, ldh, s); - + hh_trafo_kernel_16_AVX512_2hv_single(&q[i], hh, nb, ldq, ldh, s); } } - - /** * Unrolled kernel that computes - * 24 rows of Q simultaneously, a + * 64 rows of Q simultaneously, a * matrix vector product with two householder * vectors + a rank 2 update is performed */ - __forceinline void hh_trafo_kernel_24_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s) + __forceinline void hh_trafo_kernel_64_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s) { ///////////////////////////////////////////////////// // Matrix Vector Multiplication, Q [24 x nb+1] * hh @@ -164,920 +137,619 @@ void double_hh_trafo_real_avx512_2hv_single(float* q, float* hh, int* pnb, int* ///////////////////////////////////////////////////// int i; // Needed bit mask for floating point sign flip - __m256 sign = (__m256)_mm256_set1_epi32(0x80000000); - - __m256 x1 = _mm256_load_ps(&q[ldq]); //load q(1,2), q(2,2), q(3,2),q(4,2), q(5,2), q(6,2), q(7,2), q(8,2) - - __m256 x2 = _mm256_load_ps(&q[ldq+8]); // load q(9,2) ... q(16,2) - - __m256 x3 = _mm256_load_ps(&q[ldq+16]); // load q(17,2) .. q(24,2) -// __m256 x4 = _mm256_load_ps(&q[ldq+12]); -// __m256 x5 = _mm256_load_ps(&q[ldq+16]); -// __m256 x6 = _mm256_load_ps(&q[ldq+20]); - - __m256 h1 = _mm256_broadcast_ss(&hh[ldh+1]); // h1 = hh(2,2) | hh(2,2) | hh(2,2) | hh(2,2) | hh(2,2) | hh(2,2) | hh(2,2) | hh(2,2) - __m256 h2; - -#ifdef __ELPA_USE_FMA__ - __m256 q1 = _mm256_load_ps(q); // q1 = q(1,1), q(2,1), q(3,1), q(4,1), q(5,1), q(6,1), q(7,1), q(8,1) - __m256 y1 = _mm256_FMA_ps(x1, h1, q1); // y1 = q(1,2) * h(2,2) + q(1,1) | q(2,2) * h(2,2) + q(2,1) | .... | q(8,2) * h(2,2) + q(8,1) - __m256 q2 = _mm256_load_ps(&q[8]); // q2 = q(9,1) | .... | q(16,1) - __m256 y2 = _mm256_FMA_ps(x2, h1, q2); // y2 = q(9,2) * hh(2,2) + q(9,1) | ... | q(16,2) * h(2,2) + q(16,1) - __m256 q3 = _mm256_load_ps(&q[16]); // q3 = q(17,1) | ... | q(24,1) - __m256 y3 = _mm256_FMA_ps(x3, h1, q3); // y3 = q(17,2) * hh(2,2) + q(17,1) ... | q(24,2) * hh(2,2) + q(24,1) -// __m256 q4 = _mm256_load_ps(&q[12]); -// __m256 y4 = _mm256_FMA_ps(x4, h1, q4); -// __m256 q5 = _mm256_load_ps(&q[16]); -// __m256 y5 = _mm256_FMA_ps(x5, h1, q5); -// __m256 q6 = _mm256_load_ps(&q[20]); -// __m256 y6 = _mm256_FMA_ps(x6, h1, q6); -#else - __m256 q1 = _mm256_load_ps(q); // q1 = q(1,1), q(2,1), q(3,1), q(4,1), q(5,1), q(6,1), q(7,1), q(8,1) - __m256 y1 = _mm256_add_ps(q1, _mm256_mul_ps(x1, h1)); // y1 = q(1,2) * h(2,2) + q(1,1) | q(2,2) * h(2,2) + q(2,1) | .... | q(8,2) * h(2,2) + q(8,1) - __m256 q2 = _mm256_load_ps(&q[8]); // q2 = q(9,1) | .... | q(16,1) - __m256 y2 = _mm256_add_ps(q2, _mm256_mul_ps(x2, h1)); // y2 = q(9,2) * hh(2,2) + q(9,1) | ... | q(16,2) * h(2,2) + q(16,1) - __m256 q3 = _mm256_load_ps(&q[16]); // q3 = q(17,1) | ... | q(24,1) - __m256 y3 = _mm256_add_ps(q3, _mm256_mul_ps(x3, h1)); // y3 = q(17,2) * hh(2,2) + q(17,1) ... | q(24,2) * hh(2,2) + q(24,1) -// __m256 q4 = _mm256_load_ps(&q[12]); -// __m256 y4 = _mm256_add_ps(q4, _mm256_mul_ps(x4, h1)); -// __m256 q5 = _mm256_load_ps(&q[16]); -// __m256 y5 = _mm256_add_ps(q5, _mm256_mul_ps(x5, h1)); -// __m256 q6 = _mm256_load_ps(&q[20]); -// __m256 y6 = _mm256_add_ps(q6, _mm256_mul_ps(x6, h1)); -#endif - for(i = 2; i < nb; i++) - { - h1 = _mm256_broadcast_ss(&hh[i-1]); // h1 = hh(i-1,1) | ... | hh(i-1,1) - h2 = _mm256_broadcast_ss(&hh[ldh+i]); // h2 = hh(i,2) | ... | hh(i,2) -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[i*ldq]); // q1 = q(1,i) | q(2,i) | q(3,i) | ... | q(8,i) - x1 = _mm256_FMA_ps(q1, h1, x1); - y1 = _mm256_FMA_ps(q1, h2, y1); - q2 = _mm256_load_ps(&q[(i*ldq)+8]); - x2 = _mm256_FMA_ps(q2, h1, x2); - y2 = _mm256_FMA_ps(q2, h2, y2); - q3 = _mm256_load_ps(&q[(i*ldq)+16]); - x3 = _mm256_FMA_ps(q3, h1, x3); - y3 = _mm256_FMA_ps(q3, h2, y3); -// q4 = _mm256_load_ps(&q[(i*ldq)+12]); -// x4 = _mm256_FMA_ps(q4, h1, x4); -// y4 = _mm256_FMA_ps(q4, h2, y4); -// q5 = _mm256_load_ps(&q[(i*ldq)+16]); -// x5 = _mm256_FMA_ps(q5, h1, x5); -// y5 = _mm256_FMA_ps(q5, h2, y5); -// q6 = _mm256_load_ps(&q[(i*ldq)+20]); -// x6 = _mm256_FMA_ps(q6, h1, x6); -// y6 = _mm256_FMA_ps(q6, h2, y6); -#else - q1 = _mm256_load_ps(&q[i*ldq]); // q1 = q(1,i) | q(2,i) | q(3,i) | ... | q(8,i) - x1 = _mm256_add_ps(x1, _mm256_mul_ps(q1,h1)); // x1 = q(1,i) * hh(i-1,1) + x1 | ... | q(8,i) ** hh(i-1,1) * x1 - y1 = _mm256_add_ps(y1, _mm256_mul_ps(q1,h2)); // y1 = q(1,i) * hh(i,2) + y1 | ... - q2 = _mm256_load_ps(&q[(i*ldq)+8]); - x2 = _mm256_add_ps(x2, _mm256_mul_ps(q2,h1)); - y2 = _mm256_add_ps(y2, _mm256_mul_ps(q2,h2)); - q3 = _mm256_load_ps(&q[(i*ldq)+16]); - x3 = _mm256_add_ps(x3, _mm256_mul_ps(q3,h1)); - y3 = _mm256_add_ps(y3, _mm256_mul_ps(q3,h2)); -// q4 = _mm256_load_ps(&q[(i*ldq)+12]); -// x4 = _mm256_add_ps(x4, _mm256_mul_ps(q4,h1)); -// y4 = _mm256_add_ps(y4, _mm256_mul_ps(q4,h2)); -// q5 = _mm256_load_ps(&q[(i*ldq)+16]); -// x5 = _mm256_add_ps(x5, _mm256_mul_ps(q5,h1)); -// y5 = _mm256_add_ps(y5, _mm256_mul_ps(q5,h2)); -// q6 = _mm256_load_ps(&q[(i*ldq)+20]); -// x6 = _mm256_add_ps(x6, _mm256_mul_ps(q6,h1)); -// y6 = _mm256_add_ps(y6, _mm256_mul_ps(q6,h2)); -#endif - } - h1 = _mm256_broadcast_ss(&hh[nb-1]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[nb*ldq]); - x1 = _mm256_FMA_ps(q1, h1, x1); - q2 = _mm256_load_ps(&q[(nb*ldq)+8]); - x2 = _mm256_FMA_ps(q2, h1, x2); - q3 = _mm256_load_ps(&q[(nb*ldq)+16]); - x3 = _mm256_FMA_ps(q3, h1, x3); -// q4 = _mm256_load_ps(&q[(nb*ldq)+12]); -// x4 = _mm256_FMA_ps(q4, h1, x4); -// q5 = _mm256_load_ps(&q[(nb*ldq)+16]); -// x5 = _mm256_FMA_ps(q5, h1, x5); -// q6 = _mm256_load_ps(&q[(nb*ldq)+20]); -// x6 = _mm256_FMA_ps(q6, h1, x6); -#else - q1 = _mm256_load_ps(&q[nb*ldq]); - x1 = _mm256_add_ps(x1, _mm256_mul_ps(q1,h1)); - q2 = _mm256_load_ps(&q[(nb*ldq)+8]); - x2 = _mm256_add_ps(x2, _mm256_mul_ps(q2,h1)); - q3 = _mm256_load_ps(&q[(nb*ldq)+16]); - x3 = _mm256_add_ps(x3, _mm256_mul_ps(q3,h1)); -// q4 = _mm256_load_ps(&q[(nb*ldq)+12]); -// x4 = _mm256_add_ps(x4, _mm256_mul_ps(q4,h1)); -// q5 = _mm256_load_ps(&q[(nb*ldq)+16]); -// x5 = _mm256_add_ps(x5, _mm256_mul_ps(q5,h1)); -// q6 = _mm256_load_ps(&q[(nb*ldq)+20]); -// x6 = _mm256_add_ps(x6, _mm256_mul_ps(q6,h1)); -#endif - - ///////////////////////////////////////////////////// - // Rank-2 update of Q [24 x nb+1] - ///////////////////////////////////////////////////// + // carefull here + __m512 sign = (__m512d)_mm512_set1_epi32(0x80000000); - __m256 tau1 = _mm256_broadcast_ss(hh); - __m256 tau2 = _mm256_broadcast_ss(&hh[ldh]); - __m256 vs = _mm256_broadcast_ss(&s); - - -// carefull here - - h1 = _mm256_xor_ps(tau1, sign); - x1 = _mm256_mul_ps(x1, h1); - x2 = _mm256_mul_ps(x2, h1); - x3 = _mm256_mul_ps(x3, h1); -// x4 = _mm256_mul_ps(x4, h1); -// x5 = _mm256_mul_ps(x5, h1); -// x6 = _mm256_mul_ps(x6, h1); - h1 = _mm256_xor_ps(tau2, sign); - h2 = _mm256_mul_ps(h1, vs); -#ifdef __ELPA_USE_FMA__ - y1 = _mm256_FMA_ps(y1, h1, _mm256_mul_ps(x1,h2)); - y2 = _mm256_FMA_ps(y2, h1, _mm256_mul_ps(x2,h2)); - y3 = _mm256_FMA_ps(y3, h1, _mm256_mul_ps(x3,h2)); -// y4 = _mm256_FMA_ps(y4, h1, _mm256_mul_ps(x4,h2)); -// y5 = _mm256_FMA_ps(y5, h1, _mm256_mul_ps(x5,h2)); -// y6 = _mm256_FMA_ps(y6, h1, _mm256_mul_ps(x6,h2)); -#else - y1 = _mm256_add_ps(_mm256_mul_ps(y1,h1), _mm256_mul_ps(x1,h2)); - y2 = _mm256_add_ps(_mm256_mul_ps(y2,h1), _mm256_mul_ps(x2,h2)); - y3 = _mm256_add_ps(_mm256_mul_ps(y3,h1), _mm256_mul_ps(x3,h2)); -// y4 = _mm256_add_ps(_mm256_mul_ps(y4,h1), _mm256_mul_ps(x4,h2)); -// y5 = _mm256_add_ps(_mm256_mul_ps(y5,h1), _mm256_mul_ps(x5,h2)); -// y6 = _mm256_add_ps(_mm256_mul_ps(y6,h1), _mm256_mul_ps(x6,h2)); -#endif + __m512 x1 = _mm512_load_ps(&q[ldq]); + __m512 x2 = _mm512_load_ps(&q[ldq+32]); + __m512 x3 = _mm512_load_ps(&q[ldq+48]); + __m512 x4 = _mm512_load_ps(&q[ldq+64]); - q1 = _mm256_load_ps(q); - q1 = _mm256_add_ps(q1, y1); - _mm256_store_ps(q,q1); - q2 = _mm256_load_ps(&q[8]); - q2 = _mm256_add_ps(q2, y2); - _mm256_store_ps(&q[8],q2); - q3 = _mm256_load_ps(&q[16]); - q3 = _mm256_add_ps(q3, y3); - _mm256_store_ps(&q[16],q3); -// q4 = _mm256_load_ps(&q[12]); -// q4 = _mm256_add_ps(q4, y4); -// _mm256_store_ps(&q[12],q4); -// q5 = _mm256_load_ps(&q[16]); -// q5 = _mm256_add_ps(q5, y5); -// _mm256_store_ps(&q[16],q5); -// q6 = _mm256_load_ps(&q[20]); -// q6 = _mm256_add_ps(q6, y6); -// _mm256_store_ps(&q[20],q6); - - h2 = _mm256_broadcast_ss(&hh[ldh+1]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[ldq]); - q1 = _mm256_add_ps(q1, _mm256_FMA_ps(y1, h2, x1)); - _mm256_store_ps(&q[ldq],q1); - q2 = _mm256_load_ps(&q[ldq+8]); - q2 = _mm256_add_ps(q2, _mm256_FMA_ps(y2, h2, x2)); - _mm256_store_ps(&q[ldq+8],q2); - q3 = _mm256_load_ps(&q[ldq+16]); - q3 = _mm256_add_ps(q3, _mm256_FMA_ps(y3, h2, x3)); - _mm256_store_ps(&q[ldq+16],q3); -// q4 = _mm256_load_ps(&q[ldq+12]); -// q4 = _mm256_add_ps(q4, _mm256_FMA_ps(y4, h2, x4)); -// _mm256_store_ps(&q[ldq+12],q4); -// q5 = _mm256_load_ps(&q[ldq+16]); -// q5 = _mm256_add_ps(q5, _mm256_FMA_ps(y5, h2, x5)); -// _mm256_store_ps(&q[ldq+16],q5); -// q6 = _mm256_load_ps(&q[ldq+20]); -// q6 = _mm256_add_ps(q6, _mm256_FMA_ps(y6, h2, x6)); -// _mm256_store_ps(&q[ldq+20],q6); -#else - q1 = _mm256_load_ps(&q[ldq]); - q1 = _mm256_add_ps(q1, _mm256_add_ps(x1, _mm256_mul_ps(y1, h2))); - _mm256_store_ps(&q[ldq],q1); - q2 = _mm256_load_ps(&q[ldq+8]); - q2 = _mm256_add_ps(q2, _mm256_add_ps(x2, _mm256_mul_ps(y2, h2))); - _mm256_store_ps(&q[ldq+8],q2); - q3 = _mm256_load_ps(&q[ldq+16]); - q3 = _mm256_add_ps(q3, _mm256_add_ps(x3, _mm256_mul_ps(y3, h2))); - _mm256_store_ps(&q[ldq+16],q3); -// q4 = _mm256_load_ps(&q[ldq+12]); -// q4 = _mm256_add_ps(q4, _mm256_add_ps(x4, _mm256_mul_ps(y4, h2))); -// _mm256_store_ps(&q[ldq+12],q4); -// q5 = _mm256_load_ps(&q[ldq+16]); -// q5 = _mm256_add_ps(q5, _mm256_add_ps(x5, _mm256_mul_ps(y5, h2))); -// _mm256_store_ps(&q[ldq+16],q5); -// q6 = _mm256_load_ps(&q[ldq+20]); -// q6 = _mm256_add_ps(q6, _mm256_add_ps(x6, _mm256_mul_ps(y6, h2))); -// _mm256_store_ps(&q[ldq+20],q6); -#endif - for (i = 2; i < nb; i++) - { - h1 = _mm256_broadcast_ss(&hh[i-1]); - h2 = _mm256_broadcast_ss(&hh[ldh+i]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[i*ldq]); - q1 = _mm256_FMA_ps(x1, h1, q1); - q1 = _mm256_FMA_ps(y1, h2, q1); - _mm256_store_ps(&q[i*ldq],q1); - q2 = _mm256_load_ps(&q[(i*ldq)+8]); - q2 = _mm256_FMA_ps(x2, h1, q2); - q2 = _mm256_FMA_ps(y2, h2, q2); - _mm256_store_ps(&q[(i*ldq)+8],q2); - q3 = _mm256_load_ps(&q[(i*ldq)+16]); - q3 = _mm256_FMA_ps(x3, h1, q3); - q3 = _mm256_FMA_ps(y3, h2, q3); - _mm256_store_ps(&q[(i*ldq)+16],q3); -// q4 = _mm256_load_ps(&q[(i*ldq)+12]); -// q4 = _mm256_FMA_ps(x4, h1, q4); -// q4 = _mm256_FMA_ps(y4, h2, q4); -// _mm256_store_ps(&q[(i*ldq)+12],q4); -// q5 = _mm256_load_ps(&q[(i*ldq)+16]); -/// q5 = _mm256_FMA_ps(x5, h1, q5); -// q5 = _mm256_FMA_ps(y5, h2, q5); -// _mm256_store_ps(&q[(i*ldq)+16],q5); -// q6 = _mm256_load_ps(&q[(i*ldq)+20]); -// q6 = _mm256_FMA_ps(x6, h1, q6); -// q6 = _mm256_FMA_ps(y6, h2, q6); -// _mm256_store_ps(&q[(i*ldq)+20],q6); -#else - q1 = _mm256_load_ps(&q[i*ldq]); - q1 = _mm256_add_ps(q1, _mm256_add_ps(_mm256_mul_ps(x1,h1), _mm256_mul_ps(y1, h2))); - _mm256_store_ps(&q[i*ldq],q1); - q2 = _mm256_load_ps(&q[(i*ldq)+8]); - q2 = _mm256_add_ps(q2, _mm256_add_ps(_mm256_mul_ps(x2,h1), _mm256_mul_ps(y2, h2))); - _mm256_store_ps(&q[(i*ldq)+8],q2); - q3 = _mm256_load_ps(&q[(i*ldq)+16]); - q3 = _mm256_add_ps(q3, _mm256_add_ps(_mm256_mul_ps(x3,h1), _mm256_mul_ps(y3, h2))); - _mm256_store_ps(&q[(i*ldq)+16],q3); -// q4 = _mm256_load_ps(&q[(i*ldq)+12]); -// q4 = _mm256_add_ps(q4, _mm256_add_ps(_mm256_mul_ps(x4,h1), _mm256_mul_ps(y4, h2))); -// _mm256_store_ps(&q[(i*ldq)+12],q4); -// q5 = _mm256_load_ps(&q[(i*ldq)+16]); -// q5 = _mm256_add_ps(q5, _mm256_add_ps(_mm256_mul_ps(x5,h1), _mm256_mul_ps(y5, h2))); -// _mm256_store_ps(&q[(i*ldq)+16],q5); -// q6 = _mm256_load_ps(&q[(i*ldq)+20]); -// q6 = _mm256_add_ps(q6, _mm256_add_ps(_mm256_mul_ps(x6,h1), _mm256_mul_ps(y6, h2))); -// _mm256_store_ps(&q[(i*ldq)+20],q6); -#endif - } + __m512 h1 = _mm512_set1_ps(hh[ldh+1]); + __m512 h2; - h1 = _mm256_broadcast_ss(&hh[nb-1]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[nb*ldq]); - q1 = _mm256_FMA_ps(x1, h1, q1); - _mm256_store_ps(&q[nb*ldq],q1); - q2 = _mm256_load_ps(&q[(nb*ldq)+8]); - q2 = _mm256_FMA_ps(x2, h1, q2); - _mm256_store_ps(&q[(nb*ldq)+8],q2); - q3 = _mm256_load_ps(&q[(nb*ldq)+16]); - q3 = _mm256_FMA_ps(x3, h1, q3); - _mm256_store_ps(&q[(nb*ldq)+16],q3); -// q4 = _mm256_load_ps(&q[(nb*ldq)+12]); -/// q4 = _mm256_FMA_ps(x4, h1, q4); -// _mm256_store_ps(&q[(nb*ldq)+12],q4); -// q5 = _mm256_load_ps(&q[(nb*ldq)+16]); -// q5 = _mm256_FMA_ps(x5, h1, q5); -// _mm256_store_ps(&q[(nb*ldq)+16],q5); -// q6 = _mm256_load_ps(&q[(nb*ldq)+20]); -// q6 = _mm256_FMA_ps(x6, h1, q6); -// _mm256_store_ps(&q[(nb*ldq)+20],q6); -#else - q1 = _mm256_load_ps(&q[nb*ldq]); - q1 = _mm256_add_ps(q1, _mm256_mul_ps(x1, h1)); - _mm256_store_ps(&q[nb*ldq],q1); - q2 = _mm256_load_ps(&q[(nb*ldq)+8]); - q2 = _mm256_add_ps(q2, _mm256_mul_ps(x2, h1)); - _mm256_store_ps(&q[(nb*ldq)+8],q2); - q3 = _mm256_load_ps(&q[(nb*ldq)+16]); - q3 = _mm256_add_ps(q3, _mm256_mul_ps(x3, h1)); - _mm256_store_ps(&q[(nb*ldq)+16],q3); -// q4 = _mm256_load_ps(&q[(nb*ldq)+12]); -// q4 = _mm256_add_ps(q4, _mm256_mul_ps(x4, h1)); -// _mm256_store_ps(&q[(nb*ldq)+12],q4); -// q5 = _mm256_load_ps(&q[(nb*ldq)+16]); -// q5 = _mm256_add_ps(q5, _mm256_mul_ps(x5, h1)); -// _mm256_store_ps(&q[(nb*ldq)+16],q5); -// q6 = _mm256_load_ps(&q[(nb*ldq)+20]); -// q6 = _mm256_add_ps(q6, _mm256_mul_ps(x6, h1)); -// _mm256_store_ps(&q[(nb*ldq)+20],q6); -#endif -} - -/** - * Unrolled kernel that computes - * 16 rows of Q simultaneously, a - * matrix vector product with two householder - * vectors + a rank 2 update is performed - */ - __forceinline void hh_trafo_kernel_16_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s) -{ - ///////////////////////////////////////////////////// - // Matrix Vector Multiplication, Q [16 x nb+1] * hh - // hh contains two householder vectors, with offset 1 - ///////////////////////////////////////////////////// - int i; - // Needed bit mask for floating point sign flip - __m256 sign = (__m256)_mm256_set1_epi32(0x80000000); - - __m256 x1 = _mm256_load_ps(&q[ldq]); - __m256 x2 = _mm256_load_ps(&q[ldq+8]); -// __m256 x3 = _mm256_load_ps(&q[ldq+16]); -// __m256 x4 = _mm256_load_ps(&q[ldq+12]); - - __m256 h1 = _mm256_broadcast_ss(&hh[ldh+1]); - __m256 h2; - -#ifdef __ELPA_USE_FMA__ - __m256 q1 = _mm256_load_ps(q); - __m256 y1 = _mm256_FMA_ps(x1, h1, q1); - __m256 q2 = _mm256_load_ps(&q[8]); - __m256 y2 = _mm256_FMA_ps(x2, h1, q2); -// __m256 q3 = _mm256_load_ps(&q[16]); -// __m256 y3 = _mm256_FMA_ps(x3, h1, q3); -// __m256 q4 = _mm256_load_ps(&q[12]); -// __m256 y4 = _mm256_FMA_ps(x4, h1, q4); -#else - __m256 q1 = _mm256_load_ps(q); - __m256 y1 = _mm256_add_ps(q1, _mm256_mul_ps(x1, h1)); - __m256 q2 = _mm256_load_ps(&q[8]); - __m256 y2 = _mm256_add_ps(q2, _mm256_mul_ps(x2, h1)); -// __m256 q3 = _mm256_load_ps(&q[16]); -// __m256 y3 = _mm256_add_ps(q3, _mm256_mul_ps(x3, h1)); -// __m256 q4 = _mm256_load_ps(&q[12]); -// __m256 y4 = _mm256_add_ps(q4, _mm256_mul_ps(x4, h1)); -#endif + __m512 q1 = _mm512_load_ps(q); + __m512 y1 = _mm512_FMA_ps(x1, h1, q1); + __m512 q2 = _mm512_load_ps(&q[16]); + __m512 y2 = _mm512_FMA_ps(x2, h1, q2); + __m512 q3 = _mm512_load_ps(&q[32]); + __m512 y3 = _mm512_FMA_ps(x3, h1, q3); + __m512 q4 = _mm512_load_ps(&q[48]); + __m512 y4 = _mm512_FMA_ps(x4, h1, q4); for(i = 2; i < nb; i++) { - h1 = _mm256_broadcast_ss(&hh[i-1]); - h2 = _mm256_broadcast_ss(&hh[ldh+i]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[i*ldq]); - x1 = _mm256_FMA_ps(q1, h1, x1); - y1 = _mm256_FMA_ps(q1, h2, y1); - q2 = _mm256_load_ps(&q[(i*ldq)+8]); - x2 = _mm256_FMA_ps(q2, h1, x2); - y2 = _mm256_FMA_ps(q2, h2, y2); -// q3 = _mm256_load_ps(&q[(i*ldq)+8]); -// x3 = _mm256_FMA_ps(q3, h1, x3); -// y3 = _mm256_FMA_ps(q3, h2, y3); -// q4 = _mm256_load_ps(&q[(i*ldq)+12]); -// x4 = _mm256_FMA_ps(q4, h1, x4); -// y4 = _mm256_FMA_ps(q4, h2, y4); -#else - q1 = _mm256_load_ps(&q[i*ldq]); - x1 = _mm256_add_ps(x1, _mm256_mul_ps(q1,h1)); - y1 = _mm256_add_ps(y1, _mm256_mul_ps(q1,h2)); - q2 = _mm256_load_ps(&q[(i*ldq)+8]); - x2 = _mm256_add_ps(x2, _mm256_mul_ps(q2,h1)); - y2 = _mm256_add_ps(y2, _mm256_mul_ps(q2,h2)); -// q3 = _mm256_load_ps(&q[(i*ldq)+8]); -// x3 = _mm256_add_ps(x3, _mm256_mul_ps(q3,h1)); -// y3 = _mm256_add_ps(y3, _mm256_mul_ps(q3,h2)); -// q4 = _mm256_load_ps(&q[(i*ldq)+12]); -// x4 = _mm256_add_ps(x4, _mm256_mul_ps(q4,h1)); -// y4 = _mm256_add_ps(y4, _mm256_mul_ps(q4,h2)); -#endif + h1 = _mm512_set1_ps(hh[i-1]); + h2 = _mm512_set1_ps(hh[ldh+i]); + + q1 = _mm512_load_ps(&q[i*ldq]); + x1 = _mm512_FMA_ps(q1, h1, x1); + y1 = _mm512_FMA_ps(q1, h2, y1); + q2 = _mm512_load_ps(&q[(i*ldq)+16]); + x2 = _mm512_FMA_ps(q2, h1, x2); + y2 = _mm512_FMA_ps(q2, h2, y2); + q3 = _mm512_load_ps(&q[(i*ldq)+32]); + x3 = _mm512_FMA_ps(q3, h1, x3); + y3 = _mm512_FMA_ps(q3, h2, y3); + q4 = _mm512_load_ps(&q[(i*ldq)+48]); + x4 = _mm512_FMA_ps(q4, h1, x4); + y4 = _mm512_FMA_ps(q4, h2, y4); + } - h1 = _mm256_broadcast_ss(&hh[nb-1]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[nb*ldq]); - x1 = _mm256_FMA_ps(q1, h1, x1); - q2 = _mm256_load_ps(&q[(nb*ldq)+8]); - x2 = _mm256_FMA_ps(q2, h1, x2); -// q3 = _mm256_load_ps(&q[(nb*ldq)+8]); -// x3 = _mm256_FMA_ps(q3, h1, x3); -// q4 = _mm256_load_ps(&q[(nb*ldq)+12]); -// x4 = _mm256_FMA_ps(q4, h1, x4); -#else - q1 = _mm256_load_ps(&q[nb*ldq]); - x1 = _mm256_add_ps(x1, _mm256_mul_ps(q1,h1)); - q2 = _mm256_load_ps(&q[(nb*ldq)+8]); - x2 = _mm256_add_ps(x2, _mm256_mul_ps(q2,h1)); -// q3 = _mm256_load_ps(&q[(nb*ldq)+8]); -// x3 = _mm256_add_ps(x3, _mm256_mul_ps(q3,h1)); -// q4 = _mm256_load_ps(&q[(nb*ldq)+12]); -// x4 = _mm256_add_ps(x4, _mm256_mul_ps(q4,h1)); -#endif + h1 = _mm512_set1_ps(hh[nb-1]); + + q1 = _mm512_load_ps(&q[nb*ldq]); + x1 = _mm512_FMA_ps(q1, h1, x1); + q2 = _mm512_load_ps(&q[(nb*ldq)+16]); + x2 = _mm512_FMA_ps(q2, h1, x2); + q3 = _mm512_load_ps(&q[(nb*ldq)+32]); + x3 = _mm512_FMA_ps(q3, h1, x3); + q4 = _mm512_load_ps(&q[(nb*ldq)+48]); + x4 = _mm512_FMA_ps(q4, h1, x4); + ///////////////////////////////////////////////////// - // Rank-2 update of Q [16 x nb+1] + // Rank-2 update of Q [24 x nb+1] ///////////////////////////////////////////////////// - __m256 tau1 = _mm256_broadcast_ss(hh); - __m256 tau2 = _mm256_broadcast_ss(&hh[ldh]); - __m256 vs = _mm256_broadcast_ss(&s); - - -// carefulle - - h1 = _mm256_xor_ps(tau1, sign); - x1 = _mm256_mul_ps(x1, h1); - x2 = _mm256_mul_ps(x2, h1); -// x3 = _mm256_mul_ps(x3, h1); -// x4 = _mm256_mul_ps(x4, h1); - h1 = _mm256_xor_ps(tau2, sign); - h2 = _mm256_mul_ps(h1, vs); -#ifdef __ELPA_USE_FMA__ - y1 = _mm256_FMA_ps(y1, h1, _mm256_mul_ps(x1,h2)); - y2 = _mm256_FMA_ps(y2, h1, _mm256_mul_ps(x2,h2)); -// y3 = _mm256_FMA_ps(y3, h1, _mm256_mul_ps(x3,h2)); -// y4 = _mm256_FMA_ps(y4, h1, _mm256_mul_ps(x4,h2)); -#else - y1 = _mm256_add_ps(_mm256_mul_ps(y1,h1), _mm256_mul_ps(x1,h2)); - y2 = _mm256_add_ps(_mm256_mul_ps(y2,h1), _mm256_mul_ps(x2,h2)); -// y3 = _mm256_add_ps(_mm256_mul_ps(y3,h1), _mm256_mul_ps(x3,h2)); -// y4 = _mm256_add_ps(_mm256_mul_ps(y4,h1), _mm256_mul_ps(x4,h2)); -#endif - - q1 = _mm256_load_ps(q); - q1 = _mm256_add_ps(q1, y1); - _mm256_store_ps(q,q1); - q2 = _mm256_load_ps(&q[8]); - q2 = _mm256_add_ps(q2, y2); - _mm256_store_ps(&q[8],q2); -// q3 = _mm256_load_psa(&q[8]); -// q3 = _mm256_add_ps(q3, y3); -// _mm256_store_ps(&q[8],q3); -// q4 = _mm256_load_ps(&q[12]); -// q4 = _mm256_add_ps(q4, y4); -// _mm256_store_ps(&q[12],q4); - - h2 = _mm256_broadcast_ss(&hh[ldh+1]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[ldq]); - q1 = _mm256_add_ps(q1, _mm256_FMA_ps(y1, h2, x1)); - _mm256_store_ps(&q[ldq],q1); - q2 = _mm256_load_ps(&q[ldq+8]); - q2 = _mm256_add_ps(q2, _mm256_FMA_ps(y2, h2, x2)); - _mm256_store_ps(&q[ldq+8],q2); -// q3 = _mm256_load_ps(&q[ldq+8]); -// q3 = _mm256_add_ps(q3, _mm256_FMA_ps(y3, h2, x3)); -// _mm256_store_ps(&q[ldq+8],q3); -// q4 = _mm256_load_ps(&q[ldq+12]); -// q4 = _mm256_add_ps(q4, _mm256_FMA_ps(y4, h2, x4)); -// _mm256_store_ps(&q[ldq+12],q4); -#else - q1 = _mm256_load_ps(&q[ldq]); - q1 = _mm256_add_ps(q1, _mm256_add_ps(x1, _mm256_mul_ps(y1, h2))); - _mm256_store_ps(&q[ldq],q1); - q2 = _mm256_load_ps(&q[ldq+8]); - q2 = _mm256_add_ps(q2, _mm256_add_ps(x2, _mm256_mul_ps(y2, h2))); - _mm256_store_ps(&q[ldq+8],q2); -// q3 = _mm256_load_ps(&q[ldq+8]); -// q3 = _mm256_add_ps(q3, _mm256_add_ps(x3, _mm256_mul_ps(y3, h2))); -// _mm256_store_ps(&q[ldq+8],q3); -// q4 = _mm256_load_ps(&q[ldq+12]); -// q4 = _mm256_add_ps(q4, _mm256_add_ps(x4, _mm256_mul_ps(y4, h2))); -// _mm256_store_ps(&q[ldq+12],q4); -#endif + __m512 tau1 = _mm512_set1_ps(hh[0]); + __m512 tau2 = _mm512_set1_ps(hh[ldh]); + __m512 vs = _mm512_set1_ps(s); + + h1 = (__m512) _mm512_xor_epi32((__m512i) tau1, (__m512i) sign); + x1 = _mm512_mul_ps(x1, h1); + x2 = _mm512_mul_ps(x2, h1); + x3 = _mm512_mul_ps(x3, h1); + x4 = _mm512_mul_ps(x4, h1); + + h1 = (__m512) _mm512_xor_epi32((__m512i) tau2, (__m512i) sign); + h2 = _mm512_mul_ps(h1, vs); + y1 = _mm512_FMA_ps(y1, h1, _mm512_mul_ps(x1,h2)); + y2 = _mm512_FMA_ps(y2, h1, _mm512_mul_ps(x2,h2)); + y3 = _mm512_FMA_ps(y3, h1, _mm512_mul_ps(x3,h2)); + y4 = _mm512_FMA_ps(y4, h1, _mm512_mul_ps(x4,h2)); + + q1 = _mm512_load_ps(q); + q1 = _mm512_add_ps(q1, y1); + _mm512_store_ps(q,q1); + q2 = _mm512_load_ps(&q[16]); + q2 = _mm512_add_ps(q2, y2); + _mm512_store_ps(&q[16],q2); + q3 = _mm512_load_ps(&q[32]); + q3 = _mm512_add_ps(q3, y3); + _mm512_store_ps(&q[32],q3); + q4 = _mm512_load_ps(&q[48]); + q4 = _mm512_add_ps(q4, y4); + _mm512_store_ps(&q[48],q4); + + h2 = _mm512_set1_ps(hh[ldh+1]); + + q1 = _mm512_load_ps(&q[ldq]); + q1 = _mm512_add_ps(q1, _mm512_FMA_ps(y1, h2, x1)); + _mm512_store_ps(&q[ldq],q1); + q2 = _mm512_load_ps(&q[ldq+16]); + q2 = _mm512_add_ps(q2, _mm512_FMA_ps(y2, h2, x2)); + _mm512_store_ps(&q[ldq+16],q2); + q3 = _mm512_load_ps(&q[ldq+32]); + q3 = _mm512_add_ps(q3, _mm512_FMA_ps(y3, h2, x3)); + _mm512_store_ps(&q[ldq+32],q3); + q4 = _mm512_load_ps(&q[ldq+48]); + q4 = _mm512_add_ps(q4, _mm512_FMA_ps(y4, h2, x4)); + _mm512_store_ps(&q[ldq+48],q4); for (i = 2; i < nb; i++) { - h1 = _mm256_broadcast_ss(&hh[i-1]); - h2 = _mm256_broadcast_ss(&hh[ldh+i]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[i*ldq]); - q1 = _mm256_FMA_ps(x1, h1, q1); - q1 = _mm256_FMA_ps(y1, h2, q1); - _mm256_store_ps(&q[i*ldq],q1); - q2 = _mm256_load_ps(&q[(i*ldq)+8]); - q2 = _mm256_FMA_ps(x2, h1, q2); - q2 = _mm256_FMA_ps(y2, h2, q2); - _mm256_store_ps(&q[(i*ldq)+8],q2); -// q3 = _mm256_load_ps(&q[(i*ldq)+8]); -// q3 = _mm256_FMA_ps(x3, h1, q3); -// q3 = _mm256_FMA_ps(y3, h2, q3); -// _mm256_store_ps(&q[(i*ldq)+8],q3); -// q4 = _mm256_load_ps(&q[(i*ldq)+12]); -// q4 = _mm256_FMA_ps(x4, h1, q4); -// q4 = _mm256_FMA_ps(y4, h2, q4); -// _mm256_store_ps(&q[(i*ldq)+12],q4); -#else - q1 = _mm256_load_ps(&q[i*ldq]); - q1 = _mm256_add_ps(q1, _mm256_add_ps(_mm256_mul_ps(x1,h1), _mm256_mul_ps(y1, h2))); - _mm256_store_ps(&q[i*ldq],q1); - q2 = _mm256_load_ps(&q[(i*ldq)+8]); - q2 = _mm256_add_ps(q2, _mm256_add_ps(_mm256_mul_ps(x2,h1), _mm256_mul_ps(y2, h2))); - _mm256_store_ps(&q[(i*ldq)+8],q2); -// q3 = _mm256_load_ps(&q[(i*ldq)+8]); -// q3 = _mm256_add_ps(q3, _mm256_add_ps(_mm256_mul_ps(x3,h1), _mm256_mul_ps(y3, h2))); -// _mm256_store_ps(&q[(i*ldq)+8],q3); -// q4 = _mm256_load_ps(&q[(i*ldq)+12]); -// q4 = _mm256_add_ps(q4, _mm256_add_ps(_mm256_mul_ps(x4,h1), _mm256_mul_ps(y4, h2))); -// _mm256_store_ps(&q[(i*ldq)+12],q4); -#endif + h1 = _mm512_set1_ps(hh[i-1]); + h2 = _mm512_set1_ps(hh[ldh+i]); + + q1 = _mm512_load_ps(&q[i*ldq]); + q1 = _mm512_FMA_ps(x1, h1, q1); + q1 = _mm512_FMA_ps(y1, h2, q1); + _mm512_store_ps(&q[i*ldq],q1); + q2 = _mm512_load_ps(&q[(i*ldq)+16]); + q2 = _mm512_FMA_ps(x2, h1, q2); + q2 = _mm512_FMA_ps(y2, h2, q2); + _mm512_store_ps(&q[(i*ldq)+16],q2); + q3 = _mm512_load_ps(&q[(i*ldq)+32]); + q3 = _mm512_FMA_ps(x3, h1, q3); + q3 = _mm512_FMA_ps(y3, h2, q3); + _mm512_store_ps(&q[(i*ldq)+32],q3); + q4 = _mm512_load_ps(&q[(i*ldq)+48]); + q4 = _mm512_FMA_ps(x4, h1, q4); + q4 = _mm512_FMA_ps(y4, h2, q4); + _mm512_store_ps(&q[(i*ldq)+48],q4); + } - h1 = _mm256_broadcast_ss(&hh[nb-1]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[nb*ldq]); - q1 = _mm256_FMA_ps(x1, h1, q1); - _mm256_store_ps(&q[nb*ldq],q1); - q2 = _mm256_load_ps(&q[(nb*ldq)+8]); - q2 = _mm256_FMA_ps(x2, h1, q2); - _mm256_store_ps(&q[(nb*ldq)+8],q2); -// q3 = _mm256_load_ps(&q[(nb*ldq)+8]); -// q3 = _mm256_FMA_ps(x3, h1, q3); -// _mm256_store_ps(&q[(nb*ldq)+8],q3); -// q4 = _mm256_load_ps(&q[(nb*ldq)+12]); -// q4 = _mm256_FMA_ps(x4, h1, q4); -// _mm256_store_ps(&q[(nb*ldq)+12],q4); -#else - q1 = _mm256_load_ps(&q[nb*ldq]); - q1 = _mm256_add_ps(q1, _mm256_mul_ps(x1, h1)); - _mm256_store_ps(&q[nb*ldq],q1); - q2 = _mm256_load_ps(&q[(nb*ldq)+8]); - q2 = _mm256_add_ps(q2, _mm256_mul_ps(x2, h1)); - _mm256_store_ps(&q[(nb*ldq)+8],q2); -// q3 = _mm256_load_ps(&q[(nb*ldq)+8]); -// q3 = _mm256_add_ps(q3, _mm256_mul_ps(x3, h1)); -// _mm256_store_ps(&q[(nb*ldq)+8],q3); -// q4 = _mm256_load_ps(&q[(nb*ldq)+12]); -// q4 = _mm256_add_ps(q4, _mm256_mul_ps(x4, h1)); -// _mm256_store_ps(&q[(nb*ldq)+12],q4); -#endif + h1 = _mm512_set1_ps(hh[nb-1]); + + q1 = _mm512_load_ps(&q[nb*ldq]); + q1 = _mm512_FMA_ps(x1, h1, q1); + _mm512_store_ps(&q[nb*ldq],q1); + q2 = _mm512_load_ps(&q[(nb*ldq)+16]); + q2 = _mm512_FMA_ps(x2, h1, q2); + _mm512_store_ps(&q[(nb*ldq)+16],q2); + q3 = _mm512_load_ps(&q[(nb*ldq)+32]); + q3 = _mm512_FMA_ps(x3, h1, q3); + _mm512_store_ps(&q[(nb*ldq)+32],q3); + q4 = _mm512_load_ps(&q[(nb*ldq)+48]); + q4 = _mm512_FMA_ps(x4, h1, q4); + _mm512_store_ps(&q[(nb*ldq)+48],q4); + } /** * Unrolled kernel that computes - * 8 rows of Q simultaneously, a + * 48 rows of Q simultaneously, a * matrix vector product with two householder * vectors + a rank 2 update is performed */ - __forceinline void hh_trafo_kernel_8_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s) + __forceinline void hh_trafo_kernel_48_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s) { ///////////////////////////////////////////////////// - // Matrix Vector Multiplication, Q [8 x nb+1] * hh + // Matrix Vector Multiplication, Q [24 x nb+1] * hh // hh contains two householder vectors, with offset 1 ///////////////////////////////////////////////////// int i; // Needed bit mask for floating point sign flip - __m256 sign = (__m256)_mm256_set1_epi32(0x80000000); - - __m256 x1 = _mm256_load_ps(&q[ldq]); -// __m256 x2 = _mm256_load_ps(&q[ldq+8]); - - __m256 h1 = _mm256_broadcast_ss(&hh[ldh+1]); - __m256 h2; - -#ifdef __ELPA_USE_FMA__ - __m256 q1 = _mm256_load_ps(q); - __m256 y1 = _mm256_FMA_ps(x1, h1, q1); -// __m256 q2 = _mm256_load_ps(&q[4]); -// __m256 y2 = _mm256_FMA_ps(x2, h1, q2); -#else - __m256 q1 = _mm256_load_ps(q); - __m256 y1 = _mm256_add_ps(q1, _mm256_mul_ps(x1, h1)); -// __m256 q2 = _mm256_load_ps(&q[4]); -// __m256 y2 = _mm256_add_ps(q2, _mm256_mul_ps(x2, h1)); -#endif + // carefull here + __m512 sign = (__m512)_mm512_set1_epi32(0x80000000); + + __m512 x1 = _mm512_load_ps(&q[ldq]); + __m512 x2 = _mm512_load_ps(&q[ldq+32]); + __m512 x3 = _mm512_load_ps(&q[ldq+48]); +// __m512 x4 = _mm512_load_ps(&q[ldq+64]); + + + __m512 h1 = _mm512_set1_ps(hh[ldh+1]); + __m512 h2; + + __m512 q1 = _mm512_load_ps(q); + __m512 y1 = _mm512_FMA_ps(x1, h1, q1); + __m512 q2 = _mm512_load_ps(&q[16]); + __m512 y2 = _mm512_FMA_ps(x2, h1, q2); + __m512 q3 = _mm512_load_ps(&q[32]); + __m512 y3 = _mm512_FMA_ps(x3, h1, q3); +// __m512 q4 = _mm512_load_ps(&q[48]); +// __m512 y4 = _mm512_FMA_ps(x4, h1, q4); for(i = 2; i < nb; i++) { - h1 = _mm256_broadcast_ss(&hh[i-1]); - h2 = _mm256_broadcast_ss(&hh[ldh+i]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[i*ldq]); - x1 = _mm256_FMA_ps(q1, h1, x1); - y1 = _mm256_FMA_ps(q1, h2, y1); -// q2 = _mm256_load_ps(&q[(i*ldq)+4]); -// x2 = _mm256_FMA_ps(q2, h1, x2); -// y2 = _mm256_FMA_ps(q2, h2, y2); -#else - q1 = _mm256_load_ps(&q[i*ldq]); - x1 = _mm256_add_ps(x1, _mm256_mul_ps(q1,h1)); - y1 = _mm256_add_ps(y1, _mm256_mul_ps(q1,h2)); -// q2 = _mm256_load_ps(&q[(i*ldq)+4]); -// x2 = _mm256_add_ps(x2, _mm256_mul_ps(q2,h1)); -// y2 = _mm256_add_ps(y2, _mm256_mul_ps(q2,h2)); -#endif + h1 = _mm512_set1_ps(hh[i-1]); + h2 = _mm512_set1_ps(hh[ldh+i]); + + q1 = _mm512_load_ps(&q[i*ldq]); + x1 = _mm512_FMA_ps(q1, h1, x1); + y1 = _mm512_FMA_ps(q1, h2, y1); + q2 = _mm512_load_ps(&q[(i*ldq)+16]); + x2 = _mm512_FMA_ps(q2, h1, x2); + y2 = _mm512_FMA_ps(q2, h2, y2); + q3 = _mm512_load_ps(&q[(i*ldq)+32]); + x3 = _mm512_FMA_ps(q3, h1, x3); + y3 = _mm512_FMA_ps(q3, h2, y3); +// q4 = _mm512_load_ps(&q[(i*ldq)+48]); +// x4 = _mm512_FMA_ps(q4, h1, x4); +// y4 = _mm512_FMA_ps(q4, h2, y4); + } - h1 = _mm256_broadcast_ss(&hh[nb-1]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[nb*ldq]); - x1 = _mm256_FMA_ps(q1, h1, x1); -// q2 = _mm256_load_ps(&q[(nb*ldq)+4]); -// x2 = _mm256_FMA_ps(q2, h1, x2); -#else - q1 = _mm256_load_ps(&q[nb*ldq]); - x1 = _mm256_add_ps(x1, _mm256_mul_ps(q1,h1)); -// q2 = _mm256_load_ps(&q[(nb*ldq)+4]); -// x2 = _mm256_add_ps(x2, _mm256_mul_ps(q2,h1)); -#endif + h1 = _mm512_set1_ps(hh[nb-1]); + + q1 = _mm512_load_ps(&q[nb*ldq]); + x1 = _mm512_FMA_ps(q1, h1, x1); + q2 = _mm512_load_ps(&q[(nb*ldq)+16]); + x2 = _mm512_FMA_ps(q2, h1, x2); + q3 = _mm512_load_ps(&q[(nb*ldq)+32]); + x3 = _mm512_FMA_ps(q3, h1, x3); +// q4 = _mm512_load_ps(&q[(nb*ldq)+48]); +// x4 = _mm512_FMA_ps(q4, h1, x4); + ///////////////////////////////////////////////////// - // Rank-2 update of Q [8 x nb+1] + // Rank-2 update of Q [24 x nb+1] ///////////////////////////////////////////////////// - __m256 tau1 = _mm256_broadcast_ss(hh); - __m256 tau2 = _mm256_broadcast_ss(&hh[ldh]); - __m256 vs = _mm256_broadcast_ss(&s); - -// carefulle - - h1 = _mm256_xor_ps(tau1, sign); - x1 = _mm256_mul_ps(x1, h1); -// x2 = _mm256_mul_ps(x2, h1); - h1 = _mm256_xor_ps(tau2, sign); - h2 = _mm256_mul_ps(h1, vs); -#ifdef __ELPA_USE_FMA__ - y1 = _mm256_FMA_ps(y1, h1, _mm256_mul_ps(x1,h2)); -// y2 = _mm256_FMA_ps(y2, h1, _mm256_mul_ps(x2,h2)); -#else - y1 = _mm256_add_ps(_mm256_mul_ps(y1,h1), _mm256_mul_ps(x1,h2)); -// y2 = _mm256_add_ps(_mm256_mul_ps(y2,h1), _mm256_mul_ps(x2,h2)); -#endif - - q1 = _mm256_load_ps(q); - q1 = _mm256_add_ps(q1, y1); - _mm256_store_ps(q,q1); -// q2 = _mm256_load_ps(&q[4]); -// q2 = _mm256_add_ps(q2, y2); -// _mm256_store_ps(&q[4],q2); - - h2 = _mm256_broadcast_ss(&hh[ldh+1]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[ldq]); - q1 = _mm256_add_ps(q1, _mm256_FMA_ps(y1, h2, x1)); - _mm256_store_ps(&q[ldq],q1); -// q2 = _mm256_load_ps(&q[ldq+4]); -// q2 = _mm256_add_ps(q2, _mm256_FMA_ps(y2, h2, x2)); -// _mm256_store_ps(&q[ldq+4],q2); -#else - q1 = _mm256_load_ps(&q[ldq]); - q1 = _mm256_add_ps(q1, _mm256_add_ps(x1, _mm256_mul_ps(y1, h2))); - _mm256_store_ps(&q[ldq],q1); -// q2 = _mm256_load_ps(&q[ldq+4]); -// q2 = _mm256_add_ps(q2, _mm256_add_ps(x2, _mm256_mul_ps(y2, h2))); -// _mm256_store_ps(&q[ldq+4],q2); -#endif + __m512 tau1 = _mm512_set1_ps(hh[0]); + __m512 tau2 = _mm512_set1_ps(hh[ldh]); + __m512 vs = _mm512_set1_ps(s); + + h1 = (__m512) _mm512_xor_epi32((__m512i) tau1, (__m512i) sign); + x1 = _mm512_mul_ps(x1, h1); + x2 = _mm512_mul_ps(x2, h1); + x3 = _mm512_mul_ps(x3, h1); +// x4 = _mm512_mul_ps(x4, h1); + + h1 = (__m512) _mm512_xor_epi32((__m512i) tau2, (__m512i) sign); + h2 = _mm512_mul_ps(h1, vs); + y1 = _mm512_FMA_ps(y1, h1, _mm512_mul_ps(x1,h2)); + y2 = _mm512_FMA_ps(y2, h1, _mm512_mul_ps(x2,h2)); + y3 = _mm512_FMA_ps(y3, h1, _mm512_mul_ps(x3,h2)); +// y4 = _mm512_FMA_ps(y4, h1, _mm512_mul_ps(x4,h2)); + + q1 = _mm512_load_ps(q); + q1 = _mm512_add_ps(q1, y1); + _mm512_store_ps(q,q1); + q2 = _mm512_load_ps(&q[16]); + q2 = _mm512_add_ps(q2, y2); + _mm512_store_ps(&q[16],q2); + q3 = _mm512_load_ps(&q[32]); + q3 = _mm512_add_ps(q3, y3); + _mm512_store_ps(&q[32],q3); +// q4 = _mm512_load_ps(&q[48]); +// q4 = _mm512_add_ps(q4, y4); +// _mm512_store_ps(&q[48],q4); + + h2 = _mm512_set1_ps(hh[ldh+1]); + + q1 = _mm512_load_ps(&q[ldq]); + q1 = _mm512_add_ps(q1, _mm512_FMA_ps(y1, h2, x1)); + _mm512_store_ps(&q[ldq],q1); + q2 = _mm512_load_ps(&q[ldq+16]); + q2 = _mm512_add_ps(q2, _mm512_FMA_ps(y2, h2, x2)); + _mm512_store_ps(&q[ldq+16],q2); + q3 = _mm512_load_ps(&q[ldq+32]); + q3 = _mm512_add_ps(q3, _mm512_FMA_ps(y3, h2, x3)); + _mm512_store_ps(&q[ldq+32],q3); +// q4 = _mm512_load_ps(&q[ldq+48]); +// q4 = _mm512_add_ps(q4, _mm512_FMA_ps(y4, h2, x4)); +// _mm512_store_ps(&q[ldq+48],q4); for (i = 2; i < nb; i++) { - h1 = _mm256_broadcast_ss(&hh[i-1]); - h2 = _mm256_broadcast_ss(&hh[ldh+i]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[i*ldq]); - q1 = _mm256_FMA_ps(x1, h1, q1); - q1 = _mm256_FMA_ps(y1, h2, q1); - _mm256_store_ps(&q[i*ldq],q1); -// q2 = _mm256_load_ps(&q[(i*ldq)+4]); -// q2 = _mm256_FMA_ps(x2, h1, q2); -// q2 = _mm256_FMA_ps(y2, h2, q2); -// _mm256_store_ps(&q[(i*ldq)+4],q2); -#else - q1 = _mm256_load_ps(&q[i*ldq]); - q1 = _mm256_add_ps(q1, _mm256_add_ps(_mm256_mul_ps(x1,h1), _mm256_mul_ps(y1, h2))); - _mm256_store_ps(&q[i*ldq],q1); -// q2 = _mm256_load_ps(&q[(i*ldq)+4]); -// q2 = _mm256_add_ps(q2, _mm256_add_ps(_mm256_mul_ps(x2,h1), _mm256_mul_ps(y2, h2))); -// _mm256_store_ps(&q[(i*ldq)+4],q2); -#endif + h1 = _mm512_set1_ps(hh[i-1]); + h2 = _mm512_set1_ps(hh[ldh+i]); + + q1 = _mm512_load_ps(&q[i*ldq]); + q1 = _mm512_FMA_ps(x1, h1, q1); + q1 = _mm512_FMA_ps(y1, h2, q1); + _mm512_store_ps(&q[i*ldq],q1); + q2 = _mm512_load_ps(&q[(i*ldq)+16]); + q2 = _mm512_FMA_ps(x2, h1, q2); + q2 = _mm512_FMA_ps(y2, h2, q2); + _mm512_store_ps(&q[(i*ldq)+16],q2); + q3 = _mm512_load_ps(&q[(i*ldq)+32]); + q3 = _mm512_FMA_ps(x3, h1, q3); + q3 = _mm512_FMA_ps(y3, h2, q3); + _mm512_store_ps(&q[(i*ldq)+32],q3); +// q4 = _mm512_load_ps(&q[(i*ldq)+48]); +// q4 = _mm512_FMA_ps(x4, h1, q4); +// q4 = _mm512_FMA_ps(y4, h2, q4); +// _mm512_store_ps(&q[(i*ldq)+48],q4); + } - h1 = _mm256_broadcast_ss(&hh[nb-1]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_load_ps(&q[nb*ldq]); - q1 = _mm256_FMA_ps(x1, h1, q1); - _mm256_store_ps(&q[nb*ldq],q1); -// q2 = _mm256_load_ps(&q[(nb*ldq)+4]); -// q2 = _mm256_FMA_ps(x2, h1, q2); -// _mm256_store_ps(&q[(nb*ldq)+4],q2); -#else - q1 = _mm256_load_ps(&q[nb*ldq]); - q1 = _mm256_add_ps(q1, _mm256_mul_ps(x1, h1)); - _mm256_store_ps(&q[nb*ldq],q1); -// q2 = _mm256_load_ps(&q[(nb*ldq)+4]); -// q2 = _mm256_add_ps(q2, _mm256_mul_ps(x2, h1)); -// _mm256_store_ps(&q[(nb*ldq)+4],q2); -#endif + h1 = _mm512_set1_ps(hh[nb-1]); + + q1 = _mm512_load_ps(&q[nb*ldq]); + q1 = _mm512_FMA_ps(x1, h1, q1); + _mm512_store_ps(&q[nb*ldq],q1); + q2 = _mm512_load_ps(&q[(nb*ldq)+16]); + q2 = _mm512_FMA_ps(x2, h1, q2); + _mm512_store_ps(&q[(nb*ldq)+16],q2); + q3 = _mm512_load_ps(&q[(nb*ldq)+32]); + q3 = _mm512_FMA_ps(x3, h1, q3); + _mm512_store_ps(&q[(nb*ldq)+32],q3); +// q4 = _mm512_load_ps(&q[(nb*ldq)+48]); +// q4 = _mm512_FMA_ps(x4, h1, q4); +// _mm512_store_ps(&q[(nb*ldq)+48],q4); + } + /** * Unrolled kernel that computes - * 4 rows of Q simultaneously, a + * 32 rows of Q simultaneously, a * matrix vector product with two householder * vectors + a rank 2 update is performed */ -__forceinline void hh_trafo_kernel_4_sse_instead_of_avx512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s) + __forceinline void hh_trafo_kernel_32_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s) { ///////////////////////////////////////////////////// - // Matrix Vector Multiplication, Q [4 x nb+1] * hh + // Matrix Vector Multiplication, Q [24 x nb+1] * hh // hh contains two householder vectors, with offset 1 ///////////////////////////////////////////////////// int i; // Needed bit mask for floating point sign flip // carefull here - __m128 sign = _mm_castsi128_ps(_mm_set_epi32(0x80000000, 0x80000000, 0x80000000, 0x80000000)); + __m512 sign = (__m512)_mm512_set1_epi32(0x80000000); - __m128 x1 = _mm_load_ps(&q[ldq]); + __m512 x1 = _mm512_load_ps(&q[ldq]); + __m512 x2 = _mm512_load_ps(&q[ldq+32]); +// __m512 x3 = _mm512_load_ps(&q[ldq+48]); +// __m512 x4 = _mm512_load_ps(&q[ldq+64]); - __m128 x2 = _mm_castpd_ps(_mm_loaddup_pd( (double *) &hh[ldh+1])); - __m128 x3 ; - __m128 h1 = _mm_moveldup_ps(x2); - __m128 h2; - __m128 q1 = _mm_load_ps(q); - __m128 y1 = _mm_add_ps(q1, _mm_mul_ps(x1, h1)); + __m512 h1 = _mm512_set1_ps(hh[ldh+1]); + __m512 h2; + + __m512 q1 = _mm512_load_ps(q); + __m512 y1 = _mm512_FMA_ps(x1, h1, q1); + __m512 q2 = _mm512_load_ps(&q[16]); + __m512 y2 = _mm512_FMA_ps(x2, h1, q2); +// __m512 q3 = _mm512_load_ps(&q[32]); +// __m512 y3 = _mm512_FMA_ps(x3, h1, q3); +// __m512 q4 = _mm512_load_ps(&q[48]); +// __m512 y4 = _mm512_FMA_ps(x4, h1, q4); for(i = 2; i < nb; i++) { + h1 = _mm512_set1_ps(hh[i-1]); + h2 = _mm512_set1_ps(hh[ldh+i]); + + q1 = _mm512_load_ps(&q[i*ldq]); + x1 = _mm512_FMA_ps(q1, h1, x1); + y1 = _mm512_FMA_ps(q1, h2, y1); + q2 = _mm512_load_ps(&q[(i*ldq)+16]); + x2 = _mm512_FMA_ps(q2, h1, x2); + y2 = _mm512_FMA_ps(q2, h2, y2); +// q3 = _mm512_load_ps(&q[(i*ldq)+32]); +// x3 = _mm512_FMA_ps(q3, h1, x3); +// y3 = _mm512_FMA_ps(q3, h2, y3); +// q4 = _mm512_load_ps(&q[(i*ldq)+48]); +// x4 = _mm512_FMA_ps(q4, h1, x4); +// y4 = _mm512_FMA_ps(q4, h2, y4); - x2 = _mm_castpd_ps(_mm_loaddup_pd( (double *) &hh[i-1])); - h1 = _mm_moveldup_ps(x2); - - x3 = _mm_castpd_ps(_mm_loaddup_pd( (double *) &hh[ldh+i])); - h2 = _mm_moveldup_ps(x3); - - q1 = _mm_load_ps(&q[i*ldq]); - x1 = _mm_add_ps(x1, _mm_mul_ps(q1,h1)); - y1 = _mm_add_ps(y1, _mm_mul_ps(q1,h2)); } - x2 = _mm_castpd_ps(_mm_loaddup_pd( (double *) &hh[nb-1])); - h1 = _mm_moveldup_ps(x2); + h1 = _mm512_set1_ps(hh[nb-1]); + + q1 = _mm512_load_ps(&q[nb*ldq]); + x1 = _mm512_FMA_ps(q1, h1, x1); + q2 = _mm512_load_ps(&q[(nb*ldq)+16]); + x2 = _mm512_FMA_ps(q2, h1, x2); +// q3 = _mm512_load_ps(&q[(nb*ldq)+32]); +// x3 = _mm512_FMA_ps(q3, h1, x3); +// q4 = _mm512_load_ps(&q[(nb*ldq)+48]); +// x4 = _mm512_FMA_ps(q4, h1, x4); - q1 = _mm_load_ps(&q[nb*ldq]); - x1 = _mm_add_ps(x1, _mm_mul_ps(q1,h1)); ///////////////////////////////////////////////////// - // Rank-2 update of Q [12 x nb+1] + // Rank-2 update of Q [24 x nb+1] ///////////////////////////////////////////////////// - x2 = _mm_castpd_ps(_mm_loaddup_pd( (double *) hh)); - __m128 tau1 = _mm_moveldup_ps(x2); - - x3 = _mm_castpd_ps(_mm_loaddup_pd( (double *) &hh[ldh])); - __m128 tau2 = _mm_moveldup_ps(x3); - - __m128 x4 = _mm_castpd_ps(_mm_loaddup_pd( (double *) &s)); - __m128 vs = _mm_moveldup_ps(x4); - - h1 = _mm_xor_ps(tau1, sign); - x1 = _mm_mul_ps(x1, h1); - h1 = _mm_xor_ps(tau2, sign); - h2 = _mm_mul_ps(h1, vs); - - y1 = _mm_add_ps(_mm_mul_ps(y1,h1), _mm_mul_ps(x1,h2)); - - q1 = _mm_load_ps(q); - q1 = _mm_add_ps(q1, y1); - _mm_store_ps(q,q1); - - x2 = _mm_castpd_ps(_mm_loaddup_pd( (double *) &hh[ldh+1])); - h2 = _mm_moveldup_ps(x2); - - q1 = _mm_load_ps(&q[ldq]); - q1 = _mm_add_ps(q1, _mm_add_ps(x1, _mm_mul_ps(y1, h2))); - _mm_store_ps(&q[ldq],q1); + __m512 tau1 = _mm512_set1_ps(hh[0]); + __m512 tau2 = _mm512_set1_ps(hh[ldh]); + __m512 vs = _mm512_set1_ps(s); + + h1 = (__m512) _mm512_xor_epi32((__m512i) tau1, (__m512i) sign); + x1 = _mm512_mul_ps(x1, h1); + x2 = _mm512_mul_ps(x2, h1); +// x3 = _mm512_mul_ps(x3, h1); +// x4 = _mm512_mul_ps(x4, h1); + + h1 = (__m512) _mm512_xor_epi32((__m512i) tau2, (__m512i) sign); + h2 = _mm512_mul_ps(h1, vs); + y1 = _mm512_FMA_ps(y1, h1, _mm512_mul_ps(x1,h2)); + y2 = _mm512_FMA_ps(y2, h1, _mm512_mul_ps(x2,h2)); +// y3 = _mm512_FMA_ps(y3, h1, _mm512_mul_ps(x3,h2)); +// y4 = _mm512_FMA_ps(y4, h1, _mm512_mul_ps(x4,h2)); + + q1 = _mm512_load_ps(q); + q1 = _mm512_add_ps(q1, y1); + _mm512_store_ps(q,q1); + q2 = _mm512_load_ps(&q[16]); + q2 = _mm512_add_ps(q2, y2); + _mm512_store_ps(&q[16],q2); +// q3 = _mm512_load_ps(&q[32]); +// q3 = _mm512_add_ps(q3, y3); +// _mm512_store_ps(&q[32],q3); +// q4 = _mm512_load_ps(&q[48]); +// q4 = _mm512_add_ps(q4, y4); +// _mm512_store_ps(&q[48],q4); + + h2 = _mm512_set1_ps(hh[ldh+1]); + + q1 = _mm512_load_ps(&q[ldq]); + q1 = _mm512_add_ps(q1, _mm512_FMA_ps(y1, h2, x1)); + _mm512_store_ps(&q[ldq],q1); + q2 = _mm512_load_ps(&q[ldq+16]); + q2 = _mm512_add_ps(q2, _mm512_FMA_ps(y2, h2, x2)); + _mm512_store_ps(&q[ldq+16],q2); +// q3 = _mm512_load_ps(&q[ldq+32]); +// q3 = _mm512_add_ps(q3, _mm512_FMA_ps(y3, h2, x3)); +// _mm512_store_ps(&q[ldq+32],q3); +// q4 = _mm512_load_ps(&q[ldq+48]); +// q4 = _mm512_add_ps(q4, _mm512_FMA_ps(y4, h2, x4)); +// _mm512_store_ps(&q[ldq+48],q4); for (i = 2; i < nb; i++) { - x2 = _mm_castpd_ps(_mm_loaddup_pd( (double *) &hh[i-1])); - h1 = _mm_moveldup_ps(x2); - - x3 = _mm_castpd_ps(_mm_loaddup_pd( (double *) &hh[ldh+i])); - h2 = _mm_moveldup_ps(x3); + h1 = _mm512_set1_ps(hh[i-1]); + h2 = _mm512_set1_ps(hh[ldh+i]); + + q1 = _mm512_load_ps(&q[i*ldq]); + q1 = _mm512_FMA_ps(x1, h1, q1); + q1 = _mm512_FMA_ps(y1, h2, q1); + _mm512_store_ps(&q[i*ldq],q1); + q2 = _mm512_load_ps(&q[(i*ldq)+16]); + q2 = _mm512_FMA_ps(x2, h1, q2); + q2 = _mm512_FMA_ps(y2, h2, q2); + _mm512_store_ps(&q[(i*ldq)+16],q2); +// q3 = _mm512_load_ps(&q[(i*ldq)+32]); +// q3 = _mm512_FMA_ps(x3, h1, q3); +// q3 = _mm512_FMA_ps(y3, h2, q3); +// _mm512_store_ps(&q[(i*ldq)+32],q3); +// q4 = _mm512_load_ps(&q[(i*ldq)+48]); +// q4 = _mm512_FMA_ps(x4, h1, q4); +// q4 = _mm512_FMA_ps(y4, h2, q4); +// _mm512_store_ps(&q[(i*ldq)+48],q4); - q1 = _mm_load_ps(&q[i*ldq]); - q1 = _mm_add_ps(q1, _mm_add_ps(_mm_mul_ps(x1,h1), _mm_mul_ps(y1, h2))); - _mm_store_ps(&q[i*ldq],q1); } - x2 = _mm_castpd_ps(_mm_loaddup_pd( (double *) &hh[nb-1])); - h1 = _mm_moveldup_ps(x2); + h1 = _mm512_set1_ps(hh[nb-1]); + + q1 = _mm512_load_ps(&q[nb*ldq]); + q1 = _mm512_FMA_ps(x1, h1, q1); + _mm512_store_ps(&q[nb*ldq],q1); + q2 = _mm512_load_ps(&q[(nb*ldq)+16]); + q2 = _mm512_FMA_ps(x2, h1, q2); + _mm512_store_ps(&q[(nb*ldq)+16],q2); +// q3 = _mm512_load_ps(&q[(nb*ldq)+32]); +// q3 = _mm512_FMA_ps(x3, h1, q3); +// _mm512_store_ps(&q[(nb*ldq)+32],q3); +// q4 = _mm512_load_ps(&q[(nb*ldq)+48]); +// q4 = _mm512_FMA_ps(x4, h1, q4); +// _mm512_store_ps(&q[(nb*ldq)+48],q4); - q1 = _mm_load_ps(&q[nb*ldq]); - q1 = _mm_add_ps(q1, _mm_mul_ps(x1, h1)); - _mm_store_ps(&q[nb*ldq],q1); } + /** * Unrolled kernel that computes - * 4 rows of Q simultaneously, a + * 16 rows of Q simultaneously, a * matrix vector product with two householder * vectors + a rank 2 update is performed */ - __forceinline void hh_trafo_kernel_4_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s) + __forceinline void hh_trafo_kernel_16_AVX512_2hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s) { ///////////////////////////////////////////////////// - // Matrix Vector Multiplication, Q [4 x nb+1] * hh + // Matrix Vector Multiplication, Q [24 x nb+1] * hh // hh contains two householder vectors, with offset 1 ///////////////////////////////////////////////////// int i; // Needed bit mask for floating point sign flip - __m256 sign = (__m256)_mm256_set1_epi32(0x80000000); + // carefull here + __m512 sign = (__m512)_mm512_set1_epi32(0x80000000); - __m256 x1 = _mm256_castps128_ps256(_mm_load_ps(&q[ldq])); + __m512 x1 = _mm512_load_ps(&q[ldq]); +// __m512 x2 = _mm512_load_ps(&q[ldq+32]); +// __m512 x3 = _mm512_load_ps(&q[ldq+48]); +// __m512 x4 = _mm512_load_ps(&q[ldq+64]); - __m256 h1 = _mm256_broadcast_ss(&hh[ldh+1]); - __m256 h2; -#ifdef __ELPA_USE_FMA__ - __m256 q1 = _mm256_castps128_ps256(_mm_load_ps(q)); - __m256 y1 = _mm256_FMA_ps(x1, h1, q1); -#else - __m256 q1 = _mm256_castps128_ps256(_mm_load_ps(q)); - __m256 y1 = _mm256_add_ps(q1, _mm256_mul_ps(x1, h1)); -#endif + __m512 h1 = _mm512_set1_ps(hh[ldh+1]); + __m512 h2; + + __m512 q1 = _mm512_load_ps(q); + __m512 y1 = _mm512_FMA_ps(x1, h1, q1); +// __m512 q2 = _mm512_load_ps(&q[16]); +// __m512 y2 = _mm512_FMA_ps(x2, h1, q2); +// __m512 q3 = _mm512_load_ps(&q[32]); +// __m512 y3 = _mm512_FMA_ps(x3, h1, q3); +// __m512 q4 = _mm512_load_ps(&q[48]); +// __m512 y4 = _mm512_FMA_ps(x4, h1, q4); for(i = 2; i < nb; i++) { - h1 = _mm256_broadcast_ss(&hh[i-1]); - h2 = _mm256_broadcast_ss(&hh[ldh+i]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_castps128_ps256(_mm_load_ps(&q[i*ldq])); - x1 = _mm256_FMA_ps(q1, h1, x1); - y1 = _mm256_FMA_ps(q1, h2, y1); -#else - q1 = _mm256_castps128_ps256(_mm_load_ps(&q[i*ldq])); - x1 = _mm256_add_ps(x1, _mm256_mul_ps(q1,h1)); - y1 = _mm256_add_ps(y1, _mm256_mul_ps(q1,h2)); -#endif + h1 = _mm512_set1_ps(hh[i-1]); + h2 = _mm512_set1_ps(hh[ldh+i]); + + q1 = _mm512_load_ps(&q[i*ldq]); + x1 = _mm512_FMA_ps(q1, h1, x1); + y1 = _mm512_FMA_ps(q1, h2, y1); +// q2 = _mm512_load_ps(&q[(i*ldq)+16]); +// x2 = _mm512_FMA_ps(q2, h1, x2); +// y2 = _mm512_FMA_ps(q2, h2, y2); +// q3 = _mm512_load_ps(&q[(i*ldq)+32]); +// x3 = _mm512_FMA_ps(q3, h1, x3); +// y3 = _mm512_FMA_ps(q3, h2, y3); +// q4 = _mm512_load_ps(&q[(i*ldq)+48]); +// x4 = _mm512_FMA_ps(q4, h1, x4); +// y4 = _mm512_FMA_ps(q4, h2, y4); + } - h1 = _mm256_broadcast_ss(&hh[nb-1]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_castps128_ps256(_mm_load_ps(&q[nb*ldq])); - x1 = _mm256_FMA_ps(q1, h1, x1); -#else - q1 = _mm256_castps128_ps256(_mm_load_ps(&q[nb*ldq])); - x1 = _mm256_add_ps(x1, _mm256_mul_ps(q1,h1)); -#endif + h1 = _mm512_set1_ps(hh[nb-1]); + + q1 = _mm512_load_ps(&q[nb*ldq]); + x1 = _mm512_FMA_ps(q1, h1, x1); +// q2 = _mm512_load_ps(&q[(nb*ldq)+16]); +// x2 = _mm512_FMA_ps(q2, h1, x2); +// q3 = _mm512_load_ps(&q[(nb*ldq)+32]); +// x3 = _mm512_FMA_ps(q3, h1, x3); +// q4 = _mm512_load_ps(&q[(nb*ldq)+48]); +// x4 = _mm512_FMA_ps(q4, h1, x4); + ///////////////////////////////////////////////////// - // Rank-2 update of Q [4 x nb+1] + // Rank-2 update of Q [24 x nb+1] ///////////////////////////////////////////////////// - __m256 tau1 = _mm256_broadcast_ss(hh); - __m256 tau2 = _mm256_broadcast_ss(&hh[ldh]); - __m256 vs = _mm256_broadcast_ss(&s); - - h1 = _mm256_xor_ps(tau1, sign); - x1 = _mm256_mul_ps(x1, h1); - h1 = _mm256_xor_ps(tau2, sign); - h2 = _mm256_mul_ps(h1, vs); -#ifdef __ELPA_USE_FMA__ - y1 = _mm256_FMA_ps(y1, h1, _mm256_mul_ps(x1,h2)); -#else - y1 = _mm256_add_ps(_mm256_mul_ps(y1,h1), _mm256_mul_ps(x1,h2)); -#endif - - q1 = _mm256_castps128_ps256(_mm_load_ps(q)); - q1 = _mm256_add_ps(q1, y1); - _mm_store_ps(q, _mm256_castps256_ps128(q1)); -// _mm256_store_ps(q,q1); - - h2 = _mm256_broadcast_ss(&hh[ldh+1]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_castps128_ps256(_mm_load_ps(&q[ldq])); - q1 = _mm256_add_ps(q1, _mm256_FMA_ps(y1, h2, x1)); - _mm_store_ps(&q[ldq], _mm256_castps256_ps128(q1)); -// _mm256_store_ps(&q[ldq],q1); -#else - q1 = _mm256_castps128_ps256(_mm_load_ps(&q[ldq])); - q1 = _mm256_add_ps(q1, _mm256_add_ps(x1, _mm256_mul_ps(y1, h2))); - _mm_store_ps(&q[ldq], _mm256_castps256_ps128(q1)); - -// _mm256_store_ps(&q[ldq],q1); -#endif + __m512 tau1 = _mm512_set1_ps(hh[0]); + __m512 tau2 = _mm512_set1_ps(hh[ldh]); + __m512 vs = _mm512_set1_ps(s); + + h1 = (__m512) _mm512_xor_epi32((__m512i) tau1, (__m512i) sign); + x1 = _mm512_mul_ps(x1, h1); +// x2 = _mm512_mul_ps(x2, h1); +// x3 = _mm512_mul_ps(x3, h1); +// x4 = _mm512_mul_ps(x4, h1); + + h1 = (__m512) _mm512_xor_epi32((__m512i) tau2, (__m512i) sign); + h2 = _mm512_mul_ps(h1, vs); + y1 = _mm512_FMA_ps(y1, h1, _mm512_mul_ps(x1,h2)); +// y2 = _mm512_FMA_ps(y2, h1, _mm512_mul_ps(x2,h2)); +// y3 = _mm512_FMA_ps(y3, h1, _mm512_mul_ps(x3,h2)); +// y4 = _mm512_FMA_ps(y4, h1, _mm512_mul_ps(x4,h2)); + + q1 = _mm512_load_ps(q); + q1 = _mm512_add_ps(q1, y1); + _mm512_store_ps(q,q1); +// q2 = _mm512_load_ps(&q[16]); +// q2 = _mm512_add_ps(q2, y2); +// _mm512_store_ps(&q[16],q2); +// q3 = _mm512_load_ps(&q[32]); +// q3 = _mm512_add_ps(q3, y3); +// _mm512_store_ps(&q[32],q3); +// q4 = _mm512_load_ps(&q[48]); +// q4 = _mm512_add_ps(q4, y4); +// _mm512_store_ps(&q[48],q4); + + h2 = _mm512_set1_ps(hh[ldh+1]); + + q1 = _mm512_load_ps(&q[ldq]); + q1 = _mm512_add_ps(q1, _mm512_FMA_ps(y1, h2, x1)); + _mm512_store_ps(&q[ldq],q1); +// q2 = _mm512_load_ps(&q[ldq+16]); +// q2 = _mm512_add_ps(q2, _mm512_FMA_ps(y2, h2, x2)); +// _mm512_store_ps(&q[ldq+16],q2); +// q3 = _mm512_load_ps(&q[ldq+32]); +// q3 = _mm512_add_ps(q3, _mm512_FMA_ps(y3, h2, x3)); +// _mm512_store_ps(&q[ldq+32],q3); +// q4 = _mm512_load_ps(&q[ldq+48]); +// q4 = _mm512_add_ps(q4, _mm512_FMA_ps(y4, h2, x4)); +// _mm512_store_ps(&q[ldq+48],q4); for (i = 2; i < nb; i++) { - h1 = _mm256_broadcast_ss(&hh[i-1]); - h2 = _mm256_broadcast_ss(&hh[ldh+i]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_castps128_ps256(_mm_load_ps(&q[i*ldq])); - q1 = _mm256_FMA_ps(x1, h1, q1); - q1 = _mm256_FMA_ps(y1, h2, q1); - _mm_store_ps(&q[i*ldq], _mm256_castps256_ps128(q1)); -// _mm256_store_ps(&q[i*ldq],q1); -#else - q1 = _mm256_castps128_ps256(_mm_load_ps(&q[i*ldq])); - q1 = _mm256_add_ps(q1, _mm256_add_ps(_mm256_mul_ps(x1,h1), _mm256_mul_ps(y1, h2))); - _mm_store_ps(&q[i*ldq], _mm256_castps256_ps128(q1)); -// _mm256_store_ps(&q[i*ldq],q1); -#endif + h1 = _mm512_set1_ps(hh[i-1]); + h2 = _mm512_set1_ps(hh[ldh+i]); + + q1 = _mm512_load_ps(&q[i*ldq]); + q1 = _mm512_FMA_ps(x1, h1, q1); + q1 = _mm512_FMA_ps(y1, h2, q1); + _mm512_store_ps(&q[i*ldq],q1); +// q2 = _mm512_load_ps(&q[(i*ldq)+16]); +// q2 = _mm512_FMA_ps(x2, h1, q2); +// q2 = _mm512_FMA_ps(y2, h2, q2); +// _mm512_store_ps(&q[(i*ldq)+16],q2); +// q3 = _mm512_load_ps(&q[(i*ldq)+32]); +// q3 = _mm512_FMA_ps(x3, h1, q3); +// q3 = _mm512_FMA_ps(y3, h2, q3); +// _mm512_store_ps(&q[(i*ldq)+32],q3); +// q4 = _mm512_load_ps(&q[(i*ldq)+48]); +// q4 = _mm512_FMA_ps(x4, h1, q4); +// q4 = _mm512_FMA_ps(y4, h2, q4); +// _mm512_store_ps(&q[(i*ldq)+48],q4); + } - h1 = _mm256_broadcast_ss(&hh[nb-1]); -#ifdef __ELPA_USE_FMA__ - q1 = _mm256_castps128_ps256(_mm_load_ps(&q[nb*ldq])); - q1 = _mm256_FMA_ps(x1, h1, q1); - _mm_store_ps(&q[nb*ldq], _mm256_castps256_ps128(q1)); -// _mm256_store_ps(&q[nb*ldq],q1); -#else - q1 = _mm256_castps128_ps256(_mm_load_ps(&q[nb*ldq])); - q1 = _mm256_add_ps(q1, _mm256_mul_ps(x1, h1)); - _mm_store_ps(&q[nb*ldq], _mm256_castps256_ps128(q1)); -// _mm256_store_ps(&q[nb*ldq],q1); -#endif + h1 = _mm512_set1_ps(hh[nb-1]); + + q1 = _mm512_load_ps(&q[nb*ldq]); + q1 = _mm512_FMA_ps(x1, h1, q1); + _mm512_store_ps(&q[nb*ldq],q1); +// q2 = _mm512_load_ps(&q[(nb*ldq)+16]); +// q2 = _mm512_FMA_ps(x2, h1, q2); +// _mm512_store_ps(&q[(nb*ldq)+16],q2); +// q3 = _mm512_load_ps(&q[(nb*ldq)+32]); +// q3 = _mm512_FMA_ps(x3, h1, q3); +// _mm512_store_ps(&q[(nb*ldq)+32],q3); +// q4 = _mm512_load_ps(&q[(nb*ldq)+48]); +// q4 = _mm512_FMA_ps(x4, h1, q4); +// _mm512_store_ps(&q[(nb*ldq)+48],q4); + } diff --git a/src/mod_compute_hh_trafo_real.F90 b/src/mod_compute_hh_trafo_real.F90 index eafa94cc13edcff10d4508cd139e5df42565b7f2..c88682b6599459cec259f400803eeed581234c50 100644 --- a/src/mod_compute_hh_trafo_real.F90 +++ b/src/mod_compute_hh_trafo_real.F90 @@ -1059,30 +1059,30 @@ module compute_hh_trafo_real #endif /* WITH_NO_SPECIFIC_REAL_KERNEL */ #endif /* WITH_REAL_AVX_BLOCK2_KERNEL || WITH_REAL_AVX2_BLOCK2_KERNEL */ -!#if defined(WITH_REAL_AVX512_BLOCK2_KERNEL) -!#if defined(WITH_NO_SPECIFIC_REAL_KERNEL) -! if (THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX512_BLOCK2) then -!#endif /* WITH_NO_SPECIFIC_REAL_KERNEL */ -! -!#if defined(WITH_NO_SPECIFIC_REAL_KERNEL) || (defined(WITH_ONE_SPECIFIC_REAL_KERNEL) && !defined(WITH_REAL_AVX512_BLOCK6_KERNEL) && !defined(WITH_REAL_AVX512_BLOCK4_KERNEL)) -! do j = ncols, 2, -2 -! w(:,1) = bcast_buffer(1:nbw,j+off) -! w(:,2) = bcast_buffer(1:nbw,j+off-1) -!#ifdef WITH_OPENMP -! call double_hh_trafo_real_avx512_2hv_single(c_loc(a(1,j+off+a_off-1,istripe,my_thread)), & -! w, nbw, nl, stripe_width, nbw) -!#else -! call double_hh_trafo_real_avx512_2hv_single(c_loc(a(1,j+off+a_off-1,istripe)), & -! w, nbw, nl, stripe_width, nbw) -!#endif -! enddo -! -!#endif /* defined(WITH_NO_SPECIFIC_REAL_KERNEL) || (defined(WITH_ONE_SPECIFIC_REAL_KERNEL) && !defined(WITH_REAL_AVX512_BLOCK6_KERNEL) && !defined(WITH_REAL_AVX512_BLOCK4_KERNEL) ) */ -! -!#if defined(WITH_NO_SPECIFIC_REAL_KERNEL) -! endif -!#endif /* WITH_NO_SPECIFIC_REAL_KERNEL */ -!#endif /* WITH_REAL_AVX512_BLOCK2_KERNEL */ +#if defined(WITH_REAL_AVX512_BLOCK2_KERNEL) +#if defined(WITH_NO_SPECIFIC_REAL_KERNEL) + if (THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX512_BLOCK2) then +#endif /* WITH_NO_SPECIFIC_REAL_KERNEL */ + +#if defined(WITH_NO_SPECIFIC_REAL_KERNEL) || (defined(WITH_ONE_SPECIFIC_REAL_KERNEL) && !defined(WITH_REAL_AVX512_BLOCK6_KERNEL) && !defined(WITH_REAL_AVX512_BLOCK4_KERNEL)) + do j = ncols, 2, -2 + w(:,1) = bcast_buffer(1:nbw,j+off) + w(:,2) = bcast_buffer(1:nbw,j+off-1) +#ifdef WITH_OPENMP + call double_hh_trafo_real_avx512_2hv_single(c_loc(a(1,j+off+a_off-1,istripe,my_thread)), & + w, nbw, nl, stripe_width, nbw) +#else + call double_hh_trafo_real_avx512_2hv_single(c_loc(a(1,j+off+a_off-1,istripe)), & + w, nbw, nl, stripe_width, nbw) +#endif + enddo + +#endif /* defined(WITH_NO_SPECIFIC_REAL_KERNEL) || (defined(WITH_ONE_SPECIFIC_REAL_KERNEL) && !defined(WITH_REAL_AVX512_BLOCK6_KERNEL) && !defined(WITH_REAL_AVX512_BLOCK4_KERNEL) ) */ + +#if defined(WITH_NO_SPECIFIC_REAL_KERNEL) + endif +#endif /* WITH_NO_SPECIFIC_REAL_KERNEL */ +#endif /* WITH_REAL_AVX512_BLOCK2_KERNEL */ #if defined(WITH_REAL_BGP_KERNEL) #if defined(WITH_NO_SPECIFIC_REAL_KERNEL)