Commit de97de56 authored by Martin Reinecke's avatar Martin Reinecke

revamp codomain handling

parent af646615
......@@ -121,12 +121,11 @@ class DiagonalOperator(EndomorphicOperator):
"""
if bare:
diagonal = self._diagonal.weight(power=-1)
return self._diagonal.weight(power=-1)
elif copy:
diagonal = self._diagonal.copy()
return self._diagonal.copy()
else:
diagonal = self._diagonal
return diagonal
return self._diagonal
def inverse_diagonal(self, bare=False):
""" Returns the inverse-diagonal of the operator.
......@@ -181,9 +180,7 @@ class DiagonalOperator(EndomorphicOperator):
"""
# use the casting functionality from Field to process `diagonal`
f = Field(domain=self.domain,
val=diagonal,
copy=copy)
f = Field(domain=self.domain, val=diagonal, copy=copy)
# weight if the given values were `bare` is True
# do inverse weightening if the other way around
......@@ -224,6 +221,4 @@ class DiagonalOperator(EndomorphicOperator):
# here the actual multiplication takes place
local_result = operation(reshaped_local_diagonal)(x.val)
result_field = x.copy_empty(dtype=local_result.dtype)
result_field.val=local_result
return result_field
return Field (x.domain,val=local_result)
......@@ -107,11 +107,13 @@ class FFTOperator(LinearOperator):
"space as input domain.")
if target is None:
target = (self.get_default_codomain(self.domain[0]), )
target = (self.domain[0].get_default_codomain(), )
self._target = self._parse_domain(target)
if len(self.target) != 1:
raise ValueError("TransformationOperator accepts only exactly one "
"space as output target.")
self.domain[0].check_codomain(self.target[0])
self.target[0].check_codomain(self.domain[0])
# Create transformation instances
forward_class = self.transformation_dictionary[
......@@ -190,47 +192,3 @@ class FFTOperator(LinearOperator):
def unitary(self):
return (self._forward_transformation.unitary and
self._backward_transformation.unitary)
# ---Added properties and methods---
@classmethod
def get_default_codomain(cls, domain):
"""Returns a codomain to the given domain.
Parameters
----------
domain: Space
An instance of RGSpace, HPSpace, GLSpace or LMSpace.
Returns
-------
target: Space
A (more or less perfect) counterpart to "domain" with respect
to a FFT operation.
Whenever "domain" is an RGSpace, the codomain (and its parameters)
are uniquely determined.
For GLSpace, HPSpace, and LMSpace, a sensible (but not unique)
co-domain is chosen that should work satisfactorily in most
situations. For full control however, the user should not rely on
this method.
Raises
------
ValueError:
if no default codomain is defined for "domain".
"""
domain_class = domain.__class__
try:
codomain_class = cls.default_codomain_dictionary[domain_class]
except KeyError:
raise ValueError("Unknown domain")
try:
transform_class = cls.transformation_dictionary[(domain_class,
codomain_class)]
except KeyError:
raise ValueError(
"No transformation for domain-codomain pair found.")
return transform_class.get_codomain(domain)
......@@ -34,36 +34,6 @@ class GLLMTransformation(SlicingTransformation):
def unitary(self):
return False
@classmethod
def get_codomain(cls, domain):
"""
Generates a compatible codomain to which transformations are
reasonable, i.e.\ an instance of the :py:class:`LMSpace` class.
Parameters
----------
domain: GLSpace
Space for which a codomain is to be generated
Returns
-------
codomain : LMSpace
A compatible codomain.
"""
if not isinstance(domain, GLSpace):
raise TypeError("domain needs to be a GLSpace")
return LMSpace(lmax=domain.nlat-1,mmax=(domain.nlon-1)//2)
@classmethod
def check_codomain(cls, domain, codomain):
if not isinstance(domain, GLSpace):
raise TypeError("domain is not a GLSpace")
if not isinstance(codomain, LMSpace):
raise TypeError("codomain must be a LMSpace.")
super(GLLMTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp):
nlat = self.domain.nlat
nlon = self.domain.nlon
......@@ -74,11 +44,11 @@ class GLLMTransformation(SlicingTransformation):
sjob.set_Gauss_geometry(nlat, nlon)
sjob.set_triangular_alm_info(lmax, mmax)
if issubclass(inp.dtype.type, np.complexfloating):
return self._combine_complex_result(
return \
lm_transformation_helper.buildIdx(sjob.map2alm(inp.real),
lmax=lmax),
lm_transformation_helper.buildIdx(sjob.map2alm(inp.imag),
lmax=lmax))
lmax=lmax)\
+1j*lm_transformation_helper.buildIdx(sjob.map2alm(inp.imag),
lmax=lmax)
else:
return lm_transformation_helper.buildIdx(sjob.map2alm(inp),
lmax=lmax)
......@@ -35,36 +35,6 @@ class HPLMTransformation(SlicingTransformation):
def unitary(self):
return False
@classmethod
def get_codomain(cls, domain):
"""
Generates a compatible codomain to which transformations are
reasonable, i.e.\ an instance of the :py:class:`lm_space` class.
Parameters
----------
domain: HPSpace
Space for which a codomain is to be generated
Returns
-------
codomain : LMSpace
A compatible codomain.
"""
if not isinstance(domain, HPSpace):
raise TypeError("domain needs to be a HPSpace")
return LMSpace(lmax=2*domain.nside)
@classmethod
def check_codomain(cls, domain, codomain):
if not isinstance(domain, HPSpace):
raise TypeError("domain is not a HPSpace")
if not isinstance(codomain, LMSpace):
raise TypeError("codomain must be a LMSpace.")
super(HPLMTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp):
lmax = self.codomain.lmax
mmax = lmax
......@@ -78,7 +48,7 @@ class HPLMTransformation(SlicingTransformation):
resultImag] = [lm_transformation_helper.buildIdx(x, lmax=lmax)
for x in [resultReal, resultImag]]
result = self._combine_complex_result(resultReal, resultImag)
result = resultReal +1j*resultImag
else:
result = pyHealpix.map2alm(inp, lmax, mmax)
......
......@@ -38,43 +38,6 @@ class LMGLTransformation(SlicingTransformation):
def unitary(self):
return False
@classmethod
def get_codomain(cls, domain):
"""
Generates a compatible codomain to which transformations are
reasonable, i.e.\ a pixelization of the two-sphere.
Parameters
----------
domain : LMSpace
Space for which a codomain is to be generated
Returns
-------
codomain : HPSpace
A compatible codomain.
References
----------
.. [#] M. Reinecke and D. Sverre Seljebotn, 2013,
"Libsharp - spherical
harmonic transforms revisited";
`arXiv:1303.4945 <http://www.arxiv.org/abs/1303.4945>`_
"""
if not isinstance(domain, LMSpace):
raise TypeError("domain needs to be a LMSpace")
return GLSpace(nlat=domain.lmax+1, nlon=domain.mmax*2+1)
@classmethod
def check_codomain(cls, domain, codomain):
if not isinstance(domain, LMSpace):
raise TypeError("domain is not a LMSpace")
if not isinstance(codomain, GLSpace):
raise TypeError("codomain must be a GLSpace.")
super(LMGLTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp):
nlat = self.codomain.nlat
nlon = self.codomain.nlon
......@@ -92,7 +55,7 @@ class LMGLTransformation(SlicingTransformation):
[resultReal, resultImag] = [sjob.alm2map(x)
for x in [resultReal, resultImag]]
result = self._combine_complex_result(resultReal, resultImag)
result = resultReal + 1j*resultImag
else:
result = lm_transformation_helper.buildLm(inp, lmax=lmax)
......
......@@ -33,44 +33,6 @@ class LMHPTransformation(SlicingTransformation):
def unitary(self):
return False
@classmethod
def get_codomain(cls, domain):
"""
Generates a compatible codomain to which transformations are
reasonable, i.e.\ a pixelization of the two-sphere.
Parameters
----------
domain : LMSpace
Space for which a codomain is to be generated
Returns
-------
codomain : HPSpace
A compatible codomain.
References
----------
.. [#] K.M. Gorski et al., 2005, "HEALPix: A Framework for
High-Resolution Discretization and Fast Analysis of Data
Distributed on the Sphere", *ApJ* 622..759G.
"""
if not isinstance(domain, LMSpace):
raise TypeError("domain needs to be a LMSpace.")
nside = max((domain.lmax + 1)//2, 1)
result = HPSpace(nside=nside)
return result
@classmethod
def check_codomain(cls, domain, codomain):
if not isinstance(domain, LMSpace):
raise TypeError("domain is not a LMSpace.")
if not isinstance(codomain, HPSpace):
raise TypeError("codomain must be a HPSpace.")
super(LMHPTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp):
nside = self.codomain.nside
lmax = self.domain.lmax
......@@ -84,7 +46,7 @@ class LMHPTransformation(SlicingTransformation):
[resultReal, resultImag] = [pyHealpix.alm2map(x, lmax, mmax, nside)
for x in [resultReal, resultImag]]
result = self._combine_complex_result(resultReal, resultImag)
result = resultReal + 1j*resultImag
else:
result = lm_transformation_helper.buildLm(inp, lmax=lmax)
......
......@@ -36,65 +36,6 @@ class RGRGTransformation(Transformation):
def unitary(self):
return True
@classmethod
def get_codomain(cls, domain):
"""
Generates a compatible codomain to which transformations are
reasonable, i.e.\ either a shifted grid or a Fourier conjugate
grid.
Parameters
----------
domain: RGSpace
Space for which a codomain is to be generated
Returns
-------
codomain : nifty.rg_space
A compatible codomain.
"""
if not isinstance(domain, RGSpace):
raise TypeError("domain needs to be a RGSpace")
# calculate the initialization parameters
distances = 1. / (np.array(domain.shape) *
np.array(domain.distances))
new_space = RGSpace(domain.shape,
distances=distances,
harmonic=(not domain.harmonic))
# better safe than sorry
cls.check_codomain(domain, new_space)
return new_space
@classmethod
def check_codomain(cls, domain, codomain):
if not isinstance(domain, RGSpace):
raise TypeError("domain is not a RGSpace")
if not isinstance(codomain, RGSpace):
raise TypeError("domain is not a RGSpace")
if not np.all(np.array(domain.shape) ==
np.array(codomain.shape)):
raise AttributeError("The shapes of domain and codomain must be "
"identical.")
if domain.harmonic == codomain.harmonic:
raise AttributeError("domain.harmonic and codomain.harmonic must "
"not be the same.")
# Check if the distances match, i.e. dist' = 1 / (num * dist)
if not np.all(
np.absolute(np.array(domain.shape) *
np.array(domain.distances) *
np.array(codomain.distances) - 1) <
1e-7):
raise AttributeError("The grid-distances of domain and codomain "
"do not match.")
super(RGRGTransformation, cls).check_codomain(domain, codomain)
def transform(self, val, axes=None):
"""
......
......@@ -39,9 +39,6 @@ class SlicingTransformation(Transformation):
val[slice_list])
return return_val
def _combine_complex_result(self, resultReal, resultImag):
return resultReal + 1j*resultImag
@abc.abstractmethod
def _transformation_of_slice(self, inp):
raise NotImplementedError
......@@ -22,31 +22,14 @@ from future.utils import with_metaclass
class Transformation(with_metaclass(abc.ABCMeta, type('NewBase', (object,), {}))):
"""
A generic transformation which defines a static check_codomain
method for all transforms.
"""
def __init__(self, domain, codomain):
if codomain is None:
self.domain = domain
self.codomain = self.get_codomain(domain)
else:
self.check_codomain(domain, codomain)
self.domain = domain
self.codomain = codomain
self.domain = domain
self.codomain = codomain
@abc.abstractproperty
def unitary(self):
raise NotImplementedError
@classmethod
def get_codomain(cls, domain):
raise NotImplementedError
@classmethod
def check_codomain(cls, domain, codomain):
pass
def transform(self, val, axes=None):
raise NotImplementedError
......@@ -174,3 +174,12 @@ class GLSpace(Space):
if nlon < 1:
raise ValueError("nlon must be a positive number.")
return nlon
def get_default_codomain(self):
from .. import LMSpace
return LMSpace(lmax=self.nlat-1,mmax=(self.nlon-1)//2)
def check_codomain(self, codomain):
from .. import LMSpace
if not isinstance(codomain, LMSpace):
raise TypeError("codomain must be a LMSpace.")
......@@ -133,3 +133,12 @@ class HPSpace(Space):
if nside < 1:
raise ValueError("nside must be >=1.")
return nside
def get_default_codomain(self):
from .. import LMSpace
return LMSpace(lmax=2*domain.nside)
def check_codomain(self, codomain):
from .. import LMSpace
if not isinstance(codomain, LMSpace):
raise TypeError("codomain must be a LMSpace.")
......@@ -175,3 +175,12 @@ class LMSpace(Space):
if lmax < 0:
raise ValueError("lmax must be >=0.")
return lmax
def get_default_codomain(self):
from .. import GLSpace
return GLSpace(self.lmax+1, self.mmax*2+1)
def check_codomain(self, codomain):
from .. import GLSpace, HPSpace
if not isinstance(codomain, (GLSpace, HPSpace)):
raise TypeError("codomain must be a GLSpace.")
......@@ -206,6 +206,31 @@ class RGSpace(Space):
raise NotImplementedError
return lambda x: np.exp(-2. * np.pi*np.pi * x*x * sigma*sigma)
def get_default_codomain(self):
distances = 1. / (np.array(self.shape)*np.array(self.distances))
return RGSpace(self.shape, distances, not self.harmonic)
def check_codomain(self, codomain):
if not isinstance(codomain, RGSpace):
raise TypeError("domain is not a RGSpace")
if self.shape != codomain.shape:
raise AttributeError("The shapes of domain and codomain must be "
"identical.")
if self.harmonic == codomain.harmonic:
raise AttributeError("domain.harmonic and codomain.harmonic must "
"not be the same.")
# Check if the distances match, i.e. dist' = 1 / (num * dist)
if not np.all(
np.absolute(np.array(self.shape) *
np.array(self.distances) *
np.array(codomain.distances) - 1) < 1e-7):
raise AttributeError("The grid-distances of domain and codomain "
"do not match.")
# ---Added properties and methods---
@property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment