Commit 157f4fea authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: Correct mask generation in FFTWTransformInfo

parent c9ae0203
Pipeline #5112 skipped
...@@ -2,9 +2,8 @@ ...@@ -2,9 +2,8 @@
import numpy as np import numpy as np
from mpi4py import MPI from mpi4py import MPI
from d2o import distributed_data_object, STRATEGIES from d2o import distributed_data_object, distributor_factory, STRATEGIES
from nifty.config import dependency_injector as gdi from nifty.config import about, dependency_injector as gdi
from nifty.config import about
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
pyfftw = gdi.get('pyfftw') pyfftw = gdi.get('pyfftw')
...@@ -206,7 +205,8 @@ class FFTW(FFT): ...@@ -206,7 +205,8 @@ class FFTW(FFT):
self.centering_mask_dict[temp_id] = centering_mask self.centering_mask_dict[temp_id] = centering_mask
return self.centering_mask_dict[temp_id] return self.centering_mask_dict[temp_id]
def _get_transform_info(self, domain, codomain, is_local=False, **kwargs): def _get_transform_info(self, domain, codomain, local_shape_info=None,
is_local=False, **kwargs):
# generate a id-tuple which identifies the domain-codomain setting # generate a id-tuple which identifies the domain-codomain setting
temp_id = domain.__hash__() ^ (101 * codomain.__hash__()) temp_id = domain.__hash__() ^ (101 * codomain.__hash__())
...@@ -214,12 +214,12 @@ class FFTW(FFT): ...@@ -214,12 +214,12 @@ class FFTW(FFT):
if temp_id not in self.info_dict: if temp_id not in self.info_dict:
if is_local: if is_local:
self.info_dict[temp_id] = FFTWLocalTransformInfo( self.info_dict[temp_id] = FFTWLocalTransformInfo(
domain, codomain, domain, codomain, local_shape_info,
self, **kwargs self, **kwargs
) )
else: else:
self.info_dict[temp_id] = FFTWMPITransfromInfo( self.info_dict[temp_id] = FFTWMPITransfromInfo(
domain, codomain, domain, codomain, local_shape_info,
self, **kwargs self, **kwargs
) )
...@@ -350,12 +350,25 @@ class FFTW(FFT): ...@@ -350,12 +350,25 @@ class FFTW(FFT):
return return_val return return_val
def _get_local_shape_info(self, comm, global_shape, distribution_strategy):
if distribution_strategy == 'equal':
local_slice = distributor_factory._equal_slicer(comm, global_shape)
local_shape = np.append((local_slice[1] - local_slice[0],),
global_shape[1:])
return (local_shape, local_slice[0])
def _slicing_fftw_mpi_transform(self, val, domain, codomain, def _slicing_fftw_mpi_transform(self, val, domain, codomain,
axes, **kwargs): axes, **kwargs):
current_info = self._get_transform_info(domain, codomain, **kwargs) current_info = self._get_transform_info(
domain, codomain,
local_shape_info=self._get_local_shape_info(
val.comm, val.shape, val.distribution_strategy
),
**kwargs
)
return_val = val.copy_empty(global_shape=val.shape, return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype) dtype=codomain.dtype)
# Extract local data # Extract local data
local_val = val.get_local_data(copy=False) local_val = val.get_local_data(copy=False)
...@@ -481,23 +494,41 @@ class FFTW(FFT): ...@@ -481,23 +494,41 @@ class FFTW(FFT):
class FFTWTransformInfo(object): class FFTWTransformInfo(object):
def __init__(self, domain, codomain, fftw_context, **kwargs): def __init__(self, domain, codomain, local_shape_info,
fftw_context, **kwargs):
if pyfftw is None: if pyfftw is None:
raise ImportError("The module pyfftw is needed but not available.") raise ImportError("The module pyfftw is needed but not available.")
# Create domain centering mask. The mask is the same size # When the domain being transformed is not split across ranks, the
# as the domain's shape. # mask will then have the same shape as the domain. The offset
self.cmask_domain = fftw_context.get_centering_mask( # is set to False since every node will have index starting from 0. In
domain.paradict['zerocenter'], # the other case, we use the supplied local_shape_info to get the
domain.get_shape(), # local_shape and offset
True if local_shape_info is None:
) self.cmask_domain = fftw_context.get_centering_mask(
domain.paradict['zerocenter'],
domain.get_shape(),
False
)
else:
self.cmask_domain = fftw_context.get_centering_mask(
domain.paradict['zerocenter'],
local_shape_info[0],
local_shape_info[1] % 2
)
self.cmask_codomain = fftw_context.get_centering_mask( if local_shape_info is None:
codomain.paradict['zerocenter'], self.cmask_codomain = fftw_context.get_centering_mask(
codomain.get_shape(), codomain.paradict['zerocenter'],
True codomain.get_shape(),
) False
)
else:
self.cmask_domain = fftw_context.get_centering_mask(
codomain.paradict['zerocenter'],
local_shape_info[0],
local_shape_info[1] % 2
)
# If both domain and codomain are zero-centered the result, # If both domain and codomain are zero-centered the result,
# will get a global minus. Store the sign to correct it. # will get a global minus. Store the sign to correct it.
...@@ -533,9 +564,11 @@ class FFTWTransformInfo(object): ...@@ -533,9 +564,11 @@ class FFTWTransformInfo(object):
class FFTWLocalTransformInfo(FFTWTransformInfo): class FFTWLocalTransformInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, fftw_context, **kwargs): def __init__(self, domain, codomain, local_shape_info,
super(FFTWLocalTransformInfo, self).__init__(domain, codomain, fftw_context, **kwargs):
fftw_context, **kwargs) super(FFTWLocalTransformInfo, self).__init__(
domain, codomain, local_shape_info, fftw_context, **kwargs
)
if codomain.harmonic: if codomain.harmonic:
self._fftw_interface = pyfftw.interfaces.numpy_fft.fftn self._fftw_interface = pyfftw.interfaces.numpy_fft.fftn
else: else:
...@@ -552,9 +585,11 @@ class FFTWLocalTransformInfo(FFTWTransformInfo): ...@@ -552,9 +585,11 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
class FFTWMPITransfromInfo(FFTWTransformInfo): class FFTWMPITransfromInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, fftw_context, **kwargs): def __init__(self, domain, codomain, local_shape_info,
super(FFTWMPITransfromInfo, self).__init__(domain, codomain, fftw_context, **kwargs):
fftw_context, **kwargs) super(FFTWMPITransfromInfo, self).__init__(
domain, codomain, local_shape_info, fftw_context, **kwargs
)
# When the domain is 1-dimensional, reshape it so that it can # When the domain is 1-dimensional, reshape it so that it can
# accept input which is also augmented by 1. # accept input which is also augmented by 1.
if len(domain.get_shape()) == 1: if len(domain.get_shape()) == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment