Commit 0360e6ca authored by Martin Reinecke's avatar Martin Reinecke

Merge remote-tracking branch 'origin/NIFTy_5' into pocketfft

parents d1aba839 157402e6
Pipeline #47602 passed with stages
in 17 minutes and 1 second
...@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y \ ...@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y \
# Testing dependencies # Testing dependencies
python3-pytest-cov jupyter \ python3-pytest-cov jupyter \
# Optional NIFTy dependencies # Optional NIFTy dependencies
libfftw3-dev python3-mpi4py python3-matplotlib python3-pynfft \ libfftw3-dev python3-mpi4py python3-matplotlib \
# more optional NIFTy dependencies # more optional NIFTy dependencies
&& pip3 install pyfftw \ && pip3 install pyfftw \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git \ && pip3 install git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git \
......
...@@ -8,7 +8,6 @@ import nifty5 as ift ...@@ -8,7 +8,6 @@ import nifty5 as ift
np.random.seed(40) np.random.seed(40)
N0s, a0s, b0s, c0s = [], [], [], [] N0s, a0s, b0s, c0s = [], [], [], []
N1s, a1s, b1s, c1s = [], [], [], []
for ii in range(10, 23): for ii in range(10, 23):
nu = 1024 nu = 1024
...@@ -44,18 +43,6 @@ for ii in range(10, 23): ...@@ -44,18 +43,6 @@ for ii in range(10, 23):
b0s.append(t2 - t1) b0s.append(t2 - t1)
c0s.append(t3 - t2) c0s.append(t3 - t2)
t0 = time()
op = ift.NFFT(uvspace, uv)
t1 = time()
op(img).to_global_data()
t2 = time()
op.adjoint(vis).to_global_data()
t3 = time()
N1s.append(N)
a1s.append(t1 - t0)
b1s.append(t2 - t1)
c1s.append(t3 - t2)
print('Measure rest operator') print('Measure rest operator')
sc = ift.StatCalculator() sc = ift.StatCalculator()
op = GM.getRest().adjoint op = GM.getRest().adjoint
...@@ -67,10 +54,9 @@ t_fft = sc.mean ...@@ -67,10 +54,9 @@ t_fft = sc.mean
print('FFT shape', res.shape) print('FFT shape', res.shape)
plt.scatter(N0s, a0s, label='Gridder mr') plt.scatter(N0s, a0s, label='Gridder mr')
plt.scatter(N1s, a1s, marker='^', label='NFFT')
plt.legend() plt.legend()
# no idea why this is necessary, but if it is omitted, the range is wrong # no idea why this is necessary, but if it is omitted, the range is wrong
plt.ylim(min(a0s+a1s), max(a0s+a1s)) plt.ylim(min(a0s), max(a0s))
plt.ylabel('time [s]') plt.ylabel('time [s]')
plt.title('Initialization') plt.title('Initialization')
plt.loglog() plt.loglog()
...@@ -78,9 +64,7 @@ plt.savefig('bench0.png') ...@@ -78,9 +64,7 @@ plt.savefig('bench0.png')
plt.close() plt.close()
plt.scatter(N0s, b0s, color='k', marker='^', label='Gridder mr times') plt.scatter(N0s, b0s, color='k', marker='^', label='Gridder mr times')
plt.scatter(N1s, b1s, color='r', marker='^', label='NFFT times')
plt.scatter(N0s, c0s, color='k', label='Gridder mr adjoint times') plt.scatter(N0s, c0s, color='k', label='Gridder mr adjoint times')
plt.scatter(N1s, c1s, color='r', label='NFFT adjoint times')
plt.axhline(sc.mean, label='FFT') plt.axhline(sc.mean, label='FFT')
plt.axhline(sc.mean + np.sqrt(sc.var)) plt.axhline(sc.mean + np.sqrt(sc.var))
plt.axhline(sc.mean - np.sqrt(sc.var)) plt.axhline(sc.mean - np.sqrt(sc.var))
......
...@@ -86,7 +86,6 @@ from .library.wiener_filter_curvature import WienerFilterCurvature ...@@ -86,7 +86,6 @@ from .library.wiener_filter_curvature import WienerFilterCurvature
from .library.correlated_fields import CorrelatedField, MfCorrelatedField from .library.correlated_fields import CorrelatedField, MfCorrelatedField
from .library.adjust_variances import (make_adjust_variances_hamiltonian, from .library.adjust_variances import (make_adjust_variances_hamiltonian,
do_adjust_variances) do_adjust_variances)
from .library.nfft import NFFT
from .library.gridder import GridderMaker from .library.gridder import GridderMaker
from . import extra from . import extra
......
...@@ -103,7 +103,9 @@ class GLSpace(StructuredDomain): ...@@ -103,7 +103,9 @@ class GLSpace(StructuredDomain):
The partner domain The partner domain
""" """
from ..domains.lm_space import LMSpace from ..domains.lm_space import LMSpace
return LMSpace(lmax=self._nlat-1, mmax=self._nlon//2) mmax = self._nlon//2
lmax = max(mmax, self._nlat-1)
return LMSpace(lmax=lmax, mmax=mmax)
def check_codomain(self, codomain): def check_codomain(self, codomain):
"""Raises `TypeError` if `codomain` is not a matching partner domain """Raises `TypeError` if `codomain` is not a matching partner domain
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2018-2019 Max-Planck-Society
#
# Resolve is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import nifty5 as ift
class NFFT(ift.LinearOperator):
"""Performs a non-equidistant Fourier transform, i.e. a Fourier transform
followed by a degridding operation.
Parameters
----------
domain : RGSpace
Domain of the operator. It has to be two-dimensional and have shape
`(2N, 2N)`. The coordinates of the lower left pixel of the dirty image
are `(-N,-N)`, and of the upper right pixel `(N-1,N-1)`.
uv : numpy.ndarray
2D numpy array of type float64 and shape (M,2), where M is the number
of measurements. uv[i,0] and uv[i,1] contain the u and v coordinates
of measurement #i, respectively. All coordinates must lie in the range
`[-0.5; 0,5[`.
"""
def __init__(self, domain, uv):
from pynfft.nfft import NFFT
npix = domain.shape[0]
assert npix == domain.shape[1]
assert len(domain.shape) == 2
assert type(npix) == int, "npix must be integer"
assert npix > 0 and (
npix % 2) == 0, "npix must be an even, positive integer"
assert isinstance(uv, np.ndarray), "uv must be a Numpy array"
assert uv.dtype == np.float64, "uv must be an array of float64"
assert uv.ndim == 2, "uv must be a 2D array"
assert uv.shape[0] > 0, "at least one point needed"
assert uv.shape[1] == 2, "the second dimension of uv must be 2"
assert np.all(uv >= -0.5) and np.all(uv <= 0.5),\
"all coordinates must lie between -0.5 and 0.5"
self._domain = ift.DomainTuple.make(domain)
self._target = ift.DomainTuple.make(
ift.UnstructuredDomain(uv.shape[0]))
self._capability = self.TIMES | self.ADJOINT_TIMES
self.npt = uv.shape[0]
self.plan = NFFT(self.domain.shape, self.npt, m=6)
self.plan.x = uv
self.plan.precompute()
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
self.plan.f_hat = x.to_global_data()
res = self.plan.trafo().copy()
else:
self.plan.f = x.to_global_data()
res = self.plan.adjoint().copy()
return ift.Field.from_global_data(self._tgt(mode), res)
...@@ -191,7 +191,12 @@ class PoissonianEnergy(EnergyOperator): ...@@ -191,7 +191,12 @@ class PoissonianEnergy(EnergyOperator):
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
res = x.sum() - x.log().vdot(self._d) res = x.sum()
tmp = res.val.local_data if isinstance(res, Linearization) else res
# if we have no infinity here, we can continue with the calculation;
# otherwise we know that the result must also be infinity
if not np.isinf(tmp):
res = res - x.log().vdot(self._d)
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return Field.scalar(res) return Field.scalar(res)
if not x.want_metric: if not x.want_metric:
......
...@@ -147,7 +147,16 @@ class FieldAdapter(LinearOperator): ...@@ -147,7 +147,16 @@ class FieldAdapter(LinearOperator):
return MultiField(self._tgt(mode), (x,)) return MultiField(self._tgt(mode), (x,))
def __repr__(self): def __repr__(self):
return 'FieldAdapter' s = 'FieldAdapter'
dom = isinstance(self._domain, MultiDomain)
tgt = isinstance(self._target, MultiDomain)
if dom and tgt:
s += ' {} <- {}'.format(self._target.keys(), self._domain.keys())
elif dom:
s += ' <- {}'.format(self._domain.keys())
elif tgt:
s += ' {} <-'.format(self._target.keys())
return s
class _SlowFieldAdapter(LinearOperator): class _SlowFieldAdapter(LinearOperator):
......
...@@ -293,10 +293,3 @@ def testValueInserter(sp, seed): ...@@ -293,10 +293,3 @@ def testValueInserter(sp, seed):
ind.append(np.random.randint(0, ss-1)) ind.append(np.random.randint(0, ss-1))
op = ift.ValueInserter(sp, ind) op = ift.ValueInserter(sp, ind)
ift.extra.consistency_check(op) ift.extra.consistency_check(op)
def testNFFT():
dom = ift.RGSpace(2*(16,))
uv = np.array([[.2, .4], [-.22, .452]])
op = ift.NFFT(dom, uv)
ift.extra.consistency_check(op)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment