From 1cf25cf5a368a4edc6c3136b6dd2a961a5b5f46b Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Wed, 29 May 2019 21:42:18 +0200
Subject: [PATCH] add complex functions for testing

---
 nifty_gridder.cc | 149 ++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 147 insertions(+), 2 deletions(-)

diff --git a/nifty_gridder.cc b/nifty_gridder.cc
index 99fd765..049b386 100644
--- a/nifty_gridder.cc
+++ b/nifty_gridder.cc
@@ -459,6 +459,29 @@ template<typename T> class GridderConfig
           }
       return res;
       }
+    pyarr_c<complex<T>> grid2dirty_c(const pyarr_c<complex<T>> &grid) const
+      {
+      myassert(grid.ndim()==2, "grid must be a 2D array");
+      myassert(size_t(grid.shape(0))==nu, "bad 1st dimension");
+      myassert(size_t(grid.shape(1))==nv, "bad 2nd dimension");
+      auto tmp = makearray<complex<T>>({nu, nv});
+      auto ptmp = tmp.mutable_data();
+      pocketfft::c2c({nu,nv},{grid.strides(0),grid.strides(1)},
+        {tmp.strides(0), tmp.strides(1)}, {0,1}, pocketfft::BACKWARD,
+        grid.data(), tmp.mutable_data(), T(1), 0);
+      auto res = makearray<complex<T>>({nx_dirty, ny_dirty});
+      auto pout = res.mutable_data();
+      for (size_t i=0; i<nx_dirty; ++i)
+        for (size_t j=0; j<ny_dirty; ++j)
+          {
+          size_t i2 = nu-nx_dirty/2+i;
+          if (i2>=nu) i2-=nu;
+          size_t j2 = nv-ny_dirty/2+j;
+          if (j2>=nv) j2-=nv;
+          pout[ny_dirty*i + j] = ptmp[nv*i2+j2]*cfu[i]*cfv[j];
+          }
+      return res;
+      }
     pyarr_c<T> dirty2grid(const pyarr_c<T> &dirty) const
       {
       myassert(dirty.ndim()==2, "dirty must be a 2D array");
@@ -481,6 +504,30 @@ template<typename T> class GridderConfig
       hartley2_2D<T>(tmp, tmp);
       return tmp;
       }
+    pyarr_c<complex<T>> dirty2grid_c(const pyarr_c<complex<T>> &dirty) const
+      {
+      myassert(dirty.ndim()==2, "dirty must be a 2D array");
+      myassert(size_t(dirty.shape(0))==nx_dirty, "bad 1st dimension");
+      myassert(size_t(dirty.shape(1))==ny_dirty, "bad 2nd dimension");
+      auto pdirty = dirty.data();
+      auto tmp = makearray<complex<T>>({nu, nv});
+      auto ptmp = tmp.mutable_data();
+      for (size_t i=0; i<nu*nv; ++i)
+        ptmp[i] = 0.;
+      for (size_t i=0; i<nx_dirty; ++i)
+        for (size_t j=0; j<ny_dirty; ++j)
+          {
+          size_t i2 = nu-nx_dirty/2+i;
+          if (i2>=nu) i2-=nu;
+          size_t j2 = nv-ny_dirty/2+j;
+          if (j2>=nv) j2-=nv;
+          ptmp[nv*i2+j2] = pdirty[ny_dirty*i + j]*cfu[i]*cfv[j];
+          }
+      pocketfft::c2c({nu,nv},{tmp.strides(0),tmp.strides(1)},
+        {tmp.strides(0), tmp.strides(1)}, {0,1}, pocketfft::FORWARD,
+        tmp.data(), tmp.mutable_data(), T(1), 0);
+      return tmp;
+      }
   };
 
 template<typename T> class Helper
@@ -710,6 +757,51 @@ template<typename T> pyarr_c<T> vis2grid(const Baselines<T> &baselines,
   return complex2hartley(res);
   }
 
+template<typename T> pyarr_c<complex<T>> vis2grid_c(const Baselines<T> &baselines,
+  const GridderConfig<T> &gconf, const pyarr_c<uint32_t> &idx_,
+  const pyarr_c<complex<T>> &vis_)
+  {
+  myassert(idx_.ndim()==1, "idx array must be 1D");
+  myassert(vis_.ndim()==1, "vis must be 1D");
+  auto vis=vis_.data();
+  myassert(vis_.shape(0)==idx_.shape(0), "bad vis dimension");
+  size_t nvis = size_t(idx_.shape(0));
+  auto idx = idx_.data();
+
+  size_t nu=gconf.Nu(), nv=gconf.Nv();
+  auto res = makearray<complex<T>>({nu, nv});
+  auto grid = res.mutable_data();
+  for (size_t i=0; i<nu*nv; ++i) grid[i] = 0.;
+  T ucorr = gconf.Ucorr(), vcorr=gconf.Vcorr();
+
+#pragma omp parallel
+{
+  WriteHelper<T> hlp(nu, nv, gconf.W(), grid);
+  T emb = exp(-2*hlp.beta);
+  const T * RESTRICT ku = hlp.kernel.data();
+  const T * RESTRICT kv = hlp.kernel.data()+hlp.w;
+
+  // Loop over sampling points
+#pragma omp for schedule(guided,100)
+  for (size_t ipart=0; ipart<nvis; ++ipart)
+    {
+    UVW<T> coord = baselines.effectiveCoord(idx[ipart]);
+    hlp.prep_write(coord.u*ucorr, coord.v*vcorr);
+    auto * RESTRICT ptr = hlp.p0;
+    int w = hlp.w;
+    auto v(vis[ipart]*emb);
+    for (int cu=0; cu<w; ++cu)
+      {
+      complex<T> tmp(v*ku[cu]);
+      for (int cv=0; cv<w; ++cv)
+        ptr[cv] += tmp*kv[cv];
+      ptr+=hlp.sv;
+      }
+    }
+} // end of parallel region
+  return res;
+  }
+
 template<typename T> pyarr_c<complex<T>> grid2vis(const Baselines<T> &baselines,
   const GridderConfig<T> &gconf, const pyarr_c<uint32_t> &idx_,
   const pyarr_c<T> &grid0_)
@@ -758,6 +850,53 @@ template<typename T> pyarr_c<complex<T>> grid2vis(const Baselines<T> &baselines,
   return res;
   }
 
+template<typename T> pyarr_c<complex<T>> grid2vis_c(const Baselines<T> &baselines,
+  const GridderConfig<T> &gconf, const pyarr_c<uint32_t> &idx_,
+  const pyarr_c<complex<T>> &grid_)
+  {
+  size_t nu=gconf.Nu(), nv=gconf.Nv();
+  myassert(idx_.ndim()==1, "idx array must be 1D");
+  auto grid = grid_.data();
+  myassert(grid_.ndim()==2, "data must be 2D");
+  myassert(grid_.shape(0)==int(nu), "bad grid dimension");
+  myassert(grid_.shape(1)==int(nv), "bad grid dimension");
+  size_t nvis = size_t(idx_.shape(0));
+  auto idx = idx_.data();
+
+  auto res = makearray<complex<T>>({nvis});
+  auto vis = res.mutable_data();
+  T ucorr = gconf.Ucorr(), vcorr=gconf.Vcorr();
+
+  // Loop over sampling points
+#pragma omp parallel
+{
+  ReadHelper<T> hlp(nu, nv, gconf.W(), grid);
+  T emb = exp(-2*hlp.beta);
+  const T * RESTRICT ku = hlp.kernel.data();
+  const T * RESTRICT kv = hlp.kernel.data()+hlp.w;
+
+#pragma omp for schedule(guided,100)
+  for (size_t ipart=0; ipart<nvis; ++ipart)
+    {
+    UVW<T> coord = baselines.effectiveCoord(idx[ipart]);
+    hlp.prep_read(coord.u*ucorr, coord.v*vcorr);
+    complex<T> r = 0;
+    auto * RESTRICT ptr = hlp.p0;
+    int w = hlp.w;
+    for (int cu=0; cu<w; ++cu)
+      {
+      complex<T> tmp(0);
+      for (int cv=0; cv<w; ++cv)
+        tmp += ptr[cv] * kv[cv];
+      r += tmp*ku[cu];
+      ptr += hlp.sv;
+      }
+    vis[ipart] = r*emb;
+    }
+}
+  return res;
+  }
+
 template<typename T> pyarr_c<uint32_t> getIndices(const Baselines<T> &baselines,
   const GridderConfig<T> &gconf, const pyarr_c<bool> &flags_, int chbegin,
   int chend, T wmin, T wmax)
@@ -1020,7 +1159,7 @@ PYBIND11_MODULE(nifty_gridder, m)
     .def ("Nrows",&Baselines<double>::Nrows)
     .def ("Nchannels",&Baselines<double>::Nchannels)
     .def ("ms2vis",&Baselines<double>::ms2vis<complex<double>>, BL_ms2vis_DS, "ms"_a, "idx"_a)
-    .def ("ms2vis_f32",&Baselines<double>::ms2vis<float>, "ms"_a, "idx"_a)
+//    .def ("ms2vis_f32",&Baselines<double>::ms2vis<float>, "ms"_a, "idx"_a)
     .def ("vis2ms",&Baselines<double>::vis2ms<complex<double>>, BL_vis2ms_DS, "vis"_a, "idx"_a)
     .def ("add_vis_to_ms",&Baselines<double>::add_vis_to_ms<complex<double>>, BL_add_vis_to_ms_DS,
       "vis"_a, "idx"_a, "ms"_a.noconvert());
@@ -1030,12 +1169,17 @@ PYBIND11_MODULE(nifty_gridder, m)
     .def("Nu", &GridderConfig<double>::Nu)
     .def("Nv", &GridderConfig<double>::Nv)
     .def("grid2dirty", &GridderConfig<double>::grid2dirty, grid2dirty_DS, "grid"_a)
-    .def("dirty2grid", &GridderConfig<double>::dirty2grid, dirty2grid_DS, "dirty"_a);
+    .def("grid2dirty_c", &GridderConfig<double>::grid2dirty_c, "grid"_a)
+    .def("dirty2grid", &GridderConfig<double>::dirty2grid, dirty2grid_DS, "dirty"_a)
+    .def("dirty2grid_c", &GridderConfig<double>::dirty2grid_c, "dirty"_a);
   m.def("getIndices", getIndices<double>, getIndices_DS, "baselines"_a, "gconf"_a,
     "flags"_a, "chbegin"_a=-1, "chend"_a=-1, "wmin"_a=-1e30, "wmax"_a=1e30);
   m.def("vis2grid",&vis2grid<double>, vis2grid_DS, "baselines"_a, "gconf"_a, "idx"_a, "vis"_a);
   m.def("grid2vis",&grid2vis<double>, grid2vis_DS, "baselines"_a, "gconf"_a, "idx"_a, "grid"_a);
+  m.def("vis2grid_c",&vis2grid_c<double>, "baselines"_a, "gconf"_a, "idx"_a, "vis"_a);
+  m.def("grid2vis_c",&grid2vis_c<double>, "baselines"_a, "gconf"_a, "idx"_a, "grid"_a);
 
+#if 0
   py::class_<Baselines<float>> (m, "Baselines_f", Baselines_DS)
     .def(py::init<const pyarr_c<float> &, const pyarr_c<float> &>(),
       "coord"_a, "scaling"_a)
@@ -1056,4 +1200,5 @@ PYBIND11_MODULE(nifty_gridder, m)
     "flags"_a, "chbegin"_a=-1, "chend"_a=-1, "wmin"_a=-1e30, "wmax"_a=1e30);
   m.def("vis2grid_f",&vis2grid<float>, "baselines"_a, "gconf"_a, "idx"_a, "vis"_a);
   m.def("grid2vis_f",&grid2vis<float>, "baselines"_a, "gconf"_a, "idx"_a, "grid"_a);
+#endif
   }
-- 
GitLab