Commit 5a190ae6 authored by Martin Reinecke's avatar Martin Reinecke

more flexible vector lengths

parent d2529537
......@@ -95,15 +95,21 @@ constexpr bool FORWARD = true,
#endif
#endif
template<typename T> struct VLEN { static constexpr size_t val=1; };
#ifndef POCKETFFT_NO_VECTORS
#if (defined(__AVX512F__))
constexpr int VBYTELEN=64;
template<> struct VLEN<float> { static constexpr size_t val=16; };
template<> struct VLEN<double> { static constexpr size_t val=8; };
#elif (defined(__AVX__))
constexpr int VBYTELEN=32;
template<> struct VLEN<float> { static constexpr size_t val=8; };
template<> struct VLEN<double> { static constexpr size_t val=4; };
#elif (defined(__SSE2__))
constexpr int VBYTELEN=16;
template<> struct VLEN<float> { static constexpr size_t val=4; };
template<> struct VLEN<double> { static constexpr size_t val=2; };
#elif (defined(__VSX__))
constexpr int VBYTELEN=16;
template<> struct VLEN<float> { static constexpr size_t val=4; };
template<> struct VLEN<double> { static constexpr size_t val=2; };
#else
#define POCKETFFT_NO_VECTORS
#endif
......@@ -2314,31 +2320,25 @@ template<size_t N, typename Ti, typename To> class multi_iter
#ifndef POCKETFFT_NO_VECTORS
template<typename T> struct VTYPE {};
template<> struct VTYPE<long double>
template<> struct VTYPE<float>
{
using type = long double __attribute__ ((vector_size (sizeof(long double))));
static constexpr size_t vlen=1;
using type = float __attribute__ ((vector_size (VLEN<float>::val*sizeof(float))));
};
template<> struct VTYPE<double>
{
using type = double __attribute__ ((vector_size (VBYTELEN)));
static constexpr size_t vlen=VBYTELEN/sizeof(double);
using type = double __attribute__ ((vector_size (VLEN<double>::val*sizeof(double))));
};
template<> struct VTYPE<float>
template<> struct VTYPE<long double>
{
using type = float __attribute__ ((vector_size (VBYTELEN)));
static constexpr size_t vlen=VBYTELEN/sizeof(float);
using type = long double __attribute__ ((vector_size (VLEN<long double>::val*sizeof(long double))));
};
#else
template<typename T> struct VTYPE
{ static constexpr size_t vlen=1; };
#endif
template<typename T> arr<char> alloc_tmp(const shape_t &shape,
size_t axsize, size_t elemsize)
{
auto othersize = util::prod(shape)/axsize;
auto tmpsize = axsize*((othersize>=VTYPE<T>::vlen) ? VTYPE<T>::vlen : 1);
auto tmpsize = axsize*((othersize>=VLEN<T>::val) ? VLEN<T>::val : 1);
return arr<char>(tmpsize*elemsize);
}
template<typename T> arr<char> alloc_tmp(const shape_t &shape,
......@@ -2350,7 +2350,7 @@ template<typename T> arr<char> alloc_tmp(const shape_t &shape,
{
auto axsize = shape[axes[i]];
auto othersize = fullsize/axsize;
auto sz = axsize*((othersize>=VTYPE<T>::vlen) ? VTYPE<T>::vlen : 1);
auto sz = axsize*((othersize>=VLEN<T>::val) ? VLEN<T>::val : 1);
if (sz>tmpsize) tmpsize=sz;
}
return arr<char>(tmpsize*elemsize);
......@@ -2370,7 +2370,7 @@ template<typename T> NOINLINE void general_c(
for (size_t iax=0; iax<axes.size(); ++iax)
{
constexpr auto vlen = VTYPE<T>::vlen;
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_c<T>(len));
......@@ -2433,7 +2433,7 @@ template<typename T> NOINLINE void general_hartley(
for (size_t iax=0; iax<axes.size(); ++iax)
{
constexpr auto vlen = VTYPE<T>::vlen;
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axes[iax]);
if ((!plan) || (len!=plan->length()))
plan.reset(new pocketfft_r<T>(len));
......@@ -2494,7 +2494,7 @@ template<typename T> NOINLINE void general_r2c(
size_t POCKETFFT_NTHREADS)
{
pocketfft_r<T> plan(in.shape(axis));
constexpr auto vlen = VTYPE<T>::vlen;
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axis);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axis))
......@@ -2545,7 +2545,7 @@ template<typename T> NOINLINE void general_c2r(
size_t POCKETFFT_NTHREADS)
{
pocketfft_r<T> plan(out.shape(axis));
constexpr auto vlen = VTYPE<T>::vlen;
constexpr auto vlen = VLEN<T>::val;
size_t len=out.shape(axis);
#ifdef POCKETFFT_OPENMP
#pragma omp parallel num_threads(util::thread_count(nthreads, in.shape(), axis))
......@@ -2600,7 +2600,7 @@ template<typename T> NOINLINE void general_r(
const ndarr<T> &in, ndarr<T> &out, size_t axis, bool forward, T fct,
size_t POCKETFFT_NTHREADS)
{
constexpr auto vlen = VTYPE<T>::vlen;
constexpr auto vlen = VLEN<T>::val;
size_t len=in.shape(axis);
pocketfft_r<T> plan(len);
#ifdef POCKETFFT_OPENMP
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment