Commit b3de9ca2 authored by Martin Reinecke's avatar Martin Reinecke

reintroduce FFTW

parent f2d9b197
Pipeline #47254 passed with stages
in 16 minutes and 48 seconds
......@@ -10,8 +10,9 @@ RUN apt-get update && apt-get install -y \
# Testing dependencies
python3-pytest-cov jupyter \
# Optional NIFTy dependencies
python3-mpi4py python3-matplotlib python3-pynfft \
libfftw3-dev python3-mpi4py python3-matplotlib python3-pynfft \
# more optional NIFTy dependencies
&& pip3 install pyfftw \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/ift/nifty_gridder.git \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/mtr/pypocketfft.git \
......
......@@ -50,6 +50,7 @@ Installation
- [pypocketfft](https://gitlab.mpcdf.mpg.de/mtr/pypocketfft)
Optional dependencies:
- [pyFFTW](https://pypi.python.org/pypi/pyFFTW) for faster Fourier transforms
- [pyHealpix](https://gitlab.mpcdf.mpg.de/ift/pyHealpix) (for harmonic
transforms involving domains on the sphere)
- [nifty_gridder](https://gitlab.mpcdf.mpg.de/ift/nifty_gridder) (for radio
......@@ -79,6 +80,23 @@ Plotting support is added via:
sudo apt-get install python3-matplotlib
NIFTy uses pypocketfft by default. For large problems FFTW may be
used because of its higher performance. It can be installed via:
sudo apt-get install libfftw3-dev
pip3 install --user pyfftw
To enable FFTW usage in NIFTy, call
nifty5.fft.enable_fftw()
at the beginning of your code.
(Note: If you encounter problems related to `pyFFTW`, make sure that you are
using a pip-installed `pyFFTW` package. Unfortunately, some distributions are
shipping an incorrectly configured `pyFFTW` package, which does not cooperate
with the installed `FFTW3` libraries.)
Support for spherical harmonic transforms is added via:
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git
......
......@@ -31,7 +31,7 @@ import numpy as np
import nifty5 as ift
ift.fft.enable_fftw()
def random_los(n_los):
starts = list(np.random.uniform(0, 1, (n_los, 2)).T)
ends = list(np.random.uniform(0, 1, (n_los, 2)).T)
......
......@@ -15,6 +15,23 @@ Plotting support is added via::
sudo apt-get install python3-matplotlib
NIFTy uses pypocketfft by default. For large problems FFTW may be
used because of its higher performance. It can be installed via::
sudo apt-get install libfftw3-dev
pip3 install --user pyfftw
To enable FFTW usage in NIFTy, call::
nifty5.fft.enable_fftw()
at the beginning of your code.
(Note: If you encounter problems related to `pyFFTW`, make sure that you are
using a pip-installed `pyFFTW` package. Unfortunately, some distributions are
shipping an incorrectly configured `pyFFTW` package, which does not cooperate
with the installed `FFTW3` libraries.)
Support for spherical harmonic transforms is added via::
pip3 install --user git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git
......
......@@ -20,6 +20,37 @@ import numpy as np
import pypocketfft
_use_fftw = False
_fftw_prepped = False
_fft_extra_args = {}
def enable_fftw():
global _use_fftw
_use_fftw = True
def disable_fftw():
global _use_fftw
_use_fftw = False
def _init_pyfftw():
global _fft_extra_args, _fftw_prepped
if not _fftw_prepped:
import pyfftw
from pyfftw.interfaces.numpy_fft import fftn, rfftn, ifftn
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(1000.)
# Optional extra arguments for the FFT calls
# if exact reproducibility is needed,
# set "planner_effort" to "FFTW_ESTIMATE"
import os
nthreads = int(os.getenv("OMP_NUM_THREADS", "1"))
_fft_extra_args = dict(planner_effort='FFTW_ESTIMATE',
threads=nthreads)
_fftw_prepped = True
# FIXME this should not be necessary ... no one should call a complex FFT
# with a float array.
def _make_complex(a):
......@@ -33,14 +64,26 @@ def _make_complex(a):
def fftn(a, axes=None):
if _use_fftw:
from pyfftw.interfaces.numpy_fft import fftn
_init_pyfftw()
return fftn(a, axes=axes, **_fft_extra_args)
return pypocketfft.fftn(_make_complex(a), axes=axes)
def rfftn(a, axes=None):
if _use_fftw:
from pyfftw.interfaces.numpy_fft import rfftn
_init_pyfftw()
return rfftn(a, axes=axes, **_fft_extra_args)
return pypocketfft.rfftn(a, axes=axes)
def ifftn(a, axes=None):
if _use_fftw:
from pyfftw.interfaces.numpy_fft import ifftn
_init_pyfftw()
return ifftn(a, axes=axes, **_fft_extra_args)
# FIXME this is a temporary fix and can be done more elegantly
if axes is None:
fct = 1./a.size
......@@ -50,6 +93,11 @@ def ifftn(a, axes=None):
def hartley(a, axes=None):
if _use_fftw:
from pyfftw.interfaces.numpy_fft import rfftn
_init_pyfftw()
tmp = rfftn(a, axes=axes, **_fft_extra_args)
return pypocketfft.complex2hartley(a, tmp, axes)
return pypocketfft.hartley2(a, axes=axes)
......
......@@ -34,9 +34,22 @@ def _get_rtol(tp):
pmp = pytest.mark.parametrize
dtype = list2fixture([np.float64, np.float32, np.complex64, np.complex128])
op = list2fixture([ift.HartleyOperator, ift.FFTOperator])
fftw = list2fixture([False, True])
def test_switch():
ift.fft.enable_fftw()
assert_(ift.fft._use_fftw is True)
ift.fft.disable_fftw()
assert_(ift.fft._use_fftw is False)
ift.fft.enable_fftw()
assert_(ift.fft._use_fftw is True)
@pmp('d', [0.1, 1, 3.7])
def test_fft1D(d, dtype, op):
def test_fft1D(d, dtype, op, fftw):
if fftw:
ift.fft.enable_fftw()
dim1 = 16
tol = _get_rtol(dtype)
a = ift.RGSpace(dim1, distances=d)
......@@ -56,13 +69,16 @@ def test_fft1D(d, dtype, op):
domain=a, random_type='normal', std=7, mean=3, dtype=dtype)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw()
@pmp('dim1', [12, 15])
@pmp('dim2', [9, 12])
@pmp('d1', [0.1, 1, 3.7])
@pmp('d2', [0.4, 1, 2.7])
def test_fft2D(dim1, dim2, d1, d2, dtype, op):
def test_fft2D(dim1, dim2, d1, d2, dtype, op, fftw):
if fftw:
ift.fft.enable_fftw()
tol = _get_rtol(dtype)
a = ift.RGSpace([dim1, dim2], distances=[d1, d2])
b = ift.RGSpace(
......@@ -81,10 +97,13 @@ def test_fft2D(dim1, dim2, d1, d2, dtype, op):
domain=a, random_type='normal', std=7, mean=3, dtype=dtype)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw()
@pmp('index', [0, 1, 2])
def test_composed_fft(index, dtype, op):
def test_composed_fft(index, dtype, op, fftw):
if fftw:
ift.fft.enable_fftw()
tol = _get_rtol(dtype)
a = [a1, a2,
a3] = [ift.RGSpace((32,)),
......@@ -96,6 +115,7 @@ def test_composed_fft(index, dtype, op):
domain=(a1, a2, a3), random_type='normal', std=7, mean=3, dtype=dtype)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw()
@pmp('space', [
......@@ -103,7 +123,9 @@ def test_composed_fft(index, dtype, op):
ift.RGSpace((15, 27), distances=(.7, .33), harmonic=True),
ift.RGSpace(73, distances=0.5643)
])
def test_normalisation(space, dtype, op):
def test_normalisation(space, dtype, op, fftw):
if fftw:
ift.fft.enable_fftw()
tol = 10*_get_rtol(dtype)
cospace = space.get_default_codomain()
fft = op(space, cospace)
......@@ -116,3 +138,4 @@ def test_normalisation(space, dtype, op):
assert_allclose(
inp.to_global_data()[zero_idx], out.integrate(), rtol=tol, atol=tol)
assert_allclose(out.local_data, out2.local_data, rtol=tol, atol=tol)
ift.fft.disable_fftw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment