Commit 2c897522 authored by Pumpe, Daniel (dpumpe)'s avatar Pumpe, Daniel (dpumpe)
Browse files

Merge branch 'master' of gitlab.mpcdf.mpg.de:ift/NIFTy into Docstring_Operators

parents 835b4730 25495b60
...@@ -107,7 +107,7 @@ if __name__ == "__main__": ...@@ -107,7 +107,7 @@ if __name__ == "__main__":
# callback=distance_measure, # callback=distance_measure,
# max_history_length=3) # max_history_length=3)
m0 = Field(s_space, val=1) m0 = Field(s_space, val=1.)
energy = WienerFilterEnergy(position=m0, D=D, j=j) energy = WienerFilterEnergy(position=m0, D=D, j=j)
......
...@@ -168,7 +168,7 @@ class Field(Loggable, Versionable, object): ...@@ -168,7 +168,7 @@ class Field(Loggable, Versionable, object):
# ---Powerspectral methods--- # ---Powerspectral methods---
def power_analyze(self, spaces=None, log=False, nbin=None, binbounds=None, def power_analyze(self, spaces=None, log=False, nbin=None, binbounds=None,
real_signal=True): decompose_power=False):
# check if all spaces in `self.domain` are either harmonic or # check if all spaces in `self.domain` are either harmonic or
# power_space instances # power_space instances
for sp in self.domain: for sp in self.domain:
...@@ -219,7 +219,7 @@ class Field(Loggable, Versionable, object): ...@@ -219,7 +219,7 @@ class Field(Loggable, Versionable, object):
pindex = power_domain.pindex pindex = power_domain.pindex
rho = power_domain.rho rho = power_domain.rho
if real_signal: if decompose_power:
hermitian_part, anti_hermitian_part = \ hermitian_part, anti_hermitian_part = \
harmonic_domain.hermitian_decomposition( harmonic_domain.hermitian_decomposition(
self.val, self.val,
...@@ -245,7 +245,7 @@ class Field(Loggable, Versionable, object): ...@@ -245,7 +245,7 @@ class Field(Loggable, Versionable, object):
result_domain = list(self.domain) result_domain = list(self.domain)
result_domain[space_index] = power_domain result_domain[space_index] = power_domain
if real_signal: if decompose_power:
result_dtype = np.complex result_dtype = np.complex
else: else:
result_dtype = np.float result_dtype = np.float
...@@ -303,8 +303,8 @@ class Field(Loggable, Versionable, object): ...@@ -303,8 +303,8 @@ class Field(Loggable, Versionable, object):
return result_obj return result_obj
def power_synthesize(self, spaces=None, real_power=True, real_signal=True, def power_synthesize(self, spaces=None, real_power=True,
mean=None, std=None): real_signal=False, mean=None, std=None):
# check if the `spaces` input is valid # check if the `spaces` input is valid
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain)) spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import numpy as np import numpy as np
from itertools import product from itertools import product
def get_slice_list(shape, axes): def get_slice_list(shape, axes):
""" """
Helper function which generates slice list(s) to traverse over all Helper function which generates slice list(s) to traverse over all
...@@ -65,8 +66,7 @@ def get_slice_list(shape, axes): ...@@ -65,8 +66,7 @@ def get_slice_list(shape, axes):
return return
def cast_axis_to_tuple(axis, length=None):
def cast_axis_to_tuple(axis, length):
if axis is None: if axis is None:
return None return None
try: try:
...@@ -78,16 +78,17 @@ def cast_axis_to_tuple(axis, length): ...@@ -78,16 +78,17 @@ def cast_axis_to_tuple(axis, length):
raise TypeError( raise TypeError(
"Could not convert axis-input to tuple of ints") "Could not convert axis-input to tuple of ints")
# shift negative indices to positive ones if length is not None:
axis = tuple(item if (item >= 0) else (item + length) for item in axis) # shift negative indices to positive ones
axis = tuple(item if (item >= 0) else (item + length) for item in axis)
# Deactivated this, in order to allow for the ComposedOperator # Deactivated this, in order to allow for the ComposedOperator
# remove duplicate entries # remove duplicate entries
# axis = tuple(set(axis)) # axis = tuple(set(axis))
# assert that all entries are elements in [0, length] # assert that all entries are elements in [0, length]
for elem in axis: for elem in axis:
assert (0 <= elem < length) assert (0 <= elem < length)
return axis return axis
......
...@@ -72,7 +72,9 @@ class ComposedOperator(LinearOperator): ...@@ -72,7 +72,9 @@ class ComposedOperator(LinearOperator):
""" """
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, operators): def __init__(self, operators, default_spaces=None):
super(ComposedOperator, self).__init__(default_spaces)
self._operator_store = () self._operator_store = ()
for op in operators: for op in operators:
if not isinstance(op, LinearOperator): if not isinstance(op, LinearOperator):
......
...@@ -85,9 +85,10 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -85,9 +85,10 @@ class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain=(), def __init__(self, domain=(), diagonal=None, bare=False, copy=True,
diagonal=None, bare=False, copy=True, distribution_strategy=None, default_spaces=None):
distribution_strategy=None): super(DiagonalOperator, self).__init__(default_spaces)
self._domain = self._parse_domain(domain) self._domain = self._parse_domain(domain)
if distribution_strategy is None: if distribution_strategy is None:
......
...@@ -34,8 +34,9 @@ from transformations import RGRGTransformation,\ ...@@ -34,8 +34,9 @@ from transformations import RGRGTransformation,\
class FFTOperator(LinearOperator): class FFTOperator(LinearOperator):
""" Transforms between a pair of position and harmonic domains. """Transforms between a pair of position and harmonic domains.
Possible domain pairs are
Built-in domain pairs are
- a harmonic and a non-harmonic RGSpace (with matching distances) - a harmonic and a non-harmonic RGSpace (with matching distances)
- a HPSpace and a LMSpace - a HPSpace and a LMSpace
- a GLSpace and a LMSpace - a GLSpace and a LMSpace
...@@ -48,40 +49,34 @@ class FFTOperator(LinearOperator): ...@@ -48,40 +49,34 @@ class FFTOperator(LinearOperator):
Parameters Parameters
---------- ----------
domain: Space or single-element tuple of Spaces domain: Space or single-element tuple of Spaces
The domain of the data that is input by "times" and output by The domain of the data that is input by "times" and output by
"adjoint_times". "adjoint_times".
target: Space or single-element tuple of Spaces (optional)
target: Space or single-element tuple of Spaces (optional)
The domain of the data that is output by "times" and input by The domain of the data that is output by "times" and input by
"adjoint_times". "adjoint_times".
If omitted, a co-domain will be chosen automatically. If omitted, a co-domain will be chosen automatically.
Whenever "domain" is an RGSpace, the codomain (and its parameters) are Whenever "domain" is an RGSpace, the codomain (and its parameters) are
uniquely determined (except for "zerocenter"). uniquely determined (except for "zerocenter").
For GLSpace, HPSpace, and LMSpace, a sensible (but not unique) co-domain For GLSpace, HPSpace, and LMSpace, a sensible (but not unique)
is chosen that should work satisfactorily in most situations, co-domain is chosen that should work satisfactorily in most situations,
but for full control, the user should explicitly specify a codomain. but for full control, the user should explicitly specify a codomain.
module: String (optional) module: String (optional)
Software module employed for carrying out the transform operations. Software module employed for carrying out the transform operations.
For RGSpace pairs this can be "numpy" or "fftw", where "numpy" is always For RGSpace pairs this can be "numpy" or "fftw", where "numpy" is
available, but "fftw" offers higher performance and parallelization. always available, but "fftw" offers higher performance and
For sphere-related domains, only "pyHealpix" is available. parallelization. For sphere-related domains, only "pyHealpix" is
If omitted, "fftw" is selected for RGSpaces if available, else "numpy"; available. If omitted, "fftw" is selected for RGSpaces if available,
on the sphere the default is (unsurprisingly) "pyHealpix". else "numpy"; on the sphere the default is "pyHealpix".
domain_dtype: data type (optional) domain_dtype: data type (optional)
Data type of the fields that go into "times" and come out of Data type of the fields that go into "times" and come out of
"adjoint_times". Default is "numpy.complex". "adjoint_times". Default is "numpy.complex".
target_dtype: data type (optional) target_dtype: data type (optional)
Data type of the fields that go into "adjoint_times" and come out of Data type of the fields that go into "adjoint_times" and come out of
"times". Default is "numpy.complex". "times". Default is "numpy.complex".
(MR: Wouldn't it make sense to specify data types
only to "times" and "adjoint_times"? Does the operator itself really
need to know this, or only the individual call?)
Attributes Attributes
---------- ----------
domain: Tuple of Spaces (with one entry) domain: Tuple of Spaces (with one entry)
The domain of the data that is input by "times" and output by The domain of the data that is input by "times" and output by
"adjoint_times". "adjoint_times".
...@@ -94,14 +89,12 @@ class FFTOperator(LinearOperator): ...@@ -94,14 +89,12 @@ class FFTOperator(LinearOperator):
Raises Raises
------ ------
ValueError: ValueError:
if "domain" or "target" are not of the proper type. if "domain" or "target" are not of the proper type.
""" """
# ---Class attributes---
# Domains for which FFTOperator is unitary # ---Class attributes---
unitary_list = (RGSpace,)
default_codomain_dictionary = {RGSpace: RGSpace, default_codomain_dictionary = {RGSpace: RGSpace,
HPSpace: LMSpace, HPSpace: LMSpace,
...@@ -119,7 +112,8 @@ class FFTOperator(LinearOperator): ...@@ -119,7 +112,8 @@ class FFTOperator(LinearOperator):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain, target=None, module=None, def __init__(self, domain, target=None, module=None,
domain_dtype=None, target_dtype=None): domain_dtype=None, target_dtype=None, default_spaces=None):
super(FFTOperator, self).__init__(default_spaces)
# Initialize domain and target # Initialize domain and target
...@@ -220,23 +214,22 @@ class FFTOperator(LinearOperator): ...@@ -220,23 +214,22 @@ class FFTOperator(LinearOperator):
@property @property
def unitary(self): def unitary(self):
return type(self.domain[0]) in self.unitary_list return (self._forward_transformation.unitary and
self._backward_transformation.unitary)
# ---Added properties and methods--- # ---Added properties and methods---
@classmethod @classmethod
def get_default_codomain(cls, domain): def get_default_codomain(cls, domain):
""" Returns a codomain to the given domain. """Returns a codomain to the given domain.
Parameters Parameters
---------- ----------
domain: Space domain: Space
An instance of RGSpace, HPSpace, GLSpace or LMSpace. An instance of RGSpace, HPSpace, GLSpace or LMSpace.
Returns Returns
------- -------
target: Space target: Space
A (more or less perfect) counterpart to "domain" with respect A (more or less perfect) counterpart to "domain" with respect
to a FFT operation. to a FFT operation.
...@@ -249,9 +242,9 @@ class FFTOperator(LinearOperator): ...@@ -249,9 +242,9 @@ class FFTOperator(LinearOperator):
Raises Raises
------ ------
ValueError: ValueError:
if no default codomain is defined for "domain". if no default codomain is defined for "domain".
""" """
domain_class = domain.__class__ domain_class = domain.__class__
try: try:
......
...@@ -45,6 +45,10 @@ class GLLMTransformation(SlicingTransformation): ...@@ -45,6 +45,10 @@ class GLLMTransformation(SlicingTransformation):
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
@property
def unitary(self):
return False
@classmethod @classmethod
def get_codomain(cls, domain): def get_codomain(cls, domain):
""" """
...@@ -85,13 +89,13 @@ class GLLMTransformation(SlicingTransformation): ...@@ -85,13 +89,13 @@ class GLLMTransformation(SlicingTransformation):
mmax = codomain.mmax mmax = codomain.mmax
if lmax != mmax: if lmax != mmax:
cls.Logger.warn("Unrecommended: codomain has lmax != mmax.") cls.logger.warn("Unrecommended: codomain has lmax != mmax.")
if lmax != nlat - 1: if lmax != nlat - 1:
cls.Logger.warn("Unrecommended: codomain has lmax != nlat - 1.") cls.logger.warn("Unrecommended: codomain has lmax != nlat - 1.")
if nlon != 2*nlat - 1: if nlon != 2*nlat - 1:
cls.Logger.warn("Unrecommended: domain has nlon != 2*nlat - 1.") cls.logger.warn("Unrecommended: domain has nlon != 2*nlat - 1.")
super(GLLMTransformation, cls).check_codomain(domain, codomain) super(GLLMTransformation, cls).check_codomain(domain, codomain)
......
...@@ -46,6 +46,10 @@ class HPLMTransformation(SlicingTransformation): ...@@ -46,6 +46,10 @@ class HPLMTransformation(SlicingTransformation):
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
@property
def unitary(self):
return False
@classmethod @classmethod
def get_codomain(cls, domain): def get_codomain(cls, domain):
""" """
...@@ -83,7 +87,7 @@ class HPLMTransformation(SlicingTransformation): ...@@ -83,7 +87,7 @@ class HPLMTransformation(SlicingTransformation):
nside = domain.nside nside = domain.nside
if lmax != 2*nside: if lmax != 2*nside:
cls.Logger.warn("Unrecommended: lmax != 2*nside.") cls.logger.warn("Unrecommended: lmax != 2*nside.")
super(HPLMTransformation, cls).check_codomain(domain, codomain) super(HPLMTransformation, cls).check_codomain(domain, codomain)
......
...@@ -45,6 +45,10 @@ class LMGLTransformation(SlicingTransformation): ...@@ -45,6 +45,10 @@ class LMGLTransformation(SlicingTransformation):
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
@property
def unitary(self):
return False
@classmethod @classmethod
def get_codomain(cls, domain): def get_codomain(cls, domain):
""" """
...@@ -91,13 +95,13 @@ class LMGLTransformation(SlicingTransformation): ...@@ -91,13 +95,13 @@ class LMGLTransformation(SlicingTransformation):
mmax = domain.mmax mmax = domain.mmax
if lmax != mmax: if lmax != mmax:
cls.Logger.warn("Unrecommended: codomain has lmax != mmax.") cls.logger.warn("Unrecommended: codomain has lmax != mmax.")
if nlat != lmax + 1: if nlat != lmax + 1:
cls.Logger.warn("Unrecommended: codomain has nlat != lmax + 1.") cls.logger.warn("Unrecommended: codomain has nlat != lmax + 1.")
if nlon != 2*lmax + 1: if nlon != 2*lmax + 1:
cls.Logger.warn("Unrecommended: domain has nlon != 2*lmax + 1.") cls.logger.warn("Unrecommended: domain has nlon != 2*lmax + 1.")
super(LMGLTransformation, cls).check_codomain(domain, codomain) super(LMGLTransformation, cls).check_codomain(domain, codomain)
......
...@@ -44,6 +44,10 @@ class LMHPTransformation(SlicingTransformation): ...@@ -44,6 +44,10 @@ class LMHPTransformation(SlicingTransformation):
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
@property
def unitary(self):
return False
@classmethod @classmethod
def get_codomain(cls, domain): def get_codomain(cls, domain):
""" """
...@@ -85,7 +89,7 @@ class LMHPTransformation(SlicingTransformation): ...@@ -85,7 +89,7 @@ class LMHPTransformation(SlicingTransformation):
lmax = domain.lmax lmax = domain.lmax
if lmax != 2*nside: if lmax != 2*nside:
cls.Logger.warn("Unrecommended: lmax != 2*nside.") cls.logger.warn("Unrecommended: lmax != 2*nside.")
super(LMHPTransformation, cls).check_codomain(domain, codomain) super(LMHPTransformation, cls).check_codomain(domain, codomain)
......
...@@ -23,6 +23,9 @@ from nifty import RGSpace, nifty_configuration ...@@ -23,6 +23,9 @@ from nifty import RGSpace, nifty_configuration
class RGRGTransformation(Transformation): class RGRGTransformation(Transformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None): def __init__(self, domain, codomain=None, module=None):
super(RGRGTransformation, self).__init__(domain, codomain, module) super(RGRGTransformation, self).__init__(domain, codomain, module)
...@@ -42,6 +45,12 @@ class RGRGTransformation(Transformation): ...@@ -42,6 +45,12 @@ class RGRGTransformation(Transformation):
else: else:
raise ValueError('Unsupported FFT module:' + module) raise ValueError('Unsupported FFT module:' + module)
# ---Mandatory properties and methods---
@property
def unitary(self):
return True
@classmethod @classmethod
def get_codomain(cls, domain, zerocenter=None): def get_codomain(cls, domain, zerocenter=None):
""" """
......
...@@ -37,6 +37,10 @@ class Transformation(Loggable, object): ...@@ -37,6 +37,10 @@ class Transformation(Loggable, object):
self.domain = domain self.domain = domain
self.codomain = codomain self.codomain = codomain
@abc.abstractproperty
def unitary(self):
raise NotImplementedError
@classmethod @classmethod
def get_codomain(cls, domain): def get_codomain(cls, domain):
raise NotImplementedError raise NotImplementedError
......
...@@ -68,6 +68,7 @@ class InvertibleOperatorMixin(object): ...@@ -68,6 +68,7 @@ class InvertibleOperatorMixin(object):
else: else:
self.__inverter = ConjugateGradient( self.__inverter = ConjugateGradient(
preconditioner=self.__preconditioner) preconditioner=self.__preconditioner)
super(InvertibleOperatorMixin, self).__init__(*args, **kwargs)
def _times(self, x, spaces, x0=None): def _times(self, x, spaces, x0=None):
if x0 is None: if x0 is None:
......
...@@ -75,8 +75,8 @@ class LinearOperator(Loggable, object): ...@@ -75,8 +75,8 @@ class LinearOperator(Loggable, object):
__metaclass__ = NiftyMeta __metaclass__ = NiftyMeta
def __init__(self): def __init__(self, default_spaces=None):
pass self.default_spaces = default_spaces
def _parse_domain(self, domain): def _parse_domain(self, domain):
return utilities.parse_domain(domain) return utilities.parse_domain(domain)
...@@ -124,6 +124,14 @@ class LinearOperator(Loggable, object): ...@@ -124,6 +124,14 @@ class LinearOperator(Loggable, object):
""" """
raise NotImplementedError raise NotImplementedError
@property
def default_spaces(self):
return self._default_spaces
@default_spaces.setter
def default_spaces(self, spaces):
self._default_spaces = utilities.cast_axis_to_tuple(spaces)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.times(*args, **kwargs) return self.times(*args, **kwargs)
...@@ -332,6 +340,9 @@ class LinearOperator(Loggable, object): ...@@ -332,6 +340,9 @@ class LinearOperator(Loggable, object):
raise ValueError( raise ValueError(
"supplied object is not a `nifty.Field`.") "supplied object is not a `nifty.Field`.")
if spaces is None:
spaces = self.default_spaces
# sanitize the `spaces` and `types` input # sanitize the `spaces` and `types` input
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
......
...@@ -66,7 +66,9 @@ class ProjectionOperator(EndomorphicOperator): ...@@ -66,7 +66,9 @@ class ProjectionOperator(EndomorphicOperator):
""" """