Commit 537a2d3d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

restructure iteration controllers to match those in nightly branch

parent 0d71c455
Pipeline #19122 canceled with stage
...@@ -88,11 +88,11 @@ if __name__ == "__main__": ...@@ -88,11 +88,11 @@ if __name__ == "__main__":
d_data = d.val.real d_data = d.val.real
ift.plotting.plot(d.real, name="data.pdf") ift.plotting.plot(d.real, name="data.pdf")
IC1 = ift.DefaultIterationController(verbose=True,iteration_limit=100,tol_abs_gradnorm=0.1) IC1 = ift.GradientNormController(verbose=True,iteration_limit=100,tol_abs_gradnorm=0.1)
minimizer1 = ift.RelaxedNewton(IC1) minimizer1 = ift.RelaxedNewton(IC1)
IC2 = ift.DefaultIterationController(verbose=True,iteration_limit=100,tol_abs_gradnorm=0.1) IC2 = ift.GradientNormController(verbose=True,iteration_limit=100,tol_abs_gradnorm=0.1)
minimizer2 = ift.VL_BFGS(IC2, max_history_length=20) minimizer2 = ift.VL_BFGS(IC2, max_history_length=20)
IC3 = ift.DefaultIterationController(verbose=True,iteration_limit=100,tol_abs_gradnorm=0.1) IC3 = ift.GradientNormController(verbose=True,iteration_limit=100,tol_abs_gradnorm=0.1)
minimizer3 = ift.SteepestDescent(IC3) minimizer3 = ift.SteepestDescent(IC3)
# Set starting position # Set starting position
...@@ -107,14 +107,14 @@ if __name__ == "__main__": ...@@ -107,14 +107,14 @@ if __name__ == "__main__":
S0 = ift.create_power_operator(h_space, power_spectrum=ps0) S0 = ift.create_power_operator(h_space, power_spectrum=ps0)
# Initialize non-linear Wiener Filter energy # Initialize non-linear Wiener Filter energy
ICI = ift.DefaultIterationController(verbose=False,iteration_limit=500,tol_abs_gradnorm=0.1) ICI = ift.GradientNormController(verbose=False,iteration_limit=500,tol_abs_gradnorm=0.1)
map_inverter = ift.ConjugateGradient(controller=ICI) map_inverter = ift.ConjugateGradient(controller=ICI)
map_energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S0, inverter=map_inverter) map_energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S0, inverter=map_inverter)
# Solve the Wiener Filter analytically # Solve the Wiener Filter analytically
D0 = map_energy.curvature D0 = map_energy.curvature
m0 = D0.inverse_times(j) m0 = D0.inverse_times(j)
# Initialize power energy with updated parameters # Initialize power energy with updated parameters
ICI2 = ift.DefaultIterationController(name="powI",verbose=True,iteration_limit=200,tol_abs_gradnorm=1e-5) ICI2 = ift.GradientNormController(name="powI",verbose=True,iteration_limit=200,tol_abs_gradnorm=1e-5)
power_inverter = ift.ConjugateGradient(controller=ICI2) power_inverter = ift.ConjugateGradient(controller=ICI2)
power_energy = ift.library.CriticalPowerEnergy(position=t0, m=m0, D=D0, power_energy = ift.library.CriticalPowerEnergy(position=t0, m=m0, D=D0,
smoothness_prior=10., samples=3, inverter=power_inverter) smoothness_prior=10., samples=3, inverter=power_inverter)
......
...@@ -46,8 +46,8 @@ if __name__ == "__main__": ...@@ -46,8 +46,8 @@ if __name__ == "__main__":
# Wiener filter # Wiener filter
m0 = ift.Field(harmonic_space, val=0.) m0 = ift.Field(harmonic_space, val=0.)
ctrl = ift.DefaultIterationController(verbose=False,tol_abs_gradnorm=1) ctrl = ift.GradientNormController(verbose=False,tol_abs_gradnorm=1)
ctrl2 = ift.DefaultIterationController(verbose=True,tol_abs_gradnorm=0.1, name="outer") ctrl2 = ift.GradientNormController(verbose=True,tol_abs_gradnorm=0.1, name="outer")
inverter = ift.ConjugateGradient(controller=ctrl) inverter = ift.ConjugateGradient(controller=ctrl)
energy = ift.library.LogNormalWienerFilterEnergy(m0, data, R_harmonic, N, S, inverter=inverter) energy = ift.library.LogNormalWienerFilterEnergy(m0, data, R_harmonic, N, S, inverter=inverter)
minimizer1 = ift.VL_BFGS(controller=ctrl2,max_history_length=20) minimizer1 = ift.VL_BFGS(controller=ctrl2,max_history_length=20)
......
...@@ -90,7 +90,7 @@ if __name__ == "__main__": ...@@ -90,7 +90,7 @@ if __name__ == "__main__":
# Wiener filter # Wiener filter
j = R_harmonic.adjoint_times(N.inverse_times(data)) j = R_harmonic.adjoint_times(N.inverse_times(data))
ctrl = ift.DefaultIterationController(verbose=True, iteration_limit=100) ctrl = ift.GradientNormController(verbose=True, iteration_limit=100)
inverter = ift.ConjugateGradient(controller=ctrl) inverter = ift.ConjugateGradient(controller=ctrl)
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic, inverter=inverter) wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic, inverter=inverter)
......
...@@ -45,7 +45,7 @@ if __name__ == "__main__": ...@@ -45,7 +45,7 @@ if __name__ == "__main__":
# Wiener filter # Wiener filter
j = R_harmonic.adjoint_times(N.inverse_times(data)) j = R_harmonic.adjoint_times(N.inverse_times(data))
ctrl = ift.DefaultIterationController(verbose=True,tol_abs_gradnorm=0.1,iteration_limit=10) ctrl = ift.GradientNormController(verbose=True,tol_abs_gradnorm=0.1,iteration_limit=10)
inverter = ift.ConjugateGradient(controller=ctrl) inverter = ift.ConjugateGradient(controller=ctrl)
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic,inverter=inverter) wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic,inverter=inverter)
m_k = wiener_curvature.inverse_times(j) #|\label{code:wf_wiener_filter}| m_k = wiener_curvature.inverse_times(j) #|\label{code:wf_wiener_filter}|
......
...@@ -58,7 +58,7 @@ if __name__ == "__main__": ...@@ -58,7 +58,7 @@ if __name__ == "__main__":
# Wiener filter # Wiener filter
j = R_harmonic.adjoint_times(N.inverse_times(data)) j = R_harmonic.adjoint_times(N.inverse_times(data))
ctrl = ift.DefaultIterationController(verbose=True,tol_abs_gradnorm=1e-2) ctrl = ift.GradientNormController(verbose=True,tol_abs_gradnorm=1e-2)
inverter = ift.ConjugateGradient(controller=ctrl) inverter = ift.ConjugateGradient(controller=ctrl)
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic, inverter=inverter) wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic, inverter=inverter)
......
...@@ -75,7 +75,7 @@ if __name__ == "__main__": ...@@ -75,7 +75,7 @@ if __name__ == "__main__":
# Choosing the minimization strategy # Choosing the minimization strategy
ctrl = ift.DefaultIterationController(verbose=True,tol_abs_gradnorm=0.1) ctrl = ift.GradientNormController(verbose=True,tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=ctrl) inverter = ift.ConjugateGradient(controller=ctrl)
# Setting starting position # Setting starting position
m0 = ift.Field(h_space, val=.0) m0 = ift.Field(h_space, val=.0)
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from .line_searching import * from .line_searching import *
from .iteration_controller import IterationController from .iteration_controlling import *
from .default_iteration_controller import DefaultIterationController
from .minimizer import Minimizer from .minimizer import Minimizer
from .conjugate_gradient import ConjugateGradient from .conjugate_gradient import ConjugateGradient
from .nonlinear_cg import NonlinearCG from .nonlinear_cg import NonlinearCG
......
from .iteration_controller import IterationController
from .gradient_norm_controller import GradientNormController
...@@ -20,11 +20,11 @@ from __future__ import print_function ...@@ -20,11 +20,11 @@ from __future__ import print_function
from .iteration_controller import IterationController from .iteration_controller import IterationController
class DefaultIterationController(IterationController): class GradientNormController(IterationController):
def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None, def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None,
convergence_level=1, iteration_limit=None, convergence_level=1, iteration_limit=None,
name=None, verbose=None): name=None, verbose=None):
super(DefaultIterationController, self).__init__() super(GradientNormController, self).__init__()
self._tol_abs_gradnorm = tol_abs_gradnorm self._tol_abs_gradnorm = tol_abs_gradnorm
self._tol_rel_gradnorm = tol_rel_gradnorm self._tol_rel_gradnorm = tol_rel_gradnorm
self._convergence_level = convergence_level self._convergence_level = convergence_level
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from builtins import range from builtins import range
import abc import abc
from ..nifty_meta import NiftyMeta from ...nifty_meta import NiftyMeta
from future.utils import with_metaclass from future.utils import with_metaclass
......
...@@ -24,7 +24,7 @@ class Test_Minimizers(unittest.TestCase): ...@@ -24,7 +24,7 @@ class Test_Minimizers(unittest.TestCase):
covariance = ift.DiagonalOperator(space, diagonal=covariance_diagonal) covariance = ift.DiagonalOperator(space, diagonal=covariance_diagonal)
required_result = ift.Field(space, val=1.) required_result = ift.Field(space, val=1.)
IC = ift.DefaultIterationController(tol_abs_gradnorm=1e-5) IC = ift.GradientNormController(tol_abs_gradnorm=1e-5)
minimizer = minimizer_class(controller=IC) minimizer = minimizer_class(controller=IC)
energy = ift.QuadraticEnergy(A=covariance, b=required_result, energy = ift.QuadraticEnergy(A=covariance, b=required_result,
position=starting_point) position=starting_point)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment