diff --git a/README.md b/README.md index 627745c90ddc24d5c7569d7674387589baf20fb9..afe3908dabf8f858215bfe04565ac5954d7cdaea 100644 --- a/README.md +++ b/README.md @@ -39,9 +39,9 @@ Installation - [Python 3](https://www.python.org/) (3.5.x or later) - [SciPy](https://www.scipy.org/) -- [pyFFTW](https://pypi.python.org/pypi/pyFFTW) Optional dependencies: +- [pyFFTW](https://pypi.python.org/pypi/pyFFTW) for faster Fourier transforms - [pyHealpix](https://gitlab.mpcdf.mpg.de/ift/pyHealpix) (for harmonic transforms involving domains on the sphere) - [mpi4py](https://mpi4py.scipy.org) (for MPI-parallel execution) @@ -61,18 +61,29 @@ distributions, the "apt" lines will need slight changes. NIFTy5 and its mandatory dependencies can be installed via: - sudo apt-get install git libfftw3-dev python3 python3-pip python3-dev + sudo apt-get install git python3 python3-pip python3-dev pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/NIFTy.git@NIFTy_5 +Plotting support is added via: + + pip3 install --user matplotlib + +FFTW support is added via: + + sudo apt-get install libfftw3-dev + pip3 install --user pyfftw + +To actually use FFTW in your Nifty calculations, you need to call + + nifty5.fft.enable_fftw() + +at the beginning of your code. + (Note: If you encounter problems related to `pyFFTW`, make sure that you are using a pip-installed `pyFFTW` package. Unfortunately, some distributions are shipping an incorrectly configured `pyFFTW` package, which does not cooperate with the installed `FFTW3` libraries.) -Plotting support is added via: - - pip3 install --user matplotlib - Support for spherical harmonic transforms is added via: pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git @@ -86,7 +97,7 @@ MPI support is added via: To run the tests, additional packages are required: - sudo apt-get install python3-coverage python3-parameterized python3-pytest python3-pytest-cov + sudo apt-get install python3-coverage python3-pytest python3-pytest-cov Afterwards the tests (including a coverage report) can be run using the following command in the repository root: diff --git a/docs/source/ift.rst b/docs/source/ift.rst index 4622abbac0e274e5bae45600a7dfdf3c1bd9d5cf..c94fdb51950978a11678d776687436553f3b9d0a 100644 --- a/docs/source/ift.rst +++ b/docs/source/ift.rst @@ -13,15 +13,16 @@ There is a full toolbox of methods that can be used, like the classical approxim .. tip:: *In-a-nutshell introductions to information field theory* can be found in [2]_, [3]_, [4]_, and [5]_, with the latter probably being the most didactical. -.. [1] T.A. Enßlin et al. (2009), "Information field theory for cosmological perturbation reconstruction and nonlinear signal analysis", PhysRevD.80.105005, 09/2009; `arXiv:0806.3474 <http://www.arxiv.org/abs/0806.3474>`_ +.. [1] T.A. Enßlin et al. (2009), "Information field theory for cosmological perturbation reconstruction and nonlinear signal analysis", PhysRevD.80.105005, 09/2009; `[arXiv:0806.3474] <http://www.arxiv.org/abs/0806.3474>`_ -.. [2] T.A. Enßlin (2013), "Information field theory", proceedings of MaxEnt 2012 -- the 32nd International Workshop on Bayesian Inference and Maximum Entropy Methods in Science and Engineering; AIP Conference Proceedings, Volume 1553, Issue 1, p.184; `arXiv:1301.2556 <http://arxiv.org/abs/1301.2556>`_ +.. [2] T.A. Enßlin (2013), "Information field theory", proceedings of MaxEnt 2012 -- the 32nd International Workshop on Bayesian Inference and Maximum Entropy Methods in Science and Engineering; AIP Conference Proceedings, Volume 1553, Issue 1, p.184; `[arXiv:1301.2556] <http://arxiv.org/abs/1301.2556>`_ -.. [3] T.A. Enßlin (2014), "Astrophysical data analysis with information field theory", AIP Conference Proceedings, Volume 1636, Issue 1, p.49; `arXiv:1405.7701 <http://arxiv.org/abs/1405.7701>`_ +.. [3] T.A. Enßlin (2014), "Astrophysical data analysis with information field theory", AIP Conference Proceedings, Volume 1636, Issue 1, p.49; `[arXiv:1405.7701] <http://arxiv.org/abs/1405.7701>`_ .. [4] Wikipedia contributors (2018), `"Information field theory" <https://en.wikipedia.org/w/index.php?title=Information_field_theory&oldid=876731720>`_, Wikipedia, The Free Encyclopedia. -.. [5] T.A. Enßlin (2019), "Information theory for fields", accepted by Annalen der Physik; `arXiv:1804.03350 <http://arxiv.org/abs/1804.03350>`_ +.. [5] T.A. Enßlin (2019), "Information theory for fields", accepted by Annalen der Physik; `[DOI] <https://doi.org/10.1002/andp.201800127>`_, `[arXiv:1804.03350] <http://arxiv.org/abs/1804.03350>`_ + Discretized continuum --------------------- @@ -103,7 +104,6 @@ and the measurement equation is linear in both signal and noise, with :math:`{R}` being the measurement response, which maps the continous signal field into the discrete data space. This is called a free theory, as the information Hamiltonian -associate professor (*FIXME*: really?) .. math:: @@ -135,7 +135,7 @@ the posterior covariance operator, and j = R^\dagger N^{-1} d -the information source. The operation in :math:`{d= D\,R^\dagger N^{-1} d}` is also called the generalized Wiener filter. +the information source. The operation in :math:`{m = D\,R^\dagger N^{-1} d}` is also called the generalized Wiener filter. NIFTy permits to define the involved operators :math:`{R}`, :math:`{R^\dagger}`, :math:`{S}`, and :math:`{N}` implicitly, as routines that can be applied to vectors, but which do not require the explicit storage of the matrix elements of the operators. diff --git a/docs/source/index.rst b/docs/source/index.rst index afbbf8edb99b37b0387684ee5e3a8a43fa8fd0fb..a6a6da8804c3b1f59256fd067f8e9023b8bc3fb3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,7 +1,7 @@ NIFTy -- Numerical Information Field Theory =========================================== -**NIFTy** [1]_, "\ **N**\umerical **I**\nformation **F**\ield **T**\heor\ **y**\ ", is a versatile library designed to enable the development of signal inference algorithms that are independent of the underlying spatial grid and its resolution. +**NIFTy** [1]_, [2]_, "\ **N**\umerical **I**\nformation **F**\ield **T**\heor\ **y**\ ", is a versatile library designed to enable the development of signal inference algorithms that are independent of the underlying spatial grid and its resolution. Its object-oriented framework is written in Python, although it accesses libraries written in C++ and C for efficiency. NIFTy offers a toolkit that abstracts discretized representations of continuous spaces, fields in these spaces, and operators acting on fields into classes. @@ -13,7 +13,9 @@ The set of spaces on which NIFTy operates comprises point sets, *n*-dimensional References ---------- -.. [1] Steininger et al., "NIFTy 3 - Numerical Information Field Theory - A Python framework for multicomponent signal inference on HPC clusters", 2017, submitted to PLOS One; `[arXiv:1708.01073] <https://arxiv.org/abs/1708.01073>`_ +.. [1] Selig et al., "NIFTY - Numerical Information Field Theory. A versatile PYTHON library for signal inference ", 2013, Astronmy and Astrophysics 554, 26; `[DOI] <https://ui.adsabs.harvard.edu/link_gateway/2013A&A...554A..26S/doi:10.1051/0004-6361/201321236>`_, `[arXiv:1301.4499] <https://arxiv.org/abs/1301.4499>`_ + +.. [2] Steininger et al., "NIFTy 3 - Numerical Information Field Theory - A Python framework for multicomponent signal inference on HPC clusters", 2017, accepted by Annalen der Physik; `[arXiv:1708.01073] <https://arxiv.org/abs/1708.01073>`_ Contents ........ diff --git a/docs/source/installation.rst b/docs/source/installation.rst index d2f182d1dce965c6ef6216b038979094cdfd4abb..84a5ba8f1539c8513783904c64f36ef6cd016a0e 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -7,18 +7,29 @@ distributions, the "apt" lines will need slight changes. NIFTy5 and its mandatory dependencies can be installed via:: - sudo apt-get install git libfftw3-dev python3 python3-pip python3-dev + sudo apt-get install git python3 python3-pip python3-dev pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/NIFTy.git@NIFTy_5 +Plotting support is added via:: + + pip3 install --user matplotlib + +FFTW support is added via: + + sudo apt-get install libfftw3-dev + pip3 install --user pyfftw + +To actually use FFTW in your Nifty calculations, you need to call + + nifty5.fft.enable_fftw() + +at the beginning of your code. + (Note: If you encounter problems related to `pyFFTW`, make sure that you are using a pip-installed `pyFFTW` package. Unfortunately, some distributions are shipping an incorrectly configured `pyFFTW` package, which does not cooperate with the installed `FFTW3` libraries.) -Plotting support is added via:: - - pip3 install --user matplotlib - Support for spherical harmonic transforms is added via:: pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git diff --git a/nifty5/domains/domain.py b/nifty5/domains/domain.py index b635ae1ad2dfccd9375d3b7455915ed5d0d5fdf9..de88089718d9577b700ab5cb2adc6e9cacee47cf 100644 --- a/nifty5/domains/domain.py +++ b/nifty5/domains/domain.py @@ -16,10 +16,10 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. from functools import reduce -from ..utilities import NiftyMetaBase +from ..utilities import NiftyMeta -class Domain(NiftyMetaBase()): +class Domain(metaclass=NiftyMeta): """The abstract class repesenting a (structured or unstructured) domain. """ def __repr__(self): diff --git a/nifty5/fft.py b/nifty5/fft.py index 458ffb3dca52b32bd471f66a151c69b166be0026..0fc2381409bc92ef66b28f930de95115fcd4885c 100644 --- a/nifty5/fft.py +++ b/nifty5/fft.py @@ -19,23 +19,63 @@ from .utilities import iscomplextype import numpy as np -_use_fftw = True +_use_fftw = False +_fftw_prepped = False +_fft_extra_args = {} -if _use_fftw: - import pyfftw - from pyfftw.interfaces.numpy_fft import fftn, rfftn, ifftn - pyfftw.interfaces.cache.enable() - pyfftw.interfaces.cache.set_keepalive_time(1000.) - # Optional extra arguments for the FFT calls - # if exact reproducibility is needed, - # set "planner_effort" to "FFTW_ESTIMATE" - import os - nthreads = int(os.getenv("OMP_NUM_THREADS", "1")) - _fft_extra_args = dict(planner_effort='FFTW_ESTIMATE', threads=nthreads) -else: - from numpy.fft import fftn, rfftn, ifftn - _fft_extra_args = {} +def enable_fftw(): + global _use_fftw + _use_fftw = True + + +def disable_fftw(): + global _use_fftw + _use_fftw = False + + +def _init_pyfftw(): + global _fft_extra_args, _fftw_prepped + if not _fftw_prepped: + import pyfftw + from pyfftw.interfaces.numpy_fft import fftn, rfftn, ifftn + pyfftw.interfaces.cache.enable() + pyfftw.interfaces.cache.set_keepalive_time(1000.) + # Optional extra arguments for the FFT calls + # if exact reproducibility is needed, + # set "planner_effort" to "FFTW_ESTIMATE" + import os + nthreads = int(os.getenv("OMP_NUM_THREADS", "1")) + _fft_extra_args = dict(planner_effort='FFTW_ESTIMATE', + threads=nthreads) + _fftw_prepped = True + + +def fftn(a, axes=None): + if _use_fftw: + from pyfftw.interfaces.numpy_fft import fftn + _init_pyfftw() + return fftn(a, axes=axes, **_fft_extra_args) + else: + return np.fft.fftn(a, axes=axes) + + +def rfftn(a, axes=None): + if _use_fftw: + from pyfftw.interfaces.numpy_fft import rfftn + _init_pyfftw() + return rfftn(a, axes=axes, **_fft_extra_args) + else: + return np.fft.rfftn(a, axes=axes) + + +def ifftn(a, axes=None): + if _use_fftw: + from pyfftw.interfaces.numpy_fft import ifftn + _init_pyfftw() + return ifftn(a, axes=axes, **_fft_extra_args) + else: + return np.fft.ifftn(a, axes=axes) def hartley(a, axes=None): @@ -46,7 +86,7 @@ def hartley(a, axes=None): if iscomplextype(a.dtype): raise TypeError("Hartley transform requires real-valued arrays.") - tmp = rfftn(a, axes=axes, **_fft_extra_args) + tmp = rfftn(a, axes=axes) def _fill_array(tmp, res, axes): if axes is None: @@ -89,7 +129,7 @@ def my_fftn_r2c(a, axes=None): if iscomplextype(a.dtype): raise TypeError("Transform requires real-valued input arrays.") - tmp = rfftn(a, axes=axes, **_fft_extra_args) + tmp = rfftn(a, axes=axes) def _fill_complex_array(tmp, res, axes): if axes is None: @@ -123,4 +163,4 @@ def my_fftn_r2c(a, axes=None): def my_fftn(a, axes=None): - return fftn(a, axes=axes, **_fft_extra_args) + return fftn(a, axes=axes) diff --git a/nifty5/linearization.py b/nifty5/linearization.py index babcd420d2609e5037b069a738735bdd77cb9f91..1affc0b65134603bb47de2a69428310d33acb2c7 100644 --- a/nifty5/linearization.py +++ b/nifty5/linearization.py @@ -461,7 +461,7 @@ class Linearization(object): if len(constants) == 0: return Linearization.make_var(field, want_metric) else: - ops = [ScalingOperator(0. if key in constants else 1., dom) - for key, dom in field.domain.items()] - bdop = BlockDiagonalOperator(field.domain, tuple(ops)) + ops = {key: ScalingOperator(0. if key in constants else 1., dom) + for key, dom in field.domain.items()} + bdop = BlockDiagonalOperator(field.domain, ops) return Linearization(field, bdop, want_metric=want_metric) diff --git a/nifty5/minimization/energy.py b/nifty5/minimization/energy.py index 3eaa94ef43f8eb77117e917c6545fb563151baf3..3d42c13af16356f161909757f1c04466191d771f 100644 --- a/nifty5/minimization/energy.py +++ b/nifty5/minimization/energy.py @@ -15,10 +15,10 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. -from ..utilities import NiftyMetaBase +from ..utilities import NiftyMeta -class Energy(NiftyMetaBase()): +class Energy(metaclass=NiftyMeta): """Provides the functional used by minimization schemes. The Energy object is an implementation of a scalar function including its diff --git a/nifty5/minimization/iteration_controllers.py b/nifty5/minimization/iteration_controllers.py index 467e4a5d658be51e14732d0b2cb59f3d4cd47c64..e27201e85501a2612372b5dbb274555dc55dc13d 100644 --- a/nifty5/minimization/iteration_controllers.py +++ b/nifty5/minimization/iteration_controllers.py @@ -16,11 +16,11 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. from ..logger import logger -from ..utilities import NiftyMetaBase +from ..utilities import NiftyMeta import numpy as np -class IterationController(NiftyMetaBase()): +class IterationController(metaclass=NiftyMeta): """The abstract base class for all iteration controllers. An iteration controller is an object that monitors the progress of a minimization iteration. At the begin of the minimization, its start() diff --git a/nifty5/minimization/line_search.py b/nifty5/minimization/line_search.py index fac4170450c4abd1b39fde7329067a2afe4eea53..89a771b931a4b824d63f22353e960449583d9ffa 100644 --- a/nifty5/minimization/line_search.py +++ b/nifty5/minimization/line_search.py @@ -15,7 +15,7 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. -from ..utilities import NiftyMetaBase +from ..utilities import NiftyMeta from ..logger import logger import numpy as np @@ -103,7 +103,7 @@ class LineEnergy(object): return res.real -class LineSearch(NiftyMetaBase()): +class LineSearch(metaclass=NiftyMeta): """Class for finding a step size that satisfies the strong Wolfe conditions. diff --git a/nifty5/minimization/minimizer.py b/nifty5/minimization/minimizer.py index 17ea9f79f383a1d6d7ab53e6ee90964b6f975909..eb8199c808b6ec9ee92b923fc633edc498d20625 100644 --- a/nifty5/minimization/minimizer.py +++ b/nifty5/minimization/minimizer.py @@ -15,10 +15,10 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. -from ..utilities import NiftyMetaBase +from ..utilities import NiftyMeta -class Minimizer(NiftyMetaBase()): +class Minimizer(metaclass=NiftyMeta): """A base class used by all minimizers.""" # MR FIXME: the docstring is partially ignored by Sphinx. Why? diff --git a/nifty5/operators/block_diagonal_operator.py b/nifty5/operators/block_diagonal_operator.py index be29fc3ae6111be3832ad98917892009488de68f..72e3a72e6ff92fc73a7376b3bf5f9927053158d6 100644 --- a/nifty5/operators/block_diagonal_operator.py +++ b/nifty5/operators/block_diagonal_operator.py @@ -24,17 +24,16 @@ class BlockDiagonalOperator(EndomorphicOperator): """ Parameters ---------- + domain : MultiDomain + Domain and target of the operator. operators : dict - Dictionary with operators domain names as keys and LinearOperators as - items. + Dictionary with subdomain names as keys and LinearOperators as items. """ def __init__(self, domain, operators): if not isinstance(domain, MultiDomain): raise TypeError("MultiDomain expected") - if not isinstance(operators, tuple): - raise TypeError("tuple expected") self._domain = domain - self._ops = operators + self._ops = tuple(operators[key] for key in domain.keys()) self._capability = self._all_ops for op in self._ops: if op is not None: @@ -55,13 +54,14 @@ class BlockDiagonalOperator(EndomorphicOperator): def _combine_chain(self, op): if self._domain != op._domain: raise ValueError("domain mismatch") - res = tuple(v1(v2) for v1, v2 in zip(self._ops, op._ops)) + res = {key: v1(v2) + for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)} return BlockDiagonalOperator(self._domain, res) def _combine_sum(self, op, selfneg, opneg): from ..operators.sum_operator import SumOperator if self._domain != op._domain: raise ValueError("domain mismatch") - res = tuple(SumOperator.make([v1, v2], [selfneg, opneg]) - for v1, v2 in zip(self._ops, op._ops)) + res = {key: SumOperator.make([v1, v2], [selfneg, opneg]) + for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)} return BlockDiagonalOperator(self._domain, res) diff --git a/nifty5/operators/energy_operators.py b/nifty5/operators/energy_operators.py index 6788366f5dd0ec4b33f6d63f13d788810d5d5c95..bf66afcc8080c47837dd9956478d74e0e6c3f141 100644 --- a/nifty5/operators/energy_operators.py +++ b/nifty5/operators/energy_operators.py @@ -318,7 +318,7 @@ class Hamiltonian(EnergyOperator): class AveragedEnergy(EnergyOperator): - """Computes Kullbach-Leibler (KL) divergence or Gibbs free energies. + """Computes Kullback-Leibler (KL) divergence or Gibbs free energies. A sample-averaged energy, e.g. an Hamiltonian, approximates the relevant part of a KL to be used in Variational Bayes inference if the samples are diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 3b2c8a102bb8992576bf7134d9056dd35b23fa7f..3c1552a41d3fec62c2a5345e8456b433bdba87d3 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -16,10 +16,10 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. import numpy as np -from ..utilities import NiftyMetaBase, indent +from ..utilities import NiftyMeta, indent -class Operator(NiftyMetaBase()): +class Operator(metaclass=NiftyMeta): """Transforms values defined on one domain into values defined on another domain, and can also provide the Jacobian. """ diff --git a/nifty5/probing.py b/nifty5/probing.py index 7dfef05f15494fa220a502c5ffe68ea016c6c08f..4675ff4a2847de45bd3e8553975546fa7a394219 100644 --- a/nifty5/probing.py +++ b/nifty5/probing.py @@ -16,6 +16,8 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. from .field import Field +from .operators.endomorphic_operator import EndomorphicOperator +from .operators.operator import Operator class StatCalculator(object): @@ -69,6 +71,29 @@ class StatCalculator(object): def probe_with_posterior_samples(op, post_op, nprobes): + '''FIXME + + Parameters + ---------- + op : EndomorphicOperator + FIXME + post_op : Operator + FIXME + nprobes : int + Number of samples which shall be drawn. + + Returns + ------- + List of Field + List of two fields: the mean and the variance. + ''' + if not isinstance(op, EndomorphicOperator): + raise TypeError + if post_op is not None: + if not isinstance(post_op, Operator): + raise TypeError + if post_op.domain is not op.target: + raise ValueError sc = StatCalculator() for i in range(nprobes): if post_op is None: @@ -82,6 +107,28 @@ def probe_with_posterior_samples(op, post_op, nprobes): def probe_diagonal(op, nprobes, random_type="pm1"): + '''Probes the diagonal of an endomorphic operator. + + The operator is called on a user-specified number of randomly generated + input vectors :math:`v_i`, producing :math:`r_i`. The estimated diagonal + is the mean of :math:`r_i^\dagger v_i`. + + Parameters + ---------- + op: EndomorphicOperator + The operator to be probed. + nprobes: int + The number of probes to be used. + random_type: str + The kind of random number distribution to be used for the probing. + The default value `pm1` causes the probing vector to be randomly + filled with values of +1 and -1. + + Returns + ------- + Field + The estimated diagonal. + ''' sc = StatCalculator() for i in range(nprobes): input = Field.from_random(random_type, op.domain) diff --git a/nifty5/sugar.py b/nifty5/sugar.py index 68484a326e56b8b1ebb4e9a177c2b88d674e1d0c..7e1342b3b5f073b266c9d1f4630a63ecc72af10b 100644 --- a/nifty5/sugar.py +++ b/nifty5/sugar.py @@ -363,7 +363,7 @@ def makeOp(input): return DiagonalOperator(input) if isinstance(input, MultiField): return BlockDiagonalOperator( - input.domain, tuple(makeOp(val) for val in input.values())) + input.domain, {key: makeOp(val) for key, val in enumerate(input)}) raise NotImplementedError diff --git a/nifty5/utilities.py b/nifty5/utilities.py index 15e32f45d33f873e80315d98ba82f469ed7f0aa0..899f9989bdd494f1b689b2b68dbd8d467bb1ddda 100644 --- a/nifty5/utilities.py +++ b/nifty5/utilities.py @@ -20,10 +20,9 @@ from itertools import product from functools import reduce import numpy as np -from future.utils import with_metaclass __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space", - "memo", "NiftyMetaBase", "my_sum", "my_lincomb_simple", + "memo", "NiftyMeta", "my_sum", "my_lincomb_simple", "my_lincomb", "indent", "my_product", "frozendict", "special_add_at", "iscomplextype"] @@ -178,10 +177,6 @@ class NiftyMeta(_DocStringInheritor): pass -def NiftyMetaBase(): - return with_metaclass(NiftyMeta, type('NewBase', (object,), {})) - - class frozendict(collections.abc.Mapping): """ An immutable wrapper around dictionaries that implements the complete diff --git a/setup.py b/setup.py index b86435dde2a78a4b75e2ff2a00223bf5659b0f53..79d6f5542966fd5fff64d761ba0dc5de2d910492 100644 --- a/setup.py +++ b/setup.py @@ -39,8 +39,8 @@ setup(name="nifty5", packages=find_packages(include=["nifty5", "nifty5.*"]), zip_safe=True, license="GPLv3", - setup_requires=['future', 'scipy'], - install_requires=['future', 'scipy', 'pyfftw>=0.10.4'], + setup_requires=['scipy'], + install_requires=['scipy'], classifiers=[ "Development Status :: 4 - Beta", "Topic :: Utilities", diff --git a/test/test_multi_field.py b/test/test_multi_field.py index a316b3604421558ffb1c53a3c59b3fae9e7b2fcf..6990c038b08b137d5c1e986a5ff2c5e032f6fc15 100644 --- a/test/test_multi_field.py +++ b/test/test_multi_field.py @@ -43,7 +43,8 @@ def test_dataconv(): def test_blockdiagonal(): - op = ift.BlockDiagonalOperator(dom, (ift.ScalingOperator(20., dom["d1"]),)) + op = ift.BlockDiagonalOperator( + dom, {"d1": ift.ScalingOperator(20., dom["d1"])}) op2 = op(op) ift.extra.consistency_check(op2) assert_equal(type(op2), ift.BlockDiagonalOperator) diff --git a/test/test_operators/test_fft_operator.py b/test/test_operators/test_fft_operator.py index 171f26bfab0ce31629c5a3dcd9560052c4715b16..47b8e6a11722cccca76ef65c742bfd14d87ee826 100644 --- a/test/test_operators/test_fft_operator.py +++ b/test/test_operators/test_fft_operator.py @@ -17,7 +17,7 @@ import numpy as np import pytest -from numpy.testing import assert_allclose +from numpy.testing import assert_, assert_allclose import nifty5 as ift @@ -34,10 +34,22 @@ def _get_rtol(tp): pmp = pytest.mark.parametrize dtype = list2fixture([np.float64, np.float32, np.complex64, np.complex128]) op = list2fixture([ift.HartleyOperator, ift.FFTOperator]) +fftw = list2fixture([False, True]) + + +def test_switch(): + ift.fft.enable_fftw() + assert_(ift.fft._use_fftw is True) + ift.fft.disable_fftw() + assert_(ift.fft._use_fftw is False) + ift.fft.enable_fftw() + assert_(ift.fft._use_fftw is True) @pmp('d', [0.1, 1, 3.7]) -def test_fft1D(d, dtype, op): +def test_fft1D(d, dtype, op, fftw): + if fftw: + ift.fft.enable_fftw() dim1 = 16 tol = _get_rtol(dtype) a = ift.RGSpace(dim1, distances=d) @@ -57,13 +69,16 @@ def test_fft1D(d, dtype, op): domain=a, random_type='normal', std=7, mean=3, dtype=dtype) out = fft.inverse_times(fft.times(inp)) assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol) + ift.fft.disable_fftw() @pmp('dim1', [12, 15]) @pmp('dim2', [9, 12]) @pmp('d1', [0.1, 1, 3.7]) @pmp('d2', [0.4, 1, 2.7]) -def test_fft2D(dim1, dim2, d1, d2, dtype, op): +def test_fft2D(dim1, dim2, d1, d2, dtype, op, fftw): + if fftw: + ift.fft.enable_fftw() tol = _get_rtol(dtype) a = ift.RGSpace([dim1, dim2], distances=[d1, d2]) b = ift.RGSpace( @@ -82,10 +97,13 @@ def test_fft2D(dim1, dim2, d1, d2, dtype, op): domain=a, random_type='normal', std=7, mean=3, dtype=dtype) out = fft.inverse_times(fft.times(inp)) assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol) + ift.fft.disable_fftw() @pmp('index', [0, 1, 2]) -def test_composed_fft(index, dtype, op): +def test_composed_fft(index, dtype, op, fftw): + if fftw: + ift.fft.enable_fftw() tol = _get_rtol(dtype) a = [a1, a2, a3] = [ift.RGSpace((32,)), @@ -97,6 +115,7 @@ def test_composed_fft(index, dtype, op): domain=(a1, a2, a3), random_type='normal', std=7, mean=3, dtype=dtype) out = fft.inverse_times(fft.times(inp)) assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol) + ift.fft.disable_fftw() @pmp('space', [ @@ -104,7 +123,9 @@ def test_composed_fft(index, dtype, op): ift.RGSpace((15, 27), distances=(.7, .33), harmonic=True), ift.RGSpace(73, distances=0.5643) ]) -def test_normalisation(space, dtype, op): +def test_normalisation(space, dtype, op, fftw): + if fftw: + ift.fft.enable_fftw() tol = 10*_get_rtol(dtype) cospace = space.get_default_codomain() fft = op(space, cospace) @@ -117,3 +138,4 @@ def test_normalisation(space, dtype, op): assert_allclose( inp.to_global_data()[zero_idx], out.integrate(), rtol=tol, atol=tol) assert_allclose(out.local_data, out2.local_data, rtol=tol, atol=tol) + ift.fft.disable_fftw()