From c05ca91676b66e6f516c493a593df9aa023ded89 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Thu, 16 May 2019 13:12:14 +0200 Subject: [PATCH] fixes --- nifty5/library/gridder.py | 26 +++++++++++++------------- test/test_operators/test_nft.py | 10 +++++++--- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/nifty5/library/gridder.py b/nifty5/library/gridder.py index 17b2b3702..dee33550b 100644 --- a/nifty5/library/gridder.py +++ b/nifty5/library/gridder.py @@ -63,21 +63,21 @@ class GridderMaker(object): class _RestOperator(LinearOperator): def __init__(self, domain, oversampled_domain, eps): + from nifty_gridder import correction_factors self._domain = makeDomain(oversampled_domain) self._target = domain nu, nv = domain.shape nu2, nv2 = oversampled_domain.shape + fu = correction_factors(nu2, nu//2+1, eps) + fv = correction_factors(nv2, nv//2+1, eps) # compute deconvolution operator -# rng = np.arange(nu) -# k = np.minimum(rng, nu-rng) -# c = np.pi*r2lamb/nu2**2 -# self._deconv_u = np.roll(np.exp(c*k**2), -nu//2).reshape((-1, 1)) -# rng = np.arange(nv) -# k = np.minimum(rng, nv-rng) -# c = np.pi*r2lamb/nv2**2 -# self._deconv_v = np.roll( -# np.exp(c*k**2)/r2lamb, -nv//2).reshape((1, -1)) + rng = np.arange(nu) + k = np.minimum(rng, nu-rng) + self._deconv_u = np.roll(fu[k], -nu//2).reshape((-1, 1)) + rng = np.arange(nv) + k = np.minimum(rng, nv-rng) + self._deconv_v = np.roll(fv[k], -nv//2).reshape((1, -1)) self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): @@ -88,11 +88,11 @@ class _RestOperator(LinearOperator): res = hartley(res) res = np.roll(res, (nu//2, nv//2), axis=(0, 1)) res = res[:nu, :nv] -# res *= self._deconv_u -# res *= self._deconv_v + res *= self._deconv_u + res *= self._deconv_v else: -# res = res*self._deconv_u -# res *= self._deconv_v + res = res*self._deconv_u + res *= self._deconv_v nu2, nv2 = self._domain.shape res = np.pad(res, ((0, nu2-nu), (0, nv2-nv)), mode='constant', constant_values=0) diff --git a/test/test_operators/test_nft.py b/test/test_operators/test_nft.py index f62738b4a..deab5845c 100644 --- a/test/test_operators/test_nft.py +++ b/test/test_operators/test_nft.py @@ -25,16 +25,20 @@ np.random.seed(40) pmp = pytest.mark.parametrize +def _l2error(a,b): + return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2)) + +@pmp('eps', [1e-2, 1e-6, 1e-7, 1e-15]) @pmp('nu', [12, 128]) @pmp('nv', [4, 12, 128]) @pmp('N', [1, 10, 100]) -def test_gridding(nu, nv, N): +def test_gridding(nu, nv, N, eps): uv = np.random.rand(N, 2) - 0.5 vis = np.random.randn(N) + 1j*np.random.randn(N) # Nifty - GM = ift.GridderMaker(ift.RGSpace((nu, nv))) + GM = ift.GridderMaker(ift.RGSpace((nu, nv)),eps=eps) # re-order for performance idx = GM.getReordering(uv) uv, vis = uv[idx], vis[idx] @@ -48,7 +52,7 @@ def test_gridding(nu, nv, N): dft = pynu*0. for i in range(N): dft += (vis[i]*np.exp(2j*np.pi*(x*uv[i, 0] + y*uv[i, 1]))).real - assert_allclose(dft, pynu) + assert(_l2error(dft,pynu)<max(1e-13,10*eps)) @pmp('eps', [1e-2, 1e-6, 1e-15]) -- GitLab