diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index b875db83f9599ef4cb4fa58a96f9e2791e2cc7f2..ac2ac7f85dd449bfb1aaa2ec408b522f1fd7b3d3 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -2293,6 +2293,7 @@ template<typename T0> class pocketfft_r // // sine/cosine transforms // + template<typename T0> class T_dct1 { private: @@ -2302,7 +2303,8 @@ template<typename T0> class T_dct1 POCKETFFT_NOINLINE T_dct1(size_t length) : fftplan(2*(length-1)) {} - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const + template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int /*type*/, bool /*cosine*/) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=fftplan.length(), n=N/2+1; @@ -2323,22 +2325,42 @@ template<typename T0> class T_dct1 size_t length() const { return fftplan.length()/2+1; } }; -template<typename T0> class T_dct2 +template<typename T0> class T_dst1 { private: pocketfft_r<T0> fftplan; - vector<T0> twiddle; public: - POCKETFFT_NOINLINE T_dct2(size_t length) - : fftplan(length), twiddle(length) + POCKETFFT_NOINLINE T_dst1(size_t length) + : fftplan(2*(length+1)) {} + + template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool /*cosine*/) const { - constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); - for (size_t i=0; i<length; ++i) - twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length))); + size_t N=fftplan.length(), n=N/2-1; + arr<T> tmp(N); + tmp[0] = tmp[n+1] = c[0]*0; + for (size_t i=0; i<n; ++i) + { + tmp[i+1] = c[i]; + tmp[N-1-i] = -c[i]; + } + fftplan.forward(tmp.data(), fct); + for (size_t i=0; i<n; ++i) + c[i] = -tmp[2*i+2]; } - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const + size_t length() const { return fftplan.length()/2-1; } + }; + +template<typename T0> class T_dcst23 + { + private: + pocketfft_r<T0> fftplan; + vector<T0> twiddle; + + template<typename T> POCKETFFT_NOINLINE void exec2c(T c[], T0 fct, + bool ortho) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); @@ -2368,25 +2390,8 @@ template<typename T0> class T_dct2 if (ortho) c[0]/=sqrt2; } - size_t length() const { return fftplan.length(); } - }; - -template<typename T0> class T_dct3 - { - private: - pocketfft_r<T0> fftplan; - vector<T0> twiddle; - - public: - POCKETFFT_NOINLINE T_dct3(size_t length) - : fftplan(length), twiddle(length) - { - constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); - for (size_t i=0; i<length; ++i) - twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length))); - } - - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const + template<typename T> POCKETFFT_NOINLINE void exec3c(T c[], T0 fct, + bool ortho) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); @@ -2417,10 +2422,55 @@ template<typename T0> class T_dct3 } } + template<typename T> POCKETFFT_NOINLINE void exec2s(T c[], T0 fct, + bool ortho) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=length(); + for (size_t k=1; k<N; k+=2) + c[k] = -c[k]; + exec2c(c, fct, false); + for (size_t k=0, kc=N-1; k<kc; ++k, --kc) + swap(c[k], c[kc]); + if (ortho) c[0]/=sqrt2; + } + + template<typename T> POCKETFFT_NOINLINE void exec3s(T c[], T0 fct, + bool ortho) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=length(); + if (ortho) c[0]*=sqrt2; + size_t NS2 = N/2; + for (size_t k=0, kc=N-1; k<NS2; ++k, --kc) + swap(c[k], c[kc]); + exec3c(c, fct, false); + for (size_t k=1; k<N; k+=2) + c[k] = -c[k]; + } + + public: + POCKETFFT_NOINLINE T_dcst23(size_t length) + : fftplan(length), twiddle(length) + { + constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); + for (size_t i=0; i<length; ++i) + twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length))); + } + + template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int type, bool cosine) const + { + if (type==2) + cosine ? exec2c(c, fct, ortho) : exec2s(c, fct, ortho); + else + cosine ? exec3c(c, fct, ortho) : exec3s(c, fct, ortho); + } + size_t length() const { return fftplan.length(); } }; -template<typename T0> class T_dct4 +template<typename T0> class T_dcst4 { // even length algorithm from // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/ @@ -2430,23 +2480,7 @@ template<typename T0> class T_dct4 unique_ptr<pocketfft_r<T0>> rfft; arr<cmplx<T0>> C2; - public: - POCKETFFT_NOINLINE T_dct4(size_t length) - : N(length), - fft((N&1) ? nullptr : new pocketfft_c<T0>(N/2)), - rfft((N&1)? new pocketfft_r<T0>(N) : nullptr), - C2((N&1) ? 0 : N/2) - { - constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); - if ((N&1)==0) - for (size_t i=0; i<N/2; ++i) - { - T0 ang = -pi/T0(N)*(T0(i)+T0(0.125)); - C2[i].Set(cos(ang), sin(ang)); - } - } - - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const + template<typename T> POCKETFFT_NOINLINE void exec4c(T c[], T0 fct) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); if (N&1) @@ -2512,106 +2546,38 @@ template<typename T0> class T_dct4 } } - size_t length() const { return N; } - }; - -template<typename T0> class T_dst1 - { - private: - pocketfft_r<T0> fftplan; - - public: - POCKETFFT_NOINLINE T_dst1(size_t length) - : fftplan(2*(length+1)) {} - - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const + template<typename T> POCKETFFT_NOINLINE void exec4s(T c[], T0 fct) const { - size_t N=fftplan.length(), n=N/2-1; - arr<T> tmp(N); - tmp[0] = tmp[n+1] = c[0]*0; - for (size_t i=0; i<n; ++i) - { - tmp[i+1] = c[i]; - tmp[N-1-i] = -c[i]; - } - fftplan.forward(tmp.data(), fct); - for (size_t i=0; i<n; ++i) - c[i] = -tmp[2*i+2]; - } - - size_t length() const { return fftplan.length()/2-1; } - }; - -template<typename T0> class T_dst2 - { - private: - T_dct2<T0> dct; - - public: - POCKETFFT_NOINLINE T_dst2(size_t length) - : dct(length) {} - - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const - { - constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); - for (size_t k=1; k<N; k+=2) - c[k] = -c[k]; - dct.exec(c, fct, false); - for (size_t k=0, kc=N-1; k<kc; ++k, --kc) - swap(c[k], c[kc]); - if (ortho) c[0]/=sqrt2; - } - - size_t length() const { return dct.length(); } - }; - -template<typename T0> class T_dst3 - { - private: - T_dct3<T0> dct; - - public: - POCKETFFT_NOINLINE T_dst3(size_t length) - : dct(length) {} - - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) - { - constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); - size_t N=length(); - if (ortho) c[0]*=sqrt2; size_t NS2 = N/2; for (size_t k=0, kc=N-1; k<NS2; ++k, --kc) swap(c[k], c[kc]); - dct.exec(c, fct, false); + exec4c(c, fct); for (size_t k=1; k<N; k+=2) c[k] = -c[k]; } - size_t length() const { return dct.length(); } - }; - -template<typename T0> class T_dst4 - { - private: - T_dct4<T0> dct; - public: - POCKETFFT_NOINLINE T_dst4(size_t length) - : dct(length) {} - - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) + POCKETFFT_NOINLINE T_dcst4(size_t length) + : N(length), + fft((N&1) ? nullptr : new pocketfft_c<T0>(N/2)), + rfft((N&1)? new pocketfft_r<T0>(N) : nullptr), + C2((N&1) ? 0 : N/2) { - size_t N=length(); - size_t NS2 = N/2; - for (size_t k=0, kc=N-1; k<NS2; ++k, --kc) - swap(c[k], c[kc]); - dct.exec(c, fct, false); - for (size_t k=1; k<N; k+=2) - c[k] = -c[k]; + constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); + if ((N&1)==0) + for (size_t i=0; i<N/2; ++i) + { + T0 ang = -pi/T0(N)*(T0(i)+T0(0.125)); + C2[i].Set(cos(ang), sin(ang)); + } } - size_t length() const { return dct.length(); } + template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool cosine) const + { cosine ? exec4c(c, fct) : exec4s(c, fct); } + + size_t length() const { return N; } }; @@ -3060,7 +3026,7 @@ template<typename T> POCKETFFT_NOINLINE void general_hartley( template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst( const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, - T fct, bool ortho, size_t POCKETFFT_NTHREADS) + T fct, bool ortho, int type, bool cosine, size_t POCKETFFT_NTHREADS) { shared_ptr<Trafo> plan; @@ -3088,7 +3054,7 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst( for (size_t i=0; i<len; ++i) for (size_t j=0; j<vlen; ++j) tdatav[i][j] = tin[it.iofs(j,i)]; - plan->exec(tdatav, fct, ortho); + plan->exec(tdatav, fct, ortho, type, cosine); for (size_t i=0; i<len; ++i) for (size_t j=0; j<vlen; ++j) out[it.oofs(j,i)] = tdatav[i][j]; @@ -3099,18 +3065,18 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst( it.advance(1); auto tdata = reinterpret_cast<T *>(storage.data()); if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place - plan->exec(&out[it.oofs(0)], fct, ortho); + plan->exec(&out[it.oofs(0)], fct, ortho, type, cosine); else if (it.stride_out()==sizeof(T)) // compute FFT in output location { for (size_t i=0; i<len; ++i) out[it.oofs(i)] = tin[it.iofs(i)]; - plan->exec(&out[it.oofs(0)], fct, ortho); + plan->exec(&out[it.oofs(0)], fct, ortho, type, cosine); } else { for (size_t i=0; i<len; ++i) tdata[i] = tin[it.iofs(i)]; - plan->exec(tdata, fct, ortho); + plan->exec(tdata, fct, ortho, type, cosine); for (size_t i=0; i<len; ++i) out[it.oofs(i)] = tdata[i]; } @@ -3374,13 +3340,13 @@ template<typename T> void dct(const shape_t &shape, cndarr<T> ain(data_in, shape, stride_in); ndarr<T> aout(data_out, shape, stride_out); if (type==1) - general_dcst<T_dct1<T>>(ain, aout, axes, fct, ortho, nthreads); + general_dcst<T_dct1<T>>(ain, aout, axes, fct, ortho, type, true, nthreads); else if (type==2) - general_dcst<T_dct2<T>>(ain, aout, axes, fct, ortho, nthreads); + general_dcst<T_dcst23<T>>(ain, aout, axes, fct, ortho, type, true, nthreads); else if (type==3) - general_dcst<T_dct3<T>>(ain, aout, axes, fct, ortho, nthreads); + general_dcst<T_dcst23<T>>(ain, aout, axes, fct, ortho, type, true, nthreads); else if (type==4) - general_dcst<T_dct4<T>>(ain, aout, axes, fct, ortho, nthreads); + general_dcst<T_dcst4<T>>(ain, aout, axes, fct, ortho, type, true, nthreads); else throw runtime_error("unsupported DCT type"); } @@ -3395,13 +3361,13 @@ template<typename T> void dst(const shape_t &shape, cndarr<T> ain(data_in, shape, stride_in); ndarr<T> aout(data_out, shape, stride_out); if (type==1) - general_dcst<T_dst1<T>>(ain, aout, axes, fct, ortho, nthreads); + general_dcst<T_dst1<T>>(ain, aout, axes, fct, ortho, type, false, nthreads); else if (type==2) - general_dcst<T_dst2<T>>(ain, aout, axes, fct, ortho, nthreads); + general_dcst<T_dcst23<T>>(ain, aout, axes, fct, ortho, type, false, nthreads); else if (type==3) - general_dcst<T_dst3<T>>(ain, aout, axes, fct, ortho, nthreads); + general_dcst<T_dcst23<T>>(ain, aout, axes, fct, ortho, type, false, nthreads); else if (type==4) - general_dcst<T_dst4<T>>(ain, aout, axes, fct, ortho, nthreads); + general_dcst<T_dcst4<T>>(ain, aout, axes, fct, ortho, type, false, nthreads); else throw runtime_error("unsupported DST type"); }