diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 8d144c81cb734404dcb19fa731dad6e567c15617..fcedf151832c024726d4a9b4156947ed308d9bf3 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -2296,9 +2296,12 @@ template<typename T0> class T_dct1 POCKETFFT_NOINLINE T_dct1(size_t length) : fftplan(2*(length-1)) {} - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const + template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=fftplan.length(), n=N/2+1; + if (ortho) + { c[0]*=sqrt2; c[n-1]*=sqrt2; } arr<T> tmp(N); tmp[0] = c[0]; for (size_t i=1; i<n; ++i) @@ -2307,6 +2310,8 @@ template<typename T0> class T_dct1 c[0] = tmp[0]; for (size_t i=1; i<n; ++i) c[i] = tmp[2*i-1]; + if (ortho) + { c[0]/=sqrt2; c[n-1]/=sqrt2; } } size_t length() const { return fftplan.length()/2+1; } @@ -2327,41 +2332,45 @@ template<typename T0> class T_dct2 twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length))); } - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const + template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); - if (N==1) { c[0]*=2*fct; return; } - if (N==2) + if (N==1) + c[0]*=2*fct; + else if (N==2) { T x1 = 2*fct*(c[0]+c[1]); c[1] = sqrt2*fct*(c[0]-c[1]); c[0] = x1; - return; - } - size_t NS2 = (N+1)/2; - for (size_t i=2; i<N; i+=2) - { - T xim1 = T0(0.5)*(c[i-1]+c[i]); - c[i] = T0(0.5)*(c[i]-c[i-1]); - c[i-1] = xim1; - } - fftplan.backward(c, fct); - for (size_t k=1, kc=N-1; k<NS2; ++k, --kc) - { - T tmp = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k]; - c[kc] = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc]; - c[k] = tmp; } - if ((N&1)==0) - c[NS2] = twiddle[NS2-1]*(c[NS2]+c[NS2]); - for (size_t k=1, kc=N-1; k<NS2; ++k, --kc) + else { - T tmp = c[k]+c[kc]; - c[kc] = c[k]-c[kc]; - c[k] = tmp; + size_t NS2 = (N+1)/2; + for (size_t i=2; i<N; i+=2) + { + T xim1 = T0(0.5)*(c[i-1]+c[i]); + c[i] = T0(0.5)*(c[i]-c[i-1]); + c[i-1] = xim1; + } + fftplan.backward(c, fct); + for (size_t k=1, kc=N-1; k<NS2; ++k, --kc) + { + T tmp = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k]; + c[kc] = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc]; + c[k] = tmp; + } + if ((N&1)==0) + c[NS2] = twiddle[NS2-1]*(c[NS2]+c[NS2]); + for (size_t k=1, kc=N-1; k<NS2; ++k, --kc) + { + T tmp = c[k]+c[kc]; + c[kc] = c[k]-c[kc]; + c[k] = tmp; + } + c[0] *= 2; } - c[0] *= 2; + if (ortho) c[0]/=sqrt2; } size_t length() const { return fftplan.length(); } @@ -2382,41 +2391,45 @@ template<typename T0> class T_dct3 twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length))); } - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const + template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); - if (N==1) { c[0]*=fct; return; } - if (N==2) + if (ortho) c[0]*=sqrt2; + if (N==1) + c[0]*=fct; + else if (N==2) { T TSQX = sqrt2*c[1]; c[1] = fct*(c[0]-TSQX); c[0] = fct*(c[0]+TSQX); - return; - } - size_t NS2 = (N+1)/2; - for (size_t k=1, kc=N-1; k<NS2; ++k, --kc) - { - T tmp = c[k]-c[kc]; - c[k] = c[k]+c[kc]; - c[kc] = tmp; } - if ((N&1)==0) - c[NS2] = c[NS2]+c[NS2]; - for (size_t k=1, kc=N-1; k<NS2; ++k, --kc) - { - T tmp = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc]; - c[k] = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k]; - c[kc] = tmp; - } - if ((N&1)==0) - c[NS2] = twiddle[NS2-1]*c[NS2]; - fftplan.forward(c, fct); - for (size_t i=2; i<N; i+=2) + else { - T xim1 = c[i-1]-c[i]; - c[i] += c[i-1]; - c[i-1] = xim1; + size_t NS2 = (N+1)/2; + for (size_t k=1, kc=N-1; k<NS2; ++k, --kc) + { + T tmp = c[k]-c[kc]; + c[k] = c[k]+c[kc]; + c[kc] = tmp; + } + if ((N&1)==0) + c[NS2] = c[NS2]+c[NS2]; + for (size_t k=1, kc=N-1; k<NS2; ++k, --kc) + { + T tmp = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc]; + c[k] = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k]; + c[kc] = tmp; + } + if ((N&1)==0) + c[NS2] = twiddle[NS2-1]*c[NS2]; + fftplan.forward(c, fct); + for (size_t i=2; i<N; i+=2) + { + T xim1 = c[i-1]-c[i]; + c[i] += c[i-1]; + c[i-1] = xim1; + } } } @@ -2455,7 +2468,7 @@ template<typename T0> class T_dct4 } } - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const + template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const { // FIXME: due to the recurrence below, the algorithm is not very accurate // Trying to find a better one... @@ -2463,7 +2476,7 @@ template<typename T0> class T_dct4 { for (size_t i=0; i<N; ++i) c[i] *= C[i]; - dct2.exec(c, fct); + dct2.exec(c, fct, false); for (size_t i=1; i<N; ++i) c[i] = 2*c[i] - c[i-1]; } @@ -2498,7 +2511,7 @@ template<typename T0> class T_dst1 POCKETFFT_NOINLINE T_dst1(size_t length) : fftplan(2*(length+1)) {} - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const + template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const { size_t N=fftplan.length(), n=N/2-1; arr<T> tmp(N); @@ -2525,15 +2538,21 @@ template<typename T0> class T_dst2 POCKETFFT_NOINLINE T_dst2(size_t length) : dct(length) {} - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) const + template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); - if (N==1) { c[0]*=2*fct; return; } - for (size_t k=1; k<N; k+=2) - c[k] = -c[k]; - dct.exec(c, fct); - for (size_t k=0, kc=N-1; k<kc; ++k, --kc) - swap(c[k], c[kc]); + if (N==1) + c[0]*=2*fct; + else + { + for (size_t k=1; k<N; k+=2) + c[k] = -c[k]; + dct.exec(c, fct, false); + for (size_t k=0, kc=N-1; k<kc; ++k, --kc) + swap(c[k], c[kc]); + } + if (ortho) c[0]/=sqrt2; } size_t length() const { return dct.length(); } @@ -2548,16 +2567,22 @@ template<typename T0> class T_dst3 POCKETFFT_NOINLINE T_dst3(size_t length) : dct(length) {} - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) + template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); - if (N==1) { c[0]*=fct; return; } - size_t NS2 = N/2; - for (size_t k=0, kc=N-1; k<NS2; ++k, --kc) - swap(c[k], c[kc]); - dct.exec(c, fct); - for (size_t k=1; k<N; k+=2) - c[k] = -c[k]; + if (ortho) c[0]*=sqrt2; + if (N==1) + c[0]*=fct; + else + { + size_t NS2 = N/2; + for (size_t k=0, kc=N-1; k<NS2; ++k, --kc) + swap(c[k], c[kc]); + dct.exec(c, fct, false); + for (size_t k=1; k<N; k+=2) + c[k] = -c[k]; + } } size_t length() const { return dct.length(); } @@ -2572,14 +2597,14 @@ template<typename T0> class T_dst4 POCKETFFT_NOINLINE T_dst4(size_t length) : dct(length) {} - template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct) + template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) { size_t N=length(); //if (N==1) { c[0]*=fct; return; } size_t NS2 = N/2; for (size_t k=0, kc=N-1; k<NS2; ++k, --kc) swap(c[k], c[kc]); - dct.exec(c, fct); + dct.exec(c, fct, false); for (size_t k=1; k<N; k+=2) c[k] = -c[k]; } @@ -3033,7 +3058,7 @@ template<typename T> POCKETFFT_NOINLINE void general_hartley( template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst( const cndarr<T> &in, ndarr<T> &out, const shape_t &axes, - T fct, size_t POCKETFFT_NTHREADS) + T fct, bool ortho, size_t POCKETFFT_NTHREADS) { shared_ptr<Trafo> plan; @@ -3061,7 +3086,7 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst( for (size_t i=0; i<len; ++i) for (size_t j=0; j<vlen; ++j) tdatav[i][j] = tin[it.iofs(j,i)]; - plan->exec(tdatav, fct); + plan->exec(tdatav, fct, ortho); for (size_t i=0; i<len; ++i) for (size_t j=0; j<vlen; ++j) out[it.oofs(j,i)] = tdatav[i][j]; @@ -3072,18 +3097,18 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst( it.advance(1); auto tdata = reinterpret_cast<T *>(storage.data()); if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place - plan->exec(&out[it.oofs(0)], fct); + plan->exec(&out[it.oofs(0)], fct, ortho); else if (it.stride_out()==sizeof(T)) // compute FFT in output location { for (size_t i=0; i<len; ++i) out[it.oofs(i)] = tin[it.iofs(i)]; - plan->exec(&out[it.oofs(0)], fct); + plan->exec(&out[it.oofs(0)], fct, ortho); } else { for (size_t i=0; i<len; ++i) tdata[i] = tin[it.iofs(i)]; - plan->exec(tdata, fct); + plan->exec(tdata, fct, ortho); for (size_t i=0; i<len; ++i) out[it.oofs(i)] = tdata[i]; } @@ -3339,7 +3364,7 @@ template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in, template<typename T> void dct(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, - int type, const T *data_in, T *data_out, T fct, size_t nthreads=1) + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) { if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type"); if (util::prod(shape)==0) return; @@ -3347,20 +3372,20 @@ template<typename T> void dct(const shape_t &shape, cndarr<T> ain(data_in, shape, stride_in); ndarr<T> aout(data_out, shape, stride_out); if (type==1) - general_dcst<T_dct1<T>>(ain, aout, axes, fct, nthreads); + general_dcst<T_dct1<T>>(ain, aout, axes, fct, ortho, nthreads); else if (type==2) - general_dcst<T_dct2<T>>(ain, aout, axes, fct, nthreads); + general_dcst<T_dct2<T>>(ain, aout, axes, fct, ortho, nthreads); else if (type==3) - general_dcst<T_dct3<T>>(ain, aout, axes, fct, nthreads); + general_dcst<T_dct3<T>>(ain, aout, axes, fct, ortho, nthreads); else if (type==4) - general_dcst<T_dct4<T>>(ain, aout, axes, fct, nthreads); + general_dcst<T_dct4<T>>(ain, aout, axes, fct, ortho, nthreads); else throw runtime_error("unsupported DCT type"); } template<typename T> void dst(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, - int type, const T *data_in, T *data_out, T fct, size_t nthreads=1) + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) { if ((type<1) || (type>4)) throw invalid_argument("invalid DST type"); if (util::prod(shape)==0) return; @@ -3368,13 +3393,13 @@ template<typename T> void dst(const shape_t &shape, cndarr<T> ain(data_in, shape, stride_in); ndarr<T> aout(data_out, shape, stride_out); if (type==1) - general_dcst<T_dst1<T>>(ain, aout, axes, fct, nthreads); + general_dcst<T_dst1<T>>(ain, aout, axes, fct, ortho, nthreads); else if (type==2) - general_dcst<T_dst2<T>>(ain, aout, axes, fct, nthreads); + general_dcst<T_dst2<T>>(ain, aout, axes, fct, ortho, nthreads); else if (type==3) - general_dcst<T_dst3<T>>(ain, aout, axes, fct, nthreads); + general_dcst<T_dst3<T>>(ain, aout, axes, fct, ortho, nthreads); else if (type==4) - general_dcst<T_dst4<T>>(ain, aout, axes, fct, nthreads); + general_dcst<T_dst4<T>>(ain, aout, axes, fct, ortho, nthreads); else throw runtime_error("unsupported DST type"); } diff --git a/pypocketfft.cc b/pypocketfft.cc index 35102e32c742302238c0a27b87b6ef74d42f872f..c6d72bf7a3cd9c345dd25803ec1c8251c030fa14 100644 --- a/pypocketfft.cc +++ b/pypocketfft.cc @@ -227,7 +227,7 @@ py::array r2r_fftpack(const py::array &in, const py::object &axes_, } template<typename T> py::array dct_internal(const py::array &in, - const py::object &axes_, int type, int inorm, py::object &out_, + const py::object &axes_, int type, int inorm, bool ortho, py::object &out_, size_t nthreads) { auto axes = makeaxes(in, axes_); @@ -239,24 +239,26 @@ template<typename T> py::array dct_internal(const py::array &in, auto d_out=reinterpret_cast<T *>(res.mutable_data()); { py::gil_scoped_release release; + // override + if (ortho) inorm=1; T fct = (type==1) ? norm_fct<T>(inorm, dims, axes, 2, -1) : norm_fct<T>(inorm, dims, axes, 2); - pocketfft::dct(dims, s_in, s_out, axes, type, d_in, d_out, fct, + pocketfft::dct(dims, s_in, s_out, axes, type, d_in, d_out, fct, ortho, nthreads); } return res; } py::array dct(const py::array &in, int type, const py::object &axes_, - int inorm, py::object &out_, size_t nthreads) + int inorm, bool ortho, py::object &out_, size_t nthreads) { if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type"); - DISPATCH(in, f64, f32, flong, dct_internal, (in, axes_, type, inorm, + DISPATCH(in, f64, f32, flong, dct_internal, (in, axes_, type, inorm, ortho, out_, nthreads)) } template<typename T> py::array dst_internal(const py::array &in, - const py::object &axes_, int type, int inorm, py::object &out_, + const py::object &axes_, int type, int inorm, bool ortho, py::object &out_, size_t nthreads) { auto axes = makeaxes(in, axes_); @@ -268,19 +270,21 @@ template<typename T> py::array dst_internal(const py::array &in, auto d_out=reinterpret_cast<T *>(res.mutable_data()); { py::gil_scoped_release release; + // override + if (ortho) inorm=1; T fct = (type==1) ? norm_fct<T>(inorm, dims, axes, 2, 1) : norm_fct<T>(inorm, dims, axes, 2); - pocketfft::dst(dims, s_in, s_out, axes, type, d_in, d_out, fct, + pocketfft::dst(dims, s_in, s_out, axes, type, d_in, d_out, fct, ortho, nthreads); } return res; } py::array dst(const py::array &in, int type, const py::object &axes_, - int inorm, py::object &out_, size_t nthreads) + int inorm, bool ortho, py::object &out_, size_t nthreads) { if ((type<1) || (type>4)) throw invalid_argument("invalid DST type"); - DISPATCH(in, f64, f32, flong, dst_internal, (in, axes_, type, inorm, + DISPATCH(in, f64, f32, flong, dst_internal, (in, axes_, type, inorm, ortho, out_, nthreads)) } @@ -665,7 +669,7 @@ PYBIND11_MODULE(pypocketfft, m) m.def("genuine_hartley", genuine_hartley, genuine_hartley_DS, "a"_a, "axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); m.def("dct", dct, dct_DS, "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0, - "out"_a=None, "nthreads"_a=1); + "ortho"_a=false, "out"_a=None, "nthreads"_a=1); m.def("dst", dst, dst_DS, "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0, - "out"_a=None, "nthreads"_a=1); + "ortho"_a=false, "out"_a=None, "nthreads"_a=1); } diff --git a/test.py b/test.py index 66efb9256b7ac7409ad15859fed968847174f630..3d22f1ebb4c744b0ab2b6edfa82ae03e2242331e 100644 --- a/test.py +++ b/test.py @@ -222,6 +222,23 @@ def testdcst1D(len, inorm, type): _assert_close(a, pypocketfft.dst(pypocketfft.dst(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-15) _assert_close(b, pypocketfft.dst(pypocketfft.dst(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-7) +@pmp("len", len1D) +@pmp("type", [1, 2, 3]) +def testdcst1Dortho(len, type): + a = np.random.rand(len)-0.5 + b = a.astype(np.float32) + c = a.astype(np.float128) + itp = (0, 1, 3, 2, 4) + itype = itp[type] + if type != 1 or len > 1: + _assert_close(a, pypocketfft.dct(pypocketfft.dct(c, ortho=True, type=type), ortho=True, type=itype), 2e-18) + _assert_close(a, pypocketfft.dct(pypocketfft.dct(a, ortho=True, type=type), ortho=True, type=itype), 1.5e-15) + _assert_close(b, pypocketfft.dct(pypocketfft.dct(b, ortho=True, type=type), ortho=True, type=itype), 6e-7) + if type != 1: + _assert_close(a, pypocketfft.dst(pypocketfft.dst(c, ortho=True, type=type), ortho=True, type=itype), 2e-18) + _assert_close(a, pypocketfft.dst(pypocketfft.dst(a, ortho=True, type=type), ortho=True, type=itype), 1.5e-15) + _assert_close(b, pypocketfft.dst(pypocketfft.dst(b, ortho=True, type=type), ortho=True, type=itype), 6e-7) + # TEMPORARY: separate test for DCT/DST IV, since they are less accurate @pmp("len", len1D)