Commit 50476533 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'cleanup' into 'master'

Cleanup

See merge request !15
parents 656b2373 88f64fcf
...@@ -209,13 +209,19 @@ template<typename T> struct cmplx { ...@@ -209,13 +209,19 @@ template<typename T> struct cmplx {
r = tmp; r = tmp;
return *this; return *this;
} }
cmplx operator+ (const cmplx &other) const template<typename T2>cmplx &operator+= (const cmplx<T2> &other)
{ return cmplx(r+other.r, i+other.i); } { r+=other.r; i+=other.i; return *this; }
cmplx operator- (const cmplx &other) const template<typename T2>cmplx &operator-= (const cmplx<T2> &other)
{ return cmplx(r-other.r, i-other.i); } { r-=other.r; i-=other.i; return *this; }
template<typename T2> auto operator* (const T2 &other) const template<typename T2> auto operator* (const T2 &other) const
-> cmplx<decltype(r*other)> -> cmplx<decltype(r*other)>
{ return {r*other, i*other}; } { return {r*other, i*other}; }
template<typename T2> auto operator+ (const cmplx<T2> &other) const
-> cmplx<decltype(r+other.r)>
{ return {r+other.r, i+other.i}; }
template<typename T2> auto operator- (const cmplx<T2> &other) const
-> cmplx<decltype(r+other.r)>
{ return {r-other.r, i-other.i}; }
template<typename T2> auto operator* (const cmplx<T2> &other) const template<typename T2> auto operator* (const cmplx<T2> &other) const
-> cmplx<decltype(r+other.r)> -> cmplx<decltype(r+other.r)>
{ return {r*other.r-i*other.i, r*other.i + i*other.r}; } { return {r*other.r-i*other.i, r*other.i + i*other.r}; }
...@@ -227,8 +233,9 @@ template<typename T> struct cmplx { ...@@ -227,8 +233,9 @@ template<typename T> struct cmplx {
: Tres(r*other.r-i*other.i, r*other.i+i*other.r); : Tres(r*other.r-i*other.i, r*other.i+i*other.r);
} }
}; };
template<typename T> void PMC(cmplx<T> &a, cmplx<T> &b, template<typename T> inline void PMINPLACE(T &a, T &b)
const cmplx<T> &c, const cmplx<T> &d) { T t = a; a+=b; b=t-b; }
template<typename T> void PMC(T &a, T &b, const T &c, const T &d)
{ a = c+d; b = c-d; } { a = c+d; b = c-d; }
template<typename T> cmplx<T> conj(const cmplx<T> &a) template<typename T> cmplx<T> conj(const cmplx<T> &a)
{ return {a.r, -a.i}; } { return {a.r, -a.i}; }
...@@ -845,8 +852,6 @@ template <bool fwd, typename T> void ROTX135(T &a) ...@@ -845,8 +852,6 @@ template <bool fwd, typename T> void ROTX135(T &a)
else else
{ auto tmp_=a.r; a.r=hsqt2*(-a.r-a.i); a.i=hsqt2*(tmp_-a.i); } { auto tmp_=a.r; a.r=hsqt2*(-a.r-a.i); a.i=hsqt2*(tmp_-a.i); }
} }
template<typename T> inline void PMINPLACE(T &a, T &b)
{ T t = a; a.r+=b.r; a.i+=b.i; b.r=t.r-b.r; b.i=t.i-b.i; }
template<bool fwd, typename T> void pass8 (size_t ido, size_t l1, template<bool fwd, typename T> void pass8 (size_t ido, size_t l1,
const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
...@@ -2293,6 +2298,7 @@ template<typename T0> class pocketfft_r ...@@ -2293,6 +2298,7 @@ template<typename T0> class pocketfft_r
// //
// sine/cosine transforms // sine/cosine transforms
// //
template<typename T0> class T_dct1 template<typename T0> class T_dct1
{ {
private: private:
...@@ -2302,7 +2308,8 @@ template<typename T0> class T_dct1 ...@@ -2302,7 +2308,8 @@ template<typename T0> class T_dct1
POCKETFFT_NOINLINE T_dct1(size_t length) POCKETFFT_NOINLINE T_dct1(size_t length)
: fftplan(2*(length-1)) {} : fftplan(2*(length-1)) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho,
int /*type*/, bool /*cosine*/) const
{ {
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=fftplan.length(), n=N/2+1; size_t N=fftplan.length(), n=N/2+1;
...@@ -2323,126 +2330,124 @@ template<typename T0> class T_dct1 ...@@ -2323,126 +2330,124 @@ template<typename T0> class T_dct1
size_t length() const { return fftplan.length()/2+1; } size_t length() const { return fftplan.length()/2+1; }
}; };
template<typename T0> class T_dct2 template<typename T0> class T_dst1
{ {
private: private:
pocketfft_r<T0> fftplan; pocketfft_r<T0> fftplan;
vector<T0> twiddle;
public: public:
POCKETFFT_NOINLINE T_dct2(size_t length) POCKETFFT_NOINLINE T_dst1(size_t length)
: fftplan(length), twiddle(length) : fftplan(2*(length+1)) {}
{
constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
for (size_t i=0; i<length; ++i)
twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct,
bool /*ortho*/, int /*type*/, bool /*cosine*/) const
{ {
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=fftplan.length(), n=N/2-1;
size_t N=length(); arr<T> tmp(N);
if (N==1) tmp[0] = tmp[n+1] = c[0]*0;
c[0]*=2*fct; for (size_t i=0; i<n; ++i)
else if (N==2)
{
T x1 = 2*fct*(c[0]+c[1]);
c[1] = sqrt2*fct*(c[0]-c[1]);
c[0] = x1;
}
else
{ {
size_t NS2 = (N+1)/2; tmp[i+1] = c[i];
for (size_t i=2; i<N; i+=2) tmp[N-1-i] = -c[i];
{
T xim1 = T0(0.5)*(c[i-1]+c[i]);
c[i] = T0(0.5)*(c[i]-c[i-1]);
c[i-1] = xim1;
}
fftplan.backward(c, fct);
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
{
T tmp = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k];
c[kc] = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc];
c[k] = tmp;
}
if ((N&1)==0)
c[NS2] = twiddle[NS2-1]*(c[NS2]+c[NS2]);
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
{
T tmp = c[k]+c[kc];
c[kc] = c[k]-c[kc];
c[k] = tmp;
}
c[0] *= 2;
} }
if (ortho) c[0]/=sqrt2; fftplan.forward(tmp.data(), fct);
for (size_t i=0; i<n; ++i)
c[i] = -tmp[2*i+2];
} }
size_t length() const { return fftplan.length(); } size_t length() const { return fftplan.length()/2-1; }
}; };
template<typename T0> class T_dct3 template<typename T0> class T_dcst23
{ {
private: private:
pocketfft_r<T0> fftplan; pocketfft_r<T0> fftplan;
vector<T0> twiddle; vector<T0> twiddle;
public: template<typename T> POCKETFFT_NOINLINE void exec2(T c[], T0 fct,
POCKETFFT_NOINLINE T_dct3(size_t length) bool ortho, bool cosine) const
: fftplan(length), twiddle(length)
{ {
constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
for (size_t i=0; i<length; ++i) size_t N=length();
twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length))); size_t NS2 = (N+1)/2;
if (!cosine)
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
for (size_t i=2; i<N; i+=2)
{
T xim1 = T0(0.5)*(c[i-1]+c[i]);
c[i] = T0(0.5)*(c[i]-c[i-1]);
c[i-1] = xim1;
}
fftplan.backward(c, fct);
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
{
T tmp = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k];
c[kc] = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc];
c[k] = tmp;
}
if ((N&1)==0)
c[NS2] = twiddle[NS2-1]*(c[NS2]+c[NS2]);
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
PMINPLACE(c[k], c[kc]);
c[0] *= 2;
if (!cosine)
for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
swap(c[k], c[kc]);
if (ortho) c[0]/=sqrt2;
} }
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const template<typename T> POCKETFFT_NOINLINE void exec3(T c[], T0 fct,
bool ortho, bool cosine) const
{ {
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length(); size_t N=length();
if (ortho) c[0]*=sqrt2; if (ortho) c[0]*=sqrt2;
if (N==1) size_t NS2 = (N+1)/2;
c[0]*=fct; if (!cosine)
else if (N==2) for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
swap(c[k], c[kc]);
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
PMINPLACE(c[k], c[kc]);
if ((N&1)==0)
c[NS2] = c[NS2]+c[NS2];
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
{ {
T TSQX = sqrt2*c[1]; T tmp = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc];
c[1] = fct*(c[0]-TSQX); c[k] = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k];
c[0] = fct*(c[0]+TSQX); c[kc] = tmp;
} }
else if ((N&1)==0)
c[NS2] = twiddle[NS2-1]*c[NS2];
fftplan.forward(c, fct);
for (size_t i=2; i<N; i+=2)
{ {
size_t NS2 = (N+1)/2; T xim1 = c[i-1]-c[i];
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc) c[i] += c[i-1];
{ c[i-1] = xim1;
T tmp = c[k]-c[kc];
c[k] = c[k]+c[kc];
c[kc] = tmp;
}
if ((N&1)==0)
c[NS2] = c[NS2]+c[NS2];
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
{
T tmp = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc];
c[k] = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k];
c[kc] = tmp;
}
if ((N&1)==0)
c[NS2] = twiddle[NS2-1]*c[NS2];
fftplan.forward(c, fct);
for (size_t i=2; i<N; i+=2)
{
T xim1 = c[i-1]-c[i];
c[i] += c[i-1];
c[i-1] = xim1;
}
} }
if (!cosine)
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
}
public:
POCKETFFT_NOINLINE T_dcst23(size_t length)
: fftplan(length), twiddle(length)
{
constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
for (size_t i=0; i<length; ++i)
twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
} }
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho,
int type, bool cosine) const
{ type==2 ? exec2(c, fct, ortho, cosine) : exec3(c, fct, ortho, cosine); }
size_t length() const { return fftplan.length(); } size_t length() const { return fftplan.length(); }
}; };
template<typename T0> class T_dct4 template<typename T0> class T_dcst4
{ {
// even length algorithm from // even length algorithm from
// https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/ // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/
...@@ -2453,7 +2458,7 @@ template<typename T0> class T_dct4 ...@@ -2453,7 +2458,7 @@ template<typename T0> class T_dct4
arr<cmplx<T0>> C2; arr<cmplx<T0>> C2;
public: public:
POCKETFFT_NOINLINE T_dct4(size_t length) POCKETFFT_NOINLINE T_dcst4(size_t length)
: N(length), : N(length),
fft((N&1) ? nullptr : new pocketfft_c<T0>(N/2)), fft((N&1) ? nullptr : new pocketfft_c<T0>(N/2)),
rfft((N&1)? new pocketfft_r<T0>(N) : nullptr), rfft((N&1)? new pocketfft_r<T0>(N) : nullptr),
...@@ -2468,9 +2473,14 @@ template<typename T0> class T_dct4 ...@@ -2468,9 +2473,14 @@ template<typename T0> class T_dct4
} }
} }
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct,
bool /*ortho*/, int /*type*/, bool cosine) const
{ {
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t n2 = N/2;
if (!cosine)
for (size_t k=0, kc=N-1; k<n2; ++k, --kc)
swap(c[k], c[kc]);
if (N&1) if (N&1)
{ {
// The following code is derived from the FFTW3 function apply_re11() // The following code is derived from the FFTW3 function apply_re11()
...@@ -2479,11 +2489,9 @@ template<typename T0> class T_dct4 ...@@ -2479,11 +2489,9 @@ template<typename T0> class T_dct4
auto SGN_SET = [](T x, size_t i) {return (i%2) ? -x : x;}; auto SGN_SET = [](T x, size_t i) {return (i%2) ? -x : x;};
arr<T> y(N); arr<T> y(N);
size_t n2 = N/2;
size_t i;
{ {
size_t m; size_t i=0, m=n2;
for (i=0, m=n2; m<N; ++i, m+=4) for (; m<N; ++i, m+=4)
y[i] = c[m]; y[i] = c[m];
for (; m<2*N; ++i, m+=4) for (; m<2*N; ++i, m+=4)
y[i] = -c[2*N-m-1]; y[i] = -c[2*N-m-1];
...@@ -2496,7 +2504,9 @@ template<typename T0> class T_dct4 ...@@ -2496,7 +2504,9 @@ template<typename T0> class T_dct4
y[i] = c[m]; y[i] = c[m];
} }
rfft->forward(y.data(), fct); rfft->forward(y.data(), fct);
for (i=0; i+i+1<n2; ++i) {
size_t i=0;
for (; i+i+1<n2; ++i)
{ {
size_t k = i+i+1; size_t k = i+i+1;
T c1=y[2*k-1], s1=y[2*k], c2=y[2*k+1], s2=y[2*k+2]; T c1=y[2*k-1], s1=y[2*k], c2=y[2*k+1], s2=y[2*k+2];
...@@ -2511,140 +2521,34 @@ template<typename T0> class T_dct4 ...@@ -2511,140 +2521,34 @@ template<typename T0> class T_dct4
c[i] = sqrt2 * (SGN_SET(cx, (i+1)/2) + SGN_SET(sx, i/2)); c[i] = sqrt2 * (SGN_SET(cx, (i+1)/2) + SGN_SET(sx, i/2));
c[N-(i+1)] = sqrt2 * (SGN_SET(cx, (i+2)/2) + SGN_SET(sx, (i+1)/2)); c[N-(i+1)] = sqrt2 * (SGN_SET(cx, (i+2)/2) + SGN_SET(sx, (i+1)/2));
} }
}
c[n2] = sqrt2 * SGN_SET(y[0], (n2+1)/2); c[n2] = sqrt2 * SGN_SET(y[0], (n2+1)/2);
// FFTW-derived code ends here // FFTW-derived code ends here
} }
else else
{ {
arr<cmplx<T>> y(N/2); // even length algorithm from
for(size_t i=0; i<N/2; ++i) // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/
arr<cmplx<T>> y(n2);
for(size_t i=0; i<n2; ++i)
{ {
y[i].Set(c[2*i],c[N-1-2*i]); y[i].Set(c[2*i],c[N-1-2*i]);
y[i] *= C2[i]; y[i] *= C2[i];
} }
fft->forward(y.data(), fct); fft->forward(y.data(), fct);
for(size_t i=0; i<N/2; ++i) for(size_t i=0, ic=n2-1; i<n2; ++i, --ic)
y[i] *= C2[i];
for(size_t i=0; i<N/2; ++i)
{ {
c[2*i] = 2*y[i].r; c[2*i ] = 2*(y[i ].r*C2[i ].r-y[i ].i*C2[i ].i);
c[2*i+1] = -2*y[N/2-1-i].i; c[2*i+1] = -2*(y[ic].i*C2[ic].r+y[ic].r*C2[ic].i);
} }
} }
} if (!cosine)
size_t length() const { return N; }
};
template<typename T0> class T_dst1
{
private:
pocketfft_r<T0> fftplan;
public:
POCKETFFT_NOINLINE T_dst1(size_t length)
: fftplan(2*(length+1)) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const
{
size_t N=fftplan.length(), n=N/2-1;
arr<T> tmp(N);
tmp[0] = tmp[n+1] = c[0]*0;
for (size_t i=0; i<n; ++i)
{
tmp[i+1] = c[i];
tmp[N-1-i] = -c[i];
}
fftplan.forward(tmp.data(), fct);
for (size_t i=0; i<n; ++i)
c[i] = -tmp[2*i+2];
}
size_t length() const { return fftplan.length()/2-1; }
};
template<typename T0> class T_dst2
{
private:
T_dct2<T0> dct;
public:
POCKETFFT_NOINLINE T_dst2(size_t length)
: dct(length) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
if (N==1)
c[0]*=2*fct;
else
{
for (size_t k=1; k<N; k+=2) for (size_t k=1; k<N; k+=2)
c[k] = -c[k]; c[k] = -c[k];
dct.exec(c, fct, false);
for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
swap(c[k], c[kc]);
}
if (ortho) c[0]/=sqrt2;
} }
size_t length() const { return dct.length(); } size_t length() const { return N; }
};
template<typename T0> class T_dst3
{
private:
T_dct3<T0> dct;
public:
POCKETFFT_NOINLINE T_dst3(size_t length)
: dct(length) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho)
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
if (ortho) c[0]*=sqrt2;
if (N==1)
c[0]*=fct;
else
{
size_t NS2 = N/2;
for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
swap(c[k], c[kc]);
dct.exec(c, fct, false);
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
}
}
size_t length() const { return dct.length(); }
};
template<typename T0> class T_dst4
{
private:
T_dct4<T0> dct;
public:
POCKETFFT_NOINLINE T_dst4(size_t length)
: dct(length) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/)
{
size_t N=length();
//if (N==1) { c[0]*=fct; return; }
size_t NS2 = N/2;
for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
swap(c[k], c[kc]);
dct.exec(c, fct, false);
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
}
size_t length() const { return dct.length(); }
}; };
...@@ -3093,7 +2997,7 @@ template<typename T> POCKETFFT_NOINLINE void general_hartley( ...@@ -3093,7 +2997,7 @@ template<typename T> POCKETFFT_NOINLINE void general_hartley(
template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst( template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, const cndarr<T> &in, ndarr<T> &out, const shape_t &axes,
T fct, bool ortho, size_t POCKETFFT_NTHREADS) T fct, bool ortho, int type, bool cosine, size_t POCKETFFT_NTHREADS)
{ {
shared_ptr<Trafo> plan; shared_ptr<Trafo> plan;
...@@ -3121,7 +3025,7 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst( ...@@ -3121,7 +3025,7 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = tin[it.iofs(j,i)]; tdatav[i][j] = tin[it.iofs(j,i)];
plan->exec(tdatav, fct, ortho); plan->exec(tdatav, fct, ortho, type, cosine);
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,i)] = tdatav[i][j]; out[it.oofs(j,i)] = tdatav[i][j];
...@@ -3132,18 +3036,18 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst( ...@@ -3132,18 +3036,18 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
it.advance(1); it.advance(1);
auto tdata = reinterpret_cast<T *>(storage.data()); auto tdata = reinterpret_cast<T *>(storage.data());
if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place
plan->exec(&out[it.oofs(0)], fct, ortho); plan->exec(&out[it.oofs(0)], fct, ortho, type, cosine);
else if (it.stride_out()==sizeof(T)) // compute FFT in output location else if (it.stride_out()==sizeof(T)) // compute FFT in output location
{ {
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tin[it.iofs(i)];