energy_adapter.py 1.51 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
from __future__ import absolute_import, division, print_function

from ..compat import *
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
5
from ..minimization.energy import Energy
6
from ..operators.block_diagonal_operator import BlockDiagonalOperator
Philipp Arras's avatar
Philipp Arras committed
7
from ..operators.scaling_operator import ScalingOperator
Martin Reinecke's avatar
Martin Reinecke committed
8
9
10


class EnergyAdapter(Energy):
11
    def __init__(self, position, op, constants=[], want_metric=False):
Martin Reinecke's avatar
Martin Reinecke committed
12
13
        super(EnergyAdapter, self).__init__(position)
        self._op = op
Martin Reinecke's avatar
Martin Reinecke committed
14
        self._constants = constants
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
15
        self._want_metric = want_metric
Martin Reinecke's avatar
Martin Reinecke committed
16
        if len(self._constants) == 0:
17
            tmp = self._op(Linearization.make_var(self._position, want_metric))
Martin Reinecke's avatar
Martin Reinecke committed
18
19
20
21
        else:
            ops = [ScalingOperator(0. if key in self._constants else 1., dom)
                   for key, dom in self._position.domain.items()]
            bdop = BlockDiagonalOperator(self._position.domain, tuple(ops))
22
23
            tmp = self._op(Linearization(self._position, bdop,
                                         want_metric=want_metric))
Martin Reinecke's avatar
Martin Reinecke committed
24
25
26
        self._val = tmp.val.local_data[()]
        self._grad = tmp.gradient
        self._metric = tmp._metric
Martin Reinecke's avatar
Martin Reinecke committed
27
28

    def at(self, position):
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
29
30
        return EnergyAdapter(position, self._op, self._constants,
                             self._want_metric)
Martin Reinecke's avatar
Martin Reinecke committed
31
32
33

    @property
    def value(self):
Martin Reinecke's avatar
Martin Reinecke committed
34
        return self._val
Martin Reinecke's avatar
Martin Reinecke committed
35
36
37

    @property
    def gradient(self):
Martin Reinecke's avatar
Martin Reinecke committed
38
        return self._grad
Martin Reinecke's avatar
Martin Reinecke committed
39

Martin Reinecke's avatar
Martin Reinecke committed
40
41
42
43
    @property
    def metric(self):
        return self._metric

Martin Reinecke's avatar
Martin Reinecke committed
44
45
    def apply_metric(self, x):
        return self._metric(x)