diff --git a/src/mr_util/infra/mav.h b/src/mr_util/infra/mav.h index 35b656da2e1a441daecbcc8c9fce0bf37fb9f500..18f8bb4865a77fe9ea1ebc7d46d6c798d04559a8 100644 --- a/src/mr_util/infra/mav.h +++ b/src/mr_util/infra/mav.h @@ -215,6 +215,36 @@ template<size_t ndim> class mav_info } }; + +class FmavIter + { + private: + fmav_info::shape_t pos; + fmav_info arr; + ptrdiff_t p; + size_t rem; + + public: + FmavIter(const fmav_info &arr_) + : pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {} + void advance() + { + --rem; + for (int i_=int(pos.size())-1; i_>=0; --i_) + { + auto i = size_t(i_); + p += arr.stride(i); + if (++pos[i] < arr.shape(i)) + return; + pos[i] = 0; + p -= ptrdiff_t(arr.shape(i))*arr.stride(i); + } + } + ptrdiff_t ofs() const { return p; } + size_t remaining() const { return rem; } + }; + + // "mav" stands for "multidimensional array view" template<typename T> class fmav: public fmav_info, public membuf<T> { @@ -293,12 +323,27 @@ template<typename T> class fmav: public fmav_info, public membuf<T> fmav subarray(const shape_t &i0, const shape_t &extent) { auto [nshp, nstr, nofs] = subdata(i0, extent); - return fmav(tbuf(*this,nofs), nshp, nstr); + return fmav(nshp, nstr, tbuf::d+nofs, *this); } fmav subarray(const shape_t &i0, const shape_t &extent) const { auto [nshp, nstr, nofs] = subdata(i0, extent); - return fmav(tbuf(*this,nofs),nshp, nstr); + return fmav(nshp, nstr, tbuf::d+nofs, *this); + } + vector<T> dump() const + { + FmavIter it(*this); + vector<T> res(sz); + for (size_t i=0; i<sz; ++i, it.advance()) + res[i] = operator[](it.ofs()); + return res; + } + void load (const vector<T> &v) + { + MR_assert(v.size()==sz, "bad input data size"); + FmavIter it(*this); + for (size_t i=0; i<sz; ++i, it.advance()) + vraw(it.ofs()) = v[i]; } }; @@ -508,6 +553,7 @@ template<typename T, size_t ndim> class MavIter using detail_mav::fmav_info; using detail_mav::fmav; using detail_mav::mav; +using detail_mav::FmavIter; using detail_mav::MavIter; } diff --git a/src/mr_util/math/fft.h b/src/mr_util/math/fft.h index d9905c6d2fceccc73a672e6c28bafdcc52103524..313534f01550fd11db501dd051ebe4af94318bfd 100644 --- a/src/mr_util/math/fft.h +++ b/src/mr_util/math/fft.h @@ -507,34 +507,6 @@ template<size_t N> class multi_iter size_t remaining() const { return rem; } }; -class simple_iter - { - private: - shape_t pos; - fmav_info arr; - ptrdiff_t p; - size_t rem; - - public: - simple_iter(const fmav_info &arr_) - : pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {} - void advance() - { - --rem; - for (int i_=int(pos.size())-1; i_>=0; --i_) - { - auto i = size_t(i_); - p += arr.stride(i); - if (++pos[i] < arr.shape(i)) - return; - pos[i] = 0; - p -= ptrdiff_t(arr.shape(i))*arr.stride(i); - } - } - ptrdiff_t ofs() const { return p; } - size_t remaining() const { return rem; } - }; - class rev_iter { private: @@ -1047,7 +1019,7 @@ template<typename T> void r2r_genuine_hartley(const fmav<T> &in, tshp[axes.back()] = tshp[axes.back()]/2+1; fmav<std::complex<T>> atmp(tshp); r2c(in, atmp, axes, true, fct, nthreads); - simple_iter iin(atmp); + FmavIter iin(atmp); rev_iter iout(out, axes); auto vout = out.vdata(); while(iin.remaining()>0)