Commit bb94f74d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

generalize data types

parent 98819486
......@@ -299,13 +299,14 @@ size_t get_w(double epsilon)
throw runtime_error("requested epsilon too small - minimum is 2e-13");
}
pyarr_c<double> complex2hartley (const pyarr_c<complex<double>> &grid_)
template<typename T> pyarr_c<T> complex2hartley
(const pyarr_c<complex<T>> &grid_)
{
myassert(grid_.ndim()==2, "grid array must be 2D");
size_t nu = size_t(grid_.shape(0)), nv = size_t(grid_.shape(1));
auto grid = grid_.data();
auto res = makearray<double>({nu,nv});
auto res = makearray<T>({nu,nv});
auto grid2 = res.mutable_data();
for (size_t u=0; u<nu; ++u)
{
......@@ -315,20 +316,21 @@ pyarr_c<double> complex2hartley (const pyarr_c<complex<double>> &grid_)
size_t xv = (v==0) ? 0 : nv-v;
size_t i1 = u*nv+v;
size_t i2 = xu*nv+xv;
grid2[i1] = 0.5*(grid[i1].real()+grid[i1].imag()+
grid[i2].real()-grid[i2].imag());
grid2[i1] = T(0.5)*(grid[i1].real()+grid[i1].imag()+
grid[i2].real()-grid[i2].imag());
}
}
return res;
}
pyarr_c<complex<double>> hartley2complex (const pyarr_c<double> &grid_)
template<typename T> pyarr_c<complex<T>> hartley2complex
(const pyarr_c<T> &grid_)
{
myassert(grid_.ndim()==2, "grid array must be 2D");
size_t nu = size_t(grid_.shape(0)), nv = size_t(grid_.shape(1));
auto grid = grid_.data();
auto res=makearray<complex<double>>({nu, nv});
auto res=makearray<complex<T>>({nu, nv});
auto grid2 = res.mutable_data();
for (size_t u=0; u<nu; ++u)
{
......@@ -338,9 +340,9 @@ pyarr_c<complex<double>> hartley2complex (const pyarr_c<double> &grid_)
size_t xv = (v==0) ? 0 : nv-v;
size_t i1 = u*nv+v;
size_t i2 = xu*nv+xv;
double v1 = 0.5*grid[i1];
double v2 = 0.5*grid[i2];
grid2[i1] = complex<double>(v1+v2, v1-v2);
T v1 = T(0.5)*grid[i1];
T v2 = T(0.5)*grid[i2];
grid2[i1] = complex<T>(v1+v2, v1-v2);
}
}
return res;
......@@ -370,27 +372,27 @@ vector<double> correction_factors (size_t n, size_t nval, size_t w)
return res;
}
struct UV
template<typename T> struct UV
{
double u, v;
T u, v;
UV () {}
UV (double u_, double v_) : u(u_), v(v_) {}
UV operator* (double fct) const
UV (T u_, T v_) : u(u_), v(v_) {}
UV operator* (T fct) const
{ return UV(u*fct, v*fct); }
};
struct UVW
template<typename T> struct UVW
{
double u, v, w;
T u, v, w;
UVW () {}
UVW (double u_, double v_, double w_) : u(u_), v(v_), w(w_) {}
UVW operator* (double fct) const
UVW (T u_, T v_, T w_) : u(u_), v(v_), w(w_) {}
UVW operator* (T fct) const
{ return UVW(u*fct, v*fct, w*fct); }
};
class Baselines
{
private:
vector<UVW> coord;
vector<UVW<double>> coord;
vector<double> scaling;
size_t nrows, nchan;
size_t channelbits, channelmask;
......@@ -411,7 +413,7 @@ class Baselines
coord.resize(nrows);
auto cood = coord_.data();
for (size_t i=0; i<coord.size(); ++i)
coord[i] = UVW(cood[3*i], cood[3*i+1], cood[3*i+2]);
coord[i] = UVW<double>(cood[3*i], cood[3*i+1], cood[3*i+2]);
channelbits = bits_needed(nchan);
channelmask = (size_t(1)<<channelbits)-1;
auto rowbits = bits_needed(nrows);
......@@ -433,7 +435,7 @@ class Baselines
return odata;
}
UVW effectiveCoord(uint32_t index) const
UVW<double> effectiveCoord(uint32_t index) const
{ return coord[index>>channelbits]*scaling[index&channelmask]; }
size_t Nrows() const { return nrows; }
size_t Nchannels() const { return nchan; }
......@@ -540,7 +542,7 @@ class GridderConfig
size_t Nu() const { return nu; }
size_t Nv() const { return nv; }
size_t W() const { return w; }
size_t coord2peano(const UVW &coord) const
size_t coord2peano(const UVW<double> &coord) const
{
double u=fmodulo(coord.u*ucorr, 1.)*nu,
v=fmodulo(coord.v*vcorr, 1.)*nv;
......@@ -782,7 +784,7 @@ pyarr_c<double> vis2grid(const Baselines &baselines, const GridderConfig &gconf,
#pragma omp for schedule(dynamic,10000)
for (size_t ipart=0; ipart<nvis; ++ipart)
{
UVW coord = baselines.effectiveCoord(idx[ipart]);
UVW<double> coord = baselines.effectiveCoord(idx[ipart]);
hlp.prep_write(coord.u, coord.v);
auto * RESTRICT ptr = hlp.p0;
int w = hlp.w;
......@@ -826,7 +828,7 @@ pyarr_c<complex<double>> grid2vis(const Baselines &baselines,
#pragma omp for schedule(dynamic,10000)
for (size_t ipart=0; ipart<nvis; ++ipart)
{
UVW coord = baselines.effectiveCoord(idx[ipart]);
UVW<double> coord = baselines.effectiveCoord(idx[ipart]);
hlp.prep_read(coord.u, coord.v);
complex<double> r = 0.;
auto * RESTRICT ptr = hlp.p0;
......@@ -867,10 +869,10 @@ Class storing UVW coordinates and channel information.
Parameters
==========
coord: np.array((nrows, 2), dtype=np.float)
u and v coordinates for each row
coord: np.array((nrows, 3), dtype=np.float)
u, v and w coordinates for each row
scaling: np.array((nchannels,), dtype=np.float)
scaling factor for u and v for each individual channel
scaling factor for u, v, w for each individual channel
)""";
const char *BL_ms2vis_DS = R"""(
......@@ -890,7 +892,37 @@ np.array((nvis,), dtype=np.complex)
)""";
const char *BL_vis2ms_DS = R"""(
TBW
Produces a new MS with the provided visibilities set.
Parameters
==========
vis: np.array((nvis,), dtype=np.complex)
The visibility data for the index array
idx: np.array((nvis,), dtype=np.uint32)
the indices to be inserted
Returns
=======
np.array((nrows, nchannels), dtype=np.complex)
the measurement set's visibility data (0 where not covered by idx)
)""";
const char *BL_add_vis_to_ms_DS = R"""(
Adds a set of visibilities to an existing MS.
Parameters
==========
vis: np.array((nvis,), dtype=np.complex)
The visibility data for the index array
idx: np.array((nvis,), dtype=np.uint32)
the indices to be inserted
ms: np.array((nrows, nchannels), dtype=np.complex)
a MS
Returns
=======
np.array((nrows, nchannels), dtype=np.complex)
the updated MS
)""";
} // unnamed namespace
......@@ -903,7 +935,8 @@ PYBIND11_MODULE(nifty_gridder, m)
.def(py::init<pyarr_c<double>, pyarr_c<double>>(), "coord"_a, "scaling"_a)
.def ("ms2vis",&Baselines::ms2vis, BL_ms2vis_DS, "ms"_a, "idx"_a)
.def ("vis2ms",&Baselines::vis2ms, BL_vis2ms_DS, "vis"_a, "idx"_a)
.def ("add_vis_to_ms",&Baselines::add_vis_to_ms, "vis"_a, "idx"_a, "ms"_a.noconvert());
.def ("add_vis_to_ms",&Baselines::add_vis_to_ms, BL_add_vis_to_ms_DS,
"vis"_a, "idx"_a, "ms"_a.noconvert());
py::class_<GridderConfig> (m, "GridderConfig")
.def(py::init<size_t, size_t, double, double, double>(),"nxdirty"_a,
"nydirty"_a, "epsilon"_a, "urange"_a, "vrange"_a)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment