Commit f0e7585f authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: Add axes keyword to FFTW

- Finished the case when input is np.ndarray
- Fix typos
parent 2ae2d181
Pipeline #3364 skipped
......@@ -109,7 +109,7 @@ class FFTW(FFT):
zero-centered.
dimensions_input : tuple, list, numpy.ndarray
A tuple containing the masks desired shape.
A tuple containing the mask's desired shape.
offset_input : int, boolean
Specifies whether the zero-th dimension starts with an odd
......@@ -232,81 +232,51 @@ class FFTW(FFT):
Returns
-------
result : np.ndarray
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
current_plan_and_info = self._get_plan_and_info(domain, codomain,
**kwargs)
# Prepare the environment variables
local_size = current_plan_and_info.fftw_local_size
local_start = local_size[2]
local_end = local_start + local_size[1]
# If the input is a numpy array we transform it locally
if not isinstance(val, distributed_data_object):
# Copy data for local manipulation
local_val = val[:]
# Create domain and codomain centering mask
domain_centering_mask = self.get_centering_mask(
domain.paradict['zerocenter'],
domain.get_shape()
)
codomain_centering_mask = self.get_centering_mask(
codomain.paradict['zerocenter'],
codomain.get_shape()
)
# Prepare the input data
# Case 1: val is a distributed_data_object
if isinstance(val, distributed_data_object):
return_val = val.copy_empty(
global_shape=current_plan_and_info.global_output_shape,
dtype=codomain.dtype)
# If the distribution strategy of the d2o is fftw, extract
# the data directly
if val.distribution_strategy == 'fftw':
local_val = val.get_local_data()
else:
local_val = val.get_data(slice(local_start, local_end),
local_keys=True).get_local_data()
# Case 2: val is a numpy array carrying the full data
# Apply codomain centering mask
for slice_list in utilities.get_slice_list(val.shape, axes):
if slice_list == [slice(None, None)]:
local_val *= codomain_centering_mask
else:
local_val = val[slice(local_start, local_end)]
local_val *= current_plan_and_info.get_codomain_centering_mask()
# Define a abbreviation for the fftw plan
p = current_plan_and_info.get_plan()
# load the field into the plan
if p.has_input:
p.input_array[:] = local_val
# execute the plan
p()
local_val[slice_list] *= codomain_centering_mask
if p.has_output:
result = p.output_array * current_plan_and_info. \
get_domain_centering_mask()
else:
result = local_val
assert (result.shape[0] == 0)
# build the return object according to the input val
# TODO: Check if comm is the same, too!
try:
if return_val.distribution_strategy == 'fftw':
return_val.set_local_data(data=result)
# We use pyfftw.interface.numpy_fft module to handle transformation
# in this case. The first call, with the given input combination,
# might be slow but the subsequent calls in the same session will
# be much faster.
if codomain.harmonic:
result = pyfftw.interfaces.numpy_fft.fftn(local_val, axes=axes)
else:
return_val.set_data(data=result,
to_key=slice(local_start, local_end),
local_keys=True)
result = pyfftw.interfaces.numpy_fft.ifftn(local_val, axes=axes)
Please register or sign in to reply
# If the values living in domain are purely real, the
# result of the fft is hermitian
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
# In case the input val was not a distributed data obect, the try
# will produce a NameError
except(NameError):
return_val = distributed_data_object(
global_shape=current_plan_and_info.global_output_shape,
dtype=codomain.dtype,
distribution_strategy='fftw')
return_val.set_local_data(data=result)
return_val = return_val.get_full_data()
# The +-1 magic has a caveat: if a dimension was zero-centered
# in the harmonic as well as in position space, the result gets
# a global minus. The following multiplitcation compensates that.
return_val *= current_plan_and_info.overall_sign
# Apply domain centering mask
for slice_list in utilities.get_slice_list(val.shape, axes):
if slice_list == [slice(None, None)]:
result *= domain_centering_mask
else:
result[slice_list] *= domain_centering_mask
return return_val
return result.astype(codomain.dtype)
else:
# val is a distributed_data_object
pass
# The instances of plan_and_info store the fftw plan and all
......@@ -430,7 +400,7 @@ class GFFT(FFT):
Returns
-------
result : np.ndarray
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# GFFT doesn't accept d2o objects as input. Consolidate data from
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment