Commit a2f72e63 authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: Add LMGLTransformation and LMHPTransformation

parent 08bc6e76
......@@ -54,9 +54,6 @@ class GFFT(Transform):
else:
temp_inp = val
# Cast input datatype to codomain's dtype
temp_inp = temp_inp.astype(np.complex128, copy=False)
# Array for storing the result
return_val = None
......
......@@ -4,9 +4,9 @@ from d2o import distributed_data_object
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
gl = gdi.get('libsharp_wrapper_gl')
class GLTransform(Transform):
"""
GLTransform wrapper for libsharp's transform functions
......@@ -65,4 +65,4 @@ class GLTransform(Transform):
else:
return_val = return_val.astype(self.codomain.dtype, copy=False)
return return_val
\ No newline at end of file
return return_val
import numpy as np
from nifty import GLSpace, HPSpace
from nifty.config import about
import nifty.nifty_utilities as utilities
from transform import Transform
from d2o import distributed_data_object
class LMTransform(Transform):
"""
LMTransform for transforming to GL/HP space
"""
def __init__(self, domain, codomain, module):
self.domain = domain
self.codomain = codomain
self.module = module
def _transform(self, val):
if isinstance(self.codomain, GLSpace):
# shorthand for transform parameters
nlat = self.codomain.paradict['nlat']
nlon = self.codomain.paradict['nlon']
lmax = self.domain.paradict['lmax']
mmax = self.paradict['mmax']
if self.domain.dtype == np.dtype('complex64')
val = self.module.alm2map_f(val, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
else:
val = self.module.alm2map(val, nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax, cl=False)
elif isinstance(self.codomain, HPSpace):
# shorthand for transform parameters
nside = self.codomain.paradict['nside']
lmax = self.domain.paradict['lmax']
mmax = self.domain.paradict['mmax']
val = val.astype(np.complex128, copy=False)
val = self.module.alm2map(val, nside, lmax=lmax, mmax=mmax,
pixwin=False, fwhm=0.0, sigma=None,
pol=True, inplace=False)
else:
raise ValueError("ERROR: Unsupported transformation.")
return val
def transform(self, val, axes, **kwargs):
if isinstance(val, distributed_data_object):
temp_val = val.get_full_data()
else:
temp_val = val
return_val = None
for slice_list in utilities.get_slice_list(temp_val.shape, axes):
if slice_list == [slice(None, None)]:
inp = temp_val
else:
if return_val is None:
return_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
inp = self._transform(inp)
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
# re-weight if discrete
if self.codomain.discrete:
val = self.codomain.calc_weight(val, power=0.5)
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
else:
return_val = return_val.astype(self.codomain.dtype, copy=False)
return return_val
\ No newline at end of file
......@@ -2,6 +2,7 @@ from fftw import FFTW
from gfft import GFFT
from gltransform import GLTransform
from hptransform import HPTransform
from lmtransform import LMTransform
from nifty.config import about, dependency_injector as gdi
from nifty import RGSpace, GLSpace, HPSpace, LMSpace
......@@ -96,6 +97,26 @@ class Transformation(object):
if (3 * nside - 1 != lmax) or (lmax != mmax):
return False
elif isinstance(domain, LMSpace):
if isinstance(codomain, GLSpace):
nlat = codomain.paradict['nlat']
nlon = codomain.paradict['nlon']
lmax = domain.paradict['lmax']
mmax = domain.paradict['mmax']
if (lmax != mmax) or (nlat != lmax + 1) or \
(nlon != 2 * lmax + 1):
return False
elif isinstance(codomain, HPSpace):
nside = codomain.paradict['nside']
lmax = domain.paradict['lmax']
mmax = domain.paradict['mmax']
if (lmax != mmax) or (3 * nside - 1 != lmax):
return False
else:
raise ValueError('ERROT: LMSpace codomain should be ' +
'HPSpace or GLSpace')
else:
return False
......@@ -157,7 +178,7 @@ class RGRGTransformation(Transformation):
class GLLMTransformation(Transformation):
def __init__(self, domain, codomain, module=None):
def __init__(self, domain, codomain):
if Transformation.check_codomain(domain, codomain):
self._transform = GLTransform(domain, codomain)
else:
......@@ -168,11 +189,41 @@ class GLLMTransformation(Transformation):
class HPLMTransformation(Transformation):
def __init__(self, domain, codomain, module=None):
def __init__(self, domain, codomain):
if Transformation.check_codomain(domain, codomain):
self._transform = HPTransform(domain, codomain)
else:
raise ValueError("ERROR: Incompatible codomain!")
def transform(self, val, axes=None, **kwargs):
return self._transform.transform(val, axes, **kwargs)
class LMGLTransformation(Transformation):
def __init__(self, domain, codomain):
if Transformation.check_codomain(domain, codomain):
if gdi.get('libsharp_wrapper_gl') is None:
raise ImportError("The module libsharp is " +
"needed but not available.")
self._transform = LMTransform(domain, codomain,
gdi.get('libsharp_wrapper_gl'))
else:
raise ValueError("ERROR: Incompatible codomain!")
def transform(self, val, axes=None, **kwargs):
return self._transform.transform(val, axes, **kwargs)
class LMHPTransformation(Transformation):
def __init__(self, domain, codomain):
if Transformation.check_codomain(domain, codomain):
if gdi.get('healpy') is None:
raise ImportError("The module healpy is needed" +
"but not available.")
self._transform = LMTransform(domain, codomain,
gdi.get('healpy'))
else:
raise ValueError("ERROR: Incompatible codomain!")
def transform(self, val, axes=None, **kwargs):
return self._transform.transform(val, axes, **kwargs)
\ No newline at end of file
......@@ -22,6 +22,15 @@ class TransformationFactory(object):
return transformation.GLLMTransformation(domain, codomain, module)
elif isinstance(domain, HPSpace):
return transformation.HPLMTransformation(domain, codomain, module)
elif isinstance(domain, LMSpace):
if isinstance(codomain, GLSpace):
return transformation.LMGLTransformation(domain, codomain,
module)
elif isinstance(codomain, HPSpace):
return transformation.LMHPTransformation(domain, codomain,
module)
else:
raise ValueError('ERROR: incompatible codomain')
else:
raise ValueError('ERROR: unknown domain')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment