diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index b875db83f9599ef4cb4fa58a96f9e2791e2cc7f2..ac2ac7f85dd449bfb1aaa2ec408b522f1fd7b3d3 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -2293,6 +2293,7 @@ template class pocketfft_r // // sine/cosine transforms // + template class T_dct1 { private: @@ -2302,7 +2303,8 @@ template class T_dct1 POCKETFFT_NOINLINE T_dct1(size_t length) : fftplan(2*(length-1)) {} - template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int /*type*/, bool /*cosine*/) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=fftplan.length(), n=N/2+1; @@ -2323,22 +2325,42 @@ template class T_dct1 size_t length() const { return fftplan.length()/2+1; } }; -template class T_dct2 +template class T_dst1 { private: pocketfft_r fftplan; - vector twiddle; public: - POCKETFFT_NOINLINE T_dct2(size_t length) - : fftplan(length), twiddle(length) + POCKETFFT_NOINLINE T_dst1(size_t length) + : fftplan(2*(length+1)) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool /*cosine*/) const { - constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); - for (size_t i=0; i tmp(N); + tmp[0] = tmp[n+1] = c[0]*0; + for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const + size_t length() const { return fftplan.length()/2-1; } + }; + +template class T_dcst23 + { + private: + pocketfft_r fftplan; + vector twiddle; + + template POCKETFFT_NOINLINE void exec2c(T c[], T0 fct, + bool ortho) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); @@ -2368,25 +2390,8 @@ template class T_dct2 if (ortho) c[0]/=sqrt2; } - size_t length() const { return fftplan.length(); } - }; - -template class T_dct3 - { - private: - pocketfft_r fftplan; - vector twiddle; - - public: - POCKETFFT_NOINLINE T_dct3(size_t length) - : fftplan(length), twiddle(length) - { - constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); - for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const + template POCKETFFT_NOINLINE void exec3c(T c[], T0 fct, + bool ortho) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); @@ -2417,10 +2422,55 @@ template class T_dct3 } } + template POCKETFFT_NOINLINE void exec2s(T c[], T0 fct, + bool ortho) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=length(); + for (size_t k=1; k POCKETFFT_NOINLINE void exec3s(T c[], T0 fct, + bool ortho) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=length(); + if (ortho) c[0]*=sqrt2; + size_t NS2 = N/2; + for (size_t k=0, kc=N-1; k POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int type, bool cosine) const + { + if (type==2) + cosine ? exec2c(c, fct, ortho) : exec2s(c, fct, ortho); + else + cosine ? exec3c(c, fct, ortho) : exec3s(c, fct, ortho); + } + size_t length() const { return fftplan.length(); } }; -template class T_dct4 +template class T_dcst4 { // even length algorithm from // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/ @@ -2430,23 +2480,7 @@ template class T_dct4 unique_ptr> rfft; arr> C2; - public: - POCKETFFT_NOINLINE T_dct4(size_t length) - : N(length), - fft((N&1) ? nullptr : new pocketfft_c(N/2)), - rfft((N&1)? new pocketfft_r(N) : nullptr), - C2((N&1) ? 0 : N/2) - { - constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); - if ((N&1)==0) - for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const + template POCKETFFT_NOINLINE void exec4c(T c[], T0 fct) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); if (N&1) @@ -2512,106 +2546,38 @@ template class T_dct4 } } - size_t length() const { return N; } - }; - -template class T_dst1 - { - private: - pocketfft_r fftplan; - - public: - POCKETFFT_NOINLINE T_dst1(size_t length) - : fftplan(2*(length+1)) {} - - template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const + template POCKETFFT_NOINLINE void exec4s(T c[], T0 fct) const { - size_t N=fftplan.length(), n=N/2-1; - arr tmp(N); - tmp[0] = tmp[n+1] = c[0]*0; - for (size_t i=0; i class T_dst2 - { - private: - T_dct2 dct; - - public: - POCKETFFT_NOINLINE T_dst2(size_t length) - : dct(length) {} - - template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const - { - constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); - for (size_t k=1; k class T_dst3 - { - private: - T_dct3 dct; - - public: - POCKETFFT_NOINLINE T_dst3(size_t length) - : dct(length) {} - - template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) - { - constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); - size_t N=length(); - if (ortho) c[0]*=sqrt2; size_t NS2 = N/2; for (size_t k=0, kc=N-1; k class T_dst4 - { - private: - T_dct4 dct; - public: - POCKETFFT_NOINLINE T_dst4(size_t length) - : dct(length) {} - - template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) + POCKETFFT_NOINLINE T_dcst4(size_t length) + : N(length), + fft((N&1) ? nullptr : new pocketfft_c(N/2)), + rfft((N&1)? new pocketfft_r(N) : nullptr), + C2((N&1) ? 0 : N/2) { - size_t N=length(); - size_t NS2 = N/2; - for (size_t k=0, kc=N-1; k POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool cosine) const + { cosine ? exec4c(c, fct) : exec4s(c, fct); } + + size_t length() const { return N; } }; @@ -3060,7 +3026,7 @@ template POCKETFFT_NOINLINE void general_hartley( template POCKETFFT_NOINLINE void general_dcst( const cndarr &in, ndarr &out, const shape_t &axes, - T fct, bool ortho, size_t POCKETFFT_NTHREADS) + T fct, bool ortho, int type, bool cosine, size_t POCKETFFT_NTHREADS) { shared_ptr plan; @@ -3088,7 +3054,7 @@ template POCKETFFT_NOINLINE void general_dcst( for (size_t i=0; iexec(tdatav, fct, ortho); + plan->exec(tdatav, fct, ortho, type, cosine); for (size_t i=0; i POCKETFFT_NOINLINE void general_dcst( it.advance(1); auto tdata = reinterpret_cast(storage.data()); if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place - plan->exec(&out[it.oofs(0)], fct, ortho); + plan->exec(&out[it.oofs(0)], fct, ortho, type, cosine); else if (it.stride_out()==sizeof(T)) // compute FFT in output location { for (size_t i=0; iexec(&out[it.oofs(0)], fct, ortho); + plan->exec(&out[it.oofs(0)], fct, ortho, type, cosine); } else { for (size_t i=0; iexec(tdata, fct, ortho); + plan->exec(tdata, fct, ortho, type, cosine); for (size_t i=0; i void dct(const shape_t &shape, cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); if (type==1) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, type, true, nthreads); else if (type==2) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, type, true, nthreads); else if (type==3) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, type, true, nthreads); else if (type==4) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, type, true, nthreads); else throw runtime_error("unsupported DCT type"); } @@ -3395,13 +3361,13 @@ template void dst(const shape_t &shape, cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); if (type==1) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, type, false, nthreads); else if (type==2) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, type, false, nthreads); else if (type==3) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, type, false, nthreads); else if (type==4) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, type, false, nthreads); else throw runtime_error("unsupported DST type"); }