Commit 9aaf8783 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

first try

parent a439a47c
Pipeline #49995 failed with stages
in 4 minutes and 59 seconds
...@@ -27,12 +27,9 @@ for ii in range(10, 23): ...@@ -27,12 +27,9 @@ for ii in range(10, 23):
img = ift.from_global_data(uvspace, img) img = ift.from_global_data(uvspace, img)
t0 = time() t0 = time()
GM = ift.GridderMaker(uvspace, eps=1e-7) GM = ift.GridderMaker(uvspace, eps=1e-7, uv=uv)
idx = GM.getReordering(uv)
uv = uv[idx]
vis = vis[idx]
vis = ift.from_global_data(visspace, vis) vis = ift.from_global_data(visspace, vis)
op = GM.getFull(uv).adjoint op = GM.getFull().adjoint
t1 = time() t1 = time()
op(img).to_global_data() op(img).to_global_data()
t2 = time() t2 = time()
......
...@@ -26,52 +26,45 @@ from ..sugar import from_global_data, makeDomain ...@@ -26,52 +26,45 @@ from ..sugar import from_global_data, makeDomain
class GridderMaker(object): class GridderMaker(object):
def __init__(self, domain, eps=2e-13): def __init__(self, dirty_domain, uv, eps=2e-13):
from nifty_gridder import get_w import nifty_gridder
domain = makeDomain(domain) dirty_domain = makeDomain(dirty_domain)
if (len(domain) != 1 or not isinstance(domain[0], RGSpace) or if (len(dirty_domain) != 1 or
not len(domain.shape) == 2): not isinstance(dirty_domain[0], RGSpace) or
raise ValueError("need domain with exactly one 2D RGSpace") not len(dirty_domain.shape) == 2):
nu, nv = domain.shape raise ValueError("need dirty_domain with exactly one 2D RGSpace")
if nu % 2 != 0 or nv % 2 != 0: bl = nifty_gridder.Baselines(uv, np.array([1.]));
raise ValueError("dimensions must be even") nxdirty, nydirty = dirty_domain.shape
nu2, nv2 = 2*nu, 2*nv gconf = nifty_gridder.GridderConfig(nxdirty, nydirty, eps, 1., 1.)
w = get_w(eps) nu = gconf.Nu()
nsafe = (w+1)//2 nv = gconf.Nv()
nu2 = max([nu2, 2*nsafe]) idx = bl.getIndices()
nv2 = max([nv2, 2*nsafe]) idx = gconf.reorderIndices(idx, bl)
oversampled_domain = RGSpace( grid_domain = RGSpace([nu, nv], distances=[1, 1], harmonic=False)
[nu2, nv2], distances=[1, 1], harmonic=False)
self._rest = _RestOperator(dirty_domain, grid_domain, gconf)
self._eps = eps self._gridder = RadioGridder(grid_domain, bl, gconf, idx)
self._rest = _RestOperator(domain, oversampled_domain, eps)
def getGridder(self):
def getReordering(self, uv): return self._gridder
from nifty_gridder import peanoindex
nu2, nv2 = self._rest._domain.shape
return peanoindex(uv, nu2, nv2)
def getGridder(self, uv):
return RadioGridder(self._rest.domain, self._eps, uv)
def getRest(self): def getRest(self):
return self._rest return self._rest
def getFull(self, uv): def getFull(self):
return self.getRest() @ self.getGridder(uv) return self.getRest() @ self._gridder
class _RestOperator(LinearOperator): class _RestOperator(LinearOperator):
def __init__(self, domain, oversampled_domain, eps): def __init__(self, dirty_domain, grid_domain, gconf):
from nifty_gridder import correction_factors import nifty_gridder
self._domain = makeDomain(oversampled_domain) self._domain = makeDomain(grid_domain)
self._target = domain self._target = makeDomain(dirty_domain)
nu, nv = domain.shape self._gconf = gconf
nu2, nv2 = oversampled_domain.shape fu = gconf.U_corrections()
fv = gconf.V_corrections()
fu = correction_factors(nu2, nu//2+1, eps) nu, nv = dirty_domain.shape
fv = correction_factors(nv2, nv//2+1, eps)
# compute deconvolution operator # compute deconvolution operator
rng = np.arange(nu) rng = np.arange(nu)
k = np.minimum(rng, nu-rng) k = np.minimum(rng, nu-rng)
...@@ -82,6 +75,7 @@ class _RestOperator(LinearOperator): ...@@ -82,6 +75,7 @@ class _RestOperator(LinearOperator):
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
import nifty_gridder
self._check_input(x, mode) self._check_input(x, mode)
nu, nv = self._target.shape nu, nv = self._target.shape
res = x.to_global_data() res = x.to_global_data()
...@@ -103,20 +97,21 @@ class _RestOperator(LinearOperator): ...@@ -103,20 +97,21 @@ class _RestOperator(LinearOperator):
class RadioGridder(LinearOperator): class RadioGridder(LinearOperator):
def __init__(self, target, eps, uv): def __init__(self, grid_domain, bl, gconf, idx):
self._domain = DomainTuple.make( self._domain = DomainTuple.make(UnstructuredDomain((idx.shape[0],)))
UnstructuredDomain((uv.shape[0],))) self._target = DomainTuple.make(grid_domain)
self._target = DomainTuple.make(target) self._bl = bl
self._gconf = gconf
self._idx = idx
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
self._eps = float(eps)
self._uv = uv # FIXME: should we write-protect this?
def apply(self, x, mode): def apply(self, x, mode):
from nifty_gridder import to_grid, from_grid import nifty_gridder
self._check_input(x, mode) self._check_input(x, mode)
if mode == self.TIMES: if mode == self.TIMES:
nu2, nv2 = self._target.shape res = nifty_gridder.ms2grid(
res = to_grid(self._uv, x.to_global_data(), nu2, nv2, self._eps) self._bl, self._gconf, self._idx, x.to_global_data().reshape((-1,1)))
else: else:
res = from_grid(self._uv, x.to_global_data(), self._eps) res = nifty_gridder.grid2ms(
self._bl, self._gconf, self._idx, x.to_global_data())
return from_global_data(self._tgt(mode), res) return from_global_data(self._tgt(mode), res)
...@@ -39,13 +39,10 @@ def test_gridding(nu, nv, N, eps): ...@@ -39,13 +39,10 @@ def test_gridding(nu, nv, N, eps):
vis = np.random.randn(N) + 1j*np.random.randn(N) vis = np.random.randn(N) + 1j*np.random.randn(N)
# Nifty # Nifty
GM = ift.GridderMaker(ift.RGSpace((nu, nv)), eps=eps) GM = ift.GridderMaker(ift.RGSpace((nu, nv)), eps=eps, uv=uv)
# re-order for performance
idx = GM.getReordering(uv)
uv, vis = uv[idx], vis[idx]
vis2 = ift.from_global_data(ift.UnstructuredDomain(vis.shape), vis) vis2 = ift.from_global_data(ift.UnstructuredDomain(vis.shape), vis)
Op = GM.getFull(uv) Op = GM.getFull()
pynu = Op(vis2).to_global_data() pynu = Op(vis2).to_global_data()
# DFT # DFT
x, y = np.meshgrid( x, y = np.meshgrid(
...@@ -63,14 +60,12 @@ def test_gridding(nu, nv, N, eps): ...@@ -63,14 +60,12 @@ def test_gridding(nu, nv, N, eps):
def test_build(nu, nv, N, eps): def test_build(nu, nv, N, eps):
dom = ift.RGSpace([nu, nv]) dom = ift.RGSpace([nu, nv])
uv = np.random.rand(N, 2) - 0.5 uv = np.random.rand(N, 2) - 0.5
GM = ift.GridderMaker(dom, eps=eps) GM = ift.GridderMaker(dom, eps=eps, uv=uv)
# re-order for performance # re-order for performance
idx = GM.getReordering(uv) R0 = GM.getGridder()
uv = uv[idx]
R0 = GM.getGridder(uv)
R1 = GM.getRest() R1 = GM.getRest()
R = R1@R0 R = R1@R0
RF = GM.getFull(uv) RF = GM.getFull()
# Consistency checks # Consistency checks
flt = np.float64 flt = np.float64
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment