Commit dbe83a62 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

simplify lognormal curvature

parent 7c749218
Pipeline #23452 passed with stage
in 4 minutes and 36 seconds
from ..operators import EndomorphicOperator, InversionEnabler from ..operators import InversionEnabler
from ..utilities import memo
from ..field import exp def LogNormalWienerFilterCurvature(R, N, S, fft, expp_sspace, inverter):
part1 = S.inverse
part3 = (fft.adjoint * expp_sspace * fft *
class LogNormalWienerFilterCurvature(EndomorphicOperator): R. adjoint * N.inverse * R *
"""The curvature of the LogNormalWienerFilterEnergy. fft.adjoint * expp_sspace * fft)
return InversionEnabler(part1 + part3, inverter)
This operator implements the second derivative of the
LogNormalWienerFilterEnergy used in some minimization algorithms or for
error estimates of the posterior maps. It is the inverse of the propagator
operator.
Parameters
----------
R: LinearOperator,
The response operator of the Wiener filter measurement.
N: EndomorphicOperator
The noise covariance.
S: DiagonalOperator,
The prior signal covariance
"""
class _Helper(EndomorphicOperator):
def __init__(self, R, N, S, position, fft4exp):
super(LogNormalWienerFilterCurvature._Helper, self).__init__()
self.R = R
self.N = N
self.S = S
self.position = position
self._fft = fft4exp
@property
def domain(self):
return self.S.domain
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
part1 = self.S.inverse_times(x)
part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x))
part3 = self._fft.adjoint_times(
self._expp_sspace *
self._fft(self.R.adjoint_times(
self.N.inverse_times(self.R(part3)))))
return part1 + part3
@property
@memo
def _expp_sspace(self):
return exp(self._fft(self.position))
def __init__(self, R, N, S, position, fft4exp, inverter):
super(LogNormalWienerFilterCurvature, self).__init__()
self._op = self._Helper(R, N, S, position, fft4exp)
self._op = InversionEnabler(self._op, inverter)
@property
def domain(self):
return self._op.domain
@property
def capability(self):
return self._op.capability
@property
def _expp_sspace(self):
return self._op._op._expp_sspace
def apply(self, x, mode):
return self._op.apply(x, mode)
...@@ -2,6 +2,7 @@ from ..minimization.energy import Energy ...@@ -2,6 +2,7 @@ from ..minimization.energy import Energy
from ..utilities import memo from ..utilities import memo
from .log_normal_wiener_filter_curvature import LogNormalWienerFilterCurvature from .log_normal_wiener_filter_curvature import LogNormalWienerFilterCurvature
from ..sugar import create_composed_fft_operator from ..sugar import create_composed_fft_operator
from ..field import exp
class LogNormalWienerFilterEnergy(Energy): class LogNormalWienerFilterEnergy(Energy):
...@@ -58,18 +59,23 @@ class LogNormalWienerFilterEnergy(Energy): ...@@ -58,18 +59,23 @@ class LogNormalWienerFilterEnergy(Energy):
@memo @memo
def curvature(self): def curvature(self):
return LogNormalWienerFilterCurvature( return LogNormalWienerFilterCurvature(
R=self.R, N=self.N, S=self.S, position=self.position, R=self.R, N=self.N, S=self.S, fft=self._fft,
fft4exp=self._fft, inverter=self._inverter) expp_sspace = self._expp_sspace, inverter=self._inverter)
@property @property
@memo @memo
def _Sp(self): def _Sp(self):
return self.S.inverse_times(self.position) return self.S.inverse_times(self.position)
@property
@memo
def _expp_sspace(self):
return exp(self._fft(self.position))
@property @property
@memo @memo
def _Rexppd(self): def _Rexppd(self):
expp = self._fft.adjoint_times(self.curvature._expp_sspace) expp = self._fft.adjoint_times(self._expp_sspace)
return self.R(expp) - self.d return self.R(expp) - self.d
@property @property
...@@ -81,5 +87,5 @@ class LogNormalWienerFilterEnergy(Energy): ...@@ -81,5 +87,5 @@ class LogNormalWienerFilterEnergy(Energy):
@memo @memo
def _exppRNRexppd(self): def _exppRNRexppd(self):
return self._fft.adjoint_times( return self._fft.adjoint_times(
self.curvature._expp_sspace * self._expp_sspace *
self._fft(self.R.adjoint_times(self._NRexppd))) self._fft(self.R.adjoint_times(self._NRexppd)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment