Commit 6edd0ce5 authored by Jait Dixit's avatar Jait Dixit
Browse files

Add pyfftw-mpi section

parent 5a7e9f94
Pipeline #3779 skipped
......@@ -238,22 +238,32 @@ class FFTW(FFT):
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# TODO +/- magic in both cases
# Create domain and codomain centering mask
domain_centering_mask = self.get_centering_mask(
domain.paradict['zerocenter'],
domain.get_shape()
)
codomain_centering_mask = self.get_centering_mask(
codomain.paradict['zerocenter'],
codomain.get_shape()
)
# The +-1 magic has a caveat: if a dimension was zero-centered
# in the harmonic as well as in position space, the result gets
# a global minus. This keeps track of the sign so that we can
# remedy that before returning.
overall_sign = (-1) ** np.sum(
np.array(domain.paradict['zerocenter']) *
np.array(codomain.paradict['zerocenter']) *
np.array(domain.get_shape() // 2 % 2)
)
# If the input is a numpy array we transform it locally
if not isinstance(val, distributed_data_object):
# Copy data for local manipulation
local_val = np.copy(val)
# Create domain and codomain centering mask
domain_centering_mask = self.get_centering_mask(
domain.paradict['zerocenter'],
domain.get_shape()
)
codomain_centering_mask = self.get_centering_mask(
codomain.paradict['zerocenter'],
codomain.get_shape()
)
# reshape the masks
if axes:
domain_centering_mask = domain_centering_mask.reshape(
......@@ -282,7 +292,8 @@ class FFTW(FFT):
# Apply domain centering mask
return_val *= domain_centering_mask
return return_val.astype(codomain.dtype)
# Correct overall sign
return_val *= overall_sign
else:
if val.distribution_strategy == 'not':
new_val = val.copy(distribution_strategy='fftw')
......@@ -292,7 +303,7 @@ class FFTW(FFT):
elif val.distribution_strategy in ('equal', 'fftw', 'freeform'):
if axes:
# We use pyfftw in this case
# Setup up the array which will be returned
# Setup the return object
return_val = val.copy_empty(
global_shape=domain.get_shape(),
dtype=codomain.type
......@@ -312,23 +323,15 @@ class FFTW(FFT):
local_keys=True
).get_local_data()
# Create domain and codomain centering mask
domain_centering_mask = self.get_centering_mask(
domain.paradict['zerocenter'],
domain.get_shape()
)
codomain_centering_mask = self.get_centering_mask(
codomain.paradict['zerocenter'],
codomain.get_shape()
)
# reshape domain centering mask
domain_centering_mask = domain_centering_mask.reshape(
[y if x in axes else 1
for x, y in enumerate(local_val.shape)]
for x, y in enumerate(local_val.shape)]
)
codomain_centering_mask = codomain_centering_mask.reshape(
[y if x in axes else 1
for x, y in enumerate(local_val.shape)]
for x, y in enumerate(local_val.shape)]
)
# Apply codomain centering mask
......@@ -361,9 +364,69 @@ class FFTW(FFT):
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
return_val *= overall_sign
elif not axes or axes == (0,):
# We use pyfftw-mpi in this case
pass
current_plan_and_info = self._get_plan_and_info(domain,
codomain,
**kwargs)
# Save the original distribution_strategy
original_strategy = None
# Change distribution strategy to fftw
if val.distribution_strategy != 'fftw':
original_strategy = val.distribution_strategy
val_fftw = val.copy(distribution_strategy='fftw')
# Create return object
return_val = val_fftw.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
# Extract local data
local_val = val_fftw.get_local_data()
domain_centering_mask = \
current_plan_and_info.get_domain_centering_mask().\
reshape(
[y if x in axes else 1
for x, y in enumerate(local_val.shape)]
)
codomain_centering_mask = \
current_plan_and_info.get_codomain_centering_mask().\
reshape(
[y if x in axes else 1
for x, y in enumerate(local_val.shape)]
)
local_val *= codomain_centering_mask
# Get the fftw plan
p = current_plan_and_info.get_plan()
# Load the value into the plan
if p.has_input:
p.input_array[:] = local_val
# execute the plan
p()
if p.has_output:
result = p.output_array * domain_centering_mask
else:
result = local_val
assert(result.shape[0] == 0)
return_val.set_local_data(data=result)
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
# reset to original strategy
return_val = return_val.copy(
distribution_strategy=original_strategy
)
# Correct sign using current_plan_and_info
return_val *= current_plan_and_info.ovrall_sign
else:
raise ValueError('ERROR: Unknown distribution strategy')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment