From 007a96f6fc15421c82b01bff7c05bbd9d2a2a68b Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Mon, 29 Apr 2019 11:53:00 +0200
Subject: [PATCH] simplify

---
 pocketfft.cc | 79 ++++++----------------------------------------------
 1 file changed, 8 insertions(+), 71 deletions(-)

diff --git a/pocketfft.cc b/pocketfft.cc
index 73f7b3a..8488848 100644
--- a/pocketfft.cc
+++ b/pocketfft.cc
@@ -1968,63 +1968,14 @@ template<typename T0> class pocketfft_r
 using shape_t = vector<size_t>;
 using stride_t = vector<int64_t>;
 
-struct diminfo_io
-  { size_t n; int64_t s_i, s_o; };
-
-class multi_iter_io
+template<size_t N, typename Ti, typename To> class multi_iter
   {
   public:
+    struct diminfo_io
+       { size_t n; int64_t s_i, s_o; };
     vector<diminfo_io> dim;
     shape_t pos;
-    int64_t ofs_i, ofs_o;
-    size_t len_i, len_o;
-    int64_t str_i, str_o;
-    size_t rem;
-
-  public:
-    multi_iter_io(const shape_t &shape_i, const stride_t &stride_i,
-      const shape_t &shape_o, const stride_t &stride_o, size_t idim)
-      : pos(shape_i.size()-1, 0), ofs_i(0), ofs_o(0), len_i(shape_i[idim]),
-        len_o(shape_o[idim]), str_i(stride_i[idim]),
-        str_o(stride_o[idim]), rem(1)
-      {
-      dim.reserve(shape_i.size()-1);
-      for (size_t i=0; i<shape_i.size(); ++i)
-        if (i!=idim)
-          {
-          dim.push_back({shape_i[i], stride_i[i], stride_o[i]});
-          rem *= shape_i[i];
-          }
-      }
-    void advance()
-      {
-      if (--rem==0) return;
-      for (int i=pos.size()-1; i>=0; --i)
-        {
-        ++pos[i];
-        ofs_i += dim[i].s_i;
-        ofs_o += dim[i].s_o;
-        if (pos[i] < dim[i].n)
-          return;
-        pos[i] = 0;
-        ofs_i -= dim[i].n*dim[i].s_i;
-        ofs_o -= dim[i].n*dim[i].s_o;
-        }
-      }
-    int64_t offset_in() const { return ofs_i; }
-    int64_t offset_out() const { return ofs_o; }
-    size_t length_in() const { return len_i; }
-    size_t length_out() const { return len_o; }
-    int64_t stride_in() const { return str_i; }
-    int64_t stride_out() const { return str_o; }
-    size_t remaining() const { return rem; }
-  };
-
-template<size_t N, typename Ti, typename To> class multi_iter
-  {
   private:
-    vector<diminfo_io> dim;
-    shape_t pos;
     const Ti *p_ii, *p_i[N];
     To *p_oi, *p_o[N];
     size_t len_i, len_o;
@@ -2148,18 +2099,6 @@ template<typename T> arr<char> alloc_tmp(const shape_t &shape,
   return arr<char>(tmpsize*elemsize);
   }
 
-template<size_t vlen> struct multioffset_io
-  {
-  int64_t ofs_i[vlen], ofs_o[vlen];
-  multioffset_io(multi_iter_io &it)
-    {
-    for (size_t i=0; i<vlen; ++i)
-      { ofs_i[i] = it.offset_in(); ofs_o[i] = it.offset_out(); it.advance(); }
-    }
-  int64_t in(size_t i) const { return ofs_i[i]; }
-  int64_t out(size_t i) const { return ofs_o[i]; }
-  };
-
 template<typename T> NOINLINE void pocketfft_general_c(const shape_t &shape,
   const stride_t &stride_in, const stride_t &stride_out,
   const shape_t &axes, bool forward, const cmplx<T> *data_in,
@@ -2536,17 +2475,15 @@ template<typename T>py::array complex2hartley(const py::array &in,
   auto stride_tmp(copy_strides(tmp)), stride_out(copy_strides(out));
   auto axes = makeaxes(in, axes_);
   size_t axis = axes.back();
-  multi_iter_io it(dims_tmp, stride_tmp, dims_out, stride_out, axis);
   const cmplx<T> *tdata = (const cmplx<T> *)tmp.data();
   T *odata = (T *)out.mutable_data();
+  multi_iter<1,cmplx<T>,T> it(tdata, dims_tmp, stride_tmp, odata, dims_out, stride_out, axis);
   vector<bool> swp(ndim-1,false);
   for (auto i: axes)
     if (i!=axis)
       swp[i<axis ? i : i-1] = true;
   while(it.remaining()>0)
     {
-    auto tofs = it.offset_in();
-    auto ofs = it.offset_out();
     int64_t rofs = 0;
     for (size_t i=0; i<it.pos.size(); ++i)
       {
@@ -2558,15 +2495,15 @@ template<typename T>py::array complex2hartley(const py::array &in,
         rofs += x*it.dim[i].s_o;
         }
       }
+    it.advance(1);
     for (size_t i=0; i<it.length_in(); ++i)
       {
-      auto re = tdata[tofs + i*it.stride_in()].r;
-      auto im = tdata[tofs + i*it.stride_in()].i;
+      auto re = it.in(0,i).r;
+      auto im = it.in(0,i).i;
       auto rev_i = (i==0) ? 0 : it.length_out()-i;
-      odata[ofs + i*it.stride_out()] = re+im;
+      it.out(0,i) = re+im;
       odata[rofs + rev_i*it.stride_out()] = re-im;
       }
-    it.advance();
     }
   return out;
   }
-- 
GitLab