Commit 9d4cead3 authored by Theo Steininger's avatar Theo Steininger
Browse files

Removed gfft.

parent fd6d03be
Pipeline #11495 passed with stages
in 30 minutes and 34 seconds
......@@ -24,6 +24,6 @@ import keepers
d2o_configuration = keepers.get_Configuration(
name='D2O',
file_name='D2O.conf',
search_pathes=[os.path.expanduser('~') + "/.config/nifty/",
os.path.expanduser('~') + "/.config/",
'./'])
search_paths=[os.path.expanduser('~') + "/.config/nifty/",
os.path.expanduser('~') + "/.config/",
'./'])
......@@ -25,8 +25,6 @@ import keepers
# Setup the dependency injector
dependency_injector = keepers.DependencyInjector(
[('mpi4py.MPI', 'MPI'),
'gfft',
('nifty.dummys.gfft_dummy', 'gfft_dummy'),
'healpy',
'libsharp_wrapper_gl'])
......@@ -36,8 +34,9 @@ dependency_injector.register('pyfftw', lambda z: hasattr(z, 'FFTW_MPI'))
# Initialize the variables
variable_fft_module = keepers.Variable(
'fft_module',
['pyfftw', 'gfft', 'gfft_dummy'],
lambda z: z in dependency_injector)
['fftw', 'numpy'],
lambda z: (('pyfftw' in dependency_injector)
if z == 'fftw' else True))
def _healpy_validator(use_healpy):
......@@ -97,12 +96,11 @@ nifty_configuration = keepers.get_Configuration(
variable_default_field_dtype,
variable_default_distribution_strategy],
file_name='NIFTy.conf',
search_pathes=[os.path.expanduser('~') + "/.config/nifty/",
os.path.expanduser('~') + "/.config/",
'./'])
search_paths=[os.path.expanduser('~') + "/.config/nifty/",
os.path.expanduser('~') + "/.config/",
'./'])
########
########
try:
nifty_configuration.load()
......
......@@ -33,47 +33,10 @@ class Transform(Loggable, object):
A generic fft object without any implementation.
"""
def __init__(self, domain, codomain):
pass
def transform(self, val, axes, **kwargs):
"""
A generic ff-transform function.
Parameters
----------
field_val : distributed_data_object
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
"""
raise NotImplementedError
class FFTW(Transform):
"""
The pyfftw pendant of a fft object.
"""
def __init__(self, domain, codomain):
self.domain = domain
self.codomain = codomain
if 'pyfftw' not in gdi:
raise ImportError("The module pyfftw is needed but not available.")
# Enable caching for pyfftw.interfaces
pyfftw.interfaces.cache.enable()
# The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets.
self.info_dict = {}
# initialize the dictionary which stores the values from
# get_centering_mask
self.centering_mask_dict = {}
......@@ -130,7 +93,10 @@ class FFTW(Transform):
to_center = np.array(temp_to_center)
# cast the offset_input into the shape of to_center
offset = np.zeros(to_center.shape, dtype=int)
offset[0] = int(offset_input)
# if the first dimension has length 1 and has an offset, restore the
# global minus by hand
if not size_one_dimensions[0]:
offset[0] = int(offset_input)
# check for dimension match
if to_center.size != dimensions.size:
raise TypeError(
......@@ -179,34 +145,14 @@ class FFTW(Transform):
else:
temp_slice += (slice(None),)
centering_mask = centering_mask[temp_slice]
# if the first dimension has length 1 and has an offset, restore
# the global minus by hand
if size_one_dimensions[0] and offset_input:
centering_mask *= -1
self.centering_mask_dict[temp_id] = centering_mask
return self.centering_mask_dict[temp_id]
def _get_transform_info(self, domain, codomain, axes, local_shape,
local_offset_Q, is_local, transform_shape=None,
**kwargs):
# generate a id-tuple which identifies the domain-codomain setting
temp_id = (domain.__hash__() ^
(101 * codomain.__hash__()) ^
(211 * transform_shape.__hash__()) ^
(131 * is_local.__hash__())
)
# generate the plan_and_info object if not already there
if temp_id not in self.info_dict:
if is_local:
self.info_dict[temp_id] = FFTWLocalTransformInfo(
domain, codomain, axes, local_shape,
local_offset_Q, self, **kwargs
)
else:
self.info_dict[temp_id] = FFTWMPITransfromInfo(
domain, codomain, axes, local_shape,
local_offset_Q, self, transform_shape, **kwargs
)
return self.info_dict[temp_id]
def _apply_mask(self, val, mask, axes):
"""
Apply centering mask to an array.
......@@ -233,9 +179,71 @@ class FFTW(Transform):
[y if x in axes else 1
for x, y in enumerate(val.shape)]
)
return val * mask
def transform(self, val, axes, **kwargs):
"""
A generic ff-transform function.
Parameters
----------
field_val : distributed_data_object
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
"""
raise NotImplementedError
class FFTW(Transform):
"""
The pyfftw pendant of a fft object.
"""
def __init__(self, domain, codomain):
if 'pyfftw' not in gdi:
raise ImportError("The module pyfftw is needed but not available.")
super(FFTW, self).__init__(domain, codomain)
# Enable caching for pyfftw.interfaces
pyfftw.interfaces.cache.enable()
# The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets.
self.info_dict = {}
def _get_transform_info(self, domain, codomain, axes, local_shape,
local_offset_Q, is_local, transform_shape=None,
**kwargs):
# generate a id-tuple which identifies the domain-codomain setting
temp_id = (domain.__hash__() ^
(101 * codomain.__hash__()) ^
(211 * transform_shape.__hash__()) ^
(131 * is_local.__hash__())
)
# generate the plan_and_info object if not already there
if temp_id not in self.info_dict:
if is_local:
self.info_dict[temp_id] = FFTWLocalTransformInfo(
domain, codomain, axes, local_shape,
local_offset_Q, self, **kwargs
)
else:
self.info_dict[temp_id] = FFTWMPITransfromInfo(
domain, codomain, axes, local_shape,
local_offset_Q, self, transform_shape, **kwargs
)
return self.info_dict[temp_id]
def _atomic_mpi_transform(self, val, info, axes):
# Apply codomain centering mask
if reduce(lambda x, y: x + y, self.codomain.zerocenter):
......@@ -333,7 +341,6 @@ class FFTW(Transform):
np.concatenate([[0, ], val.distributor.all_local_slices[:, 2]])
)
local_offset_Q = bool(local_offset_list[val.distributor.comm.rank] % 2)
result_dtype = np.result_type(np.complex, self.codomain.dtype)
return_val = val.copy_empty(global_shape=val.shape,
dtype=result_dtype)
......@@ -472,45 +479,33 @@ class FFTWTransformInfo(object):
shape = (local_shape if axes is None else
[y for x, y in enumerate(local_shape) if x in axes])
self.cmask_domain = fftw_context.get_centering_mask(domain.zerocenter,
shape,
local_offset_Q)
self._cmask_domain = fftw_context.get_centering_mask(domain.zerocenter,
shape,
local_offset_Q)
self.cmask_codomain = fftw_context.get_centering_mask(
codomain.zerocenter,
shape,
local_offset_Q)
self._cmask_codomain = fftw_context.get_centering_mask(
codomain.zerocenter,
shape,
local_offset_Q)
# If both domain and codomain are zero-centered the result,
# will get a global minus. Store the sign to correct it.
self.sign = (-1) ** np.sum(np.array(domain.zerocenter) *
np.array(codomain.zerocenter) *
(np.array(domain.shape) // 2 % 2))
self._sign = (-1) ** np.sum(np.array(domain.zerocenter) *
np.array(codomain.zerocenter) *
(np.array(domain.shape) // 2 % 2))
@property
def cmask_domain(self):
return self._domain_centering_mask
@cmask_domain.setter
def cmask_domain(self, cmask):
self._domain_centering_mask = cmask
return self._cmask_domain
@property
def cmask_codomain(self):
return self._codomain_centering_mask
@cmask_codomain.setter
def cmask_codomain(self, cmask):
self._codomain_centering_mask = cmask
return self._cmask_codomain
@property
def sign(self):
return self._sign
@sign.setter
def sign(self, sign):
self._sign = sign
class FFTWLocalTransformInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, axes, local_shape,
......@@ -556,9 +551,9 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
return self._plan
class GFFT(Transform):
class NUMPYFFT(Transform):
"""
The gfft pendant of a fft object.
The numpy fft pendant of a fft object.
Parameters
----------
......@@ -567,33 +562,21 @@ class GFFT(Transform):
"""
def __init__(self, domain, codomain, fft_module=None):
if fft_module is None:
fft_module = gdi['gfft_dummy']
# GFFT only works for domains with an even number of pixels per axis
if np.any(np.array(domain.shape) % 2):
raise AttributeError("GFFT needs an even number of pixels per "
"axis of domain. Got: %s" % str(domain.shape))
self.domain = domain
self.codomain = codomain
self.fft_machine = fft_module
def transform(self, val, axes, **kwargs):
"""
The gfft transform function.
The pyfftw transform function.
Parameters
----------
val : numpy.ndarray or distributed_data_object
val : distributed_data_object or numpy.ndarray
The value-array of the field which is supposed to
be transformed.
axes : None or tuple
axes: tuple, None
The axes which should be transformed.
**kwargs : *optional*
Further kwargs are not processed.
Further kwargs are passed to the create_mpi_plan routine.
Returns
-------
......@@ -605,51 +588,70 @@ class GFFT(Transform):
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("Provided axes does not match array shape")
# GFFT doesn't accept d2o objects as input. Consolidate data from
# all nodes into numpy.ndarray before proceeding.
if isinstance(val, distributed_data_object):
temp_inp = val.get_full_data()
else:
temp_inp = val
# Array for storing the result
return_val = None
result_dtype = np.result_type(np.complex, self.codomain.dtype)
return_val = val.copy_empty(global_shape=val.shape,
dtype=result_dtype)
for slice_list in utilities.get_slice_list(temp_inp.shape, axes):
if (axes is None) or (0 in axes) or \
(val.distribution_strategy not in STRATEGIES['slicing']):
# don't copy the whole data array
if slice_list == [slice(None, None)]:
inp = temp_inp
if val.distribution_strategy == 'not':
local_val = val.get_local_data(copy=False)
else:
# initialize the return_val object if needed
if return_val is None:
return_val = np.empty_like(temp_inp, dtype=result_dtype)
inp = temp_inp[slice_list]
inp = self.fft_machine.gfft(
inp,
in_ax=[],
out_ax=[],
ftmachine='fft' if self.codomain.harmonic else 'ifft',
in_zero_center=map(bool, self.domain.zerocenter),
out_zero_center=map(bool, self.codomain.zerocenter),
# enforce_hermitian_symmetry=bool(self.codomain.complexity),
enforce_hermitian_symmetry=False,
W=-1,
alpha=-1,
verbose=False
)
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
local_val = val.get_full_data()
result_data = self._atomic_transform(local_val=local_val,
axes=axes,
local_offset_Q=False)
return_val.set_full_data(result_data, copy=False)
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=result_dtype)
new_val.set_full_data(return_val, copy=False)
return_val = new_val
else:
return_val = return_val.astype(result_dtype, copy=False)
local_offset_list = np.cumsum(
np.concatenate([[0, ],
val.distributor.all_local_slices[:, 2]]))
local_offset_Q = \
bool(local_offset_list[val.distributor.comm.rank] % 2)
local_val = val.get_local_data()
result_data = self._atomic_transform(local_val=local_val,
axes=axes,
local_offset_Q=local_offset_Q)
return_val.set_local_data(result_data, copy=False)
return return_val
def _atomic_transform(self, local_val, axes, local_offset_Q):
# some auxiliaries for the mask computation
local_shape = local_val.shape
shape = (local_shape if axes is None else
[y for x, y in enumerate(local_shape) if x in axes])
# Apply codomain centering mask
if reduce(lambda x, y: x + y, self.codomain.zerocenter):
temp_val = np.copy(local_val)
mask = self.get_centering_mask(self.codomain.zerocenter,
shape,
local_offset_Q)
local_val = self._apply_mask(temp_val, mask, axes)
# perform the transformation
result_val = np.fft.fftn(local_val, axes=axes)
# Apply domain centering mask
if reduce(lambda x, y: x + y, self.domain.zerocenter):
mask = self.get_centering_mask(self.domain.zerocenter,
shape,
local_offset_Q)
result_val = self._apply_mask(result_val, mask, axes)
# If both domain and codomain are zero-centered the result,
# will get a global minus. Store the sign to correct it.
sign = (-1) ** np.sum(np.array(self.domain.zerocenter) *
np.array(self.codomain.zerocenter) *
(np.array(self.domain.shape) // 2 % 2))
if sign != 1:
result_val *= sign
return result_val
......@@ -18,8 +18,7 @@
import numpy as np
from transformation import Transformation
from rg_transforms import FFTW, GFFT
from nifty.config import dependency_injector as gdi
from rg_transforms import FFTW, NUMPYFFT
from nifty import RGSpace, nifty_configuration
......@@ -29,26 +28,18 @@ class RGRGTransformation(Transformation):
module=module)
if module is None:
if nifty_configuration['fft_module'] == 'pyfftw':
if nifty_configuration['fft_module'] == 'fftw':
self._transform = FFTW(self.domain, self.codomain)
elif (nifty_configuration['fft_module'] == 'gfft' or
nifty_configuration['fft_module'] == 'gfft_dummy'):
self._transform = \
GFFT(self.domain,
self.codomain,
gdi.get(nifty_configuration['fft_module']))
elif nifty_configuration['fft_module'] == 'numpy':
self._transform = NUMPYFFT(self.domain, self.codomain)
else:
raise ValueError('ERROR: unknow default FFT module:' +
nifty_configuration['fft_module'])
else:
if module == 'pyfftw':
if module == 'fftw':
self._transform = FFTW(self.domain, self.codomain)
elif module == 'gfft':
self._transform = \
GFFT(self.domain, self.codomain, gdi.get('gfft'))
elif module == 'gfft_dummy':
self._transform = \
GFFT(self.domain, self.codomain, gdi.get('gfft_dummy'))
elif module == 'numpy':
self._transform = NUMPYFFT(self.domain, self.codomain)
else:
raise ValueError('ERROR: unknow FFT module:' + module)
......
......@@ -42,9 +42,9 @@ setup(name="ift_nifty",
]),
include_dirs=[numpy.get_include()],
dependency_links=[
'git+https://gitlab.mpcdf.mpg.de/ift/keepers.git#egg=keepers-0.3.5',
'git+https://gitlab.mpcdf.mpg.de/ift/d2o.git#egg=d2o-1.0.7'],
install_requires=['keepers>=0.3.5', 'd2o>=1.0.7'],
'git+https://gitlab.mpcdf.mpg.de/ift/keepers.git#egg=keepers-0.3.7',
'git+https://gitlab.mpcdf.mpg.de/ift/d2o.git#egg=d2o-1.0.8'],
install_requires=['keepers>=0.3.7', 'd2o>=1.0.8'],
package_data={'nifty.demos': ['demo_faraday_map.npy'],
},
license="GPLv3",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment