Commit 694ca42a authored by Reimar H Leike's avatar Reimar H Leike

fixed tests, now WienerFilterCUrvature automatically enables sampling

parent a4b29dd0
Pipeline #30711 passed with stages
in 1 minute and 26 seconds
......@@ -24,7 +24,8 @@ from ..sugar import makeOp
class NonlinearWienerFilterEnergy(Energy):
def __init__(self, position, d, Instrument, nonlinearity, ht, power, N, S,
inverter=None):
inverter=None,
sampling_inverter=None):
super(NonlinearWienerFilterEnergy, self).__init__(position=position)
self.d = d.lock()
self.Instrument = Instrument
......@@ -37,6 +38,9 @@ class NonlinearWienerFilterEnergy(Energy):
self.N = N
self.S = S
self.inverter = inverter
if sampling_inverter==None:
sampling_inverter = inverter
self.sampling_inverter = sampling_inverter
t1 = S.inverse_times(position)
t2 = N.inverse_times(residual)
self._value = 0.5 * (position.vdot(t1) + residual.vdot(t2)).real
......@@ -60,4 +64,4 @@ class NonlinearWienerFilterEnergy(Energy):
@property
@memo
def curvature(self):
return WienerFilterCurvature(self.R, self.N, self.S, self.inverter)
return WienerFilterCurvature(self.R, self.N, self.S, self.inverter, self.sampling_inverter)
......@@ -18,9 +18,10 @@
from ..operators.sandwich_operator import SandwichOperator
from ..operators.inversion_enabler import InversionEnabler
from ..operators.sampling_enabler import SamplingEnabler
def WienerFilterCurvature(R, N, S, inverter):
def WienerFilterCurvature(R, N, S, inverter, sampling_inverter=None):
"""The curvature of the WienerFilterEnergy.
This operator implements the second derivative of the
......@@ -38,6 +39,15 @@ def WienerFilterCurvature(R, N, S, inverter):
The prior signal covariance
inverter : Minimizer
The minimizer to use during numerical inversion
sampling_inverter : Minimizer
The minimizer to use during numerical sampling
if None, it is not possible to draw inverse samples
default: None
"""
op = SandwichOperator.make(R, N.inverse) + S.inverse
M = SandwichOperator.make(R, N.inverse)
if sampling_inverter != None:
op = SamplingEnabler(M, S.inverse, sampling_inverter)
else:
op = M + S.inverse
return InversionEnabler(op, inverter, S.inverse)
......@@ -20,7 +20,7 @@ from ..minimization.quadratic_energy import QuadraticEnergy
from .wiener_filter_curvature import WienerFilterCurvature
def WienerFilterEnergy(position, d, R, N, S, inverter=None):
def WienerFilterEnergy(position, d, R, N, S, inverter=None, sampling_inverter=None):
"""The Energy for the Wiener filter.
It covers the case of linear measurement with
......@@ -42,7 +42,11 @@ def WienerFilterEnergy(position, d, R, N, S, inverter=None):
inverter : Minimizer, optional
the minimization strategy to use for operator inversion
If None, the energy object will not support curvature computation.
sampling_inverter : Minimizer, optional
The minimizer to use during numerical sampling
if None, it is not possible to draw inverse samples
default: None
"""
op = WienerFilterCurvature(R, N, S, inverter)
op = WienerFilterCurvature(R, N, S, inverter, sampling_inverter)
vec = R.adjoint_times(N.inverse_times(d))
return QuadraticEnergy(position, op, vec)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment