Commit 662b29d7 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

rework operators

parent c83661d0
Pipeline #23250 passed with stage
in 4 minutes and 49 seconds
......@@ -2,36 +2,6 @@ import numpy as np
import nifty2go as ift
# Note that the constructor of PropagatorOperator takes as arguments the
# response R and noise covariance N operating on signal space and signal
# covariance operating on harmonic space.
class PropagatorOperator(ift.InversionEnabler, ift.EndomorphicOperator):
def __init__(self, R, N, Sh, inverter):
ift.InversionEnabler.__init__(self, inverter)
ift.EndomorphicOperator.__init__(self)
self.R = R
self.N = N
self.Sh = Sh
self.fft = ift.FFTOperator(R.domain, target=Sh.domain[0])
def _inverse_times(self, x):
return self.R.adjoint_times(self.N.inverse_times(self.R(x))) \
+ self.fft.adjoint_times(self.Sh.inverse_times(self.fft(x)))
@property
def domain(self):
return self.R.domain
@property
def unitary(self):
return False
@property
def self_adjoint(self):
return True
if __name__ == "__main__":
# Set up physical constants
# Total length of interval or volume the field lives on, e.g. in meters
......@@ -85,6 +55,6 @@ if __name__ == "__main__":
j = R.adjoint_times(N.inverse_times(d))
IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=IC)
D = PropagatorOperator(Sh=Sh, N=N, R=R, inverter=inverter)
D = (R.adjoint*N.inverse*R + fft.adjoint*Sh.inverse*fft).inverse
D = ift.InversionEnabler(D, inverter)
m = D(j)
......@@ -71,7 +71,7 @@ if __name__ == "__main__":
n_samples = 50
for i in range(n_samples):
sample = fft(D.generate_posterior_sample(m))
sample = fft(D.generate_posterior_sample() + m)
sample_variance += sample**2
sample_mean += sample
sample_mean /= n_samples
......
from ..operators import EndomorphicOperator, InversionEnabler, DiagonalOperator
class CriticalPowerCurvature(InversionEnabler, EndomorphicOperator):
class CriticalPowerCurvature(EndomorphicOperator):
"""The curvature of the CriticalPowerEnergy.
This operator implements the second derivative of the
......@@ -17,22 +17,17 @@ class CriticalPowerCurvature(InversionEnabler, EndomorphicOperator):
"""
def __init__(self, theta, T, inverter):
EndomorphicOperator.__init__(self)
self._theta = DiagonalOperator(theta)
InversionEnabler.__init__(self, inverter, self._theta.inverse_times)
self._T = T
def _times(self, x):
return self._T(x) + self._theta(x)
super(CriticalPowerCurvature, self).__init__()
theta = DiagonalOperator(theta)
self._op = InversionEnabler(T+theta, inverter, theta.inverse_times)
@property
def domain(self):
return self._theta.domain
def capability(self):
return self._op.capability
@property
def self_adjoint(self):
return True
def apply(self, x, mode):
return self._op.apply(x, mode)
@property
def unitary(self):
return False
def domain(self):
return self._op.domain
......@@ -70,7 +70,7 @@ class CriticalPowerEnergy(Energy):
if self.D is not None:
w = Field.zeros(self.position.domain, dtype=self.m.dtype)
for i in range(self.samples):
sample = self.D.generate_posterior_sample(self.m)
sample = self.D.generate_posterior_sample() + self.m
w += P(abs(sample)**2)
w *= 1./self.samples
......
......@@ -3,7 +3,7 @@ from ..utilities import memo
from ..field import exp
class LogNormalWienerFilterCurvature(InversionEnabler, EndomorphicOperator):
class LogNormalWienerFilterCurvature(EndomorphicOperator):
"""The curvature of the LogNormalWienerFilterEnergy.
This operator implements the second derivative of the
......@@ -21,33 +21,54 @@ class LogNormalWienerFilterCurvature(InversionEnabler, EndomorphicOperator):
The prior signal covariance
"""
class _Helper(EndomorphicOperator):
def __init__(self, R, N, S, position, fft4exp):
super(LogNormalWienerFilterCurvature._Helper, self).__init__()
self.R = R
self.N = N
self.S = S
self.position = position
self._fft = fft4exp
@property
def domain(self):
return self.S.domain
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
part1 = self.S.inverse_times(x)
part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x))
part3 = self._fft.adjoint_times(
self._expp_sspace *
self._fft(self.R.adjoint_times(
self.N.inverse_times(self.R(part3)))))
return part1 + part3
@property
@memo
def _expp_sspace(self):
return exp(self._fft(self.position))
def __init__(self, R, N, S, position, fft4exp, inverter):
InversionEnabler.__init__(self, inverter)
EndomorphicOperator.__init__(self)
self.R = R
self.N = N
self.S = S
self.position = position
self._fft = fft4exp
super(LogNormalWienerFilterCurvature, self).__init__()
self._op = self._Helper(R, N, S, position, fft4exp)
self._op = InversionEnabler(self._op, inverter)
@property
def domain(self):
return self.S.domain
return self._op.domain
@property
def self_adjoint(self):
return True
def _times(self, x):
part1 = self.S.inverse_times(x)
part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x))
part3 = self._fft.adjoint_times(
self._expp_sspace *
self._fft(self.R.adjoint_times(
self.N.inverse_times(self.R(part3)))))
return part1 + part3
def capability(self):
return self._op.capability
@property
@memo
def _expp_sspace(self):
return exp(self._fft(self.position))
return self._op._op._expp_sspace
def apply(self, x, mode):
return self._op.apply(x, mode)
......@@ -27,7 +27,7 @@ class NoiseEnergy(Energy):
if samples is None or samples == 0:
sample_list = [m]
else:
sample_list = [D.generate_posterior_sample(m)
sample_list = [D.generate_posterior_sample() + m
for _ in range(samples)]
self.sample_list = sample_list
self.inverter = inverter
......
......@@ -2,42 +2,60 @@ from ..operators import EndomorphicOperator, InversionEnabler
from .response_operators import LinearizedPowerResponse
class NonlinearPowerCurvature(InversionEnabler, EndomorphicOperator):
class NonlinearPowerCurvature(EndomorphicOperator):
class _Helper(EndomorphicOperator):
def __init__(self, position, FFT, Instrument, nonlinearity,
Projection, N, T, sample_list):
super(NonlinearPowerCurvature._Helper, self).__init__()
self.N = N
self.FFT = FFT
self.Instrument = Instrument
self.T = T
self.sample_list = sample_list
self.position = position
self.Projection = Projection
self.nonlinearity = nonlinearity
@property
def domain(self):
return self.position.domain
@property
def capability(self):
return self.TIMES
def apply(self, x, mode):
self._check_input(x, mode)
result = None
for sample in self.sample_list:
if result is None:
result = self._sample_times(x, sample)
else:
result += self._sample_times(x, sample)
result *= 1./len(self.sample_list)
return result + self.T(x)
def _sample_times(self, x, sample):
LinearizedResponse = LinearizedPowerResponse(
self.Instrument, self.nonlinearity, self.FFT, self.Projection,
self.position, sample)
return LinearizedResponse.adjoint_times(
self.N.inverse_times(LinearizedResponse(x)))
def __init__(self, position, FFT, Instrument, nonlinearity,
Projection, N, T, sample_list, inverter):
InversionEnabler.__init__(self, inverter)
EndomorphicOperator.__init__(self)
self.N = N
self.FFT = FFT
self.Instrument = Instrument
self.T = T
self.sample_list = sample_list
self.position = position
self.Projection = Projection
self.nonlinearity = nonlinearity
super(NonlinearPowerCurvature, self).__init__()
self._op = self._Helper(position, FFT, Instrument, nonlinearity,
Projection, N, T, sample_list)
self._op = InversionEnabler(self._op, inverter)
@property
def domain(self):
return self.position.domain
return self._op.domain
@property
def self_adjoint(self):
return True
def _times(self, x):
result = None
for sample in self.sample_list:
if result is None:
result = self._sample_times(x, sample)
else:
result += self._sample_times(x, sample)
result *= 1./len(self.sample_list)
return result + self.T(x)
def _sample_times(self, x, sample):
LinearizedResponse = LinearizedPowerResponse(
self.Instrument, self.nonlinearity, self.FFT, self.Projection,
self.position, sample)
return LinearizedResponse.adjoint_times(
self.N.inverse_times(LinearizedResponse(x)))
def capability(self):
return self._op.capability
def apply(self, x, mode):
return self._op.apply(x, mode)
......@@ -53,7 +53,7 @@ class NonlinearPowerEnergy(Energy):
if samples is None or samples == 0:
sample_list = [m]
else:
sample_list = [D.generate_posterior_sample(m)
sample_list = [D.generate_posterior_sample() + m
for _ in range(samples)]
self.sample_list = sample_list
self.inverter = inverter
......
......@@ -31,6 +31,14 @@ class LinearizedSignalResponse(LinearOperator):
def target(self):
return self.Instrument.target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
return self._times(x) if mode & self.TIMES else self._adjoint_times(x)
class LinearizedPowerResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m):
......@@ -70,3 +78,11 @@ class LinearizedPowerResponse(LinearOperator):
@property
def target(self):
return self.Instrument.target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
return self._times(x) if mode & self.TIMES else self._adjoint_times(x)
......@@ -3,7 +3,7 @@ from ..field import Field, sqrt
from ..sugar import power_analyze, power_synthesize
class WienerFilterCurvature(InversionEnabler, EndomorphicOperator):
class WienerFilterCurvature(EndomorphicOperator):
"""The curvature of the WienerFilterEnergy.
This operator implements the second derivative of the
......@@ -22,26 +22,25 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator):
"""
def __init__(self, R, N, S, inverter):
EndomorphicOperator.__init__(self)
InversionEnabler.__init__(self, inverter, S.times)
super(WienerFilterCurvature, self).__init__()
self.R = R
self.N = N
self.S = S
op = R.adjoint*N.inverse*R + S.inverse
self._op = InversionEnabler(op, inverter, S.times)
@property
def domain(self):
return self.S.domain
@property
def self_adjoint(self):
return True
def capability(self):
return self._op.capability
def _times(self, x):
res = self.R.adjoint_times(self.N.inverse_times(self.R(x)))
res += self.S.inverse_times(x)
return res
def apply(self, x, mode):
return self._op.apply(x, mode)
def generate_posterior_sample(self, mean):
def generate_posterior_sample(self):
""" Generates a posterior sample from a Gaussian distribution with
given mean and covariance.
......@@ -49,11 +48,6 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator):
reconstruction of a mock signal in order to obtain residuals of the
right correlation which are added to the given mean.
Parameters
----------
mean : Field
the mean of the posterior Gaussian distribution
Returns
-------
sample : Field
......@@ -74,5 +68,5 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator):
mock_j = self.R.adjoint_times(self.N.inverse_times(mock_data))
mock_m = self.inverse_times(mock_j)
sample = mock_signal - mock_m + mean
sample = mock_signal - mock_m
return sample
......@@ -27,7 +27,7 @@ class WienerFilterEnergy(Energy):
self.R = R
self.N = N
self.S = S
self._curvature = WienerFilterCurvature(R, N, S, inverter=inverter)
self._curvature = WienerFilterCurvature(R, N, S, inverter)
self._inverter = inverter
if _j is None:
_j = self.R.adjoint_times(self.N.inverse_times(d))
......
......@@ -12,3 +12,7 @@ from .laplace_operator import LaplaceOperator
from .smoothness_operator import SmoothnessOperator
from .power_projection_operator import PowerProjectionOperator
from .dof_projection_operator import DOFProjectionOperator
from .chain_operator import ChainOperator
from .sum_operator import SumOperator
from .inverse_operator import InverseOperator
from .adjoint_operator import AdjointOperator
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .linear_operator import LinearOperator
class AdjointOperator(LinearOperator):
def __init__(self, op):
super(AdjointOperator, self).__init__()
self._op = op
@property
def domain(self):
return self._op.target
@property
def target(self):
return self._op.domain
@property
def capability(self):
return self._adjointCapability[self._op.capability]
def apply(self, x, mode):
return self._op.apply(x, self._adjointMode[mode])
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .linear_operator import LinearOperator
class ChainOperator(LinearOperator):
def __init__(self, op1, op2):
super(ChainOperator, self).__init__()
if op2.target != op1.domain:
raise ValueError("domain mismatch")
self._op1 = op1
self._op2 = op2
@property
def domain(self):
return self._op2.domain
@property
def target(self):
return self._op1.target
@property
def capability(self):
return self._op1.capability & self._op2.capability
def apply(self, x, mode):
self._check_mode(mode)
if mode == self.TIMES or mode == self.ADJOINT_INVERSE_TIMES:
return self._op1.apply(self._op2.apply(x, mode), mode)
return self._op2.apply(self._op1.apply(x, mode), mode)
......@@ -16,7 +16,6 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import range
from .linear_operator import LinearOperator
......@@ -30,7 +29,6 @@ class ComposedOperator(LinearOperator):
operators : tuple of NIFTy Operators
The tuple of LinearOperators.
Attributes
----------
domain : DomainTuple
......@@ -44,6 +42,7 @@ class ComposedOperator(LinearOperator):
self._operator_store = ()
old_op = None
self._capability = operators[0].capability
for op in operators:
if not isinstance(op, LinearOperator):
raise TypeError("The elements of the operator list must be"
......@@ -51,7 +50,10 @@ class ComposedOperator(LinearOperator):
if old_op is not None and op.domain != old_op.target:
raise ValueError("incompatible domains")
self._operator_store += (op,)
self._capability &= op.capability
old_op = op
if self._capability == 0:
raise ValueError("composed operator does not support any mode")
@property
def domain(self):
......@@ -61,24 +63,16 @@ class ComposedOperator(LinearOperator):
def target(self):
return self._operator_store[-1].target
def _times(self, x):
return self._times_helper(x, func='times')
def _adjoint_times(self, x):
return self._inverse_times_helper(x, func='adjoint_times')
def _inverse_times(self, x):
return self._inverse_times_helper(x, func='inverse_times')
def _adjoint_inverse_times(self, x):
return self._times_helper(x, func='adjoint_inverse_times')
def _times_helper(self, x, func):
for op in self._operator_store:
x = getattr(op, func)(x)
return x
def _inverse_times_helper(self, x, func):
for op in reversed(self._operator_store):
x = getattr(op, func)(x)
@property
def capability(self):
return self._capability
def apply(self, x, mode):
self._check_mode(mode)
if mode == self.TIMES or mode == self.ADJOINT_INVERSE_TIMES:
for op in self._operator_store:
x = op.apply(x, mode)
else:
for op in reversed(self._operator_store):
x = op.apply(x, mode)
return x
......@@ -48,13 +48,6 @@ class DiagonalOperator(EndomorphicOperator):
----------
domain : DomainTuple
The domain on which the Operator's input Field lives.
target : DomainTuple
The domain in which the outcome of the operator lives. As the Operator
is endomorphic this is the same as its domain.
unitary : boolean
Indicates whether the Operator is unitary or not.
self_adjoint : boolean
Indicates whether the operator is self-adjoint or not.
NOTE: the fields given to __init__ and returned from .diagonal() are
considered to be non-bare, i.e. during operator application, no additional
......@@ -114,20 +107,23 @@ class DiagonalOperator(EndomorphicOperator):
else:
self._ldiag = dobj.local_data(self._diagonal.val)
self._self_adjoint = None
self._unitary = None
def apply(self, x, mode):
self._check_input(x, mode)
def _times(self, x):
return Field(x.domain, val=x