Commit 50476533 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'cleanup' into 'master'

Cleanup

See merge request !15
parents 656b2373 88f64fcf
......@@ -209,13 +209,19 @@ template<typename T> struct cmplx {
r = tmp;
return *this;
}
cmplx operator+ (const cmplx &other) const
{ return cmplx(r+other.r, i+other.i); }
cmplx operator- (const cmplx &other) const
{ return cmplx(r-other.r, i-other.i); }
template<typename T2>cmplx &operator+= (const cmplx<T2> &other)
{ r+=other.r; i+=other.i; return *this; }
template<typename T2>cmplx &operator-= (const cmplx<T2> &other)
{ r-=other.r; i-=other.i; return *this; }
template<typename T2> auto operator* (const T2 &other) const
-> cmplx<decltype(r*other)>
{ return {r*other, i*other}; }
template<typename T2> auto operator+ (const cmplx<T2> &other) const
-> cmplx<decltype(r+other.r)>
{ return {r+other.r, i+other.i}; }
template<typename T2> auto operator- (const cmplx<T2> &other) const
-> cmplx<decltype(r+other.r)>
{ return {r-other.r, i-other.i}; }
template<typename T2> auto operator* (const cmplx<T2> &other) const
-> cmplx<decltype(r+other.r)>
{ return {r*other.r-i*other.i, r*other.i + i*other.r}; }
......@@ -227,8 +233,9 @@ template<typename T> struct cmplx {
: Tres(r*other.r-i*other.i, r*other.i+i*other.r);
}
};
template<typename T> void PMC(cmplx<T> &a, cmplx<T> &b,
const cmplx<T> &c, const cmplx<T> &d)
template<typename T> inline void PMINPLACE(T &a, T &b)
{ T t = a; a+=b; b=t-b; }
template<typename T> void PMC(T &a, T &b, const T &c, const T &d)
{ a = c+d; b = c-d; }
template<typename T> cmplx<T> conj(const cmplx<T> &a)
{ return {a.r, -a.i}; }
......@@ -845,8 +852,6 @@ template <bool fwd, typename T> void ROTX135(T &a)
else
{ auto tmp_=a.r; a.r=hsqt2*(-a.r-a.i); a.i=hsqt2*(tmp_-a.i); }
}
template<typename T> inline void PMINPLACE(T &a, T &b)
{ T t = a; a.r+=b.r; a.i+=b.i; b.r=t.r-b.r; b.i=t.i-b.i; }
template<bool fwd, typename T> void pass8 (size_t ido, size_t l1,
const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
......@@ -2293,6 +2298,7 @@ template<typename T0> class pocketfft_r
//
// sine/cosine transforms
//
template<typename T0> class T_dct1
{
private:
......@@ -2302,7 +2308,8 @@ template<typename T0> class T_dct1
POCKETFFT_NOINLINE T_dct1(size_t length)
: fftplan(2*(length-1)) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho,
int /*type*/, bool /*cosine*/) const
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=fftplan.length(), n=N/2+1;
......@@ -2323,126 +2330,124 @@ template<typename T0> class T_dct1
size_t length() const { return fftplan.length()/2+1; }
};
template<typename T0> class T_dct2
template<typename T0> class T_dst1
{
private:
pocketfft_r<T0> fftplan;
vector<T0> twiddle;
public:
POCKETFFT_NOINLINE T_dct2(size_t length)
: fftplan(length), twiddle(length)
{
constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
for (size_t i=0; i<length; ++i)
twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
}
POCKETFFT_NOINLINE T_dst1(size_t length)
: fftplan(2*(length+1)) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct,
bool /*ortho*/, int /*type*/, bool /*cosine*/) const
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
if (N==1)
c[0]*=2*fct;
else if (N==2)
{
T x1 = 2*fct*(c[0]+c[1]);
c[1] = sqrt2*fct*(c[0]-c[1]);
c[0] = x1;
}
else
size_t N=fftplan.length(), n=N/2-1;
arr<T> tmp(N);
tmp[0] = tmp[n+1] = c[0]*0;
for (size_t i=0; i<n; ++i)
{
size_t NS2 = (N+1)/2;
for (size_t i=2; i<N; i+=2)
{
T xim1 = T0(0.5)*(c[i-1]+c[i]);
c[i] = T0(0.5)*(c[i]-c[i-1]);
c[i-1] = xim1;
}
fftplan.backward(c, fct);
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
{
T tmp = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k];
c[kc] = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc];
c[k] = tmp;
}
if ((N&1)==0)
c[NS2] = twiddle[NS2-1]*(c[NS2]+c[NS2]);
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
{
T tmp = c[k]+c[kc];
c[kc] = c[k]-c[kc];
c[k] = tmp;
}
c[0] *= 2;
tmp[i+1] = c[i];
tmp[N-1-i] = -c[i];
}
if (ortho) c[0]/=sqrt2;
fftplan.forward(tmp.data(), fct);
for (size_t i=0; i<n; ++i)
c[i] = -tmp[2*i+2];
}
size_t length() const { return fftplan.length(); }
size_t length() const { return fftplan.length()/2-1; }
};
template<typename T0> class T_dct3
template<typename T0> class T_dcst23
{
private:
pocketfft_r<T0> fftplan;
vector<T0> twiddle;
public:
POCKETFFT_NOINLINE T_dct3(size_t length)
: fftplan(length), twiddle(length)
template<typename T> POCKETFFT_NOINLINE void exec2(T c[], T0 fct,
bool ortho, bool cosine) const
{
constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
for (size_t i=0; i<length; ++i)
twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
size_t NS2 = (N+1)/2;
if (!cosine)
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
for (size_t i=2; i<N; i+=2)
{
T xim1 = T0(0.5)*(c[i-1]+c[i]);
c[i] = T0(0.5)*(c[i]-c[i-1]);
c[i-1] = xim1;
}
fftplan.backward(c, fct);
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
{
T tmp = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k];
c[kc] = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc];
c[k] = tmp;
}
if ((N&1)==0)
c[NS2] = twiddle[NS2-1]*(c[NS2]+c[NS2]);
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
PMINPLACE(c[k], c[kc]);
c[0] *= 2;
if (!cosine)
for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
swap(c[k], c[kc]);
if (ortho) c[0]/=sqrt2;
}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
template<typename T> POCKETFFT_NOINLINE void exec3(T c[], T0 fct,
bool ortho, bool cosine) const
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
if (ortho) c[0]*=sqrt2;
if (N==1)
c[0]*=fct;
else if (N==2)
size_t NS2 = (N+1)/2;
if (!cosine)
for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
swap(c[k], c[kc]);
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
PMINPLACE(c[k], c[kc]);
if ((N&1)==0)
c[NS2] = c[NS2]+c[NS2];
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
{
T TSQX = sqrt2*c[1];
c[1] = fct*(c[0]-TSQX);
c[0] = fct*(c[0]+TSQX);
T tmp = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc];
c[k] = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k];
c[kc] = tmp;
}
else
if ((N&1)==0)
c[NS2] = twiddle[NS2-1]*c[NS2];
fftplan.forward(c, fct);
for (size_t i=2; i<N; i+=2)
{
size_t NS2 = (N+1)/2;
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
{
T tmp = c[k]-c[kc];
c[k] = c[k]+c[kc];
c[kc] = tmp;
}
if ((N&1)==0)
c[NS2] = c[NS2]+c[NS2];
for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
{
T tmp = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc];
c[k] = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k];
c[kc] = tmp;
}
if ((N&1)==0)
c[NS2] = twiddle[NS2-1]*c[NS2];
fftplan.forward(c, fct);
for (size_t i=2; i<N; i+=2)
{
T xim1 = c[i-1]-c[i];
c[i] += c[i-1];
c[i-1] = xim1;
}
T xim1 = c[i-1]-c[i];
c[i] += c[i-1];
c[i-1] = xim1;
}
if (!cosine)
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
}
public:
POCKETFFT_NOINLINE T_dcst23(size_t length)
: fftplan(length), twiddle(length)
{
constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
for (size_t i=0; i<length; ++i)
twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho,
int type, bool cosine) const
{ type==2 ? exec2(c, fct, ortho, cosine) : exec3(c, fct, ortho, cosine); }
size_t length() const { return fftplan.length(); }
};
template<typename T0> class T_dct4
template<typename T0> class T_dcst4
{
// even length algorithm from
// https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/
......@@ -2453,7 +2458,7 @@ template<typename T0> class T_dct4
arr<cmplx<T0>> C2;
public:
POCKETFFT_NOINLINE T_dct4(size_t length)
POCKETFFT_NOINLINE T_dcst4(size_t length)
: N(length),
fft((N&1) ? nullptr : new pocketfft_c<T0>(N/2)),
rfft((N&1)? new pocketfft_r<T0>(N) : nullptr),
......@@ -2468,9 +2473,14 @@ template<typename T0> class T_dct4
}
}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct,
bool /*ortho*/, int /*type*/, bool cosine) const
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t n2 = N/2;
if (!cosine)
for (size_t k=0, kc=N-1; k<n2; ++k, --kc)
swap(c[k], c[kc]);
if (N&1)
{
// The following code is derived from the FFTW3 function apply_re11()
......@@ -2479,11 +2489,9 @@ template<typename T0> class T_dct4
auto SGN_SET = [](T x, size_t i) {return (i%2) ? -x : x;};
arr<T> y(N);
size_t n2 = N/2;
size_t i;
{
size_t m;
for (i=0, m=n2; m<N; ++i, m+=4)
size_t i=0, m=n2;
for (; m<N; ++i, m+=4)
y[i] = c[m];
for (; m<2*N; ++i, m+=4)
y[i] = -c[2*N-m-1];
......@@ -2496,7 +2504,9 @@ template<typename T0> class T_dct4
y[i] = c[m];
}
rfft->forward(y.data(), fct);
for (i=0; i+i+1<n2; ++i)
{
size_t i=0;
for (; i+i+1<n2; ++i)
{
size_t k = i+i+1;
T c1=y[2*k-1], s1=y[2*k], c2=y[2*k+1], s2=y[2*k+2];
......@@ -2511,140 +2521,34 @@ template<typename T0> class T_dct4
c[i] = sqrt2 * (SGN_SET(cx, (i+1)/2) + SGN_SET(sx, i/2));
c[N-(i+1)] = sqrt2 * (SGN_SET(cx, (i+2)/2) + SGN_SET(sx, (i+1)/2));
}
}
c[n2] = sqrt2 * SGN_SET(y[0], (n2+1)/2);
// FFTW-derived code ends here
}
else
{
arr<cmplx<T>> y(N/2);
for(size_t i=0; i<N/2; ++i)
// even length algorithm from
// https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/
arr<cmplx<T>> y(n2);
for(size_t i=0; i<n2; ++i)
{
y[i].Set(c[2*i],c[N-1-2*i]);
y[i] *= C2[i];
}
fft->forward(y.data(), fct);
for(size_t i=0; i<N/2; ++i)
y[i] *= C2[i];
for(size_t i=0; i<N/2; ++i)
for(size_t i=0, ic=n2-1; i<n2; ++i, --ic)
{
c[2*i] = 2*y[i].r;
c[2*i+1] = -2*y[N/2-1-i].i;
c[2*i ] = 2*(y[i ].r*C2[i ].r-y[i ].i*C2[i ].i);
c[2*i+1] = -2*(y[ic].i*C2[ic].r+y[ic].r*C2[ic].i);
}
}
}
size_t length() const { return N; }
};
template<typename T0> class T_dst1
{
private:
pocketfft_r<T0> fftplan;
public:
POCKETFFT_NOINLINE T_dst1(size_t length)
: fftplan(2*(length+1)) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const
{
size_t N=fftplan.length(), n=N/2-1;
arr<T> tmp(N);
tmp[0] = tmp[n+1] = c[0]*0;
for (size_t i=0; i<n; ++i)
{
tmp[i+1] = c[i];
tmp[N-1-i] = -c[i];
}
fftplan.forward(tmp.data(), fct);
for (size_t i=0; i<n; ++i)
c[i] = -tmp[2*i+2];
}
size_t length() const { return fftplan.length()/2-1; }
};
template<typename T0> class T_dst2
{
private:
T_dct2<T0> dct;
public:
POCKETFFT_NOINLINE T_dst2(size_t length)
: dct(length) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
if (N==1)
c[0]*=2*fct;
else
{
if (!cosine)
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
dct.exec(c, fct, false);
for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
swap(c[k], c[kc]);
}
if (ortho) c[0]/=sqrt2;
}
size_t length() const { return dct.length(); }
};
template<typename T0> class T_dst3
{
private:
T_dct3<T0> dct;
public:
POCKETFFT_NOINLINE T_dst3(size_t length)
: dct(length) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho)
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
if (ortho) c[0]*=sqrt2;
if (N==1)
c[0]*=fct;
else
{
size_t NS2 = N/2;
for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
swap(c[k], c[kc]);
dct.exec(c, fct, false);
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
}
}
size_t length() const { return dct.length(); }
};
template<typename T0> class T_dst4
{
private:
T_dct4<T0> dct;
public:
POCKETFFT_NOINLINE T_dst4(size_t length)
: dct(length) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/)
{
size_t N=length();
//if (N==1) { c[0]*=fct; return; }
size_t NS2 = N/2;
for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
swap(c[k], c[kc]);
dct.exec(c, fct, false);
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
}
size_t length() const { return dct.length(); }
size_t length() const { return N; }
};
......@@ -3093,7 +2997,7 @@ template<typename T> POCKETFFT_NOINLINE void general_hartley(
template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes,
T fct, bool ortho, size_t POCKETFFT_NTHREADS)
T fct, bool ortho, int type, bool cosine, size_t POCKETFFT_NTHREADS)
{
shared_ptr<Trafo> plan;
......@@ -3121,7 +3025,7 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = tin[it.iofs(j,i)];
plan->exec(tdatav, fct, ortho);
plan->exec(tdatav, fct, ortho, type, cosine);
for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,i)] = tdatav[i][j];
......@@ -3132,18 +3036,18 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
it.advance(1);
auto tdata = reinterpret_cast<T *>(storage.data());
if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place
plan->exec(&out[it.oofs(0)], fct, ortho);
plan->exec(&out[it.oofs(0)], fct, ortho, type, cosine);
else if (it.stride_out()==sizeof(T)) // compute FFT in output location
{
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tin[it.iofs(i)];
plan->exec(&out[it.oofs(0)], fct, ortho);
plan->exec(&out[it.oofs(0)], fct, ortho, type, cosine);
}
else
{
for (size_t i=0; i<len; ++i)
tdata[i] = tin[it.iofs(i)];
plan->exec(tdata, fct, ortho);
plan->exec(tdata, fct, ortho, type, cosine);
for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tdata[i];
}
......@@ -3407,15 +3311,11 @@ template<typename T> void dct(const shape_t &shape,
cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out);
if (type==1)
general_dcst<T_dct1<T>>(ain, aout, axes, fct, ortho, nthreads);
else if (type==2)
general_dcst<T_dct2<T>>(ain, aout, axes, fct, ortho, nthreads);
else if (type==3)
general_dcst<T_dct3<T>>(ain, aout, axes, fct, ortho, nthreads);
general_dcst<T_dct1<T>>(ain, aout, axes, fct, ortho, type, true, nthreads);
else if (type==4)
general_dcst<T_dct4<T>>(ain, aout, axes, fct, ortho, nthreads);
general_dcst<T_dcst4<T>>(ain, aout, axes, fct, ortho, type, true, nthreads);
else
throw runtime_error("unsupported DCT type");
general_dcst<T_dcst23<T>>(ain, aout, axes, fct, ortho, type, true, nthreads);
}
template<typename T> void dst(const shape_t &shape,
......@@ -3428,15 +3328,11 @@ template<typename T> void dst(const shape_t &shape,
cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out);
if (type==1)
general_dcst<T_dst1<T>>(ain, aout, axes, fct, ortho, nthreads);
else if (type==2)
general_dcst<T_dst2<T>>(ain, aout, axes, fct, ortho, nthreads);
else if (type==3)
general_dcst<T_dst3<T>>(ain, aout, axes, fct, ortho, nthreads);
general_dcst<T_dst1<T>>(ain, aout, axes, fct, ortho, type, false, nthreads);
else if (type==4)
general_dcst<T_dst4<T>>(ain, aout, axes, fct, ortho, nthreads);
general_dcst<T_dcst4<T>>(ain, aout, axes, fct, ortho, type, false, nthreads);
else
throw runtime_error("unsupported DST type");
general_dcst<T_dcst23<T>>(ain, aout, axes, fct, ortho, type, false, nthreads);
}
template<typename T> void r2c(const shape_t &shape_in,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment