Commit 28540189 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

introduce simple_iter and rev_iter to simplify iterating over arrays with Hermitian symmetry

parent d77cdeb5
...@@ -2236,16 +2236,18 @@ template<typename T> class ndarr ...@@ -2236,16 +2236,18 @@ template<typename T> class ndarr
{ return *reinterpret_cast<T *>(d+ofs); } { return *reinterpret_cast<T *>(d+ofs); }
const T &operator[](ptrdiff_t ofs) const const T &operator[](ptrdiff_t ofs) const
{ return *reinterpret_cast<const T *>(cd+ofs); } { return *reinterpret_cast<const T *>(cd+ofs); }
T get(ptrdiff_t ofs) const
{ return *reinterpret_cast<const T *>(cd+ofs); }
void set(ptrdiff_t ofs, const T &val) const
{ *reinterpret_cast<T *>(d+ofs) = val; }
}; };
template<size_t N, typename Ti, typename To> class multi_iter template<size_t N, typename Ti, typename To> class multi_iter
{ {
public: private:
shape_t pos; shape_t pos;
const ndarr<Ti> &iarr; const ndarr<Ti> &iarr;
ndarr<To> &oarr; ndarr<To> &oarr;
private:
ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o; ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o;
size_t idim, rem; size_t idim, rem;
...@@ -2320,6 +2322,97 @@ template<size_t N, typename Ti, typename To> class multi_iter ...@@ -2320,6 +2322,97 @@ template<size_t N, typename Ti, typename To> class multi_iter
bool contiguous_out() const { return str_o==sizeof(To); } bool contiguous_out() const { return str_o==sizeof(To); }
}; };
template<typename T> class simple_iter
{
private:
shape_t pos;
const ndarr<T> &arr;
ptrdiff_t p;
size_t rem;
public:
simple_iter(const ndarr<T> &arr_)
: pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {}
void advance()
{
--rem;
for (int i_=int(pos.size())-1; i_>=0; --i_)
{
auto i = size_t(i_);
p += arr.stride(i);
if (++pos[i] < arr.shape(i))
return;
pos[i] = 0;
p -= ptrdiff_t(arr.shape(i))*arr.stride(i);
}
}
ptrdiff_t ofs() const { return p; }
size_t remaining() const { return rem; }
};
template<typename T> class rev_iter
{
private:
shape_t pos;
ndarr<T> &arr;
vector<char> rev_axis;
vector<char> rev_jump;
size_t last_axis, last_size;
shape_t shp;
ptrdiff_t p, rp;
size_t rem;
public:
rev_iter(ndarr<T> &arr_, const shape_t &axes)
: pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0),
rev_jump(arr_.ndim(), 1), p(0), rp(0)
{
for (auto ax: axes)
rev_axis[ax]=1;
last_axis = axes.back();
last_size = arr.shape(last_axis)/2 + 1;
shp = arr.shape();
shp[last_axis] = last_size;
rem=1;
for (auto i: shp)
rem *= i;
}
void advance()
{
--rem;
for (int i_=int(pos.size())-1; i_>=0; --i_)
{
auto i = size_t(i_);
p += arr.stride(i);
if (!rev_axis[i])
rp += arr.stride(i);
else
{
rp -= arr.stride(i);
if (rev_jump[i])
{
rp += ptrdiff_t(arr.shape(i))*arr.stride(i);
rev_jump[i] = 0;
}
}
if (++pos[i] < shp[i])
return;
pos[i] = 0;
p -= ptrdiff_t(shp[i])*arr.stride(i);
if (rev_axis[i])
{
rp -= ptrdiff_t(arr.shape(i)-shp[i])*arr.stride(i);
rev_jump[i] = 1;
}
else
rp -= ptrdiff_t(shp[i])*arr.stride(i);
}
}
ptrdiff_t ofs() const { return p; }
ptrdiff_t rev_ofs() const { return rp; }
size_t remaining() const { return rem; }
};
#ifndef POCKETFFT_NO_VECTORS #ifndef POCKETFFT_NO_VECTORS
template<typename T> struct VTYPE {}; template<typename T> struct VTYPE {};
template<> struct VTYPE<float> template<> struct VTYPE<float>
......
...@@ -216,7 +216,6 @@ template<typename T>py::array complex2hartley(const py::array &in, ...@@ -216,7 +216,6 @@ template<typename T>py::array complex2hartley(const py::array &in,
const py::array &tmp, py::object axes_, bool inplace) const py::array &tmp, py::object axes_, bool inplace)
{ {
using namespace pocketfft::detail; using namespace pocketfft::detail;
size_t ndim = size_t(in.ndim());
auto dims_out(copy_shape(in)); auto dims_out(copy_shape(in));
py::array out = inplace ? in : py::array_t<T>(dims_out); py::array out = inplace ? in : py::array_t<T>(dims_out);
ndarr<cmplx<T>> atmp(tmp.data(), copy_shape(tmp), copy_strides(tmp)); ndarr<cmplx<T>> atmp(tmp.data(), copy_shape(tmp), copy_strides(tmp));
...@@ -224,35 +223,15 @@ template<typename T>py::array complex2hartley(const py::array &in, ...@@ -224,35 +223,15 @@ template<typename T>py::array complex2hartley(const py::array &in,
auto axes = makeaxes(in, axes_); auto axes = makeaxes(in, axes_);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
size_t axis = axes.back(); simple_iter<cmplx<T>> iin(atmp);
multi_iter<1,cmplx<T>,T> it(atmp, aout, axis); rev_iter<T> iout(aout, axes);
vector<bool> swp(ndim,false); if (iin.remaining()!=iout.remaining()) throw runtime_error("oops");
for (auto i: axes) while(iin.remaining()>0)
if (i!=axis)
swp[i] = true;
while(it.remaining()>0)
{ {
ptrdiff_t rofs = 0; auto v = atmp.get(iin.ofs());
for (size_t i=0; i<it.pos.size(); ++i) aout.set(iout.ofs(), v.r+v.i);
{ aout.set(iout.rev_ofs(), v.r-v.i);
if (i==axis) continue; iin.advance(); iout.advance();
if (!swp[i])
rofs += ptrdiff_t(it.pos[i])*it.oarr.stride(i);
else
{
auto x = ptrdiff_t((it.pos[i]==0) ? 0 : it.iarr.shape(i)-it.pos[i]);
rofs += x*it.oarr.stride(i);
}
}
it.advance(1);
for (size_t i=0; i<it.length_in(); ++i)
{
auto re = it.in(i).r;
auto im = it.in(i).i;
auto rev_i = ptrdiff_t((i==0) ? 0 : it.length_out()-i);
it.out(i) = re+im;
aout[rofs + rev_i*it.stride_out()] = re-im;
}
} }
} }
return out; return out;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment