From 07c6a8e7f8a8e710da554f7dff368c3df7d1a722 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Tue, 16 Apr 2019 19:11:56 +0200
Subject: [PATCH] add real FFTs

---
 pocketfft.cc | 947 ++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 941 insertions(+), 6 deletions(-)

diff --git a/pocketfft.cc b/pocketfft.cc
index 65c7050..9a9f642 100644
--- a/pocketfft.cc
+++ b/pocketfft.cc
@@ -227,11 +227,11 @@ class sincos_2pibyn
       }
 
   public:
-    sincos_2pibyn(size_t n)
-      : data(2*n)
+    sincos_2pibyn(size_t n, bool half)
+      : data(half ? n : 2*n)
       {
       sincos_2pibyn_half(n, data.data());
-      fill_second_half(n, data.data());
+      if (!half) fill_second_half(n, data.data());
       }
 
     double operator[](size_t idx) const { return data[idx]; }
@@ -883,7 +883,7 @@ template<bool bwd, typename T> NOINLINE void pass_all(T c[], T0 fact)
 
     NOINLINE void comp_twiddle()
       {
-      sincos_2pibyn twid(length);
+      sincos_2pibyn twid(length, false);
       size_t l1=1;
       size_t memofs=0;
       for (size_t k=0; k<nfct; ++k)
@@ -923,6 +923,806 @@ template<bool bwd, typename T> NOINLINE void pass_all(T c[], T0 fact)
       }
   };
 
+
+template<typename T0> class rfftp
+  {
+  private:
+    struct fctdata
+      {
+      size_t fct;
+      T0 *tw, *tws;
+      };
+
+    size_t length, nfct;
+    arr<T0> mem;
+    fctdata fct[NFCT];
+
+    void add_factor(size_t factor)
+      {
+      if (nfct>=NFCT) throw runtime_error("too many prime factors");
+      fct[nfct++].fct = factor;
+      }
+
+#define WA(x,i) wa[(i)+(x)*(ido-1)]
+#define PM(a,b,c,d) { a=c+d; b=c-d; }
+/* (a+ib) = conj(c+id) * (e+if) */
+#define MULPM(a,b,c,d,e,f) { a=c*e+d*f; b=c*f-d*e; }
+
+#define CC(a,b,c) cc[(a)+ido*((b)+l1*(c))]
+#define CH(a,b,c) ch[(a)+ido*((b)+cdim*(c))]
+
+template<typename T> NOINLINE void radf2 (size_t ido, size_t l1, const T * restrict cc,
+  T * restrict ch, const T0 * restrict wa)
+  {
+  constexpr size_t cdim=2;
+
+  for (size_t k=0; k<l1; k++)
+    PM (CH(0,0,k),CH(ido-1,1,k),CC(0,k,0),CC(0,k,1))
+  if ((ido&1)==0)
+    for (size_t k=0; k<l1; k++)
+      {
+      CH(    0,1,k) = -CC(ido-1,k,1);
+      CH(ido-1,0,k) =  CC(ido-1,k,0);
+      }
+  if (ido<=2) return;
+  for (size_t k=0; k<l1; k++)
+    for (size_t i=2; i<ido; i+=2)
+      {
+      size_t ic=ido-i;
+      T tr2, ti2;
+      MULPM (tr2,ti2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1))
+      PM (CH(i-1,0,k),CH(ic-1,1,k),CC(i-1,k,0),tr2)
+      PM (CH(i  ,0,k),CH(ic  ,1,k),ti2,CC(i  ,k,0))
+      }
+  }
+
+template<typename T> NOINLINE void radf3(size_t ido, size_t l1, const T * restrict cc,
+  T * restrict ch, const T0 * restrict wa)
+  {
+  constexpr size_t cdim=3;
+  constexpr T0 taur=-0.5, taui=0.86602540378443864676;
+
+  for (size_t k=0; k<l1; k++)
+    {
+    T cr2=CC(0,k,1)+CC(0,k,2);
+    CH(0,0,k) = CC(0,k,0)+cr2;
+    CH(0,2,k) = taui*(CC(0,k,2)-CC(0,k,1));
+    CH(ido-1,1,k) = CC(0,k,0)+taur*cr2;
+    }
+  if (ido==1) return;
+  for (size_t k=0; k<l1; k++)
+    for (size_t i=2; i<ido; i+=2)
+      {
+      size_t ic=ido-i;
+      T di2, di3, dr2, dr3;
+      MULPM (dr2,di2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1)) // d2=conj(WA0)*CC1
+      MULPM (dr3,di3,WA(1,i-2),WA(1,i-1),CC(i-1,k,2),CC(i,k,2)) // d3=conj(WA1)*CC2
+      T cr2=dr2+dr3; // c add
+      T ci2=di2+di3;
+      CH(i-1,0,k) = CC(i-1,k,0)+cr2; // c add
+      CH(i  ,0,k) = CC(i  ,k,0)+ci2;
+      T tr2 = CC(i-1,k,0)+taur*cr2; // c add
+      T ti2 = CC(i  ,k,0)+taur*ci2;
+      T tr3 = taui*(di2-di3);  // t3 = taui*i*(d3-d2)?
+      T ti3 = taui*(dr3-dr2);
+      PM(CH(i-1,2,k),CH(ic-1,1,k),tr2,tr3) // PM(i) = t2+t3
+      PM(CH(i  ,2,k),CH(ic  ,1,k),ti3,ti2) // PM(ic) = conj(t2-t3)
+      }
+  }
+
+template<typename T> NOINLINE void radf4(size_t ido, size_t l1, const T * restrict cc,
+  T * restrict ch, const T0 * restrict wa)
+  {
+  constexpr size_t cdim=4;
+  constexpr T0 hsqt2=0.70710678118654752440;
+
+  for (size_t k=0; k<l1; k++)
+    {
+    T tr1,tr2;
+    PM (tr1,CH(0,2,k),CC(0,k,3),CC(0,k,1))
+    PM (tr2,CH(ido-1,1,k),CC(0,k,0),CC(0,k,2))
+    PM (CH(0,0,k),CH(ido-1,3,k),tr2,tr1)
+    }
+  if ((ido&1)==0)
+    for (size_t k=0; k<l1; k++)
+      {
+      T ti1=-hsqt2*(CC(ido-1,k,1)+CC(ido-1,k,3));
+      T tr1= hsqt2*(CC(ido-1,k,1)-CC(ido-1,k,3));
+      PM (CH(ido-1,0,k),CH(ido-1,2,k),CC(ido-1,k,0),tr1)
+      PM (CH(    0,3,k),CH(    0,1,k),ti1,CC(ido-1,k,2))
+      }
+  if (ido<=2) return;
+  for (size_t k=0; k<l1; k++)
+    for (size_t i=2; i<ido; i+=2)
+      {
+      size_t ic=ido-i;
+      T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3, tr4;
+      MULPM(cr2,ci2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1))
+      MULPM(cr3,ci3,WA(1,i-2),WA(1,i-1),CC(i-1,k,2),CC(i,k,2))
+      MULPM(cr4,ci4,WA(2,i-2),WA(2,i-1),CC(i-1,k,3),CC(i,k,3))
+      PM(tr1,tr4,cr4,cr2)
+      PM(ti1,ti4,ci2,ci4)
+      PM(tr2,tr3,CC(i-1,k,0),cr3)
+      PM(ti2,ti3,CC(i  ,k,0),ci3)
+      PM(CH(i-1,0,k),CH(ic-1,3,k),tr2,tr1)
+      PM(CH(i  ,0,k),CH(ic  ,3,k),ti1,ti2)
+      PM(CH(i-1,2,k),CH(ic-1,1,k),tr3,ti4)
+      PM(CH(i  ,2,k),CH(ic  ,1,k),tr4,ti3)
+      }
+  }
+
+template<typename T> NOINLINE void radf5(size_t ido, size_t l1, const T * restrict cc,
+  T * restrict ch, const T0 * restrict wa)
+  {
+  constexpr size_t cdim=5;
+  constexpr T0 tr11= 0.3090169943749474241, ti11=0.95105651629515357212,
+               tr12=-0.8090169943749474241, ti12=0.58778525229247312917;
+
+  for (size_t k=0; k<l1; k++)
+    {
+    T cr2, cr3, ci4, ci5;
+    PM (cr2,ci5,CC(0,k,4),CC(0,k,1))
+    PM (cr3,ci4,CC(0,k,3),CC(0,k,2))
+    CH(0,0,k)=CC(0,k,0)+cr2+cr3;
+    CH(ido-1,1,k)=CC(0,k,0)+tr11*cr2+tr12*cr3;
+    CH(0,2,k)=ti11*ci5+ti12*ci4;
+    CH(ido-1,3,k)=CC(0,k,0)+tr12*cr2+tr11*cr3;
+    CH(0,4,k)=ti12*ci5-ti11*ci4;
+    }
+  if (ido==1) return;
+  for (size_t k=0; k<l1;++k)
+    for (size_t i=2; i<ido; i+=2)
+      {
+      T ci2, di2, ci4, ci5, di3, di4, di5, ci3, cr2, cr3, dr2, dr3,
+        dr4, dr5, cr5, cr4, ti2, ti3, ti5, ti4, tr2, tr3, tr4, tr5;
+      size_t ic=ido-i;
+      MULPM (dr2,di2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1))
+      MULPM (dr3,di3,WA(1,i-2),WA(1,i-1),CC(i-1,k,2),CC(i,k,2))
+      MULPM (dr4,di4,WA(2,i-2),WA(2,i-1),CC(i-1,k,3),CC(i,k,3))
+      MULPM (dr5,di5,WA(3,i-2),WA(3,i-1),CC(i-1,k,4),CC(i,k,4))
+      PM(cr2,ci5,dr5,dr2)
+      PM(ci2,cr5,di2,di5)
+      PM(cr3,ci4,dr4,dr3)
+      PM(ci3,cr4,di3,di4)
+      CH(i-1,0,k)=CC(i-1,k,0)+cr2+cr3;
+      CH(i  ,0,k)=CC(i  ,k,0)+ci2+ci3;
+      tr2=CC(i-1,k,0)+tr11*cr2+tr12*cr3;
+      ti2=CC(i  ,k,0)+tr11*ci2+tr12*ci3;
+      tr3=CC(i-1,k,0)+tr12*cr2+tr11*cr3;
+      ti3=CC(i  ,k,0)+tr12*ci2+tr11*ci3;
+      MULPM(tr5,tr4,cr5,cr4,ti11,ti12)
+      MULPM(ti5,ti4,ci5,ci4,ti11,ti12)
+      PM(CH(i-1,2,k),CH(ic-1,1,k),tr2,tr5)
+      PM(CH(i  ,2,k),CH(ic  ,1,k),ti5,ti2)
+      PM(CH(i-1,4,k),CH(ic-1,3,k),tr3,tr4)
+      PM(CH(i  ,4,k),CH(ic  ,3,k),ti4,ti3)
+      }
+  }
+
+#undef CC
+#undef CH
+#define C1(a,b,c) cc[(a)+ido*((b)+l1*(c))]
+#define C2(a,b) cc[(a)+idl1*(b)]
+#define CH2(a,b) ch[(a)+idl1*(b)]
+#define CC(a,b,c) cc[(a)+ido*((b)+cdim*(c))]
+#define CH(a,b,c) ch[(a)+ido*((b)+l1*(c))]
+template<typename T> NOINLINE void radfg(size_t ido, size_t ip, size_t l1,
+  T0 * restrict cc, T0 * restrict ch, const T * restrict wa,
+  const T * restrict csarr)
+  {
+  const size_t cdim=ip;
+  size_t ipph=(ip+1)/2;
+  size_t idl1 = ido*l1;
+
+  if (ido>1)
+    {
+    for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc)              // 114
+      {
+      size_t is=(j-1)*(ido-1),
+             is2=(jc-1)*(ido-1);
+      for (size_t k=0; k<l1; ++k)                            // 113
+        {
+        size_t idij=is;
+        size_t idij2=is2;
+        for (size_t i=1; i<=ido-2; i+=2)                      // 112
+          {
+          T t1=C1(i,k,j ), t2=C1(i+1,k,j ),
+                 t3=C1(i,k,jc), t4=C1(i+1,k,jc);
+          T x1=wa[idij]*t1 + wa[idij+1]*t2,
+                 x2=wa[idij]*t2 - wa[idij+1]*t1,
+                 x3=wa[idij2]*t3 + wa[idij2+1]*t4,
+                 x4=wa[idij2]*t4 - wa[idij2+1]*t3;
+          C1(i  ,k,j ) = x1+x3;
+          C1(i  ,k,jc) = x2-x4;
+          C1(i+1,k,j ) = x2+x4;
+          C1(i+1,k,jc) = x3-x1;
+          idij+=2;
+          idij2+=2;
+          }
+        }
+      }
+    }
+
+  for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc)                // 123
+    for (size_t k=0; k<l1; ++k)                              // 122
+      {
+      T t1=C1(0,k,j), t2=C1(0,k,jc);
+      C1(0,k,j ) = t1+t2;
+      C1(0,k,jc) = t2-t1;
+      }
+
+//everything in C
+//memset(ch,0,ip*l1*ido*sizeof(double));
+
+  for (size_t l=1,lc=ip-1; l<ipph; ++l,--lc)                 // 127
+    {
+    for (size_t ik=0; ik<idl1; ++ik)                         // 124
+      {
+      CH2(ik,l ) = C2(ik,0)+csarr[2*l]*C2(ik,1)+csarr[4*l]*C2(ik,2);
+      CH2(ik,lc) = csarr[2*l+1]*C2(ik,ip-1)+csarr[4*l+1]*C2(ik,ip-2);
+      }
+    size_t iang = 2*l;
+    size_t j=3, jc=ip-3;
+    for (; j<ipph-3; j+=4,jc-=4)              // 126
+      {
+      iang+=l; if (iang>=ip) iang-=ip;
+      T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1];
+      iang+=l; if (iang>=ip) iang-=ip;
+      T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1];
+      iang+=l; if (iang>=ip) iang-=ip;
+      T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1];
+      iang+=l; if (iang>=ip) iang-=ip;
+      T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1];
+      for (size_t ik=0; ik<idl1; ++ik)                       // 125
+        {
+        CH2(ik,l ) += ar1*C2(ik,j )+ar2*C2(ik,j +1)
+                     +ar3*C2(ik,j +2)+ar4*C2(ik,j +3);
+        CH2(ik,lc) += ai1*C2(ik,jc)+ai2*C2(ik,jc-1)
+                     +ai3*C2(ik,jc-2)+ai4*C2(ik,jc-3);
+        }
+      }
+    for (; j<ipph-1; j+=2,jc-=2)              // 126
+      {
+      iang+=l; if (iang>=ip) iang-=ip;
+      T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1];
+      iang+=l; if (iang>=ip) iang-=ip;
+      T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1];
+      for (size_t ik=0; ik<idl1; ++ik)                       // 125
+        {
+        CH2(ik,l ) += ar1*C2(ik,j )+ar2*C2(ik,j +1);
+        CH2(ik,lc) += ai1*C2(ik,jc)+ai2*C2(ik,jc-1);
+        }
+      }
+    for (; j<ipph; ++j,--jc)              // 126
+      {
+      iang+=l; if (iang>=ip) iang-=ip;
+      T0 ar=csarr[2*iang], ai=csarr[2*iang+1];
+      for (size_t ik=0; ik<idl1; ++ik)                       // 125
+        {
+        CH2(ik,l ) += ar*C2(ik,j );
+        CH2(ik,lc) += ai*C2(ik,jc);
+        }
+      }
+    }
+  for (size_t ik=0; ik<idl1; ++ik)                         // 101
+    CH2(ik,0) = C2(ik,0);
+  for (size_t j=1; j<ipph; ++j)                              // 129
+    for (size_t ik=0; ik<idl1; ++ik)                         // 128
+      CH2(ik,0) += C2(ik,j);
+
+// everything in CH at this point!
+//memset(cc,0,ip*l1*ido*sizeof(double));
+
+  for (size_t k=0; k<l1; ++k)                                // 131
+    for (size_t i=0; i<ido; ++i)                             // 130
+      CC(i,0,k) = CH(i,k,0);
+
+  for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc)                // 137
+    {
+    size_t j2=2*j-1;
+    for (size_t k=0; k<l1; ++k)                              // 136
+      {
+      CC(ido-1,j2,k) = CH(0,k,j);
+      CC(0,j2+1,k) = CH(0,k,jc);
+      }
+    }
+
+  if (ido==1) return;
+
+  for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc)                // 140
+    {
+    size_t j2=2*j-1;
+    for(size_t k=0; k<l1; ++k)                               // 139
+      for(size_t i=1, ic=ido-i-2; i<=ido-2; i+=2, ic-=2)      // 138
+        {
+        CC(i   ,j2+1,k) = CH(i  ,k,j )+CH(i  ,k,jc);
+        CC(ic  ,j2  ,k) = CH(i  ,k,j )-CH(i  ,k,jc);
+        CC(i+1 ,j2+1,k) = CH(i+1,k,j )+CH(i+1,k,jc);
+        CC(ic+1,j2  ,k) = CH(i+1,k,jc)-CH(i+1,k,j );
+        }
+    }
+  }
+#undef C1
+#undef C2
+#undef CH2
+
+#undef CH
+#undef CC
+#define CH(a,b,c) ch[(a)+ido*((b)+l1*(c))]
+#define CC(a,b,c) cc[(a)+ido*((b)+cdim*(c))]
+
+template<typename T> NOINLINE void radb2(size_t ido, size_t l1, const T * restrict cc,
+  T * restrict ch, const T0 * restrict wa)
+  {
+  constexpr size_t cdim=2;
+
+  for (size_t k=0; k<l1; k++)
+    PM (CH(0,k,0),CH(0,k,1),CC(0,0,k),CC(ido-1,1,k))
+  if ((ido&1)==0)
+    for (size_t k=0; k<l1; k++)
+      {
+      CH(ido-1,k,0) = 2.*CC(ido-1,0,k);
+      CH(ido-1,k,1) =-2.*CC(0    ,1,k);
+      }
+  if (ido<=2) return;
+  for (size_t k=0; k<l1;++k)
+    for (size_t i=2; i<ido; i+=2)
+      {
+      size_t ic=ido-i;
+      T ti2, tr2;
+      PM (CH(i-1,k,0),tr2,CC(i-1,0,k),CC(ic-1,1,k))
+      PM (ti2,CH(i  ,k,0),CC(i  ,0,k),CC(ic  ,1,k))
+      MULPM (CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),ti2,tr2)
+      }
+  }
+
+template<typename T> NOINLINE void radb3(size_t ido, size_t l1, const T * restrict cc,
+  T * restrict ch, const T0 * restrict wa)
+  {
+  constexpr size_t cdim=3;
+  constexpr T0 taur=-0.5, taui=0.86602540378443864676;
+
+  for (size_t k=0; k<l1; k++)
+    {
+    T tr2=2.*CC(ido-1,1,k);
+    T cr2=CC(0,0,k)+taur*tr2;
+    CH(0,k,0)=CC(0,0,k)+tr2;
+    T ci3=2.*taui*CC(0,2,k);
+    PM (CH(0,k,2),CH(0,k,1),cr2,ci3);
+    }
+  if (ido==1) return;
+  for (size_t k=0; k<l1; k++)
+    for (size_t i=2; i<ido; i+=2)
+      {
+      size_t ic=ido-i;
+      T tr2=CC(i-1,2,k)+CC(ic-1,1,k); // t2=CC(I) + conj(CC(ic))
+      T ti2=CC(i  ,2,k)-CC(ic  ,1,k);
+      T cr2=CC(i-1,0,k)+taur*tr2;     // c2=CC +taur*t2
+      T ci2=CC(i  ,0,k)+taur*ti2;
+      CH(i-1,k,0)=CC(i-1,0,k)+tr2;         // CH=CC+t2
+      CH(i  ,k,0)=CC(i  ,0,k)+ti2;
+      T cr3=taui*(CC(i-1,2,k)-CC(ic-1,1,k));// c3=taui*(CC(i)-conj(CC(ic)))
+      T ci3=taui*(CC(i  ,2,k)+CC(ic  ,1,k));
+      T di2, di3, dr2, dr3;
+      PM(dr3,dr2,cr2,ci3) // d2= (cr2-ci3, ci2+cr3) = c2+i*c3
+      PM(di2,di3,ci2,cr3) // d3= (cr2+ci3, ci2-cr3) = c2-i*c3
+      MULPM(CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),di2,dr2) // ch = WA*d2
+      MULPM(CH(i,k,2),CH(i-1,k,2),WA(1,i-2),WA(1,i-1),di3,dr3)
+      }
+  }
+
+template<typename T> NOINLINE void radb4(size_t ido, size_t l1, const T * restrict cc,
+  T * restrict ch, const T0 * restrict wa)
+  {
+  constexpr size_t cdim=4;
+  constexpr T0 sqrt2=1.41421356237309504880;
+
+  for (size_t k=0; k<l1; k++)
+    {
+    T tr1, tr2;
+    PM (tr2,tr1,CC(0,0,k),CC(ido-1,3,k))
+    T tr3=2.*CC(ido-1,1,k);
+    T tr4=2.*CC(0,2,k);
+    PM (CH(0,k,0),CH(0,k,2),tr2,tr3)
+    PM (CH(0,k,3),CH(0,k,1),tr1,tr4)
+    }
+  if ((ido&1)==0)
+    for (size_t k=0; k<l1; k++)
+      {
+      T tr1,tr2,ti1,ti2;
+      PM (ti1,ti2,CC(0    ,3,k),CC(0    ,1,k))
+      PM (tr2,tr1,CC(ido-1,0,k),CC(ido-1,2,k))
+      CH(ido-1,k,0)=tr2+tr2;
+      CH(ido-1,k,1)=sqrt2*(tr1-ti1);
+      CH(ido-1,k,2)=ti2+ti2;
+      CH(ido-1,k,3)=-sqrt2*(tr1+ti1);
+      }
+  if (ido<=2) return;
+  for (size_t k=0; k<l1;++k)
+    for (size_t i=2; i<ido; i+=2)
+      {
+      T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3, tr4;
+      size_t ic=ido-i;
+      PM (tr2,tr1,CC(i-1,0,k),CC(ic-1,3,k))
+      PM (ti1,ti2,CC(i  ,0,k),CC(ic  ,3,k))
+      PM (tr4,ti3,CC(i  ,2,k),CC(ic  ,1,k))
+      PM (tr3,ti4,CC(i-1,2,k),CC(ic-1,1,k))
+      PM (CH(i-1,k,0),cr3,tr2,tr3)
+      PM (CH(i  ,k,0),ci3,ti2,ti3)
+      PM (cr4,cr2,tr1,tr4)
+      PM (ci2,ci4,ti1,ti4)
+      MULPM (CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),ci2,cr2)
+      MULPM (CH(i,k,2),CH(i-1,k,2),WA(1,i-2),WA(1,i-1),ci3,cr3)
+      MULPM (CH(i,k,3),CH(i-1,k,3),WA(2,i-2),WA(2,i-1),ci4,cr4)
+      }
+  }
+
+template<typename T>NOINLINE void radb5(size_t ido, size_t l1, const T * restrict cc,
+  T * restrict ch, const T0 * restrict wa)
+  {
+  constexpr size_t cdim=5;
+  constexpr T0 tr11= 0.3090169943749474241, ti11=0.95105651629515357212,
+               tr12=-0.8090169943749474241, ti12=0.58778525229247312917;
+
+  for (size_t k=0; k<l1; k++)
+    {
+    T ti5=CC(0,2,k)+CC(0,2,k);
+    T ti4=CC(0,4,k)+CC(0,4,k);
+    T tr2=CC(ido-1,1,k)+CC(ido-1,1,k);
+    T tr3=CC(ido-1,3,k)+CC(ido-1,3,k);
+    CH(0,k,0)=CC(0,0,k)+tr2+tr3;
+    T cr2=CC(0,0,k)+tr11*tr2+tr12*tr3;
+    T cr3=CC(0,0,k)+tr12*tr2+tr11*tr3;
+    T ci4, ci5;
+    MULPM(ci5,ci4,ti5,ti4,ti11,ti12)
+    PM(CH(0,k,4),CH(0,k,1),cr2,ci5)
+    PM(CH(0,k,3),CH(0,k,2),cr3,ci4)
+    }
+  if (ido==1) return;
+  for (size_t k=0; k<l1;++k)
+    for (size_t i=2; i<ido; i+=2)
+      {
+      size_t ic=ido-i;
+      T tr2, tr3, tr4, tr5, ti2, ti3, ti4, ti5;
+      PM(tr2,tr5,CC(i-1,2,k),CC(ic-1,1,k))
+      PM(ti5,ti2,CC(i  ,2,k),CC(ic  ,1,k))
+      PM(tr3,tr4,CC(i-1,4,k),CC(ic-1,3,k))
+      PM(ti4,ti3,CC(i  ,4,k),CC(ic  ,3,k))
+      CH(i-1,k,0)=CC(i-1,0,k)+tr2+tr3;
+      CH(i  ,k,0)=CC(i  ,0,k)+ti2+ti3;
+      T cr2=CC(i-1,0,k)+tr11*tr2+tr12*tr3;
+      T ci2=CC(i  ,0,k)+tr11*ti2+tr12*ti3;
+      T cr3=CC(i-1,0,k)+tr12*tr2+tr11*tr3;
+      T ci3=CC(i  ,0,k)+tr12*ti2+tr11*ti3;
+      T ci4, ci5, cr5, cr4;
+      MULPM(cr5,cr4,tr5,tr4,ti11,ti12)
+      MULPM(ci5,ci4,ti5,ti4,ti11,ti12)
+      T dr2, dr3, dr4, dr5, di2, di3, di4, di5;
+      PM(dr4,dr3,cr3,ci4)
+      PM(di3,di4,ci3,cr4)
+      PM(dr5,dr2,cr2,ci5)
+      PM(di2,di5,ci2,cr5)
+      MULPM(CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),di2,dr2)
+      MULPM(CH(i,k,2),CH(i-1,k,2),WA(1,i-2),WA(1,i-1),di3,dr3)
+      MULPM(CH(i,k,3),CH(i-1,k,3),WA(2,i-2),WA(2,i-1),di4,dr4)
+      MULPM(CH(i,k,4),CH(i-1,k,4),WA(3,i-2),WA(3,i-1),di5,dr5)
+      }
+  }
+
+#undef CC
+#undef CH
+#define CC(a,b,c) cc[(a)+ido*((b)+cdim*(c))]
+#define CH(a,b,c) ch[(a)+ido*((b)+l1*(c))]
+#define C1(a,b,c) cc[(a)+ido*((b)+l1*(c))]
+#define C2(a,b) cc[(a)+idl1*(b)]
+#define CH2(a,b) ch[(a)+idl1*(b)]
+
+template<typename T> NOINLINE void radbg(size_t ido, size_t ip, size_t l1,
+  T * restrict cc, T * restrict ch, const T0 * restrict wa,
+  const T0 * restrict csarr)
+  {
+  const size_t cdim=ip;
+  size_t ipph=(ip+1)/ 2;
+  size_t idl1 = ido*l1;
+
+  for (size_t k=0; k<l1; ++k)        // 102
+    for (size_t i=0; i<ido; ++i)     // 101
+      CH(i,k,0) = CC(i,0,k);
+  for (size_t j=1, jc=ip-1; j<ipph; ++j, --jc)   // 108
+    {
+    size_t j2=2*j-1;
+    for (size_t k=0; k<l1; ++k)
+      {
+      CH(0,k,j ) = 2*CC(ido-1,j2,k);
+      CH(0,k,jc) = 2*CC(0,j2+1,k);
+      }
+    }
+
+  if (ido!=1)
+    {
+    for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc)   // 111
+      {
+      size_t j2=2*j-1;
+      for (size_t k=0; k<l1; ++k)
+        for (size_t i=1, ic=ido-i-2; i<=ido-2; i+=2, ic-=2)      // 109
+          {
+          CH(i  ,k,j ) = CC(i  ,j2+1,k)+CC(ic  ,j2,k);
+          CH(i  ,k,jc) = CC(i  ,j2+1,k)-CC(ic  ,j2,k);
+          CH(i+1,k,j ) = CC(i+1,j2+1,k)-CC(ic+1,j2,k);
+          CH(i+1,k,jc) = CC(i+1,j2+1,k)+CC(ic+1,j2,k);
+          }
+      }
+    }
+  for (size_t l=1,lc=ip-1; l<ipph; ++l,--lc)
+    {
+    for (size_t ik=0; ik<idl1; ++ik)
+      {
+      C2(ik,l ) = CH2(ik,0)+csarr[2*l]*CH2(ik,1)+csarr[4*l]*CH2(ik,2);
+      C2(ik,lc) = csarr[2*l+1]*CH2(ik,ip-1)+csarr[4*l+1]*CH2(ik,ip-2);
+      }
+    size_t iang=2*l;
+    size_t j=3,jc=ip-3;
+    for(; j<ipph-3; j+=4,jc-=4)
+      {
+      iang+=l; if(iang>ip) iang-=ip;
+      T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1];
+      iang+=l; if(iang>ip) iang-=ip;
+      T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1];
+      iang+=l; if(iang>ip) iang-=ip;
+      T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1];
+      iang+=l; if(iang>ip) iang-=ip;
+      T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1];
+      for (size_t ik=0; ik<idl1; ++ik)
+        {
+        C2(ik,l ) += ar1*CH2(ik,j )+ar2*CH2(ik,j +1)
+                    +ar3*CH2(ik,j +2)+ar4*CH2(ik,j +3);
+        C2(ik,lc) += ai1*CH2(ik,jc)+ai2*CH2(ik,jc-1)
+                    +ai3*CH2(ik,jc-2)+ai4*CH2(ik,jc-3);
+        }
+      }
+    for(; j<ipph-1; j+=2,jc-=2)
+      {
+      iang+=l; if(iang>ip) iang-=ip;
+      T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1];
+      iang+=l; if(iang>ip) iang-=ip;
+      T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1];
+      for (size_t ik=0; ik<idl1; ++ik)
+        {
+        C2(ik,l ) += ar1*CH2(ik,j )+ar2*CH2(ik,j +1);
+        C2(ik,lc) += ai1*CH2(ik,jc)+ai2*CH2(ik,jc-1);
+        }
+      }
+    for(; j<ipph; ++j,--jc)
+      {
+      iang+=l; if(iang>ip) iang-=ip;
+      T0 war=csarr[2*iang], wai=csarr[2*iang+1];
+      for (size_t ik=0; ik<idl1; ++ik)
+        {
+        C2(ik,l ) += war*CH2(ik,j );
+        C2(ik,lc) += wai*CH2(ik,jc);
+        }
+      }
+    }
+  for (size_t j=1; j<ipph; ++j)
+    for (size_t ik=0; ik<idl1; ++ik)
+      CH2(ik,0) += CH2(ik,j);
+  for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc)   // 124
+    for (size_t k=0; k<l1; ++k)
+      {
+      CH(0,k,j ) = C1(0,k,j)-C1(0,k,jc);
+      CH(0,k,jc) = C1(0,k,j)+C1(0,k,jc);
+      }
+
+  if (ido==1) return;
+
+  for (size_t j=1, jc=ip-1; j<ipph; ++j, --jc)  // 127
+    for (size_t k=0; k<l1; ++k)
+      for (size_t i=1; i<=ido-2; i+=2)
+        {
+        CH(i  ,k,j ) = C1(i  ,k,j)-C1(i+1,k,jc);
+        CH(i  ,k,jc) = C1(i  ,k,j)+C1(i+1,k,jc);
+        CH(i+1,k,j ) = C1(i+1,k,j)+C1(i  ,k,jc);
+        CH(i+1,k,jc) = C1(i+1,k,j)-C1(i  ,k,jc);
+        }
+
+// All in CH
+
+  for (size_t j=1; j<ip; ++j)
+    {
+    size_t is = (j-1)*(ido-1);
+    for (size_t k=0; k<l1; ++k)
+      {
+      size_t idij = is;
+      for (size_t i=1; i<=ido-2; i+=2)
+        {
+        T t1=CH(i,k,j), t2=CH(i+1,k,j);
+        CH(i  ,k,j) = wa[idij]*t1-wa[idij+1]*t2;
+        CH(i+1,k,j) = wa[idij]*t2+wa[idij+1]*t1;
+        idij+=2;
+        }
+      }
+    }
+  }
+#undef C1
+#undef C2
+#undef CH2
+
+#undef CC
+#undef CH
+#undef PM
+#undef MULPM
+#undef WA
+
+template<typename T> static void copy_and_norm(T *c, T *p1, size_t n, T0 fct)
+  {
+  if (p1!=c)
+    {
+    if (fct!=1.)
+      for (size_t i=0; i<n; ++i)
+        c[i] = fct*p1[i];
+    else
+      memcpy (c,p1,n*sizeof(T));
+    }
+  else
+    if (fct!=1.)
+      for (size_t i=0; i<n; ++i)
+        c[i] *= fct;
+  }
+
+  public:
+
+template<typename T> void forward(T c[], T0 fact)
+  {
+  if (length==1) return;
+  size_t n=length;
+  size_t l1=n, nf=nfct;
+  arr<T> ch(n);
+  T *p1=c, *p2=ch.data();
+
+  for(size_t k1=0; k1<nf;++k1)
+    {
+    size_t k=nf-k1-1;
+    size_t ip=fct[k].fct;
+    size_t ido=n / l1;
+    l1 /= ip;
+    if(ip==4)
+      radf4(ido, l1, p1, p2, fct[k].tw);
+    else if(ip==2)
+      radf2(ido, l1, p1, p2, fct[k].tw);
+    else if(ip==3)
+      radf3(ido, l1, p1, p2, fct[k].tw);
+    else if(ip==5)
+      radf5(ido, l1, p1, p2, fct[k].tw);
+    else
+      {
+      radfg(ido, ip, l1, p1, p2, fct[k].tw, fct[k].tws);
+      swap (p1,p2);
+      }
+    swap (p1,p2);
+    }
+  copy_and_norm(c,p1,n,fact);
+  }
+
+template<typename T> void backward(T c[], T0 fact)
+  {
+  if (length==1) return;
+  size_t n=length;
+  size_t l1=1, nf=nfct;
+  arr<T> ch(n);
+  T *p1=c, *p2=ch.data();
+
+  for(size_t k=0; k<nf; k++)
+    {
+    size_t ip = fct[k].fct,
+           ido= n/(ip*l1);
+    if(ip==4)
+      radb4(ido, l1, p1, p2, fct[k].tw);
+    else if(ip==2)
+      radb2(ido, l1, p1, p2, fct[k].tw);
+    else if(ip==3)
+      radb3(ido, l1, p1, p2, fct[k].tw);
+    else if(ip==5)
+      radb5(ido, l1, p1, p2, fct[k].tw);
+    else
+      radbg(ido, ip, l1, p1, p2, fct[k].tw, fct[k].tws);
+    swap (p1,p2);
+    l1*=ip;
+    }
+  copy_and_norm(c,p1,n,fact);
+  }
+
+  private:
+
+void factorize()
+  {
+  size_t len=length;
+  nfct=0;
+  while ((len%4)==0)
+    { add_factor(4); len>>=2; }
+  if ((len%2)==0)
+    {
+    len>>=1;
+    // factor 2 should be at the front of the factor list
+    add_factor(2);
+    swap(fct[0].fct, fct[nfct-1].fct);
+    }
+  size_t maxl=(size_t)(sqrt((double)len))+1;
+  for (size_t divisor=3; (len>1)&&(divisor<maxl); divisor+=2)
+    if ((len%divisor)==0)
+      {
+      while ((len%divisor)==0)
+        {
+        add_factor(divisor);
+        length/=divisor;
+        }
+      maxl=(size_t)(sqrt((double)len))+1;
+      }
+  if (len>1) add_factor(len);
+  }
+
+size_t twsize() const
+  {
+  size_t twsz=0, l1=1;
+  for (size_t k=0; k<nfct; ++k)
+    {
+    size_t ip=fct[k].fct, ido=length/(l1*ip);
+    twsz+=(ip-1)*(ido-1);
+    if (ip>5) twsz+=2*ip;
+    l1*=ip;
+    }
+  return twsz;
+  }
+
+NOINLINE void comp_twiddle()
+  {
+  sincos_2pibyn twid(length, true);
+  size_t l1=1;
+  T0 *ptr=mem.data();
+  for (size_t k=0; k<nfct; ++k)
+    {
+    size_t ip=fct[k].fct, ido=length/(l1*ip);
+    if (k<nfct-1) // last factor doesn't need twiddles
+      {
+      fct[k].tw=ptr; ptr+=(ip-1)*(ido-1);
+      for (size_t j=1; j<ip; ++j)
+        for (size_t i=1; i<=(ido-1)/2; ++i)
+          {
+          fct[k].tw[(j-1)*(ido-1)+2*i-2] = twid[2*j*l1*i];
+          fct[k].tw[(j-1)*(ido-1)+2*i-1] = twid[2*j*l1*i+1];
+          }
+      }
+    if (ip>5) // special factors required by *g functions
+      {
+      fct[k].tws=ptr; ptr+=2*ip;
+      fct[k].tws[0] = 1.;
+      fct[k].tws[1] = 0.;
+      for (size_t i=1; i<=(ip>>1); ++i)
+        {
+        fct[k].tws[2*i  ] = twid[2*i*(length/ip)];
+        fct[k].tws[2*i+1] = twid[2*i*(length/ip)+1];
+        fct[k].tws[2*(ip-i)  ] = twid[2*i*(length/ip)];
+        fct[k].tws[2*(ip-i)+1] = -twid[2*i*(length/ip)+1];
+        }
+      }
+    l1*=ip;
+    }
+  }
+
+  public:
+
+NOINLINE rfftp(size_t length_)
+  : length(length_)
+  {
+  if (length==0) throw runtime_error("zero_sized FFT");
+  nfct=0;
+  if (length==1) return;
+  factorize();
+  mem.resize(twsize());
+  comp_twiddle();
+  }
+
+};
+
 template<typename T0> class fftblue
   {
   private:
@@ -973,7 +1773,7 @@ template<typename T0> class fftblue
         bk(mem.data()), bkf(mem.data()+2*n)
       {
       /* initialize b_k */
-      sincos_2pibyn tmp(2*n);
+      sincos_2pibyn tmp(2*n, false);
       bk[0] = 1;
       bk[1] = 0;
 
@@ -1005,6 +1805,36 @@ template<typename T0> class fftblue
 
     template<typename T> void forward(T c[], T0 fct)
       { fft<false>(c,fct); }
+
+    template<typename T> void backward_r(T c[], T0 fct)
+      {
+      arr<T> tmp(2*n);
+      tmp[0]=c[0];
+      tmp[1]=0.;
+      memcpy (tmp.data()+2,c+1, (n-1)*sizeof(T));
+      if ((n&1)==0) tmp[n+1]=0.;
+      for (size_t m=2; m<n; m+=2)
+        {
+        tmp[2*n-m]=tmp[m];
+        tmp[2*n-m+1]=-tmp[m+1];
+        }
+      fft<true>(tmp.data(),fct);
+      for (size_t m=0; m<n; ++m)
+        c[m] = tmp[2*m];
+      }
+
+    template<typename T> void forward_r(T c[], T0 fct)
+      {
+      arr<T> tmp(2*n);
+      for (size_t m=0; m<n; ++m)
+        {
+        tmp[2*m] = c[m];
+        tmp[2*m+1] = 0.;
+        }
+      fft<false>(tmp.data(),fct);
+      c[0] = tmp[0];
+      memcpy (c+1, tmp.data()+2, (n-1)*sizeof(T));
+      }
     };
 
 template<typename T0> class pocketfft_c
@@ -1016,7 +1846,7 @@ template<typename T0> class pocketfft_c
 
   public:
     NOINLINE pocketfft_c(size_t length)
-      : packplan(nullptr), blueplan(nullptr), len(length)
+      : len(length)
       {
       if (length==0) throw runtime_error("zero-length FFT requested");
       if ((length<50) || (largest_prime_factor(length)<=sqrt(length)))
@@ -1047,6 +1877,46 @@ template<typename T0> class pocketfft_c
 
     size_t length() const { return len; }
   };
+template<typename T0> class pocketfft_r
+  {
+  private:
+    unique_ptr<rfftp<T0>> packplan;
+    unique_ptr<fftblue<T0>> blueplan;
+    size_t len;
+
+  public:
+    NOINLINE pocketfft_r(size_t length)
+      : len(length)
+      {
+      if (length==0) throw runtime_error("zero-length FFT requested");
+      if ((length<50) || (largest_prime_factor(length)<=sqrt(length)))
+        {
+        packplan=unique_ptr<rfftp<T0>>(new rfftp<T0>(length));
+        return;
+        }
+      double comp1 = 0.5*cost_guess(length);
+      double comp2 = 2*cost_guess(good_size(2*length-1));
+      comp2*=1.5; /* fudge factor that appears to give good overall performance */
+      if (comp2<comp1) // use Bluestein
+        blueplan=unique_ptr<fftblue<T0>>(new fftblue<T0>(length));
+      else
+        packplan=unique_ptr<rfftp<T0>>(new rfftp<T0>(length));
+      }
+
+    template<typename T> void backward(T c[], T0 fct)
+      {
+      packplan ? packplan->backward(c,fct)
+               : blueplan->backward_r(c,fct);
+      }
+
+    template<typename T> void forward(T c[], T0 fct)
+      {
+      packplan ? packplan->forward(c,fct)
+               : blueplan->forward_r(c,fct);
+      }
+
+    size_t length() const { return len; }
+  };
 
 struct diminfo
   { size_t n; int64_t s; };
@@ -1165,10 +2035,35 @@ template<typename T> void pocketfft_general_c(int ndim, const size_t *shape,
     }
   }
 
+template<typename T> void pocketfft_general_r(int ndim, const size_t *shape,
+  const int64_t *stride_in, const int64_t *stride_out, size_t axis,
+  bool forward, const T *data_in, T *data_out, T fct)
+  {
+  // allocate temporary 1D array storage, if necessary
+  size_t tmpsize = shape[axis];
+  arr<T> tdata(tmpsize);
+  multiarr a_in(ndim, shape, stride_in), a_out(ndim, shape, stride_out);
+  pocketfft_r<T> plan(shape[axis]);
+  multi_iter it_in(a_in, axis), it_out(a_out, axis);
+  while (!it_in.done())
+    {
+    for (size_t i=0; i<it_in.length(); ++i)
+      tdata[i] = data_in[it_in.offset()+i*it_in.stride()];
+    forward ? plan.forward((T *)tdata.data(), fct)
+            : plan.backward((T *)tdata.data(), fct);
+    for (size_t i=0; i<it_out.length(); ++i)
+      data_out[it_out.offset()+i*it_out.stride()] = tdata[i];
+    it_in.advance();
+    it_out.advance();
+    }
+  }
+
 namespace py = pybind11;
 
 auto c64 = py::dtype("complex64");
 auto c128 = py::dtype("complex128");
+auto f32 = py::dtype("float32");
+auto f64 = py::dtype("float64");
 
 py::array execute(const py::array &in, bool fwd)
   {
@@ -1208,10 +2103,50 @@ py::array fftn(const py::array &in)
   { return execute(in, true); }
 py::array ifftn(const py::array &in)
   { return execute(in, false); }
+
+py::array execute_r(const py::array &in, int axis, bool fwd)
+  {
+  py::array res;
+  vector<size_t> dims(in.ndim());
+  vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
+  for (int i=0; i<in.ndim(); ++i)
+    dims[i] = in.shape(i);
+  if (in.dtype().is(f64))
+    res = py::array_t<double>(dims);
+  else if (in.dtype().is(f32))
+    res = py::array_t<float>(dims);
+  else throw runtime_error("unsupported data type");
+  for (int i=0; i<in.ndim(); ++i)
+    {
+    s_i[i] = in.strides(i)/in.itemsize();
+    if (s_i[i]*in.itemsize() != in.strides(i))
+      throw runtime_error("weird strides");
+    s_o[i] = res.strides(i)/res.itemsize();
+    if (s_o[i]*res.itemsize() != res.strides(i))
+      throw runtime_error("weird strides");
+    }
+  if (in.dtype().is(f64))
+    pocketfft_general_r<double>(in.ndim(), dims.data(),
+    s_i.data(), s_o.data(),
+    axis, fwd, (const double *)in.data(),
+    (double *)res.mutable_data(), 1.);
+  else
+    pocketfft_general_r<float>(in.ndim(), dims.data(),
+    s_i.data(), s_o.data(),
+    axis, fwd, (const float *)in.data(),
+    (float *)res.mutable_data(), 1.);
+  return res;
+  }
+py::array rfftn(const py::array &in, int axis)
+  { return execute_r(in, axis, true); }
+py::array irfftn(const py::array &in, int axis)
+  { return execute_r(in, axis, false); }
 } // unnamed namespace
 
 PYBIND11_MODULE(pypocketfft, m)
   {
   m.def("fftn",&fftn);
   m.def("ifftn",&ifftn);
+  m.def("rfftn",&rfftn);
+  m.def("irfftn",&irfftn);
   }
-- 
GitLab