Commit 366d0f43 authored by Andreas Marek's avatar Andreas Marek
Browse files

Error in double precision AVX512 block4 kernel

parent 8fd36a16
...@@ -3574,8 +3574,8 @@ ...@@ -3574,8 +3574,8 @@
THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX512_BLOCK4 .or. & THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX512_BLOCK4 .or. &
THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX512_BLOCK6) then THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX512_BLOCK6) then
stripe_width = ((stripe_width+7)/8)*8 ! Must be a multiple of 16 because of AVX-512 memory alignment of 64 bytes stripe_width = ((stripe_width+15)/16)*16 ! Must be a multiple of 16 because of AVX-512 memory alignment of 64 bytes
! (16 * sizeof(float) == 64) ! (16 * sizeof(float) == 64)
else else
......
...@@ -60,9 +60,9 @@ ...@@ -60,9 +60,9 @@
// Adapted for building a shared-library by Andreas Marek, MPCDF (andreas.marek@mpcdf.mpg.de) // Adapted for building a shared-library by Andreas Marek, MPCDF (andreas.marek@mpcdf.mpg.de)
// -------------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------------
#include "config-f90.h" #include "config-f90.h"
#include <x86intrin.h> #include <x86intrin.h>
#define __forceinline __attribute__((always_inline)) static #define __forceinline __attribute__((always_inline)) static
#ifdef HAVE_AVX512 #ifdef HAVE_AVX512
...@@ -72,7 +72,6 @@ ...@@ -72,7 +72,6 @@
#define _mm512_FMSUB_pd(a,b,c) _mm512_fmsub_pd(a,b,c) #define _mm512_FMSUB_pd(a,b,c) _mm512_fmsub_pd(a,b,c)
#endif #endif
//Forward declaration //Forward declaration
__forceinline void hh_trafo_kernel_8_AVX512_4hv_double(double* q, double* hh, int nb, int ldq, int ldh, double s_1_2, double s_1_3, double s_2_3, double s_1_4, double s_2_4, double s_3_4); __forceinline void hh_trafo_kernel_8_AVX512_4hv_double(double* q, double* hh, int nb, int ldq, int ldh, double s_1_2, double s_1_3, double s_2_3, double s_1_4, double s_2_4, double s_3_4);
__forceinline void hh_trafo_kernel_16_AVX512_4hv_double(double* q, double* hh, int nb, int ldq, int ldh, double s_1_2, double s_1_3, double s_2_3, double s_1_4, double s_2_4, double s_3_4); __forceinline void hh_trafo_kernel_16_AVX512_4hv_double(double* q, double* hh, int nb, int ldq, int ldh, double s_1_2, double s_1_3, double s_2_3, double s_1_4, double s_2_4, double s_3_4);
...@@ -342,10 +341,10 @@ __forceinline void hh_trafo_kernel_32_AVX512_4hv_double(double* q, double* hh, i ...@@ -342,10 +341,10 @@ __forceinline void hh_trafo_kernel_32_AVX512_4hv_double(double* q, double* hh, i
h1 = _mm512_set1_pd(hh[nb-1]); h1 = _mm512_set1_pd(hh[nb-1]);
q1 = _mm512_set1_pd(q[(nb+2)*ldq]); q1 = _mm512_load_pd(&q[(nb+2)*ldq]);
q2 = _mm512_set1_pd(q[((nb+2)*ldq)+8]); q2 = _mm512_load_pd(&q[((nb+2)*ldq)+8]);
q3 = _mm512_set1_pd(q[((nb+2)*ldq)+16]); q3 = _mm512_load_pd(&q[((nb+2)*ldq)+16]);
q4 = _mm512_set1_pd(q[((nb+2)*ldq)+24]); q4 = _mm512_load_pd(&q[((nb+2)*ldq)+24]);
x1 = _mm512_FMA_pd(q1, h1, x1); x1 = _mm512_FMA_pd(q1, h1, x1);
x2 = _mm512_FMA_pd(q2, h1, x2); x2 = _mm512_FMA_pd(q2, h1, x2);
...@@ -751,9 +750,9 @@ __forceinline void hh_trafo_kernel_24_AVX512_4hv_double(double* q, double* hh, i ...@@ -751,9 +750,9 @@ __forceinline void hh_trafo_kernel_24_AVX512_4hv_double(double* q, double* hh, i
h1 = _mm512_set1_pd(hh[nb-1]); h1 = _mm512_set1_pd(hh[nb-1]);
q1 = _mm512_set1_pd(q[(nb+2)*ldq]); q1 = _mm512_load_pd(&q[(nb+2)*ldq]);
q2 = _mm512_set1_pd(q[((nb+2)*ldq)+8]); q2 = _mm512_load_pd(&q[((nb+2)*ldq)+8]);
q3 = _mm512_set1_pd(q[((nb+2)*ldq)+16]); q3 = _mm512_load_pd(&q[((nb+2)*ldq)+16]);
x1 = _mm512_FMA_pd(q1, h1, x1); x1 = _mm512_FMA_pd(q1, h1, x1);
x2 = _mm512_FMA_pd(q2, h1, x2); x2 = _mm512_FMA_pd(q2, h1, x2);
...@@ -1092,8 +1091,8 @@ __forceinline void hh_trafo_kernel_16_AVX512_4hv_double(double* q, double* hh, i ...@@ -1092,8 +1091,8 @@ __forceinline void hh_trafo_kernel_16_AVX512_4hv_double(double* q, double* hh, i
h1 = _mm512_set1_pd(hh[nb-1]); h1 = _mm512_set1_pd(hh[nb-1]);
q1 = _mm512_set1_pd(q[(nb+2)*ldq]); q1 = _mm512_load_pd(&q[(nb+2)*ldq]);
q2 = _mm512_set1_pd(q[((nb+2)*ldq)+8]); q2 = _mm512_load_pd(&q[((nb+2)*ldq)+8]);
x1 = _mm512_FMA_pd(q1, h1, x1); x1 = _mm512_FMA_pd(q1, h1, x1);
x2 = _mm512_FMA_pd(q2, h1, x2); x2 = _mm512_FMA_pd(q2, h1, x2);
...@@ -1345,7 +1344,7 @@ __forceinline void hh_trafo_kernel_8_AVX512_4hv_double(double* q, double* hh, in ...@@ -1345,7 +1344,7 @@ __forceinline void hh_trafo_kernel_8_AVX512_4hv_double(double* q, double* hh, in
h1 = _mm512_set1_pd(hh[nb-1]); h1 = _mm512_set1_pd(hh[nb-1]);
q1 = _mm512_set1_pd(q[(nb+2)*ldq]); q1 = _mm512_load_pd(&q[(nb+2)*ldq]);
x1 = _mm512_FMA_pd(q1, h1, x1); x1 = _mm512_FMA_pd(q1, h1, x1);
...@@ -1461,6 +1460,8 @@ __forceinline void hh_trafo_kernel_8_AVX512_4hv_double(double* q, double* hh, in ...@@ -1461,6 +1460,8 @@ __forceinline void hh_trafo_kernel_8_AVX512_4hv_double(double* q, double* hh, in
q1 = _mm512_NFMA_pd(x1, h1, q1); q1 = _mm512_NFMA_pd(x1, h1, q1);
_mm512_store_pd(&q[(nb+2)*ldq],q1); _mm512_store_pd(&q[(nb+2)*ldq],q1);
} }
...@@ -1708,3 +1709,4 @@ __forceinline void hh_trafo_kernel_4_AVX512_4hv_double(double* q, double* hh, in ...@@ -1708,3 +1709,4 @@ __forceinline void hh_trafo_kernel_4_AVX512_4hv_double(double* q, double* hh, in
_mm256_store_pd(&q[(nb+2)*ldq],q1); _mm256_store_pd(&q[(nb+2)*ldq],q1);
} }
#endif #endif
...@@ -353,7 +353,7 @@ module compute_hh_trafo_real ...@@ -353,7 +353,7 @@ module compute_hh_trafo_real
if ((THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX512_BLOCK2)) then if ((THIS_REAL_ELPA_KERNEL .eq. REAL_ELPA_KERNEL_AVX512_BLOCK2)) then
#endif /* WITH_NO_SPECIFIC_REAL_KERNEL */ #endif /* WITH_NO_SPECIFIC_REAL_KERNEL */
#if defined(WITH_NO_SPECIFIC_REAL_KERNEL) || (defined(WITH_ONE_SPECIFIC_REAL_KERNEL) && !defined(WITH_REAL_AVX512_BLOCK6_KERNEL) && !defined(WITH_RRAL_AVX512_BLOCK4_KERNEL)) #if defined(WITH_NO_SPECIFIC_REAL_KERNEL) || (defined(WITH_ONE_SPECIFIC_REAL_KERNEL) && !defined(WITH_REAL_AVX512_BLOCK6_KERNEL) && !defined(WITH_REAL_AVX512_BLOCK4_KERNEL))
do j = ncols, 2, -2 do j = ncols, 2, -2
w(:,1) = bcast_buffer(1:nbw,j+off) w(:,1) = bcast_buffer(1:nbw,j+off)
w(:,2) = bcast_buffer(1:nbw,j+off-1) w(:,2) = bcast_buffer(1:nbw,j+off-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment