Commit 39de2274 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

synchronize

parent 2fd77f20
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <cstdlib> #include <cstdlib>
#include <algorithm>
#include <stdexcept> #include <stdexcept>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -34,13 +33,14 @@ ...@@ -34,13 +33,14 @@
namespace pocketfft { namespace pocketfft {
using namespace std; using shape_t = std::vector<size_t>;
using stride_t = std::vector<ptrdiff_t>;
using shape_t = vector<size_t>;
using stride_t = vector<ptrdiff_t>;
namespace detail { namespace detail {
using namespace std;
#ifndef POCKETFFT_NO_VECTORS #ifndef POCKETFFT_NO_VECTORS
#if (defined(__AVX512F__)) #if (defined(__AVX512F__))
#define HAVE_VECSUPPORT #define HAVE_VECSUPPORT
...@@ -56,7 +56,7 @@ constexpr int VBYTELEN=16; ...@@ -56,7 +56,7 @@ constexpr int VBYTELEN=16;
#endif #endif
#endif #endif
template<typename T> struct arr template<typename T> class arr
{ {
private: private:
T *p; T *p;
...@@ -228,16 +228,11 @@ template<typename T> class sincos_2pibyn ...@@ -228,16 +228,11 @@ template<typename T> class sincos_2pibyn
size_t i=0, idx1=0, idx2=2*ndone-2; size_t i=0, idx1=0, idx2=2*ndone-2;
for (; i+1<ndone; i+=2, idx1+=2, idx2-=2) for (; i+1<ndone; i+=2, idx1+=2, idx2-=2)
{ {
res[idx1] = p[2*i]; res[idx1] = p[2*i ]; res[idx1+1] = p[2*i+1];
res[idx1+1] = p[2*i+1]; res[idx2] = p[2*i+3]; res[idx2+1] = p[2*i+2];
res[idx2] = p[2*i+3];
res[idx2+1] = p[2*i+2];
} }
if (i!=ndone) if (i!=ndone)
{ { res[idx1] = p[2*i]; res[idx1+1] = p[2*i+1]; }
res[idx1 ] = p[2*i];
res[idx1+1] = p[2*i+1];
}
} }
void calc_first_half(size_t n, T * restrict res) void calc_first_half(size_t n, T * restrict res)
...@@ -247,24 +242,13 @@ template<typename T> class sincos_2pibyn ...@@ -247,24 +242,13 @@ template<typename T> class sincos_2pibyn
calc_first_octant(n<<2, p); calc_first_octant(n<<2, p);
int i4=0, in=n, i=0; int i4=0, in=n, i=0;
for (; i4<=in-i4; ++i, i4+=4) // octant 0 for (; i4<=in-i4; ++i, i4+=4) // octant 0
{ { res[2*i] = p[2*i4]; res[2*i+1] = p[2*i4+1]; }
res[2*i] = p[2*i4]; res[2*i+1] = p[2*i4+1];
}
for (; i4-in <= 0; ++i, i4+=4) // octant 1 for (; i4-in <= 0; ++i, i4+=4) // octant 1
{ { int xm = in-i4; res[2*i] = p[2*xm+1]; res[2*i+1] = p[2*xm]; }
int xm = in-i4;
res[2*i] = p[2*xm+1]; res[2*i+1] = p[2*xm];
}
for (; i4<=3*in-i4; ++i, i4+=4) // octant 2 for (; i4<=3*in-i4; ++i, i4+=4) // octant 2
{ { int xm = i4-in; res[2*i] = -p[2*xm+1]; res[2*i+1] = p[2*xm]; }
int xm = i4-in;
res[2*i] = -p[2*xm+1]; res[2*i+1] = p[2*xm];
}
for (; i<ndone; ++i, i4+=4) // octant 3 for (; i<ndone; ++i, i4+=4) // octant 3
{ { int xm = 2*in-i4; res[2*i] = -p[2*xm]; res[2*i+1] = p[2*xm+1]; }
int xm = 2*in-i4;
res[2*i] = -p[2*xm]; res[2*i+1] = p[2*xm+1];
}
} }
void fill_first_quadrant(size_t n, T * restrict res) void fill_first_quadrant(size_t n, T * restrict res)
...@@ -274,10 +258,7 @@ template<typename T> class sincos_2pibyn ...@@ -274,10 +258,7 @@ template<typename T> class sincos_2pibyn
if ((n&7)==0) if ((n&7)==0)
res[quart] = res[quart+1] = hsqt2; res[quart] = res[quart+1] = hsqt2;
for (size_t i=2, j=2*quart-2; i<quart; i+=2, j-=2) for (size_t i=2, j=2*quart-2; i<quart; i+=2, j-=2)
{ { res[j] = res[i+1]; res[j+1] = res[i]; }
res[j ] = res[i+1];
res[j+1] = res[i ];
}
} }
NOINLINE void fill_first_half(size_t n, T * restrict res) NOINLINE void fill_first_half(size_t n, T * restrict res)
...@@ -285,16 +266,10 @@ template<typename T> class sincos_2pibyn ...@@ -285,16 +266,10 @@ template<typename T> class sincos_2pibyn
size_t half = n>>1; size_t half = n>>1;
if ((n&3)==0) if ((n&3)==0)
for (size_t i=0; i<half; i+=2) for (size_t i=0; i<half; i+=2)
{ { res[i+half] = -res[i+1]; res[i+half+1] = res[i]; }
res[i+half] = -res[i+1];
res[i+half+1] = res[i ];
}
else else
for (size_t i=2, j=2*half-2; i<half; i+=2, j-=2) for (size_t i=2, j=2*half-2; i<half; i+=2, j-=2)
{ { res[j] = -res[i]; res[j+1] = res[i+1]; }
res[j ] = -res[i ];
res[j+1] = res[i+1];
}
} }
void fill_second_half(size_t n, T * restrict res) void fill_second_half(size_t n, T * restrict res)
...@@ -304,10 +279,7 @@ template<typename T> class sincos_2pibyn ...@@ -304,10 +279,7 @@ template<typename T> class sincos_2pibyn
res[i+n] = -res[i]; res[i+n] = -res[i];
else else
for (size_t i=2, j=2*n-2; i<n; i+=2, j-=2) for (size_t i=2, j=2*n-2; i<n; i+=2, j-=2)
{ { res[j] = res[i]; res[j+1] = -res[i+1]; }
res[j ] = res[i ];
res[j+1] = -res[i+1];
}
} }
NOINLINE void sincos_2pibyn_half(size_t n, T * restrict res) NOINLINE void sincos_2pibyn_half(size_t n, T * restrict res)
...@@ -411,8 +383,6 @@ struct util // hack to avoid duplicate symbols ...@@ -411,8 +383,6 @@ struct util // hack to avoid duplicate symbols
#define CC(a,b,c) cc[(a)+ido*((b)+cdim*(c))] #define CC(a,b,c) cc[(a)+ido*((b)+cdim*(c))]
#define WA(x,i) wa[(i)-1+(x)*(ido-1)] #define WA(x,i) wa[(i)-1+(x)*(ido-1)]
constexpr size_t NFCT=25;
// //
// complex FFTPACK transforms // complex FFTPACK transforms
// //
...@@ -427,15 +397,12 @@ template<typename T0> class cfftp ...@@ -427,15 +397,12 @@ template<typename T0> class cfftp
cmplx<T0> *tw, *tws; cmplx<T0> *tw, *tws;
}; };
size_t length, nfct; size_t length;
arr<cmplx<T0>> mem; arr<cmplx<T0>> mem;
fctdata fact[NFCT]; vector<fctdata> fact;
void add_factor(size_t factor) void add_factor(size_t factor)
{ { fact.push_back({factor, nullptr, nullptr}); }
if (nfct>=NFCT) throw runtime_error("too many prime factors");
fact[nfct++].fct = factor;
}
template<bool bwd, typename T> void pass2 (size_t ido, size_t l1, template<bool bwd, typename T> void pass2 (size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa) const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
...@@ -883,7 +850,7 @@ template<bool bwd, typename T> void pass_all(T c[], T0 fct) ...@@ -883,7 +850,7 @@ template<bool bwd, typename T> void pass_all(T c[], T0 fct)
arr<T> ch(length); arr<T> ch(length);
T *p1=c, *p2=ch.data(); T *p1=c, *p2=ch.data();
for(size_t k1=0; k1<nfct; k1++) for(size_t k1=0; k1<fact.size(); k1++)
{ {
size_t ip=fact[k1].fct; size_t ip=fact[k1].fct;
size_t l2=ip*l1; size_t l2=ip*l1;
...@@ -936,7 +903,6 @@ template<bool bwd, typename T> void pass_all(T c[], T0 fct) ...@@ -936,7 +903,6 @@ template<bool bwd, typename T> void pass_all(T c[], T0 fct)
private: private:
NOINLINE void factorize() NOINLINE void factorize()
{ {
nfct=0;
size_t len=length; size_t len=length;
while ((len&3)==0) while ((len&3)==0)
{ add_factor(4); len>>=2; } { add_factor(4); len>>=2; }
...@@ -945,7 +911,7 @@ template<bool bwd, typename T> void pass_all(T c[], T0 fct) ...@@ -945,7 +911,7 @@ template<bool bwd, typename T> void pass_all(T c[], T0 fct)
len>>=1; len>>=1;
// factor 2 should be at the front of the factor list // factor 2 should be at the front of the factor list
add_factor(2); add_factor(2);
swap(fact[0].fct, fact[nfct-1].fct); swap(fact[0].fct, fact.back().fct);
} }
size_t maxl = size_t(sqrt(double(len)))+1; size_t maxl = size_t(sqrt(double(len)))+1;
for (size_t divisor=3; (len>1)&&(divisor<maxl); divisor+=2) for (size_t divisor=3; (len>1)&&(divisor<maxl); divisor+=2)
...@@ -964,7 +930,7 @@ template<bool bwd, typename T> void pass_all(T c[], T0 fct) ...@@ -964,7 +930,7 @@ template<bool bwd, typename T> void pass_all(T c[], T0 fct)
size_t twsize() const size_t twsize() const
{ {
size_t twsize=0, l1=1; size_t twsize=0, l1=1;
for (size_t k=0; k<nfct; ++k) for (size_t k=0; k<fact.size(); ++k)
{ {
size_t ip=fact[k].fct, ido= length/(l1*ip); size_t ip=fact[k].fct, ido= length/(l1*ip);
twsize+=(ip-1)*(ido-1); twsize+=(ip-1)*(ido-1);
...@@ -981,7 +947,7 @@ template<bool bwd, typename T> void pass_all(T c[], T0 fct) ...@@ -981,7 +947,7 @@ template<bool bwd, typename T> void pass_all(T c[], T0 fct)
auto twiddle = twid.cdata(); auto twiddle = twid.cdata();
size_t l1=1; size_t l1=1;
size_t memofs=0; size_t memofs=0;
for (size_t k=0; k<nfct; ++k) for (size_t k=0; k<fact.size(); ++k)
{ {
size_t ip=fact[k].fct, ido=length/(l1*ip); size_t ip=fact[k].fct, ido=length/(l1*ip);
fact[k].tw=mem.data()+memofs; fact[k].tw=mem.data()+memofs;
...@@ -1002,7 +968,7 @@ template<bool bwd, typename T> void pass_all(T c[], T0 fct) ...@@ -1002,7 +968,7 @@ template<bool bwd, typename T> void pass_all(T c[], T0 fct)
public: public:
NOINLINE cfftp(size_t length_) NOINLINE cfftp(size_t length_)
: length(length_), nfct(0) : length(length_)
{ {
if (length==0) throw runtime_error("zero length FFT requested"); if (length==0) throw runtime_error("zero length FFT requested");
if (length==1) return; if (length==1) return;
...@@ -1025,15 +991,12 @@ template<typename T0> class rfftp ...@@ -1025,15 +991,12 @@ template<typename T0> class rfftp
T0 *tw, *tws; T0 *tw, *tws;
}; };
size_t length, nfct; size_t length;
arr<T0> mem; arr<T0> mem;
fctdata fact[NFCT]; vector<fctdata> fact;
void add_factor(size_t factor) void add_factor(size_t factor)
{ { fact.push_back({factor, nullptr, nullptr}); }
if (nfct>=NFCT) throw runtime_error("too many prime factors");
fact[nfct++].fct = factor;
}
#define WA(x,i) wa[(i)+(x)*(ido-1)] #define WA(x,i) wa[(i)+(x)*(ido-1)]
#define PM(a,b,c,d) { a=c+d; b=c-d; } #define PM(a,b,c,d) { a=c+d; b=c-d; }
...@@ -1646,7 +1609,7 @@ template<typename T> void radbg(size_t ido, size_t ip, size_t l1, ...@@ -1646,7 +1609,7 @@ template<typename T> void radbg(size_t ido, size_t ip, size_t l1,
#undef PM #undef PM
#undef WA #undef WA
template<typename T> void copy_and_norm(T *c, T *p1, size_t n, T0 fct) template<typename T> void copy_and_norm(T *c, T *p1, size_t n, T0 fct)
{ {
if (p1!=c) if (p1!=c)
{ {
...@@ -1663,12 +1626,11 @@ template<typename T> void copy_and_norm(T *c, T *p1, size_t n, T0 fct) ...@@ -1663,12 +1626,11 @@ template<typename T> void copy_and_norm(T *c, T *p1, size_t n, T0 fct)
} }
public: public:
template<typename T> void forward(T c[], T0 fct)
template<typename T> void forward(T c[], T0 fct)
{ {
if (length==1) { c[0]*=fct; return; } if (length==1) { c[0]*=fct; return; }
size_t n=length; size_t n=length;
size_t l1=n, nf=nfct; size_t l1=n, nf=fact.size();
arr<T> ch(n); arr<T> ch(n);
T *p1=c, *p2=ch.data(); T *p1=c, *p2=ch.data();
...@@ -1687,20 +1649,17 @@ template<typename T> void forward(T c[], T0 fct) ...@@ -1687,20 +1649,17 @@ template<typename T> void forward(T c[], T0 fct)
else if(ip==5) else if(ip==5)
radf5(ido, l1, p1, p2, fact[k].tw); radf5(ido, l1, p1, p2, fact[k].tw);
else else
{ { radfg(ido, ip, l1, p1, p2, fact[k].tw, fact[k].tws); swap (p1,p2); }
radfg(ido, ip, l1, p1, p2, fact[k].tw, fact[k].tws);
swap (p1,p2);
}
swap (p1,p2); swap (p1,p2);
} }
copy_and_norm(c,p1,n,fct); copy_and_norm(c,p1,n,fct);
} }
template<typename T> void backward(T c[], T0 fct) template<typename T> void backward(T c[], T0 fct)
{ {
if (length==1) { c[0]*=fct; return; } if (length==1) { c[0]*=fct; return; }
size_t n=length; size_t n=length;
size_t l1=1, nf=nfct; size_t l1=1, nf=fact.size();
arr<T> ch(n); arr<T> ch(n);
T *p1=c, *p2=ch.data(); T *p1=c, *p2=ch.data();
...@@ -1725,11 +1684,9 @@ template<typename T> void backward(T c[], T0 fct) ...@@ -1725,11 +1684,9 @@ template<typename T> void backward(T c[], T0 fct)
} }
private: private:
void factorize()
void factorize()
{ {
size_t len=length; size_t len=length;
nfct=0;
while ((len%4)==0) while ((len%4)==0)
{ add_factor(4); len>>=2; } { add_factor(4); len>>=2; }
if ((len%2)==0) if ((len%2)==0)
...@@ -1737,7 +1694,7 @@ void factorize() ...@@ -1737,7 +1694,7 @@ void factorize()
len>>=1; len>>=1;
// factor 2 should be at the front of the factor list // factor 2 should be at the front of the factor list
add_factor(2); add_factor(2);
swap(fact[0].fct, fact[nfct-1].fct); swap(fact[0].fct, fact.back().fct);
} }
size_t maxl = size_t(sqrt(double(len)))+1; size_t maxl = size_t(sqrt(double(len)))+1;
for (size_t divisor=3; (len>1)&&(divisor<maxl); divisor+=2) for (size_t divisor=3; (len>1)&&(divisor<maxl); divisor+=2)
...@@ -1753,10 +1710,10 @@ void factorize() ...@@ -1753,10 +1710,10 @@ void factorize()
if (len>1) add_factor(len); if (len>1) add_factor(len);
} }
size_t twsize() const size_t twsize() const
{ {
size_t twsz=0, l1=1; size_t twsz=0, l1=1;
for (size_t k=0; k<nfct; ++k) for (size_t k=0; k<fact.size(); ++k)
{ {
size_t ip=fact[k].fct, ido=length/(l1*ip); size_t ip=fact[k].fct, ido=length/(l1*ip);
twsz+=(ip-1)*(ido-1); twsz+=(ip-1)*(ido-1);
...@@ -1766,15 +1723,15 @@ size_t twsize() const ...@@ -1766,15 +1723,15 @@ size_t twsize() const
return twsz; return twsz;
} }
void comp_twiddle() void comp_twiddle()
{ {
sincos_2pibyn<T0> twid(length, true); sincos_2pibyn<T0> twid(length, true);
size_t l1=1; size_t l1=1;
T0 *ptr=mem.data(); T0 *ptr=mem.data();
for (size_t k=0; k<nfct; ++k) for (size_t k=0; k<fact.size(); ++k)
{ {
size_t ip=fact[k].fct, ido=length/(l1*ip); size_t ip=fact[k].fct, ido=length/(l1*ip);
if (k<nfct-1) // last factor doesn't need twiddles if (k<fact.size()-1) // last factor doesn't need twiddles
{ {
fact[k].tw=ptr; ptr+=(ip-1)*(ido-1); fact[k].tw=ptr; ptr+=(ip-1)*(ido-1);
for (size_t j=1; j<ip; ++j) for (size_t j=1; j<ip; ++j)
...@@ -1802,18 +1759,15 @@ void comp_twiddle() ...@@ -1802,18 +1759,15 @@ void comp_twiddle()
} }
public: public:
NOINLINE rfftp(size_t length_)
NOINLINE rfftp(size_t length_)
: length(length_) : length(length_)
{ {
if (length==0) throw runtime_error("zero-sized FFT"); if (length==0) throw runtime_error("zero-sized FFT");
nfct=0;
if (length==1) return; if (length==1) return;
factorize(); factorize();
mem.resize(twsize()); mem.resize(twsize());
comp_twiddle(); comp_twiddle();
} }
}; };
// //
...@@ -1942,16 +1896,10 @@ template<typename T0> class pocketfft_c ...@@ -1942,16 +1896,10 @@ template<typename T0> class pocketfft_c
} }
template<typename T> NOINLINE void backward(cmplx<T> c[], T0 fct) template<typename T> NOINLINE void backward(cmplx<T> c[], T0 fct)
{ { packplan ? packplan->backward(c,fct) : blueplan->backward(c,fct); }
packplan ? packplan->backward(c,fct)
: blueplan->backward(c,fct);
}
template<typename T> NOINLINE void forward(cmplx<T> c[], T0 fct) template<typename T> NOINLINE void forward(cmplx<T> c[], T0 fct)
{ { packplan ? packplan->forward(c,fct) : blueplan->forward(c,fct); }
packplan ? packplan->forward(c,fct)
: blueplan->forward(c,fct);
}
size_t length() const { return len; } size_t length() const { return len; }
}; };
...@@ -2040,10 +1988,8 @@ template<size_t N, typename Ti, typename To> class multi_iter ...@@ -2040,10 +1988,8 @@ template<size_t N, typename Ti, typename To> class multi_iter
ndarr<To> &oarr; ndarr<To> &oarr;
private: private:
ptrdiff_t p_ii, p_i[N], str_i; ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o;
ptrdiff_t p_oi, p_o[N], str_o; size_t idim, rem;
size_t idim;
size_t rem;
void advance_i() void advance_i()
{ {
...@@ -2128,8 +2074,8 @@ template<typename T> arr<char> alloc_tmp(const shape_t &shape, ...@@ -2128,8 +2074,8 @@ template<typename T> arr<char> alloc_tmp(const shape_t &shape,
{ {
auto axsize = shape[axes[i]]; auto axsize = shape[axes[i]];
auto othersize = fullsize/axsize; auto othersize = fullsize/axsize;
tmpsize = max(tmpsize, auto sz = axsize*((othersize>=VTYPE<T>::vlen) ? VTYPE<T>::vlen : 1);
axsize*((othersize>=VTYPE<T>::vlen) ? VTYPE<T>::vlen : 1)); if (sz>tmpsize) tmpsize=sz;
} }
return arr<char>(tmpsize*elemsize); return arr<char>(tmpsize*elemsize);
} }
...@@ -2157,12 +2103,8 @@ template<typename T> NOINLINE void general_c( ...@@ -2157,12 +2103,8 @@ template<typename T> NOINLINE void general_c(
auto tdatav = reinterpret_cast<cmplx<vtype> *>(storage.data()); auto tdatav = reinterpret_cast<cmplx<vtype> *>(storage.data());
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
{ { tdatav[i].r[j] = it.in(j,i).r; tdatav[i].i[j] = it.in(j,i).i; }
tdatav[i].r[j] = it.in(j,i).r; forward ? plan->forward (tdatav, fct) : plan->backward(tdatav, fct);
tdatav[i].i[j] = it.in(j,i).i;
}
forward ? plan->forward (tdatav, fct)
: plan->backward(tdatav, fct);
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
it.out(j,i).Set(tdatav[i].r[j],tdatav[i].i[j]); it.out(j,i).Set(tdatav[i].r[j],tdatav[i].i[j]);
...@@ -2186,8 +2128,7 @@ template<typename T> NOINLINE void general_c( ...@@ -2186,8 +2128,7 @@ template<typename T> NOINLINE void general_c(
{ {
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
tdata[i] = it.in(i); tdata[i] = it.in(i);
forward ? plan->forward (tdata, fct) forward ? plan->forward (tdata, fct) : plan->backward(tdata, fct);
: plan->backward(tdata, fct);
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
it.out(i) = tdata[i]; it.out(i) = tdata[i];
} }
...@@ -2245,10 +2186,7 @@ template<typename T> NOINLINE void general_hartley( ...@@ -2245,10 +2186,7 @@ template<typename T> NOINLINE void general_hartley(
it.out(0) = tdata[0]; it.out(0) = tdata[0];
size_t i=1, i1=1, i2=len-1; size_t i=1, i1=1, i2=len-1;
for (i=1; i<len-1; i+=2, ++i1, --i2) for (i=1; i<len-1; i+=2, ++i1, --i2)
{ { it.out(i1) = tdata[i]+tdata[i+1]; it.out(i2) = tdata[i]-tdata[i+1]; }
it.out(i1) = tdata[i]+tdata[i+1];
it.out(i2) = tdata[i]-tdata[i+1];
}
if (i<len) if (i<len)
it.out(i1) = tdata[i]; it.out(i1) = tdata[i];
} }
...@@ -2324,10 +2262,7 @@ template<typename T> NOINLINE void general_c2r( ...@@ -2324,10 +2262,7 @@ template<typename T> NOINLINE void general_c2r(
size_t i=1, ii=1; size_t i=1, ii=1;
for (; i<len-1; i+=2, ++ii) for (; i<len-1; i+=2, ++ii)
for (size_t j=