Commit 2c897522 authored by Pumpe, Daniel (dpumpe)'s avatar Pumpe, Daniel (dpumpe)
Browse files

Merge branch 'master' of gitlab.mpcdf.mpg.de:ift/NIFTy into Docstring_Operators

parents 835b4730 25495b60
......@@ -107,7 +107,7 @@ if __name__ == "__main__":
# callback=distance_measure,
# max_history_length=3)
m0 = Field(s_space, val=1)
m0 = Field(s_space, val=1.)
energy = WienerFilterEnergy(position=m0, D=D, j=j)
......
......@@ -168,7 +168,7 @@ class Field(Loggable, Versionable, object):
# ---Powerspectral methods---
def power_analyze(self, spaces=None, log=False, nbin=None, binbounds=None,
real_signal=True):
decompose_power=False):
# check if all spaces in `self.domain` are either harmonic or
# power_space instances
for sp in self.domain:
......@@ -219,7 +219,7 @@ class Field(Loggable, Versionable, object):
pindex = power_domain.pindex
rho = power_domain.rho
if real_signal:
if decompose_power:
hermitian_part, anti_hermitian_part = \
harmonic_domain.hermitian_decomposition(
self.val,
......@@ -245,7 +245,7 @@ class Field(Loggable, Versionable, object):
result_domain = list(self.domain)
result_domain[space_index] = power_domain
if real_signal:
if decompose_power:
result_dtype = np.complex
else:
result_dtype = np.float
......@@ -303,8 +303,8 @@ class Field(Loggable, Versionable, object):
return result_obj
def power_synthesize(self, spaces=None, real_power=True, real_signal=True,
mean=None, std=None):
def power_synthesize(self, spaces=None, real_power=True,
real_signal=False, mean=None, std=None):
# check if the `spaces` input is valid
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
......
......@@ -19,6 +19,7 @@
import numpy as np
from itertools import product
def get_slice_list(shape, axes):
"""
Helper function which generates slice list(s) to traverse over all
......@@ -65,8 +66,7 @@ def get_slice_list(shape, axes):
return
def cast_axis_to_tuple(axis, length):
def cast_axis_to_tuple(axis, length=None):
if axis is None:
return None
try:
......@@ -78,16 +78,17 @@ def cast_axis_to_tuple(axis, length):
raise TypeError(
"Could not convert axis-input to tuple of ints")
# shift negative indices to positive ones
axis = tuple(item if (item >= 0) else (item + length) for item in axis)
if length is not None:
# shift negative indices to positive ones
axis = tuple(item if (item >= 0) else (item + length) for item in axis)
# Deactivated this, in order to allow for the ComposedOperator
# remove duplicate entries
# axis = tuple(set(axis))
# Deactivated this, in order to allow for the ComposedOperator
# remove duplicate entries
# axis = tuple(set(axis))
# assert that all entries are elements in [0, length]
for elem in axis:
assert (0 <= elem < length)
# assert that all entries are elements in [0, length]
for elem in axis:
assert (0 <= elem < length)
return axis
......
......@@ -72,7 +72,9 @@ class ComposedOperator(LinearOperator):
"""
# ---Overwritten properties and methods---
def __init__(self, operators):
def __init__(self, operators, default_spaces=None):
super(ComposedOperator, self).__init__(default_spaces)
self._operator_store = ()
for op in operators:
if not isinstance(op, LinearOperator):
......
......@@ -85,9 +85,10 @@ class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(),
diagonal=None, bare=False, copy=True,
distribution_strategy=None):
def __init__(self, domain=(), diagonal=None, bare=False, copy=True,
distribution_strategy=None, default_spaces=None):
super(DiagonalOperator, self).__init__(default_spaces)
self._domain = self._parse_domain(domain)
if distribution_strategy is None:
......
......@@ -34,8 +34,9 @@ from transformations import RGRGTransformation,\
class FFTOperator(LinearOperator):
""" Transforms between a pair of position and harmonic domains.
Possible domain pairs are
"""Transforms between a pair of position and harmonic domains.
Built-in domain pairs are
- a harmonic and a non-harmonic RGSpace (with matching distances)
- a HPSpace and a LMSpace
- a GLSpace and a LMSpace
......@@ -48,40 +49,34 @@ class FFTOperator(LinearOperator):
Parameters
----------
domain: Space or single-element tuple of Spaces
The domain of the data that is input by "times" and output by
"adjoint_times".
target: Space or single-element tuple of Spaces (optional)
target: Space or single-element tuple of Spaces (optional)
The domain of the data that is output by "times" and input by
"adjoint_times".
If omitted, a co-domain will be chosen automatically.
Whenever "domain" is an RGSpace, the codomain (and its parameters) are
uniquely determined (except for "zerocenter").
For GLSpace, HPSpace, and LMSpace, a sensible (but not unique) co-domain
is chosen that should work satisfactorily in most situations,
For GLSpace, HPSpace, and LMSpace, a sensible (but not unique)
co-domain is chosen that should work satisfactorily in most situations,
but for full control, the user should explicitly specify a codomain.
module: String (optional)
Software module employed for carrying out the transform operations.
For RGSpace pairs this can be "numpy" or "fftw", where "numpy" is always
available, but "fftw" offers higher performance and parallelization.
For sphere-related domains, only "pyHealpix" is available.
If omitted, "fftw" is selected for RGSpaces if available, else "numpy";
on the sphere the default is (unsurprisingly) "pyHealpix".
For RGSpace pairs this can be "numpy" or "fftw", where "numpy" is
always available, but "fftw" offers higher performance and
parallelization. For sphere-related domains, only "pyHealpix" is
available. If omitted, "fftw" is selected for RGSpaces if available,
else "numpy"; on the sphere the default is "pyHealpix".
domain_dtype: data type (optional)
Data type of the fields that go into "times" and come out of
"adjoint_times". Default is "numpy.complex".
target_dtype: data type (optional)
Data type of the fields that go into "adjoint_times" and come out of
"times". Default is "numpy.complex".
(MR: Wouldn't it make sense to specify data types
only to "times" and "adjoint_times"? Does the operator itself really
need to know this, or only the individual call?)
Attributes
----------
domain: Tuple of Spaces (with one entry)
The domain of the data that is input by "times" and output by
"adjoint_times".
......@@ -94,14 +89,12 @@ class FFTOperator(LinearOperator):
Raises
------
ValueError:
if "domain" or "target" are not of the proper type.
"""
# ---Class attributes---
# Domains for which FFTOperator is unitary
unitary_list = (RGSpace,)
# ---Class attributes---
default_codomain_dictionary = {RGSpace: RGSpace,
HPSpace: LMSpace,
......@@ -119,7 +112,8 @@ class FFTOperator(LinearOperator):
# ---Overwritten properties and methods---
def __init__(self, domain, target=None, module=None,
domain_dtype=None, target_dtype=None):
domain_dtype=None, target_dtype=None, default_spaces=None):
super(FFTOperator, self).__init__(default_spaces)
# Initialize domain and target
......@@ -220,23 +214,22 @@ class FFTOperator(LinearOperator):
@property
def unitary(self):
return type(self.domain[0]) in self.unitary_list
return (self._forward_transformation.unitary and
self._backward_transformation.unitary)
# ---Added properties and methods---
@classmethod
def get_default_codomain(cls, domain):
""" Returns a codomain to the given domain.
"""Returns a codomain to the given domain.
Parameters
----------
domain: Space
An instance of RGSpace, HPSpace, GLSpace or LMSpace.
Returns
-------
target: Space
A (more or less perfect) counterpart to "domain" with respect
to a FFT operation.
......@@ -249,9 +242,9 @@ class FFTOperator(LinearOperator):
Raises
------
ValueError:
if no default codomain is defined for "domain".
"""
domain_class = domain.__class__
try:
......
......@@ -45,6 +45,10 @@ class GLLMTransformation(SlicingTransformation):
# ---Mandatory properties and methods---
@property
def unitary(self):
return False
@classmethod
def get_codomain(cls, domain):
"""
......@@ -85,13 +89,13 @@ class GLLMTransformation(SlicingTransformation):
mmax = codomain.mmax
if lmax != mmax:
cls.Logger.warn("Unrecommended: codomain has lmax != mmax.")
cls.logger.warn("Unrecommended: codomain has lmax != mmax.")
if lmax != nlat - 1:
cls.Logger.warn("Unrecommended: codomain has lmax != nlat - 1.")
cls.logger.warn("Unrecommended: codomain has lmax != nlat - 1.")
if nlon != 2*nlat - 1:
cls.Logger.warn("Unrecommended: domain has nlon != 2*nlat - 1.")
cls.logger.warn("Unrecommended: domain has nlon != 2*nlat - 1.")
super(GLLMTransformation, cls).check_codomain(domain, codomain)
......
......@@ -46,6 +46,10 @@ class HPLMTransformation(SlicingTransformation):
# ---Mandatory properties and methods---
@property
def unitary(self):
return False
@classmethod
def get_codomain(cls, domain):
"""
......@@ -83,7 +87,7 @@ class HPLMTransformation(SlicingTransformation):
nside = domain.nside
if lmax != 2*nside:
cls.Logger.warn("Unrecommended: lmax != 2*nside.")
cls.logger.warn("Unrecommended: lmax != 2*nside.")
super(HPLMTransformation, cls).check_codomain(domain, codomain)
......
......@@ -45,6 +45,10 @@ class LMGLTransformation(SlicingTransformation):
# ---Mandatory properties and methods---
@property
def unitary(self):
return False
@classmethod
def get_codomain(cls, domain):
"""
......@@ -91,13 +95,13 @@ class LMGLTransformation(SlicingTransformation):
mmax = domain.mmax
if lmax != mmax:
cls.Logger.warn("Unrecommended: codomain has lmax != mmax.")
cls.logger.warn("Unrecommended: codomain has lmax != mmax.")
if nlat != lmax + 1:
cls.Logger.warn("Unrecommended: codomain has nlat != lmax + 1.")
cls.logger.warn("Unrecommended: codomain has nlat != lmax + 1.")
if nlon != 2*lmax + 1:
cls.Logger.warn("Unrecommended: domain has nlon != 2*lmax + 1.")
cls.logger.warn("Unrecommended: domain has nlon != 2*lmax + 1.")
super(LMGLTransformation, cls).check_codomain(domain, codomain)
......
......@@ -44,6 +44,10 @@ class LMHPTransformation(SlicingTransformation):
# ---Mandatory properties and methods---
@property
def unitary(self):
return False
@classmethod
def get_codomain(cls, domain):
"""
......@@ -85,7 +89,7 @@ class LMHPTransformation(SlicingTransformation):
lmax = domain.lmax
if lmax != 2*nside:
cls.Logger.warn("Unrecommended: lmax != 2*nside.")
cls.logger.warn("Unrecommended: lmax != 2*nside.")
super(LMHPTransformation, cls).check_codomain(domain, codomain)
......
......@@ -23,6 +23,9 @@ from nifty import RGSpace, nifty_configuration
class RGRGTransformation(Transformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None):
super(RGRGTransformation, self).__init__(domain, codomain, module)
......@@ -42,6 +45,12 @@ class RGRGTransformation(Transformation):
else:
raise ValueError('Unsupported FFT module:' + module)
# ---Mandatory properties and methods---
@property
def unitary(self):
return True
@classmethod
def get_codomain(cls, domain, zerocenter=None):
"""
......
......@@ -37,6 +37,10 @@ class Transformation(Loggable, object):
self.domain = domain
self.codomain = codomain
@abc.abstractproperty
def unitary(self):
raise NotImplementedError
@classmethod
def get_codomain(cls, domain):
raise NotImplementedError
......
......@@ -68,6 +68,7 @@ class InvertibleOperatorMixin(object):
else:
self.__inverter = ConjugateGradient(
preconditioner=self.__preconditioner)
super(InvertibleOperatorMixin, self).__init__(*args, **kwargs)
def _times(self, x, spaces, x0=None):
if x0 is None:
......
......@@ -75,8 +75,8 @@ class LinearOperator(Loggable, object):
__metaclass__ = NiftyMeta
def __init__(self):
pass
def __init__(self, default_spaces=None):
self.default_spaces = default_spaces
def _parse_domain(self, domain):
return utilities.parse_domain(domain)
......@@ -124,6 +124,14 @@ class LinearOperator(Loggable, object):
"""
raise NotImplementedError
@property
def default_spaces(self):
return self._default_spaces
@default_spaces.setter
def default_spaces(self, spaces):
self._default_spaces = utilities.cast_axis_to_tuple(spaces)
def __call__(self, *args, **kwargs):
return self.times(*args, **kwargs)
......@@ -332,6 +340,9 @@ class LinearOperator(Loggable, object):
raise ValueError(
"supplied object is not a `nifty.Field`.")
if spaces is None:
spaces = self.default_spaces
# sanitize the `spaces` and `types` input
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
......
......@@ -66,7 +66,9 @@ class ProjectionOperator(EndomorphicOperator):
"""
# ---Overwritten properties and methods---
def __init__(self, projection_field):
def __init__(self, projection_field, default_spaces=None):
super(ProjectionOperator, self).__init__(default_spaces)
if not isinstance(projection_field, Field):
raise TypeError("The projection_field must be a NIFTy-Field"
"instance.")
......
......@@ -78,7 +78,7 @@ class PropagatorOperator(InvertibleOperatorMixin, EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, S=None, M=None, R=None, N=None, inverter=None,
preconditioner=None):
preconditioner=None, default_spaces=None):
"""
Sets the standard operator properties and `codomain`, `_A1`, `_A2`,
and `RN` if required.
......@@ -118,7 +118,8 @@ class PropagatorOperator(InvertibleOperatorMixin, EndomorphicOperator):
preconditioner = self._S_times
super(PropagatorOperator, self).__init__(inverter=inverter,
preconditioner=preconditioner)
preconditioner=preconditioner,
default_spaces=default_spaces)
# ---Mandatory properties and methods---
......
......@@ -78,7 +78,9 @@ class SmoothingOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(), sigma=0, log_distances=False):
def __init__(self, domain=(), sigma=0, log_distances=False,
default_spaces=None):
super(SmoothingOperator, self).__init__(default_spaces)
self._domain = self._parse_domain(domain)
......
......@@ -34,10 +34,10 @@ def create_power_operator(domain, power_spectrum, dtype=None,
distribution_strategy=distribution_strategy)
fp = Field(power_domain,
val=power_spectrum,dtype=dtype,
val=power_spectrum, dtype=dtype,
distribution_strategy=distribution_strategy)
f = fp.power_synthesize(mean=1, std=0, real_signal=False)
f = fp.power_synthesize(mean=1, std=0, decompose_power=False)
power_operator = DiagonalOperator(domain, diagonal=f, bare=True)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment