Commit 5a190ae6 authored by Martin Reinecke's avatar Martin Reinecke

more flexible vector lengths

parent d2529537
...@@ -95,15 +95,21 @@ constexpr bool FORWARD = true, ...@@ -95,15 +95,21 @@ constexpr bool FORWARD = true,
#endif #endif
#endif #endif
template<typename T> struct VLEN { static constexpr size_t val=1; };
#ifndef POCKETFFT_NO_VECTORS #ifndef POCKETFFT_NO_VECTORS
#if (defined(__AVX512F__)) #if (defined(__AVX512F__))
constexpr int VBYTELEN=64; template<> struct VLEN<float> { static constexpr size_t val=16; };
template<> struct VLEN<double> { static constexpr size_t val=8; };
#elif (defined(__AVX__)) #elif (defined(__AVX__))
constexpr int VBYTELEN=32; template<> struct VLEN<float> { static constexpr size_t val=8; };
template<> struct VLEN<double> { static constexpr size_t val=4; };
#elif (defined(__SSE2__)) #elif (defined(__SSE2__))
constexpr int VBYTELEN=16; template<> struct VLEN<float> { static constexpr size_t val=4; };
template<> struct VLEN<double> { static constexpr size_t val=2; };
#elif (defined(__VSX__)) #elif (defined(__VSX__))
constexpr int VBYTELEN=16; template<> struct VLEN<float> { static constexpr size_t val=4; };
template<> struct VLEN<double> { static constexpr size_t val=2; };
#else #else
#define POCKETFFT_NO_VECTORS #define POCKETFFT_NO_VECTORS
#endif #endif
...@@ -2314,31 +2320,25 @@ template<size_t N, typename Ti, typename To> class multi_iter ...@@ -2314,31 +2320,25 @@ template<size_t N, typename Ti, typename To> class multi_iter
#ifndef POCKETFFT_NO_VECTORS #ifndef POCKETFFT_NO_VECTORS
template<typename T> struct VTYPE {}; template<typename T> struct VTYPE {};
template<> struct VTYPE<long double> template<> struct VTYPE<float>
{ {
using type = long double __attribute__ ((vector_size (sizeof(long double)))); using type = float __attribute__ ((vector_size (VLEN<float>::val*sizeof(float))));
static constexpr size_t vlen=1;
}; };
template<> struct VTYPE<double> template<> struct VTYPE<double>
{ {
using type = double __attribute__ ((vector_size (VBYTELEN))); using type = double __attribute__ ((vector_size (VLEN<double>::val*sizeof(double))));
static constexpr size_t vlen=VBYTELEN/sizeof(double);
}; };
template<> struct VTYPE<float> template<> struct VTYPE<long double>
{ {
using type = float __attribute__ ((vector_size (VBYTELEN))); using type = long double __attribute__ ((vector_size (VLEN<long double>::val*sizeof(long double))));
static constexpr size_t vlen=VBYTELEN/sizeof(float);
}; };
#else
template<typename T> struct VTYPE
{ static constexpr size_t vlen=1; };
#endif #endif
template<typename T> arr<char> alloc_tmp(const shape_t &shape, template<typename T> arr<char> alloc_tmp(const shape_t &shape,
size_t axsize, size_t elemsize) size_t axsize, size_t elemsize)
{ {
auto othersize = util::prod(shape)/axsize; auto othersize = util::prod(shape)/axsize;
auto tmpsize = axsize*((othersize>=VTYPE<T>::vlen) ? VTYPE<T>::vlen : 1); auto tmpsize = axsize*((othersize>=VLEN<T>::val) ? VLEN<T>::val : 1);
return arr<char>(tmpsize*elemsize); return arr<char>(tmpsize*elemsize);
} }
template<typename T> arr<char> alloc_tmp(const shape_t &shape, template<typename T> arr<char> alloc_tmp(const shape_t &shape,
...@@ -2350,7 +2350,7 @@ template<typename T> arr<char> alloc_tmp(const shape_t &shape, ...@@ -2350,7 +2350,7 @@ template<typename T> arr<char> alloc_tmp(const shape_t &shape,
{ {
auto axsize = shape[axes[i]]; auto axsize = shape[axes[i]];
auto othersize = fullsize/axsize; auto othersize = fullsize/axsize;
auto sz = axsize*((othersize>=VTYPE<T>::vlen) ? VTYPE<T>::vlen : 1); auto sz = axsize*((othersize>=VLEN<T>::val) ? VLEN<T>::val : 1);
if (sz>tmpsize) tmpsize=sz; if (sz>tmpsize) tmpsize=sz;
} }
return arr<char>(tmpsize*elemsize); return arr<char>(tmpsize*elemsize);
...@@ -2370,7 +2370,7 @@ template<typename T> NOINLINE void general_c( ...@@ -2370,7 +2370,7 @@ template<typename T> NOINLINE void general_c(
for (size_t iax=0; iax<axes.size(); ++iax) for (size_t iax=0; iax<axes.size(); ++iax)
{ {
constexpr auto vlen = VTYPE<T>::vlen; constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]); size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length())) if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_c<T>(len)); plan.reset(new pocketfft_c<T>(len));
...@@ -2433,7 +2433,7 @@ template<typename T> NOINLINE void general_hartley( ...@@ -2433,7 +2433,7 @@ template<typename T> NOINLINE void general_hartley(
for (size_t iax=0; iax<axes.size(); ++iax) for (size_t iax=0; iax<axes.size(); ++iax)
{ {
constexpr auto vlen = VTYPE<T>::vlen; constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]); size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length())) if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_r<T>(len)); plan.reset(new pocketfft_r<T>(len));
...@@ -2494,7 +2494,7 @@ template<typename T> NOINLINE void general_r2c( ...@@ -2494,7 +2494,7 @@ template<typename T> NOINLINE void general_r2c(
size_t POCKETFFT_NTHREADS) size_t POCKETFFT_NTHREADS)
{ {
pocketfft_r<T> plan(in.shape(axis)); pocketfft_r<T> plan(in.shape(axis));
constexpr auto vlen = VTYPE<T>::vlen; constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axis); size_t len=in.shape(axis);
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axis)) #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axis))
...@@ -2545,7 +2545,7 @@ template<typename T> NOINLINE void general_c2r( ...@@ -2545,7 +2545,7 @@ template<typename T> NOINLINE void general_c2r(
size_t POCKETFFT_NTHREADS) size_t POCKETFFT_NTHREADS)
{ {
pocketfft_r<T> plan(out.shape(axis)); pocketfft_r<T> plan(out.shape(axis));
constexpr auto vlen = VTYPE<T>::vlen; constexpr auto vlen = VLEN<T>::val;
size_t len=out.shape(axis); size_t len=out.shape(axis);
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axis)) #pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axis))
...@@ -2600,7 +2600,7 @@ template<typename T> NOINLINE void general_r( ...@@ -2600,7 +2600,7 @@ template<typename T> NOINLINE void general_r(
const ndarr<T> &in, ndarr<T> &out, size_t axis, bool forward, T fct, const ndarr<T> &in, ndarr<T> &out, size_t axis, bool forward, T fct,
size_t POCKETFFT_NTHREADS) size_t POCKETFFT_NTHREADS)
{ {
constexpr auto vlen = VTYPE<T>::vlen; constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axis); size_t len=in.shape(axis);
pocketfft_r<T> plan(len); pocketfft_r<T> plan(len);
#ifdef POCKETFFT_OPENMP #ifdef POCKETFFT_OPENMP
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment