There is a maintenance of MPCDF Gitlab on Thursday, April 22st 2020, 9:00 am CEST - Expect some service interruptions during this time

Commit c1828a0b authored by theos's avatar theos

Fixed lmgltransformation.py

parent be3d197c
......@@ -39,7 +39,7 @@ class GLLMTransformation(SlicingTransformation):
"""
if domain is None:
raise ValueError(about._errors.cstring(
"ERROR: cannot generate codomain for None"))
"ERROR: cannot generate codomain for None-input"))
if not isinstance(domain, GLSpace):
raise TypeError(about._errors.cstring(
......@@ -49,9 +49,11 @@ class GLLMTransformation(SlicingTransformation):
lmax = nlat - 1
mmax = nlat - 1
if domain.dtype == np.dtype('float32'):
return LMSpace(lmax=lmax, mmax=mmax, dtype=np.complex64)
return_dtype = np.float32
else:
return LMSpace(lmax=lmax, mmax=mmax, dtype=np.complex128)
return_dtype = np.float64
return LMSpace(lmax=lmax, mmax=mmax, dtype=return_dtype)
@classmethod
def check_codomain(cls, domain, codomain):
......@@ -59,9 +61,6 @@ class GLLMTransformation(SlicingTransformation):
raise TypeError(about._errors.cstring(
"ERROR: domain is not a GLSpace"))
if codomain is None:
return False
if not isinstance(codomain, LMSpace):
raise TypeError(about._errors.cstring(
"ERROR: codomain must be a LMSpace."))
......@@ -71,15 +70,18 @@ class GLLMTransformation(SlicingTransformation):
lmax = codomain.lmax
mmax = codomain.mmax
if (nlon != 2 * nlat - 1) or (lmax != nlat - 1) or (lmax != mmax):
return False
if lmax != mmax:
raise ValueError('ERROR: codomain has lmax != mmax.')
return True
if lmax != nlat - 1:
raise ValueError('ERROR: codomain has lmax != nlat - 1.')
# ---Added properties and methods---
if nlon != 2 * nlat - 1:
raise ValueError('ERROR: domain has nlon != 2 * nlat - 1.')
return None
def _transformation_of_slice(self, inp):
# shorthands for transform parameters
nlat = self.domain.nlat
nlon = self.domain.nlon
lmax = self.codomain.lmax
......@@ -94,8 +96,9 @@ class GLLMTransformation(SlicingTransformation):
mmax=mmax)
for x in (inp.real, inp.imag)]
resultReal = ltf.buildIdx(resultReal, lmax=lmax)
resultImag = ltf.buildIdx(resultImag, lmax=lmax)
[resultReal, resultImag] = [ltf.buildIdx(x, lmax=lmax)
for x in [resultReal, resultImag]]
# construct correct complex dtype
one = resultReal.dtype.type(1)
result_dtype = np.dtype(type(one + 1j))
......@@ -110,6 +113,8 @@ class GLLMTransformation(SlicingTransformation):
return result
# ---Added properties and methods---
def libsharpMap2Alm(self, inp, **kwargs):
if inp.dtype == np.dtype('float32'):
return libsharp.map2alm_f(inp, **kwargs)
......@@ -118,3 +123,6 @@ class GLLMTransformation(SlicingTransformation):
else:
about.warnings.cprint("WARNING: performing dtype conversion for "
"libsharp compatibility.")
casted_inp = inp.astype(np.dtype('float64'), copy=False)
result = libsharp.map2alm(casted_inp, **kwargs)
return result
\ No newline at end of file
import numpy as np
from transformation import Transformation
from d2o import distributed_data_object
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
from nifty.config import dependency_injector as gdi,\
about
from nifty import GLSpace, LMSpace
from slicing_transformation import SlicingTransformation
import lm_transformation_factory as ltf
gl = gdi.get('libsharp_wrapper_gl')
libsharp = gdi.get('libsharp_wrapper_gl')
class LMGLTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
class LMGLTransformation(Transformation):
def __init__(self, domain, codomain=None, module=None):
if gdi.get('libsharp_wrapper_gl') is None:
raise ImportError(
"The module libsharp is needed but not available.")
if codomain is None:
self.domain = domain
self.codomain = self.get_codomain(domain)
elif self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
if 'libsharp_wrapper_gl' not in gdi:
raise ImportError(about._errors.cstring(
"The module libsharp is needed but not available."))
super(LMGLTransformation, self).__init__(domain, codomain, module)
# ---Mandatory properties and methods---
@staticmethod
def get_codomain(domain):
......@@ -48,20 +46,21 @@ class LMGLTransformation(Transformation):
`arXiv:1303.4945 <http://www.arxiv.org/abs/1303.4945>`_
"""
if domain is None:
raise ValueError('ERROR: cannot generate codomain for None')
raise ValueError(about._errors.cstring(
'ERROR: cannot generate codomain for None-input'))
if not isinstance(domain, LMSpace):
raise TypeError('ERROR: domain needs to be a LMSpace')
raise TypeError(about._errors.cstring(
'ERROR: domain needs to be a LMSpace'))
if domain.dtype == np.dtype('complex64'):
if domain.dtype is np.dtype('float32'):
new_dtype = np.float32
elif domain.dtype == np.dtype('complex128'):
new_dtype = np.float64
else:
raise ValueError('ERROR: unsupported domain dtype')
new_dtype = np.float64
nlat = domain.lmax + 1
nlon = domain.lmax * 2 + 1
return GLSpace(nlat=nlat, nlon=nlon, dtype=new_dtype)
@staticmethod
......@@ -69,9 +68,6 @@ class LMGLTransformation(Transformation):
if not isinstance(domain, LMSpace):
raise TypeError('ERROR: domain is not a LMSpace')
if codomain is None:
return False
if not isinstance(codomain, GLSpace):
raise TypeError('ERROR: codomain must be a GLSpace.')
......@@ -80,78 +76,59 @@ class LMGLTransformation(Transformation):
lmax = domain.lmax
mmax = domain.mmax
if (lmax != mmax) or (nlat != lmax + 1) or (nlon != 2 * lmax + 1):
return False
if lmax != mmax:
raise ValueError('ERROR: domain has lmax != mmax.')
return True
if nlat != lmax + 1:
raise ValueError('ERROR: codomain has nlat != lmax + 1.')
def transform(self, val, axes=None, **kwargs):
"""
LM -> GL transform method.
if nlon != 2 * lmax + 1:
raise ValueError('ERROR: domain has nlon != 2 * lmax + 1.')
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
return None
axes : None or tuple
The axes along which the transformation should take place
def _transformation_of_slice(self, inp):
nlat = self.codomain.nlat
nlon = self.codomain.nlon
lmax = self.domain.lmax
mmax = self.domain.mmax
"""
if isinstance(val, distributed_data_object):
temp_val = val.get_full_data()
else:
temp_val = val
return_val = None
for slice_list in utilities.get_slice_list(temp_val.shape, axes):
if slice_list == [slice(None, None)]:
inp = temp_val
else:
if return_val is None:
return_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
nlat = self.codomain.nlat
nlon = self.codomain.nlon
lmax = self.domain.lmax
mmax = self.mmax
if inp.dtype >= np.dtype('complex64'):
inpReal = np.real(inp)
inpImag = np.imag(inp)
inpReal = ltf.buildLm(inpReal,lmax=lmax)
inpImag = ltf.buildLm(inpImag,lmax=lmax)
inpReal = self.GlAlm2Map(inpReal, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
inpImag = self.GlAlm2Map(inpImag, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
inp = inpReal+inpImag*(1j)
else:
inp = ltf.buildLm(inp, lmax=lmax)
inp = self.GlAlm2Map(inp, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
# re-weight if discrete
if self.codomain.discrete:
val = self.codomain.weight(val, power=0.5, axes=axes)
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
if issubclass(inp.dtype.type, np.complexfloating):
[resultReal, resultImag] = [ltf.buildLm(x, lmax=lmax)
for x in (inp.real, inp.imag)]
[resultReal, resultImag] = [self.libsharpAlm2Map(x,
nlat=nlat,
nlon=nlon,
lmax=lmax,
mmax=mmax,
cl=False)
for x in [resultReal, resultImag]]
# construct correct complex dtype
one = resultReal.dtype.type(1)
result_dtype = np.dtype(type(one + 1j))
result = np.empty_like(resultReal, dtype=result_dtype)
result.real = resultReal
result.imag = resultImag
else:
return_val = return_val.astype(self.codomain.dtype, copy=False)
result = ltf.buildLm(inp, lmax=lmax)
result = self.libsharpAlm2Map(result, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
return result
return return_val
# ---Added properties and methods---
def GlAlm2Map(self, inp, **kwargs):
def libsharpAlm2Map(self, inp, **kwargs):
if inp.dtype == np.dtype('complex64'):
return gl.alm2map_f(inp, kwargs)
return libsharp.alm2map_f(inp, **kwargs)
elif inp.dtype == np.dtype('complex128'):
return libsharp.alm2map(inp, **kwargs)
else:
return gl.alm2map(inp, kwargs)
about.warnings.cprint("WARNING: performing dtype conversion for "
"libsharp compatibility.")
casted_inp = inp.astype(np.dtype('complex128'), copy=False)
result = libsharp.alm2map(casted_inp, **kwargs)
return result
......@@ -34,7 +34,7 @@ class RGRGTransformation(Transformation):
raise ValueError('ERROR: unknow FFT module:' + module)
@staticmethod
def get_codomain(domain, dtype=None, zerocenter=None, **kwargs):
def get_codomain(domain, dtype=None, zerocenter=None):
"""
Generates a compatible codomain to which transformations are
reasonable, i.e.\ either a shifted grid or a Fourier conjugate
......
......@@ -22,12 +22,12 @@ class SlicingTransformation(Transformation):
return_val = val.copy_empty(dtype=self.codomain.dtype,
global_shape=return_shape)
data = val[slice_list]
data = val.get_data(slice_list, copy=False)
data = data.get_full_data()
data = self._transformation_of_slice(data)
return_val[slice_list] = data
return_val.set_data(data=data, to_key=slice_list, copy=False)
return return_val
......
......@@ -13,14 +13,13 @@ class Transformation(object):
if codomain is None:
self.domain = domain
self.codomain = self.get_codomain(domain)
elif self.check_codomain(domain, codomain):
else:
self.check_codomain(domain, codomain)
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod
def get_codomain(domain, dtype=None, zerocenter=None, **kwargs):
def get_codomain(domain, dtype=None, zerocenter=None):
raise NotImplementedError
@staticmethod
......
......@@ -129,10 +129,10 @@ class LMSpace(Space):
def dim(self):
l = self.lmax
m = self.mmax
# the LMSpace consist of the full triangle, minus two little triangles
# if mmax < lmax
# dim = l(l+1)/2 - 2 * (l-m)(l-m+1)/2
return np.int(l*(l+1.)/2. - 2.*(l-m)*(l-m+1.)/2.)
# the LMSpace consist of the full triangle (including -m's!),
# minus two little triangles if mmax < lmax
# dim = (((2*(l+1)-1)+1)**2/4 - 2 * (l-m)(l-m+1)/2
return np.int((l+1)**2 - (l-m)*(l-m+1.))
@property
def total_volume(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment