Commit 28540189 by Martin Reinecke

### introduce simple_iter and rev_iter to simplify iterating over arrays with Hermitian symmetry

parent d77cdeb5
 ... ... @@ -2236,16 +2236,18 @@ template class ndarr { return *reinterpret_cast(d+ofs); } const T &operator[](ptrdiff_t ofs) const { return *reinterpret_cast(cd+ofs); } T get(ptrdiff_t ofs) const { return *reinterpret_cast(cd+ofs); } void set(ptrdiff_t ofs, const T &val) const { *reinterpret_cast(d+ofs) = val; } }; template class multi_iter { public: private: shape_t pos; const ndarr &iarr; ndarr &oarr; private: ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o; size_t idim, rem; ... ... @@ -2320,6 +2322,97 @@ template class multi_iter bool contiguous_out() const { return str_o==sizeof(To); } }; template class simple_iter { private: shape_t pos; const ndarr &arr; ptrdiff_t p; size_t rem; public: simple_iter(const ndarr &arr_) : pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {} void advance() { --rem; for (int i_=int(pos.size())-1; i_>=0; --i_) { auto i = size_t(i_); p += arr.stride(i); if (++pos[i] < arr.shape(i)) return; pos[i] = 0; p -= ptrdiff_t(arr.shape(i))*arr.stride(i); } } ptrdiff_t ofs() const { return p; } size_t remaining() const { return rem; } }; template class rev_iter { private: shape_t pos; ndarr &arr; vector rev_axis; vector rev_jump; size_t last_axis, last_size; shape_t shp; ptrdiff_t p, rp; size_t rem; public: rev_iter(ndarr &arr_, const shape_t &axes) : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0), rev_jump(arr_.ndim(), 1), p(0), rp(0) { for (auto ax: axes) rev_axis[ax]=1; last_axis = axes.back(); last_size = arr.shape(last_axis)/2 + 1; shp = arr.shape(); shp[last_axis] = last_size; rem=1; for (auto i: shp) rem *= i; } void advance() { --rem; for (int i_=int(pos.size())-1; i_>=0; --i_) { auto i = size_t(i_); p += arr.stride(i); if (!rev_axis[i]) rp += arr.stride(i); else { rp -= arr.stride(i); if (rev_jump[i]) { rp += ptrdiff_t(arr.shape(i))*arr.stride(i); rev_jump[i] = 0; } } if (++pos[i] < shp[i]) return; pos[i] = 0; p -= ptrdiff_t(shp[i])*arr.stride(i); if (rev_axis[i]) { rp -= ptrdiff_t(arr.shape(i)-shp[i])*arr.stride(i); rev_jump[i] = 1; } else rp -= ptrdiff_t(shp[i])*arr.stride(i); } } ptrdiff_t ofs() const { return p; } ptrdiff_t rev_ofs() const { return rp; } size_t remaining() const { return rem; } }; #ifndef POCKETFFT_NO_VECTORS template struct VTYPE {}; template<> struct VTYPE ... ...
 ... ... @@ -216,7 +216,6 @@ templatepy::array complex2hartley(const py::array &in, const py::array &tmp, py::object axes_, bool inplace) { using namespace pocketfft::detail; size_t ndim = size_t(in.ndim()); auto dims_out(copy_shape(in)); py::array out = inplace ? in : py::array_t(dims_out); ndarr> atmp(tmp.data(), copy_shape(tmp), copy_strides(tmp)); ... ... @@ -224,35 +223,15 @@ templatepy::array complex2hartley(const py::array &in, auto axes = makeaxes(in, axes_); { py::gil_scoped_release release; size_t axis = axes.back(); multi_iter<1,cmplx,T> it(atmp, aout, axis); vector swp(ndim,false); for (auto i: axes) if (i!=axis) swp[i] = true; while(it.remaining()>0) simple_iter> iin(atmp); rev_iter iout(aout, axes); if (iin.remaining()!=iout.remaining()) throw runtime_error("oops"); while(iin.remaining()>0) { ptrdiff_t rofs = 0; for (size_t i=0; i
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!