diff --git a/bfps/NavierStokes.py b/bfps/NavierStokes.py index 643a3d2ff241a5b13884117fa09f1c5744b61a1c..29d9a8fc1bf4ca9cc2285d7e3851f8ace1ba583a 100644 --- a/bfps/NavierStokes.py +++ b/bfps/NavierStokes.py @@ -237,11 +237,11 @@ class NavierStokes(bfps.fluid_base.fluid_particle_base): kcut = None, neighbours = 1, name = 'particle_field'): - self.particle_variables += ('interpolator<{0}> *vel_{1}, *acc_{1};\n' + - '{0} *{1}_tmp;\n').format(self.C_dtype, name) - self.particle_start += ('vel_{0} = new interpolator<{1}>(fs, {2});\n' + - 'acc_{0} = new interpolator<{1}>(fs, {2});\n' + - '{0}_tmp = new {1}[acc_{0}->src_descriptor->local_size];\n').format(name, self.C_dtype, neighbours+1) + self.particle_variables += ('interpolator<{0}, {1}> *vel_{2}, *acc_{2};\n' + + '{0} *{2}_tmp;\n').format(self.C_dtype, neighbours, name) + self.particle_start += ('vel_{0} = new interpolator<{1}, {2}>(fs);\n' + + 'acc_{0} = new interpolator<{1}, {2}>(fs);\n' + + '{0}_tmp = new {1}[acc_{0}->unbuffered_descriptor->local_size];\n').format(name, self.C_dtype, neighbours) self.particle_end += ('delete vel_{0};\n' + 'delete acc_{0};\n' + 'delete[] {0}_tmp;\n').format(name) diff --git a/bfps/cpp/interpolator.cpp b/bfps/cpp/interpolator.cpp index 066b8c308a9be18220275cd2505ccd8d9d622a00..7609387b43fadeb34503379dfaae659322ef7ae7 100644 --- a/bfps/cpp/interpolator.cpp +++ b/bfps/cpp/interpolator.cpp @@ -26,23 +26,21 @@ #include "interpolator.hpp" -template <class rnumber> -interpolator<rnumber>::interpolator( - fluid_solver_base<rnumber> *fs, - const int bw) +template <class rnumber, int interp_neighbours> +interpolator<rnumber, interp_neighbours>::interpolator( + fluid_solver_base<rnumber> *fs) { int tdims[4]; - this->buffer_width = bw; - this->src_descriptor = fs->rd; - this->buffer_size = this->buffer_width*this->src_descriptor->slice_size; - tdims[0] = this->buffer_width*2*this->src_descriptor->nprocs + this->src_descriptor->sizes[0]; - tdims[1] = this->src_descriptor->sizes[1]; - tdims[2] = this->src_descriptor->sizes[2]; - tdims[3] = this->src_descriptor->sizes[3]; + this->unbuffered_descriptor = fs->rd; + this->buffer_size = (interp_neighbours+1)*this->unbuffered_descriptor->slice_size; + tdims[0] = (interp_neighbours+1)*2*this->unbuffered_descriptor->nprocs + this->unbuffered_descriptor->sizes[0]; + tdims[1] = this->unbuffered_descriptor->sizes[1]; + tdims[2] = this->unbuffered_descriptor->sizes[2]; + tdims[3] = this->unbuffered_descriptor->sizes[3]; this->descriptor = new field_descriptor<rnumber>( 4, tdims, - this->src_descriptor->mpi_dtype, - this->src_descriptor->comm); + this->unbuffered_descriptor->mpi_dtype, + this->unbuffered_descriptor->comm); this->f = new rnumber[this->descriptor->local_size]; //if (sizeof(rnumber) == 4) // this->f = fftwf_alloc_real(this->descriptor->local_size); @@ -50,74 +48,84 @@ interpolator<rnumber>::interpolator( // this->f = fftw_alloc_real(this->descriptor->local_size); } -template <class rnumber> -interpolator<rnumber>::~interpolator() +template <class rnumber, int interp_neighbours> +interpolator<rnumber, interp_neighbours>::~interpolator() { delete[] this->f; delete this->descriptor; } -template <class rnumber> -int interpolator<rnumber>::read_rFFTW(void *void_src) +template <class rnumber, int interp_neighbours> +int interpolator<rnumber, interp_neighbours>::read_rFFTW(void *void_src) { rnumber *src = (rnumber*)void_src; rnumber *dst = this->f; /* do big copy of middle stuff */ std::copy(src, - src + this->src_descriptor->local_size, + src + this->unbuffered_descriptor->local_size, dst + this->buffer_size); MPI_Datatype MPI_RNUM = (sizeof(rnumber) == 4) ? MPI_FLOAT : MPI_DOUBLE; int rsrc; /* get upper slices */ - for (int rdst = 0; rdst < this->src_descriptor->nprocs; rdst++) + for (int rdst = 0; rdst < this->unbuffered_descriptor->nprocs; rdst++) { - rsrc = this->src_descriptor->rank[(this->src_descriptor->all_start0[rdst] + - this->src_descriptor->all_size0[rdst]) % - this->src_descriptor->sizes[0]]; - if (this->src_descriptor->myrank == rsrc) + rsrc = this->unbuffered_descriptor->rank[(this->unbuffered_descriptor->all_start0[rdst] + + this->unbuffered_descriptor->all_size0[rdst]) % + this->unbuffered_descriptor->sizes[0]]; + if (this->unbuffered_descriptor->myrank == rsrc) MPI_Send( src, this->buffer_size, MPI_RNUM, rdst, - 2*(rsrc*this->src_descriptor->nprocs + rdst), + 2*(rsrc*this->unbuffered_descriptor->nprocs + rdst), this->descriptor->comm); - if (this->src_descriptor->myrank == rdst) + if (this->unbuffered_descriptor->myrank == rdst) MPI_Recv( - dst + this->buffer_size + this->src_descriptor->local_size, + dst + this->buffer_size + this->unbuffered_descriptor->local_size, this->buffer_size, MPI_RNUM, rsrc, - 2*(rsrc*this->src_descriptor->nprocs + rdst), + 2*(rsrc*this->unbuffered_descriptor->nprocs + rdst), this->descriptor->comm, MPI_STATUS_IGNORE); } /* get lower slices */ - for (int rdst = 0; rdst < this->src_descriptor->nprocs; rdst++) + for (int rdst = 0; rdst < this->unbuffered_descriptor->nprocs; rdst++) { - rsrc = this->src_descriptor->rank[MOD(this->src_descriptor->all_start0[rdst] - 1, - this->src_descriptor->sizes[0])]; - if (this->src_descriptor->myrank == rsrc) + rsrc = this->unbuffered_descriptor->rank[MOD(this->unbuffered_descriptor->all_start0[rdst] - 1, + this->unbuffered_descriptor->sizes[0])]; + if (this->unbuffered_descriptor->myrank == rsrc) MPI_Send( - src + this->src_descriptor->local_size - this->buffer_size, + src + this->unbuffered_descriptor->local_size - this->buffer_size, this->buffer_size, MPI_RNUM, rdst, - 2*(rsrc*this->src_descriptor->nprocs + rdst)+1, - this->src_descriptor->comm); - if (this->src_descriptor->myrank == rdst) + 2*(rsrc*this->unbuffered_descriptor->nprocs + rdst)+1, + this->unbuffered_descriptor->comm); + if (this->unbuffered_descriptor->myrank == rdst) MPI_Recv( dst, this->buffer_size, MPI_RNUM, rsrc, - 2*(rsrc*this->src_descriptor->nprocs + rdst)+1, - this->src_descriptor->comm, + 2*(rsrc*this->unbuffered_descriptor->nprocs + rdst)+1, + this->unbuffered_descriptor->comm, MPI_STATUS_IGNORE); } return EXIT_SUCCESS; } -template class interpolator<float>; -template class interpolator<double>; +template class interpolator<float, 1>; +template class interpolator<float, 2>; +template class interpolator<float, 3>; +template class interpolator<float, 4>; +template class interpolator<float, 5>; +template class interpolator<float, 6>; +template class interpolator<double, 1>; +template class interpolator<double, 2>; +template class interpolator<double, 3>; +template class interpolator<double, 4>; +template class interpolator<double, 5>; +template class interpolator<double, 6>; diff --git a/bfps/cpp/interpolator.hpp b/bfps/cpp/interpolator.hpp index 65c017d057475b7bf052ef463241e476efc11f4e..e6576453500b69354aa617b0c11a7ec4d2ba024e 100644 --- a/bfps/cpp/interpolator.hpp +++ b/bfps/cpp/interpolator.hpp @@ -43,19 +43,17 @@ typedef void (*base_polynomial_values)( double fraction, double *destination); -template <class rnumber> +template <class rnumber, int interp_neighbours> class interpolator { public: - int buffer_width; ptrdiff_t buffer_size; field_descriptor<rnumber> *descriptor; - field_descriptor<rnumber> *src_descriptor; + field_descriptor<rnumber> *unbuffered_descriptor; rnumber *f; interpolator( - fluid_solver_base<rnumber> *FSOLVER, - const int BUFFER_WIDTH); + fluid_solver_base<rnumber> *FSOLVER); ~interpolator(); /* destroys input */