quadratic_energy.py 1.05 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
from .energy import Energy
Martin Reinecke's avatar
Martin Reinecke committed
2
3
4
5


class QuadraticEnergy(Energy):
    """The Energy for a quadratic form.
Martin Reinecke's avatar
Martin Reinecke committed
6
7
    The most important aspect of this energy is that its curvature must be
    position-independent.
Martin Reinecke's avatar
Martin Reinecke committed
8
9
    """

10
    def __init__(self, position, A, b, _grad=None):
Martin Reinecke's avatar
Martin Reinecke committed
11
12
13
        super(QuadraticEnergy, self).__init__(position=position)
        self._A = A
        self._b = b
14
        if _grad is not None:
Martin Reinecke's avatar
Martin Reinecke committed
15
16
            self._grad = _grad
            Ax = _grad + self._b
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
17
        else:
Martin Reinecke's avatar
Martin Reinecke committed
18
19
20
            Ax = self._A(self.position)
            self._grad = Ax - self._b
        self._value = 0.5*self.position.vdot(Ax) - b.vdot(self.position)
Martin Reinecke's avatar
Martin Reinecke committed
21
22

    def at(self, position):
23
        return QuadraticEnergy(position=position, A=self._A, b=self._b)
Martin Reinecke's avatar
Martin Reinecke committed
24

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
25
    def at_with_grad(self, position, grad):
26
27
        return QuadraticEnergy(position=position, A=self._A, b=self._b,
                               _grad=grad)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
28

Martin Reinecke's avatar
Martin Reinecke committed
29
30
    @property
    def value(self):
Martin Reinecke's avatar
Martin Reinecke committed
31
        return self._value
Martin Reinecke's avatar
Martin Reinecke committed
32
33
34

    @property
    def gradient(self):
Martin Reinecke's avatar
Martin Reinecke committed
35
        return self._grad
Martin Reinecke's avatar
Martin Reinecke committed
36
37
38
39

    @property
    def curvature(self):
        return self._A