diff --git a/nifty/operators/fft_operator/transformations/rg_transforms.py b/nifty/operators/fft_operator/transformations/rg_transforms.py index 9b4866420223f52cbe82549e7894596bfa2f1049..4ae16500976ed61b7414d1f8369966b2194c4c77 100644 --- a/nifty/operators/fft_operator/transformations/rg_transforms.py +++ b/nifty/operators/fft_operator/transformations/rg_transforms.py @@ -373,6 +373,10 @@ class MPIFFT(Transform): original_shape = inp.shape inp = inp.reshape(inp.shape[0], 1) axes = (0, ) + if original_shape[0]%2!=0: + raise AttributeError("MPI-FFTs of onedimensional arrays " + "with odd length are currently not supported due to a " + "bug in FFTW. Please use a grid with even length.") if current_info is None: transform_shape = list(inp.shape) diff --git a/test/test_operators/test_fft_operator.py b/test/test_operators/test_fft_operator.py index 2f31c3272e230f967f995dfa7de3dc8df85546d0..18bb5c7f52e94a35cbc76ab52b482db6e627b76f 100644 --- a/test/test_operators/test_fft_operator.py +++ b/test/test_operators/test_fft_operator.py @@ -63,7 +63,7 @@ class FFTOperatorTests(unittest.TestCase): assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.) @expand(product(["numpy", "fftw", "fftw_mpi"], - [10, 11], [False, True], [False, True], + [12, ], [False, True], [False, True], [0.1, 1, 3.7], [np.float64, np.complex128, np.float32, np.complex64])) def test_fft1D(self, module, dim1, zc1, zc2, d, itp): @@ -86,7 +86,7 @@ class FFTOperatorTests(unittest.TestCase): rtol=tol, atol=tol) @expand(product(["numpy", "fftw", "fftw_mpi"], - [10, 11], [9, 12], [False, True], + [12, 15], [9, 12], [False, True], [False, True], [False, True], [False, True], [0.1, 1, 3.7], [0.4, 1, 2.7], [np.float64, np.complex128, np.float32, np.complex64])) diff --git a/test/test_serialization.py b/test/test_serialization.py index 27be6fbc243819534a5c4330efdaa29ece76b55c..48fa1bc12f4f9366ce70bc043293ddac8567ea1a 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -39,4 +39,7 @@ class SpaceSerializationTests(unittest.TestCase): repo.commit() assert_equal(space, repo.get('space')) assert_equal(field, repo.get('field')) - os.remove('test.h5') + try: + os.remove('test.h5') + except OSError: + pass