Commit 71cc7162 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'isolate_fft' into 'NIFTy_5'

make switching between FFT libraries easier

See merge request ift/nifty-dev!114
parents b95589ff be1e365a
from __future__ import absolute_import, division, print_function
from .utilities import iscomplextype
import numpy as np
_use_fftw = True
if _use_fftw:
import pyfftw
from pyfftw.interfaces.numpy_fft import fftn, rfftn, ifftn
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(1000.)
# Optional extra arguments for the FFT calls
# _fft_extra_args = {}
# if exact reproducibility is needed, use this:
import os
nthreads = int(os.getenv("OMP_NUM_THREADS", "1"))
_fft_extra_args = dict(planner_effort='FFTW_ESTIMATE', threads=nthreads)
else:
from numpy.fft import fftn, rfftn, ifftn
_fft_extra_args={}
def hartley(a, axes=None):
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis < len(a.shape) for axis in axes):
raise ValueError("Provided axes do not match array shape")
if iscomplextype(a.dtype):
raise TypeError("Hartley transform requires real-valued arrays.")
tmp = rfftn(a, axes=axes, **_fft_extra_args)
def _fill_array(tmp, res, axes):
if axes is None:
axes = tuple(range(tmp.ndim))
lastaxis = axes[-1]
ntmplast = tmp.shape[lastaxis]
slice1 = (slice(None),)*lastaxis + (slice(0, ntmplast),)
np.add(tmp.real, tmp.imag, out=res[slice1])
def _fill_upper_half(tmp, res, axes):
lastaxis = axes[-1]
nlast = res.shape[lastaxis]
ntmplast = tmp.shape[lastaxis]
nrem = nlast - ntmplast
slice1 = [slice(None)]*lastaxis + [slice(ntmplast, None)]
slice2 = [slice(None)]*lastaxis + [slice(nrem, 0, -1)]
for i in axes[:-1]:
slice1[i] = slice(1, None)
slice2[i] = slice(None, 0, -1)
slice1 = tuple(slice1)
slice2 = tuple(slice2)
np.subtract(tmp[slice2].real, tmp[slice2].imag, out=res[slice1])
for i, ax in enumerate(axes[:-1]):
dim1 = (slice(None),)*ax + (slice(0, 1),)
axes2 = axes[:i] + axes[i+1:]
_fill_upper_half(tmp[dim1], res[dim1], axes2)
_fill_upper_half(tmp, res, axes)
return res
return _fill_array(tmp, np.empty_like(a), axes)
# Do a real-to-complex forward FFT and return the _full_ output array
def my_fftn_r2c(a, axes=None):
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis < len(a.shape) for axis in axes):
raise ValueError("Provided axes do not match array shape")
if iscomplextype(a.dtype):
raise TypeError("Transform requires real-valued input arrays.")
tmp = rfftn(a, axes=axes, **_fft_extra_args)
def _fill_complex_array(tmp, res, axes):
if axes is None:
axes = tuple(range(tmp.ndim))
lastaxis = axes[-1]
ntmplast = tmp.shape[lastaxis]
slice1 = [slice(None)]*lastaxis + [slice(0, ntmplast)]
res[tuple(slice1)] = tmp
def _fill_upper_half_complex(tmp, res, axes):
lastaxis = axes[-1]
nlast = res.shape[lastaxis]
ntmplast = tmp.shape[lastaxis]
nrem = nlast - ntmplast
slice1 = [slice(None)]*lastaxis + [slice(ntmplast, None)]
slice2 = [slice(None)]*lastaxis + [slice(nrem, 0, -1)]
for i in axes[:-1]:
slice1[i] = slice(1, None)
slice2[i] = slice(None, 0, -1)
# np.conjugate(tmp[slice2], out=res[slice1])
res[tuple(slice1)] = np.conjugate(tmp[tuple(slice2)])
for i, ax in enumerate(axes[:-1]):
dim1 = tuple([slice(None)]*ax + [slice(0, 1)])
axes2 = axes[:i] + axes[i+1:]
_fill_upper_half_complex(tmp[dim1], res[dim1], axes2)
_fill_upper_half_complex(tmp, res, axes)
return res
return _fill_complex_array(tmp, np.empty_like(a, dtype=tmp.dtype), axes)
def my_fftn(a, axes=None):
return fftn(a, axes=axes, **_fft_extra_args)
......@@ -20,7 +20,7 @@ from __future__ import absolute_import, division, print_function
import numpy as np
from .. import dobj, utilities
from .. import dobj, utilities, fft
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.gl_space import GLSpace
......@@ -74,8 +74,6 @@ class FFTOperator(LinearOperator):
adom.check_codomain(target)
target.check_codomain(adom)
utilities.fft_prep()
def apply(self, x, mode):
from pyfftw.interfaces.numpy_fft import fftn, ifftn
self._check_input(x, mode)
......@@ -174,8 +172,6 @@ class HartleyOperator(LinearOperator):
adom.check_codomain(target)
target.check_codomain(adom)
utilities.fft_prep()
def apply(self, x, mode):
self._check_input(x, mode)
if utilities.iscomplextype(x.dtype):
......@@ -190,14 +186,14 @@ class HartleyOperator(LinearOperator):
oldax = dobj.distaxis(x.val)
if oldax not in axes: # straightforward, no redistribution needed
ldat = x.local_data
ldat = utilities.hartley(ldat, axes=axes)
ldat = fft.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
elif len(axes) < len(x.shape) or len(axes) == 1:
# we can use one Hartley pass in between the redistributions
tmp = dobj.redistribute(x.val, nodist=axes)
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = utilities.hartley(ldat, axes=axes)
ldat = fft.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
else: # two separate, full FFTs needed
......@@ -211,7 +207,7 @@ class HartleyOperator(LinearOperator):
rem_axes = tuple(i for i in axes if i != oldax)
tmp = x.val
ldat = dobj.local_data(tmp)
ldat = utilities.my_fftn_r2c(ldat, axes=rem_axes)
ldat = fft.my_fftn_r2c(ldat, axes=rem_axes)
if oldax != 0:
raise ValueError("bad distribution")
ldat2 = ldat.reshape((ldat.shape[0],
......@@ -220,7 +216,7 @@ class HartleyOperator(LinearOperator):
tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp)
ldat2 = utilities.my_fftn(ldat2, axes=(1,))
ldat2 = fft.my_fftn(ldat2, axes=(1,))
ldat2 = ldat2.real+ldat2.imag
tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
......
......@@ -20,9 +20,10 @@ from __future__ import absolute_import, division, print_function
from .. import dobj
from ..compat import *
from .. import fft
from ..domain_tuple import DomainTuple
from ..field import Field
from ..utilities import hartley, infer_space
from ..utilities import infer_space
from .linear_operator import LinearOperator
......@@ -69,5 +70,5 @@ class QHTOperator(LinearOperator):
for i in rng:
sl = (slice(None),)*i + (slice(1, None),)
v, tmp = dobj.ensure_not_distributed(v, (i,))
tmp[sl] = hartley(tmp[sl], axes=(i,))
tmp[sl] = fft.hartley(tmp[sl], axes=(i,))
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
......@@ -22,15 +22,13 @@ import collections
from itertools import product
import numpy as np
import pyfftw
from future.utils import with_metaclass
from pyfftw.interfaces.numpy_fft import fftn, rfftn
from .compat import *
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c",
"my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb", "indent",
"memo", "NiftyMetaBase", "my_sum", "my_lincomb_simple",
"my_lincomb", "indent",
"my_product", "frozendict", "special_add_at", "iscomplextype"]
......@@ -187,117 +185,6 @@ def NiftyMetaBase():
return with_metaclass(NiftyMeta, type('NewBase', (object,), {}))
def nthreads():
if nthreads._val is None:
import os
nthreads._val = int(os.getenv("OMP_NUM_THREADS", "1"))
return nthreads._val
nthreads._val = None
# Optional extra arguments for the FFT calls
# _fft_extra_args = {}
# if exact reproducibility is needed, use this:
_fft_extra_args = dict(planner_effort='FFTW_ESTIMATE')
def fft_prep():
if not fft_prep._initialized:
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(1000.)
fft_prep._initialized = True
fft_prep._initialized = False
def hartley(a, axes=None):
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis < len(a.shape) for axis in axes):
raise ValueError("Provided axes do not match array shape")
if iscomplextype(a.dtype):
raise TypeError("Hartley transform requires real-valued arrays.")
tmp = rfftn(a, axes=axes, threads=nthreads(), **_fft_extra_args)
def _fill_array(tmp, res, axes):
if axes is None:
axes = tuple(range(tmp.ndim))
lastaxis = axes[-1]
ntmplast = tmp.shape[lastaxis]
slice1 = (slice(None),)*lastaxis + (slice(0, ntmplast),)
np.add(tmp.real, tmp.imag, out=res[slice1])
def _fill_upper_half(tmp, res, axes):
lastaxis = axes[-1]
nlast = res.shape[lastaxis]
ntmplast = tmp.shape[lastaxis]
nrem = nlast - ntmplast
slice1 = [slice(None)]*lastaxis + [slice(ntmplast, None)]
slice2 = [slice(None)]*lastaxis + [slice(nrem, 0, -1)]
for i in axes[:-1]:
slice1[i] = slice(1, None)
slice2[i] = slice(None, 0, -1)
slice1 = tuple(slice1)
slice2 = tuple(slice2)
np.subtract(tmp[slice2].real, tmp[slice2].imag, out=res[slice1])
for i, ax in enumerate(axes[:-1]):
dim1 = (slice(None),)*ax + (slice(0, 1),)
axes2 = axes[:i] + axes[i+1:]
_fill_upper_half(tmp[dim1], res[dim1], axes2)
_fill_upper_half(tmp, res, axes)
return res
return _fill_array(tmp, np.empty_like(a), axes)
# Do a real-to-complex forward FFT and return the _full_ output array
def my_fftn_r2c(a, axes=None):
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis < len(a.shape) for axis in axes):
raise ValueError("Provided axes do not match array shape")
if iscomplextype(a.dtype):
raise TypeError("Transform requires real-valued input arrays.")
tmp = rfftn(a, axes=axes, threads=nthreads(), **_fft_extra_args)
def _fill_complex_array(tmp, res, axes):
if axes is None:
axes = tuple(range(tmp.ndim))
lastaxis = axes[-1]
ntmplast = tmp.shape[lastaxis]
slice1 = [slice(None)]*lastaxis + [slice(0, ntmplast)]
res[tuple(slice1)] = tmp
def _fill_upper_half_complex(tmp, res, axes):
lastaxis = axes[-1]
nlast = res.shape[lastaxis]
ntmplast = tmp.shape[lastaxis]
nrem = nlast - ntmplast
slice1 = [slice(None)]*lastaxis + [slice(ntmplast, None)]
slice2 = [slice(None)]*lastaxis + [slice(nrem, 0, -1)]
for i in axes[:-1]:
slice1[i] = slice(1, None)
slice2[i] = slice(None, 0, -1)
# np.conjugate(tmp[slice2], out=res[slice1])
res[tuple(slice1)] = np.conjugate(tmp[tuple(slice2)])
for i, ax in enumerate(axes[:-1]):
dim1 = tuple([slice(None)]*ax + [slice(0, 1)])
axes2 = axes[:i] + axes[i+1:]
_fill_upper_half_complex(tmp[dim1], res[dim1], axes2)
_fill_upper_half_complex(tmp, res, axes)
return res
return _fill_complex_array(tmp, np.empty_like(a, dtype=tmp.dtype), axes)
def my_fftn(a, axes=None):
return fftn(a, axes=axes, **_fft_extra_args)
class frozendict(collections.Mapping):
"""
An immutable wrapper around dictionaries that implements the complete
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment