Commit e14677b4 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'internal_double' into 'master'

Restructuring, tuning, bug fixes

See merge request !22
parents 5aa488ad 5c014080
Nifty gridder
Library for high-accuracy gridding/degridding of radio interferometry datasets
(Highly experimental pre-release version!)
Programming aspects
- written in C++11, fully portable
- shared-memory parallelization via OpenMP and C++ threads.
- Python interface available
- kernel computation is performed on the fly, avoiding inaccuracies
due to table lookup and reducing overall memory bandwidth
Numerical aspects
- uses the analytical gridding kernel presented in
- uses the "improved W-stacking method" described in (p. 139ff)
- in combination these two aspects allow extremely accurate gridding/degridding
operations (L2 error compared to explicit DFTs can go below 1e-12) with
reasonable resource consumption
Installation and prerequisites
- Clone the repository
- execute `pip3 install --user .`. This requires the g++ compiler.
`numpy` and `pybind11` should be installed automatically if necessary.
For the unit tests, `pytest` is required.
The file `` requires `casacore` and measurement set
import numpy as np
import nifty_gridder as ng
# Some generic nomenclature:
# (If this sounds completely stupid, it's because I have no radio background
# whatsoever; I'm more than happy to change this where needed!)
# nrow : integer (number of rows in a measurement set)
# nchan : integer (number of channels in a measurement set)
# baselines: an object containing
# - a float(nrow,3) array; this are the UvW coordinates in the measurement set
# - a float(nchan) array; this describes how the UVW need to be scaled for every channel
# gconf: an object containing information about a gridding setup,
# i.e. resolution of dirty image, requested accuracy etc.
# ms: a complex(nrow, nchan) array containing the visibilities of a measurement set
# Later on this will probably become (nrow, nchan, npol)
# flags: a bool(nrow, nchan) array. Where True, the corresponding visibilities
# will be ignored.
# idx: a 1D integer array containing indices of selected visibilities. One index
# ranges from 0 to nrow*nchan-1 and encodes row and channel number simultaneously
# to save space.
# vis: a 1D complex array, which is always accompanied by an "idx" array. It contains
# the visibilities of a "ms", extracted at "idx".
# grid: oversampled 2D grid in UV space onto which the visibilities are gridded
# (resolution is higher than that of the dirty image)
# dirty: float(nxdirty, nydirty): the dirty image
f0 = 1e9 # rough observation frequency
npixdirty = 1024
pixsize = np.pi/180/60/npixdirty # assume 1 arcmin FOV
speedoflight = 3e8
# number of rows in the measurement set
nrow = 10000
# number of channels
nchan = 100
# Frequency for all channels in the measurement set [Hz].
freq = f0 + np.arange(nchan)*(f0/nchan) # just use silly values for this example
# Invent mock UVW data. For every row, there is one UVW triple
# NOTE: I don't know how to set the w values properly, so I use zeros.
#uvmax = pixsize*np.max(freq)/sp
uvw = (np.random.rand(nrow,3)-0.5) / (pixsize*f0/speedoflight)
uvw[:,2] = 0.
# Build Baselines object from the geometrical information
baselines = ng.Baselines(coord=uvw, freq=freq)
# Build GridderConfig object describing how gridding should be done
# pixsize_x and pixsize_y are given in radians
gconf = ng.GridderConfig(nxdirty=npixdirty, nydirty=npixdirty, epsilon=1e-7, pixsize_x=pixsize, pixsize_y=pixsize)
# At this point everything about the experimental setup is known and explained
# to the gridder. The only thing still missing is actual data, i.e. visibilities
# and flags.
# Invent mock flags. This is a bool array of shape (nrow, nchan).
# For this test we set it completely to False
flags = np.zeros((nrow, nchan), dtype = np.bool)
# extract the indices for the subset of channels and w that we want to grid.
# The gconf object is needed here because knowing the gridding parameters helps
# to optimize the ordering of the returned indices.
# For parallel processing it is possible to create multiple index sets, each
# covering a different range of channels, or to generate all w slices in
# parallel.
# If the complete "flags" array does not fit into memory, we can adjust the
# interface: for example, we could just pass the flag sub-array that matches the
# selected channel range.
idx = ng.getIndices(baselines, gconf, flags)
# Invent mock visibilities. Currently we have just NPOL=1, resulting in
# (nrow, nchan) complex visibility values
# Currently, the gridder code refers to this as "ms". Suggestions with more
# appropriate names are welcome!
ms = np.random.rand(nrow,nchan)-0.5 + 1j*(np.random.rand(nrow,nchan)-0.5)
# extract the visibility data at the obtained indices from ms.
# For large-scale datasets where ms does not fit into memory, this needs to be
# done differently, but should still be straightforward.
vis = baselines.ms2vis(ms, idx)
# perform the gridding
# Many of these operations can be called in parallel with identical baselines
# and gconf arguments; they won't interfere with each other.
grid = ng.vis2grid(baselines, gconf, idx, vis)
# convert gridded data in UV space to dirty image (i.e. FFT, cropping,
# multiplication with correction function)
dirty = gconf.grid2dirty(grid)
# Adjointness test
dirty2 = np.random.rand(*dirty.shape)
ms2 = baselines.vis2ms(ng.grid2vis(baselines, gconf, idx, gconf.dirty2grid(dirty2)), idx)
print (np.vdot(ms,ms2).real, np.vdot(dirty, dirty2))
import matplotlib.pyplot as plt
import nifty_gridder as ng
import numpy as np
def _l2error(a, b):
return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2))
def explicit_gridder(uvw, freq, ms, nxdirty, nydirty, xpixsize, ypixsize):
speedoflight = 299792458.
x, y = np.meshgrid(*[-ss/2 + np.arange(ss) for ss in [nxdirty, nydirty]],
x *= xpixsize
y *= ypixsize
res = np.zeros((nxdirty, nydirty))
eps = x**2+y**2
nm1 = -eps/(np.sqrt(1.-eps)+1.)
n = nm1+1
for row in range(ms.shape[0]):
for chan in range(ms.shape[1]):
phase = (freq[chan]/speedoflight *
(x*uvw[row, 0] + y*uvw[row, 1] - uvw[row, 2]*nm1))
res += (ms[row, chan]*np.exp(2j*np.pi*phase)).real
return res/n
def test_against_wdft(nrow, nchan, nxdirty, nydirty, fov, epsilon, nthreads,
print("\n\nTesting gridding/degridding with {} rows and {} "
"frequency channels".format(nrow, nchan))
print("Dirty image has {}x{} pixels, "
"FOV={} degrees".format(nxdirty, nydirty, fov))
print("Requested accuracy: {}".format(epsilon))
print("Number of threads: {}".format(nthreads))
speedoflight = 299792458.
xpixsize = fov*np.pi/180/nxdirty
ypixsize = fov*np.pi/180/nydirty
f0 = 1e9
freq = f0 + np.arange(nchan)*(f0/nchan)
uvw = (np.random.rand(nrow, 3)-0.5)/(xpixsize*f0/speedoflight)
ms = np.random.rand(nrow, nchan)-0.5 + 1j*(np.random.rand(nrow, nchan)-0.5)
tdirty = np.random.rand(nxdirty, nydirty)-0.5
single = epsilon > 5e-6
if single:
print("\nCalling single-precision functions")
ms = ms.astype("c8")
tdirty = tdirty.astype("c8")
print("\nCalling double-precision functions")
if test_against_explicit:
print("\nTesting against explicit transform "
"(potentially VERY slow!)...")
truth = explicit_gridder(uvw, freq, ms, nxdirty, nydirty,
xpixsize, ypixsize)
res = ng.ms2dirty(uvw, freq, ms, None, nxdirty, nydirty, xpixsize,
ypixsize, epsilon, nthreads)
print("L2 error between explicit transform and gridder:",
_l2error(truth, res))
# test adjointness
print("\nTesting adjointness of the gridding/degridding operation")
adj1 = np.vdot(ng.ms2dirty(uvw, freq, ms, None, nxdirty, nydirty,
xpixsize, ypixsize, epsilon, nthreads,
adj2 = np.vdot(ms, ng.dirty2ms(uvw, freq, tdirty, None, xpixsize, ypixsize,
epsilon, nthreads, verbosity=2)).real
print("adjointness test:", np.abs(adj1-adj2)/np.maximum(np.abs(adj1),
test_against_wdft(100, 20, 64, 130, 0.5, 1e-12, 3, True)
test_against_wdft(1000, 300, 1024, 1024, 2., 1e-12, 4, False)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
# Copyright(C) 2019 Max-Planck-Society
from os.path import join
from time import time
import matplotlib.pyplot as plt
import nifty_gridder as ng
import numpy as np
from casacore.tables import table
# Assumptions:
# - Only one field
# - Only one spectral window
# - Flag both LL and RR if one is flagged
name = ''
t = table(name, readonly=True)
uvw = t.getcol("UVW") # [:, uvw]
ms = np.array(t.getcol("DATA"), dtype=np.complex128) # [:, ch, corr]
wgt = t.getcol("WEIGHT").astype("f8")
# Flag if one correlation is flagged
flags = np.any(np.array(t.getcol('FLAG'), np.bool), axis=2) # [:, ch]
if len(set(t.getcol('FIELD_ID'))) != 1:
raise RuntimeError
if len(set(t.getcol('DATA_DESC_ID'))) != 1:
raise RuntimeError
print('# Rows: {}'.format(ms.shape[0]))
print('# Channels: {}'.format(ms.shape[1]))
print('# Correlations: {}'.format(ms.shape[2]))
print("{} % flagged".format(np.sum(flags)/flags.size*100))
t = table(join(name, 'SPECTRAL_WINDOW'), readonly=True)
freq = t.getcol('CHAN_FREQ')[0]
# Select either RR+LL or XX+YY
t = table(join(name, 'POLARIZATION'), readonly=True)
pol = list(t.getcol('CORR_TYPE')[0])
if set(pol) <= set([5, 6, 7, 8]):
ind = [pol.index(5), pol.index(8)]
ind = [pol.index(9), pol.index(12)]
ms = np.sum(ms[:, :, ind], axis=2)
wgt = 1/np.sum(1/wgt, axis=1)
wgt = np.repeat(wgt[:, None], len(freq), axis=1)
wgt[flags] = 0
npixdirty = 756
DEG2RAD = np.pi/180
pixsize = 2.3/npixdirty*DEG2RAD
nthreads = 4
epsilon = 6e-6
t0 = time()
print('Start gridding...')
if epsilon > 5e-6:
ms = ms.astype("c8")
wgt = wgt.astype("f4")
dirty = ng.ms2dirty(
uvw, freq, ms, wgt, npixdirty, npixdirty, pixsize, pixsize, epsilon,
do_wstacking=True, nthreads=nthreads, verbosity=2)
t = time() - t0
print("{} s".format(t))
print("{} visibilities/thread/s".format(np.sum(wgt != 0)/nthreads/t))
This diff is collapsed.
This diff is collapsed.
......@@ -13,7 +13,8 @@ class _deferred_pybind11_include(object):
include_dirs = ['./', _deferred_pybind11_include(True),
extra_compile_args = []
extra_compile_args = ['-Wall', '-Wextra', '-Wfatal-errors', '-Wstrict-aliasing=2', '-Wwrite-strings', '-Wredundant-decls', '-Woverloaded-virtual', '-Wcast-qual', '-Wcast-align', '-Wpointer-arith', '-Wfloat-conversion']
#, '-Wsign-conversion', '-Wconversion'
python_module_link_args = []
if sys.platform == 'darwin':
......@@ -32,7 +33,8 @@ else:
def get_extension_modules():
return [Extension('nifty_gridder',
depends=['pocketfft_hdronly.h', ''],
depends=['pocketfft_hdronly.h', 'gridder_cxx.h',
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment