quadratic_energy.py 1.27 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
from .energy import Energy
from .memoization import memo
Martin Reinecke's avatar
Martin Reinecke committed
3
4
5
6


class QuadraticEnergy(Energy):
    """The Energy for a quadratic form.
Martin Reinecke's avatar
Martin Reinecke committed
7
8
    The most important aspect of this energy is that its curvature must be
    position-independent.
Martin Reinecke's avatar
Martin Reinecke committed
9
10
    """

11
    def __init__(self, position, A, b, _grad=None, _bnorm=None):
Martin Reinecke's avatar
Martin Reinecke committed
12
13
14
        super(QuadraticEnergy, self).__init__(position=position)
        self._A = A
        self._b = b
15
16
17
        self._bnorm = _bnorm
        if _grad is not None:
            self._Ax = _grad + self._b
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
18
19
        else:
            self._Ax = self._A(self.position)
Martin Reinecke's avatar
Martin Reinecke committed
20
21

    def at(self, position):
22
23
        return self.__class__(position=position, A=self._A, b=self._b,
                              _bnorm=self.norm_b)
Martin Reinecke's avatar
Martin Reinecke committed
24

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
25
26
    def at_with_grad(self, position, grad):
        return self.__class__(position=position, A=self._A, b=self._b,
27
                              _grad=grad, _bnorm=self.norm_b)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
28

Martin Reinecke's avatar
Martin Reinecke committed
29
30
31
32
33
34
35
36
37
38
39
40
41
    @property
    @memo
    def value(self):
        return 0.5*self.position.vdot(self._Ax) - self._b.vdot(self.position)

    @property
    @memo
    def gradient(self):
        return self._Ax - self._b

    @property
    def curvature(self):
        return self._A
42
43
44

    @property
    def norm_b(self):
45
46
47
        if self._bnorm is None:
            self._bnorm = self._b.norm()
        return self._bnorm