Commit a439a47c authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'new_gridder' into 'NIFTy_5'

New gridder (again!)

See merge request !324
parents afbc22bc b667dc47
Pipeline #50416 passed with stages
in 19 minutes
...@@ -10,11 +10,11 @@ RUN apt-get update && apt-get install -y \ ...@@ -10,11 +10,11 @@ RUN apt-get update && apt-get install -y \
# Testing dependencies # Testing dependencies
python3-pytest-cov jupyter \ python3-pytest-cov jupyter \
# Optional NIFTy dependencies # Optional NIFTy dependencies
libfftw3-dev python3-mpi4py python3-matplotlib \ python3-mpi4py python3-matplotlib \
# more optional NIFTy dependencies # more optional NIFTy dependencies
&& pip3 install pyfftw \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git \ && pip3 install git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/ift/nifty_gridder.git \ && pip3 install git+https://gitlab.mpcdf.mpg.de/ift/nifty_gridder.git \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/mtr/pypocketfft.git \
&& pip3 install jupyter \ && pip3 install jupyter \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
......
...@@ -47,9 +47,9 @@ Installation ...@@ -47,9 +47,9 @@ Installation
- [Python 3](https://www.python.org/) (3.5.x or later) - [Python 3](https://www.python.org/) (3.5.x or later)
- [SciPy](https://www.scipy.org/) - [SciPy](https://www.scipy.org/)
- [pypocketfft](https://gitlab.mpcdf.mpg.de/mtr/pypocketfft)
Optional dependencies: Optional dependencies:
- [pyFFTW](https://pypi.python.org/pypi/pyFFTW) for faster Fourier transforms
- [pyHealpix](https://gitlab.mpcdf.mpg.de/ift/pyHealpix) (for harmonic - [pyHealpix](https://gitlab.mpcdf.mpg.de/ift/pyHealpix) (for harmonic
transforms involving domains on the sphere) transforms involving domains on the sphere)
- [nifty_gridder](https://gitlab.mpcdf.mpg.de/ift/nifty_gridder) (for radio - [nifty_gridder](https://gitlab.mpcdf.mpg.de/ift/nifty_gridder) (for radio
...@@ -73,28 +73,12 @@ NIFTy5 and its mandatory dependencies can be installed via: ...@@ -73,28 +73,12 @@ NIFTy5 and its mandatory dependencies can be installed via:
sudo apt-get install git python3 python3-pip python3-dev sudo apt-get install git python3 python3-pip python3-dev
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/nifty.git@NIFTy_5 pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/nifty.git@NIFTy_5
pip3 install --user git+https://gitlab.mpcdf.mpg.de/mtr/pypocketfft
Plotting support is added via: Plotting support is added via:
sudo apt-get install python3-matplotlib sudo apt-get install python3-matplotlib
NIFTy uses Numpy's FFT implementation by default. For large problems FFTW may be
used because of its higher performance. It can be installed via:
sudo apt-get install libfftw3-dev
pip3 install --user pyfftw
To enable FFTW usage in NIFTy, call
nifty5.fft.enable_fftw()
at the beginning of your code.
(Note: If you encounter problems related to `pyFFTW`, make sure that you are
using a pip-installed `pyFFTW` package. Unfortunately, some distributions are
shipping an incorrectly configured `pyFFTW` package, which does not cooperate
with the installed `FFTW3` libraries.)
Support for spherical harmonic transforms is added via: Support for spherical harmonic transforms is added via:
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git
......
...@@ -5,7 +5,6 @@ import numpy as np ...@@ -5,7 +5,6 @@ import numpy as np
import nifty5 as ift import nifty5 as ift
ift.fft.enable_fftw()
np.random.seed(40) np.random.seed(40)
N0s, a0s, b0s, c0s = [], [], [], [] N0s, a0s, b0s, c0s = [], [], [], []
......
...@@ -9,28 +9,12 @@ NIFTy5 and its mandatory dependencies can be installed via:: ...@@ -9,28 +9,12 @@ NIFTy5 and its mandatory dependencies can be installed via::
sudo apt-get install git python3 python3-pip python3-dev sudo apt-get install git python3 python3-pip python3-dev
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/nifty.git@NIFTy_5 pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/nifty.git@NIFTy_5
pip3 install --user git+https://gitlab.mpcdf.mpg.de/mtr/pypocketfft
Plotting support is added via:: Plotting support is added via::
sudo apt-get install python3-matplotlib sudo apt-get install python3-matplotlib
NIFTy uses Numpy's FFT implementation by default. For large problems FFTW may be
used because of its higher performance. It can be installed via::
sudo apt-get install libfftw3-dev
pip3 install --user pyfftw
To enable FFTW usage in NIFTy, call::
nifty5.fft.enable_fftw()
at the beginning of your code.
(Note: If you encounter problems related to `pyFFTW`, make sure that you are
using a pip-installed `pyFFTW` package. Unfortunately, some distributions are
shipping an incorrectly configured `pyFFTW` package, which does not cooperate
with the installed `FFTW3` libraries.)
Support for spherical harmonic transforms is added via:: Support for spherical harmonic transforms is added via::
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git
......
...@@ -17,8 +17,10 @@ ...@@ -17,8 +17,10 @@
import numpy as np import numpy as np
from .domain_tuple import DomainTuple
from .field import Field from .field import Field
from .linearization import Linearization from .linearization import Linearization
from .multi_domain import MultiDomain
from .operators.linear_operator import LinearOperator from .operators.linear_operator import LinearOperator
from .sugar import from_random from .sugar import from_random
...@@ -70,12 +72,20 @@ def _full_implementation(op, domain_dtype, target_dtype, atol, rtol, ...@@ -70,12 +72,20 @@ def _full_implementation(op, domain_dtype, target_dtype, atol, rtol,
def _check_linearity(op, domain_dtype, atol, rtol): def _check_linearity(op, domain_dtype, atol, rtol):
fld1 = from_random("normal", op.domain, dtype=domain_dtype) fld1 = from_random("normal", op.domain, dtype=domain_dtype)
fld2 = from_random("normal", op.domain, dtype=domain_dtype) fld2 = from_random("normal", op.domain, dtype=domain_dtype)
alpha = np.random.random() alpha = np.random.random() # FIXME: this can break badly with MPI!
val1 = op(alpha*fld1+fld2) val1 = op(alpha*fld1+fld2)
val2 = alpha*op(fld1)+op(fld2) val2 = alpha*op(fld1)+op(fld2)
_assert_allclose(val1, val2, atol=atol, rtol=rtol) _assert_allclose(val1, val2, atol=atol, rtol=rtol)
def _domain_check(op):
for dd in [op.domain, op.target]:
if not isinstance(dd, (DomainTuple, MultiDomain)):
raise TypeError(
'The domain and the target of an operator need to',
'be instances of either DomainTuple or MultiDomain.')
def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64, def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
atol=0, rtol=1e-7, only_r_linear=False): atol=0, rtol=1e-7, only_r_linear=False):
""" """
...@@ -109,6 +119,7 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64, ...@@ -109,6 +119,7 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
""" """
if not isinstance(op, LinearOperator): if not isinstance(op, LinearOperator):
raise TypeError('This test tests only linear operators.') raise TypeError('This test tests only linear operators.')
_domain_check(op)
_check_linearity(op, domain_dtype, atol, rtol) _check_linearity(op, domain_dtype, atol, rtol)
_full_implementation(op, domain_dtype, target_dtype, atol, rtol, _full_implementation(op, domain_dtype, target_dtype, atol, rtol,
only_r_linear) only_r_linear)
...@@ -162,6 +173,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100): ...@@ -162,6 +173,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100):
tol : float tol : float
Tolerance for the check. Tolerance for the check.
""" """
_domain_check(op)
for _ in range(ntries): for _ in range(ntries):
lin = op(Linearization.make_var(loc)) lin = op(Linearization.make_var(loc))
loc2, lin2 = _get_acceptable_location(op, loc, lin) loc2, lin2 = _get_acceptable_location(op, loc, lin)
......
...@@ -17,107 +17,52 @@ ...@@ -17,107 +17,52 @@
from .utilities import iscomplextype from .utilities import iscomplextype
import numpy as np import numpy as np
import pypocketfft
_nthreads = 1
_use_fftw = False
_fftw_prepped = False
_fft_extra_args = {}
def nthreads():
return _nthreads
def enable_fftw():
global _use_fftw
_use_fftw = True
def set_nthreads(nthr):
global _nthreads
_nthreads = nthr
def disable_fftw():
global _use_fftw
_use_fftw = False
# FIXME this should not be necessary ... no one should call a complex FFT
def _init_pyfftw(): # with a float array.
global _fft_extra_args, _fftw_prepped def _make_complex(a):
if not _fftw_prepped: if a.dtype in (np.complex64, np.complex128):
import pyfftw return a
from pyfftw.interfaces.numpy_fft import fftn, rfftn, ifftn if a.dtype == np.float64:
pyfftw.interfaces.cache.enable() return a.astype(np.complex128)
pyfftw.interfaces.cache.set_keepalive_time(1000.) if a.dtype == np.float32:
# Optional extra arguments for the FFT calls return a.astype(np.complex64)
# if exact reproducibility is needed, raise NotImplementedError
# set "planner_effort" to "FFTW_ESTIMATE"
import os
nthreads = int(os.getenv("OMP_NUM_THREADS", "1"))
_fft_extra_args = dict(planner_effort='FFTW_ESTIMATE',
threads=nthreads)
_fftw_prepped = True
def fftn(a, axes=None): def fftn(a, axes=None):
if _use_fftw: return pypocketfft.fftn(_make_complex(a), axes=axes, nthreads=_nthreads)
from pyfftw.interfaces.numpy_fft import fftn
_init_pyfftw()
return fftn(a, axes=axes, **_fft_extra_args)
else:
return np.fft.fftn(a, axes=axes)
def rfftn(a, axes=None): def rfftn(a, axes=None):
if _use_fftw: return pypocketfft.rfftn(a, axes=axes, nthreads=_nthreads)
from pyfftw.interfaces.numpy_fft import rfftn
_init_pyfftw()
return rfftn(a, axes=axes, **_fft_extra_args)
else:
return np.fft.rfftn(a, axes=axes)
def ifftn(a, axes=None): def ifftn(a, axes=None):
if _use_fftw: # FIXME this is a temporary fix and can be done more elegantly
from pyfftw.interfaces.numpy_fft import ifftn if axes is None:
_init_pyfftw() fct = 1./a.size
return ifftn(a, axes=axes, **_fft_extra_args)
else: else:
return np.fft.ifftn(a, axes=axes) fct = 1./np.prod(np.take(a.shape, axes))
return pypocketfft.ifftn(_make_complex(a), axes=axes, fct=fct,
nthreads=_nthreads)
def hartley(a, axes=None): def hartley(a, axes=None):
# Check if the axes provided are valid given the shape return pypocketfft.hartley2(a, axes=axes, nthreads=_nthreads)
if axes is not None and \
not all(axis < len(a.shape) for axis in axes):
raise ValueError("Provided axes do not match array shape")
if iscomplextype(a.dtype):
raise TypeError("Hartley transform requires real-valued arrays.")
tmp = rfftn(a, axes=axes)
def _fill_array(tmp, res, axes):
if axes is None:
axes = tuple(range(tmp.ndim))
lastaxis = axes[-1]
ntmplast = tmp.shape[lastaxis]
slice1 = (slice(None),)*lastaxis + (slice(0, ntmplast),)
np.add(tmp.real, tmp.imag, out=res[slice1])
def _fill_upper_half(tmp, res, axes):
lastaxis = axes[-1]
nlast = res.shape[lastaxis]
ntmplast = tmp.shape[lastaxis]
nrem = nlast - ntmplast
slice1 = [slice(None)]*lastaxis + [slice(ntmplast, None)]
slice2 = [slice(None)]*lastaxis + [slice(nrem, 0, -1)]
for i in axes[:-1]:
slice1[i] = slice(1, None)
slice2[i] = slice(None, 0, -1)
slice1 = tuple(slice1)
slice2 = tuple(slice2)
np.subtract(tmp[slice2].real, tmp[slice2].imag, out=res[slice1])
for i, ax in enumerate(axes[:-1]):
dim1 = (slice(None),)*ax + (slice(0, 1),)
axes2 = axes[:i] + axes[i+1:]
_fill_upper_half(tmp[dim1], res[dim1], axes2)
_fill_upper_half(tmp, res, axes)
return res
return _fill_array(tmp, np.empty_like(a), axes)
# Do a real-to-complex forward FFT and return the _full_ output array # Do a real-to-complex forward FFT and return the _full_ output array
......
...@@ -26,7 +26,8 @@ from ..sugar import from_global_data, makeDomain ...@@ -26,7 +26,8 @@ from ..sugar import from_global_data, makeDomain
class GridderMaker(object): class GridderMaker(object):
def __init__(self, domain, eps=1e-15): def __init__(self, domain, eps=2e-13):
from nifty_gridder import get_w
domain = makeDomain(domain) domain = makeDomain(domain)
if (len(domain) != 1 or not isinstance(domain[0], RGSpace) or if (len(domain) != 1 or not isinstance(domain[0], RGSpace) or
not len(domain.shape) == 2): not len(domain.shape) == 2):
...@@ -34,20 +35,17 @@ class GridderMaker(object): ...@@ -34,20 +35,17 @@ class GridderMaker(object):
nu, nv = domain.shape nu, nv = domain.shape
if nu % 2 != 0 or nv % 2 != 0: if nu % 2 != 0 or nv % 2 != 0:
raise ValueError("dimensions must be even") raise ValueError("dimensions must be even")
rat = 3 if eps < 1e-11 else 2 nu2, nv2 = 2*nu, 2*nv
nu2, nv2 = rat*nu, rat*nv w = get_w(eps)
nsafe = (w+1)//2
nspread = int(-np.log(eps)/(np.pi*(rat-1)/(rat-.5)) + .5) + 1 nu2 = max([nu2, 2*nsafe])
nu2 = max([nu2, 2*nspread]) nv2 = max([nv2, 2*nsafe])
nv2 = max([nv2, 2*nspread])
r2lamb = rat*rat*nspread/(rat*(rat-.5))
oversampled_domain = RGSpace( oversampled_domain = RGSpace(
[nu2, nv2], distances=[1, 1], harmonic=False) [nu2, nv2], distances=[1, 1], harmonic=False)
self._nspread = nspread self._eps = eps
self._r2lamb = r2lamb self._rest = _RestOperator(domain, oversampled_domain, eps)
self._rest = _RestOperator(domain, oversampled_domain, r2lamb)
def getReordering(self, uv): def getReordering(self, uv):
from nifty_gridder import peanoindex from nifty_gridder import peanoindex
...@@ -55,7 +53,7 @@ class GridderMaker(object): ...@@ -55,7 +53,7 @@ class GridderMaker(object):
return peanoindex(uv, nu2, nv2) return peanoindex(uv, nu2, nv2)
def getGridder(self, uv): def getGridder(self, uv):
return RadioGridder(self._rest.domain, self._nspread, self._r2lamb, uv) return RadioGridder(self._rest.domain, self._eps, uv)
def getRest(self): def getRest(self):
return self._rest return self._rest
...@@ -65,22 +63,22 @@ class GridderMaker(object): ...@@ -65,22 +63,22 @@ class GridderMaker(object):
class _RestOperator(LinearOperator): class _RestOperator(LinearOperator):
def __init__(self, domain, oversampled_domain, r2lamb): def __init__(self, domain, oversampled_domain, eps):
from nifty_gridder import correction_factors
self._domain = makeDomain(oversampled_domain) self._domain = makeDomain(oversampled_domain)
self._target = domain self._target = domain
nu, nv = domain.shape nu, nv = domain.shape
nu2, nv2 = oversampled_domain.shape nu2, nv2 = oversampled_domain.shape
fu = correction_factors(nu2, nu//2+1, eps)
fv = correction_factors(nv2, nv//2+1, eps)
# compute deconvolution operator # compute deconvolution operator
rng = np.arange(nu) rng = np.arange(nu)
k = np.minimum(rng, nu-rng) k = np.minimum(rng, nu-rng)
c = np.pi*r2lamb/nu2**2 self._deconv_u = np.roll(fu[k], -nu//2).reshape((-1, 1))
self._deconv_u = np.roll(np.exp(c*k**2), -nu//2).reshape((-1, 1))
rng = np.arange(nv) rng = np.arange(nv)
k = np.minimum(rng, nv-rng) k = np.minimum(rng, nv-rng)
c = np.pi*r2lamb/nv2**2 self._deconv_v = np.roll(fv[k], -nv//2).reshape((1, -1))
self._deconv_v = np.roll(
np.exp(c*k**2)/r2lamb, -nv//2).reshape((1, -1))
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
...@@ -105,24 +103,20 @@ class _RestOperator(LinearOperator): ...@@ -105,24 +103,20 @@ class _RestOperator(LinearOperator):
class RadioGridder(LinearOperator): class RadioGridder(LinearOperator):
def __init__(self, target, nspread, r2lamb, uv): def __init__(self, target, eps, uv):
self._domain = DomainTuple.make( self._domain = DomainTuple.make(
UnstructuredDomain((uv.shape[0],))) UnstructuredDomain((uv.shape[0],)))
self._target = DomainTuple.make(target) self._target = DomainTuple.make(target)
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
self._nspread, self._r2lamb = int(nspread), float(r2lamb) self._eps = float(eps)
self._uv = uv # FIXME: should we write-protect this? self._uv = uv # FIXME: should we write-protect this?
def apply(self, x, mode): def apply(self, x, mode):
from nifty_gridder import (to_grid, to_grid_post, from nifty_gridder import to_grid, from_grid
from_grid, from_grid_pre)
self._check_input(x, mode) self._check_input(x, mode)
nu2, nv2 = self._target.shape
x = x.to_global_data()
if mode == self.TIMES: if mode == self.TIMES:
res = to_grid(self._uv, x, nu2, nv2, self._nspread, self._r2lamb) nu2, nv2 = self._target.shape
res = to_grid_post(res) res = to_grid(self._uv, x.to_global_data(), nu2, nv2, self._eps)
else: else:
x = from_grid_pre(x) res = from_grid(self._uv, x.to_global_data(), self._eps)
res = from_grid(self._uv, x, nu2, nv2, self._nspread, self._r2lamb)
return from_global_data(self._tgt(mode), res) return from_global_data(self._tgt(mode), res)
...@@ -158,7 +158,7 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -158,7 +158,7 @@ class DiagonalOperator(EndomorphicOperator):
def process_sample(self, samp, from_inverse): def process_sample(self, samp, from_inverse):
if (self._complex or (self._diagmin < 0.) or if (self._complex or (self._diagmin < 0.) or
(self._diagmin == 0. and from_inverse)): (self._diagmin == 0. and from_inverse)):
raise ValueError("operator not positive definite") raise ValueError("operator not positive definite")
if from_inverse: if from_inverse:
res = samp.local_data/np.sqrt(self._ldiag) res = samp.local_data/np.sqrt(self._ldiag)
else: else:
......
...@@ -87,7 +87,7 @@ class ScalingOperator(EndomorphicOperator): ...@@ -87,7 +87,7 @@ class ScalingOperator(EndomorphicOperator):
fct = self._factor fct = self._factor
if (fct.imag != 0. or fct.real < 0. or if (fct.imag != 0. or fct.real < 0. or
(fct.real == 0. and from_inverse)): (fct.real == 0. and from_inverse)):
raise ValueError("operator not positive definite") raise ValueError("operator not positive definite")
return 1./np.sqrt(fct) if from_inverse else np.sqrt(fct) return 1./np.sqrt(fct) if from_inverse else np.sqrt(fct)
# def process_sample(self, samp, from_inverse): # def process_sample(self, samp, from_inverse):
......
...@@ -349,11 +349,10 @@ def _plot2D(f, ax, **kwargs): ...@@ -349,11 +349,10 @@ def _plot2D(f, ax, **kwargs):
rgb = _rgb_data(f.to_global_data()) rgb = _rgb_data(f.to_global_data())
have_rgb = True have_rgb = True
label = kwargs.pop("label", None)
foo = kwargs.pop("norm", None) foo = kwargs.pop("norm", None)
norm = {} if foo is None else {'norm': foo} norm = {} if foo is None else {'norm': foo}
aspect = kwargs.pop("aspect", None)
foo = kwargs.pop("aspect", None)
aspect = {} if foo is None else {'aspect': foo} aspect = {} if foo is None else {'aspect': foo}
ax.set_title(kwargs.pop("title", "")) ax.set_title(kwargs.pop("title", ""))
...@@ -424,7 +423,7 @@ def _plot(f, ax, **kwargs): ...@@ -424,7 +423,7 @@ def _plot(f, ax, **kwargs):
if len(f) == 0: if len(f) == 0:
raise ValueError("need something to plot") raise ValueError("need something to plot")
if not isinstance(f[0], Field): if not isinstance(f[0], Field):
raise TypeError("incorrect data type") raise TypeError("incorrect data type")
dom1 = f[0].domain dom1 = f[0].domain
if (len(dom1) == 1 and if (len(dom1) == 1 and
(isinstance(dom1[0], PowerSpace) or (isinstance(dom1[0], PowerSpace) or
......
...@@ -37,6 +37,8 @@ _pow_spaces = [ift.PowerSpace(ift.RGSpace((17, 38), harmonic=True))] ...@@ -37,6 +37,8 @@ _pow_spaces = [ift.PowerSpace(ift.RGSpace((17, 38), harmonic=True))]
pmp = pytest.mark.parametrize pmp = pytest.mark.parametrize
dtype = list2fixture([np.float64, np.complex128]) dtype = list2fixture([np.float64, np.complex128])
np.random.seed(42)
@pmp('sp', _p_RG_spaces) @pmp('sp', _p_RG_spaces)
def testLOSResponse(sp, dtype): def testLOSResponse(sp, dtype):
...@@ -75,14 +77,14 @@ def testLinearInterpolator(): ...@@ -75,14 +77,14 @@ def testLinearInterpolator():
def testRealizer(sp): def testRealizer(sp):
op = ift.Realizer(sp) op = ift.Realizer(sp)
ift.extra.consistency_check(op, np.complex128, np.float64, ift.extra.consistency_check(op, np.complex128, np.float64,
only_r_linear=True) only_r_linear=True)
@pmp('sp', _h_spaces + _p_spaces + _pow_spaces) @pmp('sp', _h_spaces + _p_spaces + _pow_spaces)
def testConjugationOperator(sp): def testConjugationOperator(sp):
op = ift.ConjugationOperator(sp) op = ift.ConjugationOperator(sp)
ift.extra.consistency_check(op, np.complex128, np.complex128, ift.extra.consistency_check(op, np.complex128, np.complex128,
only_r_linear=True) only_r_linear=True)
@pmp('args', [(ift.RGSpace(10, harmonic=True), 4, 0), (ift.RGSpace( @pmp('args', [(ift.RGSpace(10, harmonic=True), 4, 0), (ift.RGSpace(
......
...@@ -34,22 +34,10 @@ def _get_rtol(tp): ...@@ -34,22 +34,10 @@ def _get_rtol(tp):
pmp = pytest.mark.parametrize pmp = pytest.mark.parametrize
dtype = list2fixture([np.float64, np.float32, np.complex64, np.complex128]) dtype = list2fixture([np.float64, np.float32, np.complex64, np.complex128])
op = list2fixture([ift.HartleyOperator, ift.FFTOperator]) op = list2fixture([ift.HartleyOperator, ift.FFTOperator])
fftw = list2fixture([False, True])
def test_switch():
ift.fft.enable_fftw()
assert_(ift.fft._use_fftw is True)
ift.fft.disable_fftw()
assert_(ift.fft._use_fftw is False)
ift.fft.enable_fftw()
assert_(ift.fft._use_fftw is True)
@pmp('d', [0.1, 1, 3.7]) @pmp('d', [0.1, 1, 3.7])
def test_fft1D(d, dtype, op, fftw): def test_fft1D(d, dtype, op):
if fftw:
ift.fft.enable_fftw()
dim1 = 16 dim1 = 16
tol = _get_rtol(dtype) tol = _get_rtol(dtype)
a = ift.RGSpace(dim1, distances=d) a = ift.RGSpace(dim1, distances=d)
...@@ -69,16 +57,16 @@ def test_fft1D(d, dtype, op, fftw): ...@@ -69,16 +57,16 @@ def test_fft1D(d, dtype, op, fftw):
domain=a, random_type='normal', std=7, mean=3, dtype=dtype) domain=a, random_type='normal', std=7, mean=3, dtype=dtype)
out = fft.inverse_times(fft.times(inp)) out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol) assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw()
@pmp('dim1', [12, 15]) @pmp('dim1', [12, 15])
@pmp('dim2', [9, 12]) @pmp('dim2', [9, 12])
@pmp('d1', [0.1, 1, 3.7]) @pmp('d1', [0.1, 1, 3.7])
@pmp('d2', [0.4, 1, 2.7]) @pmp('d2', [0.4, 1, 2.7])
def test_fft2D(dim1, dim2, d1, d2, dtype, op, fftw): @pmp('nthreads', [0, 1, 2, 3, 4])
if fftw: def test_fft2D(dim1, dim2, d1, d2, dtype, op, nthreads):
ift.fft.enable_fftw() ift.fft.set_nthreads(nthreads)
assert_(ift.fft.nthreads() == nthreads)
tol = _get_rtol(dtype) tol = _get_rtol(dtype)
a = ift.RGSpace([dim1, dim2], distances=[d1, d2]) a = ift.RGSpace([dim1, dim2], distances=[d1, d2])
b = ift.RGSpace( b = ift.RGSpace(
...@@ -97,13 +85,11 @@ def test_fft2D(dim1, dim2, d1, d2, dtype, op, fftw): ...@@ -97,13 +85,11 @@ def test_fft2D(dim1, dim2, d1, d2, dtype, op, fftw):
domain=a, random_type='normal', std=7, mean=3, dtype=dtype) domain=a, random_type='normal', std=7, mean=3, dtype=dtype)
out = fft.inverse_times(fft.times(inp)) out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol) assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw() ift.fft.set_nthreads(1)
@pmp('index', [0, 1, 2]) @pmp('index', [0, 1, 2])
def test_composed_fft(index, dtype, op, fftw): def test_composed_fft(index, dtype, op):
if fftw:
ift.fft.enable_fftw()
tol = _get_rtol(dtype) tol = _get_rtol(dtype)
a = [a1, a2, a = [a1, a2,
a3] = [ift.RGSpace((32,)), a3] = [ift.RGSpace((32,)),
...@@ -115,7 +101,6 @@ def test_composed_fft(index, dtype, op, fftw): ...@@ -115,7 +101,6 @@ def test_composed_fft(index, dtype, op, fftw):
domain=(a1, a2, a3), random_type='normal', std=7, mean=3, dtype=dtype) domain=(a1, a2, a3), random_type='normal', std=7, mean=3, dtype=dtype)
out = fft.inverse_times(fft.times(inp)) out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol) assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw()
@pmp('space', [ @pmp('space', [
...@@ -123,9 +108,7 @@ def test_composed_fft(index, dtype, op, fftw): ...@@ -123,9 +108,7 @@ def test_composed_fft(index, dtype, op, fftw):
ift.RGSpace((15, 27), distances=(.7, .33), harmonic=True), ift.RGSpace((15, 27), distances=(.7, .33), harmonic=True),
ift.RGSpace(73, distances=0.5643) ift.RGSpace(73, distances=0.5643)
]) ])
def test_normalisation(space, dtype, op, fftw): def test_normalisation(space, dtype, op):
if fftw:
ift.fft.enable_fftw()
tol = 10*_get_rtol(dtype) tol = 10*_get_rtol(dtype)
cospace = space.get_default_codomain() cospace = space.get_default_codomain()
fft = op(space, cospace) fft = op(space, cospace)
...@@ -138,4 +121,3 @@ def test_normalisation(space, dtype, op, fftw): ...@@ -138,4 +121,3 @@ def test_normalisation(space, dtype, op, fftw):