Commit 9aaf8783 authored by Martin Reinecke's avatar Martin Reinecke

first try

parent a439a47c
Pipeline #49995 failed with stages
in 4 minutes and 59 seconds
......@@ -27,12 +27,9 @@ for ii in range(10, 23):
img = ift.from_global_data(uvspace, img)
t0 = time()
GM = ift.GridderMaker(uvspace, eps=1e-7)
idx = GM.getReordering(uv)
uv = uv[idx]
vis = vis[idx]
GM = ift.GridderMaker(uvspace, eps=1e-7, uv=uv)
vis = ift.from_global_data(visspace, vis)
op = GM.getFull(uv).adjoint
op = GM.getFull().adjoint
t1 = time()
op(img).to_global_data()
t2 = time()
......
......@@ -26,52 +26,45 @@ from ..sugar import from_global_data, makeDomain
class GridderMaker(object):
def __init__(self, domain, eps=2e-13):
from nifty_gridder import get_w
domain = makeDomain(domain)
if (len(domain) != 1 or not isinstance(domain[0], RGSpace) or
not len(domain.shape) == 2):
raise ValueError("need domain with exactly one 2D RGSpace")
nu, nv = domain.shape
if nu % 2 != 0 or nv % 2 != 0:
raise ValueError("dimensions must be even")
nu2, nv2 = 2*nu, 2*nv
w = get_w(eps)
nsafe = (w+1)//2
nu2 = max([nu2, 2*nsafe])
nv2 = max([nv2, 2*nsafe])
oversampled_domain = RGSpace(
[nu2, nv2], distances=[1, 1], harmonic=False)
self._eps = eps
self._rest = _RestOperator(domain, oversampled_domain, eps)
def getReordering(self, uv):
from nifty_gridder import peanoindex
nu2, nv2 = self._rest._domain.shape
return peanoindex(uv, nu2, nv2)
def getGridder(self, uv):
return RadioGridder(self._rest.domain, self._eps, uv)
def __init__(self, dirty_domain, uv, eps=2e-13):
import nifty_gridder
dirty_domain = makeDomain(dirty_domain)
if (len(dirty_domain) != 1 or
not isinstance(dirty_domain[0], RGSpace) or
not len(dirty_domain.shape) == 2):
raise ValueError("need dirty_domain with exactly one 2D RGSpace")
bl = nifty_gridder.Baselines(uv, np.array([1.]));
nxdirty, nydirty = dirty_domain.shape
gconf = nifty_gridder.GridderConfig(nxdirty, nydirty, eps, 1., 1.)
nu = gconf.Nu()
nv = gconf.Nv()
idx = bl.getIndices()
idx = gconf.reorderIndices(idx, bl)
grid_domain = RGSpace([nu, nv], distances=[1, 1], harmonic=False)
self._rest = _RestOperator(dirty_domain, grid_domain, gconf)
self._gridder = RadioGridder(grid_domain, bl, gconf, idx)
def getGridder(self):
return self._gridder
def getRest(self):
return self._rest
def getFull(self, uv):
return self.getRest() @ self.getGridder(uv)
def getFull(self):
return self.getRest() @ self._gridder
class _RestOperator(LinearOperator):
def __init__(self, domain, oversampled_domain, eps):
from nifty_gridder import correction_factors
self._domain = makeDomain(oversampled_domain)
self._target = domain
nu, nv = domain.shape
nu2, nv2 = oversampled_domain.shape
fu = correction_factors(nu2, nu//2+1, eps)
fv = correction_factors(nv2, nv//2+1, eps)
def __init__(self, dirty_domain, grid_domain, gconf):
import nifty_gridder
self._domain = makeDomain(grid_domain)
self._target = makeDomain(dirty_domain)
self._gconf = gconf
fu = gconf.U_corrections()
fv = gconf.V_corrections()
nu, nv = dirty_domain.shape
# compute deconvolution operator
rng = np.arange(nu)
k = np.minimum(rng, nu-rng)
......@@ -82,6 +75,7 @@ class _RestOperator(LinearOperator):
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
import nifty_gridder
self._check_input(x, mode)
nu, nv = self._target.shape
res = x.to_global_data()
......@@ -103,20 +97,21 @@ class _RestOperator(LinearOperator):
class RadioGridder(LinearOperator):
def __init__(self, target, eps, uv):
self._domain = DomainTuple.make(
UnstructuredDomain((uv.shape[0],)))
self._target = DomainTuple.make(target)
def __init__(self, grid_domain, bl, gconf, idx):
self._domain = DomainTuple.make(UnstructuredDomain((idx.shape[0],)))
self._target = DomainTuple.make(grid_domain)
self._bl = bl
self._gconf = gconf
self._idx = idx
self._capability = self.TIMES | self.ADJOINT_TIMES
self._eps = float(eps)
self._uv = uv # FIXME: should we write-protect this?
def apply(self, x, mode):
from nifty_gridder import to_grid, from_grid
import nifty_gridder
self._check_input(x, mode)
if mode == self.TIMES:
nu2, nv2 = self._target.shape
res = to_grid(self._uv, x.to_global_data(), nu2, nv2, self._eps)
res = nifty_gridder.ms2grid(
self._bl, self._gconf, self._idx, x.to_global_data().reshape((-1,1)))
else:
res = from_grid(self._uv, x.to_global_data(), self._eps)
res = nifty_gridder.grid2ms(
self._bl, self._gconf, self._idx, x.to_global_data())
return from_global_data(self._tgt(mode), res)
......@@ -39,13 +39,10 @@ def test_gridding(nu, nv, N, eps):
vis = np.random.randn(N) + 1j*np.random.randn(N)
# Nifty
GM = ift.GridderMaker(ift.RGSpace((nu, nv)), eps=eps)
# re-order for performance
idx = GM.getReordering(uv)
uv, vis = uv[idx], vis[idx]
GM = ift.GridderMaker(ift.RGSpace((nu, nv)), eps=eps, uv=uv)
vis2 = ift.from_global_data(ift.UnstructuredDomain(vis.shape), vis)
Op = GM.getFull(uv)
Op = GM.getFull()
pynu = Op(vis2).to_global_data()
# DFT
x, y = np.meshgrid(
......@@ -63,14 +60,12 @@ def test_gridding(nu, nv, N, eps):
def test_build(nu, nv, N, eps):
dom = ift.RGSpace([nu, nv])
uv = np.random.rand(N, 2) - 0.5
GM = ift.GridderMaker(dom, eps=eps)
GM = ift.GridderMaker(dom, eps=eps, uv=uv)
# re-order for performance
idx = GM.getReordering(uv)
uv = uv[idx]
R0 = GM.getGridder(uv)
R0 = GM.getGridder()
R1 = GM.getRest()
R = R1@R0
RF = GM.getFull(uv)
RF = GM.getFull()
# Consistency checks
flt = np.float64
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment