From 1b51a9ec38fc623f363234cd32bd78e77ab5b41a Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Fri, 8 May 2020 12:30:00 +0200
Subject: [PATCH] allow dumping and loading of fmavs to/from vectors

---
 src/mr_util/infra/mav.h | 50 +++++++++++++++++++++++++++++++++++++++--
 src/mr_util/math/fft.h  | 30 +------------------------
 2 files changed, 49 insertions(+), 31 deletions(-)

diff --git a/src/mr_util/infra/mav.h b/src/mr_util/infra/mav.h
index 35b656d..18f8bb4 100644
--- a/src/mr_util/infra/mav.h
+++ b/src/mr_util/infra/mav.h
@@ -215,6 +215,36 @@ template<size_t ndim> class mav_info
       }
   };
 
+
+class FmavIter
+  {
+  private:
+    fmav_info::shape_t pos;
+    fmav_info arr;
+    ptrdiff_t p;
+    size_t rem;
+
+  public:
+    FmavIter(const fmav_info &arr_)
+      : pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {}
+    void advance()
+      {
+      --rem;
+      for (int i_=int(pos.size())-1; i_>=0; --i_)
+        {
+        auto i = size_t(i_);
+        p += arr.stride(i);
+        if (++pos[i] < arr.shape(i))
+          return;
+        pos[i] = 0;
+        p -= ptrdiff_t(arr.shape(i))*arr.stride(i);
+        }
+      }
+    ptrdiff_t ofs() const { return p; }
+    size_t remaining() const { return rem; }
+  };
+
+
 // "mav" stands for "multidimensional array view"
 template<typename T> class fmav: public fmav_info, public membuf<T>
   {
@@ -293,12 +323,27 @@ template<typename T> class fmav: public fmav_info, public membuf<T>
     fmav subarray(const shape_t &i0, const shape_t &extent)
       {
       auto [nshp, nstr, nofs] = subdata(i0, extent);
-      return fmav(tbuf(*this,nofs), nshp, nstr);
+      return fmav(nshp, nstr, tbuf::d+nofs, *this);
       }
     fmav subarray(const shape_t &i0, const shape_t &extent) const
       {
       auto [nshp, nstr, nofs] = subdata(i0, extent);
-      return fmav(tbuf(*this,nofs),nshp, nstr);
+      return fmav(nshp, nstr, tbuf::d+nofs, *this);
+      }
+    vector<T> dump() const
+      {
+      FmavIter it(*this);
+      vector<T> res(sz);
+      for (size_t i=0; i<sz; ++i, it.advance())
+        res[i] = operator[](it.ofs());
+      return res;
+      }
+    void load (const vector<T> &v)
+      {
+      MR_assert(v.size()==sz, "bad input data size");
+      FmavIter it(*this);
+      for (size_t i=0; i<sz; ++i, it.advance())
+        vraw(it.ofs()) = v[i];
       }
   };
 
@@ -508,6 +553,7 @@ template<typename T, size_t ndim> class MavIter
 using detail_mav::fmav_info;
 using detail_mav::fmav;
 using detail_mav::mav;
+using detail_mav::FmavIter;
 using detail_mav::MavIter;
 
 }
diff --git a/src/mr_util/math/fft.h b/src/mr_util/math/fft.h
index d9905c6..313534f 100644
--- a/src/mr_util/math/fft.h
+++ b/src/mr_util/math/fft.h
@@ -507,34 +507,6 @@ template<size_t N> class multi_iter
     size_t remaining() const { return rem; }
   };
 
-class simple_iter
-  {
-  private:
-    shape_t pos;
-    fmav_info arr;
-    ptrdiff_t p;
-    size_t rem;
-
-  public:
-    simple_iter(const fmav_info &arr_)
-      : pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {}
-    void advance()
-      {
-      --rem;
-      for (int i_=int(pos.size())-1; i_>=0; --i_)
-        {
-        auto i = size_t(i_);
-        p += arr.stride(i);
-        if (++pos[i] < arr.shape(i))
-          return;
-        pos[i] = 0;
-        p -= ptrdiff_t(arr.shape(i))*arr.stride(i);
-        }
-      }
-    ptrdiff_t ofs() const { return p; }
-    size_t remaining() const { return rem; }
-  };
-
 class rev_iter
   {
   private:
@@ -1047,7 +1019,7 @@ template<typename T> void r2r_genuine_hartley(const fmav<T> &in,
   tshp[axes.back()] = tshp[axes.back()]/2+1;
   fmav<std::complex<T>> atmp(tshp);
   r2c(in, atmp, axes, true, fct, nthreads);
-  simple_iter iin(atmp);
+  FmavIter iin(atmp);
   rev_iter iout(out, axes);
   auto vout = out.vdata();
   while(iin.remaining()>0)
-- 
GitLab