diff --git a/nifty/operators/fft_operator/transformations/rg_transforms.py b/nifty/operators/fft_operator/transformations/rg_transforms.py
index 4ae16500976ed61b7414d1f8369966b2194c4c77..e56f309d3a18e62595f567ac7e46f7b335e7e092 100644
--- a/nifty/operators/fft_operator/transformations/rg_transforms.py
+++ b/nifty/operators/fft_operator/transformations/rg_transforms.py
@@ -373,10 +373,16 @@ class MPIFFT(Transform):
                 original_shape = inp.shape
                 inp = inp.reshape(inp.shape[0], 1)
                 axes = (0, )
-                if original_shape[0]%2!=0:
-                    raise AttributeError("MPI-FFTs of onedimensional arrays "
-                        "with odd length are currently not supported due to a "
-                        "bug in FFTW. Please use a grid with even length.")
+
+                if axes is None:
+                    global_shape = val.shape
+                else:
+                    global_shape = (val.shape[axes[0]], )
+                if global_shape[0] % 2 != 0:
+                    raise AttributeError(
+                        "MPI-FFTs of onedimensional arrays with odd length "
+                        "are currently not supported due to a bug in FFTW. "
+                        "Please use a grid with even length.")
 
             if current_info is None:
                 transform_shape = list(inp.shape)
diff --git a/test/test_operators/test_fft_operator.py b/test/test_operators/test_fft_operator.py
index 18bb5c7f52e94a35cbc76ab52b482db6e627b76f..edf9bcf99e6fe2521b20a8480ad1d8fe666b8a9b 100644
--- a/test/test_operators/test_fft_operator.py
+++ b/test/test_operators/test_fft_operator.py
@@ -63,7 +63,7 @@ class FFTOperatorTests(unittest.TestCase):
         assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.)
 
     @expand(product(["numpy", "fftw", "fftw_mpi"],
-                    [12, ], [False, True], [False, True],
+                    [16, ], [False, True], [False, True],
                     [0.1, 1, 3.7],
                     [np.float64, np.complex128, np.float32, np.complex64]))
     def test_fft1D(self, module, dim1, zc1, zc2, d, itp):