Commit 3b88ba48 authored by Martin Reinecke's avatar Martin Reinecke

seems to work. needs lots of polishing

parent 9aaf8783
Pipeline #50076 failed with stages
in 4 minutes and 59 seconds
......@@ -9,13 +9,13 @@ np.random.seed(40)
N0s, a0s, b0s, c0s = [], [], [], []
for ii in range(10, 23):
nu = 1024
nv = 1024
for ii in range(10, 26):
nu = 2048
nv = 2048
N = int(2**ii)
print('N = {}'.format(N))
uv = np.random.rand(N, 2) - 0.5
uvw = np.random.rand(N, 3) - 0.5
vis = np.random.randn(N) + 1j*np.random.randn(N)
uvspace = ift.RGSpace((nu, nv))
......@@ -27,7 +27,7 @@ for ii in range(10, 23):
img = ift.from_global_data(uvspace, img)
t0 = time()
GM = ift.GridderMaker(uvspace, eps=1e-7, uv=uv)
GM = ift.GridderMaker(uvspace, eps=1e-7, uvw=uvw, channel_fact=np.array([1.]))
vis = ift.from_global_data(visspace, vis)
op = GM.getFull().adjoint
t1 = time()
......@@ -35,6 +35,7 @@ for ii in range(10, 23):
t2 = time()
op.adjoint(vis).to_global_data()
t3 = time()
print(t2-t1, t3-t2)
N0s.append(N)
a0s.append(t1 - t0)
b0s.append(t2 - t1)
......
......@@ -26,20 +26,21 @@ from ..sugar import from_global_data, makeDomain
class GridderMaker(object):
def __init__(self, dirty_domain, uv, eps=2e-13):
def __init__(self, dirty_domain, uvw, channel_fact, eps=2e-13):
import nifty_gridder
dirty_domain = makeDomain(dirty_domain)
if (len(dirty_domain) != 1 or
not isinstance(dirty_domain[0], RGSpace) or
not len(dirty_domain.shape) == 2):
raise ValueError("need dirty_domain with exactly one 2D RGSpace")
bl = nifty_gridder.Baselines(uv, np.array([1.]));
if channel_fact.ndim != 1:
raise ValueError("channel_fact must be a 1D array")
bl = nifty_gridder.Baselines(uvw, channel_fact);
nxdirty, nydirty = dirty_domain.shape
gconf = nifty_gridder.GridderConfig(nxdirty, nydirty, eps, 1., 1.)
nu = gconf.Nu()
nv = gconf.Nv()
idx = bl.getIndices()
idx = gconf.reorderIndices(idx, bl)
idx = nifty_gridder.getIndices(bl, gconf)
grid_domain = RGSpace([nu, nv], distances=[1, 1], harmonic=False)
......@@ -62,43 +63,23 @@ class _RestOperator(LinearOperator):
self._domain = makeDomain(grid_domain)
self._target = makeDomain(dirty_domain)
self._gconf = gconf
fu = gconf.U_corrections()
fv = gconf.V_corrections()
nu, nv = dirty_domain.shape
# compute deconvolution operator
rng = np.arange(nu)
k = np.minimum(rng, nu-rng)
self._deconv_u = np.roll(fu[k], -nu//2).reshape((-1, 1))
rng = np.arange(nv)
k = np.minimum(rng, nv-rng)
self._deconv_v = np.roll(fv[k], -nv//2).reshape((1, -1))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
import nifty_gridder
self._check_input(x, mode)
nu, nv = self._target.shape
res = x.to_global_data()
if mode == self.TIMES:
res = hartley(res)
res = np.roll(res, (nu//2, nv//2), axis=(0, 1))
res = res[:nu, :nv]
res *= self._deconv_u
res *= self._deconv_v
res = self._gconf.grid2dirty(res)
else:
res = res*self._deconv_u
res *= self._deconv_v
nu2, nv2 = self._domain.shape
res = np.pad(res, ((0, nu2-nu), (0, nv2-nv)), mode='constant',
constant_values=0)
res = np.roll(res, (-nu//2, -nv//2), axis=(0, 1))
res = hartley(res)
res = self._gconf.dirty2grid(res)
return from_global_data(self._tgt(mode), res)
class RadioGridder(LinearOperator):
def __init__(self, grid_domain, bl, gconf, idx):
self._domain = DomainTuple.make(UnstructuredDomain((idx.shape[0],)))
self._domain = DomainTuple.make(UnstructuredDomain(
(idx.shape[0],)))
self._target = DomainTuple.make(grid_domain)
self._bl = bl
self._gconf = gconf
......@@ -109,9 +90,9 @@ class RadioGridder(LinearOperator):
import nifty_gridder
self._check_input(x, mode)
if mode == self.TIMES:
res = nifty_gridder.ms2grid(
self._bl, self._gconf, self._idx, x.to_global_data().reshape((-1,1)))
res = nifty_gridder.vis2grid(
self._bl, self._gconf, self._idx, x.to_global_data())
else:
res = nifty_gridder.grid2ms(
res = nifty_gridder.grid2vis(
self._bl, self._gconf, self._idx, x.to_global_data())
return from_global_data(self._tgt(mode), res)
......@@ -35,11 +35,13 @@ def _l2error(a, b):
@pmp('nv', [4, 12, 128])
@pmp('N', [1, 10, 100])
def test_gridding(nu, nv, N, eps):
uv = np.random.rand(N, 2) - 0.5
vis = np.random.randn(N) + 1j*np.random.randn(N)
uvw = np.random.rand(N, 3) - 0.5
ms = (np.random.randn(N) + 1j*np.random.randn(N)).reshape((-1,1))
# FIXME temporary!
vis = np.ones((N,))+1j*np.ones((N,))
# Nifty
GM = ift.GridderMaker(ift.RGSpace((nu, nv)), eps=eps, uv=uv)
GM = ift.GridderMaker(ift.RGSpace((nu, nv)), uvw=uvw, channel_fact=np.array([1.]), eps=eps)
vis2 = ift.from_global_data(ift.UnstructuredDomain(vis.shape), vis)
Op = GM.getFull()
......@@ -49,7 +51,7 @@ def test_gridding(nu, nv, N, eps):
*[-ss/2 + np.arange(ss) for ss in [nu, nv]], indexing='ij')
dft = pynu*0.
for i in range(N):
dft += (vis[i]*np.exp(2j*np.pi*(x*uv[i, 0] + y*uv[i, 1]))).real
dft += (vis[i]*np.exp(2j*np.pi*(x*uvw[i, 0] + y*uvw[i, 1]))).real
assert_(_l2error(dft, pynu) < eps)
......@@ -59,9 +61,8 @@ def test_gridding(nu, nv, N, eps):
@pmp('N', [1, 10, 100])
def test_build(nu, nv, N, eps):
dom = ift.RGSpace([nu, nv])
uv = np.random.rand(N, 2) - 0.5
GM = ift.GridderMaker(dom, eps=eps, uv=uv)
# re-order for performance
uvw = np.random.rand(N, 3) - 0.5
GM = ift.GridderMaker(dom, uvw=uvw, channel_fact=np.array([1.]), eps=eps)
R0 = GM.getGridder()
R1 = GM.getRest()
R = R1@R0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment