diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 474a7a8325643eeaead13d3e91d621c4489aef6f..fd146c710d31fd1f9b63f88c0ad1b73dd8c37654 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -95,15 +95,21 @@ constexpr bool FORWARD = true, #endif #endif +template<typename T> struct VLEN { static constexpr size_t val=1; }; + #ifndef POCKETFFT_NO_VECTORS #if (defined(__AVX512F__)) -constexpr int VBYTELEN=64; +template<> struct VLEN<float> { static constexpr size_t val=16; }; +template<> struct VLEN<double> { static constexpr size_t val=8; }; #elif (defined(__AVX__)) -constexpr int VBYTELEN=32; +template<> struct VLEN<float> { static constexpr size_t val=8; }; +template<> struct VLEN<double> { static constexpr size_t val=4; }; #elif (defined(__SSE2__)) -constexpr int VBYTELEN=16; +template<> struct VLEN<float> { static constexpr size_t val=4; }; +template<> struct VLEN<double> { static constexpr size_t val=2; }; #elif (defined(__VSX__)) -constexpr int VBYTELEN=16; +template<> struct VLEN<float> { static constexpr size_t val=4; }; +template<> struct VLEN<double> { static constexpr size_t val=2; }; #else #define POCKETFFT_NO_VECTORS #endif @@ -2314,31 +2320,25 @@ template<size_t N, typename Ti, typename To> class multi_iter #ifndef POCKETFFT_NO_VECTORS template<typename T> struct VTYPE {}; -template<> struct VTYPE<long double> +template<> struct VTYPE<float> { - using type = long double __attribute__ ((vector_size (sizeof(long double)))); - static constexpr size_t vlen=1; + using type = float __attribute__ ((vector_size (VLEN<float>::val*sizeof(float)))); }; template<> struct VTYPE<double> { - using type = double __attribute__ ((vector_size (VBYTELEN))); - static constexpr size_t vlen=VBYTELEN/sizeof(double); + using type = double __attribute__ ((vector_size (VLEN<double>::val*sizeof(double)))); }; -template<> struct VTYPE<float> +template<> struct VTYPE<long double> { - using type = float __attribute__ ((vector_size (VBYTELEN))); - static constexpr size_t vlen=VBYTELEN/sizeof(float); + using type = long double __attribute__ ((vector_size (VLEN<long double>::val*sizeof(long double)))); }; -#else -template<typename T> struct VTYPE - { static constexpr size_t vlen=1; }; #endif template<typename T> arr<char> alloc_tmp(const shape_t &shape, size_t axsize, size_t elemsize) { auto othersize = util::prod(shape)/axsize; - auto tmpsize = axsize*((othersize>=VTYPE<T>::vlen) ? VTYPE<T>::vlen : 1); + auto tmpsize = axsize*((othersize>=VLEN<T>::val) ? VLEN<T>::val : 1); return arr<char>(tmpsize*elemsize); } template<typename T> arr<char> alloc_tmp(const shape_t &shape, @@ -2350,7 +2350,7 @@ template<typename T> arr<char> alloc_tmp(const shape_t &shape, { auto axsize = shape[axes[i]]; auto othersize = fullsize/axsize; - auto sz = axsize*((othersize>=VTYPE<T>::vlen) ? VTYPE<T>::vlen : 1); + auto sz = axsize*((othersize>=VLEN<T>::val) ? VLEN<T>::val : 1); if (sz>tmpsize) tmpsize=sz; } return arr<char>(tmpsize*elemsize); @@ -2370,7 +2370,7 @@ template<typename T> NOINLINE void general_c( for (size_t iax=0; iax<axes.size(); ++iax) { - constexpr auto vlen = VTYPE<T>::vlen; + constexpr auto vlen = VLEN<T>::val; size_t len=in.shape(axes[iax]); if ((!plan) || (len!=plan->length())) plan.reset(new pocketfft_c<T>(len)); @@ -2433,7 +2433,7 @@ template<typename T> NOINLINE void general_hartley( for (size_t iax=0; iax<axes.size(); ++iax) { - constexpr auto vlen = VTYPE<T>::vlen; + constexpr auto vlen = VLEN<T>::val; size_t len=in.shape(axes[iax]); if ((!plan) || (len!=plan->length())) plan.reset(new pocketfft_r<T>(len)); @@ -2494,7 +2494,7 @@ template<typename T> NOINLINE void general_r2c( size_t POCKETFFT_NTHREADS) { pocketfft_r<T> plan(in.shape(axis)); - constexpr auto vlen = VTYPE<T>::vlen; + constexpr auto vlen = VLEN<T>::val; size_t len=in.shape(axis); #ifdef POCKETFFT_OPENMP #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axis)) @@ -2545,7 +2545,7 @@ template<typename T> NOINLINE void general_c2r( size_t POCKETFFT_NTHREADS) { pocketfft_r<T> plan(out.shape(axis)); - constexpr auto vlen = VTYPE<T>::vlen; + constexpr auto vlen = VLEN<T>::val; size_t len=out.shape(axis); #ifdef POCKETFFT_OPENMP #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axis)) @@ -2600,7 +2600,7 @@ template<typename T> NOINLINE void general_r( const ndarr<T> &in, ndarr<T> &out, size_t axis, bool forward, T fct, size_t POCKETFFT_NTHREADS) { - constexpr auto vlen = VTYPE<T>::vlen; + constexpr auto vlen = VLEN<T>::val; size_t len=in.shape(axis); pocketfft_r<T> plan(len); #ifdef POCKETFFT_OPENMP