Commit 370fdd22 authored by Jait Dixit's avatar Jait Dixit
Browse files

Fix zero-centering part for issue #62

parent a59d90da
......@@ -164,7 +164,7 @@ class FFTW(Transform):
self.centering_mask_dict[temp_id] = centering_mask
return self.centering_mask_dict[temp_id]
def _get_transform_info(self, domain, codomain, local_shape,
def _get_transform_info(self, domain, codomain, axes, local_shape,
local_offset_Q, is_local, transform_shape=None,
**kwargs):
# generate a id-tuple which identifies the domain-codomain setting
......@@ -176,12 +176,12 @@ class FFTW(Transform):
if temp_id not in self.info_dict:
if is_local:
self.info_dict[temp_id] = FFTWLocalTransformInfo(
domain, codomain, local_shape,
domain, codomain, axes, local_shape,
local_offset_Q, self, **kwargs
)
else:
self.info_dict[temp_id] = FFTWMPITransfromInfo(
domain, codomain, local_shape,
domain, codomain, axes, local_shape,
local_offset_Q, self, transform_shape, **kwargs
)
......@@ -255,6 +255,7 @@ class FFTW(Transform):
current_info = self._get_transform_info(self.domain,
self.codomain,
axes,
local_shape=local_val.shape,
local_offset_Q=False,
is_local=True,
......@@ -356,6 +357,7 @@ class FFTW(Transform):
current_info = self._get_transform_info(
self.domain,
self.codomain,
axes,
local_shape=val.local_shape,
local_offset_Q=local_offset_Q,
is_local=False,
......@@ -440,19 +442,19 @@ class FFTW(Transform):
class FFTWTransformInfo(object):
def __init__(self, domain, codomain, local_shape,
def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs):
if pyfftw is None:
raise ImportError("The module pyfftw is needed but not available.")
self.cmask_domain = fftw_context.get_centering_mask(
domain.zerocenter,
local_shape,
[y for x, y in enumerate(local_shape) if x in axes],
local_offset_Q)
self.cmask_codomain = fftw_context.get_centering_mask(
codomain.zerocenter,
local_shape,
[y for x, y in enumerate(local_shape) if x in axes],
local_offset_Q)
# If both domain and codomain are zero-centered the result,
......@@ -487,10 +489,11 @@ class FFTWTransformInfo(object):
class FFTWLocalTransformInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, local_shape,
def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs):
super(FFTWLocalTransformInfo, self).__init__(domain,
codomain,
axes,
local_shape,
local_offset_Q,
fftw_context,
......@@ -506,10 +509,11 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
class FFTWMPITransfromInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, local_shape,
def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, transform_shape, **kwargs):
super(FFTWMPITransfromInfo, self).__init__(domain,
codomain,
axes,
local_shape,
local_offset_Q,
fftw_context,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment