Commit 8b1423c6 authored by Philipp Arras's avatar Philipp Arras
Browse files

Make convention for uv in gridder compatible with NIFTy spaces

parent 2d9e6511
Pipeline #50993 passed with stages
in 7 minutes and 51 seconds
...@@ -34,10 +34,14 @@ class GridderMaker(object): ...@@ -34,10 +34,14 @@ class GridderMaker(object):
raise ValueError("uv must be a 2D array") raise ValueError("uv must be a 2D array")
if uv.shape[1] != 2: if uv.shape[1] != 2:
raise ValueError("second dimension of uv must have length 2") raise ValueError("second dimension of uv must have length 2")
dstx, dsty = dirty_domain[0].distances
# wasteful hack to adjust to shape required by nifty_gridder # wasteful hack to adjust to shape required by nifty_gridder
uvw = np.empty((uv.shape[0],3), dtype=np.float64) uvw = np.empty((uv.shape[0],3), dtype=np.float64)
uvw[:,0:2] = uv uvw[:,0:2] = uv
uvw[:,2] = 0. uvw[:,2] = 0.
# Scale uv such that 0<uv<=1 which is assmued by nifty_gridder
uvw[:, 0] = uvw[:,0]*dstx
uvw[:, 1] = uvw[:,1]*dsty
speedOfLight = 299792458. speedOfLight = 299792458.
bl = nifty_gridder.Baselines(uvw, np.array([speedOfLight])) bl = nifty_gridder.Baselines(uvw, np.array([speedOfLight]))
nxdirty, nydirty = dirty_domain.shape nxdirty, nydirty = dirty_domain.shape
......
...@@ -39,20 +39,21 @@ def test_gridding(nu, nv, N, eps): ...@@ -39,20 +39,21 @@ def test_gridding(nu, nv, N, eps):
vis = np.random.randn(N) + 1j*np.random.randn(N) vis = np.random.randn(N) + 1j*np.random.randn(N)
# Nifty # Nifty
GM = ift.GridderMaker(ift.RGSpace((nu, nv)), uv=uv, eps=eps) dom = ift.RGSpace((nu, nv), distances=(0.2, 1.12))
dstx, dsty = dom.distances
uv[:,0] = uv[:,0]/dstx
uv[:,1] = uv[:,1]/dsty
GM = ift.GridderMaker(dom, uv=uv, eps=eps)
vis2 = ift.from_global_data(ift.UnstructuredDomain(vis.shape), vis) vis2 = ift.from_global_data(ift.UnstructuredDomain(vis.shape), vis)
Op = GM.getFull() Op = GM.getFull()
pynu = Op(vis2).to_global_data() pynu = Op(vis2).to_global_data()
import matplotlib.pyplot as plt
plt.imshow(pynu)
plt.show()
# DFT # DFT
x, y = np.meshgrid( x, y = np.meshgrid(
*[-ss/2 + np.arange(ss) for ss in [nu, nv]], indexing='ij') *[-ss/2 + np.arange(ss) for ss in [nu, nv]], indexing='ij')
dft = pynu*0. dft = pynu*0.
for i in range(N): for i in range(N):
dft += (vis[i]*np.exp(2j*np.pi*(x*uv[i, 0] + y*uv[i, 1]))).real dft += (vis[i]*np.exp(2j*np.pi*(x*uv[i, 0]*dstx + y*uv[i, 1]*dsty))).real
assert_(_l2error(dft, pynu) < eps) assert_(_l2error(dft, pynu) < eps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment