Commit 736bf3b3 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'better_wscreen' into 'master'

try to improve accuracy of wscreen formula

See merge request !12
parents 6acee462 4e97d7e2
......@@ -594,8 +594,11 @@ template<typename T> class GridderConfig
complex<T> wscreen(double x, double y, double w, bool adjoint) const
{
constexpr double pi = 3.141592653589793238462643383279502884197;
double n = cos(sqrt(x+y)), xn = 1./n;
double phase = 2*pi*w*(n-1);
double eps = sqrt(x+y);
double s = sin(eps);
double nm1 = -s*s/(1.+cos(eps));
double n = nm1+1., xn = 1./n;
double phase = 2*pi*w*nm1;
if (adjoint) phase *= -1;
return complex<T>(cos(phase)*xn, sin(phase)*xn);
}
......
......@@ -16,7 +16,6 @@ def _init_gridder(nxdirty, nydirty, epsilon, nchan, nrow):
speedoflight, f0 = 3e8, 1e9
freq = f0 + np.arange(nchan)*(f0/nchan)
uvw = (np.random.rand(nrow, 3)-0.5)/(pixsize*f0/speedoflight)
uvw[:, 2] = 0.
baselines = ng.Baselines(coord=uvw, freq=freq)
flags = np.zeros((nrow, nchan), dtype=np.bool)
idx = ng.getIndices(baselines, conf, flags)
......@@ -32,6 +31,10 @@ def _wscreen(npix, dst, w):
return wscreen
def _l2error(a, b):
return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2))
@pmp("nxdirty", (128, 300))
@pmp("nydirty", (128, 250))
@pmp("nrow", (1, 10, 10000))
......@@ -168,10 +171,11 @@ def test_pickling():
@pmp('nx', [4, 18, 54])
@pmp('dx', [1., 0.13, 132])
@pmp('fov', [0.1, 1, 5]) # deg
@pmp('w', [0, 10, 8489])
def test_wstacking(nx, dx, w):
def test_wstacking(nx, fov, w):
np.random.seed(42)
dx = fov*np.pi/180/nx
ny, dy = nx, dx
conf = ng.GridderConfig(nx, ny, 1e-7, dx, dy)
x, y = conf.Nu(), conf.Nv()
......@@ -232,3 +236,84 @@ def test_correlations(nxdirty, nydirty, nrow, nchan, epsilon, du, dv, weight):
ind = (pp[0]+du) % (2*nxdirty), (pp[1]+dv) % (2*nydirty)
assert_allclose(y0[pp], y1[ind].real)
assert_allclose(np.zeros_like(y1), y1.imag)
@pmp('epsilon', [1e-2, 1e-4, 1e-7, 1e-10, 1e-11, 1e-12, 2e-13])
@pmp('nxdirty', [12, 128])
@pmp('nydirty', [4, 12, 128])
@pmp("nrow", (10, 100))
@pmp("nchan", (1, 10))
def test_against_dft(nxdirty, nydirty, epsilon, nchan, nrow):
bl, conf, idx = _init_gridder(nxdirty, nydirty, epsilon, nchan, nrow)
ms = np.random.rand(nrow, nchan)-0.5 + 1j*(np.random.rand(nrow, nchan)-0.5)
vis = bl.ms2vis(ms, idx)
res0 = conf.grid2dirty(ng.vis2grid(bl, conf, idx, vis))
x, y = np.meshgrid(*[-ss/2 + np.arange(ss) for ss in [nxdirty, nydirty]],
indexing='ij')
x *= conf.Pixsize_x()
y *= conf.Pixsize_y()
res1 = np.zeros_like(res0)
uvw = bl.effectiveuvw(idx)
for ii in idx:
phase = x*uvw[ii, 0] + y*uvw[ii, 1]
res1 += (vis[ii]*np.exp(2j*np.pi*phase)).real
assert_(_l2error(res0, res1) < epsilon)
@pmp('nxdirty', [16, 64])
@pmp('nydirty', [64])
@pmp("nrow", (10, 100, 1000))
@pmp("nchan", (1, 10))
@pmp("fov", (1,))
def test_against_wdft(nxdirty, nydirty, nchan, nrow, fov):
epsilon = 1e-7
np.random.seed(40)
pixsize = fov*np.pi/180/nxdirty
conf = ng.GridderConfig(nxdirty=nxdirty,
nydirty=nydirty,
epsilon=epsilon,
pixsize_x=pixsize,
pixsize_y=pixsize)
speedoflight, f0 = 3e8, 1e9
freq = f0 + np.arange(nchan)*(f0/nchan)
uvw = (np.random.rand(nrow, 3)-0.5)/(pixsize*f0/speedoflight)
bl = ng.Baselines(coord=uvw, freq=freq)
flags = np.zeros((nrow, nchan), dtype=np.bool)
idx = ng.getIndices(bl, conf, flags)
ms = np.random.rand(nrow, nchan)-0.5 + 1j*(np.random.rand(nrow, nchan)-0.5)
vis = bl.ms2vis(ms, idx)
uvw = bl.effectiveuvw(idx)
res0 = np.zeros((nxdirty, nydirty))
mi, ma = np.min(uvw[:, 2]), np.max(uvw[:, 2])
nplanes = 10000
ws = mi + np.arange(nplanes)*(ma-mi)/(nplanes-1)
for ii in range(len(ws)-1):
wkp1 = ws[ii+1]
if ii == nplanes-2:
wkp1 += abs(wkp1)
jj = ng.getIndices(bl, conf, flags, wmin=ws[ii], wmax=wkp1)
if len(jj) == 0:
continue
dd = conf.grid2dirty_c(ng.vis2grid_c(bl, conf, jj, bl.ms2vis(ms, jj)))
wforplane = 0.5*(ws[ii+1]+ws[ii])
res0 += conf.apply_wscreen(dd, wforplane, adjoint=False).real
# Compute dft with w term
x, y = np.meshgrid(*[-ss/2 + np.arange(ss) for ss in [nxdirty, nydirty]],
indexing='ij')
x *= conf.Pixsize_x()
y *= conf.Pixsize_y()
res1 = np.zeros_like(res0)
eps = np.sqrt(x**2+y**2)
s = np.sin(eps)
nm1 = -s*s/(1+np.cos(eps))
n = nm1+1
for ii in idx:
phase = x*uvw[ii, 0] + y*uvw[ii, 1] + uvw[ii, 2]*nm1
res1 += (vis[ii]*np.exp(2j*np.pi*phase)).real
res1 /= n
assert_(_l2error(res0, res1) < 1e-4)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment