Commit bf190bbd authored by Philipp Arras's avatar Philipp Arras
Browse files

Unify notation in conjugate gradient and NewtonCG

parent b2fa7769
Pipeline #52268 passed with stages
in 8 minutes and 32 seconds
......@@ -74,27 +74,27 @@ class ConjugateGradient(Minimizer):
if previous_gamma == 0:
return energy, controller.CONVERGED
iter = 0
ii = 0
while True:
q = energy.apply_metric(d)
ddotq = d.vdot(q).real
if ddotq == 0.:
logger.error("Error: ConjugateGradient: ddotq==0.")
curv = d.vdot(q).real
if curv == 0.:
logger.error("Error: ConjugateGradient: curv==0.")
return energy, controller.ERROR
alpha = previous_gamma/ddotq
alpha = previous_gamma/curv
if alpha < 0:
logger.error("Error: ConjugateGradient: alpha<0.")
return energy, controller.ERROR
iter += 1
if iter < self._nreset:
ii += 1
if ii < self._nreset:
r = r - q*alpha
energy = energy.at_with_grad(energy.position - alpha*d, r)
else:
energy = energy.at(energy.position - alpha*d)
r = energy.gradient
iter = 0
ii = 0
s = r if preconditioner is None else preconditioner(r)
......
......@@ -162,32 +162,29 @@ class NewtonCG(DescentMinimizer):
def get_descent_direction(self, energy):
float64eps = np.finfo(np.float64).eps
grad = energy.gradient
maggrad = abs(grad).sum()
r = energy.gradient
maggrad = abs(r).sum()
termcond = np.min([0.5, np.sqrt(maggrad)]) * maggrad
xsupi = energy.position*0
ri = grad
psupi = -ri
dri0 = ri.vdot(ri)
i = 0
pos = energy.position*0
d = -r
previous_gamma = r.vdot(r)
ii = 0
while True:
if abs(ri).sum() <= termcond:
return xsupi
Ap = energy.apply_metric(psupi)
# check curvature
curv = psupi.vdot(Ap)
if abs(r).sum() <= termcond:
return pos
q = energy.apply_metric(d)
curv = d.vdot(q)
if 0 <= curv <= 3*float64eps:
return xsupi
elif curv < 0:
return xsupi if i > 0 else (dri0/curv) * grad
alphai = dri0/curv
xsupi = xsupi + alphai*psupi
ri = ri + alphai*Ap
dri1 = ri.vdot(ri)
psupi = (dri1/dri0)*psupi - ri
i += 1
dri0 = dri1 # update numpy.dot(ri,ri) for next time.
return pos
if curv < 0:
return pos if ii > 0 else previous_gamma/curv*r
alpha = previous_gamma/curv
pos = pos + alpha*d
r = r + alpha*q
gamma = r.vdot(r)
d = (gamma/previous_gamma)*d - r
ii += 1
previous_gamma = gamma
# curvature keeps increasing, bail out
raise ValueError("Warning: CG iterations didn't converge. "
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment