diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 2c76ad863ba3bc8521c6f06d6d7f312001d2cbad..7515f79413ff99516511fd384148fea3c30fb8b9 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -3294,7 +3294,7 @@ template void c2c(const shape_t &shape, const stride_t &stride_in, general_c(ain, aout, axes, forward, fct, nthreads); } -template void r2r_dct(const shape_t &shape, +template void dct(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, int type, const T *data_in, T *data_out, T fct, size_t nthreads=1) { @@ -3315,7 +3315,7 @@ template void r2r_dct(const shape_t &shape, throw runtime_error("unsupported DCT type"); } -template void r2r_dst(const shape_t &shape, +template void dst(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, int type, const T *data_in, T *data_out, T fct, size_t nthreads=1) { @@ -3442,8 +3442,8 @@ using detail::c2r; using detail::r2c; using detail::r2r_fftpack; using detail::r2r_separable_hartley; -using detail::r2r_dct; -using detail::r2r_dst; +using detail::dct; +using detail::dst; } // namespace pocketfft diff --git a/pypocketfft.cc b/pypocketfft.cc index fa57a574b7011427b4f1ff0b84f7c7f01e304d5f..c2b4b2cf3413da93c3c3123a515ae804e96b9801 100644 --- a/pypocketfft.cc +++ b/pypocketfft.cc @@ -226,7 +226,7 @@ py::array r2r_fftpack(const py::array &in, const py::object &axes_, real2hermitian, forward, inorm, out_, nthreads)) } -template py::array r2r_dct_internal(const py::array &in, +template py::array dct_internal(const py::array &in, const py::object &axes_, int type, int inorm, py::object &out_, size_t nthreads) { @@ -241,21 +241,21 @@ template py::array r2r_dct_internal(const py::array &in, py::gil_scoped_release release; T fct = (type==1) ? norm_fct(inorm, dims, axes, 2, -1) : norm_fct(inorm, dims, axes, 2); - pocketfft::r2r_dct(dims, s_in, s_out, axes, type, d_in, d_out, fct, + pocketfft::dct(dims, s_in, s_out, axes, type, d_in, d_out, fct, nthreads); } return res; } -py::array r2r_dct(const py::array &in, int type, const py::object &axes_, +py::array dct(const py::array &in, int type, const py::object &axes_, int inorm, py::object &out_, size_t nthreads) { if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type"); - DISPATCH(in, f64, f32, flong, r2r_dct_internal, (in, axes_, type, inorm, + DISPATCH(in, f64, f32, flong, dct_internal, (in, axes_, type, inorm, out_, nthreads)) } -template py::array r2r_dst_internal(const py::array &in, +template py::array dst_internal(const py::array &in, const py::object &axes_, int type, int inorm, py::object &out_, size_t nthreads) { @@ -270,17 +270,17 @@ template py::array r2r_dst_internal(const py::array &in, py::gil_scoped_release release; T fct = (type==1) ? norm_fct(inorm, dims, axes, 2, 1) : norm_fct(inorm, dims, axes, 2); - pocketfft::r2r_dst(dims, s_in, s_out, axes, type, d_in, d_out, fct, + pocketfft::dst(dims, s_in, s_out, axes, type, d_in, d_out, fct, nthreads); } return res; } -py::array r2r_dst(const py::array &in, int type, const py::object &axes_, +py::array dst(const py::array &in, int type, const py::object &axes_, int inorm, py::object &out_, size_t nthreads) { if ((type<1) || (type>4)) throw invalid_argument("invalid DST type"); - DISPATCH(in, f64, f32, flong, r2r_dst_internal, (in, axes_, type, inorm, + DISPATCH(in, f64, f32, flong, dst_internal, (in, axes_, type, inorm, out_, nthreads)) } @@ -600,8 +600,8 @@ PYBIND11_MODULE(pypocketfft, m) "axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); m.def("genuine_hartley", genuine_hartley, genuine_hartley_DS, "a"_a, "axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); - m.def("r2r_dct", r2r_dct, /*r2r_dct_DS,*/ "a"_a, "type"_a, "axes"_a=None, + m.def("dct", dct, /*dct_DS,*/ "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); - m.def("r2r_dst", r2r_dst, /*r2r_dst_DS,*/ "a"_a, "type"_a, "axes"_a=None, + m.def("dst", dst, /*dst_DS,*/ "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); } diff --git a/test.py b/test.py index 218fbd1fcc078a5afa5a6fb3b3272e10fab7328d..66efb9256b7ac7409ad15859fed968847174f630 100644 --- a/test.py +++ b/test.py @@ -215,12 +215,12 @@ def testdcst1D(len, inorm, type): itp = (0, 1, 3, 2, 4) itype = itp[type] if type != 1 or len > 1: - _assert_close(a, pypocketfft.r2r_dct(pypocketfft.r2r_dct(c, inorm=inorm, type=type), inorm=2-inorm, type=itype), 2e-18) - _assert_close(a, pypocketfft.r2r_dct(pypocketfft.r2r_dct(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-15) - _assert_close(b, pypocketfft.r2r_dct(pypocketfft.r2r_dct(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-7) - _assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(c, inorm=inorm, type=type), inorm=2-inorm, type=itype), 2e-18) - _assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-15) - _assert_close(b, pypocketfft.r2r_dst(pypocketfft.r2r_dst(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-7) + _assert_close(a, pypocketfft.dct(pypocketfft.dct(c, inorm=inorm, type=type), inorm=2-inorm, type=itype), 2e-18) + _assert_close(a, pypocketfft.dct(pypocketfft.dct(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-15) + _assert_close(b, pypocketfft.dct(pypocketfft.dct(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-7) + _assert_close(a, pypocketfft.dst(pypocketfft.dst(c, inorm=inorm, type=type), inorm=2-inorm, type=itype), 2e-18) + _assert_close(a, pypocketfft.dst(pypocketfft.dst(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-15) + _assert_close(b, pypocketfft.dst(pypocketfft.dst(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-7) # TEMPORARY: separate test for DCT/DST IV, since they are less accurate @@ -234,9 +234,9 @@ def testdcst1D4(len, inorm, type): itp = (0, 1, 3, 2, 4) itype = itp[type] if type != 1 or len > 1: - _assert_close(a, pypocketfft.r2r_dct(pypocketfft.r2r_dct(c, inorm=inorm, type=type), inorm=2-inorm, type=itype), 2e-16) - _assert_close(a, pypocketfft.r2r_dct(pypocketfft.r2r_dct(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-13) - _assert_close(b, pypocketfft.r2r_dct(pypocketfft.r2r_dct(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-5) - _assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(c, inorm=inorm, type=type), inorm=2-inorm, type=itype), 2e-16) - _assert_close(a, pypocketfft.r2r_dst(pypocketfft.r2r_dst(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-13) - _assert_close(b, pypocketfft.r2r_dst(pypocketfft.r2r_dst(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-5) + _assert_close(a, pypocketfft.dct(pypocketfft.dct(c, inorm=inorm, type=type), inorm=2-inorm, type=itype), 2e-16) + _assert_close(a, pypocketfft.dct(pypocketfft.dct(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-13) + _assert_close(b, pypocketfft.dct(pypocketfft.dct(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-5) + _assert_close(a, pypocketfft.dst(pypocketfft.dst(c, inorm=inorm, type=type), inorm=2-inorm, type=itype), 2e-16) + _assert_close(a, pypocketfft.dst(pypocketfft.dst(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-13) + _assert_close(b, pypocketfft.dst(pypocketfft.dst(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-5)