Commit dbe83a62 authored by Martin Reinecke's avatar Martin Reinecke

simplify lognormal curvature

parent 7c749218
Pipeline #23452 passed with stage
in 4 minutes and 36 seconds
from ..operators import EndomorphicOperator, InversionEnabler
from ..utilities import memo
from ..field import exp
class LogNormalWienerFilterCurvature(EndomorphicOperator):
"""The curvature of the LogNormalWienerFilterEnergy.
This operator implements the second derivative of the
LogNormalWienerFilterEnergy used in some minimization algorithms or for
error estimates of the posterior maps. It is the inverse of the propagator
operator.
Parameters
----------
R: LinearOperator,
The response operator of the Wiener filter measurement.
N: EndomorphicOperator
The noise covariance.
S: DiagonalOperator,
The prior signal covariance
"""
class _Helper(EndomorphicOperator):
def __init__(self, R, N, S, position, fft4exp):
super(LogNormalWienerFilterCurvature._Helper, self).__init__()
self.R = R
self.N = N
self.S = S
self.position = position
self._fft = fft4exp
@property
def domain(self):
return self.S.domain
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
part1 = self.S.inverse_times(x)
part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x))
part3 = self._fft.adjoint_times(
self._expp_sspace *
self._fft(self.R.adjoint_times(
self.N.inverse_times(self.R(part3)))))
return part1 + part3
@property
@memo
def _expp_sspace(self):
return exp(self._fft(self.position))
def __init__(self, R, N, S, position, fft4exp, inverter):
super(LogNormalWienerFilterCurvature, self).__init__()
self._op = self._Helper(R, N, S, position, fft4exp)
self._op = InversionEnabler(self._op, inverter)
@property
def domain(self):
return self._op.domain
@property
def capability(self):
return self._op.capability
@property
def _expp_sspace(self):
return self._op._op._expp_sspace
def apply(self, x, mode):
return self._op.apply(x, mode)
from ..operators import InversionEnabler
def LogNormalWienerFilterCurvature(R, N, S, fft, expp_sspace, inverter):
part1 = S.inverse
part3 = (fft.adjoint * expp_sspace * fft *
R. adjoint * N.inverse * R *
fft.adjoint * expp_sspace * fft)
return InversionEnabler(part1 + part3, inverter)
......@@ -2,6 +2,7 @@ from ..minimization.energy import Energy
from ..utilities import memo
from .log_normal_wiener_filter_curvature import LogNormalWienerFilterCurvature
from ..sugar import create_composed_fft_operator
from ..field import exp
class LogNormalWienerFilterEnergy(Energy):
......@@ -58,18 +59,23 @@ class LogNormalWienerFilterEnergy(Energy):
@memo
def curvature(self):
return LogNormalWienerFilterCurvature(
R=self.R, N=self.N, S=self.S, position=self.position,
fft4exp=self._fft, inverter=self._inverter)
R=self.R, N=self.N, S=self.S, fft=self._fft,
expp_sspace = self._expp_sspace, inverter=self._inverter)
@property
@memo
def _Sp(self):
return self.S.inverse_times(self.position)
@property
@memo
def _expp_sspace(self):
return exp(self._fft(self.position))
@property
@memo
def _Rexppd(self):
expp = self._fft.adjoint_times(self.curvature._expp_sspace)
expp = self._fft.adjoint_times(self._expp_sspace)
return self.R(expp) - self.d
@property
......@@ -81,5 +87,5 @@ class LogNormalWienerFilterEnergy(Energy):
@memo
def _exppRNRexppd(self):
return self._fft.adjoint_times(
self.curvature._expp_sspace *
self._expp_sspace *
self._fft(self.R.adjoint_times(self._NRexppd)))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment