Commit c9ae0203 authored by Jait Dixit's avatar Jait Dixit
Browse files

Move individual cases in FFTW's transform to separate methods

parent afa17f5a
Pipeline #5049 skipped
......@@ -299,6 +299,109 @@ class FFTW(FFT):
return result
def _not_slicing_transform(self, val, domain, codomain, axes, **kwargs):
temp_val = val.copy_empty(distribution_strategy='fftw')
about.warnings.cprint('WARNING: Repacking d2o to fftw \
distribution strategy')
temp_val.set_full_data(val, copy=False)
# Recursive call to take advantage of the fact that the data
# necessary is already present on the nodes.
result = self.transform(temp_val, domain, codomain, axes,
**kwargs)
return_val = val.copy_empty(
distribution_strategy=val.distribution_strategy
)
return_val.set_full_data(result, copy=False)
return return_val
def _slicing_local_transform(self, val, domain, codomain, axes, **kwargs):
current_info = self._get_transform_info(domain, codomain,
is_local=True, **kwargs)
# Compute transform for the local data
result = self._local_transform(
val.get_local_data(copy=False),
current_info, axes,
domain, codomain
)
# Create return object and insert results inplace
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
return_val.set_local_data(data=result, copy=False)
return return_val
def _slicing_not_fftw_mpi_transform(self, val, domain, codomain,
axes, **kwargs):
temp_val = val.copy_empty(distribution_strategy='fftw')
temp_val.set_full_data(val, copy=False)
# Recursive call to transform
result = self.transform(temp_val, domain, codomain, axes, **kwargs)
return_val = result.copy_empty(
distribution_strategy=val.distribution_strategy
)
return_val.set_full_data(data=result, copy=False)
return return_val
def _slicing_fftw_mpi_transform(self, val, domain, codomain,
axes, **kwargs):
current_info = self._get_transform_info(domain, codomain, **kwargs)
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
# Extract local data
local_val = val.get_local_data(copy=False)
# Create temporary storage for slices
temp_val = None
# If axes tuple includes all axes, set it to None
if axes is not None:
if set(axes) == set(range(len(val.shape))):
axes = None
for slice_list in utilities.get_slice_list(local_val.shape, axes):
if slice_list == [slice(None, None)]:
inp = local_val
else:
if temp_val is None:
temp_val = np.empty_like(local_val)
inp = local_val[slice_list]
# This is in order to make FFTW behave properly when slicing input
# over MPI ranks when the input is 1-dimensional. The default
# behaviour is to optimize to take advantage of byte-alignment,
# which doesn't match the slicing strategy for multi-dimensional
# data.
original_shape = None
if len(inp.shape) == 1:
original_shape = inp.shape
inp = inp.reshape(inp.shape[0], 1)
result = self._mpi_transform(inp, current_info, axes,
domain, codomain)
if slice_list == [slice(None, None)]:
temp_val = result
else:
# Reverting to the original shape i.e. before the input was
# augmented with 1 to make FFTW behave properly.
if original_shape is not None:
result = result.reshape(original_shape)
temp_val[slice_list] = result
return_val.set_local_data(data=temp_val, copy=False)
return return_val
def transform(self, val, domain, codomain, axes=None, **kwargs):
"""
The pyfftw transform function.
......@@ -352,108 +455,22 @@ class FFTW(FFT):
if axes is None or set(axes) == set(range(len(val.shape))) \
or 0 in axes:
if val.distribution_strategy != 'fftw':
temp_val = val.copy_empty(distribution_strategy='fftw')
temp_val.set_full_data(val, copy=False)
# Recursive call to transform
result = self.transform(temp_val, domain, codomain,
axes, **kwargs)
return_val = result.copy_empty(
distribution_strategy=val.distribution_strategy
)
return_val.set_full_data(data=result, copy=False)
return_val = \
self._slicing_not_fftw_mpi_transform(
val, domain, codomain, axes, **kwargs
)
else:
current_info = self._get_transform_info(domain,
codomain,
**kwargs)
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
# Extract local data
local_val = val.get_local_data(copy=False)
# Create temporary storage for slices
temp_val = None
# If axes tuple includes all axes, set it to None
if axes is not None:
if set(axes) == set(range(len(val.shape))):
axes = None
for slice_list in\
utilities.get_slice_list(local_val.shape,
axes):
if slice_list == [slice(None, None)]:
inp = local_val
else:
if temp_val is None:
temp_val = np.empty_like(local_val)
inp = local_val[slice_list]
# This is in order to make FFTW behave properly
# when slicing input over MPI ranks when the
# input is 1-dimensional. The default behaviour
# is to slice so that it's byte-aligned, which
# doesn't play well with multi-dimensional data
# sliced for FFTW.
original_shape = None
if len(inp.shape) == 1:
original_shape = inp.shape
inp = inp.reshape(inp.shape[0], 1)
result = self._mpi_transform(inp,
current_info,
axes, domain,
codomain)
if slice_list == [slice(None, None)]:
temp_val = result
else:
# Reverting to the original shape i.e. before
# the input was augmented with 1 to make
# FFTW behave properly.
if original_shape is not None:
result = result.reshape(original_shape)
temp_val[slice_list] = result
return_val.set_local_data(data=temp_val, copy=False)
return_val = self._slicing_fftw_mpi_transform(
val, domain, codomain, axes, **kwargs
)
else:
current_info = self._get_transform_info(domain, codomain,
is_local=True,
**kwargs)
# Compute transform for the local data
result = self._local_transform(
val.get_local_data(copy=False),
current_info, axes,
domain, codomain
return_val = self._slicing_local_transform(
val, domain, codomain, axes, **kwargs
)
# Create return object and insert results inplace
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
return_val.set_local_data(data=result, copy=False)
# If domain is purely real, the result of the FFT is hermitian
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
else:
temp_val = val.copy_empty(distribution_strategy='fftw')
about.warnings.cprint('WARNING: Repacking d2o to fftw \
distribution strategy')
temp_val.set_full_data(val, copy=False)
# Recursive call to take advantage of the fact that the data
# necessary is already present on the nodes.
result = self.transform(temp_val, domain, codomain, axes,
**kwargs)
return_val = val.copy_empty(
distribution_strategy=val.distribution_strategy
return_val = self._not_slicing_transform(
val, domain, codomain, axes, **kwargs
)
return_val.set_full_data(result, copy=False)
# If domain is purely real, the result of the FFT is hermitian
if domain.paradict['complexity'] == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment