Commit bf190bbd authored by Philipp Arras's avatar Philipp Arras
Browse files

Unify notation in conjugate gradient and NewtonCG

parent b2fa7769
Pipeline #52268 passed with stages
in 8 minutes and 32 seconds
...@@ -74,27 +74,27 @@ class ConjugateGradient(Minimizer): ...@@ -74,27 +74,27 @@ class ConjugateGradient(Minimizer):
if previous_gamma == 0: if previous_gamma == 0:
return energy, controller.CONVERGED return energy, controller.CONVERGED
iter = 0 ii = 0
while True: while True:
q = energy.apply_metric(d) q = energy.apply_metric(d)
ddotq = d.vdot(q).real curv = d.vdot(q).real
if ddotq == 0.: if curv == 0.:
logger.error("Error: ConjugateGradient: ddotq==0.") logger.error("Error: ConjugateGradient: curv==0.")
return energy, controller.ERROR return energy, controller.ERROR
alpha = previous_gamma/ddotq alpha = previous_gamma/curv
if alpha < 0: if alpha < 0:
logger.error("Error: ConjugateGradient: alpha<0.") logger.error("Error: ConjugateGradient: alpha<0.")
return energy, controller.ERROR return energy, controller.ERROR
iter += 1 ii += 1
if iter < self._nreset: if ii < self._nreset:
r = r - q*alpha r = r - q*alpha
energy = energy.at_with_grad(energy.position - alpha*d, r) energy = energy.at_with_grad(energy.position - alpha*d, r)
else: else:
energy = energy.at(energy.position - alpha*d) energy = energy.at(energy.position - alpha*d)
r = energy.gradient r = energy.gradient
iter = 0 ii = 0
s = r if preconditioner is None else preconditioner(r) s = r if preconditioner is None else preconditioner(r)
......
...@@ -162,32 +162,29 @@ class NewtonCG(DescentMinimizer): ...@@ -162,32 +162,29 @@ class NewtonCG(DescentMinimizer):
def get_descent_direction(self, energy): def get_descent_direction(self, energy):
float64eps = np.finfo(np.float64).eps float64eps = np.finfo(np.float64).eps
grad = energy.gradient r = energy.gradient
maggrad = abs(grad).sum() maggrad = abs(r).sum()
termcond = np.min([0.5, np.sqrt(maggrad)]) * maggrad termcond = np.min([0.5, np.sqrt(maggrad)]) * maggrad
xsupi = energy.position*0 pos = energy.position*0
ri = grad d = -r
psupi = -ri previous_gamma = r.vdot(r)
dri0 = ri.vdot(ri) ii = 0
i = 0
while True: while True:
if abs(ri).sum() <= termcond: if abs(r).sum() <= termcond:
return xsupi return pos
Ap = energy.apply_metric(psupi) q = energy.apply_metric(d)
# check curvature curv = d.vdot(q)
curv = psupi.vdot(Ap)
if 0 <= curv <= 3*float64eps: if 0 <= curv <= 3*float64eps:
return xsupi return pos
elif curv < 0: if curv < 0:
return xsupi if i > 0 else (dri0/curv) * grad return pos if ii > 0 else previous_gamma/curv*r
alphai = dri0/curv alpha = previous_gamma/curv
xsupi = xsupi + alphai*psupi pos = pos + alpha*d
ri = ri + alphai*Ap r = r + alpha*q
dri1 = ri.vdot(ri) gamma = r.vdot(r)
psupi = (dri1/dri0)*psupi - ri d = (gamma/previous_gamma)*d - r
i += 1 ii += 1
dri0 = dri1 # update numpy.dot(ri,ri) for next time. previous_gamma = gamma
# curvature keeps increasing, bail out # curvature keeps increasing, bail out
raise ValueError("Warning: CG iterations didn't converge. " raise ValueError("Warning: CG iterations didn't converge. "
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment