Commit 8bbee24d authored by Martin Reinecke's avatar Martin Reinecke

try to add ortho support

parent 913949e9
This diff is collapsed.
......@@ -227,7 +227,7 @@ py::array r2r_fftpack(const py::array &in, const py::object &axes_,
}
template<typename T> py::array dct_internal(const py::array &in,
const py::object &axes_, int type, int inorm, py::object &out_,
const py::object &axes_, int type, int inorm, bool ortho, py::object &out_,
size_t nthreads)
{
auto axes = makeaxes(in, axes_);
......@@ -239,24 +239,26 @@ template<typename T> py::array dct_internal(const py::array &in,
auto d_out=reinterpret_cast<T *>(res.mutable_data());
{
py::gil_scoped_release release;
// override
if (ortho) inorm=1;
T fct = (type==1) ? norm_fct<T>(inorm, dims, axes, 2, -1)
: norm_fct<T>(inorm, dims, axes, 2);
pocketfft::dct(dims, s_in, s_out, axes, type, d_in, d_out, fct,
pocketfft::dct(dims, s_in, s_out, axes, type, d_in, d_out, fct, ortho,
nthreads);
}
return res;
}
py::array dct(const py::array &in, int type, const py::object &axes_,
int inorm, py::object &out_, size_t nthreads)
int inorm, bool ortho, py::object &out_, size_t nthreads)
{
if ((type<1) || (type>4)) throw invalid_argument("invalid DCT type");
DISPATCH(in, f64, f32, flong, dct_internal, (in, axes_, type, inorm,
DISPATCH(in, f64, f32, flong, dct_internal, (in, axes_, type, inorm, ortho,
out_, nthreads))
}
template<typename T> py::array dst_internal(const py::array &in,
const py::object &axes_, int type, int inorm, py::object &out_,
const py::object &axes_, int type, int inorm, bool ortho, py::object &out_,
size_t nthreads)
{
auto axes = makeaxes(in, axes_);
......@@ -268,19 +270,21 @@ template<typename T> py::array dst_internal(const py::array &in,
auto d_out=reinterpret_cast<T *>(res.mutable_data());
{
py::gil_scoped_release release;
// override
if (ortho) inorm=1;
T fct = (type==1) ? norm_fct<T>(inorm, dims, axes, 2, 1)
: norm_fct<T>(inorm, dims, axes, 2);
pocketfft::dst(dims, s_in, s_out, axes, type, d_in, d_out, fct,
pocketfft::dst(dims, s_in, s_out, axes, type, d_in, d_out, fct, ortho,
nthreads);
}
return res;
}
py::array dst(const py::array &in, int type, const py::object &axes_,
int inorm, py::object &out_, size_t nthreads)
int inorm, bool ortho, py::object &out_, size_t nthreads)
{
if ((type<1) || (type>4)) throw invalid_argument("invalid DST type");
DISPATCH(in, f64, f32, flong, dst_internal, (in, axes_, type, inorm,
DISPATCH(in, f64, f32, flong, dst_internal, (in, axes_, type, inorm, ortho,
out_, nthreads))
}
......@@ -665,7 +669,7 @@ PYBIND11_MODULE(pypocketfft, m)
m.def("genuine_hartley", genuine_hartley, genuine_hartley_DS, "a"_a,
"axes"_a=None, "inorm"_a=0, "out"_a=None, "nthreads"_a=1);
m.def("dct", dct, dct_DS, "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0,
"out"_a=None, "nthreads"_a=1);
"ortho"_a=false, "out"_a=None, "nthreads"_a=1);
m.def("dst", dst, dst_DS, "a"_a, "type"_a, "axes"_a=None, "inorm"_a=0,
"out"_a=None, "nthreads"_a=1);
"ortho"_a=false, "out"_a=None, "nthreads"_a=1);
}
......@@ -222,6 +222,23 @@ def testdcst1D(len, inorm, type):
_assert_close(a, pypocketfft.dst(pypocketfft.dst(a, inorm=inorm, type=type), inorm=2-inorm, type=itype), 1.5e-15)
_assert_close(b, pypocketfft.dst(pypocketfft.dst(b, inorm=inorm, type=type), inorm=2-inorm, type=itype), 6e-7)
@pmp("len", len1D)
@pmp("type", [1, 2, 3])
def testdcst1Dortho(len, type):
a = np.random.rand(len)-0.5
b = a.astype(np.float32)
c = a.astype(np.float128)
itp = (0, 1, 3, 2, 4)
itype = itp[type]
if type != 1 or len > 1:
_assert_close(a, pypocketfft.dct(pypocketfft.dct(c, ortho=True, type=type), ortho=True, type=itype), 2e-18)
_assert_close(a, pypocketfft.dct(pypocketfft.dct(a, ortho=True, type=type), ortho=True, type=itype), 1.5e-15)
_assert_close(b, pypocketfft.dct(pypocketfft.dct(b, ortho=True, type=type), ortho=True, type=itype), 6e-7)
if type != 1:
_assert_close(a, pypocketfft.dst(pypocketfft.dst(c, ortho=True, type=type), ortho=True, type=itype), 2e-18)
_assert_close(a, pypocketfft.dst(pypocketfft.dst(a, ortho=True, type=type), ortho=True, type=itype), 1.5e-15)
_assert_close(b, pypocketfft.dst(pypocketfft.dst(b, ortho=True, type=type), ortho=True, type=itype), 6e-7)
# TEMPORARY: separate test for DCT/DST IV, since they are less accurate
@pmp("len", len1D)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment