#include #include #include #include #include #include #ifdef __GNUC__ #define NOINLINE __attribute__((noinline)) #define restrict __restrict__ #else #define NOINLINE #define restrict #endif using namespace std; namespace { template struct arr { private: T *p; size_t sz; static T *ralloc(size_t num) { return (T *)aligned_alloc(64,num*sizeof(T)); } static void dealloc(T *ptr) { free(ptr); } public: arr() : p(0), sz(0) {} arr(size_t n) : p(ralloc(n)), sz(n) { if ((!p) && (n!=0)) throw bad_alloc(); } ~arr() { dealloc(p); } void resize(size_t n) { if (n==sz) return; dealloc(p); p = ralloc(n); sz = n; } T &operator[](size_t idx) { return p[idx]; } const T &operator[](size_t idx) const { return p[idx]; } T *data() { return p; } const T *data() const { return p; } size_t size() const { return sz; } }; class sincos_2pibyn { private: arr data; // adapted from https://stackoverflow.com/questions/42792939/ // CAUTION: this function only works for arguments in the range [-0.25; 0.25]! void my_sincosm1pi (double a, double *restrict res) { double s = a * a; /* Approximate cos(pi*x)-1 for x in [-0.25,0.25] */ double r = -1.0369917389758117e-4; r = fma (r, s, 1.9294935641298806e-3); r = fma (r, s, -2.5806887942825395e-2); r = fma (r, s, 2.3533063028328211e-1); r = fma (r, s, -1.3352627688538006e+0); r = fma (r, s, 4.0587121264167623e+0); r = fma (r, s, -4.9348022005446790e+0); double c = r*s; /* Approximate sin(pi*x) for x in [-0.25,0.25] */ r = 4.6151442520157035e-4; r = fma (r, s, -7.3700183130883555e-3); r = fma (r, s, 8.2145868949323936e-2); r = fma (r, s, -5.9926452893214921e-1); r = fma (r, s, 2.5501640398732688e+0); r = fma (r, s, -5.1677127800499516e+0); s = s * a; r = r * s; s = fma (a, 3.1415926535897931e+0, r); res[0] = c; res[1] = s; } NOINLINE void calc_first_octant(size_t den, double * restrict res) { size_t n = (den+4)>>3; if (n==0) return; res[0]=1.; res[1]=0.; if (n==1) return; size_t l1=(size_t)sqrt(n); for (size_t i=1; in) end = n-start; for (size_t i=1; i>2; size_t i=0, idx1=0, idx2=2*ndone-2; for (; i+1>1; double * p = res+n-1; calc_first_octant(n<<2, p); int i4=0, in=n, i=0; for (; i4<=in-i4; ++i, i4+=4) // octant 0 { res[2*i] = p[2*i4]; res[2*i+1] = p[2*i4+1]; } for (; i4-in <= 0; ++i, i4+=4) // octant 1 { int xm = in-i4; res[2*i] = p[2*xm+1]; res[2*i+1] = p[2*xm]; } for (; i4<=3*in-i4; ++i, i4+=4) // octant 2 { int xm = i4-in; res[2*i] = -p[2*xm+1]; res[2*i+1] = p[2*xm]; } for (; i>2; if ((n&7)==0) res[quart] = res[quart+1] = hsqt2; for (size_t i=2, j=2*quart-2; i>1; if ((n&3)==0) for (size_t i=0; i>=1; } size_t limit=size_t(sqrt(n+0.01)); for (size_t x=3; x<=limit; x+=2) while ((n/x)*x==n) { res=x; n/=x; limit=size_t(sqrt(n+0.01)); } if (n>1) res=n; return res; } NOINLINE double cost_guess (size_t n) { const double lfp=1.1; // penalty for non-hardcoded larger factors size_t ni=n; double result=0.; while ((n&1)==0) { result+=2; n>>=1; } size_t limit=size_t(sqrt(n+0.01)); for (size_t x=3; x<=limit; x+=2) while ((n/x)*x==n) { result+= (x<=5) ? x : lfp*x; // penalize larger prime factors n/=x; limit=size_t(sqrt(n+0.01)); } if (n>1) result+=(n<=5) ? n : lfp*n; return result*ni; } /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n */ NOINLINE size_t good_size(size_t n) { if (n<=12) return n; size_t bestfac=2*n; for (size_t f2=1; f2=n) bestfac=f235711; return bestfac; } template struct cmplx { T r, i; cmplx() {} cmplx(T r_, T i_) : r(r_), i(i_) {} cmplx &operator+= (const cmplx &other) { r+=other.r; i+=other.i; return *this; } templatecmplx &operator*= (T2 other) { r*=other; i*=other; return *this; } cmplx operator+ (const cmplx &other) const { return cmplx(r+other.r, i+other.i); } cmplx operator- (const cmplx &other) const { return cmplx(r-other.r, i-other.i); } template auto operator* (const T2 &other) const -> cmplx { return {r*other, i*other}; } template auto operator* (const cmplx &other) const -> cmplx { return {r*other.r-i*other.i, r*other.i + i*other.r}; } template auto special_mul (const cmplx &other) const -> cmplx { if (bwd) return {r*other.r-i*other.i, r*other.i + i*other.r}; else return {r*other.r+i*other.i, i*other.r - r*other.i}; } }; template void PMC(cmplx &a, cmplx &b, const cmplx &c, const cmplx &d) { a = c+d; b = c-d; } template cmplx conj(const cmplx &a) { return {a.r, -a.i}; } template void ROT90(cmplx &a) { auto tmp_=a.r; a.r=-a.i; a.i=tmp_; } template void ROTM90(cmplx &a) { auto tmp_=-a.r; a.r=a.i; a.i=tmp_; } #define CH(a,b,c) ch[(a)+ido*((b)+l1*(c))] #define CC(a,b,c) cc[(a)+ido*((b)+cdim*(c))] #define WA(x,i) wa[(i)-1+(x)*(ido-1)] constexpr size_t NFCT=25; template class cfftp { private: struct fctdata { size_t fct; cmplx *tw, *tws; }; size_t length, nfct; arr> mem; fctdata fct[NFCT]; void add_factor(size_t factor) { if (nfct>=NFCT) throw runtime_error("too many prime factors"); fct[nfct++].fct = factor; } template NOINLINE void pass2 (size_t ido, size_t l1, const T * restrict cc, T * restrict ch, const cmplx * restrict wa) { constexpr size_t cdim=2; if (ido==1) for (size_t k=0; k(WA(0,i)); } } } #define PREP3(idx) \ T t0 = CC(idx,0,k), t1, t2; \ PMC (t1,t2,CC(idx,1,k),CC(idx,2,k)); \ CH(idx,k,0)=t0+t1; #define PARTSTEP3a(u1,u2,twr,twi) \ { \ T ca,cb; \ ca=t0+t1*twr; \ cb=t2*twi; ROT90(cb); \ PMC(CH(0,k,u1),CH(0,k,u2),ca,cb) ;\ } #define PARTSTEP3b(u1,u2,twr,twi) \ { \ T ca,cb,da,db; \ ca=t0+t1*twr; \ cb=t2*twi; ROT90(cb); \ PMC(da,db,ca,cb); \ CH(i,k,u1) = da.template special_mul(WA(u1-1,i)); \ CH(i,k,u2) = db.template special_mul(WA(u2-1,i)); \ } template NOINLINE void pass3 (size_t ido, size_t l1, const T * restrict cc, T * restrict ch, const cmplx * restrict wa) { constexpr size_t cdim=3; constexpr T0 tw1r=-0.5, tw1i= (bwd ? 1.: -1.) * 0.86602540378443864676; if (ido==1) for (size_t k=0; k NOINLINE void pass4 (size_t ido, size_t l1, const T * restrict cc, T * restrict ch, const cmplx * restrict wa) { constexpr size_t cdim=4; if (ido==1) for (size_t k=0; k wa0=WA(0,i), wa1=WA(1,i),wa2=WA(2,i); PMC(CH(i,k,0),c3,t2,t3); PMC(c2,c4,t1,t4); CH(i,k,1) = c2.template special_mul(wa0); CH(i,k,2) = c3.template special_mul(wa1); CH(i,k,3) = c4.template special_mul(wa2); } } } #define PREP5(idx) \ T t0 = CC(idx,0,k), t1, t2, t3, t4; \ PMC (t1,t4,CC(idx,1,k),CC(idx,4,k)); \ PMC (t2,t3,CC(idx,2,k),CC(idx,3,k)); \ CH(idx,k,0).r=t0.r+t1.r+t2.r; \ CH(idx,k,0).i=t0.i+t1.i+t2.i; #define PARTSTEP5a(u1,u2,twar,twbr,twai,twbi) \ { \ T ca,cb; \ ca.r=t0.r+twar*t1.r+twbr*t2.r; \ ca.i=t0.i+twar*t1.i+twbr*t2.i; \ cb.i=twai*t4.r twbi*t3.r; \ cb.r=-(twai*t4.i twbi*t3.i); \ PMC(CH(0,k,u1),CH(0,k,u2),ca,cb); \ } #define PARTSTEP5b(u1,u2,twar,twbr,twai,twbi) \ { \ T ca,cb,da,db; \ ca.r=t0.r+twar*t1.r+twbr*t2.r; \ ca.i=t0.i+twar*t1.i+twbr*t2.i; \ cb.i=twai*t4.r twbi*t3.r; \ cb.r=-(twai*t4.i twbi*t3.i); \ PMC(da,db,ca,cb); \ CH(i,k,u1) = da.template special_mul(WA(u1-1,i)); \ CH(i,k,u2) = db.template special_mul(WA(u2-1,i)); \ } template NOINLINE void pass5 (size_t ido, size_t l1, const T * restrict cc, T * restrict ch, const cmplx * restrict wa) { constexpr size_t cdim=5; constexpr T0 tw1r= 0.3090169943749474241, tw1i= (bwd ? 1.: -1.) * 0.95105651629515357212, tw2r= -0.8090169943749474241, tw2i= (bwd ? 1.: -1.) * 0.58778525229247312917; if (ido==1) for (size_t k=0; k(WA(u1-1,i)); \ CH(i,k,u2) = db.template special_mul(WA(u2-1,i)); \ } template NOINLINE void pass7(size_t ido, size_t l1, const T * restrict cc, T * restrict ch, const cmplx * restrict wa) { constexpr size_t cdim=7; constexpr T0 tw1r= 0.623489801858733530525, tw1i= (bwd ? 1. : -1.) * 0.7818314824680298087084, tw2r= -0.222520933956314404289, tw2i= (bwd ? 1. : -1.) * 0.9749279121818236070181, tw3r= -0.9009688679024191262361, tw3i= (bwd ? 1. : -1.) * 0.4338837391175581204758; if (ido==1) for (size_t k=0; k(WA(u1-1,i)); \ CH(i,k,u2) = db.template special_mul(WA(u2-1,i)); \ } template NOINLINE void pass11 (size_t ido, size_t l1, const T * restrict cc, T * restrict ch, const cmplx * restrict wa) { const size_t cdim=11; const T0 tw1r = 0.8412535328311811688618, tw1i = (bwd ? 1. : -1.) * 0.5406408174555975821076, tw2r = 0.4154150130018864255293, tw2i = (bwd ? 1. : -1.) * 0.9096319953545183714117, tw3r = -0.1423148382732851404438, tw3i = (bwd ? 1. : -1.) * 0.9898214418809327323761, tw4r = -0.6548607339452850640569, tw4i = (bwd ? 1. : -1.) * 0.755749574354258283774, tw5r = -0.9594929736144973898904, tw5i = (bwd ? 1. : -1.) * 0.2817325568414296977114; if (ido==1) for (size_t k=0; k NOINLINE void passg (size_t ido, size_t ip, size_t l1, T * restrict cc, T * restrict ch, const cmplx * restrict wa, const cmplx * restrict csarr) { const size_t cdim=ip; size_t ipph = (ip+1)/2; size_t idl1 = ido*l1; arr> wal(ip); wal[0] = cmplx(1., 0.); for (size_t i=1; i(csarr[i].r,bwd ? csarr[i].i : -csarr[i].i); for (size_t k=0; kip) iwal-=ip; cmplx xwal=wal[iwal]; iwal+=l; if (iwal>ip) iwal-=ip; cmplx xwal2=wal[iwal]; for (size_t ik=0; ikip) iwal-=ip; cmplx xwal=wal[iwal]; for (size_t ik=0; ik(wa[idij]); idij=(jc-1)*(ido-1)+i-1; CX(i,k,jc) = x2.template special_mul(wa[idij]); } } } } #undef CH2 #undef CX2 #undef CX template NOINLINE void pass_all(T c[], T0 fact) { if (length==1) return; size_t l1=1; arr ch(length); T *p1=c, *p2=ch.data(); for(size_t k1=0; k1 (ido, l1, p1, p2, fct[k1].tw); else if(ip==2) pass2(ido, l1, p1, p2, fct[k1].tw); else if(ip==3) pass3 (ido, l1, p1, p2, fct[k1].tw); else if(ip==5) pass5 (ido, l1, p1, p2, fct[k1].tw); else if(ip==7) pass7 (ido, l1, p1, p2, fct[k1].tw); else if(ip==11) pass11 (ido, l1, p1, p2, fct[k1].tw); else { passg(ido, ip, l1, p1, p2, fct[k1].tw, fct[k1].tws); swap(p1,p2); } swap(p1,p2); l1=l2; } if (p1!=c) { if (fact!=1.) for (size_t i=0; i NOINLINE void forward(T c[], T0 fct) { pass_all(c, fct); } template NOINLINE void backward(T c[], T0 fct) { pass_all(c, fct); } private: NOINLINE void factorize() { nfct=0; size_t len=length; while ((len&3)==0) { add_factor(4); len>>=2; } if ((len&1)==0) { len>>=1; // factor 2 should be at the front of the factor list add_factor(2); swap(fct[0].fct, fct[nfct-1].fct); } size_t maxl=(size_t)(sqrt((double)len))+1; for (size_t divisor=3; (len>1)&&(divisor1) add_factor(len); } NOINLINE size_t twsize() const { size_t twsize=0, l1=1; for (size_t k=0; k11) twsize+=ip; l1*=ip; } return twsize; } NOINLINE void comp_twiddle() { sincos_2pibyn twid(length); size_t l1=1; size_t memofs=0; for (size_t k=0; k11) { fct[k].tws=mem.data()+memofs; memofs+=ip; for (size_t j=0; j class fftblue { private: size_t n, n2; cfftp plan; arr mem; T0 *bk, *bkf; template NOINLINE void fft(T c[], T0 fct) { constexpr int isign = bwd ? 1 : -1; arr akf(2*n2); /* initialize a_k and FFT it */ for (size_t m=0; m<2*n; m+=2) { akf[m] = c[m]*bk[m] -isign*c[m+1]*bk[m+1]; akf[m+1] = isign*c[m]*bk[m+1] + c[m+1]*bk[m]; } for (size_t m=2*n; m<2*n2; ++m) akf[m]=0.*c[0]; plan.forward ((cmplx *)akf.data(),1.); /* do the convolution */ for (size_t m=0; m<2*n2; m+=2) { T im = -isign*akf[m]*bkf[m+1] + akf[m+1]*bkf[m]; akf[m ] = akf[m]*bkf[m] + isign*akf[m+1]*bkf[m+1]; akf[m+1] = im; } /* inverse FFT */ plan.backward ((cmplx *)akf.data(),1.); /* multiply by b_k */ for (size_t m=0; m<2*n; m+=2) { c[m] = fct*(bk[m] *akf[m] - isign*bk[m+1]*akf[m+1]); c[m+1] = fct*(isign*bk[m+1]*akf[m] + bk[m] *akf[m+1]); } } public: NOINLINE fftblue(size_t length) : n(length), n2(good_size(n*2-1)), plan(n2), mem(2*(n+n2)), bk(mem.data()), bkf(mem.data()+2*n) { /* initialize b_k */ sincos_2pibyn tmp(2*n); bk[0] = 1; bk[1] = 0; size_t coeff=0; for (size_t m=1; m=2*n) coeff-=2*n; bk[2*m ] = tmp[2*coeff ]; bk[2*m+1] = tmp[2*coeff+1]; } /* initialize the zero-padded, Fourier transformed b_k. Add normalisation. */ T0 xn2 = 1./n2; bkf[0] = bk[0]*xn2; bkf[1] = bk[1]*xn2; for (size_t m=2; m<2*n; m+=2) { bkf[m] = bkf[2*n2-m] = bk[m] *xn2; bkf[m+1] = bkf[2*n2-m+1] = bk[m+1] *xn2; } for (size_t m=2*n;m<=(2*n2-2*n+1);++m) bkf[m]=0.; plan.forward((cmplx *)bkf,1.); } template void backward(T c[], T0 fct) { fft(c,fct); } template void forward(T c[], T0 fct) { fft(c,fct); } }; template class pocketfft_c { private: unique_ptr> packplan; unique_ptr> blueplan; size_t len; public: NOINLINE pocketfft_c(size_t length) : packplan(nullptr), blueplan(nullptr), len(length) { if (length==0) throw runtime_error("zero-length FFT requested"); if ((length<50) || (largest_prime_factor(length)<=sqrt(length))) { packplan=unique_ptr>(new cfftp(length)); return; } double comp1 = cost_guess(length); double comp2 = 2*cost_guess(good_size(2*length-1)); comp2*=1.5; /* fudge factor that appears to give good overall performance */ if (comp2>(new fftblue(length)); else packplan=unique_ptr>(new cfftp(length)); } template void backward(T c[], T0 fct) { packplan ? packplan->backward((cmplx *)c,fct) : blueplan->backward(c,fct); } template void forward(T c[], T0 fct) { packplan ? packplan->forward((cmplx *)c,fct) : blueplan->forward(c,fct); } size_t length() const { return len; } }; struct diminfo { size_t n, s; }; class multiarr { private: vector dim; public: multiarr (size_t ndim_, const size_t *n, const size_t *s) { dim.reserve(ndim_); for (size_t i=0; i dim; vector pos; size_t ofs_, len, str; bool done_; public: multi_iter(const multiarr &arr, size_t idim) : pos(arr.ndim()-1, 0), ofs_(0), len(arr.size(idim)), str(arr.stride(idim)), done_(false) { dim.reserve(arr.ndim()-1); for (size_t i=0; i=0; --i) { ++pos[i]; ofs_ += dim[i].s; if (pos[i] < dim[i].n) return; pos[i] = 0; ofs_ -= dim[i].n*dim[i].s; } done_ = true; } bool done() const { return done_; } size_t offset() const { return ofs_; } size_t length() const { return len; } size_t stride() const { return str; } }; template void pocketfft_general_c(int ndim, const size_t *shape, const size_t *stride_in, const size_t *stride_out, int nax, const size_t *axes, bool forward, const cmplx *data_in, cmplx *data_out, T fct) { // allocate temporary 1D array storage, if necessary size_t tmpsize = 0; for (int iax=0; iax0) || (data_in==data_out)) inplace = true; if (!inplace) if (shape[axes[iax]] > tmpsize) tmpsize = shape[axes[iax]]; } arr> tdata(tmpsize); multiarr a_in(ndim, shape, stride_in), a_out(ndim, shape, stride_out); unique_ptr> plan; for (int iax=0; iaxlength())) plan.reset(new pocketfft_c(it_in.length())); while (!it_in.done()) { if (inplace) { cmplx *d = data_out+it_in.offset(); forward ? plan->forward((T *)d, fct) : plan->backward((T *)d, fct); } else { for (size_t i=0; iforward((T *)tdata.data(), fct) : plan->backward((T *)tdata.data(), fct); for (size_t i=0; i #include namespace py = pybind11; namespace { py::array execute(const py::array &in, bool fwd) { py::array res; vector dims(in.ndim()), s_i(in.ndim()), s_o(in.ndim()), axes(in.ndim()); for (int i=0; i>(dims); else if (in.dtype().is(py::dtype("complex64"))) res = py::array_t>(dims); else throw runtime_error("unsupported data type"); for (int i=0; i(in.ndim(), dims.data(), s_i.data(), s_o.data(), in.ndim(), axes.data(), fwd, (const cmplx *)in.data(), (cmplx *)res.mutable_data(), 1.); else pocketfft_general_c(in.ndim(), dims.data(), s_i.data(), s_o.data(), in.ndim(), axes.data(), fwd, (const cmplx *)in.data(), (cmplx *)res.mutable_data(), 1.); return res; } py::array fftn(const py::array &in) { return execute(in, true); } py::array ifftn(const py::array &in) { return execute(in, false); } } // unnamed namespace PYBIND11_MODULE(pypocketfft, m) { m.def("fftn",&fftn); m.def("ifftn",&ifftn); } #else int main() { int sz=100; vector> data(sz), data2(sz); size_t shape[] = {sz}; size_t stride[] = {1}; size_t axes[] = {0}; pocketfft_general_c(1, shape, stride, stride, 1, axes, true, data.data(), data2.data(), 1.); } #endif