Commit e820a0a6 authored by Martin Reinecke's avatar Martin Reinecke

switch to header-only version

parent 188984a0
......@@ -4,2402 +4,29 @@
*/
/*
* Main implementation file.
* Python interface.
*
* Copyright (C) 2004-2019 Max-Planck-Society
* Copyright (C) 2019 Max-Planck-Society
* \author Martin Reinecke
*/
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <iostream>
#include <memory>
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#ifdef __GNUC__
#define NOINLINE __attribute__((noinline))
#define restrict __restrict__
#else
#define NOINLINE
#define restrict
#endif
#pragma GCC visibility push(hidden)
namespace pocketfft_private {
using namespace std;
template<typename T> struct arr
{
private:
T *p;
size_t sz;
static T *ralloc(size_t num)
{
if (num==0) return nullptr;
#ifdef POCKETFFT_NO_VECTORS
void *res = malloc(num*sizeof(T));
if (!res) throw bad_alloc();
#else
#if 0
void *res = aligned_alloc(64,num*sizeof(T));
if (!res) throw bad_alloc();
#else
void *res(nullptr);
if (posix_memalign(&res, 64, num*sizeof(T))!=0)
throw bad_alloc();
#endif
#endif
return reinterpret_cast<T *>(res);
}
static void dealloc(T *ptr)
{ free(ptr); }
public:
arr() : p(0), sz(0) {}
arr(size_t n) : p(ralloc(n)), sz(n) {}
arr(arr &&other)
: p(other.p), sz(other.sz)
{ other.p=nullptr; other.sz=0; }
~arr() { dealloc(p); }
void resize(size_t n)
{
if (n==sz) return;
dealloc(p);
p = ralloc(n);
sz = n;
}
T &operator[](size_t idx) { return p[idx]; }
const T &operator[](size_t idx) const { return p[idx]; }
T *data() { return p; }
const T *data() const { return p; }
size_t size() const { return sz; }
};
template<typename T> struct cmplx {
T r, i;
cmplx() {}
cmplx(T r_, T i_) : r(r_), i(i_) {}
void Set(T r_, T i_) { r=r_; i=i_; }
void Set(T r_) { r=r_; i=T(0); }
cmplx &operator+= (const cmplx &other)
{ r+=other.r; i+=other.i; return *this; }
template<typename T2>cmplx &operator*= (T2 other)
{ r*=other; i*=other; return *this; }
cmplx operator+ (const cmplx &other) const
{ return cmplx(r+other.r, i+other.i); }
cmplx operator- (const cmplx &other) const
{ return cmplx(r-other.r, i-other.i); }
template<typename T2> auto operator* (const T2 &other) const
-> cmplx<decltype(r*other)>
{ return {r*other, i*other}; }
template<typename T2> auto operator* (const cmplx<T2> &other) const
-> cmplx<decltype(r+other.r)>
{ return {r*other.r-i*other.i, r*other.i + i*other.r}; }
template<bool bwd, typename T2> auto special_mul (const cmplx<T2> &other) const
-> cmplx<decltype(r+other.r)>
{
return bwd ? cmplx(r*other.r-i*other.i, r*other.i + i*other.r)
: cmplx(r*other.r+i*other.i, i*other.r - r*other.i);
}
};
template<typename T> void PMC(cmplx<T> &a, cmplx<T> &b,
const cmplx<T> &c, const cmplx<T> &d)
{ a = c+d; b = c-d; }
template<typename T> cmplx<T> conj(const cmplx<T> &a)
{ return {a.r, -a.i}; }
template<typename T> void ROT90(cmplx<T> &a)
{ auto tmp_=a.r; a.r=-a.i; a.i=tmp_; }
template<typename T> void ROTM90(cmplx<T> &a)
{ auto tmp_=-a.r; a.r=a.i; a.i=tmp_; }
//
// twiddle factor section
//
template<typename T> class sincos_2pibyn
{
private:
arr<T> data;
// adapted from https://stackoverflow.com/questions/42792939/
// CAUTION: this function only works for arguments in the range
// [-0.25; 0.25]!
void my_sincosm1pi (double a, double *restrict res)
{
double s = a * a;
/* Approximate cos(pi*x)-1 for x in [-0.25,0.25] */
double r = -1.0369917389758117e-4;
r = fma (r, s, 1.9294935641298806e-3);
r = fma (r, s, -2.5806887942825395e-2);
r = fma (r, s, 2.3533063028328211e-1);
r = fma (r, s, -1.3352627688538006e+0);
r = fma (r, s, 4.0587121264167623e+0);
r = fma (r, s, -4.9348022005446790e+0);
double c = r*s;
/* Approximate sin(pi*x) for x in [-0.25,0.25] */
r = 4.6151442520157035e-4;
r = fma (r, s, -7.3700183130883555e-3);
r = fma (r, s, 8.2145868949323936e-2);
r = fma (r, s, -5.9926452893214921e-1);
r = fma (r, s, 2.5501640398732688e+0);
r = fma (r, s, -5.1677127800499516e+0);
s = s * a;
r = r * s;
s = fma (a, 3.1415926535897931e+0, r);
res[0] = c;
res[1] = s;
}
NOINLINE void calc_first_octant(size_t den, T * restrict res)
{
size_t n = (den+4)>>3;
if (n==0) return;
res[0]=1.; res[1]=0.;
if (n==1) return;
size_t l1=(size_t)sqrt(n);
arr<double> tmp(2*l1);
for (size_t i=1; i<l1; ++i)
{
my_sincosm1pi((2.*i)/den,&tmp[2*i]);
res[2*i ] = tmp[2*i]+1.;
res[2*i+1] = tmp[2*i+1];
}
size_t start=l1;
while(start<n)
{
double cs[2];
my_sincosm1pi((2.*start)/den,cs);
res[2*start] = cs[0]+1.;
res[2*start+1] = cs[1];
size_t end = l1;
if (start+end>n) end = n-start;
for (size_t i=1; i<end; ++i)
{
double csx[2]={tmp[2*i], tmp[2*i+1]};
res[2*(start+i)] = ((cs[0]*csx[0] - cs[1]*csx[1] + cs[0]) + csx[0]) + 1.;
res[2*(start+i)+1] = (cs[0]*csx[1] + cs[1]*csx[0]) + cs[1] + csx[1];
}
start += l1;
}
}
NOINLINE void calc_first_quadrant(size_t n, T * restrict res)
{
T * restrict p = res+n;
calc_first_octant(n<<1, p);
size_t ndone=(n+2)>>2;
size_t i=0, idx1=0, idx2=2*ndone-2;
for (; i+1<ndone; i+=2, idx1+=2, idx2-=2)
{
res[idx1] = p[2*i];
res[idx1+1] = p[2*i+1];
res[idx2] = p[2*i+3];
res[idx2+1] = p[2*i+2];
}
if (i!=ndone)
{
res[idx1 ] = p[2*i];
res[idx1+1] = p[2*i+1];
}
}
NOINLINE void calc_first_half(size_t n, T * restrict res)
{
int ndone=(n+1)>>1;
T * p = res+n-1;
calc_first_octant(n<<2, p);
int i4=0, in=n, i=0;
for (; i4<=in-i4; ++i, i4+=4) // octant 0
{
res[2*i] = p[2*i4]; res[2*i+1] = p[2*i4+1];
}
for (; i4-in <= 0; ++i, i4+=4) // octant 1
{
int xm = in-i4;
res[2*i] = p[2*xm+1]; res[2*i+1] = p[2*xm];
}
for (; i4<=3*in-i4; ++i, i4+=4) // octant 2
{
int xm = i4-in;
res[2*i] = -p[2*xm+1]; res[2*i+1] = p[2*xm];
}
for (; i<ndone; ++i, i4+=4) // octant 3
{
int xm = 2*in-i4;
res[2*i] = -p[2*xm]; res[2*i+1] = p[2*xm+1];
}
}
NOINLINE void fill_first_quadrant(size_t n, T * restrict res)
{
const double hsqt2 = 0.707106781186547524400844362104849;
size_t quart = n>>2;
if ((n&7)==0)
res[quart] = res[quart+1] = hsqt2;
for (size_t i=2, j=2*quart-2; i<quart; i+=2, j-=2)
{
res[j ] = res[i+1];
res[j+1] = res[i ];
}
}
NOINLINE void fill_first_half(size_t n, T * restrict res)
{
size_t half = n>>1;
if ((n&3)==0)
for (size_t i=0; i<half; i+=2)
{
res[i+half] = -res[i+1];
res[i+half+1] = res[i ];
}
else
for (size_t i=2, j=2*half-2; i<half; i+=2, j-=2)
{
res[j ] = -res[i ];
res[j+1] = res[i+1];
}
}
NOINLINE void fill_second_half(size_t n, T * restrict res)
{
if ((n&1)==0)
for (size_t i=0; i<n; ++i)
res[i+n] = -res[i];
else
for (size_t i=2, j=2*n-2; i<n; i+=2, j-=2)
{
res[j ] = res[i ];
res[j+1] = -res[i+1];
}
}
NOINLINE void sincos_2pibyn_half(size_t n, T * restrict res)
{
if ((n&3)==0)
{
calc_first_octant(n, res);
fill_first_quadrant(n, res);
fill_first_half(n, res);
}
else if ((n&1)==0)
{
calc_first_quadrant(n, res);
fill_first_half(n, res);
}
else
calc_first_half(n, res);
}
public:
NOINLINE sincos_2pibyn(size_t n, bool half)
: data(2*n)
{
sincos_2pibyn_half(n, data.data());
if (!half) fill_second_half(n, data.data());
}
T operator[](size_t idx) const { return data[idx]; }
const T *rdata() const { return data; }
const cmplx<T> *cdata() const
{ return reinterpret_cast<const cmplx<T> *>(data.data()); }
};
NOINLINE size_t largest_prime_factor (size_t n)
{
size_t res=1;
while ((n&1)==0)
{ res=2; n>>=1; }
size_t limit=size_t(sqrt(n+0.01));
for (size_t x=3; x<=limit; x+=2)
while ((n/x)*x==n)
{
res=x;
n/=x;
limit=size_t(sqrt(n+0.01));
}
if (n>1) res=n;
return res;
}
NOINLINE double cost_guess (size_t n)
{
const double lfp=1.1; // penalty for non-hardcoded larger factors
size_t ni=n;
double result=0.;
while ((n&1)==0)
{ result+=2; n>>=1; }
size_t limit=size_t(sqrt(n+0.01));
for (size_t x=3; x<=limit; x+=2)
while ((n/x)*x==n)
{
result+= (x<=5) ? x : lfp*x; // penalize larger prime factors
n/=x;
limit=size_t(sqrt(n+0.01));
}
if (n>1) result+=(n<=5) ? n : lfp*n;
return result*ni;
}
/* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n */
NOINLINE size_t good_size(size_t n)
{
if (n<=12) return n;
size_t bestfac=2*n;
for (size_t f2=1; f2<bestfac; f2*=2)
for (size_t f23=f2; f23<bestfac; f23*=3)
for (size_t f235=f23; f235<bestfac; f235*=5)
for (size_t f2357=f235; f2357<bestfac; f2357*=7)
for (size_t f235711=f2357; f235711<bestfac; f235711*=11)
if (f235711>=n) bestfac=f235711;
return bestfac;
}
#define CH(a,b,c) ch[(a)+ido*((b)+l1*(c))]
#define CC(a,b,c) cc[(a)+ido*((b)+cdim*(c))]
#define WA(x,i) wa[(i)-1+(x)*(ido-1)]
constexpr size_t NFCT=25;
//
// complex FFTPACK transforms
//
template<typename T0> class cfftp
{
private:
struct fctdata
{
size_t fct;
cmplx<T0> *tw, *tws;
};
size_t length, nfct;
arr<cmplx<T0>> mem;
fctdata fct[NFCT];
void add_factor(size_t factor)
{
if (nfct>=NFCT) throw runtime_error("too many prime factors");
fct[nfct++].fct = factor;
}
template<bool bwd, typename T> NOINLINE void pass2 (size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{
constexpr size_t cdim=2;
if (ido==1)
for (size_t k=0; k<l1; ++k)
{
CH(0,k,0) = CC(0,0,k)+CC(0,1,k);
CH(0,k,1) = CC(0,0,k)-CC(0,1,k);
}
else
for (size_t k=0; k<l1; ++k)
{
CH(0,k,0) = CC(0,0,k)+CC(0,1,k);
CH(0,k,1) = CC(0,0,k)-CC(0,1,k);
for (size_t i=1; i<ido; ++i)
{
CH(i,k,0) = CC(i,0,k)+CC(i,1,k);
CH(i,k,1) = (CC(i,0,k)-CC(i,1,k)).template special_mul<bwd>(WA(0,i));
}
}
}
#define PREP3(idx) \
T t0 = CC(idx,0,k), t1, t2; \
PMC (t1,t2,CC(idx,1,k),CC(idx,2,k)); \
CH(idx,k,0)=t0+t1;
#define PARTSTEP3a(u1,u2,twr,twi) \
{ \
T ca,cb; \
ca=t0+t1*twr; \
cb=t2*twi; ROT90(cb); \
PMC(CH(0,k,u1),CH(0,k,u2),ca,cb) ;\
}
#define PARTSTEP3b(u1,u2,twr,twi) \
{ \
T ca,cb,da,db; \
ca=t0+t1*twr; \
cb=t2*twi; ROT90(cb); \
PMC(da,db,ca,cb); \
CH(i,k,u1) = da.template special_mul<bwd>(WA(u1-1,i)); \
CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \
}
template<bool bwd, typename T> NOINLINE void pass3 (size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{
constexpr size_t cdim=3;
constexpr T0 tw1r=-0.5, tw1i= (bwd ? 1.: -1.) * 0.86602540378443864676;
if (ido==1)
for (size_t k=0; k<l1; ++k)
{
PREP3(0)
PARTSTEP3a(1,2,tw1r,tw1i)
}
else
for (size_t k=0; k<l1; ++k)
{
{
PREP3(0)
PARTSTEP3a(1,2,tw1r,tw1i)
}
for (size_t i=1; i<ido; ++i)
{
PREP3(i)
PARTSTEP3b(1,2,tw1r,tw1i)
}
}
}
#undef PARTSTEP3b
#undef PARTSTEP3a
#undef PREP3
template<bool bwd, typename T> NOINLINE void pass4 (size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{
constexpr size_t cdim=4;
if (ido==1)
for (size_t k=0; k<l1; ++k)
{
T t1, t2, t3, t4;
PMC(t2,t1,CC(0,0,k),CC(0,2,k));
PMC(t3,t4,CC(0,1,k),CC(0,3,k));
bwd ? ROT90(t4) : ROTM90(t4);
PMC(CH(0,k,0),CH(0,k,2),t2,t3);
PMC(CH(0,k,1),CH(0,k,3),t1,t4);
}
else
for (size_t k=0; k<l1; ++k)
{
{
T t1, t2, t3, t4;
PMC(t2,t1,CC(0,0,k),CC(0,2,k));
PMC(t3,t4,CC(0,1,k),CC(0,3,k));
bwd ? ROT90(t4) : ROTM90(t4);
PMC(CH(0,k,0),CH(0,k,2),t2,t3);
PMC(CH(0,k,1),CH(0,k,3),t1,t4);
}
for (size_t i=1; i<ido; ++i)
{
T c2, c3, c4, t1, t2, t3, t4;
T cc0=CC(i,0,k), cc1=CC(i,1,k),cc2=CC(i,2,k),cc3=CC(i,3,k);
PMC(t2,t1,cc0,cc2);
PMC(t3,t4,cc1,cc3);
bwd ? ROT90(t4) : ROTM90(t4);
cmplx<T0> wa0=WA(0,i), wa1=WA(1,i),wa2=WA(2,i);
PMC(CH(i,k,0),c3,t2,t3);
PMC(c2,c4,t1,t4);
CH(i,k,1) = c2.template special_mul<bwd>(wa0);
CH(i,k,2) = c3.template special_mul<bwd>(wa1);
CH(i,k,3) = c4.template special_mul<bwd>(wa2);
}
}
}
#define PREP5(idx) \
T t0 = CC(idx,0,k), t1, t2, t3, t4; \
PMC (t1,t4,CC(idx,1,k),CC(idx,4,k)); \
PMC (t2,t3,CC(idx,2,k),CC(idx,3,k)); \
CH(idx,k,0).r=t0.r+t1.r+t2.r; \
CH(idx,k,0).i=t0.i+t1.i+t2.i;
#define PARTSTEP5a(u1,u2,twar,twbr,twai,twbi) \
{ \
T ca,cb; \
ca.r=t0.r+twar*t1.r+twbr*t2.r; \
ca.i=t0.i+twar*t1.i+twbr*t2.i; \
cb.i=twai*t4.r twbi*t3.r; \
cb.r=-(twai*t4.i twbi*t3.i); \
PMC(CH(0,k,u1),CH(0,k,u2),ca,cb); \
}
#define PARTSTEP5b(u1,u2,twar,twbr,twai,twbi) \
{ \
T ca,cb,da,db; \
ca.r=t0.r+twar*t1.r+twbr*t2.r; \
ca.i=t0.i+twar*t1.i+twbr*t2.i; \
cb.i=twai*t4.r twbi*t3.r; \
cb.r=-(twai*t4.i twbi*t3.i); \
PMC(da,db,ca,cb); \
CH(i,k,u1) = da.template special_mul<bwd>(WA(u1-1,i)); \
CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \
}
template<bool bwd, typename T> NOINLINE void pass5 (size_t ido, size_t l1,
const T * restrict cc, T * restrict ch, const cmplx<T0> * restrict wa)
{
constexpr size_t cdim=5;
constexpr T0 tw1r= 0.3090169943749474241,
tw1i= (bwd ? 1.: -1.) * 0.95105651629515357212,
tw2r= -0.8090169943749474241,
tw2i= (bwd ? 1.: -1.) * 0.58778525229247312917;
if (ido==1)
for (size_t k=0; k<l1; ++k)
{
PREP5(0)
PARTSTEP5a(1,4,tw1r,tw2r,+tw1i,+tw2i)
PARTSTEP5a(2,3,tw2r,tw1r,+tw2i,-tw1i)
}
else
for (size_t k=0; k<l1; ++k)
{
{
PREP5(0)
PARTSTEP5a(1,4,tw1r,tw2r,+tw1i,+tw2i)
PARTSTEP5a(2,3,tw2r,tw1r,+tw2i,-tw1i)
}
for (size_t i=1; i<ido; ++i)
{
PREP5(i)
PARTSTEP5b(1,4,tw1r,tw2r,+tw1i,+tw2i)
PARTSTEP5b(2,3,tw2r,tw1r,+tw2i,-tw1i)
}
}
}
#undef PARTSTEP5b
#undef PARTSTEP5a
#undef PREP5
#define PREP7(idx) \
T t1 = CC(idx,0,k), t2, t3, t4, t5, t6, t7; \
PMC (t2,t7,CC(idx,1,k),CC(idx,6,k)); \
PMC (t3,t6,CC(idx,2,k),CC(idx,5,k)); \
PMC (t4,t5,CC(idx,3,k),CC(idx,4,k)); \
CH(idx,k,0).r=t1.r+t2.r+t3.r+t4.r; \
CH(idx,k,0).i=t1.i+t2.i+t3.i+t4.i;
#define PARTSTEP7a0(u1,u2,x1,x2,x3,y1,y2,y3,out1,out2) \
{ \
T ca,cb; \
ca.r=t1.r+x1*t2.r+x2*t3.r+x3*t4.r; \
ca.i=t1.i+x1*t2.i+x2*t3.i+x3*t4.i; \
cb.i=y1*t7.r y2*t6.r y3*t5.r; \
cb.r=-(y1*t7.i y2*t6.i y3*t5.i); \
PMC(out1,out2,ca,cb); \
}
#define PARTSTEP7a(u1,u2,x1,x2,x3,y1,y2,y3) \
PARTSTEP7a0(u1,u2,x1,x2,x3,y1,y2,y3,CH(0,k,u1),CH(0,k,u2))
#define PARTSTEP7(u1,u2,x1,x2,x3,y1,y2,y3) \
{ \
T da,db; \
PARTSTEP7a0(u1,u2,x1,x2,x3,y1,y2,y3,da,db) \
CH(i,k,u1) = da.template special_mul<bwd>(WA(u1-1,i)); \
CH(i,k,u2) = db.template special_mul<bwd>(WA(u2-1,i)); \