energy_adapter.py 1.26 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
from __future__ import absolute_import, division, print_function

from ..compat import *
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
5
from ..minimization.energy import Energy
6
from ..operators.block_diagonal_operator import BlockDiagonalOperator
Philipp Arras's avatar
Philipp Arras committed
7
from ..operators.scaling_operator import ScalingOperator
Martin Reinecke's avatar
Martin Reinecke committed
8
9
10


class EnergyAdapter(Energy):
Martin Reinecke's avatar
Martin Reinecke committed
11
    def __init__(self, position, op, constants=[]):
Martin Reinecke's avatar
Martin Reinecke committed
12
13
        super(EnergyAdapter, self).__init__(position)
        self._op = op
Martin Reinecke's avatar
Martin Reinecke committed
14
        self._constants = constants
Martin Reinecke's avatar
Martin Reinecke committed
15
16
17
18
19
20
21
22
23
24
        if len(self._constants) == 0:
            tmp = self._op(Linearization.make_var(self._position))
        else:
            ops = [ScalingOperator(0. if key in self._constants else 1., dom)
                   for key, dom in self._position.domain.items()]
            bdop = BlockDiagonalOperator(self._position.domain, tuple(ops))
            tmp = self._op(Linearization(self._position, bdop))
        self._val = tmp.val.local_data[()]
        self._grad = tmp.gradient
        self._metric = tmp._metric
Martin Reinecke's avatar
Martin Reinecke committed
25
26

    def at(self, position):
Martin Reinecke's avatar
Martin Reinecke committed
27
        return EnergyAdapter(position, self._op, self._constants)
Martin Reinecke's avatar
Martin Reinecke committed
28
29
30

    @property
    def value(self):
Martin Reinecke's avatar
Martin Reinecke committed
31
        return self._val
Martin Reinecke's avatar
Martin Reinecke committed
32
33
34

    @property
    def gradient(self):
Martin Reinecke's avatar
Martin Reinecke committed
35
        return self._grad
Martin Reinecke's avatar
Martin Reinecke committed
36

Martin Reinecke's avatar
Martin Reinecke committed
37
38
    def apply_metric(self, x):
        return self._metric(x)