Commit 1dda33da authored by Martin Reinecke's avatar Martin Reinecke

move genuine_hartley into pocketfft proper

parent 4b075930
......@@ -3387,6 +3387,33 @@ template<typename T> void r2r_separable_hartley(const shape_t &shape,
false);
}
template<typename T> void r2r_genuine_hartley(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const T *data_in, T *data_out, T fct, size_t nthreads=1)
{
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
shape_t tshp(shape);
tshp[axes.back()] = tshp[axes.back()]/2+1;
arr<complex<T>> tdata(util::prod(tshp));
stride_t tstride(shape.size());
tstride.back()=sizeof(complex<T>);
for (size_t i=tstride.size()-1; i>0; --i)
tstride[i-1]=tstride[i]*ptrdiff_t(tshp[i]);
r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads);
cndarr<cmplx<T>> atmp(tdata.data(), tshp, tstride);
ndarr<T> aout(data_out, shape, stride_out);
simple_iter iin(atmp);
rev_iter iout(aout, axes);
while(iin.remaining()>0)
{
auto v = atmp[iin.ofs()];
aout[iout.ofs()] = v.r+v.i;
aout[iout.rev_ofs()] = v.r-v.i;
iin.advance(); iout.advance();
}
}
} // namespace detail
using detail::FORWARD;
......@@ -3398,6 +3425,7 @@ using detail::c2r;
using detail::r2c;
using detail::r2r_fftpack;
using detail::r2r_separable_hartley;
using detail::r2r_genuine_hartley;
using detail::dct;
using detail::dst;
......
......@@ -346,35 +346,30 @@ py::array separable_hartley(const py::array &in, const py::object &axes_,
out_, nthreads))
}
template<typename T>py::array complex2hartley(const py::array &in,
const py::array &tmp, const py::object &axes_, py::object &out_)
template<typename T> py::array genuine_hartley_internal(const py::array &in,
const py::object &axes_, int inorm, py::object &out_, size_t nthreads)
{
using namespace pocketfft::detail;
auto dims_out(copy_shape(in));
py::array out = prepare_output<T>(out_, dims_out);
cndarr<cmplx<T>> atmp(tmp.data(), copy_shape(tmp), copy_strides(tmp));
ndarr<T> aout(out.mutable_data(), copy_shape(out), copy_strides(out));
auto dims(copy_shape(in));
py::array res = prepare_output<T>(out_, dims);
auto axes = makeaxes(in, axes_);
auto s_in=copy_strides(in);
auto s_out=copy_strides(res);
auto d_in=reinterpret_cast<const T *>(in.data());
auto d_out=reinterpret_cast<T *>(res.mutable_data());
{
py::gil_scoped_release release;
simple_iter iin(atmp);
rev_iter iout(aout, axes);
while(iin.remaining()>0)
{
auto v = atmp[iin.ofs()];
aout[iout.ofs()] = v.r+v.i;
aout[iout.rev_ofs()] = v.r-v.i;
iin.advance(); iout.advance();
}
T fct = norm_fct<T>(inorm, dims, axes);
pocketfft::r2r_genuine_hartley(dims, s_in, s_out, axes, d_in, d_out, fct,
nthreads);
}
return out;
return res;
}
py::array genuine_hartley(const py::array &in, const py::object &axes_,
int inorm, py::object &out_, size_t nthreads)
{
auto tmp = r2c(in, axes_, true, inorm, None, nthreads);
DISPATCH(in, f64, f32, flong, complex2hartley, (in, tmp, axes_, out_))
DISPATCH(in, f64, f32, flong, genuine_hartley_internal, (in, axes_, inorm,
out_, nthreads))
}
size_t good_size(size_t n, bool real)
......
......@@ -197,7 +197,7 @@ def test_genuine_hartley_identity(shp):
assert_(_l2error(a, v1) < 1e-15)
@pmp("shp", shapes2D)
@pmp("shp", shapes2D+shapes3D)
@pmp("axes", ((0,), (1,), (0, 1), (1, 0)))
def test_genuine_hartley_2D(shp, axes):
a = np.random.rand(*shp)-0.5
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment