Commit 1b51a9ec authored by Martin Reinecke's avatar Martin Reinecke
Browse files

allow dumping and loading of fmavs to/from vectors

parent 6a69a231
Pipeline #74560 passed with stages
in 12 minutes and 9 seconds
......@@ -215,6 +215,36 @@ template<size_t ndim> class mav_info
}
};
class FmavIter
{
private:
fmav_info::shape_t pos;
fmav_info arr;
ptrdiff_t p;
size_t rem;
public:
FmavIter(const fmav_info &arr_)
: pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {}
void advance()
{
--rem;
for (int i_=int(pos.size())-1; i_>=0; --i_)
{
auto i = size_t(i_);
p += arr.stride(i);
if (++pos[i] < arr.shape(i))
return;
pos[i] = 0;
p -= ptrdiff_t(arr.shape(i))*arr.stride(i);
}
}
ptrdiff_t ofs() const { return p; }
size_t remaining() const { return rem; }
};
// "mav" stands for "multidimensional array view"
template<typename T> class fmav: public fmav_info, public membuf<T>
{
......@@ -293,12 +323,27 @@ template<typename T> class fmav: public fmav_info, public membuf<T>
fmav subarray(const shape_t &i0, const shape_t &extent)
{
auto [nshp, nstr, nofs] = subdata(i0, extent);
return fmav(tbuf(*this,nofs), nshp, nstr);
return fmav(nshp, nstr, tbuf::d+nofs, *this);
}
fmav subarray(const shape_t &i0, const shape_t &extent) const
{
auto [nshp, nstr, nofs] = subdata(i0, extent);
return fmav(tbuf(*this,nofs),nshp, nstr);
return fmav(nshp, nstr, tbuf::d+nofs, *this);
}
vector<T> dump() const
{
FmavIter it(*this);
vector<T> res(sz);
for (size_t i=0; i<sz; ++i, it.advance())
res[i] = operator[](it.ofs());
return res;
}
void load (const vector<T> &v)
{
MR_assert(v.size()==sz, "bad input data size");
FmavIter it(*this);
for (size_t i=0; i<sz; ++i, it.advance())
vraw(it.ofs()) = v[i];
}
};
......@@ -508,6 +553,7 @@ template<typename T, size_t ndim> class MavIter
using detail_mav::fmav_info;
using detail_mav::fmav;
using detail_mav::mav;
using detail_mav::FmavIter;
using detail_mav::MavIter;
}
......
......@@ -507,34 +507,6 @@ template<size_t N> class multi_iter
size_t remaining() const { return rem; }
};
class simple_iter
{
private:
shape_t pos;
fmav_info arr;
ptrdiff_t p;
size_t rem;
public:
simple_iter(const fmav_info &arr_)
: pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {}
void advance()
{
--rem;
for (int i_=int(pos.size())-1; i_>=0; --i_)
{
auto i = size_t(i_);
p += arr.stride(i);
if (++pos[i] < arr.shape(i))
return;
pos[i] = 0;
p -= ptrdiff_t(arr.shape(i))*arr.stride(i);
}
}
ptrdiff_t ofs() const { return p; }
size_t remaining() const { return rem; }
};
class rev_iter
{
private:
......@@ -1047,7 +1019,7 @@ template<typename T> void r2r_genuine_hartley(const fmav<T> &in,
tshp[axes.back()] = tshp[axes.back()]/2+1;
fmav<std::complex<T>> atmp(tshp);
r2c(in, atmp, axes, true, fct, nthreads);
simple_iter iin(atmp);
FmavIter iin(atmp);
rev_iter iout(out, axes);
auto vout = out.vdata();
while(iin.remaining()>0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment