Commit 3c5a2a01 authored by Philipp Arras's avatar Philipp Arras
Browse files

Rewrite NewtonCG

parent abe02b37
Pipeline #52374 passed with stages
in 8 minutes and 38 seconds
......@@ -18,8 +18,11 @@
import numpy as np
from ..logger import logger
from .conjugate_gradient import ConjugateGradient
from .iteration_controllers import GradientNormController
from .line_search import LineSearch
from .minimizer import Minimizer
from .quadratic_energy import QuadraticEnergy
class DescentMinimizer(Minimizer):
......@@ -154,49 +157,22 @@ class NewtonCG(DescentMinimizer):
Algorithm derived from SciPy sources.
"""
def __init__(self, controller, napprox=0, line_searcher=None):
def __init__(self, controller, line_searcher=None):
if line_searcher is None:
line_searcher = LineSearch(preferred_initial_step_size=1.)
super(NewtonCG, self).__init__(controller=controller,
line_searcher=line_searcher)
self._napprox = int(napprox)
def get_descent_direction(self, energy):
# if self._napprox > 1:
# from ..probing import approximation2endo
# sqdiag = approximation2endo(energy.metric, self._napprox).sqrt()
float64eps = np.finfo(np.float64).eps
r = energy.gradient
maggrad = abs(r).sum()
g = energy.gradient
maggrad = abs(g).sum()
termcond = np.min([0.5, np.sqrt(maggrad)]) * maggrad
pos = energy.position*0
d = r
previous_gamma = r.vdot(d)
ii = 0
while True:
if not ii % 10 and ii > 0:
print(ii)
if abs(r).sum() <= termcond:
return pos
q = energy.apply_metric(d)
curv = d.vdot(q)
if 0 <= curv <= 3*float64eps:
return pos
if curv < 0:
return pos if ii > 0 else previous_gamma/curv*r
ii += 1
alpha = previous_gamma/curv
pos = pos - alpha*d
r = r - alpha*q
s = r
gamma = r.vdot(s)
d = d*(gamma/previous_gamma)+r
previous_gamma = gamma
# curvature keeps increasing, bail out
raise ValueError("Warning: CG iterations didn't converge. "
"The Hessian is not positive definite.")
ic = GradientNormController(tol_abs_gradnorm=termcond, p=1)
e = QuadraticEnergy(0*energy.position, energy.metric, g)
e, conv = ConjugateGradient(ic, nreset=np.inf)(e)
if conv == ic.ERROR:
raise RuntimeError
return -e.position
class L_BFGS(DescentMinimizer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment