Commit 83ae9e94 authored by Jait Dixit's avatar Jait Dixit
Browse files

Fix issues #50 and #52

parent 69fa111d
......@@ -7,7 +7,7 @@ import itertools
from nifty import RGSpace, LMSpace, HPSpace, GLSpace
from nifty import transformator
import nifty.transformations.transformation as transformation
from nifty.transformations.rgrgtransformation import RGRGTransformation
from nifty.rg.rg_space import gc as RG_GC
import d2o
......@@ -72,8 +72,8 @@ class TestRGRGTransformation(unittest.TestCase):
def test_check_codomain_rgspecific(self, complexity, distances, harmonic):
x = RGSpace((8, 8), complexity=complexity,
distances=distances, harmonic=harmonic)
assert (transformation.RGRGTransformation.check_codomain(x, x.get_codomain()))
assert (transformation.RGRGTransformation.check_codomain(x, x.get_codomain()))
assert (RGRGTransformation.check_codomain(x, x.get_codomain()))
assert (RGRGTransformation.check_codomain(x, x.get_codomain()))
@parameterized.expand(rg_rg_fft_modules, testcase_func_name=custom_name_func)
def test_shapemismatch(self, module):
......
import numpy as np
from transform import Transform
from d2o import distributed_data_object
import nifty.nifty_utilities as utilities
class GFFT(Transform):
"""
The gfft pendant of a fft object.
Parameters
----------
fft_module_name : String
Switch between the gfft module used: 'gfft' and 'gfft_dummy'
"""
def __init__(self, domain, codomain, fft_module):
self.domain = domain
self.codomain = codomain
self.fft_machine = fft_module
def transform(self, val, axes, **kwargs):
"""
The gfft transform function.
Parameters
----------
val : numpy.ndarray or distributed_data_object
The value-array of the field which is supposed to
be transformed.
axes : None or tuple
The axes which should be transformed.
**kwargs : *optional*
Further kwargs are not processed.
Returns
-------
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("ERROR: Provided axes does not match array shape")
# GFFT doesn't accept d2o objects as input. Consolidate data from
# all nodes into numpy.ndarray before proceeding.
if isinstance(val, distributed_data_object):
temp_inp = val.get_full_data()
else:
temp_inp = val
# Array for storing the result
return_val = None
for slice_list in utilities.get_slice_list(temp_inp.shape, axes):
# don't copy the whole data array
if slice_list == [slice(None, None)]:
inp = temp_inp
else:
# initialize the return_val object if needed
if return_val is None:
return_val = np.empty_like(temp_inp)
inp = temp_inp[slice_list]
inp = self.fft_machine.gfft(
inp,
in_ax=[],
out_ax=[],
ftmachine='fft' if self.codomain.harmonic else 'ifft',
in_zero_center=map(
bool, self.domain.paradict['zerocenter']
),
out_zero_center=map(
bool, self.codomain.paradict['zerocenter']
),
enforce_hermitian_symmetry=bool(
self.codomain.paradict['complexity']
),
W=-1,
alpha=-1,
verbose=False
)
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
# If the values living in domain are purely real, the result of
# the fft is hermitian
if self.domain.paradict['complexity'] == 0:
new_val.hermitian = True
return_val = new_val
else:
return_val = return_val.astype(self.codomain.dtype, copy=False)
return return_val
import numpy as np
from transform import Transform
from transformation import Transformation
from d2o import distributed_data_object
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
from nifty import GLSpace, LMSpace
gl = gdi.get('libsharp_wrapper_gl')
class GLTransform(Transform):
"""
GLTransform wrapper for libsharp's transform functions
"""
class GLLMTransformation(Transformation):
def __init__(self, domain, codomain, module=None):
if 'libsharp_wrapper_gl' not in gdi:
raise ImportError("The module libsharp is needed but not available")
def __init__(self, domain, codomain):
self.domain = domain
self.codomain = codomain
if self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
if 'libsharp_wrapper_gl' not in gdi:
raise ImportError("The module libsharp_wrapper_gl " +
"is needed but not available")
@staticmethod
def check_codomain(domain, codomain):
if not isinstance(domain, GLSpace):
raise TypeError('ERROR: domain is not a GLSpace')
if codomain is None:
return False
if not isinstance(codomain, LMSpace):
raise TypeError('ERROR: codomain must be a LMSpace.')
nlat = domain.paradict['nlat']
nlon = domain.paradict['nlon']
lmax = codomain.paradict['lmax']
mmax = codomain.paradict['mmax']
if (nlon != 2 * nlat - 1) or (lmax != nlat - 1) or (lmax != mmax):
return False
return True
def transform(self, val, axes=None, **kwargs):
"""
GL -> LM transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
def transform(self, val, axes, **kwargs):
"""
if self.domain.discrete:
val = self.domain.calc_weight(val, power=-0.5)
......
import numpy as np
from transform import Transform
from transformation import Transformation
from d2o import distributed_data_object
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
from nifty import HPSpace, LMSpace
hp = gdi.get('healpy')
class HPTransform(Transform):
"""
GLTransform wrapper for libsharp's transform functions
"""
def __init__(self, domain, codomain):
self.domain = domain
self.codomain = codomain
class HPLMTransformation(Transformation):
def __init__(self, domain, codomain, module=None):
if 'healpy' not in gdi:
raise ImportError("The module healpy is needed but not available")
def transform(self, val, axes, **kwargs):
if self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod
def check_codomain(domain, codomain):
if not isinstance(domain, HPSpace):
raise TypeError('ERROR: domain is not a HPSpace')
if codomain is None:
return False
if not isinstance(codomain, LMSpace):
raise TypeError('ERROR: codomain must be a LMSpace.')
nside = domain.paradict['nside']
lmax = codomain.paradict['lmax']
mmax = codomain.paradict['mmax']
if (3 * nside - 1 != lmax) or (lmax != mmax):
return False
return True
def transform(self, val, axes=None, **kwargs):
"""
HP -> LM transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
"""
# get by number of iterations from kwargs
niter = kwargs['niter'] if 'niter' in kwargs else 0
......@@ -60,4 +92,4 @@ class HPTransform(Transform):
else:
return_val = return_val.astype(self.codomain.dtype, copy=False)
return return_val
\ No newline at end of file
return return_val
import numpy as np
from nifty import GLSpace, HPSpace
from nifty.config import about
import nifty.nifty_utilities as utilities
from transform import Transform
from transformation import Transformation
from d2o import distributed_data_object
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
from nifty import GLSpace, LMSpace
class LMTransform(Transform):
"""
LMTransform for transforming to GL/HP space
"""
gl = gdi.get('libsharp_wrapper_gl')
def __init__(self, domain, codomain, module):
self.domain = domain
self.codomain = codomain
self.module = module
def _transform(self, val):
if isinstance(self.codomain, GLSpace):
# shorthand for transform parameters
nlat = self.codomain.paradict['nlat']
nlon = self.codomain.paradict['nlon']
lmax = self.domain.paradict['lmax']
mmax = self.paradict['mmax']
class LMGLTransformation(Transformation):
def __init__(self, domain, codomain, module=None):
if gdi.get('libsharp_wrapper_gl') is None:
raise ImportError(
"The module libsharp is needed but not available.")
if self.domain.dtype == np.dtype('complex64'):
val = self.module.alm2map_f(val, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
else:
val = self.module.alm2map(val, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
elif isinstance(self.codomain, HPSpace):
# shorthand for transform parameters
nside = self.codomain.paradict['nside']
lmax = self.domain.paradict['lmax']
mmax = self.domain.paradict['mmax']
val = val.astype(np.complex128, copy=False)
val = self.module.alm2map(val, nside, lmax=lmax, mmax=mmax,
pixwin=False, fwhm=0.0, sigma=None,
pol=True, inplace=False)
if self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Unsupported transformation.")
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod
def check_codomain(domain, codomain):
if not isinstance(domain, LMSpace):
raise TypeError('ERROR: domain is not a LMSpace')
if codomain is None:
return False
if not isinstance(codomain, GLSpace):
raise TypeError('ERROR: codomain must be a GLSpace.')
return val
nlat = codomain.paradict['nlat']
nlon = codomain.paradict['nlon']
lmax = domain.paradict['lmax']
mmax = domain.paradict['mmax']
def transform(self, val, axes, **kwargs):
if (lmax != mmax) or (nlat != lmax + 1) or (nlon != 2 * lmax + 1):
return False
return True
def transform(self, val, axes=None, **kwargs):
"""
LM -> GL transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
"""
if isinstance(val, distributed_data_object):
temp_val = val.get_full_data()
else:
......@@ -60,7 +69,17 @@ class LMTransform(Transform):
return_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
inp = self._transform(inp)
nlat = self.codomain.paradict['nlat']
nlon = self.codomain.paradict['nlon']
lmax = self.domain.paradict['lmax']
mmax = self.paradict['mmax']
if self.domain.dtype == np.dtype('complex64'):
inp = gl.alm2map_f(inp, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
else:
inp = gl.alm2map(inp, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
if slice_list == [slice(None, None)]:
return_val = inp
......@@ -77,4 +96,4 @@ class LMTransform(Transform):
else:
return_val = return_val.astype(self.codomain.dtype, copy=False)
return return_val
\ No newline at end of file
return return_val
import numpy as np
from transformation import Transformation
from d2o import distributed_data_object
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
from nifty import HPSpace, LMSpace
hp = gdi.get('healpy')
class LMHPTransformation(Transformation):
def __init__(self, domain, codomain, module=None):
if gdi.get('healpy') is None:
raise ImportError(
"The module libsharp is needed but not available.")
if self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod
def check_codomain(domain, codomain):
if not isinstance(domain, LMSpace):
raise TypeError('ERROR: domain is not a LMSpace')
if codomain is None:
return False
if not isinstance(codomain, HPSpace):
raise TypeError('ERROR: codomain must be a HPSpace.')
nside = codomain.paradict['nside']
lmax = domain.paradict['lmax']
mmax = domain.paradict['mmax']
if (lmax != mmax) or (3 * nside - 1 != lmax):
return False
return True
def transform(self, val, axes=None, **kwargs):
"""
LM -> HP transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
"""
if isinstance(val, distributed_data_object):
temp_val = val.get_full_data()
else:
temp_val = val
return_val = None
for slice_list in utilities.get_slice_list(temp_val.shape, axes):
if slice_list == [slice(None, None)]:
inp = temp_val
else:
if return_val is None:
return_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
nside = self.codomain.paradict['nside']
lmax = self.domain.paradict['lmax']
mmax = self.domain.paradict['mmax']
inp = inp.astype(np.complex128, copy=False)
inp = hp.alm2map(inp, nside, lmax=lmax, mmax=mmax,
pixwin=False, fwhm=0.0, sigma=None,
pol=True, inplace=False)
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
# re-weight if discrete
if self.codomain.discrete:
val = self.codomain.calc_weight(val, power=0.5)
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
else:
return_val = return_val.astype(self.codomain.dtype, copy=False)
return return_val
......@@ -4,15 +4,39 @@ import numpy as np
from d2o import distributed_data_object, STRATEGIES
from nifty.config import about, dependency_injector as gdi
import nifty.nifty_utilities as utilities
from transform import Transform
from mpi4py import MPI
from nifty import nifty_configuration
pyfftw = gdi.get('pyfftw')
class FFTW(Transform):
class Transform(object):
"""
A generic fft object without any implementation.
"""
def __init__(self, domain, codomain):
pass
def transform(self, val, axes, **kwargs):
"""
A generic ff-transform function.
Parameters
----------
field_val : distributed_data_object
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
"""
raise NotImplementedError
class FFTW(Transform):
"""
The pyfftw pendant of a fft object.
"""
......@@ -102,12 +126,13 @@ class FFTW(Transform):
# until the desired format is constructed.
core = np.fromfunction(
lambda *args: (-1) **
(np.tensordot(to_center,
args +
offset.reshape(offset.shape +
(1,) *
(np.array(args).ndim - 1)),
1)),
(np.tensordot(to_center,
args +
offset.reshape(offset.shape +
(1,) *
(np.array(
args).ndim - 1)),
1)),
(2,) * to_center.size)
# Cast the core to the smallest integers we can get
core = core.astype(np.int8)
......@@ -185,7 +210,7 @@ class FFTW(Transform):
if axes:
mask = mask.reshape(
[y if x in axes else 1
for x, y in enumerate(val.shape)]
for x, y in enumerate(val.shape)]
)
return val * mask
......@@ -193,7 +218,7 @@ class FFTW(Transform):
def _atomic_mpi_transform(self, val, info, axes):
# Apply codomain centering mask
if reduce(lambda x, y: x+y, self.codomain.paradict['zerocenter']):
if reduce(lambda x, y: x + y, self.codomain.paradict['zerocenter']):
temp_val = np.copy(val)
val = self._apply_mask(temp_val, info.cmask_codomain, axes)
......@@ -210,7 +235,7 @@ class FFTW(Transform):
return None
# Apply domain centering mask
if reduce(lambda x, y: x+y, self.domain.paradict['zerocenter']):
if reduce(lambda x, y: x + y, self.domain.paradict['zerocenter']):
result = self._apply_mask(result, info.cmask_domain, axes)
# Correct the sign if needed
......@@ -238,7 +263,7 @@ class FFTW(Transform):
**kwargs)
# Apply codomain centering mask
if reduce(lambda x, y: x+y, self.codomain.paradict['zerocenter']):
if reduce(lambda x, y: x + y, self.codomain.paradict['zerocenter']):
temp_val = np.copy(local_val)
local_val = self._apply_mask(temp_val,
current_info.cmask_codomain, axes)
......@@ -250,7 +275,7 @@ class FFTW(Transform):
)
# Apply domain centering mask
if reduce(lambda x, y: x+y, self.domain.paradict['zerocenter']):
if reduce(lambda x, y: x + y, self.domain.paradict['zerocenter']):
local_result = self._apply_mask(local_result,
current_info.cmask_domain, axes)
......@@ -297,7 +322,6 @@ class FFTW(Transform):
return_val = val.copy_empty(global_shape=val.shape,
dtype=self.codomain.dtype)
# Extract local data
local_val = val.get_local_data(copy=False)
......@@ -335,7 +359,8 @@ class FFTW(Transform):
local_shape=val.local_shape,
local_offset_Q=local_offset_Q,
is_local=False,
transform_shape=val.shape, # TODO: check why inp.shape doesn't work
transform_shape=val.shape,
# TODO: check why inp.shape doesn't work
**kwargs