Commit 175a8a95 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

FFTOperator seems to be working

parent f099ab83
......@@ -67,41 +67,35 @@ class FFTOperator(LinearOperator):
utilities.fft_prep()
def apply(self, x, mode):
from pyfftw.interfaces.numpy_fft import fftn, ifftn
self._check_input(x, mode)
if np.issubdtype(x.dtype, np.complexfloating):
return (self._apply_cartesian(x.real, mode) +
1j*self._apply_cartesian(x.imag, mode))
ncells = x.domain[self._space].size
if x.domain[self._space].harmonic: # harmonic -> position
func = fftn
fct = 1.
else:
return self._apply_cartesian(x, mode)
def _apply_cartesian(self, x, mode):
func = ifftn
fct = ncells
axes = x.domain.axes[self._space]
tdom = self._tgt(mode)
oldax = dobj.distaxis(x.val)
if oldax not in axes: # straightforward, no redistribution needed
ldat = x.local_data
ldat = utilities.hartley(ldat, axes=axes)
ldat = func(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
elif len(axes) < len(x.shape) or len(axes) == 1:
# we can use one Hartley pass in between the redistributions
# we can use one FFT pass in between the redistributions
tmp = dobj.redistribute(x.val, nodist=axes)
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = utilities.hartley(ldat, axes=axes)
ldat = func(ldat, axes=axes)
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
else: # two separate, full FFTs needed
# ideal strategy for the moment would be:
# - do real-to-complex FFT on all local axes
# - fill up array
# - redistribute array
# - do complex-to-complex FFT on remaining axis
# - add re+im
# - redistribute back
else: # two separate FFTs needed
rem_axes = tuple(i for i in axes if i != oldax)
tmp = x.val
ldat = dobj.local_data(tmp)
ldat = utilities.my_fftn_r2c(ldat, axes=rem_axes)
ldat = func(ldat, axes=rem_axes)
if oldax != 0:
raise ValueError("bad distribution")
ldat2 = ldat.reshape((ldat.shape[0],
......@@ -110,17 +104,16 @@ class FFTOperator(LinearOperator):
tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp)
ldat2 = utilities.my_fftn(ldat2, axes=(1,))
ldat2 = ldat2.real+ldat2.imag
ldat2 = func(ldat2, axes=(1,))
tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp).reshape(ldat.shape)
tmp = dobj.from_local_data(x.val.shape, ldat2, distaxis=0)
Tval = Field(tdom, tmp)
if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES):
fct = self._domain[self._space].scalar_dvol
fct *= self._domain[self._space].scalar_dvol
else:
fct = self._target[self._space].scalar_dvol
fct *= self._target[self._space].scalar_dvol
return Tval if fct == 1 else Tval*fct
@property
......
......@@ -36,14 +36,15 @@ def _get_rtol(tp):
class FFTOperatorTests(unittest.TestCase):
@expand(product([16, ], [0.1, 1, 3.7],
[np.float64, np.float32, np.complex64, np.complex128]))
def test_fft1D(self, dim1, d, itp):
[np.float64, np.float32, np.complex64, np.complex128],
[ift.HartleyOperator, ift.FFTOperator]))
def test_fft1D(self, dim1, d, itp, op):
tol = _get_rtol(itp)
a = ift.RGSpace(dim1, distances=d)
b = ift.RGSpace(dim1, distances=1./(dim1*d), harmonic=True)
np.random.seed(16)
fft = ift.FFTOperator(domain=a, target=b)
fft = op(domain=a, target=b)
inp = ift.Field.from_random(domain=a, random_type='normal',
std=7, mean=3, dtype=itp)
out = fft.inverse_times(fft.times(inp))
......@@ -59,14 +60,15 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([12, 15], [9, 12], [0.1, 1, 3.7],
[0.4, 1, 2.7],
[np.float64, np.float32, np.complex64, np.complex128]))
def test_fft2D(self, dim1, dim2, d1, d2, itp):
[np.float64, np.float32, np.complex64, np.complex128],
[ift.HartleyOperator, ift.FFTOperator]))
def test_fft2D(self, dim1, dim2, d1, d2, itp, op):
tol = _get_rtol(itp)
a = ift.RGSpace([dim1, dim2], distances=[d1, d2])
b = ift.RGSpace([dim1, dim2],
distances=[1./(dim1*d1), 1./(dim2*d2)], harmonic=True)
fft = ift.FFTOperator(domain=a, target=b)
fft = op(domain=a, target=b)
inp = ift.Field.from_random(domain=a, random_type='normal',
std=7, mean=3, dtype=itp)
out = fft.inverse_times(fft.times(inp))
......@@ -81,12 +83,13 @@ class FFTOperatorTests(unittest.TestCase):
assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol)
@expand(product([0, 1, 2],
[np.float64, np.float32, np.complex64, np.complex128]))
def test_composed_fft(self, index, dtype):
[np.float64, np.float32, np.complex64, np.complex128],
[ift.HartleyOperator, ift.FFTOperator]))
def test_composed_fft(self, index, dtype, op):
tol = _get_rtol(dtype)
a = [a1, a2, a3] = [ift.RGSpace((32,)), ift.RGSpace((4, 4)),
ift.RGSpace((5, 6))]
fft = ift.FFTOperator(domain=a, space=index)
fft = op(domain=a, space=index)
inp = ift.Field.from_random(domain=(a1, a2, a3), random_type='normal',
std=7, mean=3, dtype=dtype)
......@@ -96,15 +99,16 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([ift.RGSpace(128, distances=3.76, harmonic=True),
ift.RGSpace((15, 27), distances=(.7, .33), harmonic=True),
ift.RGSpace(73, distances=0.5643)],
[np.float64, np.float32, np.complex64, np.complex128]))
def test_normalisation(self, space, tp):
[np.float64, np.float32, np.complex64, np.complex128],
[ift.HartleyOperator, ift.FFTOperator]))
def test_normalisation(self, space, tp, op):
tol = 10 * _get_rtol(tp)
cospace = space.get_default_codomain()
fft = ift.FFTOperator(space, cospace)
fft = op(space, cospace)
inp = ift.Field.from_random(domain=space, random_type='normal',
std=1, mean=2, dtype=tp)
out = fft.times(inp)
fft2 = ift.FFTOperator(cospace, space)
fft2 = op(cospace, space)
out2 = fft2.inverse_times(inp)
zero_idx = tuple([0]*len(space.shape))
assert_allclose(inp.to_global_data()[zero_idx], out.integrate(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment