Commit afa17f5a authored by Jait Dixit's avatar Jait Dixit
Browse files

Fix FFTW data distribution problem

parent b19d5d58
Pipeline #4807 skipped
......@@ -392,6 +392,17 @@ class FFTW(FFT):
temp_val = np.empty_like(local_val)
inp = local_val[slice_list]
# This is in order to make FFTW behave properly
# when slicing input over MPI ranks when the
# input is 1-dimensional. The default behaviour
# is to slice so that it's byte-aligned, which
# doesn't play well with multi-dimensional data
# sliced for FFTW.
original_shape = None
if len(inp.shape) == 1:
original_shape = inp.shape
inp = inp.reshape(inp.shape[0], 1)
result = self._mpi_transform(inp,
current_info,
axes, domain,
......@@ -400,6 +411,11 @@ class FFTW(FFT):
if slice_list == [slice(None, None)]:
temp_val = result
else:
# Reverting to the original shape i.e. before
# the input was augmented with 1 to make
# FFTW behave properly.
if original_shape is not None:
result = result.reshape(original_shape)
temp_val[slice_list] = result
return_val.set_local_data(data=temp_val, copy=False)
......@@ -522,8 +538,15 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, fftw_context, **kwargs):
super(FFTWMPITransfromInfo, self).__init__(domain, codomain,
fftw_context, **kwargs)
# When the domain is 1-dimensional, reshape it so that it can
# accept input which is also augmented by 1.
if len(domain.get_shape()) == 1:
shape = (domain.get_shape()[0], 1)
else:
shape = domain.get_shape()
self._plan = pyfftw.create_mpi_plan(
input_shape=domain.get_shape(),
input_shape=shape,
input_dtype='complex128',
output_dtype='complex128',
direction='FFTW_FORWARD' if codomain.harmonic else 'FFTW_BACKWARD',
......
......@@ -107,5 +107,27 @@ class TestFFTWTransform(unittest.TestCase):
), 'results do not match numpy.fft.fftn'
)
def test_mpi_zero_equal(self):
x = nt.rg_space(8, fft_module='pyfftw')
a = np.ones((8, 8)) + 1j*np.zeros((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='equal')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0,)),
np.fft.fftn(a, axes=(0,))
), 'results do not match numpy.fft.fftn'
)
def test_mpi_zero_not(self):
x = nt.rg_space(8, fft_module='pyfftw')
a = np.ones((8, 8)) + 1j*np.zeros((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='not')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0,)),
np.fft.fftn(a, axes=(0,))
), 'results do not match numpy.fft.fftn'
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment