diff --git a/nifty5/operators/fft_operator.py b/nifty5/operators/fft_operator.py index c96687005c1c39f5dcca19066c37bfdd28616b7f..97be21f116f51832e3028c0b8d21fb4c39fe20d9 100644 --- a/nifty5/operators/fft_operator.py +++ b/nifty5/operators/fft_operator.py @@ -67,41 +67,35 @@ class FFTOperator(LinearOperator): utilities.fft_prep() def apply(self, x, mode): + from pyfftw.interfaces.numpy_fft import fftn, ifftn self._check_input(x, mode) - if np.issubdtype(x.dtype, np.complexfloating): - return (self._apply_cartesian(x.real, mode) + - 1j*self._apply_cartesian(x.imag, mode)) + ncells = x.domain[self._space].size + if x.domain[self._space].harmonic: # harmonic -> position + func = fftn + fct = 1. else: - return self._apply_cartesian(x, mode) - - def _apply_cartesian(self, x, mode): + func = ifftn + fct = ncells axes = x.domain.axes[self._space] tdom = self._tgt(mode) oldax = dobj.distaxis(x.val) if oldax not in axes: # straightforward, no redistribution needed ldat = x.local_data - ldat = utilities.hartley(ldat, axes=axes) + ldat = func(ldat, axes=axes) tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax) elif len(axes) < len(x.shape) or len(axes) == 1: - # we can use one Hartley pass in between the redistributions + # we can use one FFT pass in between the redistributions tmp = dobj.redistribute(x.val, nodist=axes) newax = dobj.distaxis(tmp) ldat = dobj.local_data(tmp) - ldat = utilities.hartley(ldat, axes=axes) + ldat = func(ldat, axes=axes) tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax) tmp = dobj.redistribute(tmp, dist=oldax) - else: # two separate, full FFTs needed - # ideal strategy for the moment would be: - # - do real-to-complex FFT on all local axes - # - fill up array - # - redistribute array - # - do complex-to-complex FFT on remaining axis - # - add re+im - # - redistribute back + else: # two separate FFTs needed rem_axes = tuple(i for i in axes if i != oldax) tmp = x.val ldat = dobj.local_data(tmp) - ldat = utilities.my_fftn_r2c(ldat, axes=rem_axes) + ldat = func(ldat, axes=rem_axes) if oldax != 0: raise ValueError("bad distribution") ldat2 = ldat.reshape((ldat.shape[0], @@ -110,17 +104,16 @@ class FFTOperator(LinearOperator): tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0) tmp = dobj.transpose(tmp) ldat2 = dobj.local_data(tmp) - ldat2 = utilities.my_fftn(ldat2, axes=(1,)) - ldat2 = ldat2.real+ldat2.imag + ldat2 = func(ldat2, axes=(1,)) tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0) tmp = dobj.transpose(tmp) ldat2 = dobj.local_data(tmp).reshape(ldat.shape) tmp = dobj.from_local_data(x.val.shape, ldat2, distaxis=0) Tval = Field(tdom, tmp) if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES): - fct = self._domain[self._space].scalar_dvol + fct *= self._domain[self._space].scalar_dvol else: - fct = self._target[self._space].scalar_dvol + fct *= self._target[self._space].scalar_dvol return Tval if fct == 1 else Tval*fct @property diff --git a/test/test_operators/test_fft_operator.py b/test/test_operators/test_fft_operator.py index 24ec9fe1f74246ed8710388c45d06bf8ed4ec322..783d0a7f9655771eca08d1460ea5970e65547da2 100644 --- a/test/test_operators/test_fft_operator.py +++ b/test/test_operators/test_fft_operator.py @@ -36,14 +36,15 @@ def _get_rtol(tp): class FFTOperatorTests(unittest.TestCase): @expand(product([16, ], [0.1, 1, 3.7], - [np.float64, np.float32, np.complex64, np.complex128])) - def test_fft1D(self, dim1, d, itp): + [np.float64, np.float32, np.complex64, np.complex128], + [ift.HartleyOperator, ift.FFTOperator])) + def test_fft1D(self, dim1, d, itp, op): tol = _get_rtol(itp) a = ift.RGSpace(dim1, distances=d) b = ift.RGSpace(dim1, distances=1./(dim1*d), harmonic=True) np.random.seed(16) - fft = ift.FFTOperator(domain=a, target=b) + fft = op(domain=a, target=b) inp = ift.Field.from_random(domain=a, random_type='normal', std=7, mean=3, dtype=itp) out = fft.inverse_times(fft.times(inp)) @@ -59,14 +60,15 @@ class FFTOperatorTests(unittest.TestCase): @expand(product([12, 15], [9, 12], [0.1, 1, 3.7], [0.4, 1, 2.7], - [np.float64, np.float32, np.complex64, np.complex128])) - def test_fft2D(self, dim1, dim2, d1, d2, itp): + [np.float64, np.float32, np.complex64, np.complex128], + [ift.HartleyOperator, ift.FFTOperator])) + def test_fft2D(self, dim1, dim2, d1, d2, itp, op): tol = _get_rtol(itp) a = ift.RGSpace([dim1, dim2], distances=[d1, d2]) b = ift.RGSpace([dim1, dim2], distances=[1./(dim1*d1), 1./(dim2*d2)], harmonic=True) - fft = ift.FFTOperator(domain=a, target=b) + fft = op(domain=a, target=b) inp = ift.Field.from_random(domain=a, random_type='normal', std=7, mean=3, dtype=itp) out = fft.inverse_times(fft.times(inp)) @@ -81,12 +83,13 @@ class FFTOperatorTests(unittest.TestCase): assert_allclose(inp.local_data, out.local_data, rtol=tol, atol=tol) @expand(product([0, 1, 2], - [np.float64, np.float32, np.complex64, np.complex128])) - def test_composed_fft(self, index, dtype): + [np.float64, np.float32, np.complex64, np.complex128], + [ift.HartleyOperator, ift.FFTOperator])) + def test_composed_fft(self, index, dtype, op): tol = _get_rtol(dtype) a = [a1, a2, a3] = [ift.RGSpace((32,)), ift.RGSpace((4, 4)), ift.RGSpace((5, 6))] - fft = ift.FFTOperator(domain=a, space=index) + fft = op(domain=a, space=index) inp = ift.Field.from_random(domain=(a1, a2, a3), random_type='normal', std=7, mean=3, dtype=dtype) @@ -96,15 +99,16 @@ class FFTOperatorTests(unittest.TestCase): @expand(product([ift.RGSpace(128, distances=3.76, harmonic=True), ift.RGSpace((15, 27), distances=(.7, .33), harmonic=True), ift.RGSpace(73, distances=0.5643)], - [np.float64, np.float32, np.complex64, np.complex128])) - def test_normalisation(self, space, tp): + [np.float64, np.float32, np.complex64, np.complex128], + [ift.HartleyOperator, ift.FFTOperator])) + def test_normalisation(self, space, tp, op): tol = 10 * _get_rtol(tp) cospace = space.get_default_codomain() - fft = ift.FFTOperator(space, cospace) + fft = op(space, cospace) inp = ift.Field.from_random(domain=space, random_type='normal', std=1, mean=2, dtype=tp) out = fft.times(inp) - fft2 = ift.FFTOperator(cospace, space) + fft2 = op(cospace, space) out2 = fft2.inverse_times(inp) zero_idx = tuple([0]*len(space.shape)) assert_allclose(inp.to_global_data()[zero_idx], out.integrate(),