Commit f5d453d3 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'NIFTy_4' into yango_minimizer

parents 63eeb81b cb13f9db
......@@ -52,38 +52,51 @@ def generate_krylov_samples(D_inv, S, j, N_samps, controller):
D_inv(x) = j
and the second entry are a list of samples from D_inv.inverse
"""
# MR FIXME: this should be synchronized with the "official" Nifty CG
# RL FIXME: make consistent with complex numbers
j = S.draw_sample(from_inverse=True) if j is None else j
x = j*0.
energy = QuadraticEnergy(x, D_inv, j)
energy = QuadraticEnergy(j*0., D_inv, j)
y = [S.draw_sample() for _ in range(N_samps)]
status = controller.start(energy)
if status != controller.CONTINUE:
return x, y
return energy.position, y
r = energy.gradient
d = r.copy()
previous_gamma = r.vdot(r).real
if previous_gamma == 0:
return energy.position, y
r = j.copy()
p = r.copy()
Dip = D_inv(p)
d = p.vdot(Dip)
while True:
gamma = r.vdot(r) / d
if gamma == 0.:
break
x = x + gamma*p
q = energy.curvature(d)
ddotq = d.vdot(q).real
if ddotq == 0.:
logger.error("Error: ConjugateGradient: ddotq==0.")
return energy.position, y
alpha = previous_gamma/ddotq
if alpha < 0:
logger.error("Error: ConjugateGradient: alpha<0.")
return energy.position, y
for samp in y:
samp += (np.random.randn() * np.sqrt(d) - samp.vdot(Dip)) / d * p
energy = energy.at(x)
samp += (np.random.randn()*np.sqrt(ddotq) - samp.vdot(q))/ddotq * d
q *= -alpha
r = r + q
energy = energy.at_with_grad(energy.position - alpha*d, r)
gamma = r.vdot(r).real
if gamma == 0:
return energy.position, y
status = controller.check(energy)
if status != controller.CONTINUE:
return x, y
r_new = r - gamma * Dip
beta = r_new.vdot(r_new) / r.vdot(r)
r = r_new
p = r + beta * p
Dip = D_inv(p)
d = p.vdot(Dip)
if d == 0.:
break
return x, y
return energy.position, y
d *= max(0, gamma/previous_gamma)
d += r
previous_gamma = gamma
......@@ -68,7 +68,7 @@ class ConjugateGradient(Minimizer):
r = energy.gradient
d = r.copy() if preconditioner is None else preconditioner(r)
previous_gamma = (r.vdot(d)).real
previous_gamma = r.vdot(d).real
if previous_gamma == 0:
return energy, controller.CONVERGED
......@@ -99,7 +99,7 @@ class ConjugateGradient(Minimizer):
if gamma == 0:
return energy, controller.CONVERGED
status = self._controller.check(energy)
status = controller.check(energy)
if status != controller.CONTINUE:
return energy, status
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment