Commit e045a9eb authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fixes

parent 1e3079c4
......@@ -1953,7 +1953,7 @@ class multiarr
class multi_iter
{
private:
public:
vector<diminfo> dim;
vector<size_t> pos;
size_t ofs_, len;
......@@ -2538,6 +2538,85 @@ py::array hartley(const py::array &in, py::object axes_, double fct)
return res;
}
template<typename T>py::array complex2hartley(const py::array &in,
const py::array &tmp, py::object axes_)
{
int ndim = in.ndim();
vector<size_t> dims_out(in.ndim());
for (int i=0; i<ndim; ++i)
dims_out[i] = in.shape(i);
py::array out;
if (in.dtype().is(f64))
out = py::array_t<double>(dims_out);
else if (in.dtype().is(f32))
out = py::array_t<float>(dims_out);
else throw runtime_error("unsupported data type");
vector<size_t> dims_tmp(ndim);
vector<int64_t> stride_tmp(ndim), stride_out(ndim);
for (int i=0; i<ndim; ++i)
{
dims_tmp[i] = tmp.shape(i);
stride_tmp[i] = tmp.strides(i)/tmp.itemsize();
if (stride_tmp[i]*tmp.itemsize() != tmp.strides(i))
throw runtime_error("weird strides");
stride_out[i] = out.strides(i)/out.itemsize();
if (stride_out[i]*out.itemsize() != out.strides(i))
throw runtime_error("weird strides");
}
auto axes = makeaxes(in, axes_);
int axis = axes.back();
multiarr a_tmp(ndim, dims_tmp.data(), stride_tmp.data()),
a_out(ndim, dims_out.data(), stride_out.data());
multi_iter it_tmp(a_tmp, axis), it_out(a_out, axis);
const cmplx<T> *tdata = (const cmplx<T> *)tmp.data();
T *odata = (T *)out.mutable_data();
vector<bool> swp(ndim-1,false);
for (auto i: axes)
{
if (i!=axis)
{
auto i2 = i<axis ? i : i-1;
swp[i2] = true;
}
}
while(!it_tmp.done())
{
auto tofs = it_tmp.offset();
auto ofs = it_out.offset();
size_t rofs = 0;
for (size_t i=0; i<it_out.pos.size(); ++i)
{
if (!swp[i])
rofs += it_out.pos[i]*it_out.dim[i].s;
else
{
auto x = (it_out.pos[i]==0) ? 0 : it_out.dim[i].n-it_out.pos[i];
rofs += x*it_out.dim[i].s;
}
}
for (size_t i=0; i<it_tmp.length(); ++i)
{
auto re = tdata[tofs + i*it_tmp.stride()].r;
auto im = tdata[tofs + i*it_tmp.stride()].i;
auto rev_i = (i==0) ? 0 : it_out.length()-i;
odata[ofs + i*it_out.stride()] = re+im;
odata[rofs + rev_i*it_out.stride()] = re-im;
}
it_tmp.advance();
it_out.advance();
}
return out;
}
py::array hartley2(const py::array &in, py::object axes_, double fct)
{
auto tmp = rfftn(in, axes_, fct);
if (in.dtype().is(f64))
return complex2hartley<double>(in, tmp, axes_);
else if (in.dtype().is(f32))
return complex2hartley<float>(in, tmp, axes_);
else throw runtime_error("unsupported data type");
}
const char *pypocketfft_DS = R"DELIM(
Fast Fourier and Hartley transforms.
......@@ -2667,4 +2746,5 @@ PYBIND11_MODULE(pypocketfft, m)
m.def("rfftn",&rfftn, rfftn_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.);
m.def("irfftn",&irfftn, irfftn_DS, "a"_a, "axes"_a=py::none(), "lastsize"_a=0, "fct"_a=1.);
m.def("hartley",&hartley, hartley_DS, "a"_a, "axes"_a=py::none(), "fct"_a=1.);
m.def("hartley2",&hartley2, "a"_a, "axes"_a=py::none(), "fct"_a=1.);
}
......@@ -71,18 +71,33 @@ def test_identity_r2(shp):
vmax = np.max(np.abs(a))
assert_allclose(pypocketfft.rfftn(pypocketfft.irfftn(a),fct=fct), a, atol=2e-15*vmax, rtol=0)
@pmp("shp", shapes3D)
def xtest_hartley(shp):
@pmp("shp", shapes2D+shapes3D)
def test_hartley2(shp):
a=np.random.rand(*shp)-0.5
v1 = pypocketfft.hartley(a)
v1 = pypocketfft.hartley2(a)
v2 = pypocketfft.fftn(a.astype(np.complex128))
vmax = np.max(np.abs(v1))
v2 = v2.real+v2.imag
assert_allclose(v1, v2, atol=2e-15*vmax, rtol=0)
@pmp("shp", shapes3D)
@pmp("shp", shapes)
def test_hartley_identity(shp):
a=np.random.rand(*shp)-0.5
v1 = pypocketfft.hartley(pypocketfft.hartley(a))/a.size
vmax = np.max(np.abs(a))
assert_allclose(a, v1, atol=2e-15*vmax, rtol=0)
@pmp("shp", shapes)
def test_hartley2_identity(shp):
a=np.random.rand(*shp)-0.5
v1 = pypocketfft.hartley2(pypocketfft.hartley2(a))/a.size
vmax = np.max(np.abs(a))
assert_allclose(a, v1, atol=2e-15*vmax, rtol=0)
@pmp("shp", shapes2D)
@pmp("axes", ((0,),(1,),(0,1),(1,0)))
def test_hartley2_2D(shp, axes):
a=np.random.rand(*shp)-0.5
fct = 1./np.prod(np.take(shp,axes))
vmax = np.max(np.abs(a))
assert_allclose(pypocketfft.hartley2(pypocketfft.hartley2(a, axes=axes), axes=axes, fct=fct),a , atol=2e-15*vmax, rtol=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment