Commit dbe83a62 by Martin Reinecke

### simplify lognormal curvature

parent 7c749218
Pipeline #23452 passed with stage
in 4 minutes and 36 seconds
 from ..operators import EndomorphicOperator, InversionEnabler from ..utilities import memo from ..field import exp class LogNormalWienerFilterCurvature(EndomorphicOperator): """The curvature of the LogNormalWienerFilterEnergy. This operator implements the second derivative of the LogNormalWienerFilterEnergy used in some minimization algorithms or for error estimates of the posterior maps. It is the inverse of the propagator operator. Parameters ---------- R: LinearOperator, The response operator of the Wiener filter measurement. N: EndomorphicOperator The noise covariance. S: DiagonalOperator, The prior signal covariance """ class _Helper(EndomorphicOperator): def __init__(self, R, N, S, position, fft4exp): super(LogNormalWienerFilterCurvature._Helper, self).__init__() self.R = R self.N = N self.S = S self.position = position self._fft = fft4exp @property def domain(self): return self.S.domain @property def capability(self): return self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): self._check_input(x, mode) part1 = self.S.inverse_times(x) part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x)) part3 = self._fft.adjoint_times( self._expp_sspace * self._fft(self.R.adjoint_times( self.N.inverse_times(self.R(part3))))) return part1 + part3 @property @memo def _expp_sspace(self): return exp(self._fft(self.position)) def __init__(self, R, N, S, position, fft4exp, inverter): super(LogNormalWienerFilterCurvature, self).__init__() self._op = self._Helper(R, N, S, position, fft4exp) self._op = InversionEnabler(self._op, inverter) @property def domain(self): return self._op.domain @property def capability(self): return self._op.capability @property def _expp_sspace(self): return self._op._op._expp_sspace def apply(self, x, mode): return self._op.apply(x, mode) from ..operators import InversionEnabler def LogNormalWienerFilterCurvature(R, N, S, fft, expp_sspace, inverter): part1 = S.inverse part3 = (fft.adjoint * expp_sspace * fft * R. adjoint * N.inverse * R * fft.adjoint * expp_sspace * fft) return InversionEnabler(part1 + part3, inverter)
 ... ... @@ -2,6 +2,7 @@ from ..minimization.energy import Energy from ..utilities import memo from .log_normal_wiener_filter_curvature import LogNormalWienerFilterCurvature from ..sugar import create_composed_fft_operator from ..field import exp class LogNormalWienerFilterEnergy(Energy): ... ... @@ -58,18 +59,23 @@ class LogNormalWienerFilterEnergy(Energy): @memo def curvature(self): return LogNormalWienerFilterCurvature( R=self.R, N=self.N, S=self.S, position=self.position, fft4exp=self._fft, inverter=self._inverter) R=self.R, N=self.N, S=self.S, fft=self._fft, expp_sspace = self._expp_sspace, inverter=self._inverter) @property @memo def _Sp(self): return self.S.inverse_times(self.position) @property @memo def _expp_sspace(self): return exp(self._fft(self.position)) @property @memo def _Rexppd(self): expp = self._fft.adjoint_times(self.curvature._expp_sspace) expp = self._fft.adjoint_times(self._expp_sspace) return self.R(expp) - self.d @property ... ... @@ -81,5 +87,5 @@ class LogNormalWienerFilterEnergy(Energy): @memo def _exppRNRexppd(self): return self._fft.adjoint_times( self.curvature._expp_sspace * self._expp_sspace * self._fft(self.R.adjoint_times(self._NRexppd)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!