Commit 370fdd22 authored by Jait Dixit's avatar Jait Dixit
Browse files

Fix zero-centering part for issue #62

parent a59d90da
...@@ -164,7 +164,7 @@ class FFTW(Transform): ...@@ -164,7 +164,7 @@ class FFTW(Transform):
self.centering_mask_dict[temp_id] = centering_mask self.centering_mask_dict[temp_id] = centering_mask
return self.centering_mask_dict[temp_id] return self.centering_mask_dict[temp_id]
def _get_transform_info(self, domain, codomain, local_shape, def _get_transform_info(self, domain, codomain, axes, local_shape,
local_offset_Q, is_local, transform_shape=None, local_offset_Q, is_local, transform_shape=None,
**kwargs): **kwargs):
# generate a id-tuple which identifies the domain-codomain setting # generate a id-tuple which identifies the domain-codomain setting
...@@ -176,12 +176,12 @@ class FFTW(Transform): ...@@ -176,12 +176,12 @@ class FFTW(Transform):
if temp_id not in self.info_dict: if temp_id not in self.info_dict:
if is_local: if is_local:
self.info_dict[temp_id] = FFTWLocalTransformInfo( self.info_dict[temp_id] = FFTWLocalTransformInfo(
domain, codomain, local_shape, domain, codomain, axes, local_shape,
local_offset_Q, self, **kwargs local_offset_Q, self, **kwargs
) )
else: else:
self.info_dict[temp_id] = FFTWMPITransfromInfo( self.info_dict[temp_id] = FFTWMPITransfromInfo(
domain, codomain, local_shape, domain, codomain, axes, local_shape,
local_offset_Q, self, transform_shape, **kwargs local_offset_Q, self, transform_shape, **kwargs
) )
...@@ -255,6 +255,7 @@ class FFTW(Transform): ...@@ -255,6 +255,7 @@ class FFTW(Transform):
current_info = self._get_transform_info(self.domain, current_info = self._get_transform_info(self.domain,
self.codomain, self.codomain,
axes,
local_shape=local_val.shape, local_shape=local_val.shape,
local_offset_Q=False, local_offset_Q=False,
is_local=True, is_local=True,
...@@ -356,6 +357,7 @@ class FFTW(Transform): ...@@ -356,6 +357,7 @@ class FFTW(Transform):
current_info = self._get_transform_info( current_info = self._get_transform_info(
self.domain, self.domain,
self.codomain, self.codomain,
axes,
local_shape=val.local_shape, local_shape=val.local_shape,
local_offset_Q=local_offset_Q, local_offset_Q=local_offset_Q,
is_local=False, is_local=False,
...@@ -440,19 +442,19 @@ class FFTW(Transform): ...@@ -440,19 +442,19 @@ class FFTW(Transform):
class FFTWTransformInfo(object): class FFTWTransformInfo(object):
def __init__(self, domain, codomain, local_shape, def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs): local_offset_Q, fftw_context, **kwargs):
if pyfftw is None: if pyfftw is None:
raise ImportError("The module pyfftw is needed but not available.") raise ImportError("The module pyfftw is needed but not available.")
self.cmask_domain = fftw_context.get_centering_mask( self.cmask_domain = fftw_context.get_centering_mask(
domain.zerocenter, domain.zerocenter,
local_shape, [y for x, y in enumerate(local_shape) if x in axes],
local_offset_Q) local_offset_Q)
self.cmask_codomain = fftw_context.get_centering_mask( self.cmask_codomain = fftw_context.get_centering_mask(
codomain.zerocenter, codomain.zerocenter,
local_shape, [y for x, y in enumerate(local_shape) if x in axes],
local_offset_Q) local_offset_Q)
# If both domain and codomain are zero-centered the result, # If both domain and codomain are zero-centered the result,
...@@ -487,10 +489,11 @@ class FFTWTransformInfo(object): ...@@ -487,10 +489,11 @@ class FFTWTransformInfo(object):
class FFTWLocalTransformInfo(FFTWTransformInfo): class FFTWLocalTransformInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, local_shape, def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs): local_offset_Q, fftw_context, **kwargs):
super(FFTWLocalTransformInfo, self).__init__(domain, super(FFTWLocalTransformInfo, self).__init__(domain,
codomain, codomain,
axes,
local_shape, local_shape,
local_offset_Q, local_offset_Q,
fftw_context, fftw_context,
...@@ -506,10 +509,11 @@ class FFTWLocalTransformInfo(FFTWTransformInfo): ...@@ -506,10 +509,11 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
class FFTWMPITransfromInfo(FFTWTransformInfo): class FFTWMPITransfromInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, local_shape, def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, transform_shape, **kwargs): local_offset_Q, fftw_context, transform_shape, **kwargs):
super(FFTWMPITransfromInfo, self).__init__(domain, super(FFTWMPITransfromInfo, self).__init__(domain,
codomain, codomain,
axes,
local_shape, local_shape,
local_offset_Q, local_offset_Q,
fftw_context, fftw_context,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment