Commit d9642e63 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

generalize data types

parent bb94f74d
......@@ -389,16 +389,16 @@ template<typename T> struct UVW
{ return UVW(u*fct, v*fct, w*fct); }
};
class Baselines
template<typename T> class Baselines
{
private:
vector<UVW<double>> coord;
vector<double> scaling;
vector<UVW<T>> coord;
vector<T> scaling;
size_t nrows, nchan;
size_t channelbits, channelmask;
public:
Baselines(const pyarr_c<double> &coord_, const pyarr_c<double> &scaling_)
Baselines(const pyarr_c<T> &coord_, const pyarr_c<T> &scaling_)
{
myassert(coord_.ndim()==2, "coord array must be 2D");
myassert(coord_.shape(1)==3, "coord.shape[1] must be 3");
......@@ -413,14 +413,14 @@ class Baselines
coord.resize(nrows);
auto cood = coord_.data();
for (size_t i=0; i<coord.size(); ++i)
coord[i] = UVW<double>(cood[3*i], cood[3*i+1], cood[3*i+2]);
coord[i] = UVW<T>(cood[3*i], cood[3*i+1], cood[3*i+2]);
channelbits = bits_needed(nchan);
channelmask = (size_t(1)<<channelbits)-1;
auto rowbits = bits_needed(nrows);
myassert(rowbits+channelbits<=8*sizeof(uint32_t), "Ti too small");
}
vector<uint32_t> getIndices(int chbegin, int chend, double wmin,
double wmax) const
vector<uint32_t> getIndices(int chbegin, int chend, T wmin,
T wmax) const
{
if (chbegin<0) chbegin=0;
if (chend<0) chend=nchan;
......@@ -435,7 +435,7 @@ class Baselines
return odata;
}
UVW<double> effectiveCoord(uint32_t index) const
UVW<T> effectiveCoord(uint32_t index) const
{ return coord[index>>channelbits]*scaling[index&channelmask]; }
size_t Nrows() const { return nrows; }
size_t Nchannels() const { return nchan; }
......@@ -446,7 +446,7 @@ class Baselines
size_t offset(uint32_t index) const
{ return (index>>channelbits)*nchan + (index&channelmask); }
pyarr_c<complex<double>> ms2vis(const pyarr_c<complex<double>> &ms_,
pyarr_c<complex<T>> ms2vis(const pyarr_c<complex<T>> &ms_,
const pyarr_c<uint32_t> &idx_) const
{
myassert(idx_.ndim()==1, "idx array must be 1D");
......@@ -457,14 +457,14 @@ class Baselines
auto idx = idx_.data();
auto ms = ms_.data();
auto res=makearray<complex<double>>({nvis});
auto res=makearray<complex<T>>({nvis});
auto vis = res.mutable_data();
for (size_t i=0; i<nvis; ++i)
vis[i] = ms[offset(idx[i])];
return res;
}
pyarr_c<complex<double>> vis2ms(const pyarr_c<complex<double>> &vis_,
pyarr_c<complex<T>> vis2ms(const pyarr_c<complex<T>> &vis_,
const pyarr_c<uint32_t> &idx_) const
{
myassert(idx_.ndim()==1, "idx array must be 1D");
......@@ -474,17 +474,17 @@ class Baselines
auto idx = idx_.data();
auto vis = vis_.data();
auto res = makearray<complex<double>>({nrows, nchan});
auto res = makearray<complex<T>>({nrows, nchan});
auto ms = res.mutable_data();
for (size_t i=0; i<nrows*nchan; ++i)
ms[i] = 0.;
ms[i] = complex<T>(0);
for (size_t i=0; i<nvis; ++i)
ms[offset(idx[i])] = vis[i];
return res;
}
pyarr_c<complex<double>> add_vis_to_ms(const pyarr_c<complex<double>> &vis_,
const pyarr_c<uint32_t> &idx_, pyarr_c<complex<double>> &ms_) const
pyarr_c<complex<T>> add_vis_to_ms(const pyarr_c<complex<T>> &vis_,
const pyarr_c<uint32_t> &idx_, pyarr_c<complex<T>> &ms_) const
{
myassert(idx_.ndim()==1, "idx array must be 1D");
myassert(vis_.ndim()==1, "vis array must be 1D");
......@@ -758,7 +758,7 @@ class ReadHelper: public Helper
}
};
pyarr_c<double> vis2grid(const Baselines &baselines, const GridderConfig &gconf,
pyarr_c<double> vis2grid(const Baselines<double> &baselines, const GridderConfig &gconf,
const pyarr_c<uint32_t> &idx_, const pyarr_c<complex<double>> &vis_)
{
myassert(idx_.ndim()==1, "idx array must be 1D");
......@@ -800,7 +800,7 @@ pyarr_c<double> vis2grid(const Baselines &baselines, const GridderConfig &gconf,
} // end of parallel region
return complex2hartley(res);
}
pyarr_c<complex<double>> grid2vis(const Baselines &baselines,
pyarr_c<complex<double>> grid2vis(const Baselines<double> &baselines,
const GridderConfig &gconf, const pyarr_c<uint32_t> &idx_,
const pyarr_c<double> &grid0_)
{
......@@ -847,7 +847,7 @@ pyarr_c<complex<double>> grid2vis(const Baselines &baselines,
return res;
}
pyarr_c<uint32_t> getIndices(const Baselines &baselines,
pyarr_c<uint32_t> getIndices(const Baselines<double> &baselines,
const GridderConfig &gconf, int chbegin, int chend, double wmin, double wmax)
{
auto idx = baselines.getIndices(chbegin, chend, wmin, wmax);
......@@ -931,11 +931,11 @@ PYBIND11_MODULE(nifty_gridder, m)
{
using namespace pybind11::literals;
py::class_<Baselines> (m, "Baselines", Baselines_DS)
py::class_<Baselines<double>> (m, "Baselines", Baselines_DS)
.def(py::init<pyarr_c<double>, pyarr_c<double>>(), "coord"_a, "scaling"_a)
.def ("ms2vis",&Baselines::ms2vis, BL_ms2vis_DS, "ms"_a, "idx"_a)
.def ("vis2ms",&Baselines::vis2ms, BL_vis2ms_DS, "vis"_a, "idx"_a)
.def ("add_vis_to_ms",&Baselines::add_vis_to_ms, BL_add_vis_to_ms_DS,
.def ("ms2vis",&Baselines<double>::ms2vis, BL_ms2vis_DS, "ms"_a, "idx"_a)
.def ("vis2ms",&Baselines<double>::vis2ms, BL_vis2ms_DS, "vis"_a, "idx"_a)
.def ("add_vis_to_ms",&Baselines<double>::add_vis_to_ms, BL_add_vis_to_ms_DS,
"vis"_a, "idx"_a, "ms"_a.noconvert());
py::class_<GridderConfig> (m, "GridderConfig")
.def(py::init<size_t, size_t, double, double, double>(),"nxdirty"_a,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment