From 6203d28e2d0a065d75ff6e8cd6dd415a100b10f3 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Tue, 28 May 2019 10:41:17 +0200
Subject: [PATCH] intriduce flags

---
 nifty_gridder.cc | 34 +++++++++++++++++++++++-----------
 1 file changed, 23 insertions(+), 11 deletions(-)

diff --git a/nifty_gridder.cc b/nifty_gridder.cc
index cc58602..1577e14 100644
--- a/nifty_gridder.cc
+++ b/nifty_gridder.cc
@@ -394,19 +394,22 @@ template<typename T> class Baselines
   private:
     vector<UVW<T>> coord;
     vector<T> scaling;
+    vector<bool> flags;
     size_t nrows, nchan;
     size_t channelbits, channelmask;
 
   public:
-    Baselines(const pyarr_c<T> &coord_, const pyarr_c<T> &scaling_)
+    Baselines(const pyarr_c<T> &coord_, const pyarr_c<T> &scaling_,
+       const pyarr_c<bool> &flags_)
       {
       myassert(coord_.ndim()==2, "coord array must be 2D");
       myassert(coord_.shape(1)==3, "coord.shape[1] must be 3");
       myassert(scaling_.ndim()==1, "scaling array must be 1D");
-      nrows = coord_.shape(0)/scaling_.shape(0);
-      myassert(nrows*size_t(scaling_.shape(0))==size_t(coord_.shape(0)),
-        "bad array dimensions");
+      nrows = coord_.shape(0);
       nchan = scaling_.shape(0);
+      myassert(flags_.ndim()==2, "flags array must be 2D");
+      myassert(size_t(flags_.shape(0))==nrows, "flags.shape[0] must be nrows");
+      myassert(size_t(flags_.shape(1))==nchan, "flags.shape[1] must be chan");
       scaling.resize(nchan);
       for (size_t i=0; i<nchan; ++i)
         scaling[i] = scaling_.data()[i];
@@ -414,6 +417,10 @@ template<typename T> class Baselines
       auto cood = coord_.data();
       for (size_t i=0; i<coord.size(); ++i)
         coord[i] = UVW<T>(cood[3*i], cood[3*i+1], cood[3*i+2]);
+      flags.resize(nrows*nchan);
+      auto pflags=flags_.data();
+      for (size_t i=0; i<nrows*nchan; ++i)
+        flags[i] = (pflags[i]!=0);
       channelbits = bits_needed(nchan);
       channelmask = (size_t(1)<<channelbits)-1;
       auto rowbits = bits_needed(nrows);
@@ -427,11 +434,12 @@ template<typename T> class Baselines
       vector<uint32_t> odata;
       for (size_t i=0; i<nrows; ++i)
         for (int j=chbegin; j<chend; ++j)
-          {
-          auto w = coord[i].w*scaling[j];
-          if ((w>=wmin) && (w<wmax))
-            odata.push_back((i<<channelbits)+j);
-          }
+          if (!flags[i*nchan + j])
+            {
+            auto w = coord[i].w*scaling[j];
+            if ((w>=wmin) && (w<wmax))
+              odata.push_back((i<<channelbits)+j);
+            }
       return odata;
       }
 
@@ -915,6 +923,8 @@ coord: np.array((nrows, 3), dtype=np.float)
     u, v and w coordinates for each row
 scaling: np.array((nchannels,), dtype=np.float)
     scaling factor for u, v, w for each individual channel
+flags: np.array((nrows, nchannels), dtype=np.bool)
+    "True" indicates that the value should not be used
 )""";
 
 const char *BL_ms2vis_DS = R"""(
@@ -974,7 +984,8 @@ PYBIND11_MODULE(nifty_gridder, m)
   using namespace pybind11::literals;
 
   py::class_<Baselines<double>> (m, "Baselines", Baselines_DS)
-    .def(py::init<pyarr_c<double>, pyarr_c<double>>(), "coord"_a, "scaling"_a)
+    .def(py::init<const pyarr_c<double> &, const pyarr_c<double> &,
+      const pyarr_c<bool> &>(), "coord"_a, "scaling"_a, "flags"_a)
     .def ("ms2vis",&Baselines<double>::ms2vis, BL_ms2vis_DS, "ms"_a, "idx"_a)
     .def ("vis2ms",&Baselines<double>::vis2ms, BL_vis2ms_DS, "vis"_a, "idx"_a)
     .def ("add_vis_to_ms",&Baselines<double>::add_vis_to_ms, BL_add_vis_to_ms_DS,
@@ -992,7 +1003,8 @@ PYBIND11_MODULE(nifty_gridder, m)
   m.def("grid2vis",&grid2vis<double>, "baselines"_a, "gconf"_a, "idx"_a, "grid"_a);
 
   py::class_<Baselines<float>> (m, "Baselines_f", Baselines_DS)
-    .def(py::init<pyarr_c<float>, pyarr_c<float>>(), "coord"_a, "scaling"_a)
+    .def(py::init<const pyarr_c<float> &, const pyarr_c<float> &,
+      const pyarr_c<bool> &>(), "coord"_a, "scaling"_a, "flags"_a)
     .def ("ms2vis",&Baselines<float>::ms2vis, BL_ms2vis_DS, "ms"_a, "idx"_a)
     .def ("vis2ms",&Baselines<float>::vis2ms, BL_vis2ms_DS, "vis"_a, "idx"_a)
     .def ("add_vis_to_ms",&Baselines<float>::add_vis_to_ms, BL_add_vis_to_ms_DS,
-- 
GitLab