Skip to content
Snippets Groups Projects
Commit 3c5a2a01 authored by Philipp Arras's avatar Philipp Arras
Browse files

Rewrite NewtonCG

parent abe02b37
No related branches found
No related tags found
2 merge requests!349Fix mpi,!333Operator spectra
Pipeline #52374 passed
...@@ -18,8 +18,11 @@ ...@@ -18,8 +18,11 @@
import numpy as np import numpy as np
from ..logger import logger from ..logger import logger
from .conjugate_gradient import ConjugateGradient
from .iteration_controllers import GradientNormController
from .line_search import LineSearch from .line_search import LineSearch
from .minimizer import Minimizer from .minimizer import Minimizer
from .quadratic_energy import QuadraticEnergy
class DescentMinimizer(Minimizer): class DescentMinimizer(Minimizer):
...@@ -154,49 +157,22 @@ class NewtonCG(DescentMinimizer): ...@@ -154,49 +157,22 @@ class NewtonCG(DescentMinimizer):
Algorithm derived from SciPy sources. Algorithm derived from SciPy sources.
""" """
def __init__(self, controller, napprox=0, line_searcher=None): def __init__(self, controller, line_searcher=None):
if line_searcher is None: if line_searcher is None:
line_searcher = LineSearch(preferred_initial_step_size=1.) line_searcher = LineSearch(preferred_initial_step_size=1.)
super(NewtonCG, self).__init__(controller=controller, super(NewtonCG, self).__init__(controller=controller,
line_searcher=line_searcher) line_searcher=line_searcher)
self._napprox = int(napprox)
def get_descent_direction(self, energy): def get_descent_direction(self, energy):
# if self._napprox > 1: g = energy.gradient
# from ..probing import approximation2endo maggrad = abs(g).sum()
# sqdiag = approximation2endo(energy.metric, self._napprox).sqrt()
float64eps = np.finfo(np.float64).eps
r = energy.gradient
maggrad = abs(r).sum()
termcond = np.min([0.5, np.sqrt(maggrad)]) * maggrad termcond = np.min([0.5, np.sqrt(maggrad)]) * maggrad
pos = energy.position*0 ic = GradientNormController(tol_abs_gradnorm=termcond, p=1)
d = r e = QuadraticEnergy(0*energy.position, energy.metric, g)
previous_gamma = r.vdot(d) e, conv = ConjugateGradient(ic, nreset=np.inf)(e)
ii = 0 if conv == ic.ERROR:
while True: raise RuntimeError
if not ii % 10 and ii > 0: return -e.position
print(ii)
if abs(r).sum() <= termcond:
return pos
q = energy.apply_metric(d)
curv = d.vdot(q)
if 0 <= curv <= 3*float64eps:
return pos
if curv < 0:
return pos if ii > 0 else previous_gamma/curv*r
ii += 1
alpha = previous_gamma/curv
pos = pos - alpha*d
r = r - alpha*q
s = r
gamma = r.vdot(s)
d = d*(gamma/previous_gamma)+r
previous_gamma = gamma
# curvature keeps increasing, bail out
raise ValueError("Warning: CG iterations didn't converge. "
"The Hessian is not positive definite.")
class L_BFGS(DescentMinimizer): class L_BFGS(DescentMinimizer):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment