Skip to content
Snippets Groups Projects
Commit 61c1c853 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

prepare for new gridder (broken)

parent 60672433
No related branches found
No related tags found
1 merge request!324New gridder (again!)
Pipeline #48530 failed
...@@ -34,20 +34,17 @@ class GridderMaker(object): ...@@ -34,20 +34,17 @@ class GridderMaker(object):
nu, nv = domain.shape nu, nv = domain.shape
if nu % 2 != 0 or nv % 2 != 0: if nu % 2 != 0 or nv % 2 != 0:
raise ValueError("dimensions must be even") raise ValueError("dimensions must be even")
rat = 3 if eps < 1e-11 else 2 nu2, nv2 = 2*nu, 2*nv
nu2, nv2 = rat*nu, rat*nv w = int(-np.log10(eps)+1.9999)
nsafe = (w+1)//2
nspread = int(-np.log(eps)/(np.pi*(rat-1)/(rat-.5)) + .5) + 1 nu2 = max([nu2, 2*nsafe])
nu2 = max([nu2, 2*nspread]) nv2 = max([nv2, 2*nsafe])
nv2 = max([nv2, 2*nspread])
r2lamb = rat*rat*nspread/(rat*(rat-.5))
oversampled_domain = RGSpace( oversampled_domain = RGSpace(
[nu2, nv2], distances=[1, 1], harmonic=False) [nu2, nv2], distances=[1, 1], harmonic=False)
self._nspread = nspread self._eps = eps
self._r2lamb = r2lamb self._rest = _RestOperator(domain, oversampled_domain, eps)
self._rest = _RestOperator(domain, oversampled_domain, r2lamb)
def getReordering(self, uv): def getReordering(self, uv):
from nifty_gridder import peanoindex from nifty_gridder import peanoindex
...@@ -55,7 +52,7 @@ class GridderMaker(object): ...@@ -55,7 +52,7 @@ class GridderMaker(object):
return peanoindex(uv, nu2, nv2) return peanoindex(uv, nu2, nv2)
def getGridder(self, uv): def getGridder(self, uv):
return RadioGridder(self._rest.domain, self._nspread, self._r2lamb, uv) return RadioGridder(self._rest.domain, self._eps, uv)
def getRest(self): def getRest(self):
return self._rest return self._rest
...@@ -65,22 +62,22 @@ class GridderMaker(object): ...@@ -65,22 +62,22 @@ class GridderMaker(object):
class _RestOperator(LinearOperator): class _RestOperator(LinearOperator):
def __init__(self, domain, oversampled_domain, r2lamb): def __init__(self, domain, oversampled_domain, eps):
self._domain = makeDomain(oversampled_domain) self._domain = makeDomain(oversampled_domain)
self._target = domain self._target = domain
nu, nv = domain.shape nu, nv = domain.shape
nu2, nv2 = oversampled_domain.shape nu2, nv2 = oversampled_domain.shape
# compute deconvolution operator # compute deconvolution operator
rng = np.arange(nu) # rng = np.arange(nu)
k = np.minimum(rng, nu-rng) # k = np.minimum(rng, nu-rng)
c = np.pi*r2lamb/nu2**2 # c = np.pi*r2lamb/nu2**2
self._deconv_u = np.roll(np.exp(c*k**2), -nu//2).reshape((-1, 1)) # self._deconv_u = np.roll(np.exp(c*k**2), -nu//2).reshape((-1, 1))
rng = np.arange(nv) # rng = np.arange(nv)
k = np.minimum(rng, nv-rng) # k = np.minimum(rng, nv-rng)
c = np.pi*r2lamb/nv2**2 # c = np.pi*r2lamb/nv2**2
self._deconv_v = np.roll( # self._deconv_v = np.roll(
np.exp(c*k**2)/r2lamb, -nv//2).reshape((1, -1)) # np.exp(c*k**2)/r2lamb, -nv//2).reshape((1, -1))
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
...@@ -91,11 +88,11 @@ class _RestOperator(LinearOperator): ...@@ -91,11 +88,11 @@ class _RestOperator(LinearOperator):
res = hartley(res) res = hartley(res)
res = np.roll(res, (nu//2, nv//2), axis=(0, 1)) res = np.roll(res, (nu//2, nv//2), axis=(0, 1))
res = res[:nu, :nv] res = res[:nu, :nv]
res *= self._deconv_u # res *= self._deconv_u
res *= self._deconv_v # res *= self._deconv_v
else: else:
res = res*self._deconv_u # res = res*self._deconv_u
res *= self._deconv_v # res *= self._deconv_v
nu2, nv2 = self._domain.shape nu2, nv2 = self._domain.shape
res = np.pad(res, ((0, nu2-nu), (0, nv2-nv)), mode='constant', res = np.pad(res, ((0, nu2-nu), (0, nv2-nv)), mode='constant',
constant_values=0) constant_values=0)
...@@ -105,12 +102,12 @@ class _RestOperator(LinearOperator): ...@@ -105,12 +102,12 @@ class _RestOperator(LinearOperator):
class RadioGridder(LinearOperator): class RadioGridder(LinearOperator):
def __init__(self, target, nspread, r2lamb, uv): def __init__(self, target, eps, uv):
self._domain = DomainTuple.make( self._domain = DomainTuple.make(
UnstructuredDomain((uv.shape[0],))) UnstructuredDomain((uv.shape[0],)))
self._target = DomainTuple.make(target) self._target = DomainTuple.make(target)
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
self._nspread, self._r2lamb = int(nspread), float(r2lamb) self._eps = float(eps)
self._uv = uv # FIXME: should we write-protect this? self._uv = uv # FIXME: should we write-protect this?
def apply(self, x, mode): def apply(self, x, mode):
...@@ -120,9 +117,9 @@ class RadioGridder(LinearOperator): ...@@ -120,9 +117,9 @@ class RadioGridder(LinearOperator):
nu2, nv2 = self._target.shape nu2, nv2 = self._target.shape
x = x.to_global_data() x = x.to_global_data()
if mode == self.TIMES: if mode == self.TIMES:
res = to_grid(self._uv, x, nu2, nv2, self._nspread, self._r2lamb) res = to_grid(self._uv, x, nu2, nv2, self._eps)
res = to_grid_post(res) res = to_grid_post(res)
else: else:
x = from_grid_pre(x) x = from_grid_pre(x)
res = from_grid(self._uv, x, nu2, nv2, self._nspread, self._r2lamb) res = from_grid(self._uv, x, nu2, nv2, self._eps)
return from_global_data(self._tgt(mode), res) return from_global_data(self._tgt(mode), res)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment