Commit 90d6e2f7 authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'fix_fftw_transform' into 'master'

Fix fftw transform



See merge request !35
parents 72fa347b 2b070f69
...@@ -7,7 +7,7 @@ from nifty import LineEnergy ...@@ -7,7 +7,7 @@ from nifty import LineEnergy
class LineSearch(object, Loggable): class LineSearch(object, Loggable):
""" """
Class for finding a step size. Class for finding a step size.
""" """
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
......
...@@ -164,24 +164,26 @@ class FFTW(Transform): ...@@ -164,24 +164,26 @@ class FFTW(Transform):
self.centering_mask_dict[temp_id] = centering_mask self.centering_mask_dict[temp_id] = centering_mask
return self.centering_mask_dict[temp_id] return self.centering_mask_dict[temp_id]
def _get_transform_info(self, domain, codomain, local_shape, def _get_transform_info(self, domain, codomain, axes, local_shape,
local_offset_Q, is_local, transform_shape=None, local_offset_Q, is_local, transform_shape=None,
**kwargs): **kwargs):
# generate a id-tuple which identifies the domain-codomain setting # generate a id-tuple which identifies the domain-codomain setting
temp_id = (domain.__hash__() ^ temp_id = (domain.__hash__() ^
(101 * codomain.__hash__()) ^ (101 * codomain.__hash__()) ^
(211 * transform_shape.__hash__())) (211 * transform_shape.__hash__()) ^
(131 * is_local.__hash__())
)
# generate the plan_and_info object if not already there # generate the plan_and_info object if not already there
if temp_id not in self.info_dict: if temp_id not in self.info_dict:
if is_local: if is_local:
self.info_dict[temp_id] = FFTWLocalTransformInfo( self.info_dict[temp_id] = FFTWLocalTransformInfo(
domain, codomain, local_shape, domain, codomain, axes, local_shape,
local_offset_Q, self, **kwargs local_offset_Q, self, **kwargs
) )
else: else:
self.info_dict[temp_id] = FFTWMPITransfromInfo( self.info_dict[temp_id] = FFTWMPITransfromInfo(
domain, codomain, local_shape, domain, codomain, axes, local_shape,
local_offset_Q, self, transform_shape, **kwargs local_offset_Q, self, transform_shape, **kwargs
) )
...@@ -248,17 +250,16 @@ class FFTW(Transform): ...@@ -248,17 +250,16 @@ class FFTW(Transform):
# val must be numpy array or d2o with slicing distributor # val must be numpy array or d2o with slicing distributor
### ###
local_offset_Q = False
try: try:
local_val = val.get_local_data(copy=False) local_val = val.get_local_data(copy=False)
if axes is None or 0 in axes:
local_offset_Q = val.distributor.local_shape[0] % 2
except(AttributeError): except(AttributeError):
local_val = val local_val = val
current_info = self._get_transform_info(self.domain, current_info = self._get_transform_info(self.domain,
self.codomain, self.codomain,
axes,
local_shape=local_val.shape, local_shape=local_val.shape,
local_offset_Q=local_offset_Q, local_offset_Q=False,
is_local=True, is_local=True,
**kwargs) **kwargs)
...@@ -309,14 +310,10 @@ class FFTW(Transform): ...@@ -309,14 +310,10 @@ class FFTW(Transform):
def _mpi_transform(self, val, axes, **kwargs): def _mpi_transform(self, val, axes, **kwargs):
if axes is None or 0 in axes: local_offset_list = np.cumsum(
local_offset_list = np.cumsum( np.concatenate([[0, ], val.distributor.all_local_slices[:, 2]])
np.concatenate([[0, ], val.distributor.all_local_slices[:, 2]]) )
) local_offset_Q = bool(local_offset_list[val.distributor.comm.rank] % 2)
local_offset_Q = bool(
local_offset_list[val.distributor.comm.rank] % 2)
else:
local_offset_Q = False
return_val = val.copy_empty(global_shape=val.shape, return_val = val.copy_empty(global_shape=val.shape,
dtype=self.codomain.dtype) dtype=self.codomain.dtype)
...@@ -362,6 +359,7 @@ class FFTW(Transform): ...@@ -362,6 +359,7 @@ class FFTW(Transform):
current_info = self._get_transform_info( current_info = self._get_transform_info(
self.domain, self.domain,
self.codomain, self.codomain,
axes,
local_shape=val.local_shape, local_shape=val.local_shape,
local_offset_Q=local_offset_Q, local_offset_Q=local_offset_Q,
is_local=False, is_local=False,
...@@ -446,20 +444,22 @@ class FFTW(Transform): ...@@ -446,20 +444,22 @@ class FFTW(Transform):
class FFTWTransformInfo(object): class FFTWTransformInfo(object):
def __init__(self, domain, codomain, local_shape, def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs): local_offset_Q, fftw_context, **kwargs):
if pyfftw is None: if pyfftw is None:
raise ImportError("The module pyfftw is needed but not available.") raise ImportError("The module pyfftw is needed but not available.")
self.cmask_domain = fftw_context.get_centering_mask( shape = (local_shape if axes is None else
domain.zerocenter, [y for x, y in enumerate(local_shape) if x in axes])
local_shape,
local_offset_Q) self.cmask_domain = fftw_context.get_centering_mask(domain.zerocenter,
shape,
local_offset_Q)
self.cmask_codomain = fftw_context.get_centering_mask( self.cmask_codomain = fftw_context.get_centering_mask(
codomain.zerocenter, codomain.zerocenter,
local_shape, shape,
local_offset_Q) local_offset_Q)
# If both domain and codomain are zero-centered the result, # If both domain and codomain are zero-centered the result,
# will get a global minus. Store the sign to correct it. # will get a global minus. Store the sign to correct it.
...@@ -493,10 +493,11 @@ class FFTWTransformInfo(object): ...@@ -493,10 +493,11 @@ class FFTWTransformInfo(object):
class FFTWLocalTransformInfo(FFTWTransformInfo): class FFTWLocalTransformInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, local_shape, def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs): local_offset_Q, fftw_context, **kwargs):
super(FFTWLocalTransformInfo, self).__init__(domain, super(FFTWLocalTransformInfo, self).__init__(domain,
codomain, codomain,
axes,
local_shape, local_shape,
local_offset_Q, local_offset_Q,
fftw_context, fftw_context,
...@@ -512,10 +513,11 @@ class FFTWLocalTransformInfo(FFTWTransformInfo): ...@@ -512,10 +513,11 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
class FFTWMPITransfromInfo(FFTWTransformInfo): class FFTWMPITransfromInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, local_shape, def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, transform_shape, **kwargs): local_offset_Q, fftw_context, transform_shape, **kwargs):
super(FFTWMPITransfromInfo, self).__init__(domain, super(FFTWMPITransfromInfo, self).__init__(domain,
codomain, codomain,
axes,
local_shape, local_shape,
local_offset_Q, local_offset_Q,
fftw_context, fftw_context,
......
...@@ -7,29 +7,29 @@ from nifty import RGSpace, nifty_configuration ...@@ -7,29 +7,29 @@ from nifty import RGSpace, nifty_configuration
class RGRGTransformation(Transformation): class RGRGTransformation(Transformation):
def __init__(self, domain, codomain=None, module=None): def __init__(self, domain, codomain=None, module=None):
super(RGRGTransformation, self).__init__(domain, codomain, module) super(RGRGTransformation, self).__init__(domain, codomain)
if module is None: if module is None:
if nifty_configuration['fft_module'] == 'pyfftw': if nifty_configuration['fft_module'] == 'pyfftw':
self._transform = FFTW(domain, codomain) self._transform = FFTW(self.domain, self.codomain)
elif (nifty_configuration['fft_module'] == 'gfft' or elif (nifty_configuration['fft_module'] == 'gfft' or
nifty_configuration['fft_module'] == 'gfft_dummy'): nifty_configuration['fft_module'] == 'gfft_dummy'):
self._transform = \ self._transform = \
GFFT(domain, GFFT(self.domain,
codomain, self.codomain,
gdi.get(nifty_configuration['fft_module'])) gdi.get(nifty_configuration['fft_module']))
else: else:
raise ValueError('ERROR: unknow default FFT module:' + raise ValueError('ERROR: unknow default FFT module:' +
nifty_configuration['fft_module']) nifty_configuration['fft_module'])
else: else:
if module == 'pyfftw': if module == 'pyfftw':
self._transform = FFTW(domain, codomain) self._transform = FFTW(self.domain, self.codomain)
elif module == 'gfft': elif module == 'gfft':
self._transform = \ self._transform = \
GFFT(domain, codomain, gdi.get('gfft')) GFFT(self.domain, self.codomain, gdi.get('gfft'))
elif module == 'gfft_dummy': elif module == 'gfft_dummy':
self._transform = \ self._transform = \
GFFT(domain, codomain, gdi.get('gfft_dummy')) GFFT(self.domain, self.codomain, gdi.get('gfft_dummy'))
else: else:
raise ValueError('ERROR: unknow FFT module:' + module) raise ValueError('ERROR: unknow FFT module:' + module)
......
...@@ -11,7 +11,7 @@ class Transformation(object, Loggable): ...@@ -11,7 +11,7 @@ class Transformation(object, Loggable):
""" """
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, domain, codomain=None, module=None): def __init__(self, domain, codomain):
if codomain is None: if codomain is None:
self.domain = domain self.domain = domain
self.codomain = self.get_codomain(domain) self.codomain = self.get_codomain(domain)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment