Commit c435e42f authored by theos's avatar theos
Browse files

Fixed axes handling in nifty_fft.py

parent 1182e522
Pipeline #5115 skipped
......@@ -192,21 +192,24 @@ class FFTW(FFT):
return self.centering_mask_dict[temp_id]
def _get_transform_info(self, domain, codomain, local_shape,
local_offset_Q, is_local, **kwargs):
local_offset_Q, is_local, transform_shape=None,
**kwargs):
# generate a id-tuple which identifies the domain-codomain setting
temp_id = domain.__hash__() ^ (101 * codomain.__hash__())
temp_id = (domain.__hash__() ^
(101 * codomain.__hash__()) ^
(211 * transform_shape.__hash__()))
# generate the plan_and_info object if not already there
if temp_id not in self.info_dict:
if is_local:
self.info_dict[temp_id] = FFTWLocalTransformInfo(
domain, codomain, local_shape, local_offset_Q,
self, **kwargs
domain, codomain, local_shape,
local_offset_Q, self, **kwargs
)
else:
self.info_dict[temp_id] = FFTWMPITransfromInfo(
domain, codomain, local_shape, local_offset_Q,
self, **kwargs
domain, codomain, local_shape,
local_offset_Q, self, transform_shape, **kwargs
)
return self.info_dict[temp_id]
......@@ -274,12 +277,11 @@ class FFTW(FFT):
local_offset_Q = False
try:
local_val = val.get_local_data(copy=False),
local_val = val.get_local_data(copy=False)
if axes is None or 0 in axes:
local_offset_Q = val.distributor.local_shape[0] % 2
except(AttributeError):
local_val = val
current_info = self._get_transform_info(domain,
codomain,
local_shape=local_val.shape,
......@@ -343,12 +345,7 @@ class FFTW(FFT):
local_offset_list[val.distributor.comm.rank] % 2)
else:
local_offset_Q = False
current_info = self._get_transform_info(domain,
codomain,
local_shape=val.local_shape,
local_offset_Q=local_offset_Q,
is_local=False,
**kwargs)
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
......@@ -363,6 +360,7 @@ class FFTW(FFT):
if set(axes) == set(range(len(val.shape))):
axes = None
current_info = None
for slice_list in utilities.get_slice_list(local_val.shape, axes):
if slice_list == [slice(None, None)]:
inp = local_val
......@@ -381,6 +379,16 @@ class FFTW(FFT):
original_shape = inp.shape
inp = inp.reshape(inp.shape[0], 1)
if current_info is None:
current_info = self._get_transform_info(
domain,
codomain,
local_shape=val.local_shape,
local_offset_Q=local_offset_Q,
is_local=False,
transform_shape=inp.shape,
**kwargs)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
result = self._atomic_mpi_transform(inp, current_info, axes,
......@@ -438,7 +446,10 @@ class FFTW(FFT):
# Cast to a np.ndarray
temp_val = np.asarray(val)
current_info = self._get_transform_info(domain, codomain,
current_info = self._get_transform_info(domain,
codomain,
local_shape=temp_val.shape,
local_offset_Q=False,
is_local=True,
**kwargs)
......@@ -475,8 +486,8 @@ class FFTW(FFT):
class FFTWTransformInfo(object):
def __init__(self, domain, codomain, local_shape, local_offset_Q,
fftw_context, **kwargs):
def __init__(self, domain, codomain, local_shape,
local_offset_Q, fftw_context, **kwargs):
if pyfftw is None:
raise ImportError("The module pyfftw is needed but not available.")
......@@ -522,8 +533,8 @@ class FFTWTransformInfo(object):
class FFTWLocalTransformInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, local_shape, local_offset_Q,
fftw_context, **kwargs):
def __init__(self, domain, codomain, local_shape,
local_offset_Q, fftw_context, **kwargs):
super(FFTWLocalTransformInfo, self).__init__(domain,
codomain,
local_shape,
......@@ -546,23 +557,16 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
class FFTWMPITransfromInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, local_shape, local_offset_Q,
fftw_context, **kwargs):
def __init__(self, domain, codomain, local_shape,
local_offset_Q, fftw_context, transform_shape, **kwargs):
super(FFTWMPITransfromInfo, self).__init__(domain,
codomain,
local_shape,
local_offset_Q,
fftw_context,
**kwargs)
# When the domain is 1-dimensional, reshape it so that it can
# accept input which is also augmented by 1.
if len(domain.get_shape()) == 1:
shape = (domain.get_shape()[0], 1)
else:
shape = domain.get_shape()
self._plan = pyfftw.create_mpi_plan(
input_shape=shape,
input_shape=transform_shape,
input_dtype='complex128',
output_dtype='complex128',
direction='FFTW_FORWARD' if codomain.harmonic else 'FFTW_BACKWARD',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment