Commit fd55bdb5 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more interface tweaking

parent ba67cb5d
......@@ -1354,11 +1354,11 @@ template<typename T, typename Serv> void x2dirty(
gconf.grid2dirty(cmav(rgrid), dirty);
}
}
template<typename T> void vis2dirty_wstack(const Baselines &baselines,
template<typename T> void vis2dirty(const Baselines &baselines,
const GridderConfig &gconf, const const_mav<uint32_t,1> &idx,
const const_mav<complex<T>,1> &vis, mav<T,2> &dirty)
const const_mav<complex<T>,1> &vis, mav<T,2> &dirty, bool do_wstacking)
{
x2dirty(gconf, makeVisServ(baselines, idx, vis, nullmav<T,1>()), dirty, true);
x2dirty(gconf, makeVisServ(baselines, idx, vis, nullmav<T,1>()), dirty, do_wstacking);
}
template<typename T, typename Serv> void dirty2x(
......@@ -1444,11 +1444,11 @@ template<typename T, typename Serv> void dirty2x(
grid2x_c(gconf, cmav(grid2), srv);
}
}
template<typename T> void dirty2vis_wstack(const Baselines &baselines,
template<typename T> void dirty2vis(const Baselines &baselines,
const GridderConfig &gconf, const const_mav<uint32_t,1> &idx,
const const_mav<T,2> &dirty, mav<complex<T>,1> &vis)
const const_mav<T,2> &dirty, mav<complex<T>,1> &vis, bool do_wstacking)
{
dirty2x(gconf, dirty, makeVisServ(baselines, idx, vis, nullmav<T,1>()), true);
dirty2x(gconf, dirty, makeVisServ(baselines, idx, vis, nullmav<T,1>()), do_wstacking);
}
......@@ -1601,8 +1601,8 @@ using detail::ms2grid_c;
using detail::grid2ms_c;
using detail::vis2grid_c;
using detail::grid2vis_c;
using detail::vis2dirty_wstack;
using detail::dirty2vis_wstack;
using detail::vis2dirty;
using detail::dirty2vis;
using detail::gridding;
using detail::degridding;
using detail::streamDump__;
......
......@@ -610,24 +610,34 @@ template<typename T> pyarr<complex<T>> Pygrid2ms(const PyBaselines &baselines,
return Pygrid2ms_c(baselines, gconf, idx_, grid2_, ms_in, wgt_);
}
template<typename T> pyarr<complex<T>> Pyapply_holo(
template<typename T> pyarr<complex<T>> apply_holo2(
const PyBaselines &baselines, const PyGridderConfig &gconf,
const pyarr<uint32_t> &idx_, const pyarr<complex<T>> &grid_,
const py::object &wgt_)
const py::array &idx_, const py::array &grid_, const py::object &wgt_)
{
auto idx = make_const_mav<1>(idx_);
auto grid = make_const_mav<2>(grid_);
pyarr<T> wgt2 = providePotentialArray<T>(wgt_, "wgt", {idx.shape(0)});
auto wgt=make_const_mav<1>(wgt2);
auto res = makeArray<complex<T>>({grid.shape(0),grid.shape(1)});
auto ogrid = make_mav<2>(res);
auto idx = getPyarr<uint32_t>(idx_, "idx");
auto idx2 = make_const_mav<1>(idx);
auto grid = getPyarr<complex<T>>(grid_, "grid");
auto grid2 = make_const_mav<2>(grid);
auto wgt = providePotentialArray<T>(wgt_, "wgt", {idx2.shape(0)});
auto wgt2 = make_const_mav<1>(wgt);
auto res = makeArray<complex<T>>({grid2.shape(0),grid2.shape(1)});
auto res2 = make_mav<2>(res);
{
py::gil_scoped_release release;
apply_holo(baselines, gconf, idx, grid, ogrid, wgt);
apply_holo(baselines, gconf, idx2, grid2, res2, wgt2);
}
return res;
}
py::array Pyapply_holo(
const PyBaselines &baselines, const PyGridderConfig &gconf,
const py::array &idx, const py::array &grid, const py::object &wgt)
{
if (isPytype<complex<float>>(grid))
return apply_holo2<float>(baselines, gconf, idx, grid, wgt);
if (isPytype<complex<double>>(grid))
return apply_holo2<double>(baselines, gconf, idx, grid, wgt);
myfail("type matching failed: 'grid' has neither type 'c8' nor 'c16'");
}
template<typename T> pyarr<T> Pyget_correlations(
const PyBaselines &baselines, const PyGridderConfig &gconf,
......@@ -692,39 +702,60 @@ pyarr<uint32_t> PygetIndices(const PyBaselines &baselines,
return res;
}
template<typename T> pyarr<T> Pyvis2dirty_wstack(const PyBaselines &baselines,
const PyGridderConfig &gconf, const pyarr<uint32_t> &idx_,
const pyarr<complex<T>> &vis_)
template<typename T> pyarr<T> vis2dirty2(const PyBaselines &baselines,
const PyGridderConfig &gconf, const py::array &idx_,
const py::array &vis_, bool do_wstacking)
{
auto nx_dirty=gconf.Nxdirty();
auto ny_dirty=gconf.Nydirty();
auto idx2=make_const_mav<1>(idx_);
auto vis2=make_const_mav<1>(vis_);
auto dirty = makeArray<T>({nx_dirty, ny_dirty});
auto dirty2=make_mav<2>(dirty);
auto idx = getPyarr<uint32_t>(idx_, "idx");
auto idx2 = make_const_mav<1>(idx);
auto dirty = makeArray<T>({gconf.Nxdirty(), gconf.Nydirty()});
auto dirty2 = make_mav<2>(dirty);
auto vis = getPyarr<complex<T>>(vis_, "vis");
auto vis2 = make_const_mav<1>(vis);
{
py::gil_scoped_release release;
vis2dirty_wstack<T>(baselines, gconf, idx2, vis2, dirty2);
vis2dirty<T>(baselines, gconf, idx2, vis2, dirty2, do_wstacking);
}
return dirty;
}
py::array Pyvis2dirty(const PyBaselines &baselines,
const PyGridderConfig &gconf, const py::array &idx,
const py::array &vis, bool do_wstacking)
{
if (isPytype<complex<float>>(vis))
return vis2dirty2<float>(baselines, gconf, idx, vis, do_wstacking);
if (isPytype<complex<double>>(vis))
return vis2dirty2<double>(baselines, gconf, idx, vis, do_wstacking);
myfail("type matching failed: 'vis' has neither type 'c8' nor 'c16'");
}
template<typename T> pyarr<complex<T>> Pydirty2vis_wstack(const PyBaselines &baselines,
template<typename T> pyarr<complex<T>> dirty2vis2(const PyBaselines &baselines,
const PyGridderConfig &gconf, const pyarr<uint32_t> &idx_,
const pyarr<T> &dirty_)
const pyarr<T> &dirty_, bool do_wstacking)
{
auto idx2=make_const_mav<1>(idx_);
auto nvis = idx2.shape(0);
auto vis = makeArray<complex<T>>({nvis});
auto vis2=make_mav<1>(vis);
auto dirty2=make_const_mav<2>(dirty_);
auto idx = getPyarr<uint32_t>(idx_, "idx");
auto idx2 = make_const_mav<1>(idx);
auto dirty = getPyarr<T>(dirty_, "dirty");
auto dirty2 = make_const_mav<2>(dirty_);
auto vis = makeArray<complex<T>>({idx2.shape(0)});
auto vis2 = make_mav<1>(vis);
{
py::gil_scoped_release release;
vis2.fill(0);
dirty2vis_wstack<T>(baselines, gconf, idx2, dirty2, vis2);
dirty2vis<T>(baselines, gconf, idx2, dirty2, vis2, do_wstacking);
}
return vis;
}
py::array Pydirty2vis(const PyBaselines &baselines,
const PyGridderConfig &gconf, const py::array &idx,
const py::array &dirty, bool do_wstacking)
{
if (isPytype<float>(dirty))
return dirty2vis2<float>(baselines, gconf, idx, dirty, do_wstacking);
if (isPytype<double>(dirty))
return dirty2vis2<double>(baselines, gconf, idx, dirty, do_wstacking);
myfail("type matching failed: 'dirty' has neither type 'f4' nor 'f8'");
}
template<typename T> py::array gridding2(const py::array &uvw_,
const py::array &freq_, const py::array &ms_, const py::object &wgt_,
......@@ -875,14 +906,14 @@ PYBIND11_MODULE(nifty_gridder, m)
"grid"_a, "wgt"_a=None);
m.def("grid2ms_c",&Pygrid2ms_c<double>, "baselines"_a, "gconf"_a, "idx"_a,
"grid"_a, "ms_in"_a=None, "wgt"_a=None);
m.def("apply_holo",&Pyapply_holo<double>, "baselines"_a, "gconf"_a, "idx"_a,
m.def("apply_holo",&Pyapply_holo, "baselines"_a, "gconf"_a, "idx"_a,
"grid"_a, "wgt"_a=None);
m.def("get_correlations", &Pyget_correlations<double>, "baselines"_a, "gconf"_a,
"idx"_a, "du"_a, "dv"_a, "wgt"_a=None);
m.def("vis2dirty_wstack",&Pyvis2dirty_wstack<double>, "baselines"_a, "gconf"_a,
"idx"_a, "vis"_a);
m.def("dirty2vis_wstack",&Pydirty2vis_wstack<double>, "baselines"_a, "gconf"_a,
"idx"_a, "dirty"_a);
m.def("vis2dirty",&Pyvis2dirty, "baselines"_a, "gconf"_a,
"idx"_a, "vis"_a, "do_wstacking"_a=false);
m.def("dirty2vis",&Pydirty2vis, "baselines"_a, "gconf"_a,
"idx"_a, "dirty"_a, "do_wstacking"_a=false);
m.def("gridding",&Pygridding,"uvw"_a,"freq"_a,"ms"_a,
"wgt"_a=None,"npix_x"_a,"npix_y"_a,"pixsize_x"_a,"pixsize_y"_a,"epsilon"_a,
......
......@@ -76,8 +76,8 @@ def test_adjointness_wgridding(nxdirty, nydirty, nrow, nchan, epsilon):
bl, conf, idx = _init_gridder(nxdirty, nydirty, epsilon, nchan, nrow)
vis = np.random.rand(*idx.shape)-0.5 + 1j*(np.random.rand(*idx.shape)-0.5)
dirty = np.random.rand(conf.Nxdirty(), conf.Nydirty())-0.5
dirty2 = ng.vis2dirty_wstack(bl, conf, idx, vis)
vis2 = ng.dirty2vis_wstack(bl, conf, idx, dirty)
dirty2 = ng.vis2dirty(bl, conf, idx, vis, do_wstacking=True)
vis2 = ng.dirty2vis(bl, conf, idx, dirty, do_wstacking=True)
assert_allclose(np.vdot(vis, vis2).real, np.vdot(dirty2, dirty),
rtol=2e-13)
......@@ -238,7 +238,7 @@ def test_correlations(nxdirty, nydirty, nrow, nchan, epsilon, du, dv, weight):
(0, 0),
(2*nxdirty-1, 2*nydirty-1))
for pp in pts:
grid = np.zeros((2*nxdirty, 2*nydirty))
grid = np.zeros((2*nxdirty, 2*nydirty))+0j
grid[pp] = 1
y1 = ng.apply_holo(bl, conf, idx, grid, wgt)
ind = (pp[0]+du) % (2*nxdirty), (pp[1]+dv) % (2*nydirty)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment