Commit f0e7585f authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: Add axes keyword to FFTW

- Finished the case when input is np.ndarray
- Fix typos
parent 2ae2d181
Pipeline #3364 skipped
...@@ -109,7 +109,7 @@ class FFTW(FFT): ...@@ -109,7 +109,7 @@ class FFTW(FFT):
zero-centered. zero-centered.
dimensions_input : tuple, list, numpy.ndarray dimensions_input : tuple, list, numpy.ndarray
A tuple containing the masks desired shape. A tuple containing the mask's desired shape.
offset_input : int, boolean offset_input : int, boolean
Specifies whether the zero-th dimension starts with an odd Specifies whether the zero-th dimension starts with an odd
...@@ -232,81 +232,51 @@ class FFTW(FFT): ...@@ -232,81 +232,51 @@ class FFTW(FFT):
Returns Returns
------- -------
result : np.ndarray result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field. Fourier-transformed pendant of the input field.
""" """
current_plan_and_info = self._get_plan_and_info(domain, codomain, # If the input is a numpy array we transform it locally
**kwargs) if not isinstance(val, distributed_data_object):
# Prepare the environment variables # Copy data for local manipulation
local_size = current_plan_and_info.fftw_local_size local_val = val[:]
local_start = local_size[2]
local_end = local_start + local_size[1] # Create domain and codomain centering mask
domain_centering_mask = self.get_centering_mask(
# Prepare the input data domain.paradict['zerocenter'],
# Case 1: val is a distributed_data_object domain.get_shape()
if isinstance(val, distributed_data_object): )
return_val = val.copy_empty( codomain_centering_mask = self.get_centering_mask(
global_shape=current_plan_and_info.global_output_shape, codomain.paradict['zerocenter'],
dtype=codomain.dtype) codomain.get_shape()
# If the distribution strategy of the d2o is fftw, extract )
# the data directly
if val.distribution_strategy == 'fftw':
local_val = val.get_local_data()
else:
local_val = val.get_data(slice(local_start, local_end),
local_keys=True).get_local_data()
# Case 2: val is a numpy array carrying the full data
else:
local_val = val[slice(local_start, local_end)]
local_val *= current_plan_and_info.get_codomain_centering_mask()
# Define a abbreviation for the fftw plan
p = current_plan_and_info.get_plan()
# load the field into the plan
if p.has_input:
p.input_array[:] = local_val
# execute the plan
p()
if p.has_output: # Apply codomain centering mask
result = p.output_array * current_plan_and_info. \ for slice_list in utilities.get_slice_list(val.shape, axes):
get_domain_centering_mask() if slice_list == [slice(None, None)]:
else: local_val *= codomain_centering_mask
result = local_val else:
assert (result.shape[0] == 0) local_val[slice_list] *= codomain_centering_mask
# build the return object according to the input val # We use pyfftw.interface.numpy_fft module to handle transformation
# TODO: Check if comm is the same, too! # in this case. The first call, with the given input combination,
try: # might be slow but the subsequent calls in the same session will
if return_val.distribution_strategy == 'fftw': # be much faster.
return_val.set_local_data(data=result) if codomain.harmonic:
result = pyfftw.interfaces.numpy_fft.fftn(local_val, axes=axes)
else: else:
return_val.set_data(data=result, result = pyfftw.interfaces.numpy_fft.ifftn(local_val, axes=axes)
Please register or sign in to reply
to_key=slice(local_start, local_end),
local_keys=True)
# If the values living in domain are purely real, the # Apply domain centering mask
# result of the fft is hermitian for slice_list in utilities.get_slice_list(val.shape, axes):
if domain.paradict['complexity'] == 0: if slice_list == [slice(None, None)]:
return_val.hermitian = True result *= domain_centering_mask
else:
# In case the input val was not a distributed data obect, the try result[slice_list] *= domain_centering_mask
# will produce a NameError
except(NameError):
return_val = distributed_data_object(
global_shape=current_plan_and_info.global_output_shape,
dtype=codomain.dtype,
distribution_strategy='fftw')
return_val.set_local_data(data=result)
return_val = return_val.get_full_data()
# The +-1 magic has a caveat: if a dimension was zero-centered
# in the harmonic as well as in position space, the result gets
# a global minus. The following multiplitcation compensates that.
return_val *= current_plan_and_info.overall_sign
return return_val return result.astype(codomain.dtype)
else:
# val is a distributed_data_object
pass
# The instances of plan_and_info store the fftw plan and all # The instances of plan_and_info store the fftw plan and all
...@@ -430,7 +400,7 @@ class GFFT(FFT): ...@@ -430,7 +400,7 @@ class GFFT(FFT):
Returns Returns
------- -------
result : np.ndarray result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field. Fourier-transformed pendant of the input field.
""" """
# GFFT doesn't accept d2o objects as input. Consolidate data from # GFFT doesn't accept d2o objects as input. Consolidate data from
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment