Commit c1828a0b authored by theos's avatar theos
Browse files

Fixed lmgltransformation.py

parent be3d197c
...@@ -39,7 +39,7 @@ class GLLMTransformation(SlicingTransformation): ...@@ -39,7 +39,7 @@ class GLLMTransformation(SlicingTransformation):
""" """
if domain is None: if domain is None:
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
"ERROR: cannot generate codomain for None")) "ERROR: cannot generate codomain for None-input"))
if not isinstance(domain, GLSpace): if not isinstance(domain, GLSpace):
raise TypeError(about._errors.cstring( raise TypeError(about._errors.cstring(
...@@ -49,9 +49,11 @@ class GLLMTransformation(SlicingTransformation): ...@@ -49,9 +49,11 @@ class GLLMTransformation(SlicingTransformation):
lmax = nlat - 1 lmax = nlat - 1
mmax = nlat - 1 mmax = nlat - 1
if domain.dtype == np.dtype('float32'): if domain.dtype == np.dtype('float32'):
return LMSpace(lmax=lmax, mmax=mmax, dtype=np.complex64) return_dtype = np.float32
else: else:
return LMSpace(lmax=lmax, mmax=mmax, dtype=np.complex128) return_dtype = np.float64
return LMSpace(lmax=lmax, mmax=mmax, dtype=return_dtype)
@classmethod @classmethod
def check_codomain(cls, domain, codomain): def check_codomain(cls, domain, codomain):
...@@ -59,9 +61,6 @@ class GLLMTransformation(SlicingTransformation): ...@@ -59,9 +61,6 @@ class GLLMTransformation(SlicingTransformation):
raise TypeError(about._errors.cstring( raise TypeError(about._errors.cstring(
"ERROR: domain is not a GLSpace")) "ERROR: domain is not a GLSpace"))
if codomain is None:
return False
if not isinstance(codomain, LMSpace): if not isinstance(codomain, LMSpace):
raise TypeError(about._errors.cstring( raise TypeError(about._errors.cstring(
"ERROR: codomain must be a LMSpace.")) "ERROR: codomain must be a LMSpace."))
...@@ -71,15 +70,18 @@ class GLLMTransformation(SlicingTransformation): ...@@ -71,15 +70,18 @@ class GLLMTransformation(SlicingTransformation):
lmax = codomain.lmax lmax = codomain.lmax
mmax = codomain.mmax mmax = codomain.mmax
if (nlon != 2 * nlat - 1) or (lmax != nlat - 1) or (lmax != mmax): if lmax != mmax:
return False raise ValueError('ERROR: codomain has lmax != mmax.')
return True if lmax != nlat - 1:
raise ValueError('ERROR: codomain has lmax != nlat - 1.')
# ---Added properties and methods--- if nlon != 2 * nlat - 1:
raise ValueError('ERROR: domain has nlon != 2 * nlat - 1.')
return None
def _transformation_of_slice(self, inp): def _transformation_of_slice(self, inp):
# shorthands for transform parameters
nlat = self.domain.nlat nlat = self.domain.nlat
nlon = self.domain.nlon nlon = self.domain.nlon
lmax = self.codomain.lmax lmax = self.codomain.lmax
...@@ -94,8 +96,9 @@ class GLLMTransformation(SlicingTransformation): ...@@ -94,8 +96,9 @@ class GLLMTransformation(SlicingTransformation):
mmax=mmax) mmax=mmax)
for x in (inp.real, inp.imag)] for x in (inp.real, inp.imag)]
resultReal = ltf.buildIdx(resultReal, lmax=lmax) [resultReal, resultImag] = [ltf.buildIdx(x, lmax=lmax)
resultImag = ltf.buildIdx(resultImag, lmax=lmax) for x in [resultReal, resultImag]]
# construct correct complex dtype # construct correct complex dtype
one = resultReal.dtype.type(1) one = resultReal.dtype.type(1)
result_dtype = np.dtype(type(one + 1j)) result_dtype = np.dtype(type(one + 1j))
...@@ -110,6 +113,8 @@ class GLLMTransformation(SlicingTransformation): ...@@ -110,6 +113,8 @@ class GLLMTransformation(SlicingTransformation):
return result return result
# ---Added properties and methods---
def libsharpMap2Alm(self, inp, **kwargs): def libsharpMap2Alm(self, inp, **kwargs):
if inp.dtype == np.dtype('float32'): if inp.dtype == np.dtype('float32'):
return libsharp.map2alm_f(inp, **kwargs) return libsharp.map2alm_f(inp, **kwargs)
...@@ -118,3 +123,6 @@ class GLLMTransformation(SlicingTransformation): ...@@ -118,3 +123,6 @@ class GLLMTransformation(SlicingTransformation):
else: else:
about.warnings.cprint("WARNING: performing dtype conversion for " about.warnings.cprint("WARNING: performing dtype conversion for "
"libsharp compatibility.") "libsharp compatibility.")
casted_inp = inp.astype(np.dtype('float64'), copy=False)
result = libsharp.map2alm(casted_inp, **kwargs)
return result
\ No newline at end of file
import numpy as np import numpy as np
from transformation import Transformation from nifty.config import dependency_injector as gdi,\
from d2o import distributed_data_object about
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
from nifty import GLSpace, LMSpace from nifty import GLSpace, LMSpace
from slicing_transformation import SlicingTransformation
import lm_transformation_factory as ltf import lm_transformation_factory as ltf
gl = gdi.get('libsharp_wrapper_gl')
libsharp = gdi.get('libsharp_wrapper_gl')
class LMGLTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
class LMGLTransformation(Transformation):
def __init__(self, domain, codomain=None, module=None): def __init__(self, domain, codomain=None, module=None):
if gdi.get('libsharp_wrapper_gl') is None: if 'libsharp_wrapper_gl' not in gdi:
raise ImportError( raise ImportError(about._errors.cstring(
"The module libsharp is needed but not available.") "The module libsharp is needed but not available."))
if codomain is None: super(LMGLTransformation, self).__init__(domain, codomain, module)
self.domain = domain
self.codomain = self.get_codomain(domain) # ---Mandatory properties and methods---
elif self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod @staticmethod
def get_codomain(domain): def get_codomain(domain):
...@@ -48,20 +46,21 @@ class LMGLTransformation(Transformation): ...@@ -48,20 +46,21 @@ class LMGLTransformation(Transformation):
`arXiv:1303.4945 <http://www.arxiv.org/abs/1303.4945>`_ `arXiv:1303.4945 <http://www.arxiv.org/abs/1303.4945>`_
""" """
if domain is None: if domain is None:
raise ValueError('ERROR: cannot generate codomain for None') raise ValueError(about._errors.cstring(
'ERROR: cannot generate codomain for None-input'))
if not isinstance(domain, LMSpace): if not isinstance(domain, LMSpace):
raise TypeError('ERROR: domain needs to be a LMSpace') raise TypeError(about._errors.cstring(
'ERROR: domain needs to be a LMSpace'))
if domain.dtype == np.dtype('complex64'): if domain.dtype is np.dtype('float32'):
new_dtype = np.float32 new_dtype = np.float32
elif domain.dtype == np.dtype('complex128'):
new_dtype = np.float64
else: else:
raise ValueError('ERROR: unsupported domain dtype') new_dtype = np.float64
nlat = domain.lmax + 1 nlat = domain.lmax + 1
nlon = domain.lmax * 2 + 1 nlon = domain.lmax * 2 + 1
return GLSpace(nlat=nlat, nlon=nlon, dtype=new_dtype) return GLSpace(nlat=nlat, nlon=nlon, dtype=new_dtype)
@staticmethod @staticmethod
...@@ -69,9 +68,6 @@ class LMGLTransformation(Transformation): ...@@ -69,9 +68,6 @@ class LMGLTransformation(Transformation):
if not isinstance(domain, LMSpace): if not isinstance(domain, LMSpace):
raise TypeError('ERROR: domain is not a LMSpace') raise TypeError('ERROR: domain is not a LMSpace')
if codomain is None:
return False
if not isinstance(codomain, GLSpace): if not isinstance(codomain, GLSpace):
raise TypeError('ERROR: codomain must be a GLSpace.') raise TypeError('ERROR: codomain must be a GLSpace.')
...@@ -80,78 +76,59 @@ class LMGLTransformation(Transformation): ...@@ -80,78 +76,59 @@ class LMGLTransformation(Transformation):
lmax = domain.lmax lmax = domain.lmax
mmax = domain.mmax mmax = domain.mmax
if (lmax != mmax) or (nlat != lmax + 1) or (nlon != 2 * lmax + 1): if lmax != mmax:
return False raise ValueError('ERROR: domain has lmax != mmax.')
return True if nlat != lmax + 1:
raise ValueError('ERROR: codomain has nlat != lmax + 1.')
def transform(self, val, axes=None, **kwargs): if nlon != 2 * lmax + 1:
""" raise ValueError('ERROR: domain has nlon != 2 * lmax + 1.')
LM -> GL transform method.
Parameters return None
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple def _transformation_of_slice(self, inp):
The axes along which the transformation should take place nlat = self.codomain.nlat
nlon = self.codomain.nlon
lmax = self.domain.lmax
mmax = self.domain.mmax
""" if issubclass(inp.dtype.type, np.complexfloating):
if isinstance(val, distributed_data_object): [resultReal, resultImag] = [ltf.buildLm(x, lmax=lmax)
temp_val = val.get_full_data() for x in (inp.real, inp.imag)]
else:
temp_val = val [resultReal, resultImag] = [self.libsharpAlm2Map(x,
nlat=nlat,
return_val = None nlon=nlon,
lmax=lmax,
for slice_list in utilities.get_slice_list(temp_val.shape, axes): mmax=mmax,
if slice_list == [slice(None, None)]: cl=False)
inp = temp_val for x in [resultReal, resultImag]]
else:
if return_val is None: # construct correct complex dtype
return_val = np.empty_like(temp_val) one = resultReal.dtype.type(1)
inp = temp_val[slice_list] result_dtype = np.dtype(type(one + 1j))
nlat = self.codomain.nlat result = np.empty_like(resultReal, dtype=result_dtype)
nlon = self.codomain.nlon result.real = resultReal
lmax = self.domain.lmax result.imag = resultImag
mmax = self.mmax
if inp.dtype >= np.dtype('complex64'):
inpReal = np.real(inp)
inpImag = np.imag(inp)
inpReal = ltf.buildLm(inpReal,lmax=lmax)
inpImag = ltf.buildLm(inpImag,lmax=lmax)
inpReal = self.GlAlm2Map(inpReal, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
inpImag = self.GlAlm2Map(inpImag, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
inp = inpReal+inpImag*(1j)
else:
inp = ltf.buildLm(inp, lmax=lmax)
inp = self.GlAlm2Map(inp, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
# re-weight if discrete
if self.codomain.discrete:
val = self.codomain.weight(val, power=0.5, axes=axes)
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
else: else:
return_val = return_val.astype(self.codomain.dtype, copy=False) result = ltf.buildLm(inp, lmax=lmax)
result = self.libsharpAlm2Map(result, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
return result
return return_val # ---Added properties and methods---
def GlAlm2Map(self, inp, **kwargs): def libsharpAlm2Map(self, inp, **kwargs):
if inp.dtype == np.dtype('complex64'): if inp.dtype == np.dtype('complex64'):
return gl.alm2map_f(inp, kwargs) return libsharp.alm2map_f(inp, **kwargs)
elif inp.dtype == np.dtype('complex128'):
return libsharp.alm2map(inp, **kwargs)
else: else:
return gl.alm2map(inp, kwargs) about.warnings.cprint("WARNING: performing dtype conversion for "
"libsharp compatibility.")
casted_inp = inp.astype(np.dtype('complex128'), copy=False)
result = libsharp.alm2map(casted_inp, **kwargs)
return result
...@@ -34,7 +34,7 @@ class RGRGTransformation(Transformation): ...@@ -34,7 +34,7 @@ class RGRGTransformation(Transformation):
raise ValueError('ERROR: unknow FFT module:' + module) raise ValueError('ERROR: unknow FFT module:' + module)
@staticmethod @staticmethod
def get_codomain(domain, dtype=None, zerocenter=None, **kwargs): def get_codomain(domain, dtype=None, zerocenter=None):
""" """
Generates a compatible codomain to which transformations are Generates a compatible codomain to which transformations are
reasonable, i.e.\ either a shifted grid or a Fourier conjugate reasonable, i.e.\ either a shifted grid or a Fourier conjugate
......
...@@ -22,12 +22,12 @@ class SlicingTransformation(Transformation): ...@@ -22,12 +22,12 @@ class SlicingTransformation(Transformation):
return_val = val.copy_empty(dtype=self.codomain.dtype, return_val = val.copy_empty(dtype=self.codomain.dtype,
global_shape=return_shape) global_shape=return_shape)
data = val[slice_list] data = val.get_data(slice_list, copy=False)
data = data.get_full_data() data = data.get_full_data()
data = self._transformation_of_slice(data) data = self._transformation_of_slice(data)
return_val[slice_list] = data return_val.set_data(data=data, to_key=slice_list, copy=False)
return return_val return return_val
......
...@@ -13,14 +13,13 @@ class Transformation(object): ...@@ -13,14 +13,13 @@ class Transformation(object):
if codomain is None: if codomain is None:
self.domain = domain self.domain = domain
self.codomain = self.get_codomain(domain) self.codomain = self.get_codomain(domain)
elif self.check_codomain(domain, codomain): else:
self.check_codomain(domain, codomain)
self.domain = domain self.domain = domain
self.codomain = codomain self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod @staticmethod
def get_codomain(domain, dtype=None, zerocenter=None, **kwargs): def get_codomain(domain, dtype=None, zerocenter=None):
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
......
...@@ -129,10 +129,10 @@ class LMSpace(Space): ...@@ -129,10 +129,10 @@ class LMSpace(Space):
def dim(self): def dim(self):
l = self.lmax l = self.lmax
m = self.mmax m = self.mmax
# the LMSpace consist of the full triangle, minus two little triangles # the LMSpace consist of the full triangle (including -m's!),
# if mmax < lmax # minus two little triangles if mmax < lmax
# dim = l(l+1)/2 - 2 * (l-m)(l-m+1)/2 # dim = (((2*(l+1)-1)+1)**2/4 - 2 * (l-m)(l-m+1)/2
return np.int(l*(l+1.)/2. - 2.*(l-m)*(l-m+1.)/2.) return np.int((l+1)**2 - (l-m)*(l-m+1.))
@property @property
def total_volume(self): def total_volume(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment