Commit b3b3a72a authored by Philipp Arras's avatar Philipp Arras

Do not pass inverters around anymore where it is not necessary

parent 0ac48f64
Pipeline #31375 failed with stages
in 4 minutes and 14 seconds
...@@ -169,10 +169,9 @@ ...@@ -169,10 +169,9 @@
"def Curvature(R, N, Sh):\n", "def Curvature(R, N, Sh):\n",
" IC = ift.GradientNormController(iteration_limit=50000,\n", " IC = ift.GradientNormController(iteration_limit=50000,\n",
" tol_abs_gradnorm=0.1)\n", " tol_abs_gradnorm=0.1)\n",
" inverter = ift.ConjugateGradient(controller=IC)\n",
" # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n", " # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n",
" # helper methods.\n", " # helper methods.\n",
" return ift.library.WienerFilterCurvature(R,N,Sh,inverter, sampling_inverter=inverter)" " return ift.library.WienerFilterCurvature(R,N,Sh,iteration_controller=IC)"
] ]
}, },
{ {
......
...@@ -85,15 +85,13 @@ if __name__ == "__main__": ...@@ -85,15 +85,13 @@ if __name__ == "__main__":
LS = ift.LineSearchStrongWolfe(c2=0.02) LS = ift.LineSearchStrongWolfe(c2=0.02)
minimizer = ift.RelaxedNewton(IC1, line_searcher=LS) minimizer = ift.RelaxedNewton(IC1, line_searcher=LS)
ICI = ift.GradientNormController(iteration_limit=500, IC = ift.GradientNormController(iteration_limit=500,
tol_abs_gradnorm=1e-3) tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=ICI)
for i in range(20): for i in range(20):
power0 = Distributor(ift.exp(0.5*t0)) power0 = Distributor(ift.exp(0.5*t0))
map0_energy = ift.library.NonlinearWienerFilterEnergy( map0_energy = ift.library.NonlinearWienerFilterEnergy(
m0, d, MeasurementOperator, nonlinearity, HT, power0, N, S, m0, d, MeasurementOperator, nonlinearity, HT, power0, N, S, IC)
inverter=inverter)
# Minimization with chosen minimizer # Minimization with chosen minimizer
map0_energy, convergence = minimizer(map0_energy) map0_energy, convergence = minimizer(map0_energy)
...@@ -106,7 +104,8 @@ if __name__ == "__main__": ...@@ -106,7 +104,8 @@ if __name__ == "__main__":
power0_energy = ift.library.NonlinearPowerEnergy( power0_energy = ift.library.NonlinearPowerEnergy(
position=t0, d=d, N=N, xi=m0, D=D0, ht=HT, position=t0, d=d, N=N, xi=m0, D=D0, ht=HT,
Instrument=MeasurementOperator, nonlinearity=nonlinearity, Instrument=MeasurementOperator, nonlinearity=nonlinearity,
Distributor=Distributor, sigma=1., samples=2, inverter=inverter) Distributor=Distributor, sigma=1., samples=2,
iteration_controller=IC)
power0_energy = minimizer(power0_energy)[0] power0_energy = minimizer(power0_energy)[0]
......
...@@ -78,15 +78,13 @@ if __name__ == "__main__": ...@@ -78,15 +78,13 @@ if __name__ == "__main__":
LS = ift.LineSearchStrongWolfe(c2=0.02) LS = ift.LineSearchStrongWolfe(c2=0.02)
minimizer = ift.RelaxedNewton(IC1, line_searcher=LS) minimizer = ift.RelaxedNewton(IC1, line_searcher=LS)
ICI = ift.GradientNormController(iteration_limit=500, IC = ift.GradientNormController(iteration_limit=500,
tol_abs_gradnorm=1e-3) tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=ICI)
for i in range(20): for i in range(20):
power0 = Distributor(ift.exp(0.5*t0)) power0 = Distributor(ift.exp(0.5*t0))
map0_energy = ift.library.NonlinearWienerFilterEnergy( map0_energy = ift.library.NonlinearWienerFilterEnergy(
m0, d, MeasurementOperator, nonlinearity, HT, power0, N, S, m0, d, MeasurementOperator, nonlinearity, HT, power0, N, S, IC)
inverter=inverter)
# Minimization with chosen minimizer # Minimization with chosen minimizer
map0_energy, convergence = minimizer(map0_energy) map0_energy, convergence = minimizer(map0_energy)
...@@ -99,7 +97,7 @@ if __name__ == "__main__": ...@@ -99,7 +97,7 @@ if __name__ == "__main__":
power0_energy = ift.library.NonlinearPowerEnergy( power0_energy = ift.library.NonlinearPowerEnergy(
position=t0, d=d, N=N, xi=m0, D=D0, ht=HT, position=t0, d=d, N=N, xi=m0, D=D0, ht=HT,
Instrument=MeasurementOperator, nonlinearity=nonlinearity, Instrument=MeasurementOperator, nonlinearity=nonlinearity,
Distributor=Distributor, sigma=1., samples=2, inverter=inverter) Distributor=Distributor, sigma=1., samples=2, iteration_controller=IC)
power0_energy = minimizer(power0_energy)[0] power0_energy = minimizer(power0_energy)[0]
......
...@@ -52,14 +52,13 @@ if __name__ == "__main__": ...@@ -52,14 +52,13 @@ if __name__ == "__main__":
LS = ift.LineSearchStrongWolfe(c2=0.02) LS = ift.LineSearchStrongWolfe(c2=0.02)
minimizer = ift.RelaxedNewton(IC1, line_searcher=LS) minimizer = ift.RelaxedNewton(IC1, line_searcher=LS)
ICI = ift.GradientNormController(iteration_limit=2000, IC = ift.GradientNormController(iteration_limit=2000,
tol_abs_gradnorm=1e-3) tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=ICI)
# initial guess # initial guess
m = ift.full(h_space, 1e-7) m = ift.full(h_space, 1e-7)
map_energy = ift.library.NonlinearWienerFilterEnergy( map_energy = ift.library.NonlinearWienerFilterEnergy(
m, d, R, nonlinearity, HT, power, N, S, inverter=inverter) m, d, R, nonlinearity, HT, power, N, S, IC)
# Minimization with chosen minimizer # Minimization with chosen minimizer
map_energy, convergence = minimizer(map_energy) map_energy, convergence = minimizer(map_energy)
......
...@@ -76,10 +76,9 @@ if __name__ == "__main__": ...@@ -76,10 +76,9 @@ if __name__ == "__main__":
ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=0.1) ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=0.1)
sampling_ctrl = ift.GradientNormController(name="sampling", sampling_ctrl = ift.GradientNormController(name="sampling",
tol_abs_gradnorm=1e2) tol_abs_gradnorm=1e2)
inverter = ift.ConjugateGradient(controller=ctrl)
sampling_inverter = ift.ConjugateGradient(controller=sampling_ctrl)
wiener_curvature = ift.library.WienerFilterCurvature( wiener_curvature = ift.library.WienerFilterCurvature(
S=S, N=N, R=R, inverter=inverter, sampling_inverter=sampling_inverter) S=S, N=N, R=R, iteration_controller=ctrl,
iteration_controller_sampling=sampling_ctrl)
m_k = wiener_curvature.inverse_times(j) m_k = wiener_curvature.inverse_times(j)
m = ht(m_k) m = ht(m_k)
......
...@@ -50,10 +50,9 @@ if __name__ == "__main__": ...@@ -50,10 +50,9 @@ if __name__ == "__main__":
ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=1e-2) ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=1e-2)
sampling_ctrl = ift.GradientNormController(name="sampling", sampling_ctrl = ift.GradientNormController(name="sampling",
tol_abs_gradnorm=2e1) tol_abs_gradnorm=2e1)
inverter = ift.ConjugateGradient(controller=ctrl)
sampling_inverter = ift.ConjugateGradient(controller=sampling_ctrl)
wiener_curvature = ift.library.WienerFilterCurvature( wiener_curvature = ift.library.WienerFilterCurvature(
S=S, N=N, R=R, inverter=inverter, sampling_inverter=sampling_inverter) S=S, N=N, R=R, iteration_controller=ctrl,
iteration_controller_sampling=sampling_ctrl)
m_k = wiener_curvature.inverse_times(j) m_k = wiener_curvature.inverse_times(j)
m = ht(m_k) m = ht(m_k)
......
...@@ -81,9 +81,8 @@ if __name__ == "__main__": ...@@ -81,9 +81,8 @@ if __name__ == "__main__":
j = R.adjoint_times(N.inverse_times(d)) j = R.adjoint_times(N.inverse_times(d))
IC = ift.GradientNormController(name="inverter", iteration_limit=500, IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=1e-3) tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=IC)
D = (ift.SandwichOperator.make(R, N.inverse) + Phi_h.inverse).inverse D = (ift.SandwichOperator.make(R, N.inverse) + Phi_h.inverse).inverse
D = ift.InversionEnabler(D, inverter, approximation=Phi_h) D = ift.InversionEnabler(D, IC, approximation=Phi_h)
m = HT(D(j)) m = HT(D(j))
# Uncertainty # Uncertainty
...@@ -116,8 +115,7 @@ if __name__ == "__main__": ...@@ -116,8 +115,7 @@ if __name__ == "__main__":
# initial guess # initial guess
psi0 = ift.full(h_domain, 1e-7) psi0 = ift.full(h_domain, 1e-7)
energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h, energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h, IC)
inverter)
IC1 = ift.GradientNormController(name="IC1", iteration_limit=200, IC1 = ift.GradientNormController(name="IC1", iteration_limit=200,
tol_abs_gradnorm=1e-4) tol_abs_gradnorm=1e-4)
minimizer = ift.RelaxedNewton(IC1) minimizer = ift.RelaxedNewton(IC1)
......
...@@ -39,17 +39,17 @@ N_iter = 100 ...@@ -39,17 +39,17 @@ N_iter = 100
tol = 1e-3 tol = 1e-3
IC = ift.GradientNormController(tol_abs_gradnorm=tol, iteration_limit=N_iter) IC = ift.GradientNormController(tol_abs_gradnorm=tol, iteration_limit=N_iter)
inverter = ift.ConjugateGradient(IC) curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p,
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter, iteration_controller=IC,
sampling_inverter=inverter) iteration_controller_sampling=IC)
m_xi = curv.inverse_times(j) m_xi = curv.inverse_times(j)
samps_long = [curv.draw_sample(from_inverse=True) for i in range(N_samps)] samps_long = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]
tol = 1e2 tol = 1e2
IC = ift.GradientNormController(tol_abs_gradnorm=tol, iteration_limit=N_iter) IC = ift.GradientNormController(tol_abs_gradnorm=tol, iteration_limit=N_iter)
inverter = ift.ConjugateGradient(IC) curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p,
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter, iteration_controller=IC,
sampling_inverter=inverter) iteration_controller_sampling=IC)
samps_short = [curv.draw_sample(from_inverse=True) for i in range(N_samps)] samps_short = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]
# Compute mean # Compute mean
......
...@@ -38,8 +38,8 @@ if __name__ == "__main__": ...@@ -38,8 +38,8 @@ if __name__ == "__main__":
j = Rh.adjoint_times(N.inverse_times(d)) j = Rh.adjoint_times(N.inverse_times(d))
ctrl = ift.GradientNormController(name="Iter", tol_abs_gradnorm=1e-10, ctrl = ift.GradientNormController(name="Iter", tol_abs_gradnorm=1e-10,
iteration_limit=300) iteration_limit=300)
inverter = ift.ConjugateGradient(controller=ctrl) Di = ift.library.WienerFilterCurvature(S=S, R=Rh, N=N,
Di = ift.library.WienerFilterCurvature(S=S, R=Rh, N=N, inverter=inverter) iteration_controller=ctrl)
mh = Di.inverse_times(j) mh = Di.inverse_times(j)
m = ht(mh) m = ht(mh)
......
...@@ -98,10 +98,9 @@ if __name__ == "__main__": ...@@ -98,10 +98,9 @@ if __name__ == "__main__":
IC = ift.GradientNormController(name="inverter", iteration_limit=1000, IC = ift.GradientNormController(name="inverter", iteration_limit=1000,
tol_abs_gradnorm=0.0001) tol_abs_gradnorm=0.0001)
inverter = ift.ConjugateGradient(controller=IC)
# setting up measurement precision matrix M # setting up measurement precision matrix M
M = (ift.SandwichOperator.make(R.adjoint, Sh) + N) M = (ift.SandwichOperator.make(R.adjoint, Sh) + N)
M = ift.InversionEnabler(M, inverter) M = ift.InversionEnabler(M, IC)
m = Sh(R.adjoint(M.inverse_times(d))) m = Sh(R.adjoint(M.inverse_times(d)))
# Plotting # Plotting
......
...@@ -52,9 +52,8 @@ if __name__ == "__main__": ...@@ -52,9 +52,8 @@ if __name__ == "__main__":
j = R.adjoint_times(N.inverse_times(d)) j = R.adjoint_times(N.inverse_times(d))
IC = ift.GradientNormController(name="inverter", iteration_limit=500, IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=0.1) tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=IC)
D = (ift.SandwichOperator.make(R, N.inverse) + Sh.inverse).inverse D = (ift.SandwichOperator.make(R, N.inverse) + Sh.inverse).inverse
D = ift.InversionEnabler(D, inverter, approximation=Sh) D = ift.InversionEnabler(D, IC, approximation=Sh)
m = D(j) m = D(j)
# Plotting # Plotting
......
...@@ -78,9 +78,8 @@ if __name__ == "__main__": ...@@ -78,9 +78,8 @@ if __name__ == "__main__":
j = R.adjoint_times(N.inverse_times(data)) j = R.adjoint_times(N.inverse_times(data))
ctrl = ift.GradientNormController( ctrl = ift.GradientNormController(
name="inverter", tol_abs_gradnorm=1e-5/(nu.K*(nu.m**dimensionality))) name="inverter", tol_abs_gradnorm=1e-5/(nu.K*(nu.m**dimensionality)))
inverter = ift.ConjugateGradient(controller=ctrl) wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R,
wiener_curvature = ift.library.WienerFilterCurvature( iteration_controller=ctrl)
S=S, N=N, R=R, inverter=inverter)
m = wiener_curvature.inverse_times(j) m = wiener_curvature.inverse_times(j)
m_s = HT(m) m_s = HT(m)
......
...@@ -47,15 +47,14 @@ if __name__ == "__main__": ...@@ -47,15 +47,14 @@ if __name__ == "__main__":
# Choose minimization strategy # Choose minimization strategy
ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=0.1) ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=ctrl)
controller = ift.GradientNormController(name="min", tol_abs_gradnorm=0.1) controller = ift.GradientNormController(name="min", tol_abs_gradnorm=0.1)
minimizer = ift.RelaxedNewton(controller=controller) minimizer = ift.RelaxedNewton(controller=controller)
m0 = ift.full(h_space, 0.) m0 = ift.full(h_space, 0.)
# Initialize Wiener filter energy # Initialize Wiener filter energy
energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S, energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S,
inverter=inverter, iteration_controller=ctrl,
sampling_inverter=inverter) iteration_controller_sampling=ctrl)
energy, convergence = minimizer(energy) energy, convergence = minimizer(energy)
m = energy.position m = energy.position
......
...@@ -63,7 +63,7 @@ class NonlinearPowerEnergy(Energy): ...@@ -63,7 +63,7 @@ class NonlinearPowerEnergy(Energy):
# MR FIXME: docstring incomplete and outdated # MR FIXME: docstring incomplete and outdated
def __init__(self, position, d, N, xi, D, ht, Instrument, nonlinearity, def __init__(self, position, d, N, xi, D, ht, Instrument, nonlinearity,
Distributor, sigma=0., samples=3, xi_sample_list=None, Distributor, sigma=0., samples=3, xi_sample_list=None,
inverter=None): iteration_controller=None):
super(NonlinearPowerEnergy, self).__init__(position) super(NonlinearPowerEnergy, self).__init__(position)
self.xi = xi self.xi = xi
self.D = D self.D = D
...@@ -83,7 +83,7 @@ class NonlinearPowerEnergy(Energy): ...@@ -83,7 +83,7 @@ class NonlinearPowerEnergy(Energy):
xi_sample_list = [D.draw_sample(from_inverse=True) + xi xi_sample_list = [D.draw_sample(from_inverse=True) + xi
for _ in range(samples)] for _ in range(samples)]
self.xi_sample_list = xi_sample_list self.xi_sample_list = xi_sample_list
self.inverter = inverter self._ic = iteration_controller
A = Distributor(exp(.5 * position)) A = Distributor(exp(.5 * position))
...@@ -118,7 +118,7 @@ class NonlinearPowerEnergy(Energy): ...@@ -118,7 +118,7 @@ class NonlinearPowerEnergy(Energy):
self.Distributor, sigma=self.sigma, self.Distributor, sigma=self.sigma,
samples=len(self.xi_sample_list), samples=len(self.xi_sample_list),
xi_sample_list=self.xi_sample_list, xi_sample_list=self.xi_sample_list,
inverter=self.inverter) iteration_controller=self._ic)
@property @property
def value(self): def value(self):
...@@ -139,4 +139,4 @@ class NonlinearPowerEnergy(Energy): ...@@ -139,4 +139,4 @@ class NonlinearPowerEnergy(Energy):
op = LinearizedResponse.adjoint*self.N.inverse*LinearizedResponse op = LinearizedResponse.adjoint*self.N.inverse*LinearizedResponse
result = op if result is None else result + op result = op if result is None else result + op
result = result*(1./len(self.xi_sample_list)) + self.T result = result*(1./len(self.xi_sample_list)) + self.T
return InversionEnabler(result, self.inverter) return InversionEnabler(result, self._ic)
...@@ -24,8 +24,7 @@ from ..sugar import makeOp ...@@ -24,8 +24,7 @@ from ..sugar import makeOp
class NonlinearWienerFilterEnergy(Energy): class NonlinearWienerFilterEnergy(Energy):
def __init__(self, position, d, Instrument, nonlinearity, ht, power, N, S, def __init__(self, position, d, Instrument, nonlinearity, ht, power, N, S,
inverter=None, iteration_controller=None, iteration_controller_sampling=None):
sampling_inverter=None):
super(NonlinearWienerFilterEnergy, self).__init__(position=position) super(NonlinearWienerFilterEnergy, self).__init__(position=position)
self.d = d.lock() self.d = d.lock()
self.Instrument = Instrument self.Instrument = Instrument
...@@ -37,10 +36,10 @@ class NonlinearWienerFilterEnergy(Energy): ...@@ -37,10 +36,10 @@ class NonlinearWienerFilterEnergy(Energy):
residual = d - Instrument(nonlinearity(m)) residual = d - Instrument(nonlinearity(m))
self.N = N self.N = N
self.S = S self.S = S
self.inverter = inverter self._ic = iteration_controller
if sampling_inverter is None: if iteration_controller_sampling is None:
sampling_inverter = inverter iteration_controller_sampling = self._ic
self.sampling_inverter = sampling_inverter self._ic_samp = iteration_controller_sampling
t1 = S.inverse_times(position) t1 = S.inverse_times(position)
t2 = N.inverse_times(residual) t2 = N.inverse_times(residual)
self._value = 0.5 * (position.vdot(t1) + residual.vdot(t2)).real self._value = 0.5 * (position.vdot(t1) + residual.vdot(t2)).real
...@@ -51,7 +50,7 @@ class NonlinearWienerFilterEnergy(Energy): ...@@ -51,7 +50,7 @@ class NonlinearWienerFilterEnergy(Energy):
def at(self, position): def at(self, position):
return self.__class__(position, self.d, self.Instrument, return self.__class__(position, self.d, self.Instrument,
self.nonlinearity, self.ht, self.power, self.N, self.nonlinearity, self.ht, self.power, self.N,
self.S, self.inverter) self.S, self._ic, self._ic_samp)
@property @property
def value(self): def value(self):
...@@ -64,5 +63,5 @@ class NonlinearWienerFilterEnergy(Energy): ...@@ -64,5 +63,5 @@ class NonlinearWienerFilterEnergy(Energy):
@property @property
@memo @memo
def curvature(self): def curvature(self):
return WienerFilterCurvature(self.R, self.N, self.S, self.inverter, return WienerFilterCurvature(self.R, self.N, self.S, self._ic,
self.sampling_inverter) self._ic_samp)
...@@ -25,9 +25,9 @@ from ..sugar import log ...@@ -25,9 +25,9 @@ from ..sugar import log
class PoissonEnergy(Energy): class PoissonEnergy(Energy):
def __init__(self, position, d, Instrument, nonlinearity, ht, Phi_h, def __init__(self, position, d, Instrument, nonlinearity, ht, Phi_h,
inverter=None): iteration_controller=None):
super(PoissonEnergy, self).__init__(position=position) super(PoissonEnergy, self).__init__(position=position)
self._inverter = inverter self._ic = iteration_controller
self._d = d self._d = d
self._Instrument = Instrument self._Instrument = Instrument
self._nonlinearity = nonlinearity self._nonlinearity = nonlinearity
...@@ -51,7 +51,7 @@ class PoissonEnergy(Energy): ...@@ -51,7 +51,7 @@ class PoissonEnergy(Energy):
def at(self, position): def at(self, position):
return self.__class__(position, self._d, self._Instrument, return self.__class__(position, self._d, self._Instrument,
self._nonlinearity, self._ht, self._Phi_h, self._nonlinearity, self._ht, self._Phi_h,
self._inverter) self._ic)
@property @property
def value(self): def value(self):
...@@ -63,5 +63,4 @@ class PoissonEnergy(Energy): ...@@ -63,5 +63,4 @@ class PoissonEnergy(Energy):
@property @property
def curvature(self): def curvature(self):
return InversionEnabler(self._curv, self._inverter, return InversionEnabler(self._curv, self._ic, self._Phi_h.inverse)
approximation=self._Phi_h.inverse)
...@@ -21,7 +21,8 @@ from ..operators.inversion_enabler import InversionEnabler ...@@ -21,7 +21,8 @@ from ..operators.inversion_enabler import InversionEnabler
from ..operators.sampling_enabler import SamplingEnabler from ..operators.sampling_enabler import SamplingEnabler
def WienerFilterCurvature(R, N, S, inverter, sampling_inverter=None): def WienerFilterCurvature(R, N, S, iteration_controller=None,
iteration_controller_sampling=None):
"""The curvature of the WienerFilterEnergy. """The curvature of the WienerFilterEnergy.
This operator implements the second derivative of the This operator implements the second derivative of the
...@@ -37,16 +38,16 @@ def WienerFilterCurvature(R, N, S, inverter, sampling_inverter=None): ...@@ -37,16 +38,16 @@ def WienerFilterCurvature(R, N, S, inverter, sampling_inverter=None):
The noise covariance. The noise covariance.
S : DiagonalOperator S : DiagonalOperator
The prior signal covariance The prior signal covariance
inverter : Minimizer iteration_controller : IterationController
The minimizer to use during numerical inversion The iteration controller to use during numerical inversion via
sampling_inverter : Minimizer ConjugateGradient.
The minimizer to use during numerical sampling iteration_controller_sampling : IterationController
if None, it is not possible to draw inverse samples The iteration controller to use for sampling.
default: None
""" """
M = SandwichOperator.make(R, N.inverse) M = SandwichOperator.make(R, N.inverse)
if sampling_inverter is not None: if iteration_controller is not None:
op = SamplingEnabler(M, S.inverse, sampling_inverter, S.inverse) op = SamplingEnabler(M, S.inverse, iteration_controller_sampling,
S.inverse)
else: else:
op = M + S.inverse op = M + S.inverse
return InversionEnabler(op, inverter, S.inverse) return InversionEnabler(op, iteration_controller, S.inverse)
...@@ -20,8 +20,8 @@ from ..minimization.quadratic_energy import QuadraticEnergy ...@@ -20,8 +20,8 @@ from ..minimization.quadratic_energy import QuadraticEnergy
from .wiener_filter_curvature import WienerFilterCurvature from .wiener_filter_curvature import WienerFilterCurvature
def WienerFilterEnergy(position, d, R, N, S, inverter=None, def WienerFilterEnergy(position, d, R, N, S, iteration_controller=None,
sampling_inverter=None): iteration_controller_sampling=None):
"""The Energy for the Wiener filter. """The Energy for the Wiener filter.
It covers the case of linear measurement with It covers the case of linear measurement with
...@@ -48,6 +48,7 @@ def WienerFilterEnergy(position, d, R, N, S, inverter=None, ...@@ -48,6 +48,7 @@ def WienerFilterEnergy(position, d, R, N, S, inverter=None,
if None, it is not possible to draw inverse samples if None, it is not possible to draw inverse samples
default: None default: None
""" """
op = WienerFilterCurvature(R, N, S, inverter, sampling_inverter) op = WienerFilterCurvature(R, N, S, iteration_controller,
iteration_controller_sampling)
vec = R.adjoint_times(N.inverse_times(d)) vec = R.adjoint_times(N.inverse_times(d))
return QuadraticEnergy(position, op, vec) return QuadraticEnergy(position, op, vec)
...@@ -61,6 +61,4 @@ class EnergySum(Energy): ...@@ -61,6 +61,4 @@ class EnergySum(Energy):
if precon is None and self._precon_idx is not None: if precon is None and self._precon_idx is not None:
precon = self._energies[self._precon_idx].curvature precon = self._energies[self._precon_idx].curvature
from ..operators.inversion_enabler import InversionEnabler from ..operators.inversion_enabler import InversionEnabler
from .conjugate_gradient import ConjugateGradient return InversionEnabler(res, self._min_controller, precon)
return InversionEnabler(
res, ConjugateGradient(self._min_controller), precon)
...@@ -16,11 +16,13 @@ ...@@ -16,11 +16,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from ..minimization.quadratic_energy import QuadraticEnergy import</