From 07c6a8e7f8a8e710da554f7dff368c3df7d1a722 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Tue, 16 Apr 2019 19:11:56 +0200 Subject: [PATCH] add real FFTs --- pocketfft.cc | 947 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 941 insertions(+), 6 deletions(-) diff --git a/pocketfft.cc b/pocketfft.cc index 65c7050..9a9f642 100644 --- a/pocketfft.cc +++ b/pocketfft.cc @@ -227,11 +227,11 @@ class sincos_2pibyn } public: - sincos_2pibyn(size_t n) - : data(2*n) + sincos_2pibyn(size_t n, bool half) + : data(half ? n : 2*n) { sincos_2pibyn_half(n, data.data()); - fill_second_half(n, data.data()); + if (!half) fill_second_half(n, data.data()); } double operator[](size_t idx) const { return data[idx]; } @@ -883,7 +883,7 @@ template<bool bwd, typename T> NOINLINE void pass_all(T c[], T0 fact) NOINLINE void comp_twiddle() { - sincos_2pibyn twid(length); + sincos_2pibyn twid(length, false); size_t l1=1; size_t memofs=0; for (size_t k=0; k<nfct; ++k) @@ -923,6 +923,806 @@ template<bool bwd, typename T> NOINLINE void pass_all(T c[], T0 fact) } }; + +template<typename T0> class rfftp + { + private: + struct fctdata + { + size_t fct; + T0 *tw, *tws; + }; + + size_t length, nfct; + arr<T0> mem; + fctdata fct[NFCT]; + + void add_factor(size_t factor) + { + if (nfct>=NFCT) throw runtime_error("too many prime factors"); + fct[nfct++].fct = factor; + } + +#define WA(x,i) wa[(i)+(x)*(ido-1)] +#define PM(a,b,c,d) { a=c+d; b=c-d; } +/* (a+ib) = conj(c+id) * (e+if) */ +#define MULPM(a,b,c,d,e,f) { a=c*e+d*f; b=c*f-d*e; } + +#define CC(a,b,c) cc[(a)+ido*((b)+l1*(c))] +#define CH(a,b,c) ch[(a)+ido*((b)+cdim*(c))] + +template<typename T> NOINLINE void radf2 (size_t ido, size_t l1, const T * restrict cc, + T * restrict ch, const T0 * restrict wa) + { + constexpr size_t cdim=2; + + for (size_t k=0; k<l1; k++) + PM (CH(0,0,k),CH(ido-1,1,k),CC(0,k,0),CC(0,k,1)) + if ((ido&1)==0) + for (size_t k=0; k<l1; k++) + { + CH( 0,1,k) = -CC(ido-1,k,1); + CH(ido-1,0,k) = CC(ido-1,k,0); + } + if (ido<=2) return; + for (size_t k=0; k<l1; k++) + for (size_t i=2; i<ido; i+=2) + { + size_t ic=ido-i; + T tr2, ti2; + MULPM (tr2,ti2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1)) + PM (CH(i-1,0,k),CH(ic-1,1,k),CC(i-1,k,0),tr2) + PM (CH(i ,0,k),CH(ic ,1,k),ti2,CC(i ,k,0)) + } + } + +template<typename T> NOINLINE void radf3(size_t ido, size_t l1, const T * restrict cc, + T * restrict ch, const T0 * restrict wa) + { + constexpr size_t cdim=3; + constexpr T0 taur=-0.5, taui=0.86602540378443864676; + + for (size_t k=0; k<l1; k++) + { + T cr2=CC(0,k,1)+CC(0,k,2); + CH(0,0,k) = CC(0,k,0)+cr2; + CH(0,2,k) = taui*(CC(0,k,2)-CC(0,k,1)); + CH(ido-1,1,k) = CC(0,k,0)+taur*cr2; + } + if (ido==1) return; + for (size_t k=0; k<l1; k++) + for (size_t i=2; i<ido; i+=2) + { + size_t ic=ido-i; + T di2, di3, dr2, dr3; + MULPM (dr2,di2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1)) // d2=conj(WA0)*CC1 + MULPM (dr3,di3,WA(1,i-2),WA(1,i-1),CC(i-1,k,2),CC(i,k,2)) // d3=conj(WA1)*CC2 + T cr2=dr2+dr3; // c add + T ci2=di2+di3; + CH(i-1,0,k) = CC(i-1,k,0)+cr2; // c add + CH(i ,0,k) = CC(i ,k,0)+ci2; + T tr2 = CC(i-1,k,0)+taur*cr2; // c add + T ti2 = CC(i ,k,0)+taur*ci2; + T tr3 = taui*(di2-di3); // t3 = taui*i*(d3-d2)? + T ti3 = taui*(dr3-dr2); + PM(CH(i-1,2,k),CH(ic-1,1,k),tr2,tr3) // PM(i) = t2+t3 + PM(CH(i ,2,k),CH(ic ,1,k),ti3,ti2) // PM(ic) = conj(t2-t3) + } + } + +template<typename T> NOINLINE void radf4(size_t ido, size_t l1, const T * restrict cc, + T * restrict ch, const T0 * restrict wa) + { + constexpr size_t cdim=4; + constexpr T0 hsqt2=0.70710678118654752440; + + for (size_t k=0; k<l1; k++) + { + T tr1,tr2; + PM (tr1,CH(0,2,k),CC(0,k,3),CC(0,k,1)) + PM (tr2,CH(ido-1,1,k),CC(0,k,0),CC(0,k,2)) + PM (CH(0,0,k),CH(ido-1,3,k),tr2,tr1) + } + if ((ido&1)==0) + for (size_t k=0; k<l1; k++) + { + T ti1=-hsqt2*(CC(ido-1,k,1)+CC(ido-1,k,3)); + T tr1= hsqt2*(CC(ido-1,k,1)-CC(ido-1,k,3)); + PM (CH(ido-1,0,k),CH(ido-1,2,k),CC(ido-1,k,0),tr1) + PM (CH( 0,3,k),CH( 0,1,k),ti1,CC(ido-1,k,2)) + } + if (ido<=2) return; + for (size_t k=0; k<l1; k++) + for (size_t i=2; i<ido; i+=2) + { + size_t ic=ido-i; + T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3, tr4; + MULPM(cr2,ci2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1)) + MULPM(cr3,ci3,WA(1,i-2),WA(1,i-1),CC(i-1,k,2),CC(i,k,2)) + MULPM(cr4,ci4,WA(2,i-2),WA(2,i-1),CC(i-1,k,3),CC(i,k,3)) + PM(tr1,tr4,cr4,cr2) + PM(ti1,ti4,ci2,ci4) + PM(tr2,tr3,CC(i-1,k,0),cr3) + PM(ti2,ti3,CC(i ,k,0),ci3) + PM(CH(i-1,0,k),CH(ic-1,3,k),tr2,tr1) + PM(CH(i ,0,k),CH(ic ,3,k),ti1,ti2) + PM(CH(i-1,2,k),CH(ic-1,1,k),tr3,ti4) + PM(CH(i ,2,k),CH(ic ,1,k),tr4,ti3) + } + } + +template<typename T> NOINLINE void radf5(size_t ido, size_t l1, const T * restrict cc, + T * restrict ch, const T0 * restrict wa) + { + constexpr size_t cdim=5; + constexpr T0 tr11= 0.3090169943749474241, ti11=0.95105651629515357212, + tr12=-0.8090169943749474241, ti12=0.58778525229247312917; + + for (size_t k=0; k<l1; k++) + { + T cr2, cr3, ci4, ci5; + PM (cr2,ci5,CC(0,k,4),CC(0,k,1)) + PM (cr3,ci4,CC(0,k,3),CC(0,k,2)) + CH(0,0,k)=CC(0,k,0)+cr2+cr3; + CH(ido-1,1,k)=CC(0,k,0)+tr11*cr2+tr12*cr3; + CH(0,2,k)=ti11*ci5+ti12*ci4; + CH(ido-1,3,k)=CC(0,k,0)+tr12*cr2+tr11*cr3; + CH(0,4,k)=ti12*ci5-ti11*ci4; + } + if (ido==1) return; + for (size_t k=0; k<l1;++k) + for (size_t i=2; i<ido; i+=2) + { + T ci2, di2, ci4, ci5, di3, di4, di5, ci3, cr2, cr3, dr2, dr3, + dr4, dr5, cr5, cr4, ti2, ti3, ti5, ti4, tr2, tr3, tr4, tr5; + size_t ic=ido-i; + MULPM (dr2,di2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1)) + MULPM (dr3,di3,WA(1,i-2),WA(1,i-1),CC(i-1,k,2),CC(i,k,2)) + MULPM (dr4,di4,WA(2,i-2),WA(2,i-1),CC(i-1,k,3),CC(i,k,3)) + MULPM (dr5,di5,WA(3,i-2),WA(3,i-1),CC(i-1,k,4),CC(i,k,4)) + PM(cr2,ci5,dr5,dr2) + PM(ci2,cr5,di2,di5) + PM(cr3,ci4,dr4,dr3) + PM(ci3,cr4,di3,di4) + CH(i-1,0,k)=CC(i-1,k,0)+cr2+cr3; + CH(i ,0,k)=CC(i ,k,0)+ci2+ci3; + tr2=CC(i-1,k,0)+tr11*cr2+tr12*cr3; + ti2=CC(i ,k,0)+tr11*ci2+tr12*ci3; + tr3=CC(i-1,k,0)+tr12*cr2+tr11*cr3; + ti3=CC(i ,k,0)+tr12*ci2+tr11*ci3; + MULPM(tr5,tr4,cr5,cr4,ti11,ti12) + MULPM(ti5,ti4,ci5,ci4,ti11,ti12) + PM(CH(i-1,2,k),CH(ic-1,1,k),tr2,tr5) + PM(CH(i ,2,k),CH(ic ,1,k),ti5,ti2) + PM(CH(i-1,4,k),CH(ic-1,3,k),tr3,tr4) + PM(CH(i ,4,k),CH(ic ,3,k),ti4,ti3) + } + } + +#undef CC +#undef CH +#define C1(a,b,c) cc[(a)+ido*((b)+l1*(c))] +#define C2(a,b) cc[(a)+idl1*(b)] +#define CH2(a,b) ch[(a)+idl1*(b)] +#define CC(a,b,c) cc[(a)+ido*((b)+cdim*(c))] +#define CH(a,b,c) ch[(a)+ido*((b)+l1*(c))] +template<typename T> NOINLINE void radfg(size_t ido, size_t ip, size_t l1, + T0 * restrict cc, T0 * restrict ch, const T * restrict wa, + const T * restrict csarr) + { + const size_t cdim=ip; + size_t ipph=(ip+1)/2; + size_t idl1 = ido*l1; + + if (ido>1) + { + for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc) // 114 + { + size_t is=(j-1)*(ido-1), + is2=(jc-1)*(ido-1); + for (size_t k=0; k<l1; ++k) // 113 + { + size_t idij=is; + size_t idij2=is2; + for (size_t i=1; i<=ido-2; i+=2) // 112 + { + T t1=C1(i,k,j ), t2=C1(i+1,k,j ), + t3=C1(i,k,jc), t4=C1(i+1,k,jc); + T x1=wa[idij]*t1 + wa[idij+1]*t2, + x2=wa[idij]*t2 - wa[idij+1]*t1, + x3=wa[idij2]*t3 + wa[idij2+1]*t4, + x4=wa[idij2]*t4 - wa[idij2+1]*t3; + C1(i ,k,j ) = x1+x3; + C1(i ,k,jc) = x2-x4; + C1(i+1,k,j ) = x2+x4; + C1(i+1,k,jc) = x3-x1; + idij+=2; + idij2+=2; + } + } + } + } + + for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc) // 123 + for (size_t k=0; k<l1; ++k) // 122 + { + T t1=C1(0,k,j), t2=C1(0,k,jc); + C1(0,k,j ) = t1+t2; + C1(0,k,jc) = t2-t1; + } + +//everything in C +//memset(ch,0,ip*l1*ido*sizeof(double)); + + for (size_t l=1,lc=ip-1; l<ipph; ++l,--lc) // 127 + { + for (size_t ik=0; ik<idl1; ++ik) // 124 + { + CH2(ik,l ) = C2(ik,0)+csarr[2*l]*C2(ik,1)+csarr[4*l]*C2(ik,2); + CH2(ik,lc) = csarr[2*l+1]*C2(ik,ip-1)+csarr[4*l+1]*C2(ik,ip-2); + } + size_t iang = 2*l; + size_t j=3, jc=ip-3; + for (; j<ipph-3; j+=4,jc-=4) // 126 + { + iang+=l; if (iang>=ip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1]; + for (size_t ik=0; ik<idl1; ++ik) // 125 + { + CH2(ik,l ) += ar1*C2(ik,j )+ar2*C2(ik,j +1) + +ar3*C2(ik,j +2)+ar4*C2(ik,j +3); + CH2(ik,lc) += ai1*C2(ik,jc)+ai2*C2(ik,jc-1) + +ai3*C2(ik,jc-2)+ai4*C2(ik,jc-3); + } + } + for (; j<ipph-1; j+=2,jc-=2) // 126 + { + iang+=l; if (iang>=ip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + for (size_t ik=0; ik<idl1; ++ik) // 125 + { + CH2(ik,l ) += ar1*C2(ik,j )+ar2*C2(ik,j +1); + CH2(ik,lc) += ai1*C2(ik,jc)+ai2*C2(ik,jc-1); + } + } + for (; j<ipph; ++j,--jc) // 126 + { + iang+=l; if (iang>=ip) iang-=ip; + T0 ar=csarr[2*iang], ai=csarr[2*iang+1]; + for (size_t ik=0; ik<idl1; ++ik) // 125 + { + CH2(ik,l ) += ar*C2(ik,j ); + CH2(ik,lc) += ai*C2(ik,jc); + } + } + } + for (size_t ik=0; ik<idl1; ++ik) // 101 + CH2(ik,0) = C2(ik,0); + for (size_t j=1; j<ipph; ++j) // 129 + for (size_t ik=0; ik<idl1; ++ik) // 128 + CH2(ik,0) += C2(ik,j); + +// everything in CH at this point! +//memset(cc,0,ip*l1*ido*sizeof(double)); + + for (size_t k=0; k<l1; ++k) // 131 + for (size_t i=0; i<ido; ++i) // 130 + CC(i,0,k) = CH(i,k,0); + + for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc) // 137 + { + size_t j2=2*j-1; + for (size_t k=0; k<l1; ++k) // 136 + { + CC(ido-1,j2,k) = CH(0,k,j); + CC(0,j2+1,k) = CH(0,k,jc); + } + } + + if (ido==1) return; + + for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc) // 140 + { + size_t j2=2*j-1; + for(size_t k=0; k<l1; ++k) // 139 + for(size_t i=1, ic=ido-i-2; i<=ido-2; i+=2, ic-=2) // 138 + { + CC(i ,j2+1,k) = CH(i ,k,j )+CH(i ,k,jc); + CC(ic ,j2 ,k) = CH(i ,k,j )-CH(i ,k,jc); + CC(i+1 ,j2+1,k) = CH(i+1,k,j )+CH(i+1,k,jc); + CC(ic+1,j2 ,k) = CH(i+1,k,jc)-CH(i+1,k,j ); + } + } + } +#undef C1 +#undef C2 +#undef CH2 + +#undef CH +#undef CC +#define CH(a,b,c) ch[(a)+ido*((b)+l1*(c))] +#define CC(a,b,c) cc[(a)+ido*((b)+cdim*(c))] + +template<typename T> NOINLINE void radb2(size_t ido, size_t l1, const T * restrict cc, + T * restrict ch, const T0 * restrict wa) + { + constexpr size_t cdim=2; + + for (size_t k=0; k<l1; k++) + PM (CH(0,k,0),CH(0,k,1),CC(0,0,k),CC(ido-1,1,k)) + if ((ido&1)==0) + for (size_t k=0; k<l1; k++) + { + CH(ido-1,k,0) = 2.*CC(ido-1,0,k); + CH(ido-1,k,1) =-2.*CC(0 ,1,k); + } + if (ido<=2) return; + for (size_t k=0; k<l1;++k) + for (size_t i=2; i<ido; i+=2) + { + size_t ic=ido-i; + T ti2, tr2; + PM (CH(i-1,k,0),tr2,CC(i-1,0,k),CC(ic-1,1,k)) + PM (ti2,CH(i ,k,0),CC(i ,0,k),CC(ic ,1,k)) + MULPM (CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),ti2,tr2) + } + } + +template<typename T> NOINLINE void radb3(size_t ido, size_t l1, const T * restrict cc, + T * restrict ch, const T0 * restrict wa) + { + constexpr size_t cdim=3; + constexpr T0 taur=-0.5, taui=0.86602540378443864676; + + for (size_t k=0; k<l1; k++) + { + T tr2=2.*CC(ido-1,1,k); + T cr2=CC(0,0,k)+taur*tr2; + CH(0,k,0)=CC(0,0,k)+tr2; + T ci3=2.*taui*CC(0,2,k); + PM (CH(0,k,2),CH(0,k,1),cr2,ci3); + } + if (ido==1) return; + for (size_t k=0; k<l1; k++) + for (size_t i=2; i<ido; i+=2) + { + size_t ic=ido-i; + T tr2=CC(i-1,2,k)+CC(ic-1,1,k); // t2=CC(I) + conj(CC(ic)) + T ti2=CC(i ,2,k)-CC(ic ,1,k); + T cr2=CC(i-1,0,k)+taur*tr2; // c2=CC +taur*t2 + T ci2=CC(i ,0,k)+taur*ti2; + CH(i-1,k,0)=CC(i-1,0,k)+tr2; // CH=CC+t2 + CH(i ,k,0)=CC(i ,0,k)+ti2; + T cr3=taui*(CC(i-1,2,k)-CC(ic-1,1,k));// c3=taui*(CC(i)-conj(CC(ic))) + T ci3=taui*(CC(i ,2,k)+CC(ic ,1,k)); + T di2, di3, dr2, dr3; + PM(dr3,dr2,cr2,ci3) // d2= (cr2-ci3, ci2+cr3) = c2+i*c3 + PM(di2,di3,ci2,cr3) // d3= (cr2+ci3, ci2-cr3) = c2-i*c3 + MULPM(CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),di2,dr2) // ch = WA*d2 + MULPM(CH(i,k,2),CH(i-1,k,2),WA(1,i-2),WA(1,i-1),di3,dr3) + } + } + +template<typename T> NOINLINE void radb4(size_t ido, size_t l1, const T * restrict cc, + T * restrict ch, const T0 * restrict wa) + { + constexpr size_t cdim=4; + constexpr T0 sqrt2=1.41421356237309504880; + + for (size_t k=0; k<l1; k++) + { + T tr1, tr2; + PM (tr2,tr1,CC(0,0,k),CC(ido-1,3,k)) + T tr3=2.*CC(ido-1,1,k); + T tr4=2.*CC(0,2,k); + PM (CH(0,k,0),CH(0,k,2),tr2,tr3) + PM (CH(0,k,3),CH(0,k,1),tr1,tr4) + } + if ((ido&1)==0) + for (size_t k=0; k<l1; k++) + { + T tr1,tr2,ti1,ti2; + PM (ti1,ti2,CC(0 ,3,k),CC(0 ,1,k)) + PM (tr2,tr1,CC(ido-1,0,k),CC(ido-1,2,k)) + CH(ido-1,k,0)=tr2+tr2; + CH(ido-1,k,1)=sqrt2*(tr1-ti1); + CH(ido-1,k,2)=ti2+ti2; + CH(ido-1,k,3)=-sqrt2*(tr1+ti1); + } + if (ido<=2) return; + for (size_t k=0; k<l1;++k) + for (size_t i=2; i<ido; i+=2) + { + T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3, tr4; + size_t ic=ido-i; + PM (tr2,tr1,CC(i-1,0,k),CC(ic-1,3,k)) + PM (ti1,ti2,CC(i ,0,k),CC(ic ,3,k)) + PM (tr4,ti3,CC(i ,2,k),CC(ic ,1,k)) + PM (tr3,ti4,CC(i-1,2,k),CC(ic-1,1,k)) + PM (CH(i-1,k,0),cr3,tr2,tr3) + PM (CH(i ,k,0),ci3,ti2,ti3) + PM (cr4,cr2,tr1,tr4) + PM (ci2,ci4,ti1,ti4) + MULPM (CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),ci2,cr2) + MULPM (CH(i,k,2),CH(i-1,k,2),WA(1,i-2),WA(1,i-1),ci3,cr3) + MULPM (CH(i,k,3),CH(i-1,k,3),WA(2,i-2),WA(2,i-1),ci4,cr4) + } + } + +template<typename T>NOINLINE void radb5(size_t ido, size_t l1, const T * restrict cc, + T * restrict ch, const T0 * restrict wa) + { + constexpr size_t cdim=5; + constexpr T0 tr11= 0.3090169943749474241, ti11=0.95105651629515357212, + tr12=-0.8090169943749474241, ti12=0.58778525229247312917; + + for (size_t k=0; k<l1; k++) + { + T ti5=CC(0,2,k)+CC(0,2,k); + T ti4=CC(0,4,k)+CC(0,4,k); + T tr2=CC(ido-1,1,k)+CC(ido-1,1,k); + T tr3=CC(ido-1,3,k)+CC(ido-1,3,k); + CH(0,k,0)=CC(0,0,k)+tr2+tr3; + T cr2=CC(0,0,k)+tr11*tr2+tr12*tr3; + T cr3=CC(0,0,k)+tr12*tr2+tr11*tr3; + T ci4, ci5; + MULPM(ci5,ci4,ti5,ti4,ti11,ti12) + PM(CH(0,k,4),CH(0,k,1),cr2,ci5) + PM(CH(0,k,3),CH(0,k,2),cr3,ci4) + } + if (ido==1) return; + for (size_t k=0; k<l1;++k) + for (size_t i=2; i<ido; i+=2) + { + size_t ic=ido-i; + T tr2, tr3, tr4, tr5, ti2, ti3, ti4, ti5; + PM(tr2,tr5,CC(i-1,2,k),CC(ic-1,1,k)) + PM(ti5,ti2,CC(i ,2,k),CC(ic ,1,k)) + PM(tr3,tr4,CC(i-1,4,k),CC(ic-1,3,k)) + PM(ti4,ti3,CC(i ,4,k),CC(ic ,3,k)) + CH(i-1,k,0)=CC(i-1,0,k)+tr2+tr3; + CH(i ,k,0)=CC(i ,0,k)+ti2+ti3; + T cr2=CC(i-1,0,k)+tr11*tr2+tr12*tr3; + T ci2=CC(i ,0,k)+tr11*ti2+tr12*ti3; + T cr3=CC(i-1,0,k)+tr12*tr2+tr11*tr3; + T ci3=CC(i ,0,k)+tr12*ti2+tr11*ti3; + T ci4, ci5, cr5, cr4; + MULPM(cr5,cr4,tr5,tr4,ti11,ti12) + MULPM(ci5,ci4,ti5,ti4,ti11,ti12) + T dr2, dr3, dr4, dr5, di2, di3, di4, di5; + PM(dr4,dr3,cr3,ci4) + PM(di3,di4,ci3,cr4) + PM(dr5,dr2,cr2,ci5) + PM(di2,di5,ci2,cr5) + MULPM(CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),di2,dr2) + MULPM(CH(i,k,2),CH(i-1,k,2),WA(1,i-2),WA(1,i-1),di3,dr3) + MULPM(CH(i,k,3),CH(i-1,k,3),WA(2,i-2),WA(2,i-1),di4,dr4) + MULPM(CH(i,k,4),CH(i-1,k,4),WA(3,i-2),WA(3,i-1),di5,dr5) + } + } + +#undef CC +#undef CH +#define CC(a,b,c) cc[(a)+ido*((b)+cdim*(c))] +#define CH(a,b,c) ch[(a)+ido*((b)+l1*(c))] +#define C1(a,b,c) cc[(a)+ido*((b)+l1*(c))] +#define C2(a,b) cc[(a)+idl1*(b)] +#define CH2(a,b) ch[(a)+idl1*(b)] + +template<typename T> NOINLINE void radbg(size_t ido, size_t ip, size_t l1, + T * restrict cc, T * restrict ch, const T0 * restrict wa, + const T0 * restrict csarr) + { + const size_t cdim=ip; + size_t ipph=(ip+1)/ 2; + size_t idl1 = ido*l1; + + for (size_t k=0; k<l1; ++k) // 102 + for (size_t i=0; i<ido; ++i) // 101 + CH(i,k,0) = CC(i,0,k); + for (size_t j=1, jc=ip-1; j<ipph; ++j, --jc) // 108 + { + size_t j2=2*j-1; + for (size_t k=0; k<l1; ++k) + { + CH(0,k,j ) = 2*CC(ido-1,j2,k); + CH(0,k,jc) = 2*CC(0,j2+1,k); + } + } + + if (ido!=1) + { + for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc) // 111 + { + size_t j2=2*j-1; + for (size_t k=0; k<l1; ++k) + for (size_t i=1, ic=ido-i-2; i<=ido-2; i+=2, ic-=2) // 109 + { + CH(i ,k,j ) = CC(i ,j2+1,k)+CC(ic ,j2,k); + CH(i ,k,jc) = CC(i ,j2+1,k)-CC(ic ,j2,k); + CH(i+1,k,j ) = CC(i+1,j2+1,k)-CC(ic+1,j2,k); + CH(i+1,k,jc) = CC(i+1,j2+1,k)+CC(ic+1,j2,k); + } + } + } + for (size_t l=1,lc=ip-1; l<ipph; ++l,--lc) + { + for (size_t ik=0; ik<idl1; ++ik) + { + C2(ik,l ) = CH2(ik,0)+csarr[2*l]*CH2(ik,1)+csarr[4*l]*CH2(ik,2); + C2(ik,lc) = csarr[2*l+1]*CH2(ik,ip-1)+csarr[4*l+1]*CH2(ik,ip-2); + } + size_t iang=2*l; + size_t j=3,jc=ip-3; + for(; j<ipph-3; j+=4,jc-=4) + { + iang+=l; if(iang>ip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1]; + for (size_t ik=0; ik<idl1; ++ik) + { + C2(ik,l ) += ar1*CH2(ik,j )+ar2*CH2(ik,j +1) + +ar3*CH2(ik,j +2)+ar4*CH2(ik,j +3); + C2(ik,lc) += ai1*CH2(ik,jc)+ai2*CH2(ik,jc-1) + +ai3*CH2(ik,jc-2)+ai4*CH2(ik,jc-3); + } + } + for(; j<ipph-1; j+=2,jc-=2) + { + iang+=l; if(iang>ip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + for (size_t ik=0; ik<idl1; ++ik) + { + C2(ik,l ) += ar1*CH2(ik,j )+ar2*CH2(ik,j +1); + C2(ik,lc) += ai1*CH2(ik,jc)+ai2*CH2(ik,jc-1); + } + } + for(; j<ipph; ++j,--jc) + { + iang+=l; if(iang>ip) iang-=ip; + T0 war=csarr[2*iang], wai=csarr[2*iang+1]; + for (size_t ik=0; ik<idl1; ++ik) + { + C2(ik,l ) += war*CH2(ik,j ); + C2(ik,lc) += wai*CH2(ik,jc); + } + } + } + for (size_t j=1; j<ipph; ++j) + for (size_t ik=0; ik<idl1; ++ik) + CH2(ik,0) += CH2(ik,j); + for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc) // 124 + for (size_t k=0; k<l1; ++k) + { + CH(0,k,j ) = C1(0,k,j)-C1(0,k,jc); + CH(0,k,jc) = C1(0,k,j)+C1(0,k,jc); + } + + if (ido==1) return; + + for (size_t j=1, jc=ip-1; j<ipph; ++j, --jc) // 127 + for (size_t k=0; k<l1; ++k) + for (size_t i=1; i<=ido-2; i+=2) + { + CH(i ,k,j ) = C1(i ,k,j)-C1(i+1,k,jc); + CH(i ,k,jc) = C1(i ,k,j)+C1(i+1,k,jc); + CH(i+1,k,j ) = C1(i+1,k,j)+C1(i ,k,jc); + CH(i+1,k,jc) = C1(i+1,k,j)-C1(i ,k,jc); + } + +// All in CH + + for (size_t j=1; j<ip; ++j) + { + size_t is = (j-1)*(ido-1); + for (size_t k=0; k<l1; ++k) + { + size_t idij = is; + for (size_t i=1; i<=ido-2; i+=2) + { + T t1=CH(i,k,j), t2=CH(i+1,k,j); + CH(i ,k,j) = wa[idij]*t1-wa[idij+1]*t2; + CH(i+1,k,j) = wa[idij]*t2+wa[idij+1]*t1; + idij+=2; + } + } + } + } +#undef C1 +#undef C2 +#undef CH2 + +#undef CC +#undef CH +#undef PM +#undef MULPM +#undef WA + +template<typename T> static void copy_and_norm(T *c, T *p1, size_t n, T0 fct) + { + if (p1!=c) + { + if (fct!=1.) + for (size_t i=0; i<n; ++i) + c[i] = fct*p1[i]; + else + memcpy (c,p1,n*sizeof(T)); + } + else + if (fct!=1.) + for (size_t i=0; i<n; ++i) + c[i] *= fct; + } + + public: + +template<typename T> void forward(T c[], T0 fact) + { + if (length==1) return; + size_t n=length; + size_t l1=n, nf=nfct; + arr<T> ch(n); + T *p1=c, *p2=ch.data(); + + for(size_t k1=0; k1<nf;++k1) + { + size_t k=nf-k1-1; + size_t ip=fct[k].fct; + size_t ido=n / l1; + l1 /= ip; + if(ip==4) + radf4(ido, l1, p1, p2, fct[k].tw); + else if(ip==2) + radf2(ido, l1, p1, p2, fct[k].tw); + else if(ip==3) + radf3(ido, l1, p1, p2, fct[k].tw); + else if(ip==5) + radf5(ido, l1, p1, p2, fct[k].tw); + else + { + radfg(ido, ip, l1, p1, p2, fct[k].tw, fct[k].tws); + swap (p1,p2); + } + swap (p1,p2); + } + copy_and_norm(c,p1,n,fact); + } + +template<typename T> void backward(T c[], T0 fact) + { + if (length==1) return; + size_t n=length; + size_t l1=1, nf=nfct; + arr<T> ch(n); + T *p1=c, *p2=ch.data(); + + for(size_t k=0; k<nf; k++) + { + size_t ip = fct[k].fct, + ido= n/(ip*l1); + if(ip==4) + radb4(ido, l1, p1, p2, fct[k].tw); + else if(ip==2) + radb2(ido, l1, p1, p2, fct[k].tw); + else if(ip==3) + radb3(ido, l1, p1, p2, fct[k].tw); + else if(ip==5) + radb5(ido, l1, p1, p2, fct[k].tw); + else + radbg(ido, ip, l1, p1, p2, fct[k].tw, fct[k].tws); + swap (p1,p2); + l1*=ip; + } + copy_and_norm(c,p1,n,fact); + } + + private: + +void factorize() + { + size_t len=length; + nfct=0; + while ((len%4)==0) + { add_factor(4); len>>=2; } + if ((len%2)==0) + { + len>>=1; + // factor 2 should be at the front of the factor list + add_factor(2); + swap(fct[0].fct, fct[nfct-1].fct); + } + size_t maxl=(size_t)(sqrt((double)len))+1; + for (size_t divisor=3; (len>1)&&(divisor<maxl); divisor+=2) + if ((len%divisor)==0) + { + while ((len%divisor)==0) + { + add_factor(divisor); + length/=divisor; + } + maxl=(size_t)(sqrt((double)len))+1; + } + if (len>1) add_factor(len); + } + +size_t twsize() const + { + size_t twsz=0, l1=1; + for (size_t k=0; k<nfct; ++k) + { + size_t ip=fct[k].fct, ido=length/(l1*ip); + twsz+=(ip-1)*(ido-1); + if (ip>5) twsz+=2*ip; + l1*=ip; + } + return twsz; + } + +NOINLINE void comp_twiddle() + { + sincos_2pibyn twid(length, true); + size_t l1=1; + T0 *ptr=mem.data(); + for (size_t k=0; k<nfct; ++k) + { + size_t ip=fct[k].fct, ido=length/(l1*ip); + if (k<nfct-1) // last factor doesn't need twiddles + { + fct[k].tw=ptr; ptr+=(ip-1)*(ido-1); + for (size_t j=1; j<ip; ++j) + for (size_t i=1; i<=(ido-1)/2; ++i) + { + fct[k].tw[(j-1)*(ido-1)+2*i-2] = twid[2*j*l1*i]; + fct[k].tw[(j-1)*(ido-1)+2*i-1] = twid[2*j*l1*i+1]; + } + } + if (ip>5) // special factors required by *g functions + { + fct[k].tws=ptr; ptr+=2*ip; + fct[k].tws[0] = 1.; + fct[k].tws[1] = 0.; + for (size_t i=1; i<=(ip>>1); ++i) + { + fct[k].tws[2*i ] = twid[2*i*(length/ip)]; + fct[k].tws[2*i+1] = twid[2*i*(length/ip)+1]; + fct[k].tws[2*(ip-i) ] = twid[2*i*(length/ip)]; + fct[k].tws[2*(ip-i)+1] = -twid[2*i*(length/ip)+1]; + } + } + l1*=ip; + } + } + + public: + +NOINLINE rfftp(size_t length_) + : length(length_) + { + if (length==0) throw runtime_error("zero_sized FFT"); + nfct=0; + if (length==1) return; + factorize(); + mem.resize(twsize()); + comp_twiddle(); + } + +}; + template<typename T0> class fftblue { private: @@ -973,7 +1773,7 @@ template<typename T0> class fftblue bk(mem.data()), bkf(mem.data()+2*n) { /* initialize b_k */ - sincos_2pibyn tmp(2*n); + sincos_2pibyn tmp(2*n, false); bk[0] = 1; bk[1] = 0; @@ -1005,6 +1805,36 @@ template<typename T0> class fftblue template<typename T> void forward(T c[], T0 fct) { fft<false>(c,fct); } + + template<typename T> void backward_r(T c[], T0 fct) + { + arr<T> tmp(2*n); + tmp[0]=c[0]; + tmp[1]=0.; + memcpy (tmp.data()+2,c+1, (n-1)*sizeof(T)); + if ((n&1)==0) tmp[n+1]=0.; + for (size_t m=2; m<n; m+=2) + { + tmp[2*n-m]=tmp[m]; + tmp[2*n-m+1]=-tmp[m+1]; + } + fft<true>(tmp.data(),fct); + for (size_t m=0; m<n; ++m) + c[m] = tmp[2*m]; + } + + template<typename T> void forward_r(T c[], T0 fct) + { + arr<T> tmp(2*n); + for (size_t m=0; m<n; ++m) + { + tmp[2*m] = c[m]; + tmp[2*m+1] = 0.; + } + fft<false>(tmp.data(),fct); + c[0] = tmp[0]; + memcpy (c+1, tmp.data()+2, (n-1)*sizeof(T)); + } }; template<typename T0> class pocketfft_c @@ -1016,7 +1846,7 @@ template<typename T0> class pocketfft_c public: NOINLINE pocketfft_c(size_t length) - : packplan(nullptr), blueplan(nullptr), len(length) + : len(length) { if (length==0) throw runtime_error("zero-length FFT requested"); if ((length<50) || (largest_prime_factor(length)<=sqrt(length))) @@ -1047,6 +1877,46 @@ template<typename T0> class pocketfft_c size_t length() const { return len; } }; +template<typename T0> class pocketfft_r + { + private: + unique_ptr<rfftp<T0>> packplan; + unique_ptr<fftblue<T0>> blueplan; + size_t len; + + public: + NOINLINE pocketfft_r(size_t length) + : len(length) + { + if (length==0) throw runtime_error("zero-length FFT requested"); + if ((length<50) || (largest_prime_factor(length)<=sqrt(length))) + { + packplan=unique_ptr<rfftp<T0>>(new rfftp<T0>(length)); + return; + } + double comp1 = 0.5*cost_guess(length); + double comp2 = 2*cost_guess(good_size(2*length-1)); + comp2*=1.5; /* fudge factor that appears to give good overall performance */ + if (comp2<comp1) // use Bluestein + blueplan=unique_ptr<fftblue<T0>>(new fftblue<T0>(length)); + else + packplan=unique_ptr<rfftp<T0>>(new rfftp<T0>(length)); + } + + template<typename T> void backward(T c[], T0 fct) + { + packplan ? packplan->backward(c,fct) + : blueplan->backward_r(c,fct); + } + + template<typename T> void forward(T c[], T0 fct) + { + packplan ? packplan->forward(c,fct) + : blueplan->forward_r(c,fct); + } + + size_t length() const { return len; } + }; struct diminfo { size_t n; int64_t s; }; @@ -1165,10 +2035,35 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape, } } +template<typename T> void pocketfft_general_r(int ndim, const size_t *shape, + const int64_t *stride_in, const int64_t *stride_out, size_t axis, + bool forward, const T *data_in, T *data_out, T fct) + { + // allocate temporary 1D array storage, if necessary + size_t tmpsize = shape[axis]; + arr<T> tdata(tmpsize); + multiarr a_in(ndim, shape, stride_in), a_out(ndim, shape, stride_out); + pocketfft_r<T> plan(shape[axis]); + multi_iter it_in(a_in, axis), it_out(a_out, axis); + while (!it_in.done()) + { + for (size_t i=0; i<it_in.length(); ++i) + tdata[i] = data_in[it_in.offset()+i*it_in.stride()]; + forward ? plan.forward((T *)tdata.data(), fct) + : plan.backward((T *)tdata.data(), fct); + for (size_t i=0; i<it_out.length(); ++i) + data_out[it_out.offset()+i*it_out.stride()] = tdata[i]; + it_in.advance(); + it_out.advance(); + } + } + namespace py = pybind11; auto c64 = py::dtype("complex64"); auto c128 = py::dtype("complex128"); +auto f32 = py::dtype("float32"); +auto f64 = py::dtype("float64"); py::array execute(const py::array &in, bool fwd) { @@ -1208,10 +2103,50 @@ py::array fftn(const py::array &in) { return execute(in, true); } py::array ifftn(const py::array &in) { return execute(in, false); } + +py::array execute_r(const py::array &in, int axis, bool fwd) + { + py::array res; + vector<size_t> dims(in.ndim()); + vector<int64_t> s_i(in.ndim()), s_o(in.ndim()); + for (int i=0; i<in.ndim(); ++i) + dims[i] = in.shape(i); + if (in.dtype().is(f64)) + res = py::array_t<double>(dims); + else if (in.dtype().is(f32)) + res = py::array_t<float>(dims); + else throw runtime_error("unsupported data type"); + for (int i=0; i<in.ndim(); ++i) + { + s_i[i] = in.strides(i)/in.itemsize(); + if (s_i[i]*in.itemsize() != in.strides(i)) + throw runtime_error("weird strides"); + s_o[i] = res.strides(i)/res.itemsize(); + if (s_o[i]*res.itemsize() != res.strides(i)) + throw runtime_error("weird strides"); + } + if (in.dtype().is(f64)) + pocketfft_general_r<double>(in.ndim(), dims.data(), + s_i.data(), s_o.data(), + axis, fwd, (const double *)in.data(), + (double *)res.mutable_data(), 1.); + else + pocketfft_general_r<float>(in.ndim(), dims.data(), + s_i.data(), s_o.data(), + axis, fwd, (const float *)in.data(), + (float *)res.mutable_data(), 1.); + return res; + } +py::array rfftn(const py::array &in, int axis) + { return execute_r(in, axis, true); } +py::array irfftn(const py::array &in, int axis) + { return execute_r(in, axis, false); } } // unnamed namespace PYBIND11_MODULE(pypocketfft, m) { m.def("fftn",&fftn); m.def("ifftn",&ifftn); + m.def("rfftn",&rfftn); + m.def("irfftn",&irfftn); } -- GitLab