Commit ea51586f authored by Jait Dixit's avatar Jait Dixit
Browse files

Add implementation for d2o when axes is not None

parent 6e4dbc02
Pipeline #3370 skipped
......@@ -235,6 +235,7 @@ class FFTW(FFT):
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# TODO +/- magic in both cases
# If the input is a numpy array we transform it locally
if not isinstance(val, distributed_data_object):
# Copy data for local manipulation
......@@ -251,7 +252,7 @@ class FFTW(FFT):
)
# Apply codomain centering mask
for slice_list in utilities.get_slice_list(val.shape, axes):
for slice_list in utilities.get_slice_list(local_val.shape, axes):
if slice_list == [slice(None, None)]:
local_val *= codomain_centering_mask
else:
......@@ -269,7 +270,7 @@ class FFTW(FFT):
axes=axes)
# Apply domain centering mask
for slice_list in utilities.get_slice_list(val.shape, axes):
for slice_list in utilities.get_slice_list(local_val.shape, axes):
if slice_list == [slice(None, None)]:
return_val *= domain_centering_mask
else:
......@@ -277,17 +278,84 @@ class FFTW(FFT):
return return_val.astype(codomain.dtype)
else:
# Setup the final result array
return_val = val.copy_empty(
global_shape=codomain.get_shape(),
dtype=codomain.type
)
if val.distribution_strategy == 'not':
pass
new_val = val.copy(distribution_strategy='fftw')
return_val = self.transform(
new_val,
domain,
codomain,
axes,
**kwargs).copy(distribution_strategy='not')
elif val.distribution_strategy in ('equal', 'fftw', 'freeform'):
# slicing distributor need to examine axes before proceeding
pass
if axes:
# We use pyfftw in this case
# Setup up the array which will be returned
return_val = val.copy_empty(
global_shape=domain.get_shape(),
dtype=codomain.type
)
# Find which part of the data resides on this node
local_size = pyfftw.local_size(val.shape)
local_start = local_size[2]
local_end = local_start + local_size[1]
# Extract the relevant data
if val.distribution_strategy == 'fftw':
local_val = val.get_local_data()
else:
local_val = val.get_data(
slice(local_start, local_end),
local_keys=True
).get_local_data()
# Create domain and codomain centering mask
domain_centering_mask = self.get_centering_mask(
domain.paradict['zerocenter'],
domain.get_shape()
)
codomain_centering_mask = self.get_centering_mask(
codomain.paradict['zerocenter'],
codomain.get_shape()
)
# Apply codomain centering mask
for slice_list in utilities.get_slice_list(local_val.shape,
axes):
local_val[slice_list] *= codomain_centering_mask
if codomain.harmonic:
result = pyfftw.interfaces.numpy_fft.fftn(
local_val,
axes=axes
)
else:
result = pyfftw.interfaces.numpy_fft.ifftn(
local_val,
axes=axes
)
# Apply domain centering mask
for slice_list in utilities.get_slice_list(local_val.shape,
axes):
result[slice_list] *= domain_centering_mask
# Push data in-place in the array to be returned
if return_val.distribution_strategy == 'fftw':
return_val.set_local_data(result, copy=False)
else:
return_val.set_data(
data=result,
to_key=slice(local_start, local_end),
local_keys=True
)
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
elif not axes or axes == (0,):
# We use pyfftw-mpi in this case
pass
else:
raise ValueError('ERROR: Unknown distribution strategy')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment