Commit e66fac9a authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: Handle transforming along (0,) and when axes is None

parent 40ca6219
Pipeline #4312 skipped
......@@ -204,7 +204,7 @@ class FFTW(FFT):
self.centering_mask_dict[temp_id] = centering_mask
return self.centering_mask_dict[temp_id]
def _get_plan_and_info(self, domain, codomain, is_local=False, **kwargs):
def _get_transform_info(self, domain, codomain, is_local=False, **kwargs):
# generate a id-tuple which identifies the domain-codomain setting
temp_id = domain.__hash__() ^ (101 * codomain.__hash__())
# generate the plan_and_info object if not already there
......@@ -243,17 +243,16 @@ class FFTW(FFT):
return val * mask
def _local_transform(self, val, current_info, axes, domain, codomain):
def _local_transform(self, val, info, axes, domain, codomain):
# Apply codomain centering mask
if reduce(lambda x, y: x+y, codomain.paradict['zerocenter']):
temp_val = np.copy(val)
val = self._apply_mask(temp_val, current_info.cmask_codomain, axes)
val = self._apply_mask(temp_val, info.cmask_codomain, axes)
if current_info.local:
result = current_info.plan(val, axes=axes,
planner_effort='FFTW_ESTIMATE')
if info.local:
result = info.plan(val, axes=axes, planner_effort='FFTW_ESTIMATE')
else:
p = current_info.plan
p = info.plan
# Load the value into the plan
if p.has_input:
p.input_array[:] = val
......@@ -267,10 +266,10 @@ class FFTW(FFT):
# Apply domain centering mask
if reduce(lambda x, y: x+y, domain.paradict['zerocenter']):
result = self._apply_mask(result, current_info.cmask_domain, axes)
result = self._apply_mask(result, info.cmask_domain, axes)
# Correct the sign if needed
result *= current_info.sign
result *= info.sign
return result
......@@ -303,23 +302,22 @@ class FFTW(FFT):
"""
# If the input is a numpy array we transform it locally
if not isinstance(val, distributed_data_object):
current_info = self._get_plan_and_info(domain, codomain,
is_local=True,
**kwargs)
current_info = self._get_transform_info(domain, codomain,
is_local=True,
**kwargs)
# Copy data for local manipulation
return_val = self._local_transform(val, current_info, axes,
domain, codomain)
else:
if val.distribution_strategy is 'fftw':
# self._full_mpi_transformation(...)
current_info = self._get_plan_and_info(
# Create a local transform if the axes is anything other than
# None or (0,) explicitly.
current_info = self._get_transform_info(
domain, codomain,
is_local=(axes is None), **kwargs # TODO: axes is None is not suffcient. axes could be the full hand-crafted tuple.
is_local=(False if axes is None or axes == (0,) else True),
Please register or sign in to reply
**kwargs
)
# Create the object which will be returned
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
......@@ -327,14 +325,35 @@ class FFTW(FFT):
# Extract local data
local_val = val.get_local_data(copy=False)
# Transfom the data and place result in place
result = self._local_transform(local_val, current_info, axes,
domain, codomain)
return_val.set_local_data(data=result, copy=False)
# Intermediate storage for individual slices, if transforming
# only along (0,)
temp_val = None
for slice_list in utilities.get_slice_list(local_val.shape,
axes):
if slice_list == [slice(None, None)]:
inp = local_val
else:
if temp_val is None:
temp_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
# Transfom the data
result = self._local_transform(inp, current_info,
axes, domain, codomain)
if slice_list == [slice(None, None)]:
return_val.set_local_data(data=result, copy=False)
else:
temp_val[slice_list] = result
if temp_val is not None:
return_val.set_local_data(data=result, copy=False)
# If the values living in domain are purely real, the
# resulting FFT is hermitian
# TODO: in case of multiple spaces using one d2o array, there isn't anymore 'the' hermitianity
# TODO: in case of multiple spaces using one d2o array, there
# isn't anymore 'the' hermitianity
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
......@@ -348,8 +367,9 @@ class FFTW(FFT):
**kwargs)
return_val = return_val.copy(
distribution_strategy=val.distribution_strategy,
copy=False)
distribution_strategy=val.distribution_strategy,
copy=False
)
return return_val
......@@ -520,8 +540,9 @@ class GFFT(FFT):
ftmachine='fft' if codomain.harmonic else 'ifft',
in_zero_center=map(bool, domain.paradict['zerocenter']),
out_zero_center=map(bool, codomain.paradict['zerocenter']),
enforce_hermitian_symmetry=
bool(codomain.paradict['complexity']),
enforce_hermitian_symmetry=bool(
codomain.paradict['complexity']
),
W=-1,
alpha=-1,
verbose=False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment