Commit 648587ef authored by Martin Reinecke's avatar Martin Reinecke
Browse files

first try

parent e1843d4d
Pipeline #16876 canceled with stage
in 9 minutes and 56 seconds
......@@ -17,5 +17,6 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from energy import Energy
from quadratic_energy import QuadraticEnergy
from line_energy import LineEnergy
from memoization import memo
......@@ -17,6 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from nifty.nifty_meta import NiftyMeta
from nifty.energies.memoization import memo
from keepers import Loggable
......@@ -40,7 +41,7 @@ class Energy(Loggable, object):
value : np.float
The value of the energy functional at given `position`.
gradient : Field
The gradient at given `position` in parameter direction.
The gradient at given `position`.
curvature : LinearOperator, callable
A positive semi-definite operator or function describing the curvature
of the potential at the given `position`.
......@@ -109,12 +110,32 @@ class Energy(Loggable, object):
def gradient(self):
The gradient at given `position` in parameter direction.
The gradient at given `position`.
raise NotImplementedError
def gradient_norm(self):
The length of the gradient at given `position`.
return self.gradient.norm()
def gradient_infnorm(self):
The infinity norm of the gradient at given `position`.
return abs(self.gradient).max()
def curvature(self):
from import Energy
from nifty.energies.memoization import memo
class QuadraticEnergy(Energy):
"""The Energy for a quadratic form.
def __init__(self, position, A, b):
super(QuadraticEnergy, self).__init__(position=position)
self._A = A
self._b = b
def at(self, position):
return self.__class__(position=position, A=self._A, b=self._b)
def value(self):
return 0.5*self.position.vdot(self._Ax) - self._b.vdot(self.position)
def gradient(self):
return self._Ax - self._b
def curvature(self):
return self._A
def _Ax(self):
return self.curvature(self.position)
......@@ -90,60 +90,54 @@ class ConjugateGradient(Loggable, object):
reset_count = int(reset_count)
self.reset_count = reset_count
if preconditioner is None:
preconditioner = lambda z: z
self.preconditioner = preconditioner
self.callback = callback
def __call__(self, A, b, x0):
def __call__(self, E):
""" Runs the conjugate gradient minimization.
For `Ax = b` the variable `x` is infered.
A : Operator
Operator `A` applicable to a Field.
b : Field
Result of the operation `A(x)`.
x0 : Field
Starting guess for the minimization.
E : Energy object at the starting point of the iteration.
E's curvature operator must be independent of position, otherwise
linear conjugate gradient minimization will fail.
x : Field
Latest `x` of the minimization.
E : QuadraticEnergy at last point of the iteration
convergence : integer
Latest convergence level indicating whether the minimization
has converged or not.
r = b - A(x0)
d = self.preconditioner(r)
r = -E.gradient
if self.preconditioner is not None:
d = self.preconditioner(r)
d = r.copy()
previous_gamma = (r.vdot(d)).real
if previous_gamma == 0:"The starting guess is already perfect solution "
"for the inverse problem.")
return x0, self.convergence_level+1
norm_b = np.sqrt((b.vdot(b)).real)
x = x0.copy()
return E, self.convergence_level+1
convergence = 0
iteration_number = 1"Starting conjugate gradient.")
while True:
if self.callback is not None:
self.callback(x, iteration_number)
self.callback(E, iteration_number)
q = A(d)
alpha = previous_gamma/d.vdot(q).real
q = E.curvature(d)
alpha = previous_gamma/(d.vdot(q).real)
if not np.isfinite(alpha):
self.logger.error("Alpha became infinite! Stopping.")
return x0, 0
return E, 0
x += d * alpha
E =*alpha)
reset = False
if alpha < 0:
......@@ -153,20 +147,23 @@ class ConjugateGradient(Loggable, object):
reset += (iteration_number % self.reset_count == 0)
if reset:"Resetting conjugate directions.")
r = b - A(x)
r = -E.gradient
r -= q * alpha
s = self.preconditioner(r)
if self.preconditioner is not None:
s = self.preconditioner(r)
s = r.copy()
gamma = r.vdot(s).real
if gamma < 0:
self.logger.warn("Positive definitness of preconditioner "
self.logger.warn("Positive definiteness of preconditioner "
beta = max(0, gamma/previous_gamma)
delta = np.sqrt(gamma)/norm_b
delta = r.norm()
self.logger.debug("Iteration : %08u alpha = %3.1E "
"beta = %3.1E delta = %3.1E" %
......@@ -196,4 +193,4 @@ class ConjugateGradient(Loggable, object):
iteration_number += 1
previous_gamma = gamma
return x, convergence
return E, convergence
import unittest
import numpy as np
from numpy.testing import assert_equal, assert_almost_equal
from numpy.testing import assert_equal, assert_allclose
from nifty import Field, DiagonalOperator, RGSpace, HPSpace
from nifty import ConjugateGradient
from nifty import ConjugateGradient, QuadraticEnergy
from test.common import expand
......@@ -38,10 +38,11 @@ class Test_ConjugateGradient(unittest.TestCase):
required_result = Field(space, val=1.)
minimizer = ConjugateGradient()
energy = QuadraticEnergy(A=covariance, b=required_result,
(position, convergence) = minimizer(A=covariance, x0=starting_point,
(energy, convergence) = minimizer(energy)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment