Skip to content
Snippets Groups Projects
Commit 859c79be authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fix gridder

parent 8e9c5244
No related branches found
No related tags found
1 merge request!327Power grid
Pipeline #50129 failed
......@@ -16,11 +16,11 @@ for ii in range(10, 26):
print('N = {}'.format(N))
uvw = np.random.rand(N, 3) - 0.5
vis = np.random.randn(N) + 1j*np.random.randn(N)
vis = (np.random.randn(N) + 1j*np.random.randn(N)).reshape((-1,1))
uvspace = ift.RGSpace((nu, nv))
visspace = ift.UnstructuredDomain(N)
visspace = ift.UnstructuredDomain((N,1))
img = np.random.randn(nu*nv)
img = img.reshape((nu, nv))
......
......@@ -77,7 +77,8 @@ class _RestOperator(LinearOperator):
class RadioGridder(LinearOperator):
def __init__(self, grid_domain, bl, gconf, idx):
self._domain = DomainTuple.make(UnstructuredDomain((idx.shape[0],)))
self._domain = DomainTuple.make(
UnstructuredDomain((bl.Nrows(),bl.Nchannels())))
self._target = DomainTuple.make(grid_domain)
self._bl = bl
self._gconf = gconf
......@@ -88,11 +89,10 @@ class RadioGridder(LinearOperator):
import nifty_gridder
self._check_input(x, mode)
if mode == self.TIMES:
x = x.to_global_data().reshape((-1, 1))
x = self._bl.ms2vis(x, self._idx)
x = self._bl.ms2vis(x.to_global_data(), self._idx)
res = nifty_gridder.vis2grid(self._bl, self._gconf, self._idx, x)
else:
res = nifty_gridder.grid2vis(self._bl, self._gconf, self._idx,
x.to_global_data())
res = self._bl.vis2ms(res, self._idx).reshape((-1,))
res = self._bl.vis2ms(res, self._idx)
return from_global_data(self._tgt(mode), res)
......@@ -37,7 +37,7 @@ def _l2error(a, b):
@pmp('channel_fact', [1, 1.2])
def test_gridding(nu, nv, N, eps, channel_fact):
uvw = np.random.rand(N, 3) - 0.5
vis = (np.random.randn(N) + 1j*np.random.randn(N))
vis = (np.random.randn(N) + 1j*np.random.randn(N)).reshape((-1,1))
# Nifty
GM = ift.GridderMaker(ift.RGSpace((nu, nv)), uvw=uvw,
......@@ -60,11 +60,14 @@ def test_gridding(nu, nv, N, eps, channel_fact):
@pmp('nu', [12, 128])
@pmp('nv', [4, 12, 128])
@pmp('N', [1, 10, 100])
def test_build(nu, nv, N, eps):
@pmp('cfact', [np.array([1.]), np.array([0.3, 0.5, 2.3])])
def test_build(nu, nv, N, eps, cfact):
dom = ift.RGSpace([nu, nv])
uvw = np.random.rand(N, 3) - 0.5
GM = ift.GridderMaker(dom, uvw=uvw, channel_fact=np.array([1.]), eps=eps,
flags=np.zeros((N, 1), dtype=np.bool))
flags=np.zeros((N, cfact.shape[0]), dtype=np.bool)
flags[0,0]=True
GM = ift.GridderMaker(dom, uvw=uvw, channel_fact=cfact, eps=eps,
flags=flags)
R0 = GM.getGridder()
R1 = GM.getRest()
R = R1@R0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment