Commit 40ca6219 authored by theos's avatar theos
Browse files

Detail changes during code review.

parent e1407dca
Pipeline #4231 skipped
......@@ -76,7 +76,7 @@ class FFT(object):
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
"""
return None
raise NotImplementedError
class FFTW(FFT):
......@@ -243,9 +243,11 @@ class FFTW(FFT):
return val * mask
def _transform(self, val, current_info, axes):
def _local_transform(self, val, current_info, axes, domain, codomain):
# Apply codomain centering mask
val = self._apply_mask(val, current_info.cmask_codomain, axes)
if reduce(lambda x, y: x+y, codomain.paradict['zerocenter']):
temp_val = np.copy(val)
val = self._apply_mask(temp_val, current_info.cmask_codomain, axes)
if current_info.local:
result = current_info.plan(val, axes=axes,
......@@ -264,14 +266,15 @@ class FFTW(FFT):
raise RuntimeError('ERROR: PyFFTW-MPI transform failed.')
# Apply domain centering mask
result = self._apply_mask(val, current_info.cmask_domain, axes)
if reduce(lambda x, y: x+y, domain.paradict['zerocenter']):
result = self._apply_mask(result, current_info.cmask_domain, axes)
# Correct the sign if needed
result *= current_info.sign
return result
def transform(self, val, domain, codomain, axes, **kwargs):
def transform(self, val, domain, codomain, axes=None, **kwargs):
"""
The pyfftw transform function.
......@@ -304,56 +307,49 @@ class FFTW(FFT):
is_local=True,
**kwargs)
# Copy data for local manipulation
local_val = np.copy(val)
return_val = self._transform(local_val, current_info, axes)
return_val = self._local_transform(val, current_info, axes,
domain, codomain)
else:
if val.distribution_strategy == 'not':
new_val = val.copy(distribution_strategy='fftw')
if val.distribution_strategy is 'fftw':
# Recursive call to take advantage of the fact that the data
# necessary is already present on the nodes.
return_val = self.transform(new_val, domain, codomain, axes,
**kwargs)
# self._full_mpi_transformation(...)
return_val = return_val.copy(distribution_strategy='not')
elif val.distribution_strategy in ('equal', 'fftw', 'freeform'):
current_info = self._get_plan_and_info(
domain, codomain,
is_local=False if not axes else True, **kwargs
is_local=(axes is None), **kwargs # TODO: axes is None is not suffcient. axes could be the full hand-crafted tuple.
)
# Repack input as 'fftw' to make life easier
original_strategy = None
if val.distribution_strategy != 'fftw':
original_strategy = val.distribution_strategy
val_temp = val.copy(distribution_strategy='fftw')
else:
val_temp = val
# Create the object which will be returned
return_val = val_temp.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
# Extract local data
local_val = val_temp.get_local_data()
local_val = val.get_local_data(copy=False)
# Transfom the data and place result in place
result = self._transform(local_val, current_info, axes)
result = self._local_transform(local_val, current_info, axes,
domain, codomain)
return_val.set_local_data(data=result, copy=False)
# If the values living in domain are purely real, the
# resulting FFT is purely real
# resulting FFT is hermitian
# TODO: in case of multiple spaces using one d2o array, there isn't anymore 'the' hermitianity
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
# Revert to original distribution strategy, if needed
if not original_strategy:
return_val = return_val.copy(
distribution_strategy=original_strategy
)
else:
raise ValueError('ERROR: Unknown distribution strategy')
new_val = val.copy(distribution_strategy='fftw',
copy=False)
# Recursive call to take advantage of the fact that the data
# necessary is already present on the nodes.
return_val = self.transform(new_val, domain, codomain, axes,
**kwargs)
return_val = return_val.copy(
distribution_strategy=val.distribution_strategy,
copy=False)
return return_val
......@@ -399,7 +395,7 @@ class FFTWTransformInfo(object):
self.plan = pyfftw.create_mpi_plan(
input_shape=domain.get_shape(),
input_dtype='complex128',
output_dtype='complext128',
output_dtype='complex128',
direction='FFTW_FORWARD'
if codomain.harmonic else 'FFTW_BACKWARD',
flags=["FFTW_ESTIMATE"],
......@@ -466,7 +462,7 @@ class GFFT(FFT):
raise ImportError(
"The gfft(_dummy)-module is needed but not available.")
def transform(self, val, domain, codomain, axes, **kwargs):
def transform(self, val, domain, codomain, axes=None, **kwargs):
"""
The gfft transform function.
......@@ -496,22 +492,26 @@ class GFFT(FFT):
# GFFT doesn't accept d2o objects as input. Consolidate data from
# all nodes into numpy.ndarray before proceeding.
if isinstance(val, distributed_data_object):
temp = val.get_full_data()
temp_inp = val.get_full_data()
else:
temp = val
temp_inp = val
# Cast input datatype to codomain's dtype
temp = temp.astype(codomain.dtype)
temp_inp = temp_inp.astype(np.complex128, copy=False)
# Array for storing the result
return_val = np.empty_like(temp)
return_val = None
for slice_list in utilities.get_slice_list(temp_inp.shape, axes):
for slice_list in utilities.get_slice_list(temp.shape, axes):
# don't copy the whole data array
if slice_list == [slice(None, None)]:
inp = temp
inp = temp_inp
else:
inp = temp[slice_list]
# initialize the return_val object if needed
if return_val is None:
return_val = np.empty_like(temp_inp)
inp = temp_inp[slice_list]
inp = self.fft_machine.gfft(
inp,
......@@ -526,18 +526,20 @@ class GFFT(FFT):
alpha=-1,
verbose=False
)
return_val[slice_list] = inp
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=codomain.dtype)
new_val.set_full_data(return_val)
new_val.set_full_data(return_val, copy=False)
# If the values living in domain are purely real, the result of
# the fft is hermitian
if domain.paradict['complexity'] == 0:
new_val.hermitian = True
return_val = new_val
else:
return_val = return_val.astype(codomain.dtype)
return_val = return_val.astype(codomain.dtype, copy=False)
return return_val
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment