Commit 157f4fea authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: Correct mask generation in FFTWTransformInfo

parent c9ae0203
Pipeline #5112 skipped
......@@ -2,9 +2,8 @@
import numpy as np
from mpi4py import MPI
from d2o import distributed_data_object, STRATEGIES
from nifty.config import dependency_injector as gdi
from nifty.config import about
from d2o import distributed_data_object, distributor_factory, STRATEGIES
from nifty.config import about, dependency_injector as gdi
import nifty.nifty_utilities as utilities
pyfftw = gdi.get('pyfftw')
......@@ -206,7 +205,8 @@ class FFTW(FFT):
self.centering_mask_dict[temp_id] = centering_mask
return self.centering_mask_dict[temp_id]
def _get_transform_info(self, domain, codomain, is_local=False, **kwargs):
def _get_transform_info(self, domain, codomain, local_shape_info=None,
is_local=False, **kwargs):
# generate a id-tuple which identifies the domain-codomain setting
temp_id = domain.__hash__() ^ (101 * codomain.__hash__())
......@@ -214,12 +214,12 @@ class FFTW(FFT):
if temp_id not in self.info_dict:
if is_local:
self.info_dict[temp_id] = FFTWLocalTransformInfo(
domain, codomain,
domain, codomain, local_shape_info,
self, **kwargs
)
else:
self.info_dict[temp_id] = FFTWMPITransfromInfo(
domain, codomain,
domain, codomain, local_shape_info,
self, **kwargs
)
......@@ -350,12 +350,25 @@ class FFTW(FFT):
return return_val
def _get_local_shape_info(self, comm, global_shape, distribution_strategy):
if distribution_strategy == 'equal':
local_slice = distributor_factory._equal_slicer(comm, global_shape)
local_shape = np.append((local_slice[1] - local_slice[0],),
global_shape[1:])
return (local_shape, local_slice[0])
def _slicing_fftw_mpi_transform(self, val, domain, codomain,
axes, **kwargs):
current_info = self._get_transform_info(domain, codomain, **kwargs)
current_info = self._get_transform_info(
domain, codomain,
local_shape_info=self._get_local_shape_info(
val.comm, val.shape, val.distribution_strategy
),
**kwargs
)
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
dtype=codomain.dtype)
# Extract local data
local_val = val.get_local_data(copy=False)
......@@ -481,23 +494,41 @@ class FFTW(FFT):
class FFTWTransformInfo(object):
def __init__(self, domain, codomain, fftw_context, **kwargs):
def __init__(self, domain, codomain, local_shape_info,
fftw_context, **kwargs):
if pyfftw is None:
raise ImportError("The module pyfftw is needed but not available.")
# Create domain centering mask. The mask is the same size
# as the domain's shape.
self.cmask_domain = fftw_context.get_centering_mask(
domain.paradict['zerocenter'],
domain.get_shape(),
True
)
# When the domain being transformed is not split across ranks, the
# mask will then have the same shape as the domain. The offset
# is set to False since every node will have index starting from 0. In
# the other case, we use the supplied local_shape_info to get the
# local_shape and offset
if local_shape_info is None:
self.cmask_domain = fftw_context.get_centering_mask(
domain.paradict['zerocenter'],
domain.get_shape(),
False
)
else:
self.cmask_domain = fftw_context.get_centering_mask(
domain.paradict['zerocenter'],
local_shape_info[0],
local_shape_info[1] % 2
)
self.cmask_codomain = fftw_context.get_centering_mask(
codomain.paradict['zerocenter'],
codomain.get_shape(),
True
)
if local_shape_info is None:
self.cmask_codomain = fftw_context.get_centering_mask(
codomain.paradict['zerocenter'],
codomain.get_shape(),
False
)
else:
self.cmask_domain = fftw_context.get_centering_mask(
codomain.paradict['zerocenter'],
local_shape_info[0],
local_shape_info[1] % 2
)
# If both domain and codomain are zero-centered the result,
# will get a global minus. Store the sign to correct it.
......@@ -533,9 +564,11 @@ class FFTWTransformInfo(object):
class FFTWLocalTransformInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, fftw_context, **kwargs):
super(FFTWLocalTransformInfo, self).__init__(domain, codomain,
fftw_context, **kwargs)
def __init__(self, domain, codomain, local_shape_info,
fftw_context, **kwargs):
super(FFTWLocalTransformInfo, self).__init__(
domain, codomain, local_shape_info, fftw_context, **kwargs
)
if codomain.harmonic:
self._fftw_interface = pyfftw.interfaces.numpy_fft.fftn
else:
......@@ -552,9 +585,11 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
class FFTWMPITransfromInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, fftw_context, **kwargs):
super(FFTWMPITransfromInfo, self).__init__(domain, codomain,
fftw_context, **kwargs)
def __init__(self, domain, codomain, local_shape_info,
fftw_context, **kwargs):
super(FFTWMPITransfromInfo, self).__init__(
domain, codomain, local_shape_info, fftw_context, **kwargs
)
# When the domain is 1-dimensional, reshape it so that it can
# accept input which is also augmented by 1.
if len(domain.get_shape()) == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment