Commit f3d8625e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

try to fix data type checks and w sign

parent 9b17aae0
......@@ -1328,7 +1328,7 @@ template<typename T, typename Serv> void x2dirty(
nullmav<T,1>(), wcur, dw);
}
gconf.grid2dirty_c(cmav(grid), tdirty);
gconf.apply_wscreen(cmav(tdirty), tdirty, wcur, false);
gconf.apply_wscreen(cmav(tdirty), tdirty, wcur, true);
#pragma omp parallel for num_threads(nthreads)
for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j)
......@@ -1418,7 +1418,7 @@ template<typename T, typename Serv> void dirty2x(
for (size_t i=0; i<nx_dirty; ++i)
for (size_t j=0; j<ny_dirty; ++j)
tdirty2(i,j) = tdirty(i,j);
gconf.apply_wscreen(cmav(tdirty2), tdirty2, wcur, true);
gconf.apply_wscreen(cmav(tdirty2), tdirty2, wcur, false);
gconf.dirty2grid_c(cmav(tdirty2), grid);
grid2vis_c(srv.getBaselines(), gconf, cmav(idx_loc), cmav(grid), vis_loc,
nullmav<T,1>(), wcur, dw);
......
......@@ -36,17 +36,29 @@ auto None = py::none();
template<typename T>
using pyarr = py::array_t<T, 0>;
template<typename T> bool isPytype(const py::array &arr)
{
auto t1=arr.dtype();
auto t2=pybind11::dtype::of<T>();
auto k1=t1.kind();
auto k2=t2.kind();
auto s1=t1.itemsize();
auto s2=t2.itemsize();
return (k1==k2)&&(s1==s2);
}
template<typename T> pyarr<T> getPyarr(const py::array &arr, const string &name)
{
auto t1=arr.dtype();
auto t2=pybind11::dtype::of<T>();
myassert(t1.is(t2),
"type mismatch for array '", name, "': expected '", t2.kind(),
t2.itemsize(), "', but got '", t1.kind(), t1.itemsize(), "'.");
auto k1=t1.kind();
auto k2=t2.kind();
auto s1=t1.itemsize();
auto s2=t2.itemsize();
myassert((k1==k2)&&(s1==s2),
"type mismatch for array '", name, "': expected '", k2, s2,
"', but got '", k1, s1, "'.");
return arr.cast<pyarr<T>>();
}
template<typename T> bool isPytype(const py::array &arr)
{ return arr.dtype().is(pybind11::dtype::of<T>()); }
template<typename T> pyarr<T> makeArray(const vector<size_t> &shape)
{ return pyarr<T>(shape); }
......
......@@ -37,7 +37,7 @@ def _dft(uvw, vis, conf, apply_w):
for ii, vv in enumerate(vis):
phase = x*uvw[ii, 0] + y*uvw[ii, 1]
if apply_w:
phase += uvw[ii, 2]*nm1
phase -= uvw[ii, 2]*nm1
dirty += (vv*np.exp(2j*np.pi*phase)).real
if apply_w:
return dirty/n
......@@ -314,7 +314,7 @@ def test_wstack_against_wdft(nxdirty, nydirty, nchan, nrow, fov):
if len(iind) == 0:
continue
dd = conf.grid2dirty_c(ng.vis2grid_c(bl, conf, iind, bl.ms2vis(ms, iind)))
res0 += conf.apply_wscreen(dd, 0.5*(ws[ii+1]+ws[ii]), adjoint=False).real
res0 += conf.apply_wscreen(dd, 0.5*(ws[ii+1]+ws[ii]), adjoint=True).real
res1 = _dft(uvw, vis, conf, True)
assert_(_l2error(res0, res1) < 1e-4) # Very inaccurate
......@@ -346,7 +346,7 @@ def test_wplane_against_wdft(nxdirty, nydirty, nchan, nrow, fov):
w = np.min(uvw[:, 2])
iind = ng.getIndices(bl, conf, flags)
dd = conf.grid2dirty_c(ng.vis2grid_c(bl, conf, iind, bl.ms2vis(ms, iind)))
res0 = conf.apply_wscreen(dd, w, adjoint=False).real
res0 = conf.apply_wscreen(dd, w, adjoint=True).real
res1 = _dft(uvw, vis, conf, True)
assert_(_l2error(res0, res1) < epsilon)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment