Commit f5d453d3 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'NIFTy_4' into yango_minimizer

parents 63eeb81b cb13f9db
...@@ -52,38 +52,51 @@ def generate_krylov_samples(D_inv, S, j, N_samps, controller): ...@@ -52,38 +52,51 @@ def generate_krylov_samples(D_inv, S, j, N_samps, controller):
D_inv(x) = j D_inv(x) = j
and the second entry are a list of samples from D_inv.inverse and the second entry are a list of samples from D_inv.inverse
""" """
# MR FIXME: this should be synchronized with the "official" Nifty CG
# RL FIXME: make consistent with complex numbers # RL FIXME: make consistent with complex numbers
j = S.draw_sample(from_inverse=True) if j is None else j j = S.draw_sample(from_inverse=True) if j is None else j
x = j*0. energy = QuadraticEnergy(j*0., D_inv, j)
energy = QuadraticEnergy(x, D_inv, j)
y = [S.draw_sample() for _ in range(N_samps)] y = [S.draw_sample() for _ in range(N_samps)]
status = controller.start(energy) status = controller.start(energy)
if status != controller.CONTINUE: if status != controller.CONTINUE:
return x, y return energy.position, y
r = energy.gradient
d = r.copy()
previous_gamma = r.vdot(r).real
if previous_gamma == 0:
return energy.position, y
r = j.copy()
p = r.copy()
Dip = D_inv(p)
d = p.vdot(Dip)
while True: while True:
gamma = r.vdot(r) / d q = energy.curvature(d)
if gamma == 0.: ddotq = d.vdot(q).real
break if ddotq == 0.:
x = x + gamma*p logger.error("Error: ConjugateGradient: ddotq==0.")
return energy.position, y
alpha = previous_gamma/ddotq
if alpha < 0:
logger.error("Error: ConjugateGradient: alpha<0.")
return energy.position, y
for samp in y: for samp in y:
samp += (np.random.randn() * np.sqrt(d) - samp.vdot(Dip)) / d * p samp += (np.random.randn()*np.sqrt(ddotq) - samp.vdot(q))/ddotq * d
energy = energy.at(x)
q *= -alpha
r = r + q
energy = energy.at_with_grad(energy.position - alpha*d, r)
gamma = r.vdot(r).real
if gamma == 0:
return energy.position, y
status = controller.check(energy) status = controller.check(energy)
if status != controller.CONTINUE: if status != controller.CONTINUE:
return x, y return energy.position, y
r_new = r - gamma * Dip
beta = r_new.vdot(r_new) / r.vdot(r) d *= max(0, gamma/previous_gamma)
r = r_new d += r
p = r + beta * p
Dip = D_inv(p) previous_gamma = gamma
d = p.vdot(Dip)
if d == 0.:
break
return x, y
...@@ -68,7 +68,7 @@ class ConjugateGradient(Minimizer): ...@@ -68,7 +68,7 @@ class ConjugateGradient(Minimizer):
r = energy.gradient r = energy.gradient
d = r.copy() if preconditioner is None else preconditioner(r) d = r.copy() if preconditioner is None else preconditioner(r)
previous_gamma = (r.vdot(d)).real previous_gamma = r.vdot(d).real
if previous_gamma == 0: if previous_gamma == 0:
return energy, controller.CONVERGED return energy, controller.CONVERGED
...@@ -99,7 +99,7 @@ class ConjugateGradient(Minimizer): ...@@ -99,7 +99,7 @@ class ConjugateGradient(Minimizer):
if gamma == 0: if gamma == 0:
return energy, controller.CONVERGED return energy, controller.CONVERGED
status = self._controller.check(energy) status = controller.check(energy)
if status != controller.CONTINUE: if status != controller.CONTINUE:
return energy, status return energy, status
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment