Unverified Commit e1d0685f authored by Andreas Marek's avatar Andreas Marek
Browse files

Single precision AVX/AVX2 BLOCK4 kernel

parent c72fa66a
......@@ -304,7 +304,9 @@ contains
if ( (THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_GENERIC) .or. &
(THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_GENERIC_SIMPLE) .or. &
(THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_SSE) .or. &
! (THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_SSE_BLOCK2) .or. &
(THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX_BLOCK2) .or. &
(THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX_BLOCK4) .or. &
(THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_GPU) ) then
else
print *,"At the moment single precision only works with the generic kernels"
......@@ -656,7 +658,9 @@ contains
if ( (THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_GENERIC) .or. &
(THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_GENERIC_SIMPLE) .or. &
(THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_SSE) .or. &
! (THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_SSE_BLOCK2) .or. &
(THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX_BLOCK2) .or. &
(THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX_BLOCK4) .or. &
(THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_GPU) ) then
else
print *,"At the moment single precision only works with the generic kernels"
......
......@@ -69,31 +69,33 @@
#ifdef __FMA4__
#define __ELPA_USE_FMA__
#define _mm256_FMA_pd(a,b,c) _mm256_macc_pd(a,b,c)
#define _mm256_NFMA_pd(a,b,c) _mm256_nmacc_pd(a,b,c)
#define _mm256_FMSUB_pd(a,b,c) _mm256_msub(a,b,c)
#define _mm256_FMA_ps(a,b,c) _mm256_macc_ps(a,b,c)
#define _mm256_NFMA_ps(a,b,c) _mm256_nmacc_ps(a,b,c)
#error "This should prop. be _mm256_msub_ps instead of _mm256_msub"
#define _mm256_FMSUB_ps(a,b,c) _mm256_msub_ps(a,b,c)
#endif
#ifdef __AVX2__
#define __ELPA_USE_FMA__
#define _mm256_FMA_pd(a,b,c) _mm256_fmadd_pd(a,b,c)
#define _mm256_NFMA_pd(a,b,c) _mm256_fnmadd_pd(a,b,c)
#define _mm256_FMSUB_pd(a,b,c) _mm256_fmsub_pd(a,b,c)
#define _mm256_FMA_ps(a,b,c) _mm256_fmadd_ps(a,b,c)
#define _mm256_NFMA_ps(a,b,c) _mm256_fnmadd_ps(a,b,c)
#define _mm256_FMSUB_ps(a,b,c) _mm256_fmsub_ps(a,b,c)
#endif
#endif
//Forward declaration
__forceinline void hh_trafo_kernel_4_AVX_4hv_single(double* q, double* hh, int nb, int ldq, int ldh, double s_1_2, double s_1_3, double s_2_3, double s_1_4, double s_2_4, double s_3_4);
__forceinline void hh_trafo_kernel_8_AVX_4hv_single(double* q, double* hh, int nb, int ldq, int ldh, double s_1_2, double s_1_3, double s_2_3, double s_1_4, double s_2_4, double s_3_4);
__forceinline void hh_trafo_kernel_12_AVX_4hv_single(double* q, double* hh, int nb, int ldq, int ldh, double s_1_2, double s_1_3, double s_2_3, double s_1_4, double s_2_4, double s_3_4);
__forceinline void hh_trafo_kernel_4_AVX_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4);
__forceinline void hh_trafo_kernel_8_AVX_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4);
__forceinline void hh_trafo_kernel_16_AVX_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4);
__forceinline void hh_trafo_kernel_24_AVX_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4);
void quad_hh_trafo_real_avx_avx2_4hv_single_(double* q, double* hh, int* pnb, int* pnq, int* pldq, int* pldh);
void quad_hh_trafo_real_avx_avx2_4hv_single_(float* q, float* hh, int* pnb, int* pnq, int* pldq, int* pldh);
#if 0
void quad_hh_trafo_fast_single_(double* q, double* hh, int* pnb, int* pnq, int* pldq, int* pldh);
void quad_hh_trafo_fast_single_(float* q, float* hh, int* pnb, int* pnq, int* pldq, int* pldh);
#endif
void quad_hh_trafo_real_avx_avx2_4hv_single_(double* q, double* hh, int* pnb, int* pnq, int* pldq, int* pldh)
void quad_hh_trafo_real_avx_avx2_4hv_single_(float* q, float* hh, int* pnb, int* pnq, int* pldq, int* pldh)
{
int i;
int nb = *pnb;
......@@ -103,12 +105,12 @@ void quad_hh_trafo_real_avx_avx2_4hv_single_(double* q, double* hh, int* pnb, in
// calculating scalar products to compute
// 4 householder vectors simultaneously
double s_1_2 = hh[(ldh)+1];
double s_1_3 = hh[(ldh*2)+2];
double s_2_3 = hh[(ldh*2)+1];
double s_1_4 = hh[(ldh*3)+3];
double s_2_4 = hh[(ldh*3)+2];
double s_3_4 = hh[(ldh*3)+1];
float s_1_2 = hh[(ldh)+1];
float s_1_3 = hh[(ldh*2)+2];
float s_2_3 = hh[(ldh*2)+1];
float s_1_4 = hh[(ldh*3)+3];
float s_2_4 = hh[(ldh*3)+2];
float s_3_4 = hh[(ldh*3)+1];
// calculate scalar product of first and fourth householder vector
// loop counter = 2
......@@ -145,116 +147,47 @@ void quad_hh_trafo_real_avx_avx2_4hv_single_(double* q, double* hh, int* pnb, in
// printf("s_3_4: %f\n", s_3_4);
// Production level kernel calls with padding
#ifdef __AVX__
for (i = 0; i < nq-8; i+=12)
for (i = 0; i < nq-20; i+=24)
{
hh_trafo_kernel_12_AVX_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
hh_trafo_kernel_24_AVX_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
if (nq == i)
{
return;
}
else
if (nq-i == 20)
{
if (nq-i > 4)
{
hh_trafo_kernel_8_AVX_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
else
{
hh_trafo_kernel_4_AVX_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
hh_trafo_kernel_16_AVX_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
hh_trafo_kernel_4_AVX_4hv_single(&q[i+16], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
#else
for (i = 0; i < nq-4; i+=6)
else if (nq-i == 16)
{
hh_trafo_kernel_6_SSE_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
hh_trafo_kernel_16_AVX_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
if (nq == i)
else if (nq-i == 12)
{
return;
hh_trafo_kernel_8_AVX_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
hh_trafo_kernel_4_AVX_4hv_single(&q[i+8], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
else
else if (nq-i == 8)
{
if (nq-i > 2)
{
hh_trafo_kernel_4_SSE_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
else
{
hh_trafo_kernel_2_SSE_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
hh_trafo_kernel_8_AVX_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
#endif
}
#if 0
void quad_hh_trafo_fast_single_(double* q, double* hh, int* pnb, int* pnq, int* pldq, int* pldh)
{
int i;
int nb = *pnb;
int nq = *pldq;
int ldq = *pldq;
int ldh = *pldh;
// calculating scalar products to compute
// 4 householder vectors simultaneously
double s_1_2 = hh[(ldh)+1];
double s_1_3 = hh[(ldh*2)+2];
double s_2_3 = hh[(ldh*2)+1];
double s_1_4 = hh[(ldh*3)+3];
double s_2_4 = hh[(ldh*3)+2];
double s_3_4 = hh[(ldh*3)+1];
// calculate scalar product of first and fourth householder vector
// loop counter = 2
s_1_2 += hh[2-1] * hh[(2+ldh)];
s_2_3 += hh[(ldh)+2-1] * hh[2+(ldh*2)];
s_3_4 += hh[(ldh*2)+2-1] * hh[2+(ldh*3)];
// loop counter = 3
s_1_2 += hh[3-1] * hh[(3+ldh)];
s_2_3 += hh[(ldh)+3-1] * hh[3+(ldh*2)];
s_3_4 += hh[(ldh*2)+3-1] * hh[3+(ldh*3)];
s_1_3 += hh[3-2] * hh[3+(ldh*2)];
s_2_4 += hh[(ldh*1)+3-2] * hh[3+(ldh*3)];
#pragma ivdep
for (i = 4; i < nb; i++)
else
{
s_1_2 += hh[i-1] * hh[(i+ldh)];
s_2_3 += hh[(ldh)+i-1] * hh[i+(ldh*2)];
s_3_4 += hh[(ldh*2)+i-1] * hh[i+(ldh*3)];
s_1_3 += hh[i-2] * hh[i+(ldh*2)];
s_2_4 += hh[(ldh*1)+i-2] * hh[i+(ldh*3)];
s_1_4 += hh[i-3] * hh[i+(ldh*3)];
hh_trafo_kernel_4_AVX_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
// Production level kernel calls with padding
#ifdef __AVX__
for (i = 0; i < nq; i+=12)
{
hh_trafo_kernel_12_AVX_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
#else
for (i = 0; i < nq; i+=6)
{
hh_trafo_kernel_6_SSE_4hv_single(&q[i], hh, nb, ldq, ldh, s_1_2, s_1_3, s_2_3, s_1_4, s_2_4, s_3_4);
}
#endif
}
#endif
/**
* Unrolled kernel that computes
* 12 rows of Q simultaneously, a
* 24 rows of Q simultaneously, a
* matrix vector product with two householder
* vectors + a rank 1 update is performed
*/
__forceinline void hh_trafo_kernel_12_AVX_4hv_single(double* q, double* hh, int nb, int ldq, int ldh, double s_1_2, double s_1_3, double s_2_3, double s_1_4, double s_2_4, double s_3_4)
__forceinline void hh_trafo_kernel_24_AVX_4hv_single(float* q, float* hh, int nb, int ldq, int ldh, float s_1_2, float s_1_3, float s_2_3, float s_1_4, float s_2_4, float s_3_4)
{
/////////////////////////////////////////////////////
// Matrix Vector Multiplication, Q [12 x nb+3] * hh
......@@ -262,518 +195,892 @@ __forceinline void hh_trafo_kernel_12_AVX_4hv_single(double* q, double* hh, int
/////////////////////////////////////////////////////
int i;
__m256d a1_1 = _mm256_load_pd(&q[ldq*3]);
__m256d a2_1 = _mm256_load_pd(&q[ldq*2]);
__m256d a3_1 = _mm256_load_pd(&q[ldq]);
__m256d a4_1 = _mm256_load_pd(&q[0]);
__m256d h_2_1 = _mm256_broadcast_sd(&hh[ldh+1]);
__m256d h_3_2 = _mm256_broadcast_sd(&hh[(ldh*2)+1]);
__m256d h_3_1 = _mm256_broadcast_sd(&hh[(ldh*2)+2]);
__m256d h_4_3 = _mm256_broadcast_sd(&hh[(ldh*3)+1]);
__m256d h_4_2 = _mm256_broadcast_sd(&hh[(ldh*3)+2]);
__m256d h_4_1 = _mm256_broadcast_sd(&hh[(ldh*3)+3]);
#ifdef __ELPA_USE_FMA__
register __m256d w1 = _mm256_FMA_pd(a3_1, h_4_3, a4_1);
w1 = _mm256_FMA_pd(a2_1, h_4_2, w1);
w1 = _mm256_FMA_pd(a1_1, h_4_1, w1);
register __m256d z1 = _mm256_FMA_pd(a2_1, h_3_2, a3_1);
z1 = _mm256_FMA_pd(a1_1, h_3_1, z1);
register __m256d y1 = _mm256_FMA_pd(a1_1, h_2_1, a2_1);
register __m256d x1 = a1_1;
#else
register __m256d w1 = _mm256_add_pd(a4_1, _mm256_mul_pd(a3_1, h_4_3));
w1 = _mm256_add_pd(w1, _mm256_mul_pd(a2_1, h_4_2));
w1 = _mm256_add_pd(w1, _mm256_mul_pd(a1_1, h_4_1));
register __m256d z1 = _mm256_add_pd(a3_1, _mm256_mul_pd(a2_1, h_3_2));
z1 = _mm256_add_pd(z1, _mm256_mul_pd(a1_1, h_3_1));
register __m256d y1 = _mm256_add_pd(a2_1, _mm256_mul_pd(a1_1, h_2_1));
register __m256d x1 = a1_1;
#endif
__m256d a1_2 = _mm256_load_pd(&q[(ldq*3)+4]);
__m256d a2_2 = _mm256_load_pd(&q[(ldq*2)+4]);
__m256d a3_2 = _mm256_load_pd(&q[ldq+4]);
__m256d a4_2 = _mm256_load_pd(&q[0+4]);
#ifdef __ELPA_USE_FMA__
register __m256d w2 = _mm256_FMA_pd(a3_2, h_4_3, a4_2);
w2 = _mm256_FMA_pd(a2_2, h_4_2, w2);
w2 = _mm256_FMA_pd(a1_2, h_4_1, w2);
register __m256d z2 = _mm256_FMA_pd(a2_2, h_3_2, a3_2);
z2 = _mm256_FMA_pd(a1_2, h_3_1, z2);
register __m256d y2 = _mm256_FMA_pd(a1_2, h_2_1, a2_2);
register __m256d x2 = a1_2;
#else
register __m256d w2 = _mm256_add_pd(a4_2, _mm256_mul_pd(a3_2, h_4_3));
w2 = _mm256_add_pd(w2, _mm256_mul_pd(a2_2, h_4_2));
w2 = _mm256_add_pd(w2, _mm256_mul_pd(a1_2, h_4_1));
register __m256d z2 = _mm256_add_pd(a3_2, _mm256_mul_pd(a2_2, h_3_2));
z2 = _mm256_add_pd(z2, _mm256_mul_pd(a1_2, h_3_1));
register __m256d y2 = _mm256_add_pd(a2_2, _mm256_mul_pd(a1_2, h_2_1));
register __m256d x2 = a1_2;
#endif
__m256d a1_3 = _mm256_load_pd(&q[(ldq*3)+8]);
__m256d a2_3 = _mm256_load_pd(&q[(ldq*2)+8]);
__m256d a3_3 = _mm256_load_pd(&q[ldq+8]);
__m256d a4_3 = _mm256_load_pd(&q[0+8]);
#ifdef __ELPA_USE_FMA__
register __m256d w3 = _mm256_FMA_pd(a3_3, h_4_3, a4_3);
w3 = _mm256_FMA_pd(a2_3, h_4_2, w3);
w3 = _mm256_FMA_pd(a1_3, h_4_1, w3);
register __m256d z3 = _mm256_FMA_pd(a2_3, h_3_2, a3_3);
z3 = _mm256_FMA_pd(a1_3, h_3_1, z3);
register __m256d y3 = _mm256_FMA_pd(a1_3, h_2_1, a2_3);
register __m256d x3 = a1_3;
#else
register __m256d w3 = _mm256_add_pd(a4_3, _mm256_mul_pd(a3_3, h_4_3));
w3 = _mm256_add_pd(w3, _mm256_mul_pd(a2_3, h_4_2));
w3 = _mm256_add_pd(w3, _mm256_mul_pd(a1_3, h_4_1));
register __m256d z3 = _mm256_add_pd(a3_3, _mm256_mul_pd(a2_3, h_3_2));
z3 = _mm256_add_pd(z3, _mm256_mul_pd(a1_3, h_3_1));
register __m256d y3 = _mm256_add_pd(a2_3, _mm256_mul_pd(a1_3, h_2_1));
register __m256d x3 = a1_3;
#endif
__m256d q1;
__m256d q2;
__m256d q3;
__m256d h1;
__m256d h2;
__m256d h3;
__m256d h4;
__m256 a1_1 = _mm256_load_ps(&q[ldq*3]); // q(1,4) | .. | q(8,4)
__m256 a2_1 = _mm256_load_ps(&q[ldq*2]); // q(1,3) | q(2,3) | q(3,3) | q(4,3) | q(5,3) | q(6,3) | q(7,3) | q(8,3)
__m256 a3_1 = _mm256_load_ps(&q[ldq]); // q(1,2) | q(2,2) | q(3,2) | q(4,2) | q(5,2) | q(6,2) | q(7,2) | q(8,2)
__m256 a4_1 = _mm256_load_ps(&q[0]); // q(1,1) | q(2,1) | q(3,1) | q(4,1) | q(5,1) | q(6,1) | q(7,1) | q(8,1)
__m256 h_2_1 = _mm256_broadcast_ss(&hh[ldh+1]);
__m256 h_3_2 = _mm256_broadcast_ss(&hh[(ldh*2)+1]);
__m256 h_3_1 = _mm256_broadcast_ss(&hh[(ldh*2)+2]);
__m256 h_4_3 = _mm256_broadcast_ss(&hh[(ldh*3)+1]);
__m256 h_4_2 = _mm256_broadcast_ss(&hh[(ldh*3)+2]);
__m256 h_4_1 = _mm256_broadcast_ss(&hh[(ldh*3)+3]);
#ifdef __ELPA_USE_FMA__
register __m256 w1 = _mm256_FMA_ps(a3_1, h_4_3, a4_1);
w1 = _mm256_FMA_ps(a2_1, h_4_2, w1);
w1 = _mm256_FMA_ps(a1_1, h_4_1, w1);
register __m256 z1 = _mm256_FMA_ps(a2_1, h_3_2, a3_1);
z1 = _mm256_FMA_ps(a1_1, h_3_1, z1);
register __m256 y1 = _mm256_FMA_ps(a1_1, h_2_1, a2_1);
register __m256 x1 = a1_1;
#else
register __m256 w1 = _mm256_add_ps(a4_1, _mm256_mul_ps(a3_1, h_4_3));
w1 = _mm256_add_ps(w1, _mm256_mul_ps(a2_1, h_4_2));
w1 = _mm256_add_ps(w1, _mm256_mul_ps(a1_1, h_4_1));
register __m256 z1 = _mm256_add_ps(a3_1, _mm256_mul_ps(a2_1, h_3_2));
z1 = _mm256_add_ps(z1, _mm256_mul_ps(a1_1, h_3_1));
register __m256 y1 = _mm256_add_ps(a2_1, _mm256_mul_ps(a1_1, h_2_1));
register __m256 x1 = a1_1;
#endif
__m256 a1_2 = _mm256_load_ps(&q[(ldq*3)+8]); // q(9,4) | ... | q(16,4)
__m256 a2_2 = _mm256_load_ps(&q[(ldq*2)+8]);
__m256 a3_2 = _mm256_load_ps(&q[ldq+8]); // q(9,2) | ... | q(16,2)
__m256 a4_2 = _mm256_load_ps(&q[0+8]); // q(9,1) | q(10,1) .... | q(16,1)
#ifdef __ELPA_USE_FMA__
register __m256 w2 = _mm256_FMA_ps(a3_2, h_4_3, a4_2);
w2 = _mm256_FMA_ps(a2_2, h_4_2, w2);
w2 = _mm256_FMA_ps(a1_2, h_4_1, w2);
register __m256 z2 = _mm256_FMA_ps(a2_2, h_3_2, a3_2);
z2 = _mm256_FMA_ps(a1_2, h_3_1, z2);
register __m256 y2 = _mm256_FMA_ps(a1_2, h_2_1, a2_2);
register __m256 x2 = a1_2;
#else
register __m256 w2 = _mm256_add_ps(a4_2, _mm256_mul_ps(a3_2, h_4_3));
w2 = _mm256_add_ps(w2, _mm256_mul_ps(a2_2, h_4_2));
w2 = _mm256_add_ps(w2, _mm256_mul_ps(a1_2, h_4_1));
register __m256 z2 = _mm256_add_ps(a3_2, _mm256_mul_ps(a2_2, h_3_2));
z2 = _mm256_add_ps(z2, _mm256_mul_ps(a1_2, h_3_1));
register __m256 y2 = _mm256_add_ps(a2_2, _mm256_mul_ps(a1_2, h_2_1));
register __m256 x2 = a1_2;
#endif
__m256 a1_3 = _mm256_load_ps(&q[(ldq*3)+16]);
__m256 a2_3 = _mm256_load_ps(&q[(ldq*2)+16]);
__m256 a3_3 = _mm256_load_ps(&q[ldq+16]);
__m256 a4_3 = _mm256_load_ps(&q[0+16]);
#ifdef __ELPA_USE_FMA__
register __m256 w3 = _mm256_FMA_ps(a3_3, h_4_3, a4_3);
w3 = _mm256_FMA_ps(a2_3, h_4_2, w3);
w3 = _mm256_FMA_ps(a1_3, h_4_1, w3);
register __m256 z3 = _mm256_FMA_ps(a2_3, h_3_2, a3_3);
z3 = _mm256_FMA_ps(a1_3, h_3_1, z3);
register __m256 y3 = _mm256_FMA_ps(a1_3, h_2_1, a2_3);
register __m256 x3 = a1_3;
#else
register __m256 w3 = _mm256_add_ps(a4_3, _mm256_mul_ps(a3_3, h_4_3));
w3 = _mm256_add_ps(w3, _mm256_mul_ps(a2_3, h_4_2));
w3 = _mm256_add_ps(w3, _mm256_mul_ps(a1_3, h_4_1));
register __m256 z3 = _mm256_add_ps(a3_3, _mm256_mul_ps(a2_3, h_3_2));
z3 = _mm256_add_ps(z3, _mm256_mul_ps(a1_3, h_3_1));
register __m256 y3 = _mm256_add_ps(a2_3, _mm256_mul_ps(a1_3, h_2_1));
register __m256 x3 = a1_3;
#endif
__m256 q1;
__m256 q2;
__m256 q3;
__m256 h1;
__m256 h2;
__m256 h3;
__m256 h4;
for(i = 4; i < nb; i++)
{
h1 = _mm256_broadcast_sd(&hh[i-3]);
q1 = _mm256_load_pd(&q[i*ldq]);
q2 = _mm256_load_pd(&q[(i*ldq)+4]);
q3 = _mm256_load_pd(&q[(i*ldq)+8]);
h1 = _mm256_broadcast_ss(&hh[i-3]);
q1 = _mm256_load_ps(&q[i*ldq]); // | q(i,2) | q(i+1,2) | q(i+2,2) | q(i+3,2) | q(i+4,2) | q(i+5,2) | q(i+5,2) | q(i+7,2)
q2 = _mm256_load_ps(&q[(i*ldq)+8]);
q3 = _mm256_load_ps(&q[(i*ldq)+16]);
#ifdef __ELPA_USE_FMA__
x1 = _mm256_FMA_pd(q1, h1, x1);
x2 = _mm256_FMA_pd(q2, h1, x2);
x3 = _mm256_FMA_pd(q3, h1, x3);
x1 = _mm256_FMA_ps(q1, h1, x1);
x2 = _mm256_FMA_ps(q2, h1, x2);
x3 = _mm256_FMA_ps(q3, h1, x3);
#else
x1 = _mm256_add_pd(x1, _mm256_mul_pd(q1,h1));
x2 = _mm256_add_pd(x2, _mm256_mul_pd(q2,h1));
x3 = _mm256_add_pd(x3, _mm256_mul_pd(q3,h1));
x1 = _mm256_add_ps(x1, _mm256_mul_ps(q1,h1));
x2 = _mm256_add_ps(x2, _mm256_mul_ps(q2,h1));
x3 = _mm256_add_ps(x3, _mm256_mul_ps(q3,h1));
#endif
h2 = _mm256_broadcast_sd(&hh[ldh+i-2]);
h2 = _mm256_broadcast_ss(&hh[ldh+i-2]);
#ifdef __ELPA_USE_FMA__
y1 = _mm256_FMA_pd(q1, h2, y1);
y2 = _mm256_FMA_pd(q2, h2, y2);
y3 = _mm256_FMA_pd(q3, h2, y3);
y1 = _mm256_FMA_ps(q1, h2, y1);
y2 = _mm256_FMA_ps(q2, h2, y2);
y3 = _mm256_FMA_ps(q3, h2, y3);
#else
y1 = _mm256_add_pd(y1, _mm256_mul_pd(q1,h2));
y2 = _mm256_add_pd(y2, _mm256_mul_pd(q2,h2));
y3 = _mm256_add_pd(y3, _mm256_mul_pd(q3,h2));
y1 = _mm256_add_ps(y1, _mm256_mul_ps(q1,h2));
y2 = _mm256_add_ps(y2, _mm256_mul_ps(q2,h2));
y3 = _mm256_add_ps(y3, _mm256_mul_ps(q3,h2));
#endif
h3 = _mm256_broadcast_sd(&hh[(ldh*2)+i-1]);
h3 = _mm256_broadcast_ss(&hh[(ldh*2)+i-1]);
#ifdef __ELPA_USE_FMA__
z1 = _mm256_FMA_pd(q1, h3, z1);
z2 = _mm256_FMA_pd(q2, h3, z2);
z3 = _mm256_FMA_pd(q3, h3, z3);
z1 = _mm256_FMA_ps(q1, h3, z1);
z2 = _mm256_FMA_ps(q2, h3, z2);
z3 = _mm256_FMA_ps(q3, h3, z3);
#else
z1 = _mm256_add_pd(z1, _mm256_mul_pd(q1,h3));
z2 = _mm256_add_pd(z2, _mm256_mul_pd(q2,h3));
z3 = _mm256_add_pd(z3, _mm256_mul_pd(q3,h3));
z1 = _mm256_add_ps(z1, _mm256_mul_ps(q1,h3));
z2 = _mm256_add_ps(z2, _mm256_mul_ps(q2,h3));
z3 = _mm256_add_ps(z3, _mm256_mul_ps(q3,h3));
#endif
h4 = _mm256_broadcast_sd(&hh[(ldh*3)+i]);
h4 = _mm256_broadcast_ss(&hh[(ldh*3)+i]);
#ifdef __ELPA_USE_FMA__
w1 = _mm256_FMA_pd(q1, h4, w1);
w2 = _mm256_FMA_pd(q2, h4, w2);
w3 = _mm256_FMA_pd(q3, h4, w3);
w1 = _mm256_FMA_ps(q1, h4, w1);
w2 = _mm256_FMA_ps(q2, h4, w2);
w3 = _mm256_FMA_ps(q3, h4, w3);
#else
w1 = _mm256_add_pd(w1, _mm256_mul_pd(q1,h4));
w2 = _mm256_add_pd(w2, _mm256_mul_pd(q2,h4));
w3 = _mm256_add_pd(w3, _mm256_mul_pd(q3,h4));
w1 = _mm256_add_ps(w1, _mm256_mul_ps(q1,h4));
w2 = _mm256_add_ps(w2, _mm256_mul_ps(q2,h4));
w3 = _mm256_add_ps(w3, _mm256_mul_ps(q3,h4));
#endif
}
h1 = _mm256_broadcast_sd(&hh[nb-3]);
h1 = _mm256_broadcast_ss(&hh[nb-3]);
q1 = _mm256_load_pd(&q[nb*ldq]);
q2 = _mm256_load_pd(&q[(nb*ldq)+4]);
q3 = _mm256_load_pd(&q[(nb*ldq)+8]);
q1 = _mm256_load_ps(&q[nb*ldq]);
// // carefull we just need another 4 floats, the rest is zero'd
// q2 = _mm256_castps128_ps256(_mm_load_ps(&q[(nb*ldq)+8]));
q2 = _mm256_load_ps(&q[(nb*ldq)+8]);
q3 = _mm256_load_ps(&q[(nb*ldq)+16]);
#ifdef __ELPA_USE_FMA__
x1 = _mm256_FMA_pd(q1, h1, x1);
x2 = _mm256_FMA_pd(q2, h1, x2);
x3 = _mm256_FMA_pd(q3, h1, x3);
x1 = _mm256_FMA_ps(q1, h1, x1);
x2 = _mm256_FMA_ps(q2, h1, x2);
x3 = _mm256_FMA_ps(q3, h1, x3);
#else
x1 = _mm256_add_pd(x1, _mm256_mul_pd(q1,h1));
x2 = _mm256_add_pd(x2, _mm256_mul_pd(q2,h1));
x3 = _mm256_add_pd(x3, _mm256_mul_pd(q3,h1));
x1 = _mm256_add_ps(x1, _mm256_mul_ps(q1,h1));
x2 = _mm256_add_ps(x2, _mm256_mul_ps(q2,h1));
x3 = _mm256_add_ps(x3, _mm256_mul_ps(q3,h1));
#endif
h2 = _mm256_broadcast_sd(&hh[ldh+nb-2]);
h2 = _mm256_broadcast_ss(&hh[ldh+nb-2]);
#ifdef __FMA4_
y1 = _mm256_FMA_pd(q1, h2, y1);
y2 = _mm256_FMA_pd(q2, h2, y2);
y3 = _mm256_FMA_pd(q3, h2, y3);
y1 = _mm256_FMA_ps(q1, h2, y1);
y2 = _mm256_FMA_ps(q2, h2, y2);
y3 = _mm256_FMA_ps(q3, h2, y3);
#else
y1 = _mm256_add_pd(y1, _mm256_mul_pd(q1,h2));
y2 = _mm256_add_pd(y2, _mm256_mul_pd(q2,h2));
y3 = _mm256_add_pd(y3, _mm256_mul_pd(q3,h2));
y1 = _mm256_add_ps(y1, _mm256_mul_ps(q1,h2));
y2 = _mm256_add_ps(y2, _mm256_mul_ps(q2,h2));
y3 = _mm256_add_ps(y3, _mm256_mul_ps(q3,h2));
#endif
h3 = _mm256_broadcast_sd(&hh[(ldh*2)+nb-1]);
h3 = _mm256_broadcast_ss(&hh[(ldh*2)+nb-1]);
#ifdef __ELPA_USE_FMA__
z1 = _mm256_FMA_pd(q1, h3, z1);
z2 = _mm256_FMA_pd(q2, h3, z2);
z3 = _mm256_FMA_pd(q3, h3, z3);
z1 = _mm256_FMA_ps(q1, h3, z1);
z2 = _mm256_FMA_ps(q2, h3, z2);
z3 = _mm256_FMA_ps(q3, h3, z3);
#else
z1 = _mm256_add_pd(z1, _mm256_mul_pd(q1,h3));
z2 = _mm256_add_pd(z2, _mm256_mul_pd(q2,h3));
z3 = _mm256_add_pd(z3, _mm256_mul_pd(q3,h3));
z1 = _mm256_add_ps(z1, _mm256_mul_ps(q1,h3));
z2 = _mm256_add_ps(z2, _mm256_mul_ps(q2,h3));
z3 = _mm256_add_ps(z3, _mm256_mul_ps(q3,h3));
#endif
h1 = _mm256_broadcast_sd(&hh[nb-2]);
h1 = _mm256_broadcast_ss(&hh[nb-2]);
q1 = _mm256_load_pd(&q[(nb+1)*ldq]);
q2 = _mm256_load_pd(&q[((nb+1)*ldq)+4]);
q3 = _mm256_load_pd(&q[((nb+1)*ldq)+8]);
q1 = _mm256_load_ps(&q[(nb+1)*ldq]);
q2 = _mm256_load_ps(&q[((nb+1)*ldq)+8]);
q3 = _mm256_load_ps(&q[((nb+1)*ldq)+16]);
#ifdef __ELPA_USE_FMA__
x1 = _mm256_FMA_pd(q1, h1, x1);