Commit c61a1a40 authored by Martin Reinecke's avatar Martin Reinecke

cleanup

parent b9e3ff80
Pipeline #16897 passed with stage
in 10 minutes and 15 seconds
......@@ -25,7 +25,6 @@ class QuadraticEnergy(Energy):
return self._Ax - self._b
@property
@memo
def curvature(self):
return self._A
......
......@@ -30,6 +30,8 @@ class ConjugateGradient(Minimizer):
Parameters
----------
controller : IterationController
Object that decides when to terminate the minimization.
reset_count : integer *optional*
Number of iterations after which to restart; i.e., forget previous
conjugated directions (default: None).
......@@ -37,16 +39,6 @@ class ConjugateGradient(Minimizer):
This operator can be provided which transforms the variables of the
system to improve the conditioning (default: None).
Attributes
----------
reset_count : integer
Number of iterations after which to restart; i.e., forget previous
conjugated directions.
preconditioner : function
This operator can be provided which transforms the variables of the
system to improve the conditioning (default: None).
controller : IterationController
References
----------
Jorge Nocedal & Stephen Wright, "Numerical Optimization", Second Edition,
......@@ -55,81 +47,79 @@ class ConjugateGradient(Minimizer):
"""
def __init__(self, controller, reset_count=None, preconditioner=None):
if reset_count is not None:
reset_count = int(reset_count)
self.reset_count = reset_count
self._reset_count = None if reset_count is None else int(reset_count)
self.preconditioner = preconditioner
self._preconditioner = preconditioner
self._controller = controller
def __call__(self, E):
def __call__(self, energy):
""" Runs the conjugate gradient minimization.
Parameters
----------
E : Energy object at the starting point of the iteration.
E's curvature operator must be independent of position, otherwise
energy : Energy object at the starting point of the iteration.
Its curvature operator must be independent of position, otherwise
linear conjugate gradient minimization will fail.
Returns
-------
E : QuadraticEnergy at last point of the iteration
convergence : integer
Latest convergence level indicating whether the minimization
has converged or not.
energy : QuadraticEnergy
state at last point of the iteration
status : integer
Can be controller.CONVERGED or controller.ERROR
"""
controller = self._controller
status = controller.start(E)
status = controller.start(energy)
if status != controller.CONTINUE:
return E, status
return energy, status
r = -E.gradient
if self.preconditioner is not None:
d = self.preconditioner(r)
r = -energy.gradient
if self._preconditioner is not None:
d = self._preconditioner(r)
else:
d = r.copy()
d = r
previous_gamma = (r.vdot(d)).real
if previous_gamma == 0:
return E, controller.CONVERGED
return energy, controller.CONVERGED
while True:
q = E.curvature(d)
alpha = previous_gamma/(d.vdot(q).real)
if not np.isfinite(alpha):
q = energy.curvature(d)
ddotq = d.vdot(q).real
if ddotq==0.:
self.logger.error("Alpha became infinite! Stopping.")
return E, controller.ERROR
return energy, controller.ERROR
alpha = previous_gamma/ddotq
E = E.at(E.position+d*alpha)
status = self._controller.check(E)
energy = energy.at(energy.position+d*alpha)
status = self._controller.check(energy)
if status != controller.CONTINUE:
return E, status
return energy, status
reset = False
if alpha < 0:
self.logger.warn("Positive definiteness of A violated!")
reset = True
if self.reset_count is not None:
reset += (iteration_number % self.reset_count == 0)
if self._reset_count is not None:
reset += (iteration_number % self._reset_count == 0)
if reset:
self.logger.info("Resetting conjugate directions.")
r = -E.gradient
r = -energy.gradient
else:
r -= q * alpha
if self.preconditioner is not None:
s = self.preconditioner(r)
if self._preconditioner is not None:
s = self._preconditioner(r)
else:
s = r.copy()
s = r
gamma = r.vdot(s).real
if gamma < 0:
self.logger.warn("Positive definiteness of preconditioner "
"violated!")
self.logger.warn(
"Positive definiteness of preconditioner violated!")
if gamma == 0:
return E, controller.CONVERGED
return energy, controller.CONVERGED
d = s + d * max(0, gamma/previous_gamma)
......
......@@ -33,38 +33,18 @@ class DescentMinimizer(Minimizer):
Parameters
----------
controller : IterationController
Object that decides when to terminate the minimization.
line_searcher : callable *optional*
Function which infers the step size in the descent direction
(default : LineSearchStrongWolfe()).
callback : callable *optional*
Function f(energy, iteration_number) supplied by the user to perform
in-situ analysis at every iteration step. When being called the
current energy and iteration_number are passed. (default: None)
Attributes
----------
line_searcher : LineSearch
Function which infers the optimal step size for functional minization
given a descent direction.
callback : function
Function f(energy, iteration_number) supplied by the user to perform
in-situ analysis at every iteration step. When being called the
current energy and iteration_number are passed.
Notes
------
The callback function can be used to externally stop the minimization by
raising a `StopIteration` exception.
Check `get_descent_direction` of a derived class for information on the
concrete minization scheme.
"""
def __init__(self, controller, line_searcher=LineSearchStrongWolfe()):
super(DescentMinimizer, self).__init__()
self.line_searcher = line_searcher
self._controller = controller
self.line_searcher = line_searcher
def __call__(self, energy):
""" Performs the minimization of the provided Energy functional.
......@@ -79,18 +59,15 @@ class DescentMinimizer(Minimizer):
-------
energy : Energy object
Latest `energy` of the minimization.
convergence : integer
Latest convergence level indicating whether the minimization
has converged or not.
status : integer
Can be controller.CONVERGED or controller.ERROR
Note
----
The minimization is stopped if
* the callback function raises a `StopIteration` exception,
* the controller returns controller.CONVERGED or controller.ERROR,
* a perfectly flat point is reached,
* according to the line-search the minimum is found,
* the target convergence level is reached,
* the iteration limit is reached.
"""
......
......@@ -26,9 +26,8 @@ class VL_BFGS(DescentMinimizer):
def __init__(self, controller, line_searcher=LineSearchStrongWolfe(),
max_history_length=5):
super(VL_BFGS, self).__init__(
controller=controller,
line_searcher=line_searcher)
super(VL_BFGS, self).__init__(controller=controller,
line_searcher=line_searcher)
self.max_history_length = max_history_length
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment