Commit f16ee8d4 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

rework preconditioner passing

parent 5f82b432
Pipeline #17939 passed with stage
in 3 minutes and 20 seconds
......@@ -47,7 +47,7 @@ if __name__ == "__main__":
m0 = ift.Field(harmonic_space, val=0.)
ctrl = ift.DefaultIterationController(verbose=False,tol_abs_gradnorm=1)
ctrl2 = ift.DefaultIterationController(verbose=True,tol_abs_gradnorm=0.1, name="outer")
inverter = ift.ConjugateGradient(controller=ctrl, preconditioner=S.times)
inverter = ift.ConjugateGradient(controller=ctrl)
energy = ift.library.LogNormalWienerFilterEnergy(m0, data, R_harmonic, N, S, inverter=inverter)
minimizer1 = ift.VL_BFGS(controller=ctrl2,max_history_length=20)
minimizer2 = ift.RelaxedNewton(controller=ctrl2)
......@@ -64,12 +64,12 @@ if __name__ == "__main__":
# Probing the variance
class Proby(ift.DiagonalProberMixin, ift.Prober): pass
proby = Proby(signal_space, probe_count=100)
proby(lambda z: fft(wiener_curvature.inverse_times(fft.inverse_times(z))))
#class Proby(ift.DiagonalProberMixin, ift.Prober): pass
#proby = Proby(signal_space, probe_count=100)
#proby(lambda z: fft(wiener_curvature.inverse_times(fft.inverse_times(z))))
sm = SmoothingOperator(signal_space, sigma=0.02)
variance = sm(proby.diagonal.weight(-1))
#sm = SmoothingOperator(signal_space, sigma=0.02)
#variance = sm(proby.diagonal.weight(-1))
#Plotting #|\label{code:wf_plotting}|
#plotter = plotting.RG2DPlotter(color_map=plotting.colormaps.PlankCmap())
......@@ -79,7 +79,7 @@ if __name__ == "__main__":
plotter.figure.yaxis = ift.plotting.Axis(label='Pixel Index')
plotter.plot.zmax = 5; plotter.plot.zmin = -5
plotter(variance, path = 'variance.html')
#plotter(variance, path = 'variance.html')
# #plotter.plot.zmin = exp(mock_signal.min());
# plotter(mock_signal.real, path='mock_signal.html')
# plotter(Field(signal_space, val=np.log(data.val.get_full_data().real).reshape(signal_space.shape)),
......
......@@ -93,7 +93,7 @@ if __name__ == "__main__":
# Wiener filter
j = R_harmonic.adjoint_times(N.inverse_times(data))
ctrl = ift.DefaultIterationController(verbose=True, tol_custom=1e-3, convergence_level=3)
inverter = ift.ConjugateGradient(controller=ctrl,preconditioner=S.times)
inverter = ift.ConjugateGradient(controller=ctrl)
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic, inverter=inverter)
m_k = wiener_curvature.inverse_times(j) #|\label{code:wf_wiener_filter}|
......
......@@ -46,7 +46,7 @@ if __name__ == "__main__":
# Wiener filter
j = R_harmonic.adjoint_times(N.inverse_times(data))
ctrl = ift.DefaultIterationController(verbose=False,tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=ctrl,preconditioner=S.times)
inverter = ift.ConjugateGradient(controller=ctrl)
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic,inverter=inverter)
m_k = wiener_curvature.inverse_times(j) #|\label{code:wf_wiener_filter}|
m = fft(m_k)
......
......@@ -75,7 +75,7 @@ if __name__ == "__main__":
# Choosing the minimization strategy
ctrl = ift.DefaultIterationController(verbose=True,tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=ctrl, preconditioner=S.times)
inverter = ift.ConjugateGradient(controller=ctrl)
# Setting starting position
m0 = ift.Field(h_space, val=.0)
......
......@@ -22,13 +22,17 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
"""
def __init__(self, R, N, S, inverter, **kwargs):
def __init__(self, R, N, S, inverter, preconditioner=None, **kwargs):
self.R = R
self.N = N
self.S = S
if preconditioner is None:
preconditioner = self.S.times
self._domain = self.S.domain
super(WienerFilterCurvature, self).__init__(inverter=inverter,
**kwargs)
super(WienerFilterCurvature, self).__init__(
inverter=inverter,
preconditioner=preconditioner,
**kwargs)
@property
def domain(self):
......
......@@ -46,7 +46,7 @@ class ConjugateGradient(Minimizer):
self._preconditioner = preconditioner
self._controller = controller
def __call__(self, energy):
def __call__(self, energy, preconditioner=None):
""" Runs the conjugate gradient minimization.
Parameters
......@@ -64,6 +64,9 @@ class ConjugateGradient(Minimizer):
"""
if preconditioner is None:
preconditioner = self._preconditioner
controller = self._controller
status = controller.start(energy)
if status != controller.CONTINUE:
......@@ -71,8 +74,8 @@ class ConjugateGradient(Minimizer):
norm_b = energy.norm_b
r = -energy.gradient
if self._preconditioner is not None:
d = self._preconditioner(r)
if preconditioner is not None:
d = preconditioner(r)
else:
d = r.copy()
previous_gamma = (r.vdot(d)).real
......@@ -98,8 +101,8 @@ class ConjugateGradient(Minimizer):
tpos += energy.position
energy = energy.at_with_grad(tpos, -r)
if self._preconditioner is not None:
s = self._preconditioner(r)
if preconditioner is not None:
s = preconditioner(r)
else:
s = r
......
......@@ -26,7 +26,7 @@ class Minimizer(with_metaclass(NiftyMeta, type('NewBase', (object,), {}))):
"""
@abc.abstractmethod
def __call__(self, energy):
def __call__(self, energy, preconditioner=None):
""" Performs the minimization of the provided Energy functional.
Parameters
......@@ -35,6 +35,9 @@ class Minimizer(with_metaclass(NiftyMeta, type('NewBase', (object,), {}))):
Energy object which provides value, gradient and curvature at a
specific position in parameter space.
preconditioner : LinearOperator, optional
Preconditioner to accelerate the minimization
Returns
-------
energy : Energy object
......
......@@ -38,10 +38,10 @@ class InvertibleOperatorMixin(object):
"""
def __init__(self, inverter,
def __init__(self, inverter, preconditioner=None,
forward_x0=None, backward_x0=None, *args, **kwargs):
self.__inverter = inverter
self._preconditioner = preconditioner
self.__forward_x0 = forward_x0
self.__backward_x0 = backward_x0
super(InvertibleOperatorMixin, self).__init__(*args, **kwargs)
......@@ -53,9 +53,9 @@ class InvertibleOperatorMixin(object):
x0 = Field(self.target, val=0., dtype=x.dtype)
(result, convergence) = self.__inverter(QuadraticEnergy(
A=self.inverse_times,
b=x,
position=x0))
A=self.inverse_times,
b=x, position=x0),
preconditioner=self._preconditioner)
return result.position
def _adjoint_times(self, x, spaces):
......@@ -65,9 +65,9 @@ class InvertibleOperatorMixin(object):
x0 = Field(self.domain, val=0., dtype=x.dtype)
(result, convergence) = self.__inverter(QuadraticEnergy(
A=self.adjoint_inverse_times,
b=x,
position=x0))
A=self.adjoint_inverse_times,
b=x, position=x0),
preconditioner=self._preconditioner)
return result.position
def _inverse_times(self, x, spaces):
......@@ -77,9 +77,9 @@ class InvertibleOperatorMixin(object):
x0 = Field(self.domain, val=0., dtype=x.dtype)
(result, convergence) = self.__inverter(QuadraticEnergy(
A=self.times,
b=x,
position=x0))
A=self.times,
b=x, position=x0),
preconditioner=self._preconditioner)
return result.position
def _adjoint_inverse_times(self, x, spaces):
......@@ -89,7 +89,7 @@ class InvertibleOperatorMixin(object):
x0 = Field(self.target, val=0., dtype=x.dtype)
(result, convergence) = self.__inverter(QuadraticEnergy(
A=self.adjoint_times,
b=x,
position=x0))
A=self.adjoint_times,
b=x, position=x0),
preconditioner=self._preconditioner)
return result.position
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment