Skip to content
Snippets Groups Projects
Commit 2d9e6511 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

revert to old gridder interface (almost)

parent ed095aa2
Branches
Tags
1 merge request!327Power grid
Pipeline #50985 passed
...@@ -10,30 +10,24 @@ np.random.seed(40) ...@@ -10,30 +10,24 @@ np.random.seed(40)
N0s, a0s, b0s, c0s = [], [], [], [] N0s, a0s, b0s, c0s = [], [], [], []
for ii in range(10, 26): for ii in range(10, 26):
fovx = 0.0001
fovy = 0.0002
nu = 1024 nu = 1024
nv = 1024 nv = 1024
N = int(2**ii) N = int(2**ii)
print('N = {}'.format(N)) print('N = {}'.format(N))
nchan=16
nrow=N//nchan uv = np.random.rand(N, 2) - 0.5
freq = 1e9+1e6*np.arange(nchan) vis = np.random.randn(N) + 1j*np.random.randn(N)
uvw = np.random.rand(nrow, 3) - 0.5
vis = (np.random.randn(N) + 1j*np.random.randn(N)).reshape((nrow,nchan))
uvspace = ift.RGSpace((nu, nv)) uvspace = ift.RGSpace((nu, nv))
visspace = ift.UnstructuredDomain((N//nchan,nchan)) visspace = ift.UnstructuredDomain(N)
img = np.random.randn(nu*nv) img = np.random.randn(nu*nv)
img = img.reshape((nu, nv)) img = img.reshape((nu, nv))
img = ift.from_global_data(uvspace, img) img = ift.from_global_data(uvspace, img)
t0 = time() t0 = time()
GM = ift.GridderMaker(uvspace, eps=1e-7, uvw=uvw, GM = ift.GridderMaker(uvspace, eps=1e-7, uv=uv)
freq=freq, fovx=fovx, fovy=fovy,
flags=np.zeros((N//nchan, nchan), dtype=np.bool))
vis = ift.from_global_data(visspace, vis) vis = ift.from_global_data(visspace, vis)
op = GM.getFull().adjoint op = GM.getFull().adjoint
t1 = time() t1 = time()
......
...@@ -24,20 +24,28 @@ import numpy as np ...@@ -24,20 +24,28 @@ import numpy as np
class GridderMaker(object): class GridderMaker(object):
def __init__(self, dirty_domain, uvw, freq, fovx, fovy, flags, eps=2e-13): def __init__(self, dirty_domain, uv, eps=2e-13):
import nifty_gridder import nifty_gridder
dirty_domain = makeDomain(dirty_domain) dirty_domain = makeDomain(dirty_domain)
if (len(dirty_domain) != 1 or not isinstance(dirty_domain[0], RGSpace) if (len(dirty_domain) != 1 or not isinstance(dirty_domain[0], RGSpace)
or not len(dirty_domain.shape) == 2): or not len(dirty_domain.shape) == 2):
raise ValueError("need dirty_domain with exactly one 2D RGSpace") raise ValueError("need dirty_domain with exactly one 2D RGSpace")
if freq.ndim != 1: if uv.ndim != 2:
raise ValueError("freq must be a 1D array") raise ValueError("uv must be a 2D array")
bl = nifty_gridder.Baselines(uvw, freq) if uv.shape[1] != 2:
raise ValueError("second dimension of uv must have length 2")
# wasteful hack to adjust to shape required by nifty_gridder
uvw = np.empty((uv.shape[0],3), dtype=np.float64)
uvw[:,0:2] = uv
uvw[:,2] = 0.
speedOfLight = 299792458.
bl = nifty_gridder.Baselines(uvw, np.array([speedOfLight]))
nxdirty, nydirty = dirty_domain.shape nxdirty, nydirty = dirty_domain.shape
gconf = nifty_gridder.GridderConfig(nxdirty, nydirty, eps, fovx, fovy) nxd, nyd = dirty_domain.shape
nu = gconf.Nu() gconf = nifty_gridder.GridderConfig(nxdirty, nydirty, eps, 1., 1.)
nv = gconf.Nv() nu, nv = gconf.Nu(), gconf.Nv()
self._idx = nifty_gridder.getIndices(bl, gconf, flags) self._idx = nifty_gridder.getIndices(
bl, gconf, np.zeros((uv.shape[0],1),dtype=np.bool))
self._bl = bl self._bl = bl
grid_domain = RGSpace([nu, nv], distances=[1, 1], harmonic=False) grid_domain = RGSpace([nu, nv], distances=[1, 1], harmonic=False)
...@@ -78,7 +86,7 @@ class _RestOperator(LinearOperator): ...@@ -78,7 +86,7 @@ class _RestOperator(LinearOperator):
class RadioGridder(LinearOperator): class RadioGridder(LinearOperator):
def __init__(self, grid_domain, bl, gconf, idx): def __init__(self, grid_domain, bl, gconf, idx):
self._domain = DomainTuple.make( self._domain = DomainTuple.make(
UnstructuredDomain((bl.Nrows(),bl.Nchannels()))) UnstructuredDomain((bl.Nrows())))
self._target = DomainTuple.make(grid_domain) self._target = DomainTuple.make(grid_domain)
self._bl = bl self._bl = bl
self._gconf = gconf self._gconf = gconf
...@@ -89,10 +97,10 @@ class RadioGridder(LinearOperator): ...@@ -89,10 +97,10 @@ class RadioGridder(LinearOperator):
import nifty_gridder import nifty_gridder
self._check_input(x, mode) self._check_input(x, mode)
if mode == self.TIMES: if mode == self.TIMES:
x = self._bl.ms2vis(x.to_global_data(), self._idx) x = self._bl.ms2vis(x.to_global_data().reshape((-1, 1)), self._idx)
res = nifty_gridder.vis2grid(self._bl, self._gconf, self._idx, x) res = nifty_gridder.vis2grid(self._bl, self._gconf, self._idx, x)
else: else:
res = nifty_gridder.grid2vis(self._bl, self._gconf, self._idx, res = nifty_gridder.grid2vis(self._bl, self._gconf, self._idx,
x.to_global_data()) x.to_global_data())
res = self._bl.vis2ms(res, self._idx) res = self._bl.vis2ms(res, self._idx).reshape((-1,))
return from_global_data(self._tgt(mode), res) return from_global_data(self._tgt(mode), res)
...@@ -29,37 +29,30 @@ pmp = pytest.mark.parametrize ...@@ -29,37 +29,30 @@ pmp = pytest.mark.parametrize
def _l2error(a, b): def _l2error(a, b):
return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2)) return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2))
speedOfLight = 299792458.
@pmp('eps', [1e-2, 1e-4, 1e-7, 1e-10, 1e-11, 1e-12, 2e-13]) @pmp('eps', [1e-2, 1e-4, 1e-7, 1e-10, 1e-11, 1e-12, 2e-13])
@pmp('nu', [12, 128]) @pmp('nu', [12, 128])
@pmp('nv', [4, 12, 128]) @pmp('nv', [4, 12, 128])
@pmp('N', [1, 10, 100]) @pmp('N', [1, 10, 100])
@pmp('freq', [1e9]) def test_gridding(nu, nv, N, eps):
def test_gridding(nu, nv, N, eps, freq): uv = np.random.rand(N, 2) - 0.5
fovx = 0.0001 vis = np.random.randn(N) + 1j*np.random.randn(N)
fovy = 0.0002
uvw = (np.random.rand(N, 3) - 0.5)
uvw[:,0] /= fovx*freq/speedOfLight
uvw[:,1] /= fovy*freq/speedOfLight
vis = (np.random.randn(N) + 1j*np.random.randn(N)).reshape((-1,1))
# Nifty # Nifty
GM = ift.GridderMaker(ift.RGSpace((nu, nv)), uvw=uvw, GM = ift.GridderMaker(ift.RGSpace((nu, nv)), uv=uv, eps=eps)
freq=np.array([freq]), eps=eps, fovx=fovx, fovy=fovy,
flags=np.zeros((N, 1), dtype=np.bool))
vis2 = ift.from_global_data(ift.UnstructuredDomain(vis.shape), vis) vis2 = ift.from_global_data(ift.UnstructuredDomain(vis.shape), vis)
Op = GM.getFull() Op = GM.getFull()
pynu = Op(vis2).to_global_data() pynu = Op(vis2).to_global_data()
import matplotlib.pyplot as plt
plt.imshow(pynu)
plt.show()
# DFT # DFT
x, y = np.meshgrid( x, y = np.meshgrid(
*[-ss/2 + np.arange(ss) for ss in [nu, nv]], indexing='ij') *[-ss/2 + np.arange(ss) for ss in [nu, nv]], indexing='ij')
x *= fovx*freq/speedOfLight
y *= fovy*freq/speedOfLight
dft = pynu*0. dft = pynu*0.
for i in range(N): for i in range(N):
dft += (vis[i]*np.exp(2j*np.pi*(x*uvw[i, 0] + y*uvw[i, 1]))).real dft += (vis[i]*np.exp(2j*np.pi*(x*uv[i, 0] + y*uv[i, 1]))).real
assert_(_l2error(dft, pynu) < eps) assert_(_l2error(dft, pynu) < eps)
...@@ -67,15 +60,10 @@ def test_gridding(nu, nv, N, eps, freq): ...@@ -67,15 +60,10 @@ def test_gridding(nu, nv, N, eps, freq):
@pmp('nu', [12, 128]) @pmp('nu', [12, 128])
@pmp('nv', [4, 12, 128]) @pmp('nv', [4, 12, 128])
@pmp('N', [1, 10, 100]) @pmp('N', [1, 10, 100])
@pmp('freq', [np.array([1e9]), np.array([1e9, 2e9, 2.5e9])]) def test_build(nu, nv, N, eps):
def test_build(nu, nv, N, eps, freq):
dom = ift.RGSpace([nu, nv]) dom = ift.RGSpace([nu, nv])
fov = np.pi/180/60 uv = np.random.rand(N, 2) - 0.5
uvw = np.random.rand(N, 3) - 0.5 GM = ift.GridderMaker(dom, uv=uv, eps=eps)
flags=np.zeros((N, freq.shape[0]), dtype=np.bool)
flags[0,0]=True
GM = ift.GridderMaker(dom, uvw=uvw, freq=freq, eps=eps,
flags=flags, fovx=fov, fovy=fov)
R0 = GM.getGridder() R0 = GM.getGridder()
R1 = GM.getRest() R1 = GM.getRest()
R = R1@R0 R = R1@R0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment