Commit 6b18bbdf authored by Martin Reinecke's avatar Martin Reinecke
Browse files

replace InvertibleOperatorMixin with InversionEnabler; move...

replace InvertibleOperatorMixin with InversionEnabler; move PowerSpectrum-related functionality out of Field
parent 42b1e9e8
Pipeline #20029 passed with stage
in 4 minutes and 22 seconds
...@@ -19,13 +19,13 @@ ...@@ -19,13 +19,13 @@
from __future__ import division, print_function from __future__ import division, print_function
from builtins import range from builtins import range
import numpy as np import numpy as np
from .spaces.power_space import PowerSpace
from . import nifty_utilities as utilities from . import nifty_utilities as utilities
from .random import Random from .random import Random
from .domain_tuple import DomainTuple from .domain_tuple import DomainTuple
from functools import reduce from functools import reduce
from . import dobj from . import dobj
class Field(object): class Field(object):
""" The discrete representation of a continuous field over multiple spaces. """ The discrete representation of a continuous field over multiple spaces.
...@@ -80,7 +80,8 @@ class Field(object): ...@@ -80,7 +80,8 @@ class Field(object):
raise ValueError("Domain mismatch") raise ValueError("Domain mismatch")
self._val = dobj.from_object(val.val, dtype=dtype, copy=copy) self._val = dobj.from_object(val.val, dtype=dtype, copy=copy)
elif (np.isscalar(val)): elif (np.isscalar(val)):
self._val = dobj.full(self.domain.shape, dtype=dtype, fill_value=val) self._val = dobj.full(self.domain.shape, dtype=dtype,
fill_value=val)
elif isinstance(val, dobj.data_object): elif isinstance(val, dobj.data_object):
if self.domain.shape == val.shape: if self.domain.shape == val.shape:
self._val = dobj.from_object(val, dtype=dtype, copy=copy) self._val = dobj.from_object(val, dtype=dtype, copy=copy)
...@@ -180,10 +181,6 @@ class Field(object): ...@@ -180,10 +181,6 @@ class Field(object):
------- -------
out : Field out : Field
The output object. The output object.
See Also
--------
power_synthesize
""" """
domain = DomainTuple.make(domain) domain = DomainTuple.make(domain)
...@@ -191,182 +188,6 @@ class Field(object): ...@@ -191,182 +188,6 @@ class Field(object):
return Field(domain=domain, val=generator_function(dtype=dtype, return Field(domain=domain, val=generator_function(dtype=dtype,
shape=domain.shape, **kwargs)) shape=domain.shape, **kwargs))
# ---Powerspectral methods---
def power_analyze(self, spaces=None, binbounds=None,
keep_phase_information=False):
""" Computes the square root power spectrum for a subspace of `self`.
Creates a PowerSpace for the space addressed by `spaces` with the given
binning and computes the power spectrum as a Field over this
PowerSpace. This can only be done if the subspace to be analyzed is a
harmonic space. The resulting field has the same units as the initial
field, corresponding to the square root of the power spectrum.
Parameters
----------
spaces : int *optional*
The subspace for which the powerspectrum shall be computed.
(default : None).
binbounds : array-like *optional*
Inner bounds of the bins (default : None).
if binbounds==None : bins are inferred.
keep_phase_information : boolean, *optional*
If False, return a real-valued result containing the power spectrum
of the input Field.
If True, return a complex-valued result whose real component
contains the power spectrum computed from the real part of the
input Field, and whose imaginary component contains the power
spectrum computed from the imaginary part of the input Field.
The absolute value of this result should be identical to the output
of power_analyze with keep_phase_information=False.
(default : False).
Raise
-----
TypeError
Raised if any of the input field's domains is not harmonic
Returns
-------
out : Field
The output object. Its domain is a PowerSpace and it contains
the power spectrum of 'self's field.
See Also
--------
power_synthesize, PowerSpace
"""
# check if all spaces in `self.domain` are either harmonic or
# power_space instances
for sp in self.domain:
if not sp.harmonic and not isinstance(sp, PowerSpace):
print("WARNING: Field has a space in `domain` which is "
"neither harmonic nor a PowerSpace.")
# check if the `spaces` input is valid
if spaces is None:
spaces = range(len(self.domain))
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
if len(spaces) == 0:
raise ValueError("No space for analysis specified.")
if keep_phase_information:
parts = [self.real*self.real, self.imag*self.imag]
else:
parts = [self.real*self.real + self.imag*self.imag]
parts = [ part.weight(1,spaces) for part in parts ]
for space_index in spaces:
parts = [self._single_power_analyze(field=part,
idx=space_index,
binbounds=binbounds)
for part in parts]
return parts[0] + 1j*parts[1] if keep_phase_information else parts[0]
@staticmethod
def _single_power_analyze(field, idx, binbounds):
from .operators.power_projection_operator import PowerProjectionOperator
power_domain = PowerSpace(field.domain[idx], binbounds)
ppo = PowerProjectionOperator(field.domain, power_domain, idx)
return ppo(field.weight(-1))
def _compute_spec(self, spaces):
from .operators.power_projection_operator import PowerProjectionOperator
from .basic_arithmetics import sqrt
if spaces is None:
spaces = range(len(self.domain))
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
# create the result domain
result_domain = list(self.domain)
spec = sqrt(self)
for i in spaces:
result_domain[i] = self.domain[i].harmonic_partner
ppo = PowerProjectionOperator(result_domain, self.domain[i], i)
spec = ppo.adjoint_times(spec)
return spec
def power_synthesize(self, spaces=None, real_power=True, real_signal=True):
""" Yields a sampled field with `self`**2 as its power spectrum.
This method draws a Gaussian random field in the harmonic partner
domain of this field's domains, using this field as power spectrum.
Parameters
----------
spaces : {tuple, int, None} *optional*
Specifies the subspace containing all the PowerSpaces which
should be converted (default : None).
if spaces==None : Tries to convert the whole domain.
real_power : boolean *optional*
Determines whether the power spectrum is treated as intrinsically
real or complex (default : True).
real_signal : boolean *optional*
True will result in a purely real signal-space field
(default : True).
Returns
-------
out : Field
The output object. A random field created with the power spectrum
stored in the `spaces` in `self`.
Notes
-----
For this the spaces specified by `spaces` must be a PowerSpace.
This expects this field to be the square root of a power spectrum, i.e.
to have the unit of the field to be sampled.
See Also
--------
power_analyze
Raises
------
ValueError : If domain specified by `spaces` is not a PowerSpace.
"""
spec = self._compute_spec(spaces)
# create random samples: one or two, depending on whether the
# power spectrum is real or complex
result = [self.from_random('normal', mean=0., std=1.,
domain=spec.domain,
dtype=np.float if real_signal
else np.complex)
for x in range(1 if real_power else 2)]
# MR: dummy call - will be removed soon
if real_signal:
self.from_random('normal', mean=0., std=1.,
domain=spec.domain, dtype=np.float)
# apply the rescaler to the random fields
result[0] *= spec.real
if not real_power:
result[1] *= spec.imag
return result[0] if real_power else result[0] + 1j*result[1]
def power_synthesize_special(self, spaces=None):
spec = self._compute_spec(spaces)
# MR: dummy call - will be removed soon
self.from_random('normal', mean=0., std=1.,
domain=spec.domain, dtype=np.complex)
return spec.real
# ---Properties--- # ---Properties---
@property @property
......
from ...operators.endomorphic_operator import EndomorphicOperator from ...operators.endomorphic_operator import EndomorphicOperator
from ...operators.invertible_operator_mixin import InvertibleOperatorMixin
from ...operators.diagonal_operator import DiagonalOperator from ...operators.diagonal_operator import DiagonalOperator
class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator): class CriticalPowerCurvature(EndomorphicOperator):
"""The curvature of the CriticalPowerEnergy. """The curvature of the CriticalPowerEnergy.
This operator implements the second derivative of the This operator implements the second derivative of the
CriticalPowerEnergy used in some minimization algorithms or CriticalPowerEnergy used in some minimization algorithms or
for error estimates of the powerspectrum. for error estimates of the power spectrum.
Parameters Parameters
...@@ -21,15 +20,14 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator): ...@@ -21,15 +20,14 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, theta, T, inverter, preconditioner=None, **kwargs): def __init__(self, theta, T):
self.theta = DiagonalOperator(theta) self.theta = DiagonalOperator(theta)
self.T = T self.T = T
if preconditioner is None: super(CriticalPowerCurvature, self).__init__()
preconditioner = self.theta.inverse_times
super(CriticalPowerCurvature, self).__init__( @property
inverter=inverter, def preconditioner(self):
preconditioner=preconditioner, return self.theta.inverse_times
**kwargs)
def _times(self, x): def _times(self, x):
return self.T(x) + self.theta(x) return self.T(x) + self.theta(x)
......
from ...energies.energy import Energy from ...energies.energy import Energy
from ...operators.smoothness_operator import SmoothnessOperator from ...operators.smoothness_operator import SmoothnessOperator
from ...operators.inversion_enabler import InversionEnabler
from . import CriticalPowerCurvature from . import CriticalPowerCurvature
from ...memoization import memo from ...memoization import memo
from ...sugar import generate_posterior_sample from ...sugar import generate_posterior_sample, power_analyze
from ... import Field, exp from ... import Field, exp
...@@ -95,8 +96,9 @@ class CriticalPowerEnergy(Energy): ...@@ -95,8 +96,9 @@ class CriticalPowerEnergy(Energy):
@property @property
def curvature(self): def curvature(self):
curvature = CriticalPowerCurvature(theta=self._theta.weight(-1), curvature = InversionEnabler(CriticalPowerCurvature(
T=self.T, inverter=self._inverter) theta=self._theta.weight(-1),
T=self.T), inverter=self._inverter)
return curvature return curvature
# ---Added properties and methods--- # ---Added properties and methods---
...@@ -119,8 +121,9 @@ class CriticalPowerEnergy(Energy): ...@@ -119,8 +121,9 @@ class CriticalPowerEnergy(Energy):
# self.logger.info("Drawing sample %i" % i) # self.logger.info("Drawing sample %i" % i)
posterior_sample = generate_posterior_sample( posterior_sample = generate_posterior_sample(
self.m, self.D) self.m, self.D)
projected_sample = posterior_sample.power_analyze( projected_sample = power_analyze(
binbounds=self.position.domain[0].binbounds) posterior_sample,
binbounds=self.position.domain[0].binbounds)
w += (projected_sample) * self.rho w += (projected_sample) * self.rho
w /= float(self.samples) w /= float(self.samples)
else: else:
......
from ...operators import EndomorphicOperator,\ from ...operators import EndomorphicOperator
InvertibleOperatorMixin
from ...memoization import memo from ...memoization import memo
from ...basic_arithmetics import exp from ...basic_arithmetics import exp
from ...sugar import create_composed_fft_operator from ...sugar import create_composed_fft_operator
class LogNormalWienerFilterCurvature(InvertibleOperatorMixin, class LogNormalWienerFilterCurvature(EndomorphicOperator):
EndomorphicOperator):
"""The curvature of the LogNormalWienerFilterEnergy. """The curvature of the LogNormalWienerFilterEnergy.
This operator implements the second derivative of the This operator implements the second derivative of the
...@@ -24,7 +22,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin, ...@@ -24,7 +22,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
The prior signal covariance The prior signal covariance
""" """
def __init__(self, R, N, S, d, position, inverter, fft4exp=None, **kwargs): def __init__(self, R, N, S, d, position, fft4exp=None):
self.R = R self.R = R
self.N = N self.N = N
self.S = S self.S = S
...@@ -37,9 +35,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin, ...@@ -37,9 +35,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
else: else:
self._fft = fft4exp self._fft = fft4exp
super(LogNormalWienerFilterCurvature, self).__init__( super(LogNormalWienerFilterCurvature, self).__init__()
inverter=inverter,
**kwargs)
@property @property
def domain(self): def domain(self):
......
...@@ -2,7 +2,7 @@ from ...energies.energy import Energy ...@@ -2,7 +2,7 @@ from ...energies.energy import Energy
from ...memoization import memo from ...memoization import memo
from . import LogNormalWienerFilterCurvature from . import LogNormalWienerFilterCurvature
from ...sugar import create_composed_fft_operator from ...sugar import create_composed_fft_operator
from ...operators.inversion_enabler import InversionEnabler
class LogNormalWienerFilterEnergy(Energy): class LogNormalWienerFilterEnergy(Energy):
"""The Energy for the log-normal Wiener filter. """The Energy for the log-normal Wiener filter.
...@@ -47,20 +47,21 @@ class LogNormalWienerFilterEnergy(Energy): ...@@ -47,20 +47,21 @@ class LogNormalWienerFilterEnergy(Energy):
@memo @memo
def value(self): def value(self):
return 0.5*(self.position.vdot(self._Sp) + return 0.5*(self.position.vdot(self._Sp) +
self.curvature._Rexppd.vdot(self.curvature._NRexppd)) self.curvature.op._Rexppd.vdot(self.curvature.op._NRexppd))
@property @property
@memo @memo
def gradient(self): def gradient(self):
return self._Sp + self.curvature._exppRNRexppd return self._Sp + self.curvature.op._exppRNRexppd
@property @property
@memo @memo
def curvature(self): def curvature(self):
return LogNormalWienerFilterCurvature(R=self.R, N=self.N, S=self.S, return InversionEnabler(
d=self.d, position=self.position, LogNormalWienerFilterCurvature(R=self.R, N=self.N, S=self.S,
fft4exp=self._fft, d=self.d, position=self.position,
inverter=self._inverter) fft4exp=self._fft),
inverter=self._inverter)
@property @property
@memo @memo
......
from ...operators import EndomorphicOperator,\ from ...operators import EndomorphicOperator
InvertibleOperatorMixin
class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator): class WienerFilterCurvature(EndomorphicOperator):
"""The curvature of the WienerFilterEnergy. """The curvature of the WienerFilterEnergy.
This operator implements the second derivative of the This operator implements the second derivative of the
...@@ -20,16 +19,15 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator): ...@@ -20,16 +19,15 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
The prior signal covariance The prior signal covariance
""" """
def __init__(self, R, N, S, inverter, preconditioner=None, **kwargs): def __init__(self, R, N, S):
self.R = R self.R = R
self.N = N self.N = N
self.S = S self.S = S
if preconditioner is None: super(WienerFilterCurvature, self).__init__()
preconditioner = self.S.times
super(WienerFilterCurvature, self).__init__( @property
inverter=inverter, def preconditioner(self):
preconditioner=preconditioner, return self.S.times
**kwargs)
@property @property
def domain(self): def domain(self):
......
from ...energies.energy import Energy from ...energies.energy import Energy
from ...memoization import memo from ...memoization import memo
from ...operators.inversion_enabler import InversionEnabler
from . import WienerFilterCurvature from . import WienerFilterCurvature
...@@ -49,8 +50,9 @@ class WienerFilterEnergy(Energy): ...@@ -49,8 +50,9 @@ class WienerFilterEnergy(Energy):
@property @property
@memo @memo
def curvature(self): def curvature(self):
return WienerFilterCurvature(R=self.R, N=self.N, S=self.S, return InversionEnabler(WienerFilterCurvature(R=self.R, N=self.N,
inverter=self._inverter) S=self.S),
inverter=self._inverter)
@property @property
@memo @memo
......
...@@ -11,7 +11,7 @@ from .direct_smoothing_operator import DirectSmoothingOperator ...@@ -11,7 +11,7 @@ from .direct_smoothing_operator import DirectSmoothingOperator
from .fft_operator import FFTOperator from .fft_operator import FFTOperator
from .invertible_operator_mixin import InvertibleOperatorMixin from .inversion_enabler import InversionEnabler
from .composed_operator import ComposedOperator from .composed_operator import ComposedOperator
......
...@@ -16,60 +16,82 @@ ...@@ -16,60 +16,82 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import object
from ..energies import QuadraticEnergy from ..energies import QuadraticEnergy
from ..field import Field from ..field import Field
from .linear_operator import LinearOperator
class InvertibleOperatorMixin(object): class InversionEnabler(LinearOperator):
""" Mixin class to invert implicit defined operators.
This class provides the functionality necessary to invert the application def __init__(self, op, inverter, preconditioner=None):
of a given implicitly defined operator on a field. Inheriting self._op = op
functionality from this class provides the derived class with the self._inverter = inverter
operations inverse to the defined operator applications if preconditioner is None and hasattr(op, "preconditioner"):
(e.g. .inverse_times if .times is defined and self._preconditioner = op.preconditioner
.adjoint_times if .adjoint_inverse_times is defined) else:
self._preconditioner = preconditioner
super(InversionEnabler, self).__init__()
Parameters @property
---------- def domain(self):
inverter : Inverter return self._op.domain
An instance of an Inverter class.
"""
def __init__(self, inverter, preconditioner=None, *args, **kwargs): @property
self.__inverter = inverter def target(self):
self._preconditioner = preconditioner return self._op.target
super(InvertibleOperatorMixin, self).__init__(*args, **kwargs)
@property
def unitary(self):
return self._op.unitary
@property
def op(self):
return self._op
def _times(self, x): def _times(self, x):
x0 = Field.zeros(self.target, dtype=x.dtype) try:
(result, convergence) = self.__inverter(QuadraticEnergy( res = self._op._times(x)
A=self.inverse_times, except NotImplementedError:
x0 = Field.zeros(self.target, dtype=x.dtype)
(result, convergence) = self._inverter(QuadraticEnergy(
A=self._op.inverse_times,