Commit e1407dca authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: Changes to FFTW transform()

- Move the transform logic to its own method
- Refactored the _fftw_plan_and_info to FFTWTransformInfo
parent f1c84043
Pipeline #4201 skipped
......@@ -52,6 +52,7 @@ def fft_factory(fft_module_name):
class FFT(object):
"""
A generic fft object without any implementation.
"""
......@@ -79,6 +80,7 @@ class FFT(object):
class FFTW(FFT):
"""
The pyfftw pendant of a fft object.
"""
......@@ -88,9 +90,9 @@ class FFTW(FFT):
raise ImportError("The module pyfftw is needed but not available.")
self.name = 'pyfftw'
# The plan_dict stores the plan_and_info objects which correspond
# The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets.
self.plan_dict = {}
self.info_dict = {}
# initialize the dictionary which stores the values from
# get_centering_mask
......@@ -202,14 +204,72 @@ class FFTW(FFT):
self.centering_mask_dict[temp_id] = centering_mask
return self.centering_mask_dict[temp_id]
def _get_plan_and_info(self, domain, codomain, **kwargs):
def _get_plan_and_info(self, domain, codomain, is_local=False, **kwargs):
# generate a id-tuple which identifies the domain-codomain setting
temp_id = domain.__hash__() ^ (101 * codomain.__hash__())
# generate the plan_and_info object if not already there
if temp_id not in self.plan_dict:
self.plan_dict[temp_id] = _fftw_plan_and_info(domain, codomain,
self, **kwargs)
return self.plan_dict[temp_id]
if temp_id not in self.info_dict:
self.info_dict[temp_id] = FFTWTransformInfo(domain, codomain,
is_local, self,
**kwargs)
return self.info_dict[temp_id]
def _apply_mask(self, val, mask, axes):
"""
Apply centering mask to an array.
Parameters
----------
val: distributed_data_object or numpy.ndarray
The value-array on which the mask should be applied.
mask: numpy.ndarray
The mask to be applied.
axes: tuple
The axes which are to be transformed.
Returns
-------
distributed_data_object or np.nd_array
Mask input array by multiplying it with the mask.
"""
# reshape mask if necessary
if axes:
mask = mask.reshape(
[y if x in axes else 1
for x, y in enumerate(val.shape)]
)
return val * mask
def _transform(self, val, current_info, axes):
# Apply codomain centering mask
val = self._apply_mask(val, current_info.cmask_codomain, axes)
if current_info.local:
result = current_info.plan(val, axes=axes,
planner_effort='FFTW_ESTIMATE')
else:
p = current_info.plan
# Load the value into the plan
if p.has_input:
p.input_array[:] = val
# Execute the plan
p()
if p.has_output:
result = p.output_array
else:
raise RuntimeError('ERROR: PyFFTW-MPI transform failed.')
# Apply domain centering mask
result = self._apply_mask(val, current_info.cmask_domain, axes)
# Correct the sign if needed
result *= current_info.sign
return result
def transform(self, val, domain, codomain, axes, **kwargs):
"""
......@@ -238,281 +298,157 @@ class FFTW(FFT):
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Create domain and codomain centering mask
domain_centering_mask = self.get_centering_mask(
domain.paradict['zerocenter'],
domain.get_shape()
)
codomain_centering_mask = self.get_centering_mask(
codomain.paradict['zerocenter'],
codomain.get_shape()
)
# The +-1 magic has a caveat: if a dimension was zero-centered
# in the harmonic as well as in position space, the result gets
# a global minus. This keeps track of the sign so that we can
# remedy that before returning.
overall_sign = (-1) ** np.sum(
np.array(domain.paradict['zerocenter']) *
np.array(codomain.paradict['zerocenter']) *
np.array(domain.get_shape() // 2 % 2)
)
# If the input is a numpy array we transform it locally
if not isinstance(val, distributed_data_object):
current_info = self._get_plan_and_info(domain, codomain,
is_local=True,
**kwargs)
# Copy data for local manipulation
local_val = np.copy(val)
# reshape the masks
if axes:
domain_centering_mask = domain_centering_mask.reshape(
[y if x in axes else 1
for x, y in enumerate(local_val.shape)]
)
codomain_centering_mask = codomain_centering_mask.reshape(
[y if x in axes else 1
for x, y in enumerate(local_val.shape)]
)
# Apply codomain centering mask
local_val *= codomain_centering_mask
# We use pyfftw.interface.numpy_fft module to handle transformation
# in this case. The first call, with the given input combination,
# might be slow but the subsequent calls in the same session will
# be much faster as caching is enabled.
if codomain.harmonic:
return_val = pyfftw.interfaces.numpy_fft.fftn(local_val,
axes=axes)
else:
return_val = pyfftw.interfaces.numpy_fft.ifftn(local_val,
axes=axes)
# Apply domain centering mask
return_val *= domain_centering_mask
# Correct overall sign
return_val *= overall_sign
return_val = self._transform(local_val, current_info, axes)
else:
if val.distribution_strategy == 'not':
new_val = val.copy(distribution_strategy='fftw')
# Recursive call to take advantage of the fact that the data
# necessary is already present on the nodes.
return_val = self.transform(new_val, domain, codomain, axes,
**kwargs)
return_val = return_val.copy(distribution_strategy='not')
elif val.distribution_strategy in ('equal', 'fftw', 'freeform'):
if axes:
# We use pyfftw in this case
# Setup the return object
return_val = val.copy_empty(
global_shape=domain.get_shape(),
dtype=codomain.type
)
current_info = self._get_plan_and_info(
domain, codomain,
is_local=False if not axes else True, **kwargs
)
# Find which part of the data resides on this node
local_size = pyfftw.local_size(val.shape)
local_start = local_size[2]
local_end = local_start + local_size[1]
# Extract the relevant data
if val.distribution_strategy == 'fftw':
local_val = val.get_local_data()
else:
local_val = val.get_data(
slice(local_start, local_end),
local_keys=True
).get_local_data()
# reshape domain centering mask
domain_centering_mask = domain_centering_mask.reshape(
[y if x in axes else 1
for x, y in enumerate(local_val.shape)]
)
# Repack input as 'fftw' to make life easier
original_strategy = None
if val.distribution_strategy != 'fftw':
original_strategy = val.distribution_strategy
val_temp = val.copy(distribution_strategy='fftw')
else:
val_temp = val
codomain_centering_mask = codomain_centering_mask.reshape(
[y if x in axes else 1
for x, y in enumerate(local_val.shape)]
)
# Create the object which will be returned
return_val = val_temp.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
# Extract local data
local_val = val_temp.get_local_data()
# Transfom the data and place result in place
result = self._transform(local_val, current_info, axes)
return_val.set_local_data(data=result, copy=False)
# Apply codomain centering mask
local_val *= codomain_centering_mask
if codomain.harmonic:
result = pyfftw.interfaces.numpy_fft.fftn(
local_val,
axes=axes
)
else:
result = pyfftw.interfaces.numpy_fft.ifftn(
local_val,
axes=axes
)
# Apply domain centering mask
local_val *= domain_centering_mask
# Push data in-place in the array to be returned
if return_val.distribution_strategy == 'fftw':
return_val.set_local_data(result, copy=False)
else:
return_val.set_data(
data=result,
to_key=slice(local_start, local_end),
local_keys=True
)
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
return_val *= overall_sign
elif not axes or axes == (0,):
# We use pyfftw-mpi in this case
current_plan_and_info = self._get_plan_and_info(domain,
codomain,
**kwargs)
# Save the original distribution_strategy
original_strategy = None
# Change distribution strategy to fftw
if val.distribution_strategy != 'fftw':
original_strategy = val.distribution_strategy
val_fftw = val.copy(distribution_strategy='fftw')
# Create return object
return_val = val_fftw.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
# Extract local data
local_val = val_fftw.get_local_data()
domain_centering_mask = \
current_plan_and_info.get_domain_centering_mask().\
reshape(
[y if x in axes else 1
for x, y in enumerate(local_val.shape)]
)
codomain_centering_mask = \
current_plan_and_info.get_codomain_centering_mask().\
reshape(
[y if x in axes else 1
for x, y in enumerate(local_val.shape)]
)
local_val *= codomain_centering_mask
# Get the fftw plan
p = current_plan_and_info.get_plan()
# Load the value into the plan
if p.has_input:
p.input_array[:] = local_val
# execute the plan
p()
if p.has_output:
result = p.output_array * domain_centering_mask
else:
result = local_val
assert(result.shape[0] == 0)
return_val.set_local_data(data=result)
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
# reset to original strategy
# If the values living in domain are purely real, the
# resulting FFT is purely real
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
# Revert to original distribution strategy, if needed
if not original_strategy:
return_val = return_val.copy(
distribution_strategy=original_strategy
)
# Correct sign using current_plan_and_info
return_val *= current_plan_and_info.ovrall_sign
else:
raise ValueError('ERROR: Unknown distribution strategy')
return return_val
# The instances of plan_and_info store the fftw plan and all
# other information needed in order to perform a mpi-fftw transformation
class _fftw_plan_and_info(object):
def __init__(self, domain, codomain, fft_fftw_context, **kwargs):
class FFTWTransformInfo(object):
def __init__(self, domain, codomain, is_local, fftw_context, **kwargs):
if pyfftw is None:
raise ImportError("The module pyfftw is needed but not available.")
self.compute_plan_and_info(domain, codomain, fft_fftw_context,
**kwargs)
def set_plan(self, x):
self.plan = x
# Whether it's a local or an MPI transformation
self.local = is_local
# Create domain centering mask. The mask is the same size
# as the domain's shape.
self.cmask_domain = fftw_context.get_centering_mask(
domain.paradict['zerocenter'],
domain.get_shape(),
True
)
def get_plan(self):
return self.plan
self.cmask_codomain = fftw_context.get_centering_mask(
codomain.paradict['zerocenter'],
codomain.get_shape(),
True
)
def set_domain_centering_mask(self, x):
self.domain_centering_mask = x
# If both domain and codomain are zero-centered the result,
# will get a global minus. Store the sign to correct it.
self.sign = (-1) ** np.sum(
np.array(domain.paradict['zerocenter']) *
np.array(codomain.paradict['zerocenter']) *
(np.array(domain.get_shape()) // 2 % 2)
)
# Create the plans. If it's a local transform store pyfftw.interfaces
# instance otherwise an MPI plan
if self.local:
if codomain.harmonic:
self.plan = pyfftw.interfaces.numpy_fft.fftn
else:
self.plan = pyfftw.interfaces.numpy_fftn.ifftn
else:
self.plan = pyfftw.create_mpi_plan(
input_shape=domain.get_shape(),
input_dtype='complex128',
output_dtype='complext128',
direction='FFTW_FORWARD'
if codomain.harmonic else 'FFTW_BACKWARD',
flags=["FFTW_ESTIMATE"],
**kwargs
)
def get_domain_centering_mask(self):
return self.domain_centering_mask
@property
def cmask_domain(self):
return self._domain_centering_mask
def set_codomain_centering_mask(self, x):
self.codomain_centering_mask = x
@cmask_domain.setter
def cmask_domain(self, cmask):
self._domain_centering_mask = cmask
def get_codomain_centering_mask(self):
return self.codomain_centering_mask
@property
def cmask_codomain(self):
return self._codomain_centering_mask
def compute_plan_and_info(self, domain, codomain, fft_fftw_context,
**kwargs):
@cmask_codomain.setter
def cmask_codomain(self, cmask):
self._codomain_centering_mask = cmask
self.input_dtype = 'complex128'
self.output_dtype = 'complex128'
@property
def local(self):
return self._local
self.global_input_shape = domain.get_shape()
self.global_output_shape = codomain.get_shape()
self.fftw_local_size = pyfftw.local_size(self.global_input_shape)
@local.setter
def local(self, local):
self._local = local
self.in_zero_centered_dimensions = domain.paradict['zerocenter']
self.out_zero_centered_dimensions = codomain.paradict['zerocenter']
@property
def plan(self):
return self._plan
self.overall_sign = (-1) ** np.sum(
np.array(self.in_zero_centered_dimensions) *
np.array(self.out_zero_centered_dimensions) *
(np.array(self.global_input_shape) // 2 % 2)
)
@plan.setter
def plan(self, plan):
self._plan = plan
self.local_node_dimensions = np.append((self.fftw_local_size[1],),
self.global_input_shape[1:])
self.offsetQ = self.fftw_local_size[2] % 2
@property
def sign(self):
return self._sign
if codomain.harmonic:
self.direction = 'FFTW_FORWARD'
else:
self.direction = 'FFTW_BACKWARD'
# compute the centering masks
self.set_domain_centering_mask(
fft_fftw_context.get_centering_mask(
self.in_zero_centered_dimensions,
self.local_node_dimensions,
self.offsetQ))
self.set_codomain_centering_mask(
fft_fftw_context.get_centering_mask(
self.out_zero_centered_dimensions,
self.local_node_dimensions,
self.offsetQ))
self.set_plan(
pyfftw.create_mpi_plan(
input_shape=self.global_input_shape,
input_dtype=self.input_dtype,
output_dtype=self.output_dtype,
direction=self.direction,
flags=["FFTW_ESTIMATE"],
**kwargs)
)
@sign.setter
def sign(self, sign):
self._sign = sign
class GFFT(FFT):
"""
The gfft pendant of a fft object.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment