From ef22e88626175d9a3db696b91d27723ca63037ba Mon Sep 17 00:00:00 2001
From: Chichi Lalescu <clalesc1@jhu.edu>
Date: Thu, 24 Dec 2015 20:53:20 +0100
Subject: [PATCH] reimplement sample as method of rFFTW_interpolator

---
 bfps/cpp/rFFTW_interpolator.cpp | 30 ++++++++++++++++++++++++++++++
 bfps/cpp/rFFTW_interpolator.hpp |  1 +
 bfps/cpp/rFFTW_particles.cpp    | 25 +------------------------
 bfps/cpp/rFFTW_particles.hpp    |  1 -
 4 files changed, 32 insertions(+), 25 deletions(-)

diff --git a/bfps/cpp/rFFTW_interpolator.cpp b/bfps/cpp/rFFTW_interpolator.cpp
index 46673eda..fb0cc7ef 100644
--- a/bfps/cpp/rFFTW_interpolator.cpp
+++ b/bfps/cpp/rFFTW_interpolator.cpp
@@ -143,6 +143,36 @@ void rFFTW_interpolator<rnumber, interp_neighbours>::get_grid_coordinates(
     }
 }
 
+template <class rnumber, int interp_neighbours>
+void rFFTW_interpolator<rnumber, interp_neighbours>::sample(
+        const int nparticles,
+        const int pdimension,
+        const double t,
+        const double *__restrict__ x,
+        double *__restrict__ y,
+        const int *deriv)
+{
+    /* get grid coordinates */
+    int *xg = new int[3*nparticles];
+    double *xx = new double[3*nparticles];
+    double *yy =  new double[3*nparticles];
+    std::fill_n(yy, 3*nparticles, 0.0);
+    this->get_grid_coordinates(nparticles, pdimension, x, xg, xx);
+    /* perform interpolation */
+    for (int p=0; p<nparticles; p++)
+        this->operator()(t, xg + p*3, xx + p*3, yy + p*3, deriv);
+    MPI_Allreduce(
+            yy,
+            y,
+            3*nparticles,
+            MPI_DOUBLE,
+            MPI_SUM,
+            this->descriptor->comm);
+    delete[] yy;
+    delete[] xg;
+    delete[] xx;
+}
+
 template <class rnumber, int interp_neighbours>
 void rFFTW_interpolator<rnumber, interp_neighbours>::operator()(
         const double t,
diff --git a/bfps/cpp/rFFTW_interpolator.hpp b/bfps/cpp/rFFTW_interpolator.hpp
index c7efb9d8..1f0d3065 100644
--- a/bfps/cpp/rFFTW_interpolator.hpp
+++ b/bfps/cpp/rFFTW_interpolator.hpp
@@ -79,6 +79,7 @@ class rFFTW_interpolator
         /* interpolate field at an array of locations */
         void sample(
                 const int nparticles,
+                const int pdimension,
                 const double t,
                 const double *__restrict__ x,
                 double *__restrict__ y,
diff --git a/bfps/cpp/rFFTW_particles.cpp b/bfps/cpp/rFFTW_particles.cpp
index 4d0ab85b..463ee897 100644
--- a/bfps/cpp/rFFTW_particles.cpp
+++ b/bfps/cpp/rFFTW_particles.cpp
@@ -154,7 +154,7 @@ void rFFTW_particles<particle_type, rnumber, interp_neighbours>::get_rhs(double
     switch(particle_type)
     {
         case VELOCITY_TRACER:
-            this->sample_vec_field(this->vel, t, x, y);
+            this->vel->sample(this->nparticles, this->ncomponents, t, x, y);
             break;
     }
 }
@@ -255,29 +255,6 @@ void rFFTW_particles<particle_type, rnumber, interp_neighbours>::step()
 }
 
 
-
-template <int particle_type, class rnumber, int interp_neighbours>
-void rFFTW_particles<particle_type, rnumber, interp_neighbours>::get_grid_coordinates(double *x, int *xg, double *xx)
-{
-    static double grid_size[] = {this->dx, this->dy, this->dz};
-    double tval;
-    std::fill_n(xg, this->nparticles*3, 0);
-    std::fill_n(xx, this->nparticles*3, 0.0);
-    for (int p=0; p<this->nparticles; p++)
-    {
-        for (int c=0; c<3; c++)
-        {
-            tval = floor(x[p*this->ncomponents+c]/grid_size[c]);
-            xg[p*3+c] = MOD(int(tval), this->fs->rd->sizes[2-c]);
-            xx[p*3+c] = (x[p*this->ncomponents+c] - tval*grid_size[c]) / grid_size[c];
-        }
-        /*xg[p*3+2] -= this->fs->rd->starts[0];
-        if (this->myrank == this->fs->rd->rank[0] &&
-            xg[p*3+2] > this->fs->rd->subsizes[0])
-            xg[p*3+2] -= this->fs->rd->sizes[0];*/
-    }
-}
-
 template <int particle_type, class rnumber, int interp_neighbours>
 void rFFTW_particles<particle_type, rnumber, interp_neighbours>::read(hid_t data_file_id)
 {
diff --git a/bfps/cpp/rFFTW_particles.hpp b/bfps/cpp/rFFTW_particles.hpp
index e658b0f3..fdf20c2a 100644
--- a/bfps/cpp/rFFTW_particles.hpp
+++ b/bfps/cpp/rFFTW_particles.hpp
@@ -90,7 +90,6 @@ class rFFTW_particles
         void get_rhs(double *__restrict__ x, double *__restrict__ rhs);
         void get_rhs(double t, double *__restrict__ x, double *__restrict__ rhs);
 
-        void get_grid_coordinates(double *__restrict__ x, int *__restrict__ xg, double *__restrict__ xx);
         void sample_vec_field(
             rFFTW_interpolator<rnumber, interp_neighbours> *vec,
             double t,
-- 
GitLab