diff --git a/README.md b/README.md index 627745c90ddc24d5c7569d7674387589baf20fb9..afe3908dabf8f858215bfe04565ac5954d7cdaea 100644 --- a/README.md +++ b/README.md @@ -39,9 +39,9 @@ Installation - [Python 3](https://www.python.org/) (3.5.x or later) - [SciPy](https://www.scipy.org/) -- [pyFFTW](https://pypi.python.org/pypi/pyFFTW) Optional dependencies: +- [pyFFTW](https://pypi.python.org/pypi/pyFFTW) for faster Fourier transforms - [pyHealpix](https://gitlab.mpcdf.mpg.de/ift/pyHealpix) (for harmonic transforms involving domains on the sphere) - [mpi4py](https://mpi4py.scipy.org) (for MPI-parallel execution) @@ -61,18 +61,29 @@ distributions, the "apt" lines will need slight changes. NIFTy5 and its mandatory dependencies can be installed via: - sudo apt-get install git libfftw3-dev python3 python3-pip python3-dev + sudo apt-get install git python3 python3-pip python3-dev pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/NIFTy.git@NIFTy_5 +Plotting support is added via: + + pip3 install --user matplotlib + +FFTW support is added via: + + sudo apt-get install libfftw3-dev + pip3 install --user pyfftw + +To actually use FFTW in your Nifty calculations, you need to call + + nifty5.fft.enable_fftw() + +at the beginning of your code. + (Note: If you encounter problems related to `pyFFTW`, make sure that you are using a pip-installed `pyFFTW` package. Unfortunately, some distributions are shipping an incorrectly configured `pyFFTW` package, which does not cooperate with the installed `FFTW3` libraries.) -Plotting support is added via: - - pip3 install --user matplotlib - Support for spherical harmonic transforms is added via: pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git @@ -86,7 +97,7 @@ MPI support is added via: To run the tests, additional packages are required: - sudo apt-get install python3-coverage python3-parameterized python3-pytest python3-pytest-cov + sudo apt-get install python3-coverage python3-pytest python3-pytest-cov Afterwards the tests (including a coverage report) can be run using the following command in the repository root: diff --git a/docs/source/installation.rst b/docs/source/installation.rst index d2f182d1dce965c6ef6216b038979094cdfd4abb..84a5ba8f1539c8513783904c64f36ef6cd016a0e 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -7,18 +7,29 @@ distributions, the "apt" lines will need slight changes. NIFTy5 and its mandatory dependencies can be installed via:: - sudo apt-get install git libfftw3-dev python3 python3-pip python3-dev + sudo apt-get install git python3 python3-pip python3-dev pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/NIFTy.git@NIFTy_5 +Plotting support is added via:: + + pip3 install --user matplotlib + +FFTW support is added via: + + sudo apt-get install libfftw3-dev + pip3 install --user pyfftw + +To actually use FFTW in your Nifty calculations, you need to call + + nifty5.fft.enable_fftw() + +at the beginning of your code. + (Note: If you encounter problems related to `pyFFTW`, make sure that you are using a pip-installed `pyFFTW` package. Unfortunately, some distributions are shipping an incorrectly configured `pyFFTW` package, which does not cooperate with the installed `FFTW3` libraries.) -Plotting support is added via:: - - pip3 install --user matplotlib - Support for spherical harmonic transforms is added via:: pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git diff --git a/nifty5/domains/domain.py b/nifty5/domains/domain.py index b635ae1ad2dfccd9375d3b7455915ed5d0d5fdf9..de88089718d9577b700ab5cb2adc6e9cacee47cf 100644 --- a/nifty5/domains/domain.py +++ b/nifty5/domains/domain.py @@ -16,10 +16,10 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. from functools import reduce -from ..utilities import NiftyMetaBase +from ..utilities import NiftyMeta -class Domain(NiftyMetaBase()): +class Domain(metaclass=NiftyMeta): """The abstract class repesenting a (structured or unstructured) domain. """ def __repr__(self): diff --git a/nifty5/fft.py b/nifty5/fft.py index 458ffb3dca52b32bd471f66a151c69b166be0026..0fc2381409bc92ef66b28f930de95115fcd4885c 100644 --- a/nifty5/fft.py +++ b/nifty5/fft.py @@ -19,23 +19,63 @@ from .utilities import iscomplextype import numpy as np -_use_fftw = True +_use_fftw = False +_fftw_prepped = False +_fft_extra_args = {} -if _use_fftw: - import pyfftw - from pyfftw.interfaces.numpy_fft import fftn, rfftn, ifftn - pyfftw.interfaces.cache.enable() - pyfftw.interfaces.cache.set_keepalive_time(1000.) - # Optional extra arguments for the FFT calls - # if exact reproducibility is needed, - # set "planner_effort" to "FFTW_ESTIMATE" - import os - nthreads = int(os.getenv("OMP_NUM_THREADS", "1")) - _fft_extra_args = dict(planner_effort='FFTW_ESTIMATE', threads=nthreads) -else: - from numpy.fft import fftn, rfftn, ifftn - _fft_extra_args = {} +def enable_fftw(): + global _use_fftw + _use_fftw = True + + +def disable_fftw(): + global _use_fftw + _use_fftw = False + + +def _init_pyfftw(): + global _fft_extra_args, _fftw_prepped + if not _fftw_prepped: + import pyfftw + from pyfftw.interfaces.numpy_fft import fftn, rfftn, ifftn + pyfftw.interfaces.cache.enable() + pyfftw.interfaces.cache.set_keepalive_time(1000.) + # Optional extra arguments for the FFT calls + # if exact reproducibility is needed, + # set "planner_effort" to "FFTW_ESTIMATE" + import os + nthreads = int(os.getenv("OMP_NUM_THREADS", "1")) + _fft_extra_args = dict(planner_effort='FFTW_ESTIMATE', + threads=nthreads) + _fftw_prepped = True + + +def fftn(a, axes=None): + if _use_fftw: + from pyfftw.interfaces.numpy_fft import fftn + _init_pyfftw() + return fftn(a, axes=axes, **_fft_extra_args) + else: + return np.fft.fftn(a, axes=axes) + + +def rfftn(a, axes=None): + if _use_fftw: + from pyfftw.interfaces.numpy_fft import rfftn + _init_pyfftw() + return rfftn(a, axes=axes, **_fft_extra_args) + else: + return np.fft.rfftn(a, axes=axes) + + +def ifftn(a, axes=None): + if _use_fftw: + from pyfftw.interfaces.numpy_fft import ifftn + _init_pyfftw() + return ifftn(a, axes=axes, **_fft_extra_args) + else: + return np.fft.ifftn(a, axes=axes) def hartley(a, axes=None): @@ -46,7 +86,7 @@ def hartley(a, axes=None): if iscomplextype(a.dtype): raise TypeError("Hartley transform requires real-valued arrays.") - tmp = rfftn(a, axes=axes, **_fft_extra_args) + tmp = rfftn(a, axes=axes) def _fill_array(tmp, res, axes): if axes is None: @@ -89,7 +129,7 @@ def my_fftn_r2c(a, axes=None): if iscomplextype(a.dtype): raise TypeError("Transform requires real-valued input arrays.") - tmp = rfftn(a, axes=axes, **_fft_extra_args) + tmp = rfftn(a, axes=axes) def _fill_complex_array(tmp, res, axes): if axes is None: @@ -123,4 +163,4 @@ def my_fftn_r2c(a, axes=None): def my_fftn(a, axes=None): - return fftn(a, axes=axes, **_fft_extra_args) + return fftn(a, axes=axes) diff --git a/nifty5/minimization/energy.py b/nifty5/minimization/energy.py index 3eaa94ef43f8eb77117e917c6545fb563151baf3..3d42c13af16356f161909757f1c04466191d771f 100644 --- a/nifty5/minimization/energy.py +++ b/nifty5/minimization/energy.py @@ -15,10 +15,10 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. -from ..utilities import NiftyMetaBase +from ..utilities import NiftyMeta -class Energy(NiftyMetaBase()): +class Energy(metaclass=NiftyMeta): """Provides the functional used by minimization schemes. The Energy object is an implementation of a scalar function including its diff --git a/nifty5/minimization/iteration_controllers.py b/nifty5/minimization/iteration_controllers.py index 467e4a5d658be51e14732d0b2cb59f3d4cd47c64..e27201e85501a2612372b5dbb274555dc55dc13d 100644 --- a/nifty5/minimization/iteration_controllers.py +++ b/nifty5/minimization/iteration_controllers.py @@ -16,11 +16,11 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. from ..logger import logger -from ..utilities import NiftyMetaBase +from ..utilities import NiftyMeta import numpy as np -class IterationController(NiftyMetaBase()): +class IterationController(metaclass=NiftyMeta): """The abstract base class for all iteration controllers. An iteration controller is an object that monitors the progress of a minimization iteration. At the begin of the minimization, its start() diff --git a/nifty5/minimization/line_search.py b/nifty5/minimization/line_search.py index fac4170450c4abd1b39fde7329067a2afe4eea53..89a771b931a4b824d63f22353e960449583d9ffa 100644 --- a/nifty5/minimization/line_search.py +++ b/nifty5/minimization/line_search.py @@ -15,7 +15,7 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. -from ..utilities import NiftyMetaBase +from ..utilities import NiftyMeta from ..logger import logger import numpy as np @@ -103,7 +103,7 @@ class LineEnergy(object): return res.real -class LineSearch(NiftyMetaBase()): +class LineSearch(metaclass=NiftyMeta): """Class for finding a step size that satisfies the strong Wolfe conditions. diff --git a/nifty5/minimization/minimizer.py b/nifty5/minimization/minimizer.py index 17ea9f79f383a1d6d7ab53e6ee90964b6f975909..eb8199c808b6ec9ee92b923fc633edc498d20625 100644 --- a/nifty5/minimization/minimizer.py +++ b/nifty5/minimization/minimizer.py @@ -15,10 +15,10 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. -from ..utilities import NiftyMetaBase +from ..utilities import NiftyMeta -class Minimizer(NiftyMetaBase()): +class Minimizer(metaclass=NiftyMeta): """A base class used by all minimizers.""" # MR FIXME: the docstring is partially ignored by Sphinx. Why? diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 3b2c8a102bb8992576bf7134d9056dd35b23fa7f..3c1552a41d3fec62c2a5345e8456b433bdba87d3 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -16,10 +16,10 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. import numpy as np -from ..utilities import NiftyMetaBase, indent +from ..utilities import NiftyMeta, indent -class Operator(NiftyMetaBase()): +class Operator(metaclass=NiftyMeta): """Transforms values defined on one domain into values defined on another domain, and can also provide the Jacobian. """ diff --git a/nifty5/utilities.py b/nifty5/utilities.py index 15e32f45d33f873e80315d98ba82f469ed7f0aa0..899f9989bdd494f1b689b2b68dbd8d467bb1ddda 100644 --- a/nifty5/utilities.py +++ b/nifty5/utilities.py @@ -20,10 +20,9 @@ from itertools import product from functools import reduce import numpy as np -from future.utils import with_metaclass __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space", - "memo", "NiftyMetaBase", "my_sum", "my_lincomb_simple", + "memo", "NiftyMeta", "my_sum", "my_lincomb_simple", "my_lincomb", "indent", "my_product", "frozendict", "special_add_at", "iscomplextype"] @@ -178,10 +177,6 @@ class NiftyMeta(_DocStringInheritor): pass -def NiftyMetaBase(): - return with_metaclass(NiftyMeta, type('NewBase', (object,), {})) - - class frozendict(collections.abc.Mapping): """ An immutable wrapper around dictionaries that implements the complete diff --git a/setup.py b/setup.py index b86435dde2a78a4b75e2ff2a00223bf5659b0f53..79d6f5542966fd5fff64d761ba0dc5de2d910492 100644 --- a/setup.py +++ b/setup.py @@ -39,8 +39,8 @@ setup(name="nifty5", packages=find_packages(include=["nifty5", "nifty5.*"]), zip_safe=True, license="GPLv3", - setup_requires=['future', 'scipy'], - install_requires=['future', 'scipy', 'pyfftw>=0.10.4'], + setup_requires=['scipy'], + install_requires=['scipy'], classifiers=[ "Development Status :: 4 - Beta", "Topic :: Utilities", diff --git a/test/test_operators/test_fft_operator.py b/test/test_operators/test_fft_operator.py index 171f26bfab0ce31629c5a3dcd9560052c4715b16..47b8e6a11722cccca76ef65c742bfd14d87ee826 100644 --- a/test/test_operators/test_fft_operator.py +++ b/test/test_operators/test_fft_operator.py @@ -17,7 +17,7 @@ import numpy as np import pytest -from numpy.testing import assert_allclose +from numpy.testing import assert_, assert_allclose import nifty5 as ift @@ -34,10 +34,22 @@ def _get_rtol(tp): pmp = pytest.mark.parametrize dtype = list2fixture([np.float64, np.float32, np.complex64, np.complex128]) op = list2fixture([ift.HartleyOperator, ift.FFTOperator]) +fftw = list2fixture([False, True]) + + +def test_switch(): + ift.fft.enable_fftw() + assert_(ift.fft._use_fftw is True) + ift.fft.disable_fftw() + assert_(ift.fft._use_fftw is False) + ift.fft.enable_fftw() + assert_(ift.fft._use_fftw is True) @pmp('d', [0.1, 1, 3.7]) -def test_fft1D(d, dtype, op): +def test_fft1D(d, dtype, op, fftw): + if fftw: + ift.fft.enable_fftw() dim1 = 16 tol = _get_rtol(dtype) a = ift.RGSpace(dim1, distances=d) @@ -57,13 +69,16 @@ def test_fft1D(d, dtype, op): domain=a, random_type='normal', std=7, mean=3, dtype=dtype) out = fft.inverse_times(fft.times(inp)) assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol) + ift.fft.disable_fftw() @pmp('dim1', [12, 15]) @pmp('dim2', [9, 12]) @pmp('d1', [0.1, 1, 3.7]) @pmp('d2', [0.4, 1, 2.7]) -def test_fft2D(dim1, dim2, d1, d2, dtype, op): +def test_fft2D(dim1, dim2, d1, d2, dtype, op, fftw): + if fftw: + ift.fft.enable_fftw() tol = _get_rtol(dtype) a = ift.RGSpace([dim1, dim2], distances=[d1, d2]) b = ift.RGSpace( @@ -82,10 +97,13 @@ def test_fft2D(dim1, dim2, d1, d2, dtype, op): domain=a, random_type='normal', std=7, mean=3, dtype=dtype) out = fft.inverse_times(fft.times(inp)) assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol) + ift.fft.disable_fftw() @pmp('index', [0, 1, 2]) -def test_composed_fft(index, dtype, op): +def test_composed_fft(index, dtype, op, fftw): + if fftw: + ift.fft.enable_fftw() tol = _get_rtol(dtype) a = [a1, a2, a3] = [ift.RGSpace((32,)), @@ -97,6 +115,7 @@ def test_composed_fft(index, dtype, op): domain=(a1, a2, a3), random_type='normal', std=7, mean=3, dtype=dtype) out = fft.inverse_times(fft.times(inp)) assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol) + ift.fft.disable_fftw() @pmp('space', [ @@ -104,7 +123,9 @@ def test_composed_fft(index, dtype, op): ift.RGSpace((15, 27), distances=(.7, .33), harmonic=True), ift.RGSpace(73, distances=0.5643) ]) -def test_normalisation(space, dtype, op): +def test_normalisation(space, dtype, op, fftw): + if fftw: + ift.fft.enable_fftw() tol = 10*_get_rtol(dtype) cospace = space.get_default_codomain() fft = op(space, cospace) @@ -117,3 +138,4 @@ def test_normalisation(space, dtype, op): assert_allclose( inp.to_global_data()[zero_idx], out.integrate(), rtol=tol, atol=tol) assert_allclose(out.local_data, out2.local_data, rtol=tol, atol=tol) + ift.fft.disable_fftw()