Commit a6e1eb4f authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'wsclean_adjustments' into 'master'

Wsclean adjustments

See merge request !28
parents 38f63ce9 eb9f681c
This diff is collapsed.
......@@ -284,6 +284,8 @@ w : float
the w value to use
adjoint: bool
if True, apply the complex conjugate of the w screen
divide_by_n: bool, default=True
if True, divide ny n.
Returns
=======
......@@ -314,9 +316,13 @@ class PyGridderConfig: public GridderConfig
{
public:
using GridderConfig::GridderConfig;
PyGridderConfig(size_t nxdirty, size_t nydirty, size_t nu_, size_t nv_,
double epsilon, double pixsize_x, double pixsize_y, size_t nthreads_)
: GridderConfig(nxdirty, nydirty, nu_, nv_, epsilon, pixsize_x, pixsize_y, nthreads_) {}
PyGridderConfig(size_t nxdirty, size_t nydirty, double epsilon,
double pixsize_x, double pixsize_y, size_t nthreads)
: GridderConfig(nxdirty, nydirty, epsilon, pixsize_x, pixsize_y, nthreads) {}
double pixsize_x, double pixsize_y, size_t nthreads_)
: GridderConfig(nxdirty, nydirty, epsilon, pixsize_x, pixsize_y, nthreads_) {}
template<typename T> pyarr<T> apply_taper2(const pyarr<T> &img, bool divide) const
{
......@@ -416,23 +422,25 @@ class PyGridderConfig: public GridderConfig
myfail("type matching failed: 'dirty' has neither type 'c8' nor 'c16'");
}
template<typename T> pyarr<complex<T>> apply_wscreen2
(const pyarr<complex<T>> &dirty, double w, bool adjoint) const
(const pyarr<complex<T>> &dirty, double w, bool adjoint,
bool divide_by_n) const
{
auto dirty2 = make_const_mav<2>(dirty);
auto res = makeArray<complex<T>>({nx_dirty, ny_dirty});
auto res2 = make_mav<2>(res);
{
py::gil_scoped_release release;
GridderConfig::apply_wscreen(dirty2, res2, w, adjoint);
GridderConfig::apply_wscreen(dirty2, res2, w, adjoint, divide_by_n);
}
return res;
}
py::array apply_wscreen(const py::array &dirty, double w, bool adjoint) const
py::array apply_wscreen(const py::array &dirty, double w, bool adjoint,
bool divide_by_n) const
{
if (isPytype<complex<float>>(dirty))
return apply_wscreen2<float>(dirty, w, adjoint);
return apply_wscreen2<float>(dirty, w, adjoint, divide_by_n);
if (isPytype<complex<double>>(dirty))
return apply_wscreen2<double>(dirty, w, adjoint);
return apply_wscreen2<double>(dirty, w, adjoint, divide_by_n);
myfail("type matching failed: 'dirty' has neither type 'c8' nor 'c16'");
}
};
......@@ -883,9 +891,9 @@ py::array Pydirty2vis(const PyBaselines &baselines,
myfail("type matching failed: 'dirty' has neither type 'f4' nor 'f8'");
}
template<typename T> py::array ms2dirty2(const py::array &uvw_,
template<typename T> py::array ms2dirty_general2(const py::array &uvw_,
const py::array &freq_, const py::array &ms_, const py::object &wgt_,
size_t npix_x, size_t npix_y, double pixsize_x, double pixsize_y, double epsilon,
size_t npix_x, size_t npix_y, double pixsize_x, double pixsize_y, size_t nu, size_t nv, double epsilon,
bool do_wstacking, size_t nthreads, size_t verbosity)
{
auto uvw = getPyarr<double>(uvw_, "uvw");
......@@ -900,12 +908,24 @@ template<typename T> py::array ms2dirty2(const py::array &uvw_,
auto dirty2 = make_mav<2>(dirty);
{
py::gil_scoped_release release;
ms2dirty(uvw2,freq2,ms2,wgt2,pixsize_x,pixsize_y,epsilon,do_wstacking,
ms2dirty_general(uvw2,freq2,ms2,wgt2,pixsize_x,pixsize_y, nu, nv, epsilon,do_wstacking,
nthreads,dirty2,verbosity);
}
return dirty;
}
py::array Pyms2dirty_general(const py::array &uvw,
const py::array &freq, const py::array &ms, const py::object &wgt,
size_t npix_x, size_t npix_y, double pixsize_x, double pixsize_y, size_t nu, size_t nv, double epsilon,
bool do_wstacking, size_t nthreads, size_t verbosity)
{
if (isPytype<complex<float>>(ms))
return ms2dirty_general2<float>(uvw, freq, ms, wgt, npix_x, npix_y,
pixsize_x, pixsize_y, nu, nv, epsilon, do_wstacking, nthreads, verbosity);
if (isPytype<complex<double>>(ms))
return ms2dirty_general2<double>(uvw, freq, ms, wgt, npix_x, npix_y,
pixsize_x, pixsize_y, nu, nv, epsilon, do_wstacking, nthreads, verbosity);
myfail("type matching failed: 'ms' has neither type 'c8' nor 'c16'");
}
constexpr auto ms2dirty_DS = R"""(
Converts an MS object to dirty image.
......@@ -948,18 +968,14 @@ py::array Pyms2dirty(const py::array &uvw,
size_t npix_x, size_t npix_y, double pixsize_x, double pixsize_y, double epsilon,
bool do_wstacking, size_t nthreads, size_t verbosity)
{
if (isPytype<complex<float>>(ms))
return ms2dirty2<float>(uvw, freq, ms, wgt, npix_x, npix_y,
pixsize_x, pixsize_y, epsilon, do_wstacking, nthreads, verbosity);
if (isPytype<complex<double>>(ms))
return ms2dirty2<double>(uvw, freq, ms, wgt, npix_x, npix_y,
pixsize_x, pixsize_y, epsilon, do_wstacking, nthreads, verbosity);
myfail("type matching failed: 'ms' has neither type 'c8' nor 'c16'");
return Pyms2dirty_general(uvw, freq, ms, wgt, npix_x, npix_y,
pixsize_x, pixsize_y, 2*npix_x, 2*npix_y, epsilon, do_wstacking, nthreads,
verbosity);
}
template<typename T> py::array dirty2ms2(const py::array &uvw_,
template<typename T> py::array dirty2ms_general2(const py::array &uvw_,
const py::array &freq_, const py::array &dirty_, const py::object &wgt_,
double pixsize_x, double pixsize_y, double epsilon,
double pixsize_x, double pixsize_y, size_t nu, size_t nv, double epsilon,
bool do_wstacking, size_t nthreads, size_t verbosity)
{
auto uvw = getPyarr<double>(uvw_, "uvw");
......@@ -974,12 +990,24 @@ template<typename T> py::array dirty2ms2(const py::array &uvw_,
auto ms2 = make_mav<2>(ms);
{
py::gil_scoped_release release;
dirty2ms(uvw2,freq2,dirty2,wgt2,pixsize_x,pixsize_y,epsilon,do_wstacking,
dirty2ms_general(uvw2,freq2,dirty2,wgt2,pixsize_x,pixsize_y,nu, nv,epsilon,do_wstacking,
nthreads,ms2,verbosity);
}
return ms;
}
py::array Pydirty2ms_general(const py::array &uvw,
const py::array &freq, const py::array &dirty, const py::object &wgt,
double pixsize_x, double pixsize_y, size_t nu, size_t nv, double epsilon,
bool do_wstacking, size_t nthreads, size_t verbosity)
{
if (isPytype<float>(dirty))
return dirty2ms_general2<float>(uvw, freq, dirty, wgt,
pixsize_x, pixsize_y, nu, nv, epsilon, do_wstacking, nthreads, verbosity);
if (isPytype<double>(dirty))
return dirty2ms_general2<double>(uvw, freq, dirty, wgt,
pixsize_x, pixsize_y, nu, nv, epsilon, do_wstacking, nthreads, verbosity);
myfail("type matching failed: 'dirty' has neither type 'f4' nor 'f8'");
}
constexpr auto dirty2ms_DS = R"""(
Converts a dirty image to an MS object.
......@@ -1020,13 +1048,9 @@ py::array Pydirty2ms(const py::array &uvw,
double pixsize_x, double pixsize_y, double epsilon,
bool do_wstacking, size_t nthreads, size_t verbosity)
{
if (isPytype<float>(dirty))
return dirty2ms2<float>(uvw, freq, dirty, wgt,
pixsize_x, pixsize_y, epsilon, do_wstacking, nthreads, verbosity);
if (isPytype<double>(dirty))
return dirty2ms2<double>(uvw, freq, dirty, wgt,
pixsize_x, pixsize_y, epsilon, do_wstacking, nthreads, verbosity);
myfail("type matching failed: 'dirty' has neither type 'f4' nor 'f8'");
return Pydirty2ms_general(uvw, freq, dirty, wgt, pixsize_x, pixsize_y,
2*dirty.shape(0), 2*dirty.shape(1), epsilon, do_wstacking, nthreads,
verbosity);
}
} // unnamed namespace
......@@ -1048,6 +1072,8 @@ PYBIND11_MODULE(nifty_gridder, m)
py::class_<PyGridderConfig> (m, "GridderConfig", GridderConfig_DS)
.def(py::init<size_t, size_t, double, double, double, size_t>(),"nxdirty"_a,
"nydirty"_a, "epsilon"_a, "pixsize_x"_a, "pixsize_y"_a, "nthreads"_a=1)
.def(py::init<size_t, size_t, size_t, size_t, double, double, double, size_t>(),"nxdirty"_a,
"nydirty"_a, "nu"_a, "nv"_a, "epsilon"_a, "pixsize_x"_a, "pixsize_y"_a, "nthreads"_a=1)
.def("Nxdirty", &PyGridderConfig::Nxdirty)
.def("Nydirty", &PyGridderConfig::Nydirty)
.def("Epsilon", &PyGridderConfig::Epsilon)
......@@ -1065,24 +1091,26 @@ PYBIND11_MODULE(nifty_gridder, m)
dirty2grid_DS, "dirty"_a)
.def("dirty2grid_c", &PyGridderConfig::dirty2grid_c, "dirty"_a)
.def("apply_wscreen", &PyGridderConfig::apply_wscreen,
apply_wscreen_DS, "dirty"_a, "w"_a, "adjoint"_a)
apply_wscreen_DS, "dirty"_a, "w"_a, "adjoint"_a, "divide_by_n"_a=true)
// pickle support
.def(py::pickle(
// __getstate__
[](const PyGridderConfig & gc) {
// Encode object state in tuple
return py::make_tuple(gc.Nxdirty(), gc.Nydirty(), gc.Epsilon(),
return py::make_tuple(gc.Nxdirty(), gc.Nydirty(), gc.Nu(), gc.Nv(),
gc.Epsilon(),
gc.Pixsize_x(), gc.Pixsize_y(), gc.Nthreads());
},
// __setstate__
[](py::tuple t) {
myassert(t.size()==6,"Invalid state");
myassert(t.size()==8,"Invalid state");
// Reconstruct from tuple
return PyGridderConfig(t[0].cast<size_t>(), t[1].cast<size_t>(),
t[2].cast<double>(), t[3].cast<double>(),
t[4].cast<double>(), t[5].cast<size_t>());
t[2].cast<size_t>(), t[3].cast<size_t>(),
t[4].cast<double>(), t[5].cast<double>(),
t[6].cast<double>(), t[7].cast<size_t>());
}));
m.def("getIndices", PygetIndices, getIndices_DS, "baselines"_a,
......@@ -1116,7 +1144,13 @@ PYBIND11_MODULE(nifty_gridder, m)
m.def("ms2dirty", &Pyms2dirty, ms2dirty_DS, "uvw"_a, "freq"_a, "ms"_a,
"wgt"_a=None, "npix_x"_a, "npix_y"_a, "pixsize_x"_a, "pixsize_y"_a,
"epsilon"_a, "do_wstacking"_a=false, "nthreads"_a=1, "verbosity"_a=0);
m.def("ms2dirty_general", &Pyms2dirty_general, "uvw"_a, "freq"_a, "ms"_a,
"wgt"_a=None, "npix_x"_a, "npix_y"_a, "pixsize_x"_a, "pixsize_y"_a, "nu"_a, "nv"_a,
"epsilon"_a, "do_wstacking"_a=false, "nthreads"_a=1, "verbosity"_a=0);
m.def("dirty2ms", &Pydirty2ms, dirty2ms_DS, "uvw"_a, "freq"_a, "dirty"_a,
"wgt"_a=None, "pixsize_x"_a, "pixsize_y"_a, "epsilon"_a,
"do_wstacking"_a=false, "nthreads"_a=1, "verbosity"_a=0);
m.def("dirty2ms_general", &Pydirty2ms_general, "uvw"_a, "freq"_a, "dirty"_a,
"wgt"_a=None, "pixsize_x"_a, "pixsize_y"_a, "nu"_a, "nv"_a, "epsilon"_a,
"do_wstacking"_a=false, "nthreads"_a=1, "verbosity"_a=0);
}
......@@ -3387,6 +3387,36 @@ template<typename T> void r2r_separable_hartley(const shape_t &shape,
false);
}
template<typename T> void r2r_genuine_hartley(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const T *data_in, T *data_out, T fct, size_t nthreads=1)
{
if (util::prod(shape)==0) return;
if (axes.size()==1)
return r2r_separable_hartley(shape, stride_in, stride_out, axes, data_in,
data_out, fct, nthreads);
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
shape_t tshp(shape);
tshp[axes.back()] = tshp[axes.back()]/2+1;
arr<complex<T>> tdata(util::prod(tshp));
stride_t tstride(shape.size());
tstride.back()=sizeof(complex<T>);
for (size_t i=tstride.size()-1; i>0; --i)
tstride[i-1]=tstride[i]*ptrdiff_t(tshp[i]);
r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads);
cndarr<cmplx<T>> atmp(tdata.data(), tshp, tstride);
ndarr<T> aout(data_out, shape, stride_out);
simple_iter iin(atmp);
rev_iter iout(aout, axes);
while(iin.remaining()>0)
{
auto v = atmp[iin.ofs()];
aout[iout.ofs()] = v.r+v.i;
aout[iout.rev_ofs()] = v.r-v.i;
iin.advance(); iout.advance();
}
}
} // namespace detail
using detail::FORWARD;
......@@ -3398,6 +3428,7 @@ using detail::c2r;
using detail::r2c;
using detail::r2r_fftpack;
using detail::r2r_separable_hartley;
using detail::r2r_genuine_hartley;
using detail::dct;
using detail::dst;
......
......@@ -89,6 +89,46 @@ def test_adjointness_ms2dirty(nxdirty, nydirty, nrow, nchan, epsilon, singleprec
assert_allclose(np.vdot(ms, ms2).real, np.vdot(dirty2, dirty), rtol=tol)
@pmp("nxdirty", (30, 128))
@pmp("nydirty", (128, 250))
@pmp("ofactor", (1.2, 1.5, 1.7, 2.0))
@pmp("nrow", (2, 27))
@pmp("nchan", (1, 5))
@pmp("epsilon", (1e-1, 1e-3, 1e-5))
@pmp("singleprec", (True, False))
@pmp("wstacking", (True, False))
@pmp("use_wgt", (True, False))
@pmp("nthreads", (1, 2))
def test_adjointness_ms2dirty_general(nxdirty, nydirty, ofactor, nrow, nchan, epsilon, singleprec, wstacking, use_wgt, nthreads):
if singleprec and epsilon < 5e-5:
return
np.random.seed(42)
pixsizex = np.pi/180/60/nxdirty*0.2398
pixsizey = np.pi/180/60/nxdirty
speedoflight, f0 = 299792458., 1e9
freq = f0 + np.arange(nchan)*(f0/nchan)
uvw = (np.random.rand(nrow, 3)-0.5)/(pixsizey*f0/speedoflight)
ms = np.random.rand(nrow, nchan)-0.5 + 1j*(np.random.rand(nrow, nchan)-0.5)
wgt = np.random.rand(nrow, nchan) if use_wgt else None
dirty = np.random.rand(nxdirty, nydirty)-0.5
nu, nv = int(nxdirty*ofactor)+1, int(nydirty*ofactor)+1
if nu&1:
nu+=1
if nv&1:
nv+=1
if singleprec:
ms = ms.astype("c8")
dirty = dirty.astype("f4")
if wgt is not None:
wgt = wgt.astype("f4")
dirty2 = ng.ms2dirty_general(uvw, freq, ms, wgt, nxdirty, nydirty, pixsizex,
pixsizey, nu, nv, epsilon, wstacking, nthreads, 0).astype("f8")
ms2 = ng.dirty2ms_general(uvw, freq, dirty, wgt, pixsizex, pixsizey, nu, nv, epsilon,
wstacking, nthreads, 0).astype("c16")
tol = 5e-4 if singleprec else 5e-11
assert_allclose(np.vdot(ms, ms2).real, np.vdot(dirty2, dirty), rtol=tol)
@pmp('nxdirty', [16, 64])
@pmp('nydirty', [64])
@pmp("nrow", (2, 27))
......@@ -120,6 +160,43 @@ def test_ms2dirty_against_wdft(nxdirty, nydirty, nrow, nchan, epsilon, singlepre
assert_allclose(_l2error(dirty, ref), 0, atol=epsilon)
@pmp('nxdirty', [16, 64])
@pmp('nydirty', [64])
@pmp('ofactor', [1.2, 1.4, 1.7, 2])
@pmp("nrow", (2, 27))
@pmp("nchan", (1, 5))
@pmp("epsilon", (1e-2, 1e-3, 1e-4, 1e-7))
@pmp("singleprec", (True,))
@pmp("wstacking", (True,))
@pmp("use_wgt", (False, True))
@pmp("nthreads", (1, 2))
@pmp("fov", (1., 20.))
def test_ms2dirty_against_wdft2(nxdirty, nydirty, ofactor, nrow, nchan, epsilon, singleprec, wstacking, use_wgt, fov, nthreads):
if singleprec and epsilon < 5e-5:
return
np.random.seed(40)
pixsizex = fov*np.pi/180/nxdirty
pixsizey = fov*np.pi/180/nydirty*1.1
speedoflight, f0 = 299792458., 1e9
freq = f0 + np.arange(nchan)*(f0/nchan)
uvw = (np.random.rand(nrow, 3)-0.5)/(pixsizex*f0/speedoflight)
ms = np.random.rand(nrow, nchan)-0.5 + 1j*(np.random.rand(nrow, nchan)-0.5)
wgt = np.random.rand(nrow, nchan) if use_wgt else None
nu, nv = int(nxdirty*ofactor)+1, int(nydirty*ofactor)+1
if nu&1:
nu+=1
if nv&1:
nv+=1
if singleprec:
ms = ms.astype("c8")
if wgt is not None:
wgt = wgt.astype("f4")
dirty = ng.ms2dirty_general(uvw, freq, ms, wgt, nxdirty, nydirty, pixsizex,
pixsizey, nu, nv, epsilon, wstacking, nthreads, 0).astype("f8")
ref = explicit_gridder(uvw, freq, ms, wgt, nxdirty, nydirty, pixsizex, pixsizey, wstacking)
assert_allclose(_l2error(dirty, ref), 0, atol=epsilon)
@pmp("nxdirty", (128, 300))
@pmp("nydirty", (128, 250))
@pmp("nrow", (1, 100))
......@@ -383,15 +460,18 @@ def test_wplane_against_wdft(nxdirty, nydirty, nchan, nrow, fov):
@pmp("nrow", (10, 100))
@pmp("nchan", (1, 10))
@pmp("fov", (1, 10))
@pmp("yscale", (0.5, 1.))
@pmp("epsilon", (1e-2, 1e-7, 2e-12))
def test_wgridder_against_wdft(nxdirty, nydirty, nchan, nrow, fov, epsilon):
def test_wgridder_against_wdft(nxdirty, nydirty, yscale, nchan, nrow, fov, epsilon):
np.random.seed(40)
pixsize = fov*np.pi/180/nxdirty
pixsize_x = pixsize
pixsize_y = yscale*pixsize
conf = ng.GridderConfig(nxdirty=nxdirty,
nydirty=nydirty,
epsilon=epsilon,
pixsize_x=pixsize,
pixsize_y=0.5*pixsize)
pixsize_x=pixsize_x,
pixsize_y=pixsize_y)
speedoflight, f0 = 299792458., 1e9
freq = f0 + np.arange(nchan)*(f0/nchan)
uvw = (np.random.rand(nrow, 3)-0.5)/(pixsize*f0/speedoflight)
......@@ -401,8 +481,9 @@ def test_wgridder_against_wdft(nxdirty, nydirty, nchan, nrow, fov, epsilon):
ms = np.random.rand(nrow, nchan)-0.5 + 1j*(np.random.rand(nrow, nchan)-0.5)
vis = bl.ms2vis(ms, idx)
res0 = ng.vis2dirty(bl, conf, idx, vis, do_wstacking=True)
res1 = explicit_gridder(uvw, freq, ms, None, nxdirty, nydirty, pixsize,
0.5*pixsize, True)
res1 = explicit_gridder(uvw, freq, ms, None, nxdirty, nydirty, pixsize_x,
pixsize_y, True)
print(_l2error(res0, res1)/epsilon)
assert_(_l2error(res0, res1) < epsilon)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment