Commit 01e0bb11 authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge branch 'fft_tweaks' into 'NIFTy_5'

Make FFTW optional

See merge request ift/nifty-dev!178
parents 872bd200 09c17c1e
......@@ -39,9 +39,9 @@ Installation
- [Python 3](https://www.python.org/) (3.5.x or later)
- [SciPy](https://www.scipy.org/)
- [pyFFTW](https://pypi.python.org/pypi/pyFFTW)
Optional dependencies:
- [pyFFTW](https://pypi.python.org/pypi/pyFFTW) for faster Fourier transforms
- [pyHealpix](https://gitlab.mpcdf.mpg.de/ift/pyHealpix) (for harmonic
transforms involving domains on the sphere)
- [mpi4py](https://mpi4py.scipy.org) (for MPI-parallel execution)
......@@ -61,18 +61,29 @@ distributions, the "apt" lines will need slight changes.
NIFTy5 and its mandatory dependencies can be installed via:
sudo apt-get install git libfftw3-dev python3 python3-pip python3-dev
sudo apt-get install git python3 python3-pip python3-dev
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/NIFTy.git@NIFTy_5
Plotting support is added via:
pip3 install --user matplotlib
FFTW support is added via:
sudo apt-get install libfftw3-dev
pip3 install --user pyfftw
To actually use FFTW in your Nifty calculations, you need to call
nifty5.fft.enable_fftw()
at the beginning of your code.
(Note: If you encounter problems related to `pyFFTW`, make sure that you are
using a pip-installed `pyFFTW` package. Unfortunately, some distributions are
shipping an incorrectly configured `pyFFTW` package, which does not cooperate
with the installed `FFTW3` libraries.)
Plotting support is added via:
pip3 install --user matplotlib
Support for spherical harmonic transforms is added via:
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git
......@@ -86,7 +97,7 @@ MPI support is added via:
To run the tests, additional packages are required:
sudo apt-get install python3-coverage python3-parameterized python3-pytest python3-pytest-cov
sudo apt-get install python3-coverage python3-pytest python3-pytest-cov
Afterwards the tests (including a coverage report) can be run using the
following command in the repository root:
......
......@@ -7,18 +7,29 @@ distributions, the "apt" lines will need slight changes.
NIFTy5 and its mandatory dependencies can be installed via::
sudo apt-get install git libfftw3-dev python3 python3-pip python3-dev
sudo apt-get install git python3 python3-pip python3-dev
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/NIFTy.git@NIFTy_5
Plotting support is added via::
pip3 install --user matplotlib
FFTW support is added via:
sudo apt-get install libfftw3-dev
pip3 install --user pyfftw
To actually use FFTW in your Nifty calculations, you need to call
nifty5.fft.enable_fftw()
at the beginning of your code.
(Note: If you encounter problems related to `pyFFTW`, make sure that you are
using a pip-installed `pyFFTW` package. Unfortunately, some distributions are
shipping an incorrectly configured `pyFFTW` package, which does not cooperate
with the installed `FFTW3` libraries.)
Plotting support is added via::
pip3 install --user matplotlib
Support for spherical harmonic transforms is added via::
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git
......
......@@ -16,10 +16,10 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from functools import reduce
from ..utilities import NiftyMetaBase
from ..utilities import NiftyMeta
class Domain(NiftyMetaBase()):
class Domain(metaclass=NiftyMeta):
"""The abstract class repesenting a (structured or unstructured) domain.
"""
def __repr__(self):
......
......@@ -19,10 +19,24 @@ from .utilities import iscomplextype
import numpy as np
_use_fftw = True
_use_fftw = False
_fftw_prepped = False
_fft_extra_args = {}
if _use_fftw:
def enable_fftw():
global _use_fftw
_use_fftw = True
def disable_fftw():
global _use_fftw
_use_fftw = False
def _init_pyfftw():
global _fft_extra_args, _fftw_prepped
if not _fftw_prepped:
import pyfftw
from pyfftw.interfaces.numpy_fft import fftn, rfftn, ifftn
pyfftw.interfaces.cache.enable()
......@@ -32,10 +46,36 @@ if _use_fftw:
# set "planner_effort" to "FFTW_ESTIMATE"
import os
nthreads = int(os.getenv("OMP_NUM_THREADS", "1"))
_fft_extra_args = dict(planner_effort='FFTW_ESTIMATE', threads=nthreads)
else:
from numpy.fft import fftn, rfftn, ifftn
_fft_extra_args = {}
_fft_extra_args = dict(planner_effort='FFTW_ESTIMATE',
threads=nthreads)
_fftw_prepped = True
def fftn(a, axes=None):
if _use_fftw:
from pyfftw.interfaces.numpy_fft import fftn
_init_pyfftw()
return fftn(a, axes=axes, **_fft_extra_args)
else:
return np.fft.fftn(a, axes=axes)
def rfftn(a, axes=None):
if _use_fftw:
from pyfftw.interfaces.numpy_fft import rfftn
_init_pyfftw()
return rfftn(a, axes=axes, **_fft_extra_args)
else:
return np.fft.rfftn(a, axes=axes)
def ifftn(a, axes=None):
if _use_fftw:
from pyfftw.interfaces.numpy_fft import ifftn
_init_pyfftw()
return ifftn(a, axes=axes, **_fft_extra_args)
else:
return np.fft.ifftn(a, axes=axes)
def hartley(a, axes=None):
......@@ -46,7 +86,7 @@ def hartley(a, axes=None):
if iscomplextype(a.dtype):
raise TypeError("Hartley transform requires real-valued arrays.")
tmp = rfftn(a, axes=axes, **_fft_extra_args)
tmp = rfftn(a, axes=axes)
def _fill_array(tmp, res, axes):
if axes is None:
......@@ -89,7 +129,7 @@ def my_fftn_r2c(a, axes=None):
if iscomplextype(a.dtype):
raise TypeError("Transform requires real-valued input arrays.")
tmp = rfftn(a, axes=axes, **_fft_extra_args)
tmp = rfftn(a, axes=axes)
def _fill_complex_array(tmp, res, axes):
if axes is None:
......@@ -123,4 +163,4 @@ def my_fftn_r2c(a, axes=None):
def my_fftn(a, axes=None):
return fftn(a, axes=axes, **_fft_extra_args)
return fftn(a, axes=axes)
......@@ -15,10 +15,10 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..utilities import NiftyMetaBase
from ..utilities import NiftyMeta
class Energy(NiftyMetaBase()):
class Energy(metaclass=NiftyMeta):
"""Provides the functional used by minimization schemes.
The Energy object is an implementation of a scalar function including its
......
......@@ -16,11 +16,11 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..logger import logger
from ..utilities import NiftyMetaBase
from ..utilities import NiftyMeta
import numpy as np
class IterationController(NiftyMetaBase()):
class IterationController(metaclass=NiftyMeta):
"""The abstract base class for all iteration controllers.
An iteration controller is an object that monitors the progress of a
minimization iteration. At the begin of the minimization, its start()
......
......@@ -15,7 +15,7 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..utilities import NiftyMetaBase
from ..utilities import NiftyMeta
from ..logger import logger
import numpy as np
......@@ -103,7 +103,7 @@ class LineEnergy(object):
return res.real
class LineSearch(NiftyMetaBase()):
class LineSearch(metaclass=NiftyMeta):
"""Class for finding a step size that satisfies the strong Wolfe
conditions.
......
......@@ -15,10 +15,10 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..utilities import NiftyMetaBase
from ..utilities import NiftyMeta
class Minimizer(NiftyMetaBase()):
class Minimizer(metaclass=NiftyMeta):
"""A base class used by all minimizers."""
# MR FIXME: the docstring is partially ignored by Sphinx. Why?
......
......@@ -16,10 +16,10 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..utilities import NiftyMetaBase, indent
from ..utilities import NiftyMeta, indent
class Operator(NiftyMetaBase()):
class Operator(metaclass=NiftyMeta):
"""Transforms values defined on one domain into values defined on another
domain, and can also provide the Jacobian.
"""
......
......@@ -20,10 +20,9 @@ from itertools import product
from functools import reduce
import numpy as np
from future.utils import with_metaclass
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMetaBase", "my_sum", "my_lincomb_simple",
"memo", "NiftyMeta", "my_sum", "my_lincomb_simple",
"my_lincomb", "indent",
"my_product", "frozendict", "special_add_at", "iscomplextype"]
......@@ -178,10 +177,6 @@ class NiftyMeta(_DocStringInheritor):
pass
def NiftyMetaBase():
return with_metaclass(NiftyMeta, type('NewBase', (object,), {}))
class frozendict(collections.abc.Mapping):
"""
An immutable wrapper around dictionaries that implements the complete
......
......@@ -39,8 +39,8 @@ setup(name="nifty5",
packages=find_packages(include=["nifty5", "nifty5.*"]),
zip_safe=True,
license="GPLv3",
setup_requires=['future', 'scipy'],
install_requires=['future', 'scipy', 'pyfftw>=0.10.4'],
setup_requires=['scipy'],
install_requires=['scipy'],
classifiers=[
"Development Status :: 4 - Beta",
"Topic :: Utilities",
......
......@@ -17,7 +17,7 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose
from numpy.testing import assert_, assert_allclose
import nifty5 as ift
......@@ -34,10 +34,22 @@ def _get_rtol(tp):
pmp = pytest.mark.parametrize
dtype = list2fixture([np.float64, np.float32, np.complex64, np.complex128])
op = list2fixture([ift.HartleyOperator, ift.FFTOperator])
fftw = list2fixture([False, True])
def test_switch():
ift.fft.enable_fftw()
assert_(ift.fft._use_fftw is True)
ift.fft.disable_fftw()
assert_(ift.fft._use_fftw is False)
ift.fft.enable_fftw()
assert_(ift.fft._use_fftw is True)
@pmp('d', [0.1, 1, 3.7])
def test_fft1D(d, dtype, op):
def test_fft1D(d, dtype, op, fftw):
if fftw:
ift.fft.enable_fftw()
dim1 = 16
tol = _get_rtol(dtype)
a = ift.RGSpace(dim1, distances=d)
......@@ -57,13 +69,16 @@ def test_fft1D(d, dtype, op):
domain=a, random_type='normal', std=7, mean=3, dtype=dtype)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw()
@pmp('dim1', [12, 15])
@pmp('dim2', [9, 12])
@pmp('d1', [0.1, 1, 3.7])
@pmp('d2', [0.4, 1, 2.7])
def test_fft2D(dim1, dim2, d1, d2, dtype, op):
def test_fft2D(dim1, dim2, d1, d2, dtype, op, fftw):
if fftw:
ift.fft.enable_fftw()
tol = _get_rtol(dtype)
a = ift.RGSpace([dim1, dim2], distances=[d1, d2])
b = ift.RGSpace(
......@@ -82,10 +97,13 @@ def test_fft2D(dim1, dim2, d1, d2, dtype, op):
domain=a, random_type='normal', std=7, mean=3, dtype=dtype)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw()
@pmp('index', [0, 1, 2])
def test_composed_fft(index, dtype, op):
def test_composed_fft(index, dtype, op, fftw):
if fftw:
ift.fft.enable_fftw()
tol = _get_rtol(dtype)
a = [a1, a2,
a3] = [ift.RGSpace((32,)),
......@@ -97,6 +115,7 @@ def test_composed_fft(index, dtype, op):
domain=(a1, a2, a3), random_type='normal', std=7, mean=3, dtype=dtype)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw()
@pmp('space', [
......@@ -104,7 +123,9 @@ def test_composed_fft(index, dtype, op):
ift.RGSpace((15, 27), distances=(.7, .33), harmonic=True),
ift.RGSpace(73, distances=0.5643)
])
def test_normalisation(space, dtype, op):
def test_normalisation(space, dtype, op, fftw):
if fftw:
ift.fft.enable_fftw()
tol = 10*_get_rtol(dtype)
cospace = space.get_default_codomain()
fft = op(space, cospace)
......@@ -117,3 +138,4 @@ def test_normalisation(space, dtype, op):
assert_allclose(
inp.to_global_data()[zero_idx], out.integrate(), rtol=tol, atol=tol)
assert_allclose(out.local_data, out2.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment