Commit 9b934fba authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'improve_minimization' into 'NIFTy_4'

Improve minimization

See merge request ift/NIFTy!230
parents d516b4cd 650634af
Pipeline #25837 passed with stages
in 15 minutes and 33 seconds
...@@ -81,11 +81,13 @@ class DescentMinimizer(Minimizer): ...@@ -81,11 +81,13 @@ class DescentMinimizer(Minimizer):
# compute a step length that reduces energy.value sufficiently # compute a step length that reduces energy.value sufficiently
try: try:
new_energy = self.line_searcher.perform_line_search( new_energy, success = self.line_searcher.perform_line_search(
energy=energy, pk=self.get_descent_direction(energy), energy=energy, pk=self.get_descent_direction(energy),
f_k_minus_1=f_k_minus_1) f_k_minus_1=f_k_minus_1)
except ValueError: except ValueError:
return energy, controller.ERROR return energy, controller.ERROR
if not success:
self.reset()
f_k_minus_1 = energy.value f_k_minus_1 = energy.value
...@@ -103,6 +105,9 @@ class DescentMinimizer(Minimizer): ...@@ -103,6 +105,9 @@ class DescentMinimizer(Minimizer):
if status != controller.CONTINUE: if status != controller.CONTINUE:
return energy, status return energy, status
def reset(self):
pass
@abc.abstractmethod @abc.abstractmethod
def get_descent_direction(self, energy): def get_descent_direction(self, energy):
""" Calculates the next descent direction. """ Calculates the next descent direction.
......
...@@ -109,3 +109,19 @@ class Energy(NiftyMetaBase()): ...@@ -109,3 +109,19 @@ class Energy(NiftyMetaBase()):
curvature of the potential at the given `position`. curvature of the potential at the given `position`.
""" """
raise NotImplementedError raise NotImplementedError
def longest_step(self, dir):
"""Returns the longest allowed step size along `dir`
Parameters
----------
dir : Field
the search direction
Returns
-------
float or None
the longest allowed step when starting from `self.position` along
`dir`. If None, the step size is not limited.
"""
return None
...@@ -54,5 +54,7 @@ class LineSearch(NiftyMetaBase()): ...@@ -54,5 +54,7 @@ class LineSearch(NiftyMetaBase()):
------- -------
Energy Energy
The new Energy object on the new position. The new Energy object on the new position.
bool
whether the line search was considered successful or not
""" """
raise NotImplementedError raise NotImplementedError
...@@ -41,7 +41,7 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -41,7 +41,7 @@ class LineSearchStrongWolfe(LineSearch):
Parameter for curvature condition rule. (Default: 0.9) Parameter for curvature condition rule. (Default: 0.9)
max_step_size : float max_step_size : float
Maximum step allowed in to be made in the descent direction. Maximum step allowed in to be made in the descent direction.
(Default: 1000000000) (Default: 1e30)
max_iterations : int, optional max_iterations : int, optional
Maximum number of iterations performed by the line search algorithm. Maximum number of iterations performed by the line search algorithm.
(Default: 100) (Default: 100)
...@@ -51,7 +51,7 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -51,7 +51,7 @@ class LineSearchStrongWolfe(LineSearch):
""" """
def __init__(self, preferred_initial_step_size=None, c1=1e-4, c2=0.9, def __init__(self, preferred_initial_step_size=None, c1=1e-4, c2=0.9,
max_step_size=1000000000, max_iterations=100, max_step_size=1e30, max_iterations=100,
max_zoom_iterations=100): max_zoom_iterations=100):
super(LineSearchStrongWolfe, self).__init__( super(LineSearchStrongWolfe, self).__init__(
...@@ -85,19 +85,26 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -85,19 +85,26 @@ class LineSearchStrongWolfe(LineSearch):
------- -------
Energy Energy
The new Energy object on the new position. The new Energy object on the new position.
bool
whether the line search was considered successful or not
""" """
le_0 = LineEnergy(0., energy, pk, 0.) le_0 = LineEnergy(0., energy, pk, 0.)
maxstepsize = energy.longest_step(pk)
if maxstepsize is None:
maxstepsize = self.max_step_size
maxstepsize = min(maxstepsize, self.max_step_size)
# initialize the zero phis # initialize the zero phis
old_phi_0 = f_k_minus_1 old_phi_0 = f_k_minus_1
phi_0 = le_0.value phi_0 = le_0.value
phiprime_0 = le_0.directional_derivative phiprime_0 = le_0.directional_derivative
if phiprime_0 == 0: if phiprime_0 == 0:
dobj.mprint("Directional derivative is zero; assuming convergence") dobj.mprint("Directional derivative is zero; assuming convergence")
return energy return energy, False
if phiprime_0 > 0: if phiprime_0 > 0:
dobj.mprint("Error: search direction is not a descent direction") dobj.mprint("Error: search direction is not a descent direction")
raise ValueError("search direction must be a descent direction") return energy, False
# set alphas # set alphas
alpha0 = 0. alpha0 = 0.
...@@ -112,49 +119,44 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -112,49 +119,44 @@ class LineSearchStrongWolfe(LineSearch):
alpha1 = 1.0 alpha1 = 1.0
else: else:
alpha1 = 1.0/pk.norm() alpha1 = 1.0/pk.norm()
alpha1 = min(alpha1, 0.99*maxstepsize)
# start the minimization loop # start the minimization loop
iteration_number = 0 iteration_number = 0
while iteration_number < self.max_iterations: while iteration_number < self.max_iterations:
iteration_number += 1 iteration_number += 1
if alpha1 == 0: if alpha1 == 0:
result_energy = le_0.energy return le_0.energy, False
break
le_alpha1 = le_0.at(alpha1) le_alpha1 = le_0.at(alpha1)
phi_alpha1 = le_alpha1.value phi_alpha1 = le_alpha1.value
if (phi_alpha1 > phi_0 + self.c1*alpha1*phiprime_0) or \ if (phi_alpha1 > phi_0 + self.c1*alpha1*phiprime_0) or \
((phi_alpha1 >= phi_alpha0) and (iteration_number > 1)): ((phi_alpha1 >= phi_alpha0) and (iteration_number > 1)):
le_star = self._zoom(alpha0, alpha1, phi_0, phiprime_0, return self._zoom(alpha0, alpha1, phi_0, phiprime_0,
phi_alpha0, phiprime_alpha0, phi_alpha1, phi_alpha0, phiprime_alpha0, phi_alpha1,
le_0) le_0)
result_energy = le_star.energy
break
phiprime_alpha1 = le_alpha1.directional_derivative phiprime_alpha1 = le_alpha1.directional_derivative
if abs(phiprime_alpha1) <= -self.c2*phiprime_0: if abs(phiprime_alpha1) <= -self.c2*phiprime_0:
result_energy = le_alpha1.energy return le_alpha1.energy, True
break
if phiprime_alpha1 >= 0: if phiprime_alpha1 >= 0:
le_star = self._zoom(alpha1, alpha0, phi_0, phiprime_0, return self._zoom(alpha1, alpha0, phi_0, phiprime_0,
phi_alpha1, phiprime_alpha1, phi_alpha0, phi_alpha1, phiprime_alpha1, phi_alpha0,
le_0) le_0)
result_energy = le_star.energy
break
# update alphas # update alphas
alpha0, alpha1 = alpha1, min(2*alpha1, self.max_step_size) alpha0, alpha1 = alpha1, min(2*alpha1, maxstepsize)
if alpha1 == self.max_step_size: if alpha1 == maxstepsize:
return le_alpha1.energy dobj.mprint("max step size reached")
return le_alpha1.energy, False
phi_alpha0 = phi_alpha1 phi_alpha0 = phi_alpha1
phiprime_alpha0 = phiprime_alpha1 phiprime_alpha0 = phiprime_alpha1
else:
dobj.mprint("max iterations reached") dobj.mprint("max iterations reached")
return le_alpha1.energy return le_alpha1.energy, False
return result_energy
def _zoom(self, alpha_lo, alpha_hi, phi_0, phiprime_0, def _zoom(self, alpha_lo, alpha_hi, phi_0, phiprime_0,
phi_lo, phiprime_lo, phi_hi, le_0): phi_lo, phiprime_lo, phi_hi, le_0):
...@@ -238,7 +240,7 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -238,7 +240,7 @@ class LineSearchStrongWolfe(LineSearch):
phiprime_alphaj = le_alphaj.directional_derivative phiprime_alphaj = le_alphaj.directional_derivative
# If the second Wolfe condition is met, return the result # If the second Wolfe condition is met, return the result
if abs(phiprime_alphaj) <= -self.c2*phiprime_0: if abs(phiprime_alphaj) <= -self.c2*phiprime_0:
return le_alphaj return le_alphaj.energy, True
# If not, check the sign of the slope # If not, check the sign of the slope
if phiprime_alphaj*delta_alpha >= 0: if phiprime_alphaj*delta_alpha >= 0:
alpha_recent, phi_recent = alpha_hi, phi_hi alpha_recent, phi_recent = alpha_hi, phi_hi
...@@ -251,7 +253,7 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -251,7 +253,7 @@ class LineSearchStrongWolfe(LineSearch):
else: else:
dobj.mprint("The line search algorithm (zoom) did not converge.") dobj.mprint("The line search algorithm (zoom) did not converge.")
return le_alphaj return le_alphaj.energy, False
def _cubicmin(self, a, fa, fpa, b, fb, c, fc): def _cubicmin(self, a, fa, fpa, b, fb, c, fc):
"""Estimating the minimum with cubic interpolation. """Estimating the minimum with cubic interpolation.
......
...@@ -56,8 +56,10 @@ class NonlinearCG(Minimizer): ...@@ -56,8 +56,10 @@ class NonlinearCG(Minimizer):
while True: while True:
grad_old = energy.gradient grad_old = energy.gradient
f_k = energy.value f_k = energy.value
energy = self._line_searcher.perform_line_search(energy, p, energy, success = self._line_searcher.perform_line_search(
f_k_minus_1) energy, p, f_k_minus_1)
if not success:
return energy, controller.ERROR
f_k_minus_1 = f_k f_k_minus_1 = f_k
status = self._controller.check(energy) status = self._controller.check(energy)
if status != controller.CONTINUE: if status != controller.CONTINUE:
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division, print_function from __future__ import division
from .minimizer import Minimizer from .minimizer import Minimizer
from ..field import Field from ..field import Field
from .. import dobj from .. import dobj
...@@ -97,9 +97,9 @@ class ScipyMinimizer(Minimizer): ...@@ -97,9 +97,9 @@ class ScipyMinimizer(Minimizer):
status = self._controller.check(hlp._energy) status = self._controller.check(hlp._energy)
return hlp._energy, self._controller.check(hlp._energy) return hlp._energy, self._controller.check(hlp._energy)
if not r.success: if not r.success:
print("Problem in Scipy minimization:", r.message) dobj.mprint("Problem in Scipy minimization:", r.message)
else: else:
print("Problem in Scipy minimization") dobj.mprint("Problem in Scipy minimization")
return hlp._energy, self._controller.ERROR return hlp._energy, self._controller.ERROR
......
...@@ -49,6 +49,9 @@ class VL_BFGS(DescentMinimizer): ...@@ -49,6 +49,9 @@ class VL_BFGS(DescentMinimizer):
self._information_store = None self._information_store = None
return super(VL_BFGS, self).__call__(energy) return super(VL_BFGS, self).__call__(energy)
def reset(self):
self._information_store = None
def get_descent_direction(self, energy): def get_descent_direction(self, energy):
x = energy.position x = energy.position
gradient = energy.gradient gradient = energy.gradient
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment