Commit 5a4895c8 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge remote-tracking branch 'origin/master' into get_taper

parents dabb3cbc 27b2d163
......@@ -23,8 +23,9 @@
#include <pybind11/numpy.h>
#include <iostream>
#include <algorithm>
#include <cstdlib>
#include <cmath>
#define POCKETFFT_OPENMP
#include "pocketfft_hdronly.h"
#ifdef __GNUC__
......@@ -63,6 +64,32 @@ template<typename T> inline T fmodulo (T v1, T v2)
return (tmp==v2) ? T(0) : tmp;
}
static size_t nthreads = 1;
constexpr auto set_nthreads_DS = R"""(
Specifies the number of threads to be used by the module
Parameters
==========
nthreads: int
the number of threads to be used. Must be >=1.
)""";
void set_nthreads(size_t nthreads_)
{
myassert(nthreads_>=1, "nthreads must be >= 1");
nthreads = nthreads_;
}
constexpr auto get_nthreads_DS = R"""(
Returns the number of threads used by the module
Returns
=======
int : the number of threads used by the module
)""";
size_t get_nthreads()
{ return nthreads; }
//
// Utilities for Gauss-Legendre quadrature
//
......@@ -81,7 +108,7 @@ void legendre_prep(int n, vector<double> &x, vector<double> &w)
double t0 = 1 - (1-1./n) / (8.*n*n);
double t1 = 1./(4.*n+2.);
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
int i;
#pragma omp for schedule(dynamic,100)
......@@ -115,7 +142,7 @@ void legendre_prep(int n, vector<double> &x, vector<double> &w)
if (dobreak) break;
if (abs(dx)<=eps) dobreak=1;
if (++j>=100) throw runtime_error("convergence problem");
myassert(++j<100, "convergence problem");
}
x[m-i] = x0;
......@@ -213,7 +240,7 @@ template<typename T> pyarr_c<T> complex2hartley
auto grid2 = res.mutable_data();
{
py::gil_scoped_release release;
#pragma omp parallel for
#pragma omp parallel for num_threads(nthreads)
for (size_t u=0; u<nu; ++u)
{
size_t xu = (u==0) ? 0 : nu-u;
......@@ -241,7 +268,7 @@ template<typename T> pyarr_c<complex<T>> hartley2complex
auto grid2 = res.mutable_data();
{
py::gil_scoped_release release;
#pragma omp parallel for
#pragma omp parallel for num_threads(nthreads)
for (size_t u=0; u<nu; ++u)
{
size_t xu = (u==0) ? 0 : nu-u;
......@@ -268,8 +295,9 @@ template<typename T> void hartley2_2D(const pyarr_c<T> &in, pyarr_c<T> &out)
auto ptmp = out.mutable_data();
{
py::gil_scoped_release release;
pocketfft::r2r_separable_hartley({nu, nv}, s_i, s_o, {0,1}, d_i, ptmp, T(1), 0);
#pragma omp parallel for
pocketfft::r2r_separable_hartley({nu, nv}, s_i, s_o, {0,1}, d_i, ptmp, T(1),
nthreads);
#pragma omp parallel for num_threads(nthreads)
for(size_t i=1; i<(nu+1)/2; ++i)
for(size_t j=1; j<(nv+1)/2; ++j)
{
......@@ -299,7 +327,7 @@ vector<double> correction_factors (size_t n, size_t nval, size_t w)
for (auto &v:psi)
v = exp(beta*(sqrt(1-v*v)-1.));
vector<double> res(nval);
#pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static) num_threads(nthreads)
for (size_t k=0; k<nval; ++k)
{
double tmp=0;
......@@ -397,7 +425,7 @@ template<typename T> class Baselines
auto vis = res.mutable_data();
{
py::gil_scoped_release release;
#pragma omp parallel for
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nvis; ++i)
{
auto t = idx(i);
......@@ -439,7 +467,7 @@ template<typename T> class Baselines
auto ms = res.template mutable_unchecked<2>();
{
py::gil_scoped_release release;
#pragma omp parallel for
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nvis; ++i)
{
auto t = idx(i);
......@@ -613,7 +641,7 @@ template<typename T> class GridderConfig
auto ptmp = tmp.mutable_data();
pocketfft::c2c({nu,nv},{grid.strides(0),grid.strides(1)},
{tmp.strides(0), tmp.strides(1)}, {0,1}, pocketfft::BACKWARD,
grid.data(), tmp.mutable_data(), T(1), 0);
grid.data(), tmp.mutable_data(), T(1), nthreads);
auto res = makeArray<complex<T>>({nx_dirty, ny_dirty});
auto pout = res.mutable_data();
{
......@@ -675,7 +703,7 @@ template<typename T> class GridderConfig
ptmp[nv*i2+j2] = pdirty[ny_dirty*i + j]*cfu[i]*cfv[j];
}
pocketfft::c2c({nu,nv}, strides, strides, {0,1}, pocketfft::FORWARD,
ptmp, ptmp, T(1), 0);
ptmp, ptmp, T(1), nthreads);
}
return tmp;
}
......@@ -690,19 +718,19 @@ template<typename T> class GridderConfig
}
};
template<typename T> class Helper
template<typename T, typename T2=complex<T>> class Helper
{
private:
const GridderConfig<T> &gconf;
int nu, nv, nsafe, w;
T beta;
const complex<T> *grid_r;
complex<T> *grid_w;
const T2 *grid_r;
T2 *grid_w;
int su, sv;
int iu0, iv0; // start index of the current visibility
int bu0, bv0; // start index of the current buffer
vector<complex<T>> rbuf, wbuf;
vector<T2> rbuf, wbuf;
void dump() const
{
......@@ -742,12 +770,11 @@ template<typename T> class Helper
}
public:
const complex<T> *p0r;
complex<T> *p0w;
const T2 *p0r;
T2 *p0w;
vector<T> kernel;
Helper(const GridderConfig<T> &gconf_, const complex<T> *grid_r_,
complex<T> *grid_w_)
Helper(const GridderConfig<T> &gconf_, const T2 *grid_r_, T2 *grid_w_)
: gconf(gconf_), nu(gconf.Nu()), nv(gconf.Nv()), nsafe(gconf.Nsafe()),
w(gconf.W()), beta(gconf.Beta()), grid_r(grid_r_), grid_w(grid_w_),
su(2*nsafe+(1<<logsquare)), sv(2*nsafe+(1<<logsquare)),
......@@ -830,7 +857,7 @@ template<typename T> pyarr_c<complex<T>> vis2grid_c(
T beta = gconf.Beta();
size_t w = gconf.W();
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, nullptr, grid);
T emb = exp(-2*beta);
......@@ -929,7 +956,7 @@ template<typename T> pyarr_c<complex<T>> ms2grid_c(
T beta = gconf.Beta();
size_t w = gconf.W();
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, nullptr, grid);
T emb = exp(-2*beta);
......@@ -989,7 +1016,7 @@ template<typename T> pyarr_c<complex<T>> ms2grid_c_wgt(
T beta = gconf.Beta();
size_t w = gconf.W();
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, nullptr, grid);
T emb = exp(-2*beta);
......@@ -1046,7 +1073,7 @@ template<typename T> pyarr_c<complex<T>> grid2vis_c(
size_t w = gconf.W();
// Loop over sampling points
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, grid, nullptr);
T emb = exp(-2*beta);
......@@ -1122,7 +1149,7 @@ template<typename T> pyarr_c<complex<T>> grid2ms_c(const Baselines<T> &baselines
size_t w = gconf.W();
// Loop over sampling points
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, grid, nullptr);
T emb = exp(-2*beta);
......@@ -1184,7 +1211,7 @@ template<typename T> pyarr_c<complex<T>> grid2ms_c_wgt(
size_t w = gconf.W();
// Loop over sampling points
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, grid, nullptr);
T emb = exp(-2*beta);
......@@ -1245,7 +1272,7 @@ template<typename T> pyarr_c<complex<T>> apply_holo(
size_t w = gconf.W();
// Loop over sampling points
#pragma omp parallel
#pragma omp parallel num_threads(nthreads)
{
Helper<T> hlp(gconf, grid, ogrid);
T emb = exp(-2*beta);
......@@ -1283,6 +1310,63 @@ template<typename T> pyarr_c<complex<T>> apply_holo(
return res;
}
template<typename T> pyarr_c<T> get_correlations(
const Baselines<T> &baselines, const GridderConfig<T> &gconf,
const pyarr<uint32_t> &idx_, int du, int dv)
{
size_t nu=gconf.Nu(), nv=gconf.Nv();
size_t w = gconf.W();
myassert(size_t(abs(du))<w, "|du| must be smaller than W");
myassert(size_t(abs(dv))<w, "|dv| must be smaller than W");
checkArray(idx_, "idx", {0});
size_t nvis = size_t(idx_.shape(0));
auto idx = idx_.template unchecked<1>();
auto res = makeArray<T>({nu, nv});
auto ogrid = res.mutable_data();
{
py::gil_scoped_release release;
T beta = gconf.Beta();
for (size_t i=0; i<nu*nv; ++i) ogrid[i] = 0.;
size_t u0, u1, v0, v1;
if (du>=0)
{ u0=0; u1=w-du; }
else
{ u0=-du; u1=w; }
if (dv>=0)
{ v0=0; v1=w-dv; }
else
{ v0=-dv; v1=w; }
// Loop over sampling points
#pragma omp parallel num_threads(nthreads)
{
Helper<T,T> hlp(gconf, nullptr, ogrid);
T emb = exp(-2*beta);
int jump = hlp.lineJump();
const T * RESTRICT ku = hlp.kernel.data();
const T * RESTRICT kv = hlp.kernel.data()+w;
#pragma omp for schedule(guided,100)
for (size_t ipart=0; ipart<nvis; ++ipart)
{
UVW<T> coord = baselines.effectiveCoord(idx(ipart));
hlp.prep(coord.u, coord.v);
auto * RESTRICT wptr = hlp.p0w + u0*jump;
for (size_t cu=u0; cu<u1; ++cu)
{
auto f1=ku[cu]*ku[cu+du]*emb*emb;
for (size_t cv=v0; cv<v1; ++cv)
wptr[cv] += f1*kv[cv]*kv[cv+dv];
wptr += jump;
}
}
}
}
return res;
}
constexpr auto getIndices_DS = R"""(
Selects a subset of entries from a `Baselines` object.
......@@ -1411,8 +1495,7 @@ PYBIND11_MODULE(nifty_gridder, m)
},
// __setstate__
[](py::tuple t) {
if(t.size() != 5)
{ throw std::runtime_error("Invalid state"); }
myassert(t.size()==5,"Invalid state");
// Reconstruct from tuple
return GridderConfig<double>(t[0].cast<size_t>(), t[1].cast<size_t>(),
......@@ -1436,7 +1519,7 @@ PYBIND11_MODULE(nifty_gridder, m)
m.def("grid2ms_wgt",&grid2ms_wgt<double>, "baselines"_a, "gconf"_a, "idx"_a,
"grid"_a, "wgt"_a, "ms_in"_a=None);
m.def("vis2grid_c",&vis2grid_c<double>, vis2grid_c_DS, "baselines"_a,
"gconf"_a, "idx"_a, "vis"_a, "grid_in"_a);
"gconf"_a, "idx"_a, "vis"_a, "grid_in"_a=None);
m.def("ms2grid_c",&ms2grid_c<double>, ms2grid_c_DS, "baselines"_a, "gconf"_a,
"idx"_a, "ms"_a, "grid_in"_a=None);
m.def("ms2grid_c_wgt",&ms2grid_c_wgt<double>, "baselines"_a, "gconf"_a,
......@@ -1449,4 +1532,8 @@ PYBIND11_MODULE(nifty_gridder, m)
"idx"_a, "grid"_a, "wgt"_a, "ms_in"_a=None);
m.def("apply_holo",&apply_holo<double>, "baselines"_a, "gconf"_a, "idx"_a,
"grid"_a);
m.def("get_correlations", &get_correlations<double>, "baselines"_a, "gconf"_a,
"idx"_a, "du"_a, "dv"_a);
m.def("set_nthreads",&set_nthreads, set_nthreads_DS, "nthreads"_a);
m.def("get_nthreads",&get_nthreads, get_nthreads_DS);
}
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment