diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index eba33984df9556b44581b039dd7226507a040ff6..b74b1c653c748d72c0613506099d5a733ba12e01 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -209,13 +209,19 @@ template struct cmplx { r = tmp; return *this; } - cmplx operator+ (const cmplx &other) const - { return cmplx(r+other.r, i+other.i); } - cmplx operator- (const cmplx &other) const - { return cmplx(r-other.r, i-other.i); } + templatecmplx &operator+= (const cmplx &other) + { r+=other.r; i+=other.i; return *this; } + templatecmplx &operator-= (const cmplx &other) + { r-=other.r; i-=other.i; return *this; } template auto operator* (const T2 &other) const -> cmplx { return {r*other, i*other}; } + template auto operator+ (const cmplx &other) const + -> cmplx + { return {r+other.r, i+other.i}; } + template auto operator- (const cmplx &other) const + -> cmplx + { return {r-other.r, i-other.i}; } template auto operator* (const cmplx &other) const -> cmplx { return {r*other.r-i*other.i, r*other.i + i*other.r}; } @@ -227,8 +233,9 @@ template struct cmplx { : Tres(r*other.r-i*other.i, r*other.i+i*other.r); } }; -template void PMC(cmplx &a, cmplx &b, - const cmplx &c, const cmplx &d) +template inline void PMINPLACE(T &a, T &b) + { T t = a; a+=b; b=t-b; } +template void PMC(T &a, T &b, const T &c, const T &d) { a = c+d; b = c-d; } template cmplx conj(const cmplx &a) { return {a.r, -a.i}; } @@ -845,8 +852,6 @@ template void ROTX135(T &a) else { auto tmp_=a.r; a.r=hsqt2*(-a.r-a.i); a.i=hsqt2*(tmp_-a.i); } } -template inline void PMINPLACE(T &a, T &b) - { T t = a; a.r+=b.r; a.i+=b.i; b.r=t.r-b.r; b.i=t.i-b.i; } template void pass8 (size_t ido, size_t l1, const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, @@ -2293,6 +2298,7 @@ template class pocketfft_r // // sine/cosine transforms // + template class T_dct1 { private: @@ -2302,7 +2308,8 @@ template class T_dct1 POCKETFFT_NOINLINE T_dct1(size_t length) : fftplan(2*(length-1)) {} - template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int /*type*/, bool /*cosine*/) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=fftplan.length(), n=N/2+1; @@ -2323,126 +2330,124 @@ template class T_dct1 size_t length() const { return fftplan.length()/2+1; } }; -template class T_dct2 +template class T_dst1 { private: pocketfft_r fftplan; - vector twiddle; public: - POCKETFFT_NOINLINE T_dct2(size_t length) - : fftplan(length), twiddle(length) - { - constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); - for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool /*cosine*/) const { - constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); - size_t N=length(); - if (N==1) - c[0]*=2*fct; - else if (N==2) - { - T x1 = 2*fct*(c[0]+c[1]); - c[1] = sqrt2*fct*(c[0]-c[1]); - c[0] = x1; - } - else + size_t N=fftplan.length(), n=N/2-1; + arr tmp(N); + tmp[0] = tmp[n+1] = c[0]*0; + for (size_t i=0; i class T_dct3 +template class T_dcst23 { private: pocketfft_r fftplan; vector twiddle; - public: - POCKETFFT_NOINLINE T_dct3(size_t length) - : fftplan(length), twiddle(length) + template POCKETFFT_NOINLINE void exec2(T c[], T0 fct, + bool ortho, bool cosine) const { - constexpr T0 pi = T0(3.141592653589793238462643383279502884197L); - for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const + template POCKETFFT_NOINLINE void exec3(T c[], T0 fct, + bool ortho, bool cosine) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); if (ortho) c[0]*=sqrt2; - if (N==1) - c[0]*=fct; - else if (N==2) + size_t NS2 = (N+1)/2; + if (!cosine) + for (size_t k=0, kc=N-1; k POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int type, bool cosine) const + { type==2 ? exec2(c, fct, ortho, cosine) : exec3(c, fct, ortho, cosine); } + size_t length() const { return fftplan.length(); } }; -template class T_dct4 +template class T_dcst4 { // even length algorithm from // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/ @@ -2453,7 +2458,7 @@ template class T_dct4 arr> C2; public: - POCKETFFT_NOINLINE T_dct4(size_t length) + POCKETFFT_NOINLINE T_dcst4(size_t length) : N(length), fft((N&1) ? nullptr : new pocketfft_c(N/2)), rfft((N&1)? new pocketfft_r(N) : nullptr), @@ -2468,9 +2473,14 @@ template class T_dct4 } } - template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool cosine) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t n2 = N/2; + if (!cosine) + for (size_t k=0, kc=N-1; k class T_dct4 auto SGN_SET = [](T x, size_t i) {return (i%2) ? -x : x;}; arr y(N); - size_t n2 = N/2; - size_t i; { - size_t m; - for (i=0, m=n2; m class T_dct4 y[i] = c[m]; } rfft->forward(y.data(), fct); - for (i=0; i+i+1 class T_dct4 c[i] = sqrt2 * (SGN_SET(cx, (i+1)/2) + SGN_SET(sx, i/2)); c[N-(i+1)] = sqrt2 * (SGN_SET(cx, (i+2)/2) + SGN_SET(sx, (i+1)/2)); } + } c[n2] = sqrt2 * SGN_SET(y[0], (n2+1)/2); // FFTW-derived code ends here } else { - arr> y(N/2); - for(size_t i=0; i> y(n2); + for(size_t i=0; iforward(y.data(), fct); - for(size_t i=0; i class T_dst1 - { - private: - pocketfft_r fftplan; - - public: - POCKETFFT_NOINLINE T_dst1(size_t length) - : fftplan(2*(length+1)) {} - - template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const - { - size_t N=fftplan.length(), n=N/2-1; - arr tmp(N); - tmp[0] = tmp[n+1] = c[0]*0; - for (size_t i=0; i class T_dst2 - { - private: - T_dct2 dct; - - public: - POCKETFFT_NOINLINE T_dst2(size_t length) - : dct(length) {} - - template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const - { - constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); - size_t N=length(); - if (N==1) - c[0]*=2*fct; - else - { + if (!cosine) for (size_t k=1; k class T_dst3 - { - private: - T_dct3 dct; - - public: - POCKETFFT_NOINLINE T_dst3(size_t length) - : dct(length) {} - - template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) - { - constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); - size_t N=length(); - if (ortho) c[0]*=sqrt2; - if (N==1) - c[0]*=fct; - else - { - size_t NS2 = N/2; - for (size_t k=0, kc=N-1; k class T_dst4 - { - private: - T_dct4 dct; - - public: - POCKETFFT_NOINLINE T_dst4(size_t length) - : dct(length) {} - - template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) - { - size_t N=length(); - //if (N==1) { c[0]*=fct; return; } - size_t NS2 = N/2; - for (size_t k=0, kc=N-1; k POCKETFFT_NOINLINE void general_hartley( template POCKETFFT_NOINLINE void general_dcst( const cndarr &in, ndarr &out, const shape_t &axes, - T fct, bool ortho, size_t POCKETFFT_NTHREADS) + T fct, bool ortho, int type, bool cosine, size_t POCKETFFT_NTHREADS) { shared_ptr plan; @@ -3121,7 +3025,7 @@ template POCKETFFT_NOINLINE void general_dcst( for (size_t i=0; iexec(tdatav, fct, ortho); + plan->exec(tdatav, fct, ortho, type, cosine); for (size_t i=0; i POCKETFFT_NOINLINE void general_dcst( it.advance(1); auto tdata = reinterpret_cast(storage.data()); if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place - plan->exec(&out[it.oofs(0)], fct, ortho); + plan->exec(&out[it.oofs(0)], fct, ortho, type, cosine); else if (it.stride_out()==sizeof(T)) // compute FFT in output location { for (size_t i=0; iexec(&out[it.oofs(0)], fct, ortho); + plan->exec(&out[it.oofs(0)], fct, ortho, type, cosine); } else { for (size_t i=0; iexec(tdata, fct, ortho); + plan->exec(tdata, fct, ortho, type, cosine); for (size_t i=0; i void dct(const shape_t &shape, cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); if (type==1) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); - else if (type==2) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); - else if (type==3) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, type, true, nthreads); else if (type==4) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, type, true, nthreads); else - throw runtime_error("unsupported DCT type"); + general_dcst>(ain, aout, axes, fct, ortho, type, true, nthreads); } template void dst(const shape_t &shape, @@ -3428,15 +3328,11 @@ template void dst(const shape_t &shape, cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); if (type==1) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); - else if (type==2) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); - else if (type==3) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, type, false, nthreads); else if (type==4) - general_dcst>(ain, aout, axes, fct, ortho, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, type, false, nthreads); else - throw runtime_error("unsupported DST type"); + general_dcst>(ain, aout, axes, fct, ortho, type, false, nthreads); } template void r2c(const shape_t &shape_in,