energy_adapter.py 1.02 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
from __future__ import absolute_import, division, print_function

from ..compat import *
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
5
from ..minimization.energy import Energy
Martin Reinecke's avatar
Martin Reinecke committed
6
7
8


class EnergyAdapter(Energy):
9
    def __init__(self, position, op, constants=[], want_metric=False):
Martin Reinecke's avatar
Martin Reinecke committed
10
11
        super(EnergyAdapter, self).__init__(position)
        self._op = op
Martin Reinecke's avatar
Martin Reinecke committed
12
        self._constants = constants
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
13
        self._want_metric = want_metric
Martin Reinecke's avatar
Martin Reinecke committed
14
15
        lin = Linearization.make_partial_var(position, constants, want_metric)
        tmp = self._op(lin)
Martin Reinecke's avatar
Martin Reinecke committed
16
17
18
        self._val = tmp.val.local_data[()]
        self._grad = tmp.gradient
        self._metric = tmp._metric
Martin Reinecke's avatar
Martin Reinecke committed
19
20

    def at(self, position):
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
21
22
        return EnergyAdapter(position, self._op, self._constants,
                             self._want_metric)
Martin Reinecke's avatar
Martin Reinecke committed
23
24
25

    @property
    def value(self):
Martin Reinecke's avatar
Martin Reinecke committed
26
        return self._val
Martin Reinecke's avatar
Martin Reinecke committed
27
28
29

    @property
    def gradient(self):
Martin Reinecke's avatar
Martin Reinecke committed
30
        return self._grad
Martin Reinecke's avatar
Martin Reinecke committed
31

Martin Reinecke's avatar
Martin Reinecke committed
32
33
34
35
    @property
    def metric(self):
        return self._metric

Martin Reinecke's avatar
Martin Reinecke committed
36
37
    def apply_metric(self, x):
        return self._metric(x)