Commit 99be1cb6 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

generalize data types

parent 8610cc13
......@@ -558,7 +558,7 @@ template<typename T> class GridderConfig
{grid.strides(0), grid.strides(1)},
{tmp.strides(0), tmp.strides(1)}, {0,1},
grid.data(),
ptmp, 1.);
ptmp, T(1), 0);
for(size_t i=1; i<(nu+1)/2; ++i)
for(size_t j=1; j<(nv+1)/2; ++j)
{
......@@ -603,7 +603,7 @@ template<typename T> class GridderConfig
pocketfft::r2r_hartley({nu, nv},
{tmp.strides(0), tmp.strides(1)},
{tmp.strides(0), tmp.strides(1)}, {0,1},
tmp.data(), tmp.mutable_data(), 1.);
tmp.data(), tmp.mutable_data(), T(1), 0);
for(size_t i=1; i<(nu+1)/2; ++i)
for(size_t j=1; j<(nv+1)/2; ++j)
{
......@@ -620,38 +620,38 @@ template<typename T> class GridderConfig
}
};
class Helper
template<typename T> class Helper
{
protected:
int nu, nv;
public:
int w;
double beta;
T beta;
protected:
int nsafe, su;
public:
int sv;
vector<double> kernel;
vector<T> kernel;
int iu0, iv0; // start index of the current visibility
int bu0, bv0; // start index of the current buffer
void NOINLINE update(double u_in, double v_in)
void NOINLINE update(T u_in, T v_in)
{
auto u = fmodulo(u_in, 1.)*nu;
auto u = fmodulo(u_in, T(1))*nu;
iu0 = int(u-w*0.5 + 1 + nu) - nu;
if (iu0+w>nu+nsafe) iu0 = nu+nsafe-w;
auto v = fmodulo(v_in, 1.)*nv;
auto v = fmodulo(v_in, T(1))*nv;
iv0 = int(v-w*0.5 + 1 + nv) - nv;
if (iv0+w>nv+nsafe) iv0 = nv+nsafe-w;
double xw=2./w;
T xw=T(2)/w;
for (int i=0; i<w; ++i)
{
kernel[i ] = xw*(iu0+i-u);
kernel[i+w] = xw*(iv0+i-v);
}
for (auto &k : kernel)
k = exp(beta*sqrt(1.-k*k));
k = exp(beta*sqrt(T(1)-k*k));
}
bool need_to_move() const
......@@ -674,11 +674,31 @@ class Helper
}
};
class WriteHelper: public Helper
template<typename T> class WriteHelper: public Helper<T>
{
protected:
using Helper<T>::nu;
using Helper<T>::nv;
public:
using Helper<T>::w;
using Helper<T>::beta;
protected:
using Helper<T>::nsafe;
using Helper<T>::su;
public:
using Helper<T>::sv;
using Helper<T>::kernel;
using Helper<T>::iu0;
using Helper<T>::iv0;
using Helper<T>::bu0;
using Helper<T>::bv0;
using Helper<T>::need_to_move;
using Helper<T>::update_position;
using Helper<T>::update;
private:
vector<complex<double>> data;
complex<double> *grid;
vector<complex<T>> data;
complex<T> *grid;
void dump()
{
......@@ -701,29 +721,49 @@ class WriteHelper: public Helper
}
public:
complex<double> *p0;
WriteHelper(int nu_, int nv_, int w, complex<double> *grid_)
: Helper(nu_, nv_, w), data(su*sv, 0.), grid(grid_) {}
complex<T> *p0;
WriteHelper(int nu_, int nv_, int w, complex<T> *grid_)
: Helper<T>(nu_, nv_, w), data(su*sv, T(0)), grid(grid_) {}
~WriteHelper() { dump(); }
void prep_write(double u_in, double v_in)
void prep_write(T u_in, T v_in)
{
update(u_in, v_in);
if (need_to_move())
{
dump();
update_position();
fill(data.begin(), data.end(), 0.);
fill(data.begin(), data.end(), T(0));
}
p0 = data.data() + sv*(iu0-bu0) + iv0-bv0;
}
};
class ReadHelper: public Helper
template<typename T> class ReadHelper: public Helper<T>
{
protected:
using Helper<T>::nu;
using Helper<T>::nv;
public:
using Helper<T>::w;
using Helper<T>::beta;
protected:
using Helper<T>::nsafe;
using Helper<T>::su;
public:
using Helper<T>::sv;
using Helper<T>::kernel;
using Helper<T>::iu0;
using Helper<T>::iv0;
using Helper<T>::bu0;
using Helper<T>::bv0;
using Helper<T>::need_to_move;
using Helper<T>::update_position;
using Helper<T>::update;
private:
vector<complex<double>> data;
const complex<double> *grid;
vector<complex<T>> data;
const complex<T> *grid;
void load()
{
......@@ -742,11 +782,11 @@ class ReadHelper: public Helper
}
public:
const complex<double> *p0;
ReadHelper(int nu_, int nv_, int w_, const complex<double> *grid_)
: Helper(nu_, nv_, w_), data(su*sv,0.), grid(grid_), p0(nullptr) {}
const complex<T> *p0;
ReadHelper(int nu_, int nv_, int w_, const complex<T> *grid_)
: Helper<T>(nu_, nv_, w_), data(su*sv,T(0)), grid(grid_), p0(nullptr) {}
void prep_read(double u_in, double v_in)
void prep_read(T u_in, T v_in)
{
update(u_in, v_in);
if (need_to_move())
......@@ -758,8 +798,9 @@ class ReadHelper: public Helper
}
};
pyarr_c<double> vis2grid(const Baselines<double> &baselines, const GridderConfig<double> &gconf,
const pyarr_c<uint32_t> &idx_, const pyarr_c<complex<double>> &vis_)
template<typename T> pyarr_c<T> vis2grid(const Baselines<T> &baselines,
const GridderConfig<T> &gconf, const pyarr_c<uint32_t> &idx_,
const pyarr_c<complex<T>> &vis_)
{
myassert(idx_.ndim()==1, "idx array must be 1D");
myassert(vis_.ndim()==1, "vis must be 1D");
......@@ -769,29 +810,29 @@ pyarr_c<double> vis2grid(const Baselines<double> &baselines, const GridderConfig
auto idx = idx_.data();
size_t nu=gconf.Nu(), nv=gconf.Nv();
auto res = makearray<complex<double>>({nu, nv});
auto res = makearray<complex<T>>({nu, nv});
auto grid = res.mutable_data();
for (size_t i=0; i<nu*nv; ++i) grid[i] = 0.;
#pragma omp parallel
{
WriteHelper hlp(nu, nv, gconf.W(), grid);
double emb = exp(-2*hlp.beta);
const double * RESTRICT ku = hlp.kernel.data();
const double * RESTRICT kv = hlp.kernel.data()+hlp.w;
WriteHelper<T> hlp(nu, nv, gconf.W(), grid);
T emb = exp(-2*hlp.beta);
const T * RESTRICT ku = hlp.kernel.data();
const T * RESTRICT kv = hlp.kernel.data()+hlp.w;
// Loop over sampling points
#pragma omp for schedule(dynamic,10000)
for (size_t ipart=0; ipart<nvis; ++ipart)
{
UVW<double> coord = baselines.effectiveCoord(idx[ipart]);
UVW<T> coord = baselines.effectiveCoord(idx[ipart]);
hlp.prep_write(coord.u, coord.v);
auto * RESTRICT ptr = hlp.p0;
int w = hlp.w;
auto v(vis[ipart]*emb);
for (int cu=0; cu<w; ++cu)
{
complex<double> tmp(v*ku[cu]);
complex<T> tmp(v*ku[cu]);
for (int cv=0; cv<w; ++cv)
ptr[cv] += tmp*kv[cv];
ptr+=hlp.sv;
......@@ -800,9 +841,10 @@ pyarr_c<double> vis2grid(const Baselines<double> &baselines, const GridderConfig
} // end of parallel region
return complex2hartley(res);
}
pyarr_c<complex<double>> grid2vis(const Baselines<double> &baselines,
const GridderConfig<double> &gconf, const pyarr_c<uint32_t> &idx_,
const pyarr_c<double> &grid0_)
template<typename T> pyarr_c<complex<T>> grid2vis(const Baselines<T> &baselines,
const GridderConfig<T> &gconf, const pyarr_c<uint32_t> &idx_,
const pyarr_c<T> &grid0_)
{
size_t nu=gconf.Nu(), nv=gconf.Nv();
myassert(idx_.ndim()==1, "idx array must be 1D");
......@@ -814,28 +856,28 @@ pyarr_c<complex<double>> grid2vis(const Baselines<double> &baselines,
size_t nvis = size_t(idx_.shape(0));
auto idx = idx_.data();
auto res = makearray<complex<double>>({nvis});
auto res = makearray<complex<T>>({nvis});
auto vis = res.mutable_data();
// Loop over sampling points
#pragma omp parallel
{
ReadHelper hlp(nu, nv, gconf.W(), grid);
double emb = exp(-2*hlp.beta);
const double * RESTRICT ku = hlp.kernel.data();
const double * RESTRICT kv = hlp.kernel.data()+hlp.w;
ReadHelper<T> hlp(nu, nv, gconf.W(), grid);
T emb = exp(-2*hlp.beta);
const T * RESTRICT ku = hlp.kernel.data();
const T * RESTRICT kv = hlp.kernel.data()+hlp.w;
#pragma omp for schedule(dynamic,10000)
for (size_t ipart=0; ipart<nvis; ++ipart)
{
UVW<double> coord = baselines.effectiveCoord(idx[ipart]);
UVW<T> coord = baselines.effectiveCoord(idx[ipart]);
hlp.prep_read(coord.u, coord.v);
complex<double> r = 0.;
complex<T> r = 0;
auto * RESTRICT ptr = hlp.p0;
int w = hlp.w;
for (int cu=0; cu<w; ++cu)
{
complex<double> tmp(0.);
complex<T> tmp(0);
for (int cv=0; cv<w; ++cv)
tmp += ptr[cv] * kv[cv];
r += tmp*ku[cu];
......@@ -847,8 +889,8 @@ pyarr_c<complex<double>> grid2vis(const Baselines<double> &baselines,
return res;
}
pyarr_c<uint32_t> getIndices(const Baselines<double> &baselines,
const GridderConfig<double> &gconf, int chbegin, int chend, double wmin, double wmax)
template<typename T> pyarr_c<uint32_t> getIndices(const Baselines<T> &baselines,
const GridderConfig<T> &gconf, int chbegin, int chend, T wmin, T wmax)
{
auto idx = baselines.getIndices(chbegin, chend, wmin, wmax);
vector<size_t> peano(idx.size());
......@@ -944,8 +986,26 @@ PYBIND11_MODULE(nifty_gridder, m)
.def("Nv", &GridderConfig<double>::Nv)
.def("grid2dirty", &GridderConfig<double>::grid2dirty, "grid"_a)
.def("dirty2grid", &GridderConfig<double>::dirty2grid, "dirty"_a);
m.def("getIndices", getIndices, "baselines"_a, "gconf"_a, "chbegin"_a=-1,
m.def("getIndices", getIndices<double>, "baselines"_a, "gconf"_a, "chbegin"_a=-1,
"chend"_a=-1, "wmin"_a=-1e30, "wmax"_a=1e30);
m.def("vis2grid",&vis2grid<double>, "baselines"_a, "gconf"_a, "idx"_a, "vis"_a);
m.def("grid2vis",&grid2vis<double>, "baselines"_a, "gconf"_a, "idx"_a, "grid"_a);
py::class_<Baselines<float>> (m, "Baselines_f", Baselines_DS)
.def(py::init<pyarr_c<float>, pyarr_c<float>>(), "coord"_a, "scaling"_a)
.def ("ms2vis",&Baselines<float>::ms2vis, BL_ms2vis_DS, "ms"_a, "idx"_a)
.def ("vis2ms",&Baselines<float>::vis2ms, BL_vis2ms_DS, "vis"_a, "idx"_a)
.def ("add_vis_to_ms",&Baselines<float>::add_vis_to_ms, BL_add_vis_to_ms_DS,
"vis"_a, "idx"_a, "ms"_a.noconvert());
py::class_<GridderConfig<float>> (m, "GridderConfig_f")
.def(py::init<size_t, size_t, float, float, float>(),"nxdirty"_a,
"nydirty"_a, "epsilon"_a, "urange"_a, "vrange"_a)
.def("Nu", &GridderConfig<float>::Nu)
.def("Nv", &GridderConfig<float>::Nv)
.def("grid2dirty", &GridderConfig<float>::grid2dirty, "grid"_a)
.def("dirty2grid", &GridderConfig<float>::dirty2grid, "dirty"_a);
m.def("getIndices_f", getIndices<float>, "baselines"_a, "gconf"_a, "chbegin"_a=-1,
"chend"_a=-1, "wmin"_a=-1e30, "wmax"_a=1e30);
m.def("vis2grid",&vis2grid, "baselines"_a, "gconf"_a, "idx"_a, "vis"_a);
m.def("grid2vis",&grid2vis, "baselines"_a, "gconf"_a, "idx"_a, "grid"_a);
m.def("vis2grid_f",&vis2grid<float>, "baselines"_a, "gconf"_a, "idx"_a, "vis"_a);
m.def("grid2vis_f",&grid2vis<float>, "baselines"_a, "gconf"_a, "idx"_a, "grid"_a);
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment