Commit 0fcd7cdc authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'lm_space_complex_transformsV2' into 'feature/field_multiple_space'

Lm space complex transforms v2



See merge request !27
parents 3229c061 0f43abfc
...@@ -5,7 +5,7 @@ from d2o import distributed_data_object,\ ...@@ -5,7 +5,7 @@ from d2o import distributed_data_object,\
STRATEGIES as DISTRIBUTION_STRATEGIES STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.config import about,\ from nifty.config import about,\
nifty_configuration as gc,\ nifty_configuration as gc
from nifty.field_types import FieldType from nifty.field_types import FieldType
......
from nifty.config import about from nifty.config import about
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
from nifty.spaces import RGSpace,\
GLSpace,\
HPSpace,\
LMSpace
from nifty.operators.linear_operator import LinearOperator from nifty.operators.linear_operator import LinearOperator
from transformations import TransformationFactory from transformations import RGRGTransformation,\
LMGLTransformation,\
LMHPTransformation,\
GLLMTransformation,\
HPLMTransformation,\
TransformationCache
class FFTOperator(LinearOperator): class FFTOperator(LinearOperator):
# ---Class attributes---
default_codomain_dictionary = {RGSpace: RGSpace,
HPSpace: LMSpace,
GLSpace: LMSpace,
LMSpace: HPSpace,
}
transformation_dictionary = {(RGSpace, RGSpace): RGRGTransformation,
(HPSpace, LMSpace): HPLMTransformation,
(GLSpace, LMSpace): GLLMTransformation,
(LMSpace, HPSpace): LMHPTransformation,
(LMSpace, GLSpace): LMGLTransformation
}
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), target=None): def __init__(self, domain=(), field_type=(), target=None, module=None):
super(FFTOperator, self).__init__(domain=domain, super(FFTOperator, self).__init__(domain=domain,
field_type=field_type) field_type=field_type)
# Initialize domain and target
if len(self.domain) != 1: if len(self.domain) != 1:
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
'ERROR: TransformationOperator accepts only exactly one ' 'ERROR: TransformationOperator accepts only exactly one '
...@@ -24,17 +51,30 @@ class FFTOperator(LinearOperator): ...@@ -24,17 +51,30 @@ class FFTOperator(LinearOperator):
)) ))
if target is None: if target is None:
target = utilities.get_default_codomain(self.domain[0]) target = (self.get_default_codomain(self.domain[0]), )
self._target = self._parse_domain(target) self._target = self._parse_domain(target)
self._forward_transformation = TransformationFactory.create( # Create transformation instances
self.domain[0], self.target[0] try:
) forward_class = self.transformation_dictionary[
(self.domain[0].__class__, self.target[0].__class__)]
self._inverse_transformation = TransformationFactory.create( except KeyError:
self.target[0], self.domain[0] raise TypeError(about._errors.cstring(
) "ERROR: No forward transformation for domain-target pair "
"found."))
try:
backward_class = self.transformation_dictionary[
(self.target[0].__class__, self.domain[0].__class__)]
except KeyError:
raise TypeError(about._errors.cstring(
"ERROR: No backward transformation for domain-target pair "
"found."))
self._forward_transformation = TransformationCache.create(
forward_class, self.domain[0], self.target[0], module=module)
self._backward_transformation = TransformationCache.create(
backward_class, self.target[0], self.domain[0], module=module)
def _times(self, x, spaces, types): def _times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
...@@ -69,7 +109,7 @@ class FFTOperator(LinearOperator): ...@@ -69,7 +109,7 @@ class FFTOperator(LinearOperator):
else: else:
axes = x.domain_axes[spaces[0]] axes = x.domain_axes[spaces[0]]
new_val = self._inverse_transformation.transform(x.val, axes=axes) new_val = self._backward_transformation.transform(x.val, axes=axes)
if spaces is None: if spaces is None:
result_domain = self.domain result_domain = self.domain
...@@ -99,3 +139,22 @@ class FFTOperator(LinearOperator): ...@@ -99,3 +139,22 @@ class FFTOperator(LinearOperator):
@property @property
def unitary(self): def unitary(self):
return True return True
# ---Added properties and methods---
@classmethod
def get_default_codomain(cls, domain):
domain_class = domain.__class__
try:
codomain_class = cls.default_codomain_dictionary[domain_class]
except KeyError:
raise TypeError(about._errors.cstring("ERROR: unknown domain"))
try:
transform_class = cls.transformation_dictionary[(domain_class,
codomain_class)]
except KeyError:
raise TypeError(about._errors.cstring(
"ERROR: No transformation for domain-codomain pair found."))
return transform_class.get_codomain(domain)
...@@ -4,4 +4,4 @@ from hplmtransformation import HPLMTransformation ...@@ -4,4 +4,4 @@ from hplmtransformation import HPLMTransformation
from lmgltransformation import LMGLTransformation from lmgltransformation import LMGLTransformation
from lmhptransformation import LMHPTransformation from lmhptransformation import LMHPTransformation
from transformation_factory import TransformationFactory from transformation_cache import TransformationCache
\ No newline at end of file \ No newline at end of file
import numpy as np import numpy as np
from transformation import Transformation from nifty.config import dependency_injector as gdi,\
from d2o import distributed_data_object about
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
from nifty import GLSpace, LMSpace from nifty import GLSpace, LMSpace
from slicing_transformation import SlicingTransformation
import lm_transformation_factory as ltf
gl = gdi.get('libsharp_wrapper_gl') libsharp = gdi.get('libsharp_wrapper_gl')
class GLLMTransformation(Transformation): class GLLMTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None): def __init__(self, domain, codomain=None, module=None):
if 'libsharp_wrapper_gl' not in gdi: if 'libsharp_wrapper_gl' not in gdi:
raise ImportError("The module libsharp is needed but not available") raise ImportError(about._errors.cstring(
"The module libsharp is needed but not available."))
if codomain is None:
self.domain = domain
self.codomain = self.get_codomain(domain)
elif self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod super(GLLMTransformation, self).__init__(domain, codomain, module)
def get_codomain(domain):
# ---Mandatory properties and methods---
@classmethod
def get_codomain(cls, domain):
""" """
Generates a compatible codomain to which transformations are Generates a compatible codomain to which transformations are
reasonable, i.e.\ an instance of the :py:class:`lm_space` class. reasonable, i.e.\ an instance of the :py:class:`lm_space` class.
...@@ -38,96 +37,89 @@ class GLLMTransformation(Transformation): ...@@ -38,96 +37,89 @@ class GLLMTransformation(Transformation):
codomain : LMSpace codomain : LMSpace
A compatible codomain. A compatible codomain.
""" """
if domain is None:
raise ValueError('ERROR: cannot generate codomain for None')
if not isinstance(domain, GLSpace): if not isinstance(domain, GLSpace):
raise TypeError('ERROR: domain needs to be a GLSpace') raise TypeError(about._errors.cstring(
"ERROR: domain needs to be a GLSpace"))
nlat = domain.nlat nlat = domain.nlat
lmax = nlat - 1 lmax = nlat - 1
mmax = nlat - 1 mmax = nlat - 1
if domain.dtype == np.dtype('float32'): if domain.dtype == np.dtype('float32'):
return LMSpace(lmax=lmax, mmax=mmax, dtype=np.complex64) return_dtype = np.float32
else: else:
return LMSpace(lmax=lmax, mmax=mmax, dtype=np.complex128) return_dtype = np.float64
result = LMSpace(lmax=lmax, mmax=mmax, dtype=return_dtype)
cls.check_codomain(domain, result)
return result
@staticmethod @staticmethod
def check_codomain(domain, codomain): def check_codomain(domain, codomain):
if not isinstance(domain, GLSpace): if not isinstance(domain, GLSpace):
raise TypeError('ERROR: domain is not a GLSpace') raise TypeError(about._errors.cstring(
"ERROR: domain is not a GLSpace"))
if codomain is None:
return False
if not isinstance(codomain, LMSpace): if not isinstance(codomain, LMSpace):
raise TypeError('ERROR: codomain must be a LMSpace.') raise TypeError(about._errors.cstring(
"ERROR: codomain must be a LMSpace."))
nlat = domain.nlat nlat = domain.nlat
nlon = domain.nlon nlon = domain.nlon
lmax = codomain.lmax lmax = codomain.lmax
mmax = codomain.mmax mmax = codomain.mmax
if (nlon != 2 * nlat - 1) or (lmax != nlat - 1) or (lmax != mmax): if lmax != mmax:
return False raise ValueError(about._errors.cstring(
'ERROR: codomain has lmax != mmax.'))
return True if lmax != nlat - 1:
raise ValueError(about._errors.cstring(
def transform(self, val, axes=None, **kwargs): 'ERROR: codomain has lmax != nlat - 1.'))
"""
GL -> LM transform method.
Parameters if nlon != 2 * nlat - 1:
---------- raise ValueError(about._errors.cstring(
val : np.ndarray or distributed_data_object 'ERROR: domain has nlon != 2 * nlat - 1.'))
The value array which is to be transformed
axes : None or tuple return None
The axes along which the transformation should take place
""" def _transformation_of_slice(self, inp, **kwargs):
if self.domain.discrete:
val = self.domain.weight(val, power=-0.5, axes=axes)
# shorthands for transform parameters
nlat = self.domain.nlat nlat = self.domain.nlat
nlon = self.domain.nlon nlon = self.domain.nlon
lmax = self.codomain.lmax lmax = self.codomain.lmax
mmax = self.codomain.mmax mmax = self.codomain.mmax
if isinstance(val, distributed_data_object): if issubclass(inp.dtype.type, np.complexfloating):
temp_val = val.get_full_data()
else: [resultReal, resultImag] = [self.libsharpMap2Alm(x,
temp_val = val nlat=nlat,
nlon=nlon,
return_val = None lmax=lmax,
mmax=mmax,
for slice_list in utilities.get_slice_list(temp_val.shape, axes): **kwargs)
if slice_list == [slice(None, None)]: for x in (inp.real, inp.imag)]
inp = temp_val
else: [resultReal, resultImag] = [ltf.buildIdx(x, lmax=lmax)
if return_val is None: for x in [resultReal, resultImag]]
return_val = np.empty_like(temp_val)
inp = temp_val[slice_list] result = self._combine_complex_result(resultReal, resultImag)
if self.domain.dtype == np.dtype('float32'):
inp = gl.map2alm_f(inp,
nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax)
else:
inp = gl.map2alm(inp,
nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax)
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
else: else:
return_val = return_val.astype(self.codomain.dtype, copy=False) result = self.libsharpMap2Alm(inp, nlat=nlat, nlon=nlon, lmax=lmax,
mmax=mmax)
result = ltf.buildIdx(result, lmax=lmax)
return result
return return_val # ---Added properties and methods---
def libsharpMap2Alm(self, inp, **kwargs):
if inp.dtype == np.dtype('float32'):
return libsharp.map2alm_f(inp, **kwargs)
elif inp.dtype == np.dtype('float64'):
return libsharp.map2alm(inp, **kwargs)
else:
about.warnings.cprint("WARNING: performing dtype conversion for "
"libsharp compatibility.")
casted_inp = inp.astype(np.dtype('float64'), copy=False)
result = libsharp.map2alm(casted_inp, **kwargs)
return result
\ No newline at end of file
import numpy as np import numpy as np
from transformation import Transformation from nifty.config import dependency_injector as gdi,\
from d2o import distributed_data_object about
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
from nifty import HPSpace, LMSpace from nifty import HPSpace, LMSpace
from slicing_transformation import SlicingTransformation
import lm_transformation_factory as ltf
hp = gdi.get('healpy') hp = gdi.get('healpy')
class HPLMTransformation(Transformation): class HPLMTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None): def __init__(self, domain, codomain=None, module=None):
if 'healpy' not in gdi: if 'healpy' not in gdi:
raise ImportError("The module healpy is needed but not available") raise ImportError(about._errors.cstring(
"The module healpy is needed but not available"))
if codomain is None:
self.domain = domain
self.codomain = self.get_codomain(domain)
elif self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod super(HPLMTransformation, self).__init__(domain, codomain, module)
def get_codomain(domain):
# ---Mandatory properties and methods---
@classmethod
def get_codomain(cls, domain):
""" """
Generates a compatible codomain to which transformations are Generates a compatible codomain to which transformations are
reasonable, i.e.\ an instance of the :py:class:`lm_space` class. reasonable, i.e.\ an instance of the :py:class:`lm_space` class.
...@@ -38,87 +38,65 @@ class HPLMTransformation(Transformation): ...@@ -38,87 +38,65 @@ class HPLMTransformation(Transformation):
codomain : LMSpace codomain : LMSpace
A compatible codomain. A compatible codomain.
""" """
if domain is None:
raise ValueError('ERROR: cannot generate codomain for None')
if not isinstance(domain, HPSpace): if not isinstance(domain, HPSpace):
raise TypeError('ERROR: domain needs to be a HPSpace') raise TypeError(about._errors.cstring(
"ERROR: domain needs to be a HPSpace"))
lmax = 3 * domain.nside - 1 lmax = 3 * domain.nside - 1
mmax = lmax mmax = lmax
return LMSpace(lmax=lmax, mmax=mmax, dtype=np.dtype('complex128'))
result = LMSpace(lmax=lmax, mmax=mmax, dtype=np.dtype('float64'))
cls.check_codomain(domain, result)
return result
@staticmethod @staticmethod
def check_codomain(domain, codomain): def check_codomain(domain, codomain):
if not isinstance(domain, HPSpace): if not isinstance(domain, HPSpace):
raise TypeError('ERROR: domain is not a HPSpace') raise TypeError(about._errors.cstring(
'ERROR: domain is not a HPSpace'))
if codomain is None:
return False
if not isinstance(codomain, LMSpace): if not isinstance(codomain, LMSpace):
raise TypeError('ERROR: codomain must be a LMSpace.') raise TypeError(about._errors.cstring(
'ERROR: codomain must be a LMSpace.'))
nside = domain.nside nside = domain.nside
lmax = codomain.lmax lmax = codomain.lmax
mmax = codomain.mmax mmax = codomain.mmax
if (3 * nside - 1 != lmax) or (lmax != mmax): if 3 * nside - 1 != lmax:
return False raise ValueError(about._errors.cstring(
'ERROR: codomain has 3*nside-1 != lmax.'))
return True if lmax != mmax:
raise ValueError(about._errors.cstring(
'ERROR: codomain has lmax != mmax.'))
def transform(self, val, axes=None, **kwargs): return None
"""
HP -> LM transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple def _transformation_of_slice(self, inp, **kwargs):
The axes along which the transformation should take place lmax = self.codomain.lmax
mmax = self.codomain.mmax
""" if issubclass(inp.dtype.type, np.complexfloating):
# get by number of iterations from kwargs [resultReal, resultImag] = [hp.map2alm(x.astype(np.float64,
niter = kwargs['niter'] if 'niter' in kwargs else 0 copy=False),
lmax=lmax,
mmax=mmax,
pol=True,
use_weights=False,
**kwargs)
for x in (inp.real, inp.imag)]
if self.domain.discrete: [resultReal, resultImag] = [ltf.buildIdx(x, lmax=lmax)
val = self.domain.weight(val, power=-0.5, axes=axes) for x in [resultReal, resultImag]]
# shorthands for transform parameters result = self._combine_complex_result(resultReal, resultImag)
lmax = self.codomain.lmax
mmax = self.codomain.mmax
if isinstance(val, distributed_data_object):
temp_val = val.get_full_data()
else:
temp_val = val
return_val = None
for slice_list in utilities.get_slice_list(temp_val.shape, axes):
if slice_list == [slice(None, None)]:
inp = temp_val
else:
if return_val is None:
return_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
inp = hp.map2alm(inp.astype(np.float64, copy=False),
lmax=lmax, mmax=mmax, iter=niter, pol=True,
use_weights=False, datapath=None)
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
else: