Commit dbe83a62 authored by Martin Reinecke's avatar Martin Reinecke

simplify lognormal curvature

parent 7c749218
Pipeline #23452 passed with stage
in 4 minutes and 36 seconds
from ..operators import EndomorphicOperator, InversionEnabler
from ..utilities import memo
from ..field import exp
class LogNormalWienerFilterCurvature(EndomorphicOperator):
"""The curvature of the LogNormalWienerFilterEnergy.
This operator implements the second derivative of the
LogNormalWienerFilterEnergy used in some minimization algorithms or for
error estimates of the posterior maps. It is the inverse of the propagator
R: LinearOperator,
The response operator of the Wiener filter measurement.
N: EndomorphicOperator
The noise covariance.
S: DiagonalOperator,
The prior signal covariance
class _Helper(EndomorphicOperator):
def __init__(self, R, N, S, position, fft4exp):
super(LogNormalWienerFilterCurvature._Helper, self).__init__()
self.R = R
self.N = N
self.S = S
self.position = position
self._fft = fft4exp
def domain(self):
return self.S.domain
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
part1 = self.S.inverse_times(x)
part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x))
part3 = self._fft.adjoint_times(
self._expp_sspace *
return part1 + part3
def _expp_sspace(self):
return exp(self._fft(self.position))
def __init__(self, R, N, S, position, fft4exp, inverter):
super(LogNormalWienerFilterCurvature, self).__init__()
self._op = self._Helper(R, N, S, position, fft4exp)
self._op = InversionEnabler(self._op, inverter)
def domain(self):
return self._op.domain
def capability(self):
return self._op.capability
def _expp_sspace(self):
return self._op._op._expp_sspace
def apply(self, x, mode):
return self._op.apply(x, mode)
from ..operators import InversionEnabler
def LogNormalWienerFilterCurvature(R, N, S, fft, expp_sspace, inverter):
part1 = S.inverse
part3 = (fft.adjoint * expp_sspace * fft *
R. adjoint * N.inverse * R *
fft.adjoint * expp_sspace * fft)
return InversionEnabler(part1 + part3, inverter)
......@@ -2,6 +2,7 @@ from import Energy
from ..utilities import memo
from .log_normal_wiener_filter_curvature import LogNormalWienerFilterCurvature
from ..sugar import create_composed_fft_operator
from ..field import exp
class LogNormalWienerFilterEnergy(Energy):
......@@ -58,18 +59,23 @@ class LogNormalWienerFilterEnergy(Energy):
def curvature(self):
return LogNormalWienerFilterCurvature(
R=self.R, N=self.N, S=self.S, position=self.position,
fft4exp=self._fft, inverter=self._inverter)
R=self.R, N=self.N, S=self.S, fft=self._fft,
expp_sspace = self._expp_sspace, inverter=self._inverter)
def _Sp(self):
return self.S.inverse_times(self.position)
def _expp_sspace(self):
return exp(self._fft(self.position))
def _Rexppd(self):
expp = self._fft.adjoint_times(self.curvature._expp_sspace)
expp = self._fft.adjoint_times(self._expp_sspace)
return self.R(expp) - self.d
......@@ -81,5 +87,5 @@ class LogNormalWienerFilterEnergy(Energy):
def _exppRNRexppd(self):
return self._fft.adjoint_times(
self.curvature._expp_sspace *
self._expp_sspace *
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment