energy_adapter.py 1.36 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
from __future__ import absolute_import, division, print_function

from ..compat import *
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
5
from ..minimization.energy import Energy
6
from ..operators.block_diagonal_operator import BlockDiagonalOperator
Philipp Arras's avatar
Philipp Arras committed
7
from ..operators.scaling_operator import ScalingOperator
Martin Reinecke's avatar
Martin Reinecke committed
8
9
10


class EnergyAdapter(Energy):
11
    def __init__(self, position, op, constants=[], want_metric=False):
Martin Reinecke's avatar
Martin Reinecke committed
12
13
        super(EnergyAdapter, self).__init__(position)
        self._op = op
Martin Reinecke's avatar
Martin Reinecke committed
14
        self._constants = constants
Martin Reinecke's avatar
Martin Reinecke committed
15
        if len(self._constants) == 0:
16
            tmp = self._op(Linearization.make_var(self._position, want_metric))
Martin Reinecke's avatar
Martin Reinecke committed
17
18
19
20
        else:
            ops = [ScalingOperator(0. if key in self._constants else 1., dom)
                   for key, dom in self._position.domain.items()]
            bdop = BlockDiagonalOperator(self._position.domain, tuple(ops))
21
22
            tmp = self._op(Linearization(self._position, bdop,
                                         want_metric=want_metric))
Martin Reinecke's avatar
Martin Reinecke committed
23
24
25
        self._val = tmp.val.local_data[()]
        self._grad = tmp.gradient
        self._metric = tmp._metric
Martin Reinecke's avatar
Martin Reinecke committed
26
27

    def at(self, position):
Martin Reinecke's avatar
Martin Reinecke committed
28
        return EnergyAdapter(position, self._op, self._constants)
Martin Reinecke's avatar
Martin Reinecke committed
29
30
31

    @property
    def value(self):
Martin Reinecke's avatar
Martin Reinecke committed
32
        return self._val
Martin Reinecke's avatar
Martin Reinecke committed
33
34
35

    @property
    def gradient(self):
Martin Reinecke's avatar
Martin Reinecke committed
36
        return self._grad
Martin Reinecke's avatar
Martin Reinecke committed
37

Martin Reinecke's avatar
Martin Reinecke committed
38
39
    def apply_metric(self, x):
        return self._metric(x)