Commit 109b7f8f authored by Theo Steininger's avatar Theo Steininger
Browse files

Refactored **transformation.py classes and FFTOperator. Especially fixed...

Refactored **transformation.py classes and FFTOperator. Especially fixed RGRGTransformation.check_codomain.
parent 9a008092
......@@ -55,33 +55,25 @@ class FFTOperator(LinearOperator):
def __init__(self, domain=(), target=None, module=None,
domain_dtype=None, target_dtype=None):
self._domain = self._parse_domain(domain)
# Initialize domain and target
self._domain = self._parse_domain(domain)
if len(self.domain) != 1:
raise ValueError(
'ERROR: TransformationOperator accepts only exactly one '
'space as input domain.')
raise ValueError("TransformationOperator accepts only exactly one "
"space as input domain.")
if target is None:
target = (self.get_default_codomain(self.domain[0]), )
self._target = self._parse_domain(target)
if len(self.target) != 1:
raise ValueError("TransformationOperator accepts only exactly one "
"space as output target.")
# Create transformation instances
try:
forward_class = self.transformation_dictionary[
forward_class = self.transformation_dictionary[
(self.domain[0].__class__, self.target[0].__class__)]
except KeyError:
raise ValueError(
"No forward transformation for domain-target pair "
"found.")
try:
backward_class = self.transformation_dictionary[
backward_class = self.transformation_dictionary[
(self.target[0].__class__, self.domain[0].__class__)]
except KeyError:
raise ValueError(
"No backward transformation for domain-target pair "
"found.")
self._forward_transformation = TransformationCache.create(
forward_class, self.domain[0], self.target[0], module=module)
......
......@@ -21,7 +21,7 @@ import numpy as np
from nifty.config import dependency_injector as gdi
from nifty import GLSpace, LMSpace
from slicing_transformation import SlicingTransformation
import lm_transformation_factory as ltf
import lm_transformation_factory
pyHealpix = gdi.get('pyHealpix')
......@@ -58,30 +58,21 @@ class GLLMTransformation(SlicingTransformation):
"""
if not isinstance(domain, GLSpace):
raise TypeError(
"domain needs to be a GLSpace")
raise TypeError("domain needs to be a GLSpace")
nlat = domain.nlat
lmax = nlat - 1
mmax = nlat - 1
if domain.dtype == np.dtype('float32'):
return_dtype = np.float32
else:
return_dtype = np.float64
result = LMSpace(lmax=lmax, mmax=mmax, dtype=return_dtype)
cls.check_codomain(domain, result)
result = LMSpace(lmax=lmax, dtype=domain.dtype)
return result
@staticmethod
def check_codomain(domain, codomain):
@classmethod
def check_codomain(cls, domain, codomain):
if not isinstance(domain, GLSpace):
raise TypeError(
"domain is not a GLSpace")
raise TypeError("domain is not a GLSpace")
if not isinstance(codomain, LMSpace):
raise TypeError(
"codomain must be a LMSpace.")
raise TypeError("codomain must be a LMSpace.")
nlat = domain.nlat
nlon = domain.nlon
......@@ -89,18 +80,15 @@ class GLLMTransformation(SlicingTransformation):
mmax = codomain.mmax
if lmax != mmax:
raise ValueError(
'ERROR: codomain has lmax != mmax.')
cls.Logger.warn("Unrecommended: codomain has lmax != mmax.")
if lmax != nlat - 1:
raise ValueError(
'ERROR: codomain has lmax != nlat - 1.')
cls.Logger.warn("Unrecommended: codomain has lmax != nlat - 1.")
if nlon != 2 * nlat - 1:
raise ValueError(
'ERROR: domain has nlon != 2 * nlat - 1.')
if nlon != 2*nlat - 1:
cls.Logger.warn("Unrecommended: domain has nlon != 2*nlat - 1.")
return None
super(GLLMTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
nlat = self.domain.nlat
......@@ -108,20 +96,20 @@ class GLLMTransformation(SlicingTransformation):
lmax = self.codomain.lmax
mmax = self.codomain.mmax
sjob=pyHealpix.sharpjob_d()
sjob.set_Gauss_geometry(nlat,nlot)
sjob.set_triangular_alm_info(lmax,mmax)
sjob = pyHealpix.sharpjob_d()
sjob.set_Gauss_geometry(nlat, nlon)
sjob.set_triangular_alm_info(lmax, mmax)
if issubclass(inp.dtype.type, np.complexfloating):
[resultReal, resultImag] = [sjob.map2alm(x)
for x in (inp.real, inp.imag)]
[resultReal, resultImag] = [ltf.buildIdx(x, lmax=lmax)
for x in [resultReal, resultImag]]
[resultReal,
resultImag] = [lm_transformation_factory.buildIdx(x, lmax=lmax)
for x in [resultReal, resultImag]]
result = self._combine_complex_result(resultReal, resultImag)
else:
result = sjob.map2alm(inp)
result = ltf.buildIdx(result, lmax=lmax)
result = lm_transformation_factory.buildIdx(result, lmax=lmax)
return result
......@@ -22,7 +22,7 @@ from nifty.config import dependency_injector as gdi
from nifty import HPSpace, LMSpace
from slicing_transformation import SlicingTransformation
import lm_transformation_factory as ltf
import lm_transformation_factory
pyHealpix = gdi.get('pyHealpix')
......@@ -59,48 +59,49 @@ class HPLMTransformation(SlicingTransformation):
"""
if not isinstance(domain, HPSpace):
raise TypeError(
"domain needs to be a HPSpace")
raise TypeError("domain needs to be a HPSpace")
lmax = 2 * domain.nside
lmax = 2*domain.nside
result = LMSpace(lmax=lmax, dtype=np.dtype('float64'))
cls.check_codomain(domain, result)
result = LMSpace(lmax=lmax, dtype=domain.dtype)
return result
@staticmethod
def check_codomain(domain, codomain):
@classmethod
def check_codomain(cls, domain, codomain):
if not isinstance(domain, HPSpace):
raise TypeError(
'ERROR: domain is not a HPSpace')
raise TypeError("domain is not a HPSpace")
if not isinstance(codomain, LMSpace):
raise TypeError(
'ERROR: codomain must be a LMSpace.')
raise TypeError("codomain must be a LMSpace.")
nside = domain.nside
lmax = codomain.lmax
nside = domain.nside
if lmax != 2*nside:
cls.Logger.warn("Unrecommended: lmax != 2*nside.")
return None
super(HPLMTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
nside = self.domain.nside
lmax = self.codomain.lmax
mmax = lmax
sjob=pyHealpix.sharpjob_d()
sjob = pyHealpix.sharpjob_d()
sjob.set_Healpix_geometry(nside)
sjob.set_triangular_alm_info(lmax,mmax)
sjob.set_triangular_alm_info(lmax, mmax)
if issubclass(inp.dtype.type, np.complexfloating):
[resultReal, resultImag] = [sjob.map2alm(x)
for x in (inp.real, inp.imag)]
[resultReal, resultImag] = [ltf.buildIdx(x, lmax=lmax)
for x in [resultReal, resultImag]]
[resultReal,
resultImag] = [lm_transformation_factory.buildIdx(x, lmax=lmax)
for x in [resultReal, resultImag]]
result = self._combine_complex_result(resultReal, resultImag)
else:
result = sjob.map2alm(inp)
result = ltf.buildIdx(result, lmax=lmax)
result = lm_transformation_factory.buildIdx(result, lmax=lmax)
return result
......@@ -21,7 +21,7 @@ from nifty.config import dependency_injector as gdi
from nifty import GLSpace, LMSpace
from slicing_transformation import SlicingTransformation
import lm_transformation_factory as ltf
import lm_transformation_factory
pyHealpix = gdi.get('pyHealpix')
......@@ -64,30 +64,21 @@ class LMGLTransformation(SlicingTransformation):
`arXiv:1303.4945 <http://www.arxiv.org/abs/1303.4945>`_
"""
if not isinstance(domain, LMSpace):
raise TypeError(
'ERROR: domain needs to be a LMSpace')
if domain.dtype is np.dtype('float32'):
new_dtype = np.float32
else:
new_dtype = np.float64
raise TypeError("domain needs to be a LMSpace")
nlat = domain.lmax + 1
nlon = domain.lmax * 2 + 1
nlon = domain.lmax*2 + 1
result = GLSpace(nlat=nlat, nlon=nlon, dtype=new_dtype)
cls.check_codomain(domain, result)
result = GLSpace(nlat=nlat, nlon=nlon, dtype=domain.dtype)
return result
@staticmethod
def check_codomain(domain, codomain):
@classmethod
def check_codomain(cls, domain, codomain):
if not isinstance(domain, LMSpace):
raise TypeError(
'ERROR: domain is not a LMSpace')
raise TypeError("domain is not a LMSpace")
if not isinstance(codomain, GLSpace):
raise TypeError(
'ERROR: codomain must be a GLSpace.')
raise TypeError("codomain must be a GLSpace.")
nlat = codomain.nlat
nlon = codomain.nlon
......@@ -95,18 +86,15 @@ class LMGLTransformation(SlicingTransformation):
mmax = domain.mmax
if lmax != mmax:
raise ValueError(
'ERROR: domain has lmax != mmax.')
cls.Logger.warn("Unrecommended: codomain has lmax != mmax.")
if nlat != lmax + 1:
raise ValueError(
'ERROR: codomain has nlat != lmax + 1.')
cls.Logger.warn("Unrecommended: codomain has nlat != lmax + 1.")
if nlon != 2 * lmax + 1:
raise ValueError(
'ERROR: domain has nlon != 2 * lmax + 1.')
if nlon != 2*lmax + 1:
cls.Logger.warn("Unrecommended: domain has nlon != 2*lmax + 1.")
return None
super(LMGLTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
nlat = self.codomain.nlat
......@@ -114,12 +102,13 @@ class LMGLTransformation(SlicingTransformation):
lmax = self.domain.lmax
mmax = self.domain.mmax
sjob=pyHealpix.sharpjob_d()
sjob.set_Gauss_geometry(nlat,nlot)
sjob.set_triangular_alm_info(lmax,mmax)
sjob = pyHealpix.sharpjob_d()
sjob.set_Gauss_geometry(nlat, nlon)
sjob.set_triangular_alm_info(lmax, mmax)
if issubclass(inp.dtype.type, np.complexfloating):
[resultReal, resultImag] = [ltf.buildLm(x, lmax=lmax)
for x in (inp.real, inp.imag)]
[resultReal,
resultImag] = [lm_transformation_factory.buildLm(x, lmax=lmax)
for x in (inp.real, inp.imag)]
[resultReal, resultImag] = [sjob.alm2map(x)
for x in [resultReal, resultImag]]
......@@ -127,7 +116,7 @@ class LMGLTransformation(SlicingTransformation):
result = self._combine_complex_result(resultReal, resultImag)
else:
result = ltf.buildLm(inp, lmax=lmax)
result = lm_transformation_factory.buildLm(inp, lmax=lmax)
result = sjob.alm2map(result)
return result
......@@ -20,7 +20,7 @@ import numpy as np
from nifty.config import dependency_injector as gdi
from nifty import HPSpace, LMSpace
from slicing_transformation import SlicingTransformation
import lm_transformation_factory as ltf
import lm_transformation_factory
pyHealpix = gdi.get('pyHealpix')
......@@ -62,40 +62,40 @@ class LMHPTransformation(SlicingTransformation):
Distributed on the Sphere", *ApJ* 622..759G.
"""
if not isinstance(domain, LMSpace):
raise TypeError(
'ERROR: domain needs to be a LMSpace')
raise TypeError("domain needs to be a LMSpace.")
nside = np.max(domain.lmax // 2,1)
result = HPSpace(nside=nside)
cls.check_codomain(domain, result)
nside = np.max(domain.lmax//2, 1)
result = HPSpace(nside=nside, dtype=domain.dtype)
return result
@staticmethod
def check_codomain(domain, codomain):
@classmethod
def check_codomain(cls, domain, codomain):
if not isinstance(domain, LMSpace):
raise TypeError(
'ERROR: domain is not a LMSpace')
raise TypeError("domain is not a LMSpace.")
if not isinstance(codomain, HPSpace):
raise TypeError(
'ERROR: codomain must be a HPSpace.')
raise TypeError("codomain must be a HPSpace.")
nside = codomain.nside
lmax = domain.lmax
return None
if lmax != 2*nside:
cls.Logger.warn("Unrecommended: lmax != 2*nside.")
super(LMHPTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
nside = self.codomain.nside
lmax = self.domain.lmax
mmax = lmax
sjob=pyHealpix.sharpjob_d()
sjob = pyHealpix.sharpjob_d()
sjob.set_Healpix_geometry(nside)
sjob.set_triangular_alm_info(lmax,mmax)
sjob.set_triangular_alm_info(lmax, mmax)
if issubclass(inp.dtype.type, np.complexfloating):
[resultReal, resultImag] = [ltf.buildLm(x, lmax=lmax)
for x in (inp.real, inp.imag)]
[resultReal,
resultImag] = [lm_transformation_factory.buildLm(x, lmax=lmax)
for x in (inp.real, inp.imag)]
[resultReal, resultImag] = [sjob.alm2map(x)
for x in [resultReal, resultImag]]
......@@ -103,7 +103,7 @@ class LMHPTransformation(SlicingTransformation):
result = self._combine_complex_result(resultReal, resultImag)
else:
result = ltf.buildLm(inp, lmax=lmax)
result = lm_transformation_factory.buildLm(inp, lmax=lmax)
result = sjob.alm2map(result)
return result
......@@ -64,7 +64,7 @@ class RGRGTransformation(Transformation):
A compatible codomain.
"""
if not isinstance(domain, RGSpace):
raise TypeError('ERROR: domain needs to be a RGSpace')
raise TypeError("domain needs to be a RGSpace")
# parse the zerocenter input
if zerocenter is None:
......@@ -78,40 +78,33 @@ class RGRGTransformation(Transformation):
# calculate the initialization parameters
distances = 1 / (np.array(domain.shape) *
np.array(domain.distances))
if dtype is None:
# create a definitely complex dtype from the dtype of domain
one = domain.dtype.type(1)
dtype = np.dtype(type(one + 1j))
new_space = RGSpace(domain.shape,
zerocenter=zerocenter,
distances=distances,
harmonic=(not domain.harmonic),
dtype=dtype)
dtype=domain.dtype)
# better safe than sorry
cls.check_codomain(domain, new_space)
return new_space
@classmethod
def check_codomain(cls, domain, codomain):
if not isinstance(domain, RGSpace):
raise TypeError('ERROR: domain is not a RGSpace')
if codomain is None:
return False
raise TypeError("domain is not a RGSpace")
if not isinstance(codomain, RGSpace):
return False
raise TypeError("domain is not a RGSpace")
if not np.all(np.array(domain.shape) ==
np.array(codomain.shape)):
return False
raise AttributeError("The shapes of domain and codomain must be "
"identical.")
if domain.harmonic == codomain.harmonic:
return False
if codomain.harmonic and not issubclass(codomain.dtype.type,
np.complexfloating):
cls.logger.warn("Codomain is harmonic but dtype is real.")
raise AttributeError("domain.harmonic and codomain.harmonic must "
"not be the same.")
# Check if the distances match, i.e. dist' = 1 / (num * dist)
if not np.all(
......@@ -119,9 +112,10 @@ class RGRGTransformation(Transformation):
np.array(domain.distances) *
np.array(codomain.distances) - 1) <
10**-7):
return False
raise AttributeError("The grid-distances of domain and codomain "
"do not match.")
return True
super(RGRGTransformation, cls).check_codomain(domain, codomain)
def transform(self, val, axes=None, **kwargs):
"""
......
......@@ -18,6 +18,8 @@
import abc
import numpy as np
from keepers import Loggable
......@@ -41,9 +43,11 @@ class Transformation(Loggable, object):
def get_codomain(cls, domain, dtype=None, zerocenter=None):
raise NotImplementedError
@staticmethod
def check_codomain(domain, codomain):
raise NotImplementedError
@classmethod
def check_codomain(cls, domain, codomain):
if np.dtype(domain.dtype) != np.dtype(codomain.dtype):
cls.Logger.warn("Unrecommended: domain and codomain don't have "
"the same dtype.")
def transform(self, val, axes=None, **kwargs):
raise NotImplementedError
......@@ -16,6 +16,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
class _TransformationCache(object):
def __init__(self):
self.cache = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment