Commit c26a26ea authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'gridder_mr' into 'NIFTy_5'

Gridder mr

See merge request !311
parents 214244d7 ed2787dc
Pipeline #47208 passed with stages
in 18 minutes and 54 seconds
......@@ -14,6 +14,7 @@ RUN apt-get update && apt-get install -y \
# more optional NIFTy dependencies
&& pip3 install pyfftw \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/ift/nifty_gridder.git \
&& pip3 install jupyter \
&& rm -rf /var/lib/apt/lists/*
......
......@@ -52,6 +52,8 @@ Optional dependencies:
- [pyFFTW](https://pypi.python.org/pypi/pyFFTW) for faster Fourier transforms
- [pyHealpix](https://gitlab.mpcdf.mpg.de/ift/pyHealpix) (for harmonic
transforms involving domains on the sphere)
- [nifty_gridder](https://gitlab.mpcdf.mpg.de/ift/nifty_gridder) (for radio
interferometry responses)
- [mpi4py](https://mpi4py.scipy.org) (for MPI-parallel execution)
- [matplotlib](https://matplotlib.org/) (for field plotting)
......@@ -97,6 +99,10 @@ Support for spherical harmonic transforms is added via:
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git
Support for the radio interferometry gridder is added via:
pip3 install git+https://gitlab.mpcdf.mpg.de/ift/nifty_gridder.git
MPI support is added via:
sudo apt-get install python3-mpi4py
......
from time import time
import matplotlib.pyplot as plt
import numpy as np
import nifty5 as ift
ift.fft.enable_fftw()
np.random.seed(40)
N0s, a0s, b0s, c0s = [], [], [], []
N1s, a1s, b1s, c1s = [], [], [], []
for ii in range(10, 23):
nu = 1024
nv = 1024
N = int(2**ii)
print('N = {}'.format(N))
uv = np.random.rand(N, 2) - 0.5
vis = np.random.randn(N) + 1j*np.random.randn(N)
uvspace = ift.RGSpace((nu, nv))
visspace = ift.UnstructuredDomain(N)
img = np.random.randn(nu*nv)
img = img.reshape((nu, nv))
img = ift.from_global_data(uvspace, img)
t0 = time()
GM = ift.GridderMaker(uvspace, eps=1e-7)
idx = GM.getReordering(uv)
uv = uv[idx]
vis = vis[idx]
vis = ift.from_global_data(visspace, vis)
op = GM.getFull(uv).adjoint
t1 = time()
op(img).to_global_data()
t2 = time()
op.adjoint(vis).to_global_data()
t3 = time()
N0s.append(N)
a0s.append(t1 - t0)
b0s.append(t2 - t1)
c0s.append(t3 - t2)
t0 = time()
op = ift.NFFT(uvspace, uv)
t1 = time()
op(img).to_global_data()
t2 = time()
op.adjoint(vis).to_global_data()
t3 = time()
N1s.append(N)
a1s.append(t1 - t0)
b1s.append(t2 - t1)
c1s.append(t3 - t2)
print('Measure rest operator')
sc = ift.StatCalculator()
op = GM.getRest().adjoint
for _ in range(10):
t0 = time()
res = op(img)
sc.add(time() - t0)
t_fft = sc.mean
print('FFT shape', res.shape)
plt.scatter(N0s, a0s, label='Gridder mr')
plt.scatter(N1s, a1s, marker='^', label='NFFT')
plt.legend()
# no idea why this is necessary, but if it is omitted, the range is wrong
plt.ylim(min(a0s+a1s), max(a0s+a1s))
plt.ylabel('time [s]')
plt.title('Initialization')
plt.loglog()
plt.savefig('bench0.png')
plt.close()
plt.scatter(N0s, b0s, color='k', marker='^', label='Gridder mr times')
plt.scatter(N1s, b1s, color='r', marker='^', label='NFFT times')
plt.scatter(N0s, c0s, color='k', label='Gridder mr adjoint times')
plt.scatter(N1s, c1s, color='r', label='NFFT adjoint times')
plt.axhline(sc.mean, label='FFT')
plt.axhline(sc.mean + np.sqrt(sc.var))
plt.axhline(sc.mean - np.sqrt(sc.var))
plt.legend()
plt.ylabel('time [s]')
plt.title('Apply')
plt.loglog()
plt.savefig('bench1.png')
plt.close()
......@@ -35,6 +35,10 @@ Support for spherical harmonic transforms is added via::
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git
Support for the radio interferometry gridder is added via:
pip3 install git+https://gitlab.mpcdf.mpg.de/ift/nifty_gridder.git
MPI support is added via::
sudo apt-get install python3-mpi4py
......
......@@ -87,6 +87,7 @@ from .library.correlated_fields import CorrelatedField, MfCorrelatedField
from .library.adjust_variances import (make_adjust_variances_hamiltonian,
do_adjust_variances)
from .library.nfft import NFFT
from .library.gridder import GridderMaker
from . import extra
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..domains.unstructured_domain import UnstructuredDomain
from ..fft import hartley
from ..operators.linear_operator import LinearOperator
from ..sugar import from_global_data, makeDomain
class GridderMaker(object):
def __init__(self, domain, eps=1e-15):
domain = makeDomain(domain)
if (len(domain) != 1 or not isinstance(domain[0], RGSpace) or
not len(domain.shape) == 2):
raise ValueError("need domain with exactly one 2D RGSpace")
nu, nv = domain.shape
if nu % 2 != 0 or nv % 2 != 0:
raise ValueError("dimensions must be even")
rat = 3 if eps < 1e-11 else 2
nu2, nv2 = rat*nu, rat*nv
nspread = int(-np.log(eps)/(np.pi*(rat-1)/(rat-.5)) + .5) + 1
nu2 = max([nu2, 2*nspread])
nv2 = max([nv2, 2*nspread])
r2lamb = rat*rat*nspread/(rat*(rat-.5))
oversampled_domain = RGSpace(
[nu2, nv2], distances=[1, 1], harmonic=False)
self._nspread = nspread
self._r2lamb = r2lamb
self._rest = _RestOperator(domain, oversampled_domain, r2lamb)
def getReordering(self, uv):
from nifty_gridder import peanoindex
nu2, nv2 = self._rest._domain.shape
return peanoindex(uv, nu2, nv2)
def getGridder(self, uv):
return RadioGridder(self._rest.domain, self._nspread, self._r2lamb, uv)
def getRest(self):
return self._rest
def getFull(self, uv):
return self.getRest() @ self.getGridder(uv)
class _RestOperator(LinearOperator):
def __init__(self, domain, oversampled_domain, r2lamb):
self._domain = makeDomain(oversampled_domain)
self._target = domain
nu, nv = domain.shape
nu2, nv2 = oversampled_domain.shape
# compute deconvolution operator
rng = np.arange(nu)
k = np.minimum(rng, nu-rng)
c = np.pi*r2lamb/nu2**2
self._deconv_u = np.roll(np.exp(c*k**2), -nu//2).reshape((-1, 1))
rng = np.arange(nv)
k = np.minimum(rng, nv-rng)
c = np.pi*r2lamb/nv2**2
self._deconv_v = np.roll(
np.exp(c*k**2)/r2lamb, -nv//2).reshape((1, -1))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
nu, nv = self._target.shape
res = x.to_global_data()
if mode == self.TIMES:
res = hartley(res)
res = np.roll(res, (nu//2, nv//2), axis=(0, 1))
res = res[:nu, :nv]
res *= self._deconv_u
res *= self._deconv_v
else:
res = res*self._deconv_u
res *= self._deconv_v
nu2, nv2 = self._domain.shape
res = np.pad(res, ((0, nu2-nu), (0, nv2-nv)), mode='constant',
constant_values=0)
res = np.roll(res, (-nu//2, -nv//2), axis=(0, 1))
res = hartley(res)
return from_global_data(self._tgt(mode), res)
class RadioGridder(LinearOperator):
def __init__(self, target, nspread, r2lamb, uv):
self._domain = DomainTuple.make(
UnstructuredDomain((uv.shape[0],)))
self._target = DomainTuple.make(target)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._nspread, self._r2lamb = int(nspread), float(r2lamb)
self._uv = uv # FIXME: should we write-protect this?
def apply(self, x, mode):
from nifty_gridder import (to_grid, to_grid_post,
from_grid, from_grid_pre)
self._check_input(x, mode)
nu2, nv2 = self._target.shape
x = x.to_global_data()
if mode == self.TIMES:
res = to_grid(self._uv, x, nu2, nv2, self._nspread, self._r2lamb)
res = to_grid_post(res)
else:
x = from_grid_pre(x)
res = from_grid(self._uv, x, nu2, nv2, self._nspread, self._r2lamb)
return from_global_data(self._tgt(mode), res)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import pytest
from numpy.testing import assert_allclose
import nifty5 as ift
np.random.seed(40)
pmp = pytest.mark.parametrize
@pmp('nu', [12, 128])
@pmp('nv', [4, 12, 128])
@pmp('N', [1, 10, 100])
def test_gridding(nu, nv, N):
uv = np.random.rand(N, 2) - 0.5
vis = np.random.randn(N) + 1j*np.random.randn(N)
# Nifty
GM = ift.GridderMaker(ift.RGSpace((nu, nv)))
# re-order for performance
idx = GM.getReordering(uv)
uv, vis = uv[idx], vis[idx]
vis2 = ift.from_global_data(ift.UnstructuredDomain(vis.shape), vis)
Op = GM.getFull(uv)
pynu = Op(vis2).to_global_data()
# DFT
x, y = np.meshgrid(
*[-ss/2 + np.arange(ss) for ss in [nu, nv]], indexing='ij')
dft = pynu*0.
for i in range(N):
dft += (vis[i]*np.exp(2j*np.pi*(x*uv[i, 0] + y*uv[i, 1]))).real
assert_allclose(dft, pynu)
@pmp('eps', [1e-2, 1e-6, 1e-15])
@pmp('nu', [12, 128])
@pmp('nv', [4, 12, 128])
@pmp('N', [1, 10, 100])
def test_build(nu, nv, N, eps):
dom = ift.RGSpace([nu, nv])
uv = np.random.rand(N, 2) - 0.5
GM = ift.GridderMaker(dom)
# re-order for performance
idx = GM.getReordering(uv)
uv = uv[idx]
R0 = GM.getGridder(uv)
R1 = GM.getRest()
R = R1@R0
RF = GM.getFull(uv)
# Consistency checks
flt = np.float64
cmplx = np.complex128
ift.extra.consistency_check(R0, cmplx, flt, only_r_linear=True)
ift.extra.consistency_check(R1, flt, flt)
ift.extra.consistency_check(R, cmplx, flt, only_r_linear=True)
ift.extra.consistency_check(RF, cmplx, flt, only_r_linear=True)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment