Commit 83ae9e94 authored by Jait Dixit's avatar Jait Dixit
Browse files

Fix issues #50 and #52

parent 69fa111d
...@@ -7,7 +7,7 @@ import itertools ...@@ -7,7 +7,7 @@ import itertools
from nifty import RGSpace, LMSpace, HPSpace, GLSpace from nifty import RGSpace, LMSpace, HPSpace, GLSpace
from nifty import transformator from nifty import transformator
import nifty.transformations.transformation as transformation from nifty.transformations.rgrgtransformation import RGRGTransformation
from nifty.rg.rg_space import gc as RG_GC from nifty.rg.rg_space import gc as RG_GC
import d2o import d2o
...@@ -72,8 +72,8 @@ class TestRGRGTransformation(unittest.TestCase): ...@@ -72,8 +72,8 @@ class TestRGRGTransformation(unittest.TestCase):
def test_check_codomain_rgspecific(self, complexity, distances, harmonic): def test_check_codomain_rgspecific(self, complexity, distances, harmonic):
x = RGSpace((8, 8), complexity=complexity, x = RGSpace((8, 8), complexity=complexity,
distances=distances, harmonic=harmonic) distances=distances, harmonic=harmonic)
assert (transformation.RGRGTransformation.check_codomain(x, x.get_codomain())) assert (RGRGTransformation.check_codomain(x, x.get_codomain()))
assert (transformation.RGRGTransformation.check_codomain(x, x.get_codomain())) assert (RGRGTransformation.check_codomain(x, x.get_codomain()))
@parameterized.expand(rg_rg_fft_modules, testcase_func_name=custom_name_func) @parameterized.expand(rg_rg_fft_modules, testcase_func_name=custom_name_func)
def test_shapemismatch(self, module): def test_shapemismatch(self, module):
......
import numpy as np
from transform import Transform
from d2o import distributed_data_object
import nifty.nifty_utilities as utilities
class GFFT(Transform):
"""
The gfft pendant of a fft object.
Parameters
----------
fft_module_name : String
Switch between the gfft module used: 'gfft' and 'gfft_dummy'
"""
def __init__(self, domain, codomain, fft_module):
self.domain = domain
self.codomain = codomain
self.fft_machine = fft_module
def transform(self, val, axes, **kwargs):
"""
The gfft transform function.
Parameters
----------
val : numpy.ndarray or distributed_data_object
The value-array of the field which is supposed to
be transformed.
axes : None or tuple
The axes which should be transformed.
**kwargs : *optional*
Further kwargs are not processed.
Returns
-------
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("ERROR: Provided axes does not match array shape")
# GFFT doesn't accept d2o objects as input. Consolidate data from
# all nodes into numpy.ndarray before proceeding.
if isinstance(val, distributed_data_object):
temp_inp = val.get_full_data()
else:
temp_inp = val
# Array for storing the result
return_val = None
for slice_list in utilities.get_slice_list(temp_inp.shape, axes):
# don't copy the whole data array
if slice_list == [slice(None, None)]:
inp = temp_inp
else:
# initialize the return_val object if needed
if return_val is None:
return_val = np.empty_like(temp_inp)
inp = temp_inp[slice_list]
inp = self.fft_machine.gfft(
inp,
in_ax=[],
out_ax=[],
ftmachine='fft' if self.codomain.harmonic else 'ifft',
in_zero_center=map(
bool, self.domain.paradict['zerocenter']
),
out_zero_center=map(
bool, self.codomain.paradict['zerocenter']
),
enforce_hermitian_symmetry=bool(
self.codomain.paradict['complexity']
),
W=-1,
alpha=-1,
verbose=False
)
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
# If the values living in domain are purely real, the result of
# the fft is hermitian
if self.domain.paradict['complexity'] == 0:
new_val.hermitian = True
return_val = new_val
else:
return_val = return_val.astype(self.codomain.dtype, copy=False)
return return_val
import numpy as np import numpy as np
from transform import Transform from transformation import Transformation
from d2o import distributed_data_object from d2o import distributed_data_object
from nifty.config import dependency_injector as gdi from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
from nifty import GLSpace, LMSpace
gl = gdi.get('libsharp_wrapper_gl') gl = gdi.get('libsharp_wrapper_gl')
class GLTransform(Transform): class GLLMTransformation(Transformation):
""" def __init__(self, domain, codomain, module=None):
GLTransform wrapper for libsharp's transform functions if 'libsharp_wrapper_gl' not in gdi:
""" raise ImportError("The module libsharp is needed but not available")
def __init__(self, domain, codomain): if self.check_codomain(domain, codomain):
self.domain = domain self.domain = domain
self.codomain = codomain self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
if 'libsharp_wrapper_gl' not in gdi: @staticmethod
raise ImportError("The module libsharp_wrapper_gl " + def check_codomain(domain, codomain):
"is needed but not available") if not isinstance(domain, GLSpace):
raise TypeError('ERROR: domain is not a GLSpace')
if codomain is None:
return False
if not isinstance(codomain, LMSpace):
raise TypeError('ERROR: codomain must be a LMSpace.')
nlat = domain.paradict['nlat']
nlon = domain.paradict['nlon']
lmax = codomain.paradict['lmax']
mmax = codomain.paradict['mmax']
if (nlon != 2 * nlat - 1) or (lmax != nlat - 1) or (lmax != mmax):
return False
return True
def transform(self, val, axes=None, **kwargs):
"""
GL -> LM transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
def transform(self, val, axes, **kwargs): """
if self.domain.discrete: if self.domain.discrete:
val = self.domain.calc_weight(val, power=-0.5) val = self.domain.calc_weight(val, power=-0.5)
......
import numpy as np import numpy as np
from transform import Transform from transformation import Transformation
from d2o import distributed_data_object from d2o import distributed_data_object
from nifty.config import dependency_injector as gdi from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
from nifty import HPSpace, LMSpace
hp = gdi.get('healpy') hp = gdi.get('healpy')
class HPTransform(Transform):
"""
GLTransform wrapper for libsharp's transform functions
"""
def __init__(self, domain, codomain):
self.domain = domain
self.codomain = codomain
class HPLMTransformation(Transformation):
def __init__(self, domain, codomain, module=None):
if 'healpy' not in gdi: if 'healpy' not in gdi:
raise ImportError("The module healpy is needed but not available") raise ImportError("The module healpy is needed but not available")
def transform(self, val, axes, **kwargs): if self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod
def check_codomain(domain, codomain):
if not isinstance(domain, HPSpace):
raise TypeError('ERROR: domain is not a HPSpace')
if codomain is None:
return False
if not isinstance(codomain, LMSpace):
raise TypeError('ERROR: codomain must be a LMSpace.')
nside = domain.paradict['nside']
lmax = codomain.paradict['lmax']
mmax = codomain.paradict['mmax']
if (3 * nside - 1 != lmax) or (lmax != mmax):
return False
return True
def transform(self, val, axes=None, **kwargs):
"""
HP -> LM transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
"""
# get by number of iterations from kwargs # get by number of iterations from kwargs
niter = kwargs['niter'] if 'niter' in kwargs else 0 niter = kwargs['niter'] if 'niter' in kwargs else 0
...@@ -60,4 +92,4 @@ class HPTransform(Transform): ...@@ -60,4 +92,4 @@ class HPTransform(Transform):
else: else:
return_val = return_val.astype(self.codomain.dtype, copy=False) return_val = return_val.astype(self.codomain.dtype, copy=False)
return return_val return return_val
\ No newline at end of file
import numpy as np import numpy as np
from nifty import GLSpace, HPSpace from transformation import Transformation
from nifty.config import about
import nifty.nifty_utilities as utilities
from transform import Transform
from d2o import distributed_data_object from d2o import distributed_data_object
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
from nifty import GLSpace, LMSpace
class LMTransform(Transform): gl = gdi.get('libsharp_wrapper_gl')
"""
LMTransform for transforming to GL/HP space
"""
def __init__(self, domain, codomain, module):
self.domain = domain
self.codomain = codomain
self.module = module
def _transform(self, val): class LMGLTransformation(Transformation):
if isinstance(self.codomain, GLSpace): def __init__(self, domain, codomain, module=None):
# shorthand for transform parameters if gdi.get('libsharp_wrapper_gl') is None:
nlat = self.codomain.paradict['nlat'] raise ImportError(
nlon = self.codomain.paradict['nlon'] "The module libsharp is needed but not available.")
lmax = self.domain.paradict['lmax']
mmax = self.paradict['mmax']
if self.domain.dtype == np.dtype('complex64'): if self.check_codomain(domain, codomain):
val = self.module.alm2map_f(val, nlat=nlat, nlon=nlon, self.domain = domain
lmax=lmax, mmax=mmax, cl=False) self.codomain = codomain
else:
val = self.module.alm2map(val, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
elif isinstance(self.codomain, HPSpace):
# shorthand for transform parameters
nside = self.codomain.paradict['nside']
lmax = self.domain.paradict['lmax']
mmax = self.domain.paradict['mmax']
val = val.astype(np.complex128, copy=False)
val = self.module.alm2map(val, nside, lmax=lmax, mmax=mmax,
pixwin=False, fwhm=0.0, sigma=None,
pol=True, inplace=False)
else: else:
raise ValueError("ERROR: Unsupported transformation.") raise ValueError("ERROR: Incompatible codomain!")
@staticmethod
def check_codomain(domain, codomain):
if not isinstance(domain, LMSpace):
raise TypeError('ERROR: domain is not a LMSpace')
if codomain is None:
return False
if not isinstance(codomain, GLSpace):
raise TypeError('ERROR: codomain must be a GLSpace.')
return val nlat = codomain.paradict['nlat']
nlon = codomain.paradict['nlon']
lmax = domain.paradict['lmax']
mmax = domain.paradict['mmax']
def transform(self, val, axes, **kwargs): if (lmax != mmax) or (nlat != lmax + 1) or (nlon != 2 * lmax + 1):
return False
return True
def transform(self, val, axes=None, **kwargs):
"""
LM -> GL transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
"""
if isinstance(val, distributed_data_object): if isinstance(val, distributed_data_object):
temp_val = val.get_full_data() temp_val = val.get_full_data()
else: else:
...@@ -60,7 +69,17 @@ class LMTransform(Transform): ...@@ -60,7 +69,17 @@ class LMTransform(Transform):
return_val = np.empty_like(temp_val) return_val = np.empty_like(temp_val)
inp = temp_val[slice_list] inp = temp_val[slice_list]
inp = self._transform(inp) nlat = self.codomain.paradict['nlat']
nlon = self.codomain.paradict['nlon']
lmax = self.domain.paradict['lmax']
mmax = self.paradict['mmax']
if self.domain.dtype == np.dtype('complex64'):
inp = gl.alm2map_f(inp, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
else:
inp = gl.alm2map(inp, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
if slice_list == [slice(None, None)]: if slice_list == [slice(None, None)]:
return_val = inp return_val = inp
...@@ -77,4 +96,4 @@ class LMTransform(Transform): ...@@ -77,4 +96,4 @@ class LMTransform(Transform):
else: else:
return_val = return_val.astype(self.codomain.dtype, copy=False) return_val = return_val.astype(self.codomain.dtype, copy=False)
return return_val return return_val
\ No newline at end of file
import numpy as np
from transformation import Transformation
from d2o import distributed_data_object
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
from nifty import HPSpace, LMSpace
hp = gdi.get('healpy')
class LMHPTransformation(Transformation):
def __init__(self, domain, codomain, module=None):
if gdi.get('healpy') is None:
raise ImportError(
"The module libsharp is needed but not available.")
if self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod
def check_codomain(domain, codomain):
if not isinstance(domain, LMSpace):
raise TypeError('ERROR: domain is not a LMSpace')
if codomain is None:
return False
if not isinstance(codomain, HPSpace):
raise TypeError('ERROR: codomain must be a HPSpace.')
nside = codomain.paradict['nside']
lmax = domain.paradict['lmax']
mmax = domain.paradict['mmax']
if (lmax != mmax) or (3 * nside - 1 != lmax):
return False
return True
def transform(self, val, axes=None, **kwargs):
"""
LM -> HP transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
"""
if isinstance(val, distributed_data_object):
temp_val = val.get_full_data()
else:
temp_val = val
return_val = None
for slice_list in utilities.get_slice_list(temp_val.shape, axes):
if slice_list == [slice(None, None)]:
inp = temp_val
else:
if return_val is None:
return_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
nside = self.codomain.paradict['nside']
lmax = self.domain.paradict['lmax']
mmax = self.domain.paradict['mmax']
inp = inp.astype(np.complex128, copy=False)
inp = hp.alm2map(inp, nside, lmax=lmax, mmax=mmax,
pixwin=False, fwhm=0.0, sigma=None,
pol=True, inplace=False)
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
# re-weight if discrete
if self.codomain.discrete:
val = self.codomain.calc_weight(val, power=0.5)
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
else:
return_val = return_val.astype(self.codomain.dtype, copy=False)
return return_val
...@@ -4,15 +4,39 @@ import numpy as np ...@@ -4,15 +4,39 @@ import numpy as np
from d2o import distributed_data_object, STRATEGIES from d2o import distributed_data_object, STRATEGIES
from nifty.config import about, dependency_injector as gdi from nifty.config import about, dependency_injector as gdi
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
from transform import Transform from nifty import nifty_configuration
from mpi4py import MPI
pyfftw = gdi.get('pyfftw') pyfftw = gdi.get('pyfftw')
class FFTW(Transform): class Transform(object):
"""
A generic fft object without any implementation.
"""
def __init__(self, domain, codomain):
pass
def transform(self, val, axes, **kwargs):
"""
A generic ff-transform function.
Parameters
----------
field_val : distributed_data_object
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
"""
raise NotImplementedError
class FFTW(Transform):
""" """
The pyfftw pendant of a fft object. The pyfftw pendant of a fft object.
""" """
...@@ -102,12 +126,13 @@ class FFTW(Transform): ...@@ -102,12 +126,13 @@ class FFTW(Transform):
# until the desired format is constructed. # until the desired format is constructed.
core = np.fromfunction( core = np.fromfunction(
lambda *args: (-1) ** lambda *args: (-1) **
(np.tensordot(to_center, (np.tensordot(to_center,
args + args +
offset.reshape(offset.shape + offset.reshape(offset.shape +
(1,) * (1,) *
(np.array(args).ndim - 1)), (np.array(
1)), args).ndim - 1)),
1)),
(2,) * to_center.size) (2,) * to_center.size)
# Cast the core to the smallest integers we can get # Cast the core to the smallest integers we can get
core = core.astype(np.int8) core = core.astype(np.int8)
...@@ -185,7 +210,7 @@ class FFTW(Transform): ...@@ -185,7 +210,7 @@ class FFTW(Transform):
if axes: if axes: