Commit db4cc7cc authored by Martin Reinecke's avatar Martin Reinecke

start cleaning up the gridder

parent e61cfd4d
...@@ -36,12 +36,6 @@ ...@@ -36,12 +36,6 @@
#include "mr_util/infra/mav.h" #include "mr_util/infra/mav.h"
#include "mr_util/math/es_kernel.h" #include "mr_util/math/es_kernel.h"
#if defined(__GNUC__)
#define ALIGNED(align) __attribute__ ((aligned(align)))
#else
#define ALIGNED(align)
#endif
namespace gridder { namespace gridder {
namespace detail { namespace detail {
...@@ -51,10 +45,7 @@ using namespace mr; ...@@ -51,10 +45,7 @@ using namespace mr;
template<size_t ndim> void checkShape template<size_t ndim> void checkShape
(const array<size_t, ndim> &shp1, const array<size_t, ndim> &shp2) (const array<size_t, ndim> &shp1, const array<size_t, ndim> &shp2)
{ { MR_assert(shp1==shp2, "shape mismatch"); }
for (size_t i=0; i<ndim; ++i)
MR_assert(shp1[i]==shp2[i], "shape mismatch");
}
template<typename T> inline T fmod1 (T v) template<typename T> inline T fmod1 (T v)
{ return v-floor(v); } { return v-floor(v); }
...@@ -66,7 +57,7 @@ template<typename T> inline T fmod1 (T v) ...@@ -66,7 +57,7 @@ template<typename T> inline T fmod1 (T v)
template<typename T> void complex2hartley template<typename T> void complex2hartley
(const mav<complex<T>, 2> &grid, mav<T,2> &grid2, size_t nthreads) (const mav<complex<T>, 2> &grid, mav<T,2> &grid2, size_t nthreads)
{ {
checkShape(grid.shape(), grid2.shape()); MR_assert(grid.conformable(grid2), "shape mismatch");
size_t nu=grid.shape(0), nv=grid.shape(1); size_t nu=grid.shape(0), nv=grid.shape(1);
execStatic(nu, nthreads, 0, [&](Scheduler &sched) execStatic(nu, nthreads, 0, [&](Scheduler &sched)
...@@ -87,7 +78,7 @@ template<typename T> void complex2hartley ...@@ -87,7 +78,7 @@ template<typename T> void complex2hartley
template<typename T> void hartley2complex template<typename T> void hartley2complex
(const mav<T,2> &grid, mav<complex<T>,2> &grid2, size_t nthreads) (const mav<T,2> &grid, mav<complex<T>,2> &grid2, size_t nthreads)
{ {
checkShape(grid.shape(), grid2.shape()); MR_assert(grid.conformable(grid2), "shape mismatch");
size_t nu=grid.shape(0), nv=grid.shape(1); size_t nu=grid.shape(0), nv=grid.shape(1);
execStatic(nu, nthreads, 0, [&](Scheduler &sched) execStatic(nu, nthreads, 0, [&](Scheduler &sched)
...@@ -109,7 +100,7 @@ template<typename T> void hartley2complex ...@@ -109,7 +100,7 @@ template<typename T> void hartley2complex
template<typename T> void hartley2_2D(const mav<T,2> &in, template<typename T> void hartley2_2D(const mav<T,2> &in,
mav<T,2> &out, size_t nthreads) mav<T,2> &out, size_t nthreads)
{ {
checkShape(in.shape(), out.shape()); MR_assert(in.conformable(out), "shape mismatch");
size_t nu=in.shape(0), nv=in.shape(1); size_t nu=in.shape(0), nv=in.shape(1);
fmav<T> fin(in), fout(out); fmav<T> fin(in), fout(out);
r2r_separable_hartley(fin, fout, {0,1}, T(1), nthreads); r2r_separable_hartley(fin, fout, {0,1}, T(1), nthreads);
...@@ -282,7 +273,7 @@ class GridderConfig ...@@ -282,7 +273,7 @@ class GridderConfig
template<typename T> void grid2dirty_post(mav<T,2> &tmav, template<typename T> void grid2dirty_post(mav<T,2> &tmav,
mav<T,2> &dirty) const mav<T,2> &dirty) const
{ {
checkShape(dirty.shape(), {nx_dirty,ny_dirty}); checkShape(dirty.shape(), {nx_dirty, ny_dirty});
auto cfu = correction_factors(nu, ofactor, nx_dirty/2+1, supp, nthreads); auto cfu = correction_factors(nu, ofactor, nx_dirty/2+1, supp, nthreads);
auto cfv = correction_factors(nv, ofactor, ny_dirty/2+1, supp, nthreads); auto cfv = correction_factors(nv, ofactor, ny_dirty/2+1, supp, nthreads);
execStatic(nx_dirty, nthreads, 0, [&](Scheduler &sched) execStatic(nx_dirty, nthreads, 0, [&](Scheduler &sched)
...@@ -546,7 +537,7 @@ template<typename T, typename T2=complex<T>> class Helper ...@@ -546,7 +537,7 @@ template<typename T, typename T2=complex<T>> class Helper
public: public:
const T2 *p0r; const T2 *p0r;
T2 *p0w; T2 *p0w;
T kernel[64] ALIGNED(64); T kernel[64] MRUTIL_ALIGNED(64);
static constexpr size_t vlen=native_simd<T>::size(); static constexpr size_t vlen=native_simd<T>::size();
Helper(const GridderConfig &gconf_, const T2 *grid_r_, T2 *grid_w_, Helper(const GridderConfig &gconf_, const T2 *grid_r_, T2 *grid_w_,
...@@ -1095,7 +1086,7 @@ template<typename T> vector<idx_t> getWgtIndices(const Baselines &baselines, ...@@ -1095,7 +1086,7 @@ template<typename T> vector<idx_t> getWgtIndices(const Baselines &baselines,
return res; return res;
} }
template<typename T> void ms2dirty_general(const mav<double,2> &uvw, template<typename T> void ms2dirty(const mav<double,2> &uvw,
const mav<double,1> &freq, const mav<complex<T>,2> &ms, const mav<double,1> &freq, const mav<complex<T>,2> &ms,
const mav<T,2> &wgt, double pixsize_x, double pixsize_y, size_t nu, size_t nv, double epsilon, const mav<T,2> &wgt, double pixsize_x, double pixsize_y, size_t nu, size_t nv, double epsilon,
bool do_wstacking, size_t nthreads, mav<T,2> &dirty, size_t verbosity, bool do_wstacking, size_t nthreads, mav<T,2> &dirty, size_t verbosity,
...@@ -1108,17 +1099,8 @@ template<typename T> void ms2dirty_general(const mav<double,2> &uvw, ...@@ -1108,17 +1099,8 @@ template<typename T> void ms2dirty_general(const mav<double,2> &uvw,
auto serv = makeMsServ(baselines,idx2,ms,wgt); auto serv = makeMsServ(baselines,idx2,ms,wgt);
x2dirty(gconf, serv, dirty, do_wstacking, verbosity); x2dirty(gconf, serv, dirty, do_wstacking, verbosity);
} }
template<typename T> void ms2dirty(const mav<double,2> &uvw,
const mav<double,1> &freq, const mav<complex<T>,2> &ms,
const mav<T,2> &wgt, double pixsize_x, double pixsize_y, double epsilon,
bool do_wstacking, size_t nthreads, mav<T,2> &dirty, size_t verbosity)
{
ms2dirty_general(uvw, freq, ms, wgt, pixsize_x, pixsize_y,
2*dirty.shape(0), 2*dirty.shape(1), epsilon, do_wstacking, nthreads,
dirty, verbosity);
}
template<typename T> void dirty2ms_general(const mav<double,2> &uvw, template<typename T> void dirty2ms(const mav<double,2> &uvw,
const mav<double,1> &freq, const mav<T,2> &dirty, const mav<double,1> &freq, const mav<T,2> &dirty,
const mav<T,2> &wgt, double pixsize_x, double pixsize_y, size_t nu, size_t nv,double epsilon, const mav<T,2> &wgt, double pixsize_x, double pixsize_y, size_t nu, size_t nv,double epsilon,
bool do_wstacking, size_t nthreads, mav<complex<T>,2> &ms, bool do_wstacking, size_t nthreads, mav<complex<T>,2> &ms,
...@@ -1133,21 +1115,10 @@ template<typename T> void dirty2ms_general(const mav<double,2> &uvw, ...@@ -1133,21 +1115,10 @@ template<typename T> void dirty2ms_general(const mav<double,2> &uvw,
auto serv = makeMsServ(baselines,idx2,ms,wgt); auto serv = makeMsServ(baselines,idx2,ms,wgt);
dirty2x(gconf, dirty, serv, do_wstacking, verbosity); dirty2x(gconf, dirty, serv, do_wstacking, verbosity);
} }
template<typename T> void dirty2ms(const mav<double,2> &uvw,
const mav<double,1> &freq, const mav<T,2> &dirty,
const mav<T,2> &wgt, double pixsize_x, double pixsize_y, double epsilon,
bool do_wstacking, size_t nthreads, mav<complex<T>,2> &ms, size_t verbosity)
{
dirty2ms_general(uvw, freq, dirty, wgt, pixsize_x, pixsize_y,
2*dirty.shape(0), 2*dirty.shape(1), epsilon, do_wstacking, nthreads, ms,
verbosity);
}
} // namespace detail } // namespace detail
// public names // public names
using detail::ms2dirty_general;
using detail::dirty2ms_general;
using detail::ms2dirty; using detail::ms2dirty;
using detail::dirty2ms; using detail::dirty2ms;
} // namespace gridder } // namespace gridder
......
...@@ -36,7 +36,7 @@ namespace py = pybind11; ...@@ -36,7 +36,7 @@ namespace py = pybind11;
auto None = py::none(); auto None = py::none();
template<typename T> py::array ms2dirty_general2(const py::array &uvw_, template<typename T> py::array ms2dirty2(const py::array &uvw_,
const py::array &freq_, const py::array &ms_, const py::object &wgt_, const py::array &freq_, const py::array &ms_, const py::object &wgt_,
size_t npix_x, size_t npix_y, double pixsize_x, double pixsize_y, size_t nu, size_t npix_x, size_t npix_y, double pixsize_x, double pixsize_y, size_t nu,
size_t nv, double epsilon, bool do_wstacking, size_t nthreads, size_t nv, double epsilon, bool do_wstacking, size_t nthreads,
...@@ -51,22 +51,22 @@ template<typename T> py::array ms2dirty_general2(const py::array &uvw_, ...@@ -51,22 +51,22 @@ template<typename T> py::array ms2dirty_general2(const py::array &uvw_,
auto dirty2 = to_mav<T,2>(dirty, true); auto dirty2 = to_mav<T,2>(dirty, true);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
ms2dirty_general(uvw,freq,ms,wgt2,pixsize_x,pixsize_y,nu,nv,epsilon, ms2dirty(uvw,freq,ms,wgt2,pixsize_x,pixsize_y,nu,nv,epsilon,
do_wstacking,nthreads,dirty2,verbosity); do_wstacking,nthreads,dirty2,verbosity);
} }
return move(dirty); return move(dirty);
} }
py::array Pyms2dirty_general(const py::array &uvw, py::array Pyms2dirty(const py::array &uvw,
const py::array &freq, const py::array &ms, const py::object &wgt, const py::array &freq, const py::array &ms, const py::object &wgt,
size_t npix_x, size_t npix_y, double pixsize_x, double pixsize_y, size_t nu, size_t npix_x, size_t npix_y, double pixsize_x, double pixsize_y, size_t nu,
size_t nv, double epsilon, bool do_wstacking, size_t nthreads, size_t nv, double epsilon, bool do_wstacking, size_t nthreads,
size_t verbosity) size_t verbosity)
{ {
if (isPyarr<complex<float>>(ms)) if (isPyarr<complex<float>>(ms))
return ms2dirty_general2<float>(uvw, freq, ms, wgt, npix_x, npix_y, return ms2dirty2<float>(uvw, freq, ms, wgt, npix_x, npix_y,
pixsize_x, pixsize_y, nu, nv, epsilon, do_wstacking, nthreads, verbosity); pixsize_x, pixsize_y, nu, nv, epsilon, do_wstacking, nthreads, verbosity);
if (isPyarr<complex<double>>(ms)) if (isPyarr<complex<double>>(ms))
return ms2dirty_general2<double>(uvw, freq, ms, wgt, npix_x, npix_y, return ms2dirty2<double>(uvw, freq, ms, wgt, npix_x, npix_y,
pixsize_x, pixsize_y, nu, nv, epsilon, do_wstacking, nthreads, verbosity); pixsize_x, pixsize_y, nu, nv, epsilon, do_wstacking, nthreads, verbosity);
MR_fail("type matching failed: 'ms' has neither type 'c8' nor 'c16'"); MR_fail("type matching failed: 'ms' has neither type 'c8' nor 'c16'");
} }
...@@ -89,6 +89,13 @@ npix_x, npix_y: int ...@@ -89,6 +89,13 @@ npix_x, npix_y: int
dimensions of the dirty image dimensions of the dirty image
pixsize_x, pixsize_y: float pixsize_x, pixsize_y: float
angular pixel size (in radians) of the dirty image angular pixel size (in radians) of the dirty image
nu, nv: int
dimensions of the (oversampled) intermediate uv grid
These values must be >= 1.2*the dimensions of the dirty image; tupical
oversampling values lie between 1.5 and 2.
Increasing the oversampling factor decreases the kernel support width
required for the desired accuracy, so it typically reduces run-time; on the
other hand, this will increase memory consumption.
epsilon: float epsilon: float
accuracy at which the computation should be done. Must be larger than 2e-13. accuracy at which the computation should be done. Must be larger than 2e-13.
If `ms` has type np.complex64, it must be larger than 1e-5. If `ms` has type np.complex64, it must be larger than 1e-5.
...@@ -107,17 +114,8 @@ Returns ...@@ -107,17 +114,8 @@ Returns
np.array((nxdirty, nydirty), dtype=float of same precision as `ms`) np.array((nxdirty, nydirty), dtype=float of same precision as `ms`)
the dirty image the dirty image
)"""; )""";
py::array Pyms2dirty(const py::array &uvw,
const py::array &freq, const py::array &ms, const py::object &wgt,
size_t npix_x, size_t npix_y, double pixsize_x, double pixsize_y, double epsilon,
bool do_wstacking, size_t nthreads, size_t verbosity)
{
return Pyms2dirty_general(uvw, freq, ms, wgt, npix_x, npix_y,
pixsize_x, pixsize_y, 2*npix_x, 2*npix_y, epsilon, do_wstacking, nthreads,
verbosity);
}
template<typename T> py::array dirty2ms_general2(const py::array &uvw_, template<typename T> py::array dirty2ms2(const py::array &uvw_,
const py::array &freq_, const py::array &dirty_, const py::object &wgt_, const py::array &freq_, const py::array &dirty_, const py::object &wgt_,
double pixsize_x, double pixsize_y, size_t nu, size_t nv, double epsilon, double pixsize_x, double pixsize_y, size_t nu, size_t nv, double epsilon,
bool do_wstacking, size_t nthreads, size_t verbosity) bool do_wstacking, size_t nthreads, size_t verbosity)
...@@ -131,21 +129,21 @@ template<typename T> py::array dirty2ms_general2(const py::array &uvw_, ...@@ -131,21 +129,21 @@ template<typename T> py::array dirty2ms_general2(const py::array &uvw_,
auto ms2 = to_mav<complex<T>,2>(ms, true); auto ms2 = to_mav<complex<T>,2>(ms, true);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
dirty2ms_general(uvw,freq,dirty,wgt2,pixsize_x,pixsize_y,nu,nv,epsilon, dirty2ms(uvw,freq,dirty,wgt2,pixsize_x,pixsize_y,nu,nv,epsilon,
do_wstacking,nthreads,ms2,verbosity); do_wstacking,nthreads,ms2,verbosity);
} }
return move(ms); return move(ms);
} }
py::array Pydirty2ms_general(const py::array &uvw, py::array Pydirty2ms(const py::array &uvw,
const py::array &freq, const py::array &dirty, const py::object &wgt, const py::array &freq, const py::array &dirty, const py::object &wgt,
double pixsize_x, double pixsize_y, size_t nu, size_t nv, double epsilon, double pixsize_x, double pixsize_y, size_t nu, size_t nv, double epsilon,
bool do_wstacking, size_t nthreads, size_t verbosity) bool do_wstacking, size_t nthreads, size_t verbosity)
{ {
if (isPyarr<float>(dirty)) if (isPyarr<float>(dirty))
return dirty2ms_general2<float>(uvw, freq, dirty, wgt, return dirty2ms2<float>(uvw, freq, dirty, wgt,
pixsize_x, pixsize_y, nu, nv, epsilon, do_wstacking, nthreads, verbosity); pixsize_x, pixsize_y, nu, nv, epsilon, do_wstacking, nthreads, verbosity);
if (isPyarr<double>(dirty)) if (isPyarr<double>(dirty))
return dirty2ms_general2<double>(uvw, freq, dirty, wgt, return dirty2ms2<double>(uvw, freq, dirty, wgt,
pixsize_x, pixsize_y, nu, nv, epsilon, do_wstacking, nthreads, verbosity); pixsize_x, pixsize_y, nu, nv, epsilon, do_wstacking, nthreads, verbosity);
MR_fail("type matching failed: 'dirty' has neither type 'f4' nor 'f8'"); MR_fail("type matching failed: 'dirty' has neither type 'f4' nor 'f8'");
} }
...@@ -166,6 +164,13 @@ wgt: np.array((nrows, nchan), same dtype as `dirty`), optional ...@@ -166,6 +164,13 @@ wgt: np.array((nrows, nchan), same dtype as `dirty`), optional
If present, its values are multiplied to the output If present, its values are multiplied to the output
pixsize_x, pixsize_y: float pixsize_x, pixsize_y: float
angular pixel size (in radians) of the dirty image angular pixel size (in radians) of the dirty image
nu, nv: int
dimensions of the (oversampled) intermediate uv grid
These values must be >= 1.2*the dimensions of the dirty image; tupical
oversampling values lie between 1.5 and 2.
Increasing the oversampling factor decreases the kernel support width
required for the desired accuracy, so it typically reduces run-time; on the
other hand, this will increase memory consumption.
epsilon: float epsilon: float
accuracy at which the computation should be done. Must be larger than 2e-13. accuracy at which the computation should be done. Must be larger than 2e-13.
If `dirty` has type np.float32, it must be larger than 1e-5. If `dirty` has type np.float32, it must be larger than 1e-5.
...@@ -184,31 +189,16 @@ Returns ...@@ -184,31 +189,16 @@ Returns
np.array((nrows, nchan,), dtype=complex of same precision as `dirty`) np.array((nrows, nchan,), dtype=complex of same precision as `dirty`)
the measurement set data. the measurement set data.
)"""; )""";
py::array Pydirty2ms(const py::array &uvw,
const py::array &freq, const py::array &dirty, const py::object &wgt,
double pixsize_x, double pixsize_y, double epsilon,
bool do_wstacking, size_t nthreads, size_t verbosity)
{
return Pydirty2ms_general(uvw, freq, dirty, wgt, pixsize_x, pixsize_y,
2*dirty.shape(0), 2*dirty.shape(1), epsilon, do_wstacking, nthreads,
verbosity);
}
void add_nifty_gridder(py::module &msup) void add_nifty_gridder(py::module &msup)
{ {
using namespace pybind11::literals; using namespace pybind11::literals;
auto m = msup.def_submodule("nifty_gridder"); auto m = msup.def_submodule("nifty_gridder");
m.def("ms2dirty", &Pyms2dirty, ms2dirty_DS, "uvw"_a, "freq"_a, "ms"_a, m.def("ms2dirty", &Pyms2dirty, "uvw"_a, "freq"_a, "ms"_a,
"wgt"_a=None, "npix_x"_a, "npix_y"_a, "pixsize_x"_a, "pixsize_y"_a,
"epsilon"_a, "do_wstacking"_a=false, "nthreads"_a=1, "verbosity"_a=0);
m.def("ms2dirty_general", &Pyms2dirty_general, "uvw"_a, "freq"_a, "ms"_a,
"wgt"_a=None, "npix_x"_a, "npix_y"_a, "pixsize_x"_a, "pixsize_y"_a, "nu"_a, "nv"_a, "wgt"_a=None, "npix_x"_a, "npix_y"_a, "pixsize_x"_a, "pixsize_y"_a, "nu"_a, "nv"_a,
"epsilon"_a, "do_wstacking"_a=false, "nthreads"_a=1, "verbosity"_a=0); "epsilon"_a, "do_wstacking"_a=false, "nthreads"_a=1, "verbosity"_a=0);
m.def("dirty2ms", &Pydirty2ms, dirty2ms_DS, "uvw"_a, "freq"_a, "dirty"_a, m.def("dirty2ms", &Pydirty2ms, "uvw"_a, "freq"_a, "dirty"_a,
"wgt"_a=None, "pixsize_x"_a, "pixsize_y"_a, "epsilon"_a,
"do_wstacking"_a=false, "nthreads"_a=1, "verbosity"_a=0);
m.def("dirty2ms_general", &Pydirty2ms_general, "uvw"_a, "freq"_a, "dirty"_a,
"wgt"_a=None, "pixsize_x"_a, "pixsize_y"_a, "nu"_a, "nv"_a, "epsilon"_a, "wgt"_a=None, "pixsize_x"_a, "pixsize_y"_a, "nu"_a, "nv"_a, "epsilon"_a,
"do_wstacking"_a=false, "nthreads"_a=1, "verbosity"_a=0); "do_wstacking"_a=false, "nthreads"_a=1, "verbosity"_a=0);
} }
......
...@@ -37,40 +37,6 @@ def explicit_gridder(uvw, freq, ms, wgt, nxdirty, nydirty, xpixsize, ypixsize, ...@@ -37,40 +37,6 @@ def explicit_gridder(uvw, freq, ms, wgt, nxdirty, nydirty, xpixsize, ypixsize,
return res/n return res/n
@pmp("nxdirty", (30, 128))
@pmp("nydirty", (128, 250))
@pmp("nrow", (2, 27))
@pmp("nchan", (1, 5))
@pmp("epsilon", (1e-1, 1e-3, 1e-5, 1e-7, 5e-13))
@pmp("singleprec", (True, False))
@pmp("wstacking", (True, False))
@pmp("use_wgt", (True, False))
@pmp("nthreads", (1, 2))
def test_adjointness_ms2dirty(nxdirty, nydirty, nrow, nchan, epsilon, singleprec, wstacking, use_wgt, nthreads):
if singleprec and epsilon < 5e-6:
return
np.random.seed(42)
pixsizex = np.pi/180/60/nxdirty*0.2398
pixsizey = np.pi/180/60/nxdirty
speedoflight, f0 = 299792458., 1e9
freq = f0 + np.arange(nchan)*(f0/nchan)
uvw = (np.random.rand(nrow, 3)-0.5)/(pixsizey*f0/speedoflight)
ms = np.random.rand(nrow, nchan)-0.5 + 1j*(np.random.rand(nrow, nchan)-0.5)
wgt = np.random.rand(nrow, nchan) if use_wgt else None
dirty = np.random.rand(nxdirty, nydirty)-0.5
if singleprec:
ms = ms.astype("c8")
dirty = dirty.astype("f4")
if wgt is not None:
wgt = wgt.astype("f4")
dirty2 = ng.ms2dirty(uvw, freq, ms, wgt, nxdirty, nydirty, pixsizex,
pixsizey, epsilon, wstacking, nthreads, 0).astype("f8")
ms2 = ng.dirty2ms(uvw, freq, dirty, wgt, pixsizex, pixsizey, epsilon,
wstacking, nthreads, 0).astype("c16")
tol = 5e-5 if singleprec else 5e-13
assert_allclose(np.vdot(ms, ms2).real, np.vdot(dirty2, dirty), rtol=tol)
@pmp("nxdirty", (30, 128)) @pmp("nxdirty", (30, 128))
@pmp("nydirty", (128, 250)) @pmp("nydirty", (128, 250))
@pmp("ofactor", (1.2, 1.5, 1.7, 2.0)) @pmp("ofactor", (1.2, 1.5, 1.7, 2.0))
...@@ -81,7 +47,7 @@ def test_adjointness_ms2dirty(nxdirty, nydirty, nrow, nchan, epsilon, singleprec ...@@ -81,7 +47,7 @@ def test_adjointness_ms2dirty(nxdirty, nydirty, nrow, nchan, epsilon, singleprec
@pmp("wstacking", (True, False)) @pmp("wstacking", (True, False))
@pmp("use_wgt", (True, False)) @pmp("use_wgt", (True, False))
@pmp("nthreads", (1, 2)) @pmp("nthreads", (1, 2))
def test_adjointness_ms2dirty_general(nxdirty, nydirty, ofactor, nrow, nchan, epsilon, singleprec, wstacking, use_wgt, nthreads): def test_adjointness_ms2dirty(nxdirty, nydirty, ofactor, nrow, nchan, epsilon, singleprec, wstacking, use_wgt, nthreads):
if singleprec and epsilon < 5e-5: if singleprec and epsilon < 5e-5:
return return
np.random.seed(42) np.random.seed(42)
...@@ -103,45 +69,14 @@ def test_adjointness_ms2dirty_general(nxdirty, nydirty, ofactor, nrow, nchan, ep ...@@ -103,45 +69,14 @@ def test_adjointness_ms2dirty_general(nxdirty, nydirty, ofactor, nrow, nchan, ep
dirty = dirty.astype("f4") dirty = dirty.astype("f4")
if wgt is not None: if wgt is not None:
wgt = wgt.astype("f4") wgt = wgt.astype("f4")
dirty2 = ng.ms2dirty_general(uvw, freq, ms, wgt, nxdirty, nydirty, pixsizex, dirty2 = ng.ms2dirty(uvw, freq, ms, wgt, nxdirty, nydirty, pixsizex,
pixsizey, nu, nv, epsilon, wstacking, nthreads, 0).astype("f8") pixsizey, nu, nv, epsilon, wstacking, nthreads, 0).astype("f8")
ms2 = ng.dirty2ms_general(uvw, freq, dirty, wgt, pixsizex, pixsizey, nu, nv, epsilon, ms2 = ng.dirty2ms(uvw, freq, dirty, wgt, pixsizex, pixsizey, nu, nv, epsilon,
wstacking, nthreads, 0).astype("c16") wstacking, nthreads, 0).astype("c16")
tol = 5e-4 if singleprec else 5e-11 tol = 5e-4 if singleprec else 5e-11
assert_allclose(np.vdot(ms, ms2).real, np.vdot(dirty2, dirty), rtol=tol) assert_allclose(np.vdot(ms, ms2).real, np.vdot(dirty2, dirty), rtol=tol)
@pmp('nxdirty', [16, 64])
@pmp('nydirty', [64])
@pmp("nrow", (2, 27))
@pmp("nchan", (1, 5))
@pmp("epsilon", (1e-2, 1e-3, 5e-5, 1e-7, 5e-13))
@pmp("singleprec", (True, False))
@pmp("wstacking", (True, False))
@pmp("use_wgt", (False, True))
@pmp("nthreads", (1, 2))
@pmp("fov", (1., 20.))
def test_ms2dirty_against_wdft(nxdirty, nydirty, nrow, nchan, epsilon, singleprec, wstacking, use_wgt, fov, nthreads):
if singleprec and epsilon < 5e-5:
return
np.random.seed(40)
pixsizex = fov*np.pi/180/nxdirty
pixsizey = fov*np.pi/180/nydirty*1.1
speedoflight, f0 = 299792458., 1e9
freq = f0 + np.arange(nchan)*(f0/nchan)
uvw = (np.random.rand(nrow, 3)-0.5)/(pixsizex*f0/speedoflight)
ms = np.random.rand(nrow, nchan)-0.5 + 1j*(np.random.rand(nrow, nchan)-0.5)
wgt = np.random.rand(nrow, nchan) if use_wgt else None
if singleprec:
ms = ms.astype("c8")
if wgt is not None:
wgt = wgt.astype("f4")
dirty = ng.ms2dirty(uvw, freq, ms, wgt, nxdirty, nydirty, pixsizex,
pixsizey, epsilon, wstacking, nthreads, 0).astype("f8")
ref = explicit_gridder(uvw, freq, ms, wgt, nxdirty, nydirty, pixsizex, pixsizey, wstacking)
assert_allclose(_l2error(dirty, ref), 0, atol=epsilon)
@pmp('nxdirty', [16, 64]) @pmp('nxdirty', [16, 64])
@pmp('nydirty', [64]) @pmp('nydirty', [64])
@pmp('ofactor', [1.2, 1.4, 1.7, 2]) @pmp('ofactor', [1.2, 1.4, 1.7, 2])
...@@ -173,7 +108,7 @@ def test_ms2dirty_against_wdft2(nxdirty, nydirty, ofactor, nrow, nchan, epsilon, ...@@ -173,7 +108,7 @@ def test_ms2dirty_against_wdft2(nxdirty, nydirty, ofactor, nrow, nchan, epsilon,
ms = ms.astype("c8") ms = ms.astype("c8")
if wgt is not None: if wgt is not None:
wgt = wgt.astype("f4") wgt = wgt.astype("f4")
dirty = ng.ms2dirty_general(uvw, freq, ms, wgt, nxdirty, nydirty, pixsizex, dirty = ng.ms2dirty(uvw, freq, ms, wgt, nxdirty, nydirty, pixsizex,
pixsizey, nu, nv, epsilon, wstacking, nthreads, 0).astype("f8") pixsizey, nu, nv, epsilon, wstacking, nthreads, 0).astype("f8")
ref = explicit_gridder(uvw, freq, ms, wgt, nxdirty, nydirty, pixsizex, pixsizey, wstacking) ref = explicit_gridder(uvw, freq, ms, wgt, nxdirty, nydirty, pixsizex, pixsizey, wstacking)
assert_allclose(_l2error(dirty, ref), 0, atol=epsilon) assert_allclose(_l2error(dirty, ref), 0, atol=epsilon)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment