Commit 5225392a authored by Andreas Marek's avatar Andreas Marek
Browse files

Single precision kernel for AVX512 real block4

parent 6a210b6a
......@@ -2167,6 +2167,19 @@ intel-double-precision-mpi-noopenmp-ftimings-redirect-real-avx512_block4-complex
- export LD_LIBRARY_PATH=$MKL_HOME/lib/intel64:$LD_LIBRARY_PATH
- make check TEST_FLAGS='1000 500 128'
intel-single-precision-mpi-noopenmp-ftimings-redirect-real-avx512_block4-complex-avx512_block2-kernel-jobs:
tags:
- KNL
script:
- ./autogen.sh
- ./configure FC=mpiifort CC=mpiicc CFLAGS="-O3 -mtune=knl -axMIC-AVX512" FCFLAGS="-O3 -mtune=knl -axMIC-AVX512" SCALAPACK_FCFLAGS="-L$MKLROOT/lib/intel64 -lmkl_scalapack_lp64 -lmkl_intel_lp64 -lmkl_sequential -lmkl_core -lmkl_blacs_intelmpi_lp64 -lpthread -lm -I$MKLROOT/include/intel64/lp64" SCALAPACK_LDFLAGS="-L$MKLROOT/lib/intel64 -lmkl_scalapack_lp64 -lmkl_intel_lp64 -lmkl_sequential -lmkl_core -lmkl_blacs_intelmpi_lp64 -lpthread -lm -Wl,-rpath,$MKLROOT/lib/intel64" --with-real-avx512_block4-kernel-only --with-complex-avx512_block2-kernel-only --enable-single-precision
- /home/elpa/wait_until_midnight.sh
- make -j 8
- export OMP_NUM_THREADS=1
- export LD_LIBRARY_PATH=$MKL_HOME/lib/intel64:$LD_LIBRARY_PATH
- make check TEST_FLAGS='1000 500 128'
intel-double-precision-mpi-noopenmp-ftimings-redirect-real-avx512_block6-complex-avx512_block1-kernel-jobs:
tags:
- KNL
......
......@@ -174,9 +174,9 @@ endif
if WITH_REAL_AVX512_BLOCK4_KERNEL
libelpa@SUFFIX@_private_la_SOURCES += src/elpa2_kernels/elpa2_kernels_real_avx512_4hv_double_precision.c
#if WANT_SINGLE_PRECISION_REAL
# libelpa@SUFFIX@_private_la_SOURCES += src/elpa2_kernels/elpa2_kernels_real_avx512_4hv_single_precision.c
#endif
if WANT_SINGLE_PRECISION_REAL
libelpa@SUFFIX@_private_la_SOURCES += src/elpa2_kernels/elpa2_kernels_real_avx512_4hv_single_precision.c
endif
endif
......
......@@ -42,55 +42,47 @@
// any derivatives of ELPA under the same license that we chose for
// the original distribution, the GNU Lesser General Public License.
//
// Author: Andreas Marek, MPCDF, based on the double precision case of A. Heinecke
//
// Author: Andreas Marek (andreas.marek@mpcdf.mpg.de)
// --------------------------------------------------------------------------------------------------
#include "config-f90.h"
#include <x86intrin.h>
#define __forceinline __attribute__((always_inline)) static
#ifdef HAVE_AVX2
#ifdef __FMA4__
#define __ELPA_USE_FMA__
#define _mm256_FMA_ps(a,b,c) _mm256_macc_ps(a,b,c)
#define _mm256_NFMA_ps(a,b,c) _mm256_nmacc_ps(a,b,c)
#error "This should prop. be _mm256_msub_ps instead of _mm256_msub"
#define _mm256_FMSUB_ps(a,b,c) _mm256_msub_ps(a,b,c)
#endif
#define __forceinline __attribute__((always_inline)) static
#ifdef __AVX2__
#ifdef HAVE_AVX512
#define __ELPA_USE_FMA__
#define _mm256_FMA_ps(a,b,c) _mm256_fmadd_ps(a,b,c)
#define _mm256_NFMA_ps(a,b,c) _mm256_fnmadd_ps(a,b,c)
#define _mm256_FMSUB_ps(a,b,c) _mm256_fmsub_ps(a,b,c)
#endif
#define _mm512_FMA_ps(a,b,c) _mm512_fmadd_ps(a,b,c)
#define _mm512_NFMA_ps(a,b,c) _mm512_fnmadd_ps(a,b,c)
#define _mm512_FMSUB_ps(a,b,c) _mm512_fmsub_ps(a,b,c)
#endif
//Forward declaration
__forceinline void hh_trafo_kernel_4_AVX512_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4);
__forceinline void hh_trafo_kernel_8_AVX512_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4);
//__forceinline void hh_trafo_kernel_8_AVX512_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4);
__forceinline void hh_trafo_kernel_16_AVX512_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4);
__forceinline void hh_trafo_kernel_24_AVX512_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4);
//__forceinline void hh_trafo_kernel_24_AVX512_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4);
__forceinline void hh_trafo_kernel_32_AVX512_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4);
__forceinline void hh_trafo_kernel_48_AVX512_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4);
__forceinline void hh_trafo_kernel_64_AVX512_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4);
void quad_hh_trafo_real_avx512_4hv_single(float* q, float* hh, int* pnb, int* pnq, int* pldq, int* pldh);
void quad_hh_trafo_real_avx512_4hv_single(float* q, float* hh, int* pnb, int* pnq, int* pldq, int* pldh);
/*
!f>#ifdef HAVE_AVX512
!f>#if defined(HAVE_AVX512)
!f> interface
!f> subroutine quad_hh_trafo_real_avx512_4hv_single(q, hh, pnb, pnq, pldq, pldh) &
!f> bind(C, name="quad_hh_trafo_real_avx512_4hv_single")
!f> use, intrinsic :: iso_c_binding
!f> integer(kind=c_int) :: pnb, pnq, pldq, pldh
!f> type(c_ptr), value :: q
!f> real(kind=c_float) :: hh(pnb,6)
!f> real(kind=c_float) :: hh(pnb,6)
!f> end subroutine
!f> end interface
!f>#endif
*/
void quad_hh_trafo_real_avx512_4hv_single(float* q, float* hh, int* pnb, int* pnq, int* pldq, int* pldh)
{
int i;
......@@ -136,594 +128,473 @@ void quad_hh_trafo_real_avx512_4hv_single(float* q, float* hh, int* pnb, int* pn
}
// Production level kernel calls with padding
for (i = 0; i < nq-20; i+=24)
for (i = 0; i < nq-48; i+=64)
{
hh_trafo_kernel_24_AVX512_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
hh_trafo_kernel_64_AVX512_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
if (nq == i)
{
return;
}
if (nq-i == 20)
{
hh_trafo_kernel_16_AVX512_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
hh_trafo_kernel_4_AVX512_4hv_single(&q[i+16], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
else if (nq-i == 16)
{
hh_trafo_kernel_16_AVX512_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
else if (nq-i == 12)
if (nq-i == 48)
{
hh_trafo_kernel_8_AVX512_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
hh_trafo_kernel_4_AVX512_4hv_single(&q[i+8], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
hh_trafo_kernel_48_AVX512_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
else if (nq-i == 8)
if (nq-i == 32)
{
hh_trafo_kernel_8_AVX512_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
hh_trafo_kernel_32_AVX512_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
else
else
{
hh_trafo_kernel_4_AVX512_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
hh_trafo_kernel_16_AVX512_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
}
/**
* Unrolled kernel that computes
* 24 rows of Q simultaneously, a
* 64 rows of Q simultaneously, a
* matrix vector product with two householder
* vectors + a rank 1 update is performed
*/
__forceinline void hh_trafo_kernel_24_AVX512_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4)
__forceinline void hh_trafo_kernel_64_AVX512_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4)
{
/////////////////////////////////////////////////////
// Matrix Vector Multiplication, Q [12 x nb+3] * hh
// Matrix Vector Multiplication, Q [4 x nb+3] * hh
// hh contains four householder vectors
/////////////////////////////////////////////////////
int i;
__m256 a1_1 = _mm256_load_ps(&q[ldq*3]); // q(1,4) | .. | q(8,4)
__m256 a2_1 = _mm256_load_ps(&q[ldq*2]); // q(1,3) | q(2,3) | q(3,3) | q(4,3) | q(5,3) | q(6,3) | q(7,3) | q(8,3)
__m256 a3_1 = _mm256_load_ps(&q[ldq]); // q(1,2) | q(2,2) | q(3,2) | q(4,2) | q(5,2) | q(6,2) | q(7,2) | q(8,2)
__m256 a4_1 = _mm256_load_ps(&q[0]); // q(1,1) | q(2,1) | q(3,1) | q(4,1) | q(5,1) | q(6,1) | q(7,1) | q(8,1)
__m256 h_2_1 = _mm256_broadcast_ss(&hh[ldh+1]);
__m256 h_3_2 = _mm256_broadcast_ss(&hh[(ldh*2)+1]);
__m256 h_3_1 = _mm256_broadcast_ss(&hh[(ldh*2)+2]);
__m256 h_4_3 = _mm256_broadcast_ss(&hh[(ldh*3)+1]);
__m256 h_4_2 = _mm256_broadcast_ss(&hh[(ldh*3)+2]);
__m256 h_4_1 = _mm256_broadcast_ss(&hh[(ldh*3)+3]);
#ifdef __ELPA_USE_FMA__
register __m256 w1 = _mm256_FMA_ps(a3_1, h_4_3, a4_1);
w1 = _mm256_FMA_ps(a2_1, h_4_2, w1);
w1 = _mm256_FMA_ps(a1_1, h_4_1, w1);
register __m256 z1 = _mm256_FMA_ps(a2_1, h_3_2, a3_1);
z1 = _mm256_FMA_ps(a1_1, h_3_1, z1);
register __m256 y1 = _mm256_FMA_ps(a1_1, h_2_1, a2_1);
register __m256 x1 = a1_1;
#else
register __m256 w1 = _mm256_add_ps(a4_1, _mm256_mul_ps(a3_1, h_4_3));
w1 = _mm256_add_ps(w1, _mm256_mul_ps(a2_1, h_4_2));
w1 = _mm256_add_ps(w1, _mm256_mul_ps(a1_1, h_4_1));
register __m256 z1 = _mm256_add_ps(a3_1, _mm256_mul_ps(a2_1, h_3_2));
z1 = _mm256_add_ps(z1, _mm256_mul_ps(a1_1, h_3_1));
register __m256 y1 = _mm256_add_ps(a2_1, _mm256_mul_ps(a1_1, h_2_1));
register __m256 x1 = a1_1;
#endif
__m256 a1_2 = _mm256_load_ps(&q[(ldq*3)+8]); // q(9,4) | ... | q(16,4)
__m256 a2_2 = _mm256_load_ps(&q[(ldq*2)+8]);
__m256 a3_2 = _mm256_load_ps(&q[ldq+8]); // q(9,2) | ... | q(16,2)
__m256 a4_2 = _mm256_load_ps(&q[0+8]); // q(9,1) | q(10,1) .... | q(16,1)
#ifdef __ELPA_USE_FMA__
register __m256 w2 = _mm256_FMA_ps(a3_2, h_4_3, a4_2);
w2 = _mm256_FMA_ps(a2_2, h_4_2, w2);
w2 = _mm256_FMA_ps(a1_2, h_4_1, w2);
register __m256 z2 = _mm256_FMA_ps(a2_2, h_3_2, a3_2);
z2 = _mm256_FMA_ps(a1_2, h_3_1, z2);
register __m256 y2 = _mm256_FMA_ps(a1_2, h_2_1, a2_2);
register __m256 x2 = a1_2;
#else
register __m256 w2 = _mm256_add_ps(a4_2, _mm256_mul_ps(a3_2, h_4_3));
w2 = _mm256_add_ps(w2, _mm256_mul_ps(a2_2, h_4_2));
w2 = _mm256_add_ps(w2, _mm256_mul_ps(a1_2, h_4_1));
register __m256 z2 = _mm256_add_ps(a3_2, _mm256_mul_ps(a2_2, h_3_2));
z2 = _mm256_add_ps(z2, _mm256_mul_ps(a1_2, h_3_1));
register __m256 y2 = _mm256_add_ps(a2_2, _mm256_mul_ps(a1_2, h_2_1));
register __m256 x2 = a1_2;
#endif
__m256 a1_3 = _mm256_load_ps(&q[(ldq*3)+16]);
__m256 a2_3 = _mm256_load_ps(&q[(ldq*2)+16]);
__m256 a3_3 = _mm256_load_ps(&q[ldq+16]);
__m256 a4_3 = _mm256_load_ps(&q[0+16]);
#ifdef __ELPA_USE_FMA__
register __m256 w3 = _mm256_FMA_ps(a3_3, h_4_3, a4_3);
w3 = _mm256_FMA_ps(a2_3, h_4_2, w3);
w3 = _mm256_FMA_ps(a1_3, h_4_1, w3);
register __m256 z3 = _mm256_FMA_ps(a2_3, h_3_2, a3_3);
z3 = _mm256_FMA_ps(a1_3, h_3_1, z3);
register __m256 y3 = _mm256_FMA_ps(a1_3, h_2_1, a2_3);
register __m256 x3 = a1_3;
#else
register __m256 w3 = _mm256_add_ps(a4_3, _mm256_mul_ps(a3_3, h_4_3));
w3 = _mm256_add_ps(w3, _mm256_mul_ps(a2_3, h_4_2));
w3 = _mm256_add_ps(w3, _mm256_mul_ps(a1_3, h_4_1));
register __m256 z3 = _mm256_add_ps(a3_3, _mm256_mul_ps(a2_3, h_3_2));
z3 = _mm256_add_ps(z3, _mm256_mul_ps(a1_3, h_3_1));
register __m256 y3 = _mm256_add_ps(a2_3, _mm256_mul_ps(a1_3, h_2_1));
register __m256 x3 = a1_3;
#endif
__m256 q1;
__m256 q2;
__m256 q3;
__m256 h1;
__m256 h2;
__m256 h3;
__m256 h4;
__m512 a1_1 = _mm512_load_ps(&q[ldq*3]);
__m512 a2_1 = _mm512_load_ps(&q[ldq*2]);
__m512 a3_1 = _mm512_load_ps(&q[ldq]);
__m512 a4_1 = _mm512_load_ps(&q[0]);
__m512 a1_2 = _mm512_load_ps(&q[(ldq*3)+16]);
__m512 a2_2 = _mm512_load_ps(&q[(ldq*2)+16]);
__m512 a3_2 = _mm512_load_ps(&q[ldq+16]);
__m512 a4_2 = _mm512_load_ps(&q[0+16]);
__m512 a1_3 = _mm512_load_ps(&q[(ldq*3)+32]);
__m512 a2_3 = _mm512_load_ps(&q[(ldq*2)+32]);
__m512 a3_3 = _mm512_load_ps(&q[ldq+32]);
__m512 a4_3 = _mm512_load_ps(&q[0+32]);
__m512 a1_4 = _mm512_load_ps(&q[(ldq*3)+48]);
__m512 a2_4 = _mm512_load_ps(&q[(ldq*2)+48]);
__m512 a3_4 = _mm512_load_ps(&q[ldq+48]);
__m512 a4_4 = _mm512_load_ps(&q[0+48]);
__m512 h_2_1 = _mm512_set1_ps(hh[ldh+1]);
__m512 h_3_2 = _mm512_set1_ps(hh[(ldh*2)+1]);
__m512 h_3_1 = _mm512_set1_ps(hh[(ldh*2)+2]);
__m512 h_4_3 = _mm512_set1_ps(hh[(ldh*3)+1]);
__m512 h_4_2 = _mm512_set1_ps(hh[(ldh*3)+2]);
__m512 h_4_1 = _mm512_set1_ps(hh[(ldh*3)+3]);
__m512 w1 = _mm512_FMA_ps(a3_1, h_4_3, a4_1);
w1 = _mm512_FMA_ps(a2_1, h_4_2, w1);
w1 = _mm512_FMA_ps(a1_1, h_4_1, w1);
__m512 z1 = _mm512_FMA_ps(a2_1, h_3_2, a3_1);
z1 = _mm512_FMA_ps(a1_1, h_3_1, z1);
__m512 y1 = _mm512_FMA_ps(a1_1, h_2_1, a2_1);
__m512 x1 = a1_1;
__m512 w2 = _mm512_FMA_ps(a3_2, h_4_3, a4_2);
w2 = _mm512_FMA_ps(a2_2, h_4_2, w2);
w2 = _mm512_FMA_ps(a1_2, h_4_1, w2);
__m512 z2 = _mm512_FMA_ps(a2_2, h_3_2, a3_2);
z2 = _mm512_FMA_ps(a1_2, h_3_1, z2);
__m512 y2 = _mm512_FMA_ps(a1_2, h_2_1, a2_2);
__m512 x2 = a1_2;
__m512 w3 = _mm512_FMA_ps(a3_3, h_4_3, a4_3);
w3 = _mm512_FMA_ps(a2_3, h_4_2, w3);
w3 = _mm512_FMA_ps(a1_3, h_4_1, w3);
__m512 z3 = _mm512_FMA_ps(a2_3, h_3_2, a3_3);
z3 = _mm512_FMA_ps(a1_3, h_3_1, z3);
__m512 y3 = _mm512_FMA_ps(a1_3, h_2_1, a2_3);
__m512 x3 = a1_3;
__m512 w4 = _mm512_FMA_ps(a3_4, h_4_3, a4_4);
w4 = _mm512_FMA_ps(a2_4, h_4_2, w4);
w4 = _mm512_FMA_ps(a1_4, h_4_1, w4);
__m512 z4 = _mm512_FMA_ps(a2_4, h_3_2, a3_4);
z4 = _mm512_FMA_ps(a1_4, h_3_1, z4);
__m512 y4 = _mm512_FMA_ps(a1_4, h_2_1, a2_4);
__m512 x4 = a1_4;
__m512 q1;
__m512 q2;
__m512 q3;
__m512 q4;
__m512 h1;
__m512 h2;
__m512 h3;
__m512 h4;
for(i = 4; i < nb; i++)
{
h1 = _mm256_broadcast_ss(&hh[i-3]);
q1 = _mm256_load_ps(&q[i*ldq]); // | q(i,2) | q(i+1,2) | q(i+2,2) | q(i+3,2) | q(i+4,2) | q(i+5,2) | q(i+5,2) | q(i+7,2)
q2 = _mm256_load_ps(&q[(i*ldq)+8]);
q3 = _mm256_load_ps(&q[(i*ldq)+16]);
#ifdef __ELPA_USE_FMA__
x1 = _mm256_FMA_ps(q1, h1, x1);
x2 = _mm256_FMA_ps(q2, h1, x2);
x3 = _mm256_FMA_ps(q3, h1, x3);
#else
x1 = _mm256_add_ps(x1, _mm256_mul_ps(q1,h1));
x2 = _mm256_add_ps(x2, _mm256_mul_ps(q2,h1));
x3 = _mm256_add_ps(x3, _mm256_mul_ps(q3,h1));
#endif
h1 = _mm512_set1_ps(hh[i-3]);
h2 = _mm512_set1_ps(hh[ldh+i-2]);
h3 = _mm512_set1_ps(hh[(ldh*2)+i-1]);
h4 = _mm512_set1_ps(hh[(ldh*3)+i]);
q1 = _mm512_load_ps(&q[i*ldq]);
q2 = _mm512_load_ps(&q[(i*ldq)+16]);
q3 = _mm512_load_ps(&q[(i*ldq)+32]);
q4 = _mm512_load_ps(&q[(i*ldq)+48]);
x1 = _mm512_FMA_ps(q1, h1, x1);
y1 = _mm512_FMA_ps(q1, h2, y1);
z1 = _mm512_FMA_ps(q1, h3, z1);
w1 = _mm512_FMA_ps(q1, h4, w1);
x2 = _mm512_FMA_ps(q2, h1, x2);
y2 = _mm512_FMA_ps(q2, h2, y2);
z2 = _mm512_FMA_ps(q2, h3, z2);
w2 = _mm512_FMA_ps(q2, h4, w2);
x3 = _mm512_FMA_ps(q3, h1, x3);
y3 = _mm512_FMA_ps(q3, h2, y3);
z3 = _mm512_FMA_ps(q3, h3, z3);
w3 = _mm512_FMA_ps(q3, h4, w3);
x4 = _mm512_FMA_ps(q4, h1, x4);
y4 = _mm512_FMA_ps(q4, h2, y4);
z4 = _mm512_FMA_ps(q4, h3, z4);
w4 = _mm512_FMA_ps(q4, h4, w4);
h2 = _mm256_broadcast_ss(&hh[ldh+i-2]);
#ifdef __ELPA_USE_FMA__
y1 = _mm256_FMA_ps(q1, h2, y1);
y2 = _mm256_FMA_ps(q2, h2, y2);
y3 = _mm256_FMA_ps(q3, h2, y3);
#else
y1 = _mm256_add_ps(y1, _mm256_mul_ps(q1,h2));
y2 = _mm256_add_ps(y2, _mm256_mul_ps(q2,h2));
y3 = _mm256_add_ps(y3, _mm256_mul_ps(q3,h2));
#endif
}
h3 = _mm256_broadcast_ss(&hh[(ldh*2)+i-1]);
#ifdef __ELPA_USE_FMA__
z1 = _mm256_FMA_ps(q1, h3, z1);
z2 = _mm256_FMA_ps(q2, h3, z2);
z3 = _mm256_FMA_ps(q3, h3, z3);
#else
z1 = _mm256_add_ps(z1, _mm256_mul_ps(q1,h3));
z2 = _mm256_add_ps(z2, _mm256_mul_ps(q2,h3));
z3 = _mm256_add_ps(z3, _mm256_mul_ps(q3,h3));
#endif
h1 = _mm512_set1_ps(hh[nb-3]);
h2 = _mm512_set1_ps(hh[ldh+nb-2]);
h3 = _mm512_set1_ps(hh[(ldh*2)+nb-1]);
h4 = _mm256_broadcast_ss(&hh[(ldh*3)+i]);
#ifdef __ELPA_USE_FMA__
w1 = _mm256_FMA_ps(q1, h4, w1);
w2 = _mm256_FMA_ps(q2, h4, w2);
w3 = _mm256_FMA_ps(q3, h4, w3);
#else
w1 = _mm256_add_ps(w1, _mm256_mul_ps(q1,h4));
w2 = _mm256_add_ps(w2, _mm256_mul_ps(q2,h4));
w3 = _mm256_add_ps(w3, _mm256_mul_ps(q3,h4));
#endif
}
q1 = _mm512_load_ps(&q[nb*ldq]);
q2 = _mm512_load_ps(&q[(nb*ldq)+16]);
q3 = _mm512_load_ps(&q[(nb*ldq)+32]);
q4 = _mm512_load_ps(&q[(nb*ldq)+48]);
h1 = _mm256_broadcast_ss(&hh[nb-3]);
q1 = _mm256_load_ps(&q[nb*ldq]);
// // carefull we just need another 4 floats, the rest is zero'd
// q2 = _mm256_castps128_ps256(_mm_load_ps(&q[(nb*ldq)+8]));
q2 = _mm256_load_ps(&q[(nb*ldq)+8]);
q3 = _mm256_load_ps(&q[(nb*ldq)+16]);
#ifdef __ELPA_USE_FMA__
x1 = _mm256_FMA_ps(q1, h1, x1);
x2 = _mm256_FMA_ps(q2, h1, x2);
x3 = _mm256_FMA_ps(q3, h1, x3);
#else
x1 = _mm256_add_ps(x1, _mm256_mul_ps(q1,h1));
x2 = _mm256_add_ps(x2, _mm256_mul_ps(q2,h1));
x3 = _mm256_add_ps(x3, _mm256_mul_ps(q3,h1));
#endif
x1 = _mm512_FMA_ps(q1, h1, x1);
y1 = _mm512_FMA_ps(q1, h2, y1);
z1 = _mm512_FMA_ps(q1, h3, z1);
h2 = _mm256_broadcast_ss(&hh[ldh+nb-2]);
#ifdef __FMA4_
y1 = _mm256_FMA_ps(q1, h2, y1);
y2 = _mm256_FMA_ps(q2, h2, y2);
y3 = _mm256_FMA_ps(q3, h2, y3);
#else
y1 = _mm256_add_ps(y1, _mm256_mul_ps(q1,h2));
y2 = _mm256_add_ps(y2, _mm256_mul_ps(q2,h2));
y3 = _mm256_add_ps(y3, _mm256_mul_ps(q3,h2));
#endif
x2 = _mm512_FMA_ps(q2, h1, x2);
y2 = _mm512_FMA_ps(q2, h2, y2);
z2 = _mm512_FMA_ps(q2, h3, z2);
h3 = _mm256_broadcast_ss(&hh[(ldh*2)+nb-1]);
#ifdef __ELPA_USE_FMA__
z1 = _mm256_FMA_ps(q1, h3, z1);
z2 = _mm256_FMA_ps(q2, h3, z2);
z3 = _mm256_FMA_ps(q3, h3, z3);
#else
z1 = _mm256_add_ps(z1, _mm256_mul_ps(q1,h3));
z2 = _mm256_add_ps(z2, _mm256_mul_ps(q2,h3));
z3 = _mm256_add_ps(z3, _mm256_mul_ps(q3,h3));
#endif
x3 = _mm512_FMA_ps(q3, h1, x3);
y3 = _mm512_FMA_ps(q3, h2, y3);
z3 = _mm512_FMA_ps(q3, h3, z3);
h1 = _mm256_broadcast_ss(&hh[nb-2]);
x4 = _mm512_FMA_ps(q4, h1, x4);
y4 = _mm512_FMA_ps(q4, h2, y4);
z4 = _mm512_FMA_ps(q4, h3, z4);
q1 = _mm256_load_ps(&q[(nb+1)*ldq]);
q2 = _mm256_load_ps(&q[((nb+1)*ldq)+8]);
q3 = _mm256_load_ps(&q[((nb+1)*ldq)+16]);
#ifdef __ELPA_USE_FMA__
x1 = _mm256_FMA_ps(q1, h1, x1);
x2 = _mm256_FMA_ps(q2, h1, x2);
x3 = _mm256_FMA_ps(q3, h1, x3);
#else
x1 = _mm256_add_ps(x1, _mm256_mul_ps(q1,h1));
x2 = _mm256_add_ps(x2, _mm256_mul_ps(q2,h1));
x3 = _mm256_add_ps(x3, _mm256_mul_ps(q3,h1));
#endif
h1 = _mm512_set1_ps(hh[nb-2]);
h2 = _mm512_set1_ps(hh[(ldh*1)+nb-1]);
h2 = _mm256_broadcast_ss(&hh[(ldh*1)+nb-1]);
q1 = _mm512_load_ps(&q[(nb+1)*ldq]);
q2 = _mm512_load_ps(&q[((nb+1)*ldq)+16]);
q3 = _mm512_load_ps(&q[((nb+1)*ldq)+32]);
q4 = _mm512_load_ps(&q[((nb+1)*ldq)+48]);
#ifdef __ELPA_USE_FMA__
y1 = _mm256_FMA_ps(q1, h2, y1);
y2 = _mm256_FMA_ps(q2, h2, y2);
// y3 = _mm256_FMA_ps(q3, h2, y3);
#else
y1 = _mm256_add_ps(y1, _mm256_mul_ps(q1,h2));
y2 = _mm256_add_ps(y2, _mm256_mul_ps(q2,h2));
y3 = _mm256_add_ps(y3, _mm256_mul_ps(q3,h2));
#endif
x1 = _mm512_FMA_ps(q1, h1, x1);
y1 = _mm512_FMA_ps(q1, h2, y1);
x2 = _mm512_FMA_ps(q2, h1, x2);
y2 = _mm512_FMA_ps(q2, h2, y2);
x3 = _mm512_FMA_ps(q3, h1, x3);
y3 = _mm512_FMA_ps(q3, h2, y3);
x4 = _mm512_FMA_ps(q4, h1, x4);
y4 = _mm512_FMA_ps(q4, h2, y4);
h1 = _mm256_broadcast_ss(&hh[nb-1]);
q1 = _mm256_load_ps(&q[(nb+2)*ldq]);
// q2 = _mm256_castps128_ps256(_mm_load_ps(&q[((nb+2)*ldq)+8]));
q2 = _mm256_load_ps(&q[((nb+2)*ldq)+8]);
q3 = _mm256_load_ps(&q[((nb+2)*ldq)+16]);
#ifdef __ELPA_USE_FMA__
x1 = _mm256_FMA_ps(q1, h1, x1);
x2 = _mm256_FMA_ps(q2, h1, x2);
x3 = _mm256_FMA_ps(q3, h1, x3);
#else
x1 = _mm256_add_ps(x1, _mm256_mul_ps(q1,h1));
x2 = _mm256_add_ps(x2, _mm256_mul_ps(q2,h1));
x3 = _mm256_add_ps(x3, _mm256_mul_ps(q3,h1));
#endif
h1 = _mm512_set1_ps(hh[nb-1]);
q1 = _mm512_load_ps(&q[(nb+2)*ldq]);
q2 = _mm512_load_ps(&q[((nb+2)*ldq)+16]);
q3 = _mm512_load_ps(&q[((nb+2)*ldq)+32]);
q4 = _mm512_load_ps(&q[((nb+2)*ldq)+48]);
x1 = _mm512_FMA_ps(q1, h1, x1);
x2 = _mm512_FMA_ps(q2, h1, x2);
x3 = _mm512_FMA_ps(q3, h1, x3);
x4 = _mm512_FMA_ps(q4, h1, x4);
/////////////////////////////////////////////////////
// Rank-1 update of Q [12 x nb+3]
// Rank-1 update of Q [8 x nb+3]
/////////////////////////////////////////////////////
__m256 tau1 = _mm256_broadcast_ss(&hh[0]);
__m512 tau1 = _mm512_set1_ps(hh[0]);
__m512 tau2 = _mm512_set1_ps(hh[ldh]);
__m512 tau3 = _mm512_set1_ps(hh[ldh*2]);
__m512 tau4 = _mm512_set1_ps(hh[ldh*3]);
h1 = tau1;
x1 = _mm256_mul_ps(x1, h1);
x2 = _mm256_mul_ps(x2, h1);
x3 = _mm256_mul_ps(x3, h1);
__m512 vs_1_2 = _mm512_set1_ps(s_1_2);
__m512 vs_1_3 = _mm512_set1_ps(s_1_3);
__m512 vs_2_3 = _mm512_set1_ps(s_2_3);
__m512 vs_1_4 = _mm512_set1_ps(s_1_4);
__m512 vs_2_4 = _mm512_set1_ps(s_2_4);
__m512 vs_3_4 = _mm512_set1_ps(s_3_4);
__m256 tau2 = _mm256_broadcast_ss(&hh[ldh]);
__m256 vs_1_2 = _mm256_broadcast_ss(&s_1_2);
h1 = tau1;
x1 = _mm512_mul_ps(x1, h1);
x2 = _mm512_mul_ps(x2, h1);
x3 = _mm512_mul_ps(x3<