Commit abd52d28 authored by Martin Reinecke's avatar Martin Reinecke

allow custom measurements in n=minimizers

parent fe3bc980
Pipeline #17912 passed with stage
in 3 minutes and 19 seconds
......@@ -58,6 +58,7 @@ if __name__ == "__main__":
diagonal = mock_power.power_synthesize(spaces=(0, 1), mean=1, std=0,
real_signal=False)**2
#diagonal = diagonal.real
S = ift.DiagonalOperator(domain=(harmonic_space_1, harmonic_space_2),
diagonal=diagonal)
......@@ -91,9 +92,7 @@ if __name__ == "__main__":
# Wiener filter
j = R_harmonic.adjoint_times(N.inverse_times(data))
ctrl = ift.DefaultIterationController(verbose=True,
tol_abs_gradnorm=1.0,
tol_rel_gradnorm=1e-4)
ctrl = ift.DefaultIterationController(verbose=True, tol_custom=1e-3, convergence_level=3)
inverter = ift.ConjugateGradient(controller=ctrl,preconditioner=S.times)
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic, inverter=inverter)
......
......@@ -37,3 +37,7 @@ class QuadraticEnergy(Energy):
@property
def curvature(self):
return self._A
@property
def norm_b(self):
return self._b.norm()
......@@ -18,6 +18,7 @@
from __future__ import division
from .minimizer import Minimizer
import numpy as np
class ConjugateGradient(Minimizer):
......@@ -68,6 +69,7 @@ class ConjugateGradient(Minimizer):
if status != controller.CONTINUE:
return energy, status
norm_b = energy.norm_b
r = -energy.gradient
if self._preconditioner is not None:
d = self._preconditioner(r)
......@@ -90,10 +92,6 @@ class ConjugateGradient(Minimizer):
r -= q * alpha
energy = energy.at_with_grad(energy.position+d*alpha, -r)
status = self._controller.check(energy)
if status != controller.CONTINUE:
return energy, status
if self._preconditioner is not None:
s = self._preconditioner(r)
else:
......@@ -106,6 +104,12 @@ class ConjugateGradient(Minimizer):
if gamma == 0:
return energy, controller.CONVERGED
status = self._controller.check(energy,
custom_measure=np.sqrt(gamma) /
norm_b)
if status != controller.CONTINUE:
return energy, status
d = s + d * max(0, gamma/previous_gamma)
previous_gamma = gamma
......@@ -22,33 +22,42 @@ from .iteration_controller import IterationController
class DefaultIterationController(IterationController):
def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None,
convergence_level=1, iteration_limit=None, name=None,
verbose=None):
tol_custom=None, convergence_level=1, iteration_limit=None,
name=None, verbose=None):
super(DefaultIterationController, self).__init__()
self._tol_abs_gradnorm = tol_abs_gradnorm
self._tol_rel_gradnorm = tol_rel_gradnorm
self._tol_custom = tol_custom
self._convergence_level = convergence_level
self._iteration_limit = iteration_limit
self._name = name
self._verbose = verbose
def start(self, energy):
def start(self, energy, custom_measure=None):
self._itcount = -1
self._ccount = 0
if self._tol_rel_gradnorm is not None:
self._tol_rel_gradnorm_now = self._tol_rel_gradnorm \
* energy.gradient_norm
return self.check(energy)
return self.check(energy, custom_measure)
def check(self, energy):
def check(self, energy, custom_measure=None):
self._itcount += 1
inclvl = False
if self._tol_abs_gradnorm is not None:
if energy.gradient_norm <= self._tol_abs_gradnorm:
self._ccount += 1
inclvl = True
if self._tol_rel_gradnorm is not None:
if energy.gradient_norm <= self._tol_rel_gradnorm_now:
self._ccount += 1
inclvl = True
if self._tol_custom is not None and custom_measure is not None:
if custom_measure <= self._tol_custom:
inclvl = True
if inclvl:
self._ccount += 1
else:
self._ccount = max(0, self._ccount-1)
# report
if self._verbose:
......@@ -58,6 +67,8 @@ class DefaultIterationController(IterationController):
msg += " Iteration #" + str(self._itcount)
msg += " energy=" + str(energy.value)
msg += " gradnorm=" + str(energy.gradient_norm)
if custom_measure is not None:
msg += " custom=" + str(custom_measure)
msg += " clvl=" + str(self._ccount)
print(msg)
# self.logger.info(msg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment