Commit c2da68d4 authored by Martin Reinecke's avatar Martin Reinecke

consolidate trigonometric transforms to save cache space and precomputation time

parent d556ae21
...@@ -2293,6 +2293,7 @@ template<typename T0> class pocketfft_r ...@@ -2293,6 +2293,7 @@ template<typename T0> class pocketfft_r
// //
// sine/cosine transforms // sine/cosine transforms
// //
template<typename T0> class T_dct1 template<typename T0> class T_dct1
{ {
private: private:
...@@ -2302,7 +2303,8 @@ template<typename T0> class T_dct1 ...@@ -2302,7 +2303,8 @@ template<typename T0> class T_dct1
POCKETFFT_NOINLINE T_dct1(size_t length) POCKETFFT_NOINLINE T_dct1(size_t length)
: fftplan(2*(length-1)) {} : fftplan(2*(length-1)) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho,
int /*type*/, bool /*cosine*/) const
{ {
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=fftplan.length(), n=N/2+1; size_t N=fftplan.length(), n=N/2+1;
...@@ -2323,22 +2325,42 @@ template<typename T0> class T_dct1 ...@@ -2323,22 +2325,42 @@ template<typename T0> class T_dct1
size_t length() const { return fftplan.length()/2+1; } size_t length() const { return fftplan.length()/2+1; }
}; };
template<typename T0> class T_dct2 template<typename T0> class T_dst1
{ {
private: private:
pocketfft_r<T0> fftplan; pocketfft_r<T0> fftplan;
vector<T0> twiddle;
public: public:
POCKETFFT_NOINLINE T_dct2(size_t length) POCKETFFT_NOINLINE T_dst1(size_t length)
: fftplan(length), twiddle(length) : fftplan(2*(length+1)) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct,
bool /*ortho*/, int /*type*/, bool /*cosine*/) const
{ {
constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); size_t N=fftplan.length(), n=N/2-1;
for (size_t i=0; i<length; ++i) arr<T> tmp(N);
twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length))); tmp[0] = tmp[n+1] = c[0]*0;
for (size_t i=0; i<n; ++i)
{
tmp[i+1] = c[i];
tmp[N-1-i] = -c[i];
}
fftplan.forward(tmp.data(), fct);
for (size_t i=0; i<n; ++i)
c[i] = -tmp[2*i+2];
} }
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const size_t length() const { return fftplan.length()/2-1; }
};
template<typename T0> class T_dcst23
{
private:
pocketfft_r<T0> fftplan;
vector<T0> twiddle;
template<typename T> POCKETFFT_NOINLINE void exec2c(T c[], T0 fct,
bool ortho) const
{ {
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length(); size_t N=length();
...@@ -2368,25 +2390,8 @@ template<typename T0> class T_dct2 ...@@ -2368,25 +2390,8 @@ template<typename T0> class T_dct2
if (ortho) c[0]/=sqrt2; if (ortho) c[0]/=sqrt2;
} }
size_t length() const { return fftplan.length(); } template<typename T> POCKETFFT_NOINLINE void exec3c(T c[], T0 fct,
}; bool ortho) const
template<typename T0> class T_dct3
{
private:
pocketfft_r<T0> fftplan;
vector<T0> twiddle;
public:
POCKETFFT_NOINLINE T_dct3(size_t length)
: fftplan(length), twiddle(length)
{
constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
for (size_t i=0; i<length; ++i)
twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
{ {
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length(); size_t N=length();
...@@ -2417,10 +2422,55 @@ template<typename T0> class T_dct3 ...@@ -2417,10 +2422,55 @@ template<typename T0> class T_dct3
} }
} }
template<typename T> POCKETFFT_NOINLINE void exec2s(T c[], T0 fct,
bool ortho) const
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
exec2c(c, fct, false);
for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
swap(c[k], c[kc]);
if (ortho) c[0]/=sqrt2;
}
template<typename T> POCKETFFT_NOINLINE void exec3s(T c[], T0 fct,
bool ortho) const
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
if (ortho) c[0]*=sqrt2;
size_t NS2 = N/2;
for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
swap(c[k], c[kc]);
exec3c(c, fct, false);
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
}
public:
POCKETFFT_NOINLINE T_dcst23(size_t length)
: fftplan(length), twiddle(length)
{
constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
for (size_t i=0; i<length; ++i)
twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho,
int type, bool cosine) const
{
if (type==2)
cosine ? exec2c(c, fct, ortho) : exec2s(c, fct, ortho);
else
cosine ? exec3c(c, fct, ortho) : exec3s(c, fct, ortho);
}
size_t length() const { return fftplan.length(); } size_t length() const { return fftplan.length(); }
}; };
template<typename T0> class T_dct4 template<typename T0> class T_dcst4
{ {
// even length algorithm from // even length algorithm from
// https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/ // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/
...@@ -2430,23 +2480,7 @@ template<typename T0> class T_dct4 ...@@ -2430,23 +2480,7 @@ template<typename T0> class T_dct4
unique_ptr<pocketfft_r<T0>> rfft; unique_ptr<pocketfft_r<T0>> rfft;
arr<cmplx<T0>> C2; arr<cmplx<T0>> C2;
public: template<typename T> POCKETFFT_NOINLINE void exec4c(T c[], T0 fct) const
POCKETFFT_NOINLINE T_dct4(size_t length)
: N(length),
fft((N&1) ? nullptr : new pocketfft_c<T0>(N/2)),
rfft((N&1)? new pocketfft_r<T0>(N) : nullptr),
C2((N&1) ? 0 : N/2)
{
constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
if ((N&1)==0)
for (size_t i=0; i<N/2; ++i)
{
T0 ang = -pi/T0(N)*(T0(i)+T0(0.125));
C2[i].Set(cos(ang), sin(ang));
}
}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const
{ {
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
if (N&1) if (N&1)
...@@ -2512,106 +2546,38 @@ template<typename T0> class T_dct4 ...@@ -2512,106 +2546,38 @@ template<typename T0> class T_dct4
} }
} }
size_t length() const { return N; } template<typename T> POCKETFFT_NOINLINE void exec4s(T c[], T0 fct) const
};
template<typename T0> class T_dst1
{
private:
pocketfft_r<T0> fftplan;
public:
POCKETFFT_NOINLINE T_dst1(size_t length)
: fftplan(2*(length+1)) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const
{ {
size_t N=fftplan.length(), n=N/2-1;
arr<T> tmp(N);
tmp[0] = tmp[n+1] = c[0]*0;
for (size_t i=0; i<n; ++i)
{
tmp[i+1] = c[i];
tmp[N-1-i] = -c[i];
}
fftplan.forward(tmp.data(), fct);
for (size_t i=0; i<n; ++i)
c[i] = -tmp[2*i+2];
}
size_t length() const { return fftplan.length()/2-1; }
};
template<typename T0> class T_dst2
{
private:
T_dct2<T0> dct;
public:
POCKETFFT_NOINLINE T_dst2(size_t length)
: dct(length) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length(); size_t N=length();
for (size_t k=1; k<N; k+=2)
c[k] = -c[k];
dct.exec(c, fct, false);
for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
swap(c[k], c[kc]);
if (ortho) c[0]/=sqrt2;
}
size_t length() const { return dct.length(); }
};
template<typename T0> class T_dst3
{
private:
T_dct3<T0> dct;
public:
POCKETFFT_NOINLINE T_dst3(size_t length)
: dct(length) {}
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho)
{
constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
size_t N=length();
if (ortho) c[0]*=sqrt2;
size_t NS2 = N/2; size_t NS2 = N/2;
for (size_t k=0, kc=N-1; k<NS2; ++k, --kc) for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
swap(c[k], c[kc]); swap(c[k], c[kc]);
dct.exec(c, fct, false); exec4c(c, fct);
for (size_t k=1; k<N; k+=2) for (size_t k=1; k<N; k+=2)
c[k] = -c[k]; c[k] = -c[k];
} }
size_t length() const { return dct.length(); }
};
template<typename T0> class T_dst4
{
private:
T_dct4<T0> dct;
public: public:
POCKETFFT_NOINLINE T_dst4(size_t length) POCKETFFT_NOINLINE T_dcst4(size_t length)
: dct(length) {} : N(length),
fft((N&1) ? nullptr : new pocketfft_c<T0>(N/2)),
template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) rfft((N&1)? new pocketfft_r<T0>(N) : nullptr),
C2((N&1) ? 0 : N/2)
{ {
size_t N=length(); constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
size_t NS2 = N/2; if ((N&1)==0)
for (size_t k=0, kc=N-1; k<NS2; ++k, --kc) for (size_t i=0; i<N/2; ++i)
swap(c[k], c[kc]); {
dct.exec(c, fct, false); T0 ang = -pi/T0(N)*(T0(i)+T0(0.125));
for (size_t k=1; k<N; k+=2) C2[i].Set(cos(ang), sin(ang));
c[k] = -c[k]; }
} }
size_t length() const { return dct.length(); } template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct,
bool /*ortho*/, int /*type*/, bool cosine) const
{ cosine ? exec4c(c, fct) : exec4s(c, fct); }
size_t length() const { return N; }
}; };
...@@ -3060,7 +3026,7 @@ template<typename T> POCKETFFT_NOINLINE void general_hartley( ...@@ -3060,7 +3026,7 @@ template<typename T> POCKETFFT_NOINLINE void general_hartley(
template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst( template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, const cndarr<T> &in, ndarr<T> &out, const shape_t &axes,
T fct, bool ortho, size_t POCKETFFT_NTHREADS) T fct, bool ortho, int type, bool cosine, size_t POCKETFFT_NTHREADS)
{ {
shared_ptr<Trafo> plan; shared_ptr<Trafo> plan;
...@@ -3088,7 +3054,7 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst( ...@@ -3088,7 +3054,7 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
tdatav[i][j] = tin[it.iofs(j,i)]; tdatav[i][j] = tin[it.iofs(j,i)];
plan->exec(tdatav, fct, ortho); plan->exec(tdatav, fct, ortho, type, cosine);
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,i)] = tdatav[i][j]; out[it.oofs(j,i)] = tdatav[i][j];
...@@ -3099,18 +3065,18 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst( ...@@ -3099,18 +3065,18 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
it.advance(1); it.advance(1);
auto tdata = reinterpret_cast<T *>(storage.data()); auto tdata = reinterpret_cast<T *>(storage.data());
if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place
plan->exec(&out[it.oofs(0)], fct, ortho); plan->exec(&out[it.oofs(0)], fct, ortho, type, cosine);
else if (it.stride_out()==sizeof(T)) // compute FFT in output location else if (it.stride_out()==sizeof(T)) // compute FFT in output location
{ {
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tin[it.iofs(i)]; out[it.oofs(i)] = tin[it.iofs(i)];
plan->exec(&out[it.oofs(0)], fct, ortho); plan->exec(&out[it.oofs(0)], fct, ortho, type, cosine);
} }
else else
{ {
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
tdata[i] = tin[it.iofs(i)]; tdata[i] = tin[it.iofs(i)];
plan->exec(tdata, fct, ortho); plan->exec(tdata, fct, ortho, type, cosine);
for (size_t i=0; i<len; ++i) for (size_t i=0; i<len; ++i)
out[it.oofs(i)] = tdata[i]; out[it.oofs(i)] = tdata[i];
} }
...@@ -3374,13 +3340,13 @@ template<typename T> void dct(const shape_t &shape, ...@@ -3374,13 +3340,13 @@ template<typename T> void dct(const shape_t &shape,
cndarr<T> ain(data_in, shape, stride_in); cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out); ndarr<T> aout(data_out, shape, stride_out);
if (type==1) if (type==1)
general_dcst<T_dct1<T>>(ain, aout, axes, fct, ortho, nthreads); general_dcst<T_dct1<T>>(ain, aout, axes, fct, ortho, type, true, nthreads);
else if (type==2) else if (type==2)
general_dcst<T_dct2<T>>(ain, aout, axes, fct, ortho, nthreads); general_dcst<T_dcst23<T>>(ain, aout, axes, fct, ortho, type, true, nthreads);
else if (type==3) else if (type==3)
general_dcst<T_dct3<T>>(ain, aout, axes, fct, ortho, nthreads); general_dcst<T_dcst23<T>>(ain, aout, axes, fct, ortho, type, true, nthreads);
else if (type==4) else if (type==4)
general_dcst<T_dct4<T>>(ain, aout, axes, fct, ortho, nthreads); general_dcst<T_dcst4<T>>(ain, aout, axes, fct, ortho, type, true, nthreads);
else else
throw runtime_error("unsupported DCT type"); throw runtime_error("unsupported DCT type");
} }
...@@ -3395,13 +3361,13 @@ template<typename T> void dst(const shape_t &shape, ...@@ -3395,13 +3361,13 @@ template<typename T> void dst(const shape_t &shape,
cndarr<T> ain(data_in, shape, stride_in); cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out); ndarr<T> aout(data_out, shape, stride_out);
if (type==1) if (type==1)
general_dcst<T_dst1<T>>(ain, aout, axes, fct, ortho, nthreads); general_dcst<T_dst1<T>>(ain, aout, axes, fct, ortho, type, false, nthreads);
else if (type==2) else if (type==2)
general_dcst<T_dst2<T>>(ain, aout, axes, fct, ortho, nthreads); general_dcst<T_dcst23<T>>(ain, aout, axes, fct, ortho, type, false, nthreads);
else if (type==3) else if (type==3)
general_dcst<T_dst3<T>>(ain, aout, axes, fct, ortho, nthreads); general_dcst<T_dcst23<T>>(ain, aout, axes, fct, ortho, type, false, nthreads);
else if (type==4) else if (type==4)
general_dcst<T_dst4<T>>(ain, aout, axes, fct, ortho, nthreads); general_dcst<T_dcst4<T>>(ain, aout, axes, fct, ortho, type, false, nthreads);
else else
throw runtime_error("unsupported DST type"); throw runtime_error("unsupported DST type");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment