Commit 1dda33da authored by Martin Reinecke's avatar Martin Reinecke

move genuine_hartley into pocketfft proper

parent 4b075930
...@@ -3387,6 +3387,33 @@ template<typename T> void r2r_separable_hartley(const shape_t &shape, ...@@ -3387,6 +3387,33 @@ template<typename T> void r2r_separable_hartley(const shape_t &shape,
false); false);
} }
template<typename T> void r2r_genuine_hartley(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const T *data_in, T *data_out, T fct, size_t nthreads=1)
{
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
shape_t tshp(shape);
tshp[axes.back()] = tshp[axes.back()]/2+1;
arr<complex<T>> tdata(util::prod(tshp));
stride_t tstride(shape.size());
tstride.back()=sizeof(complex<T>);
for (size_t i=tstride.size()-1; i>0; --i)
tstride[i-1]=tstride[i]*ptrdiff_t(tshp[i]);
r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads);
cndarr<cmplx<T>> atmp(tdata.data(), tshp, tstride);
ndarr<T> aout(data_out, shape, stride_out);
simple_iter iin(atmp);
rev_iter iout(aout, axes);
while(iin.remaining()>0)
{
auto v = atmp[iin.ofs()];
aout[iout.ofs()] = v.r+v.i;
aout[iout.rev_ofs()] = v.r-v.i;
iin.advance(); iout.advance();
}
}
} // namespace detail } // namespace detail
using detail::FORWARD; using detail::FORWARD;
...@@ -3398,6 +3425,7 @@ using detail::c2r; ...@@ -3398,6 +3425,7 @@ using detail::c2r;
using detail::r2c; using detail::r2c;
using detail::r2r_fftpack; using detail::r2r_fftpack;
using detail::r2r_separable_hartley; using detail::r2r_separable_hartley;
using detail::r2r_genuine_hartley;
using detail::dct; using detail::dct;
using detail::dst; using detail::dst;
......
...@@ -346,35 +346,30 @@ py::array separable_hartley(const py::array &in, const py::object &axes_, ...@@ -346,35 +346,30 @@ py::array separable_hartley(const py::array &in, const py::object &axes_,
out_, nthreads)) out_, nthreads))
} }
template<typename T>py::array complex2hartley(const py::array &in, template<typename T> py::array genuine_hartley_internal(const py::array &in,
const py::array &tmp, const py::object &axes_, py::object &out_) const py::object &axes_, int inorm, py::object &out_, size_t nthreads)
{ {
using namespace pocketfft::detail; auto dims(copy_shape(in));
auto dims_out(copy_shape(in)); py::array res = prepare_output<T>(out_, dims);
py::array out = prepare_output<T>(out_, dims_out);
cndarr<cmplx<T>> atmp(tmp.data(), copy_shape(tmp), copy_strides(tmp));
ndarr<T> aout(out.mutable_data(), copy_shape(out), copy_strides(out));
auto axes = makeaxes(in, axes_); auto axes = makeaxes(in, axes_);
auto s_in=copy_strides(in);
auto s_out=copy_strides(res);
auto d_in=reinterpret_cast<const T *>(in.data());
auto d_out=reinterpret_cast<T *>(res.mutable_data());
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
simple_iter iin(atmp); T fct = norm_fct<T>(inorm, dims, axes);
rev_iter iout(aout, axes); pocketfft::r2r_genuine_hartley(dims, s_in, s_out, axes, d_in, d_out, fct,
while(iin.remaining()>0) nthreads);
{
auto v = atmp[iin.ofs()];
aout[iout.ofs()] = v.r+v.i;
aout[iout.rev_ofs()] = v.r-v.i;
iin.advance(); iout.advance();
}
} }
return out; return res;
} }
py::array genuine_hartley(const py::array &in, const py::object &axes_, py::array genuine_hartley(const py::array &in, const py::object &axes_,
int inorm, py::object &out_, size_t nthreads) int inorm, py::object &out_, size_t nthreads)
{ {
auto tmp = r2c(in, axes_, true, inorm, None, nthreads); DISPATCH(in, f64, f32, flong, genuine_hartley_internal, (in, axes_, inorm,
DISPATCH(in, f64, f32, flong, complex2hartley, (in, tmp, axes_, out_)) out_, nthreads))
} }
size_t good_size(size_t n, bool real) size_t good_size(size_t n, bool real)
......
...@@ -197,7 +197,7 @@ def test_genuine_hartley_identity(shp): ...@@ -197,7 +197,7 @@ def test_genuine_hartley_identity(shp):
assert_(_l2error(a, v1) < 1e-15) assert_(_l2error(a, v1) < 1e-15)
@pmp("shp", shapes2D) @pmp("shp", shapes2D+shapes3D)
@pmp("axes", ((0,), (1,), (0, 1), (1, 0))) @pmp("axes", ((0,), (1,), (0, 1), (1, 0)))
def test_genuine_hartley_2D(shp, axes): def test_genuine_hartley_2D(shp, axes):
a = np.random.rand(*shp)-0.5 a = np.random.rand(*shp)-0.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment