Commit 90d6e2f7 authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'fix_fftw_transform' into 'master'

Fix fftw transform



See merge request !35
parents 72fa347b 2b070f69
......@@ -7,7 +7,7 @@ from nifty import LineEnergy
class LineSearch(object, Loggable):
"""
Class for finding a step size.
Class for finding a step size.
"""
__metaclass__ = abc.ABCMeta
......
......@@ -164,24 +164,26 @@ class FFTW(Transform):
self.centering_mask_dict[temp_id] = centering_mask
return self.centering_mask_dict[temp_id]
def _get_transform_info(self, domain, codomain, local_shape,
def _get_transform_info(self, domain, codomain, axes, local_shape,
local_offset_Q, is_local, transform_shape=None,
**kwargs):
# generate a id-tuple which identifies the domain-codomain setting
temp_id = (domain.__hash__() ^
(101 * codomain.__hash__()) ^
(211 * transform_shape.__hash__()))
(211 * transform_shape.__hash__()) ^
(131 * is_local.__hash__())
)
# generate the plan_and_info object if not already there
if temp_id not in self.info_dict:
if is_local:
self.info_dict[temp_id] = FFTWLocalTransformInfo(
domain, codomain, local_shape,
domain, codomain, axes, local_shape,
local_offset_Q, self, **kwargs
)
else:
self.info_dict[temp_id] = FFTWMPITransfromInfo(
domain, codomain, local_shape,
domain, codomain, axes, local_shape,
local_offset_Q, self, transform_shape, **kwargs
)
......@@ -248,17 +250,16 @@ class FFTW(Transform):
# val must be numpy array or d2o with slicing distributor
###
local_offset_Q = False
try:
local_val = val.get_local_data(copy=False)
if axes is None or 0 in axes:
local_offset_Q = val.distributor.local_shape[0] % 2
except(AttributeError):
local_val = val
current_info = self._get_transform_info(self.domain,
self.codomain,
axes,
local_shape=local_val.shape,
local_offset_Q=local_offset_Q,
local_offset_Q=False,
is_local=True,
**kwargs)
......@@ -309,14 +310,10 @@ class FFTW(Transform):
def _mpi_transform(self, val, axes, **kwargs):
if axes is None or 0 in axes:
local_offset_list = np.cumsum(
np.concatenate([[0, ], val.distributor.all_local_slices[:, 2]])
)
local_offset_Q = bool(
local_offset_list[val.distributor.comm.rank] % 2)
else:
local_offset_Q = False
local_offset_list = np.cumsum(
np.concatenate([[0, ], val.distributor.all_local_slices[:, 2]])
)
local_offset_Q = bool(local_offset_list[val.distributor.comm.rank] % 2)
return_val = val.copy_empty(global_shape=val.shape,
dtype=self.codomain.dtype)
......@@ -362,6 +359,7 @@ class FFTW(Transform):
current_info = self._get_transform_info(
self.domain,
self.codomain,
axes,
local_shape=val.local_shape,
local_offset_Q=local_offset_Q,
is_local=False,
......@@ -446,20 +444,22 @@ class FFTW(Transform):
class FFTWTransformInfo(object):
def __init__(self, domain, codomain, local_shape,
def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs):
if pyfftw is None:
raise ImportError("The module pyfftw is needed but not available.")
self.cmask_domain = fftw_context.get_centering_mask(
domain.zerocenter,
local_shape,
local_offset_Q)
shape = (local_shape if axes is None else
[y for x, y in enumerate(local_shape) if x in axes])
self.cmask_domain = fftw_context.get_centering_mask(domain.zerocenter,
shape,
local_offset_Q)
self.cmask_codomain = fftw_context.get_centering_mask(
codomain.zerocenter,
local_shape,
local_offset_Q)
codomain.zerocenter,
shape,
local_offset_Q)
# If both domain and codomain are zero-centered the result,
# will get a global minus. Store the sign to correct it.
......@@ -493,10 +493,11 @@ class FFTWTransformInfo(object):
class FFTWLocalTransformInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, local_shape,
def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs):
super(FFTWLocalTransformInfo, self).__init__(domain,
codomain,
axes,
local_shape,
local_offset_Q,
fftw_context,
......@@ -512,10 +513,11 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
class FFTWMPITransfromInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, local_shape,
def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, transform_shape, **kwargs):
super(FFTWMPITransfromInfo, self).__init__(domain,
codomain,
axes,
local_shape,
local_offset_Q,
fftw_context,
......
......@@ -7,29 +7,29 @@ from nifty import RGSpace, nifty_configuration
class RGRGTransformation(Transformation):
def __init__(self, domain, codomain=None, module=None):
super(RGRGTransformation, self).__init__(domain, codomain, module)
super(RGRGTransformation, self).__init__(domain, codomain)
if module is None:
if nifty_configuration['fft_module'] == 'pyfftw':
self._transform = FFTW(domain, codomain)
self._transform = FFTW(self.domain, self.codomain)
elif (nifty_configuration['fft_module'] == 'gfft' or
nifty_configuration['fft_module'] == 'gfft_dummy'):
self._transform = \
GFFT(domain,
codomain,
GFFT(self.domain,
self.codomain,
gdi.get(nifty_configuration['fft_module']))
else:
raise ValueError('ERROR: unknow default FFT module:' +
nifty_configuration['fft_module'])
else:
if module == 'pyfftw':
self._transform = FFTW(domain, codomain)
self._transform = FFTW(self.domain, self.codomain)
elif module == 'gfft':
self._transform = \
GFFT(domain, codomain, gdi.get('gfft'))
GFFT(self.domain, self.codomain, gdi.get('gfft'))
elif module == 'gfft_dummy':
self._transform = \
GFFT(domain, codomain, gdi.get('gfft_dummy'))
GFFT(self.domain, self.codomain, gdi.get('gfft_dummy'))
else:
raise ValueError('ERROR: unknow FFT module:' + module)
......
......@@ -11,7 +11,7 @@ class Transformation(object, Loggable):
"""
__metaclass__ = abc.ABCMeta
def __init__(self, domain, codomain=None, module=None):
def __init__(self, domain, codomain):
if codomain is None:
self.domain = domain
self.codomain = self.get_codomain(domain)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment