From ee5193a5ff6fa184c31d7c252c3e2847de64590a Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Sat, 20 Jul 2019 01:08:42 +0100 Subject: [PATCH] Remove redundant "ortho" parameter from python interface --- pypocketfft.cc | 24 +++++++++++------------- test.py | 12 ++++++------ 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/pypocketfft.cc b/pypocketfft.cc index c6d72bf..ee2ce0f 100644 --- a/pypocketfft.cc +++ b/pypocketfft.cc @@ -227,7 +227,7 @@ py::array r2r_fftpack(const py::array &in, const py::object &axes_, } template py::array dct_internal(const py::array &in, - const py::object &axes_, int type, int inorm, bool ortho, py::object &out_, + const py::object &axes_, int type, int inorm, py::object &out_, size_t nthreads) { auto axes = makeaxes(in, axes_); @@ -239,10 +239,9 @@ template py::array dct_internal(const py::array &in, auto d_out=reinterpret_cast(res.mutable_data()); { py::gil_scoped_release release; - // override - if (ortho) inorm=1; T fct = (type==1) ? norm_fct(inorm, dims, axes, 2, -1) : norm_fct(inorm, dims, axes, 2); + bool ortho = inorm == 1; pocketfft::dct(dims, s_in, s_out, axes, type, d_in, d_out, fct, ortho, nthreads); } @@ -250,15 +249,15 @@ template py::array dct_internal(const py::array &in, } py::array dct(const py::array &in, int type, const py::object &axes_, - int inorm, bool ortho, py::object &out_, size_t nthreads) + int inorm, py::object &out_, size_t nthreads) { if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type"); - DISPATCH(in, f64, f32, flong, dct_internal, (in, axes_, type, inorm, ortho, - out_, nthreads)) + DISPATCH(in, f64, f32, flong, dct_internal, (in, axes_, type, inorm, out_, + nthreads)) } template py::array dst_internal(const py::array &in, - const py::object &axes_, int type, int inorm, bool ortho, py::object &out_, + const py::object &axes_, int type, int inorm, py::object &out_, size_t nthreads) { auto axes = makeaxes(in, axes_); @@ -270,10 +269,9 @@ template py::array dst_internal(const py::array &in, auto d_out=reinterpret_cast(res.mutable_data()); { py::gil_scoped_release release; - // override - if (ortho) inorm=1; T fct = (type==1) ? norm_fct(inorm, dims, axes, 2, 1) : norm_fct(inorm, dims, axes, 2); + bool ortho = inorm == 1; pocketfft::dst(dims, s_in, s_out, axes, type, d_in, d_out, fct, ortho, nthreads); } @@ -281,10 +279,10 @@ template py::array dst_internal(const py::array &in, } py::array dst(const py::array &in, int type, const py::object &axes_, - int inorm, bool ortho, py::object &out_, size_t nthreads) + int inorm, py::object &out_, size_t nthreads) { if ((type<1) || (type>4)) throw invalid_argument("invalid DST type"); - DISPATCH(in, f64, f32, flong, dst_internal, (in, axes_, type, inorm, ortho, + DISPATCH(in, f64, f32, flong, dst_internal, (in, axes_, type, inorm, out_, nthreads)) } @@ -669,7 +667,7 @@ PYBIND11_MODULE(pypocketfft, m) m.def("genuine_hartley", genuine_hartley, genuine_hartley_DS, "a"_a, "axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1); m.def("dct", dct, dct_DS, "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0, - "ortho"_a=false, "out"_a=None, "nthreads"_a=1); + "out"_a=None, "nthreads"_a=1); m.def("dst", dst, dst_DS, "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0, - "ortho"_a=false, "out"_a=None, "nthreads"_a=1); + "out"_a=None, "nthreads"_a=1); } diff --git a/test.py b/test.py index 3d22f1e..9a309c0 100644 --- a/test.py +++ b/test.py @@ -231,13 +231,13 @@ def testdcst1Dortho(len, type): itp = (0, 1, 3, 2, 4) itype = itp[type] if type != 1 or len > 1: - _assert_close(a, pypocketfft.dct(pypocketfft.dct(c, ortho=True, type=type), ortho=True, type=itype), 2e-18) - _assert_close(a, pypocketfft.dct(pypocketfft.dct(a, ortho=True, type=type), ortho=True, type=itype), 1.5e-15) - _assert_close(b, pypocketfft.dct(pypocketfft.dct(b, ortho=True, type=type), ortho=True, type=itype), 6e-7) + _assert_close(a, pypocketfft.dct(pypocketfft.dct(c, inorm=1, type=type), inorm=1, type=itype), 2e-18) + _assert_close(a, pypocketfft.dct(pypocketfft.dct(a, inorm=1, type=type), inorm=1, type=itype), 1.5e-15) + _assert_close(b, pypocketfft.dct(pypocketfft.dct(b, inorm=1, type=type), inorm=1, type=itype), 6e-7) if type != 1: - _assert_close(a, pypocketfft.dst(pypocketfft.dst(c, ortho=True, type=type), ortho=True, type=itype), 2e-18) - _assert_close(a, pypocketfft.dst(pypocketfft.dst(a, ortho=True, type=type), ortho=True, type=itype), 1.5e-15) - _assert_close(b, pypocketfft.dst(pypocketfft.dst(b, ortho=True, type=type), ortho=True, type=itype), 6e-7) + _assert_close(a, pypocketfft.dst(pypocketfft.dst(c, inorm=1, type=type), inorm=1, type=itype), 2e-18) + _assert_close(a, pypocketfft.dst(pypocketfft.dst(a, inorm=1, type=type), inorm=1, type=itype), 1.5e-15) + _assert_close(b, pypocketfft.dst(pypocketfft.dst(b, inorm=1, type=type), inorm=1, type=itype), 6e-7) # TEMPORARY: separate test for DCT/DST IV, since they are less accurate -- GitLab