Commit 8fd36a16 authored by Andreas Marek's avatar Andreas Marek
Browse files

Error in computing the stripe_width for AVX-512 kernels

parent 6788de0b
...@@ -3544,7 +3544,7 @@ ...@@ -3544,7 +3544,7 @@
#ifdef DOUBLE_PRECISION_REAL #ifdef DOUBLE_PRECISION_REAL
stripe_width = 48 ! Must be a multiple of 4 stripe_width = 48 ! Must be a multiple of 4
#else #else
stripe_width = 96 ! Must be a multiple of 8 stripe_width = 48 ! Must be a multiple of 8
#endif #endif
#ifdef WITH_OPENMP #ifdef WITH_OPENMP
stripe_count = (thread_width-1)/stripe_width + 1 stripe_count = (thread_width-1)/stripe_width + 1
...@@ -3562,7 +3562,7 @@ ...@@ -3562,7 +3562,7 @@
THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX512_BLOCK4 .or. & THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX512_BLOCK4 .or. &
THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX512_BLOCK6) then THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX512_BLOCK6) then
stripe_width = ((stripe_width+3)/8)*8 ! Must be a multiple of 8 because of AVX-512 memory alignment of 64 bytes stripe_width = ((stripe_width+7)/8)*8 ! Must be a multiple of 8 because of AVX-512 memory alignment of 64 bytes
! (8 * sizeof(double) == 64) ! (8 * sizeof(double) == 64)
else else
......
...@@ -216,7 +216,7 @@ void double_hh_trafo_real_avx512_2hv_double(double* q, double* hh, int* pnb, int ...@@ -216,7 +216,7 @@ void double_hh_trafo_real_avx512_2hv_double(double* q, double* hh, int* pnb, int
x3 = _mm512_mul_pd(x3, h1); x3 = _mm512_mul_pd(x3, h1);
x4 = _mm512_mul_pd(x4, h1); x4 = _mm512_mul_pd(x4, h1);
h1 = (__m512d) _mm512_xor_si512((__m512i) tau2, (__m512i) sign); h1 = (__m512d) _mm512_xor_epi64((__m512i) tau2, (__m512i) sign);
h2 = _mm512_mul_pd(h1, vs); h2 = _mm512_mul_pd(h1, vs);
y1 = _mm512_FMA_pd(y1, h1, _mm512_mul_pd(x1,h2)); y1 = _mm512_FMA_pd(y1, h1, _mm512_mul_pd(x1,h2));
y2 = _mm512_FMA_pd(y2, h1, _mm512_mul_pd(x2,h2)); y2 = _mm512_FMA_pd(y2, h1, _mm512_mul_pd(x2,h2));
...@@ -364,7 +364,7 @@ void double_hh_trafo_real_avx512_2hv_double(double* q, double* hh, int* pnb, int ...@@ -364,7 +364,7 @@ void double_hh_trafo_real_avx512_2hv_double(double* q, double* hh, int* pnb, int
x2 = _mm512_mul_pd(x2, h1); x2 = _mm512_mul_pd(x2, h1);
x3 = _mm512_mul_pd(x3, h1); x3 = _mm512_mul_pd(x3, h1);
h1 = (__m512d) _mm512_xor_si512((__m512i) tau2, (__m512i) sign); h1 = (__m512d) _mm512_xor_epi64((__m512i) tau2, (__m512i) sign);
h2 = _mm512_mul_pd(h1, vs); h2 = _mm512_mul_pd(h1, vs);
y1 = _mm512_FMA_pd(y1, h1, _mm512_mul_pd(x1,h2)); y1 = _mm512_FMA_pd(y1, h1, _mm512_mul_pd(x1,h2));
y2 = _mm512_FMA_pd(y2, h1, _mm512_mul_pd(x2,h2)); y2 = _mm512_FMA_pd(y2, h1, _mm512_mul_pd(x2,h2));
...@@ -481,11 +481,11 @@ void double_hh_trafo_real_avx512_2hv_double(double* q, double* hh, int* pnb, int ...@@ -481,11 +481,11 @@ void double_hh_trafo_real_avx512_2hv_double(double* q, double* hh, int* pnb, int
__m512d tau2 = _mm512_set1_pd(hh[ldh]); __m512d tau2 = _mm512_set1_pd(hh[ldh]);
__m512d vs = _mm512_set1_pd(s); __m512d vs = _mm512_set1_pd(s);
h1 = (__m512d) _mm512_xor_si512((__m512i) tau1, (__m512i) sign); h1 = (__m512d) _mm512_xor_epi64((__m512i) tau1, (__m512i) sign);
x1 = _mm512_mul_pd(x1, h1); x1 = _mm512_mul_pd(x1, h1);
x2 = _mm512_mul_pd(x2, h1); x2 = _mm512_mul_pd(x2, h1);
h1 = (__m512d) _mm512_xor_si512((__m512i) tau2, (__m512i) sign); h1 = (__m512d) _mm512_xor_epi64((__m512i) tau2, (__m512i) sign);
h2 = _mm512_mul_pd(h1, vs); h2 = _mm512_mul_pd(h1, vs);
y1 = _mm512_FMA_pd(y1, h1, _mm512_mul_pd(x1,h2)); y1 = _mm512_FMA_pd(y1, h1, _mm512_mul_pd(x1,h2));
...@@ -580,10 +580,10 @@ void double_hh_trafo_real_avx512_2hv_double(double* q, double* hh, int* pnb, int ...@@ -580,10 +580,10 @@ void double_hh_trafo_real_avx512_2hv_double(double* q, double* hh, int* pnb, int
__m512d tau2 = _mm512_set1_pd(hh[ldh]); __m512d tau2 = _mm512_set1_pd(hh[ldh]);
__m512d vs = _mm512_set1_pd(s); __m512d vs = _mm512_set1_pd(s);
h1 = (__m512d) _mm512_xor_si512((__m512i) tau1, (__m512i) sign); h1 = (__m512d) _mm512_xor_epi64((__m512i) tau1, (__m512i) sign);
x1 = _mm512_mul_pd(x1, h1); x1 = _mm512_mul_pd(x1, h1);
h1 = (__m512d) _mm512_xor_si512((__m512i) tau2, (__m512i) sign); h1 = (__m512d) _mm512_xor_epi64((__m512i) tau2, (__m512i) sign);
h2 = _mm512_mul_pd(h1, vs); h2 = _mm512_mul_pd(h1, vs);
y1 = _mm512_FMA_pd(y1, h1, _mm512_mul_pd(x1,h2)); y1 = _mm512_FMA_pd(y1, h1, _mm512_mul_pd(x1,h2));
...@@ -664,9 +664,9 @@ void double_hh_trafo_real_avx512_2hv_double(double* q, double* hh, int* pnb, int ...@@ -664,9 +664,9 @@ void double_hh_trafo_real_avx512_2hv_double(double* q, double* hh, int* pnb, int
__m512d tau2 = _mm512_set1_pd(hh[ldh]); __m512d tau2 = _mm512_set1_pd(hh[ldh]);
__m512d vs = _mm512_set1_pd(s); __m512d vs = _mm512_set1_pd(s);
h1 = (__m512d) _mm512_xor_si512((__m512i) tau1, (__m512i) sign); h1 = (__m512d) _mm512_xor_epi64((__m512i) tau1, (__m512i) sign);
x1 = _mm512_mul_pd(x1, h1); x1 = _mm512_mul_pd(x1, h1);
h1 = (__m512d) _mm512_xor_si512((__m512i) tau2, (__m512i) sign); h1 = (__m512d) _mm512_xor_epi64((__m512i) tau2, (__m512i) sign);
h2 = _mm512_mul_pd(h1, vs); h2 = _mm512_mul_pd(h1, vs);
y1 = _mm512_FMA_pd(y1, h1, _mm512_mul_pd(x1,h2)); y1 = _mm512_FMA_pd(y1, h1, _mm512_mul_pd(x1,h2));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment