Commit 01e0bb11 authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge branch 'fft_tweaks' into 'NIFTy_5'

Make FFTW optional

See merge request ift/nifty-dev!178
parents 872bd200 09c17c1e
...@@ -39,9 +39,9 @@ Installation ...@@ -39,9 +39,9 @@ Installation
- [Python 3](https://www.python.org/) (3.5.x or later) - [Python 3](https://www.python.org/) (3.5.x or later)
- [SciPy](https://www.scipy.org/) - [SciPy](https://www.scipy.org/)
- [pyFFTW](https://pypi.python.org/pypi/pyFFTW)
Optional dependencies: Optional dependencies:
- [pyFFTW](https://pypi.python.org/pypi/pyFFTW) for faster Fourier transforms
- [pyHealpix](https://gitlab.mpcdf.mpg.de/ift/pyHealpix) (for harmonic - [pyHealpix](https://gitlab.mpcdf.mpg.de/ift/pyHealpix) (for harmonic
transforms involving domains on the sphere) transforms involving domains on the sphere)
- [mpi4py](https://mpi4py.scipy.org) (for MPI-parallel execution) - [mpi4py](https://mpi4py.scipy.org) (for MPI-parallel execution)
...@@ -61,18 +61,29 @@ distributions, the "apt" lines will need slight changes. ...@@ -61,18 +61,29 @@ distributions, the "apt" lines will need slight changes.
NIFTy5 and its mandatory dependencies can be installed via: NIFTy5 and its mandatory dependencies can be installed via:
sudo apt-get install git libfftw3-dev python3 python3-pip python3-dev sudo apt-get install git python3 python3-pip python3-dev
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/NIFTy.git@NIFTy_5 pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/NIFTy.git@NIFTy_5
Plotting support is added via:
pip3 install --user matplotlib
FFTW support is added via:
sudo apt-get install libfftw3-dev
pip3 install --user pyfftw
To actually use FFTW in your Nifty calculations, you need to call
nifty5.fft.enable_fftw()
at the beginning of your code.
(Note: If you encounter problems related to `pyFFTW`, make sure that you are (Note: If you encounter problems related to `pyFFTW`, make sure that you are
using a pip-installed `pyFFTW` package. Unfortunately, some distributions are using a pip-installed `pyFFTW` package. Unfortunately, some distributions are
shipping an incorrectly configured `pyFFTW` package, which does not cooperate shipping an incorrectly configured `pyFFTW` package, which does not cooperate
with the installed `FFTW3` libraries.) with the installed `FFTW3` libraries.)
Plotting support is added via:
pip3 install --user matplotlib
Support for spherical harmonic transforms is added via: Support for spherical harmonic transforms is added via:
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git
...@@ -86,7 +97,7 @@ MPI support is added via: ...@@ -86,7 +97,7 @@ MPI support is added via:
To run the tests, additional packages are required: To run the tests, additional packages are required:
sudo apt-get install python3-coverage python3-parameterized python3-pytest python3-pytest-cov sudo apt-get install python3-coverage python3-pytest python3-pytest-cov
Afterwards the tests (including a coverage report) can be run using the Afterwards the tests (including a coverage report) can be run using the
following command in the repository root: following command in the repository root:
......
...@@ -7,18 +7,29 @@ distributions, the "apt" lines will need slight changes. ...@@ -7,18 +7,29 @@ distributions, the "apt" lines will need slight changes.
NIFTy5 and its mandatory dependencies can be installed via:: NIFTy5 and its mandatory dependencies can be installed via::
sudo apt-get install git libfftw3-dev python3 python3-pip python3-dev sudo apt-get install git python3 python3-pip python3-dev
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/NIFTy.git@NIFTy_5 pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/NIFTy.git@NIFTy_5
Plotting support is added via::
pip3 install --user matplotlib
FFTW support is added via:
sudo apt-get install libfftw3-dev
pip3 install --user pyfftw
To actually use FFTW in your Nifty calculations, you need to call
nifty5.fft.enable_fftw()
at the beginning of your code.
(Note: If you encounter problems related to `pyFFTW`, make sure that you are (Note: If you encounter problems related to `pyFFTW`, make sure that you are
using a pip-installed `pyFFTW` package. Unfortunately, some distributions are using a pip-installed `pyFFTW` package. Unfortunately, some distributions are
shipping an incorrectly configured `pyFFTW` package, which does not cooperate shipping an incorrectly configured `pyFFTW` package, which does not cooperate
with the installed `FFTW3` libraries.) with the installed `FFTW3` libraries.)
Plotting support is added via::
pip3 install --user matplotlib
Support for spherical harmonic transforms is added via:: Support for spherical harmonic transforms is added via::
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git
......
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from functools import reduce from functools import reduce
from ..utilities import NiftyMetaBase from ..utilities import NiftyMeta
class Domain(NiftyMetaBase()): class Domain(metaclass=NiftyMeta):
"""The abstract class repesenting a (structured or unstructured) domain. """The abstract class repesenting a (structured or unstructured) domain.
""" """
def __repr__(self): def __repr__(self):
......
...@@ -19,23 +19,63 @@ from .utilities import iscomplextype ...@@ -19,23 +19,63 @@ from .utilities import iscomplextype
import numpy as np import numpy as np
_use_fftw = True _use_fftw = False
_fftw_prepped = False
_fft_extra_args = {}
if _use_fftw: def enable_fftw():
import pyfftw global _use_fftw
from pyfftw.interfaces.numpy_fft import fftn, rfftn, ifftn _use_fftw = True
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(1000.)
# Optional extra arguments for the FFT calls def disable_fftw():
# if exact reproducibility is needed, global _use_fftw
# set "planner_effort" to "FFTW_ESTIMATE" _use_fftw = False
import os
nthreads = int(os.getenv("OMP_NUM_THREADS", "1"))
_fft_extra_args = dict(planner_effort='FFTW_ESTIMATE', threads=nthreads) def _init_pyfftw():
else: global _fft_extra_args, _fftw_prepped
from numpy.fft import fftn, rfftn, ifftn if not _fftw_prepped:
_fft_extra_args = {} import pyfftw
from pyfftw.interfaces.numpy_fft import fftn, rfftn, ifftn
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(1000.)
# Optional extra arguments for the FFT calls
# if exact reproducibility is needed,
# set "planner_effort" to "FFTW_ESTIMATE"
import os
nthreads = int(os.getenv("OMP_NUM_THREADS", "1"))
_fft_extra_args = dict(planner_effort='FFTW_ESTIMATE',
threads=nthreads)
_fftw_prepped = True
def fftn(a, axes=None):
if _use_fftw:
from pyfftw.interfaces.numpy_fft import fftn
_init_pyfftw()
return fftn(a, axes=axes, **_fft_extra_args)
else:
return np.fft.fftn(a, axes=axes)
def rfftn(a, axes=None):
if _use_fftw:
from pyfftw.interfaces.numpy_fft import rfftn
_init_pyfftw()
return rfftn(a, axes=axes, **_fft_extra_args)
else:
return np.fft.rfftn(a, axes=axes)
def ifftn(a, axes=None):
if _use_fftw:
from pyfftw.interfaces.numpy_fft import ifftn
_init_pyfftw()
return ifftn(a, axes=axes, **_fft_extra_args)
else:
return np.fft.ifftn(a, axes=axes)
def hartley(a, axes=None): def hartley(a, axes=None):
...@@ -46,7 +86,7 @@ def hartley(a, axes=None): ...@@ -46,7 +86,7 @@ def hartley(a, axes=None):
if iscomplextype(a.dtype): if iscomplextype(a.dtype):
raise TypeError("Hartley transform requires real-valued arrays.") raise TypeError("Hartley transform requires real-valued arrays.")
tmp = rfftn(a, axes=axes, **_fft_extra_args) tmp = rfftn(a, axes=axes)
def _fill_array(tmp, res, axes): def _fill_array(tmp, res, axes):
if axes is None: if axes is None:
...@@ -89,7 +129,7 @@ def my_fftn_r2c(a, axes=None): ...@@ -89,7 +129,7 @@ def my_fftn_r2c(a, axes=None):
if iscomplextype(a.dtype): if iscomplextype(a.dtype):
raise TypeError("Transform requires real-valued input arrays.") raise TypeError("Transform requires real-valued input arrays.")
tmp = rfftn(a, axes=axes, **_fft_extra_args) tmp = rfftn(a, axes=axes)
def _fill_complex_array(tmp, res, axes): def _fill_complex_array(tmp, res, axes):
if axes is None: if axes is None:
...@@ -123,4 +163,4 @@ def my_fftn_r2c(a, axes=None): ...@@ -123,4 +163,4 @@ def my_fftn_r2c(a, axes=None):
def my_fftn(a, axes=None): def my_fftn(a, axes=None):
return fftn(a, axes=axes, **_fft_extra_args) return fftn(a, axes=axes)
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..utilities import NiftyMetaBase from ..utilities import NiftyMeta
class Energy(NiftyMetaBase()): class Energy(metaclass=NiftyMeta):
"""Provides the functional used by minimization schemes. """Provides the functional used by minimization schemes.
The Energy object is an implementation of a scalar function including its The Energy object is an implementation of a scalar function including its
......
...@@ -16,11 +16,11 @@ ...@@ -16,11 +16,11 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..logger import logger from ..logger import logger
from ..utilities import NiftyMetaBase from ..utilities import NiftyMeta
import numpy as np import numpy as np
class IterationController(NiftyMetaBase()): class IterationController(metaclass=NiftyMeta):
"""The abstract base class for all iteration controllers. """The abstract base class for all iteration controllers.
An iteration controller is an object that monitors the progress of a An iteration controller is an object that monitors the progress of a
minimization iteration. At the begin of the minimization, its start() minimization iteration. At the begin of the minimization, its start()
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..utilities import NiftyMetaBase from ..utilities import NiftyMeta
from ..logger import logger from ..logger import logger
import numpy as np import numpy as np
...@@ -103,7 +103,7 @@ class LineEnergy(object): ...@@ -103,7 +103,7 @@ class LineEnergy(object):
return res.real return res.real
class LineSearch(NiftyMetaBase()): class LineSearch(metaclass=NiftyMeta):
"""Class for finding a step size that satisfies the strong Wolfe """Class for finding a step size that satisfies the strong Wolfe
conditions. conditions.
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..utilities import NiftyMetaBase from ..utilities import NiftyMeta
class Minimizer(NiftyMetaBase()): class Minimizer(metaclass=NiftyMeta):
"""A base class used by all minimizers.""" """A base class used by all minimizers."""
# MR FIXME: the docstring is partially ignored by Sphinx. Why? # MR FIXME: the docstring is partially ignored by Sphinx. Why?
......
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np import numpy as np
from ..utilities import NiftyMetaBase, indent from ..utilities import NiftyMeta, indent
class Operator(NiftyMetaBase()): class Operator(metaclass=NiftyMeta):
"""Transforms values defined on one domain into values defined on another """Transforms values defined on one domain into values defined on another
domain, and can also provide the Jacobian. domain, and can also provide the Jacobian.
""" """
......
...@@ -20,10 +20,9 @@ from itertools import product ...@@ -20,10 +20,9 @@ from itertools import product
from functools import reduce from functools import reduce
import numpy as np import numpy as np
from future.utils import with_metaclass
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space", __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMetaBase", "my_sum", "my_lincomb_simple", "memo", "NiftyMeta", "my_sum", "my_lincomb_simple",
"my_lincomb", "indent", "my_lincomb", "indent",
"my_product", "frozendict", "special_add_at", "iscomplextype"] "my_product", "frozendict", "special_add_at", "iscomplextype"]
...@@ -178,10 +177,6 @@ class NiftyMeta(_DocStringInheritor): ...@@ -178,10 +177,6 @@ class NiftyMeta(_DocStringInheritor):
pass pass
def NiftyMetaBase():
return with_metaclass(NiftyMeta, type('NewBase', (object,), {}))
class frozendict(collections.abc.Mapping): class frozendict(collections.abc.Mapping):
""" """
An immutable wrapper around dictionaries that implements the complete An immutable wrapper around dictionaries that implements the complete
......
...@@ -39,8 +39,8 @@ setup(name="nifty5", ...@@ -39,8 +39,8 @@ setup(name="nifty5",
packages=find_packages(include=["nifty5", "nifty5.*"]), packages=find_packages(include=["nifty5", "nifty5.*"]),
zip_safe=True, zip_safe=True,
license="GPLv3", license="GPLv3",
setup_requires=['future', 'scipy'], setup_requires=['scipy'],
install_requires=['future', 'scipy', 'pyfftw>=0.10.4'], install_requires=['scipy'],
classifiers=[ classifiers=[
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Topic :: Utilities", "Topic :: Utilities",
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import numpy as np import numpy as np
import pytest import pytest
from numpy.testing import assert_allclose from numpy.testing import assert_, assert_allclose
import nifty5 as ift import nifty5 as ift
...@@ -34,10 +34,22 @@ def _get_rtol(tp): ...@@ -34,10 +34,22 @@ def _get_rtol(tp):
pmp = pytest.mark.parametrize pmp = pytest.mark.parametrize
dtype = list2fixture([np.float64, np.float32, np.complex64, np.complex128]) dtype = list2fixture([np.float64, np.float32, np.complex64, np.complex128])
op = list2fixture([ift.HartleyOperator, ift.FFTOperator]) op = list2fixture([ift.HartleyOperator, ift.FFTOperator])
fftw = list2fixture([False, True])
def test_switch():
ift.fft.enable_fftw()
assert_(ift.fft._use_fftw is True)
ift.fft.disable_fftw()
assert_(ift.fft._use_fftw is False)
ift.fft.enable_fftw()
assert_(ift.fft._use_fftw is True)
@pmp('d', [0.1, 1, 3.7]) @pmp('d', [0.1, 1, 3.7])
def test_fft1D(d, dtype, op): def test_fft1D(d, dtype, op, fftw):
if fftw:
ift.fft.enable_fftw()
dim1 = 16 dim1 = 16
tol = _get_rtol(dtype) tol = _get_rtol(dtype)
a = ift.RGSpace(dim1, distances=d) a = ift.RGSpace(dim1, distances=d)
...@@ -57,13 +69,16 @@ def test_fft1D(d, dtype, op): ...@@ -57,13 +69,16 @@ def test_fft1D(d, dtype, op):
domain=a, random_type='normal', std=7, mean=3, dtype=dtype) domain=a, random_type='normal', std=7, mean=3, dtype=dtype)
out = fft.inverse_times(fft.times(inp)) out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol) assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw()
@pmp('dim1', [12, 15]) @pmp('dim1', [12, 15])
@pmp('dim2', [9, 12]) @pmp('dim2', [9, 12])
@pmp('d1', [0.1, 1, 3.7]) @pmp('d1', [0.1, 1, 3.7])
@pmp('d2', [0.4, 1, 2.7]) @pmp('d2', [0.4, 1, 2.7])
def test_fft2D(dim1, dim2, d1, d2, dtype, op): def test_fft2D(dim1, dim2, d1, d2, dtype, op, fftw):
if fftw:
ift.fft.enable_fftw()
tol = _get_rtol(dtype) tol = _get_rtol(dtype)
a = ift.RGSpace([dim1, dim2], distances=[d1, d2]) a = ift.RGSpace([dim1, dim2], distances=[d1, d2])
b = ift.RGSpace( b = ift.RGSpace(
...@@ -82,10 +97,13 @@ def test_fft2D(dim1, dim2, d1, d2, dtype, op): ...@@ -82,10 +97,13 @@ def test_fft2D(dim1, dim2, d1, d2, dtype, op):
domain=a, random_type='normal', std=7, mean=3, dtype=dtype) domain=a, random_type='normal', std=7, mean=3, dtype=dtype)
out = fft.inverse_times(fft.times(inp)) out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol) assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw()
@pmp('index', [0, 1, 2]) @pmp('index', [0, 1, 2])
def test_composed_fft(index, dtype, op): def test_composed_fft(index, dtype, op, fftw):
if fftw:
ift.fft.enable_fftw()
tol = _get_rtol(dtype) tol = _get_rtol(dtype)
a = [a1, a2, a = [a1, a2,
a3] = [ift.RGSpace((32,)), a3] = [ift.RGSpace((32,)),
...@@ -97,6 +115,7 @@ def test_composed_fft(index, dtype, op): ...@@ -97,6 +115,7 @@ def test_composed_fft(index, dtype, op):
domain=(a1, a2, a3), random_type='normal', std=7, mean=3, dtype=dtype) domain=(a1, a2, a3), random_type='normal', std=7, mean=3, dtype=dtype)
out = fft.inverse_times(fft.times(inp)) out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol) assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw()
@pmp('space', [ @pmp('space', [
...@@ -104,7 +123,9 @@ def test_composed_fft(index, dtype, op): ...@@ -104,7 +123,9 @@ def test_composed_fft(index, dtype, op):
ift.RGSpace((15, 27), distances=(.7, .33), harmonic=True), ift.RGSpace((15, 27), distances=(.7, .33), harmonic=True),
ift.RGSpace(73, distances=0.5643) ift.RGSpace(73, distances=0.5643)
]) ])
def test_normalisation(space, dtype, op): def test_normalisation(space, dtype, op, fftw):
if fftw:
ift.fft.enable_fftw()
tol = 10*_get_rtol(dtype) tol = 10*_get_rtol(dtype)
cospace = space.get_default_codomain() cospace = space.get_default_codomain()
fft = op(space, cospace) fft = op(space, cospace)
...@@ -117,3 +138,4 @@ def test_normalisation(space, dtype, op): ...@@ -117,3 +138,4 @@ def test_normalisation(space, dtype, op):
assert_allclose( assert_allclose(
inp.to_global_data()[zero_idx], out.integrate(), rtol=tol, atol=tol) inp.to_global_data()[zero_idx], out.integrate(), rtol=tol, atol=tol)
assert_allclose(out.local_data, out2.local_data, rtol=tol, atol=tol) assert_allclose(out.local_data, out2.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment