diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 8d144c81cb734404dcb19fa731dad6e567c15617..fcedf151832c024726d4a9b4156947ed308d9bf3 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -2296,9 +2296,12 @@ template class T_dct1 POCKETFFT_NOINLINE T_dct1(size_t length) : fftplan(2*(length-1)) {} - template POCKETFFT_NOINLINE void exec(T c[], T0 fct) const + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=fftplan.length(), n=N/2+1; + if (ortho) + { c[0]*=sqrt2; c[n-1]*=sqrt2; } arr tmp(N); tmp[0] = c[0]; for (size_t i=1; i class T_dct1 c[0] = tmp[0]; for (size_t i=1; i class T_dct2 twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length))); } - template POCKETFFT_NOINLINE void exec(T c[], T0 fct) const + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); - if (N==1) { c[0]*=2*fct; return; } - if (N==2) + if (N==1) + c[0]*=2*fct; + else if (N==2) { T x1 = 2*fct*(c[0]+c[1]); c[1] = sqrt2*fct*(c[0]-c[1]); c[0] = x1; - return; - } - size_t NS2 = (N+1)/2; - for (size_t i=2; i class T_dct3 twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length))); } - template POCKETFFT_NOINLINE void exec(T c[], T0 fct) const + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const { constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); - if (N==1) { c[0]*=fct; return; } - if (N==2) + if (ortho) c[0]*=sqrt2; + if (N==1) + c[0]*=fct; + else if (N==2) { T TSQX = sqrt2*c[1]; c[1] = fct*(c[0]-TSQX); c[0] = fct*(c[0]+TSQX); - return; - } - size_t NS2 = (N+1)/2; - for (size_t k=1, kc=N-1; k class T_dct4 } } - template POCKETFFT_NOINLINE void exec(T c[], T0 fct) const + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const { // FIXME: due to the recurrence below, the algorithm is not very accurate // Trying to find a better one... @@ -2463,7 +2476,7 @@ template class T_dct4 { for (size_t i=0; i class T_dst1 POCKETFFT_NOINLINE T_dst1(size_t length) : fftplan(2*(length+1)) {} - template POCKETFFT_NOINLINE void exec(T c[], T0 fct) const + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const { size_t N=fftplan.length(), n=N/2-1; arr tmp(N); @@ -2525,15 +2538,21 @@ template class T_dst2 POCKETFFT_NOINLINE T_dst2(size_t length) : dct(length) {} - template POCKETFFT_NOINLINE void exec(T c[], T0 fct) const + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); - if (N==1) { c[0]*=2*fct; return; } - for (size_t k=1; k class T_dst3 POCKETFFT_NOINLINE T_dst3(size_t length) : dct(length) {} - template POCKETFFT_NOINLINE void exec(T c[], T0 fct) + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); size_t N=length(); - if (N==1) { c[0]*=fct; return; } - size_t NS2 = N/2; - for (size_t k=0, kc=N-1; k class T_dst4 POCKETFFT_NOINLINE T_dst4(size_t length) : dct(length) {} - template POCKETFFT_NOINLINE void exec(T c[], T0 fct) + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) { size_t N=length(); //if (N==1) { c[0]*=fct; return; } size_t NS2 = N/2; for (size_t k=0, kc=N-1; k POCKETFFT_NOINLINE void general_hartley( template POCKETFFT_NOINLINE void general_dcst( const cndarr &in, ndarr &out, const shape_t &axes, - T fct, size_t POCKETFFT_NTHREADS) + T fct, bool ortho, size_t POCKETFFT_NTHREADS) { shared_ptr plan; @@ -3061,7 +3086,7 @@ template POCKETFFT_NOINLINE void general_dcst( for (size_t i=0; iexec(tdatav, fct); + plan->exec(tdatav, fct, ortho); for (size_t i=0; i POCKETFFT_NOINLINE void general_dcst( it.advance(1); auto tdata = reinterpret_cast(storage.data()); if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place - plan->exec(&out[it.oofs(0)], fct); + plan->exec(&out[it.oofs(0)], fct, ortho); else if (it.stride_out()==sizeof(T)) // compute FFT in output location { for (size_t i=0; iexec(&out[it.oofs(0)], fct); + plan->exec(&out[it.oofs(0)], fct, ortho); } else { for (size_t i=0; iexec(tdata, fct); + plan->exec(tdata, fct, ortho); for (size_t i=0; i void c2c(const shape_t &shape, const stride_t &stride_in, template void dct(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, - int type, const T *data_in, T *data_out, T fct, size_t nthreads=1) + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) { if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type"); if (util::prod(shape)==0) return; @@ -3347,20 +3372,20 @@ template void dct(const shape_t &shape, cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); if (type==1) - general_dcst>(ain, aout, axes, fct, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, nthreads); else if (type==2) - general_dcst>(ain, aout, axes, fct, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, nthreads); else if (type==3) - general_dcst>(ain, aout, axes, fct, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, nthreads); else if (type==4) - general_dcst>(ain, aout, axes, fct, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, nthreads); else throw runtime_error("unsupported DCT type"); } template void dst(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, - int type, const T *data_in, T *data_out, T fct, size_t nthreads=1) + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) { if ((type<1) || (type>4)) throw invalid_argument("invalid DST type"); if (util::prod(shape)==0) return; @@ -3368,13 +3393,13 @@ template void dst(const shape_t &shape, cndarr ain(data_in, shape, stride_in); ndarr aout(data_out, shape, stride_out); if (type==1) - general_dcst>(ain, aout, axes, fct, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, nthreads); else if (type==2) - general_dcst>(ain, aout, axes, fct, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, nthreads); else if (type==3) - general_dcst>(ain, aout, axes, fct, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, nthreads); else if (type==4) - general_dcst>(ain, aout, axes, fct, nthreads); + general_dcst>(ain, aout, axes, fct, ortho, nthreads); else throw runtime_error("unsupported DST type"); } diff --git a/pypocketfft.cc b/pypocketfft.cc index 35102e32c742302238c0a27b87b6ef74d42f872f..c6d72bf7a3cd9c345dd25803ec1c8251c030fa14 100644 --- a/pypocketfft.cc +++ b/pypocketfft.cc @@ -227,7 +227,7 @@ py::array r2r_fftpack(const py::array &in, const py::object &axes_, } template py::array dct_internal(const py::array &in, - const py::object &axes_, int type, int inorm, py::object &out_, + const py::object &axes_, int type, int inorm, bool ortho, py::object &out_, size_t nthreads) { auto axes = makeaxes(in, axes_); @@ -239,24 +239,26 @@ template py::array dct_internal(const py::array &in, auto d_out=reinterpret_cast(res.mutable_data()); { py::gil_scoped_release release; + // override + if (ortho) inorm=1; T fct = (type==1) ? norm_fct(inorm, dims, axes, 2, -1) : norm_fct(inorm, dims, axes, 2); - pocketfft::dct(dims, s_in, s_out, axes, type, d_in, d_out, fct, + pocketfft::dct(dims, s_in, s_out, axes, type, d_in, d_out, fct, ortho, nthreads); } return res; } py::array dct(const py::array &in, int type, const py::object &axes_, - int inorm, py::object &out_, size_t nthreads) + int inorm, bool ortho, py::object &out_, size_t nthreads) { if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type"); - DISPATCH(in, f64, f32, flong, dct_internal, (in, axes_, type, inorm, + DISPATCH(in, f64, f32, flong, dct_internal, (in, axes_, type, inorm, ortho, out_, nthreads)) } template py::array dst_internal(const py::array &in, - const py::object &axes_, int type, int inorm, py::object &out_, + const py::object &axes_, int type, int inorm, bool ortho, py::object &out_, size_t nthreads) { auto axes = makeaxes(in, axes_); @@ -268,19 +270,21 @@ template py::array dst_internal(const py::array &in, auto d_out=reinterpret_cast(res.mutable_data()); { py::gil_scoped_release release; + // override + if (ortho) inorm=1; T fct = (type==1) ? norm_fct(inorm, dims, axes, 2, 1) : norm_fct(inorm, dims, axes, 2); - pocketfft::dst(dims, s_in, s_out, axes, type, d_in, d_out, fct, + pocketfft::dst(dims, s_in, s_out, axes, type, d_in, d_out, fct, ortho, nthreads); } return res; } py::array dst(const py::array &in, int type, const py::object &axes_, - int inorm, py::object &out_, size_t nthreads) + int inorm, bool ortho, py::object &out_, size_t nthreads) { if ((type<1) || (type>4)) throw invalid_argument("invalid DST type"); - DISPATCH(in, f64, f32, flong, dst_internal, (in, axes_, type, inorm, + DISPATCH(in, f64, f32, flong, dst_internal, (in, axes_, type, inorm, ortho, out_, nthreads)) } @@ -665,7 +669,7 @@ PYBIND11_MODULE(pypocketfft, m) m.def("genuine_hartley", genuine_hartley, genuine_hartley_DS, "a"_a, "axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); m.def("dct", dct, dct_DS, "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0, - "out"_a=None, "nthreads"_a=1); + "ortho"_a=false, "out"_a=None, "nthreads"_a=1); m.def("dst", dst, dst_DS, "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0, - "out"_a=None, "nthreads"_a=1); + "ortho"_a=false, "out"_a=None, "nthreads"_a=1); } diff --git a/test.py b/test.py index 66efb9256b7ac7409ad15859fed968847174f630..3d22f1ebb4c744b0ab2b6edfa82ae03e2242331e 100644 --- a/test.py +++ b/test.py @@ -222,6 +222,23 @@ def testdcst1D(len, inorm, type): _assert_close(a, pypocketfft.dst(pypocketfft.dst(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-15) _assert_close(b, pypocketfft.dst(pypocketfft.dst(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-7) +@pmp("len", len1D) +@pmp("type", [1, 2, 3]) +def testdcst1Dortho(len, type): + a = np.random.rand(len)-0.5 + b = a.astype(np.float32) + c = a.astype(np.float128) + itp = (0, 1, 3, 2, 4) + itype = itp[type] + if type != 1 or len > 1: + _assert_close(a, pypocketfft.dct(pypocketfft.dct(c, ortho=True, type=type), ortho=True, type=itype), 2e-18) + _assert_close(a, pypocketfft.dct(pypocketfft.dct(a, ortho=True, type=type), ortho=True, type=itype), 1.5e-15) + _assert_close(b, pypocketfft.dct(pypocketfft.dct(b, ortho=True, type=type), ortho=True, type=itype), 6e-7) + if type != 1: + _assert_close(a, pypocketfft.dst(pypocketfft.dst(c, ortho=True, type=type), ortho=True, type=itype), 2e-18) + _assert_close(a, pypocketfft.dst(pypocketfft.dst(a, ortho=True, type=type), ortho=True, type=itype), 1.5e-15) + _assert_close(b, pypocketfft.dst(pypocketfft.dst(b, ortho=True, type=type), ortho=True, type=itype), 6e-7) + # TEMPORARY: separate test for DCT/DST IV, since they are less accurate @pmp("len", len1D)