energy_adapter.py 2.54 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
from __future__ import absolute_import, division, print_function

from ..compat import *
from ..minimization.energy import Energy
from ..linearization import Linearization
Martin Reinecke's avatar
Martin Reinecke committed
6
from ..multi_field import MultiField
Martin Reinecke's avatar
Martin Reinecke committed
7
import numpy as np
Martin Reinecke's avatar
Martin Reinecke committed
8
9
10


class EnergyAdapter(Energy):
Martin Reinecke's avatar
Martin Reinecke committed
11
12
    def __init__(self, position, op, controller=None, preconditioner=None,
                 constants=[]):
Martin Reinecke's avatar
Martin Reinecke committed
13
14
        super(EnergyAdapter, self).__init__(position)
        self._op = op
Martin Reinecke's avatar
Martin Reinecke committed
15
        self._val = self._grad = self._metric = None
Martin Reinecke's avatar
Martin Reinecke committed
16
17
        self._controller = controller
        self._preconditioner = preconditioner
Martin Reinecke's avatar
Martin Reinecke committed
18
        self._constants = constants
Martin Reinecke's avatar
Martin Reinecke committed
19
20

    def at(self, position):
Martin Reinecke's avatar
Martin Reinecke committed
21
        return EnergyAdapter(position, self._op, self._controller,
Martin Reinecke's avatar
Martin Reinecke committed
22
                             self._preconditioner, self._constants)
Martin Reinecke's avatar
Martin Reinecke committed
23

Martin Reinecke's avatar
Martin Reinecke committed
24
    def _fill_all(self):
Martin Reinecke's avatar
Martin Reinecke committed
25
26
27
28
29
30
31
32
33
34
35
        if len(self._constants) == 0:
            tmp = self._op(Linearization.make_var(self._position))
        else:
            ctmp = MultiField.from_dict({key: val
                                        for key, val in self._position.items()
                                        if key in self._constants})
            vtmp = MultiField.from_dict({key: val
                                        for key, val in self._position.items()
                                        if key not in self._constants})
            lin = Linearization.make_var(vtmp) + Linearization.make_const(ctmp)
            tmp = self._op(lin)
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
36
        self._val = tmp.val.local_data[()]
Martin Reinecke's avatar
Martin Reinecke committed
37
        self._grad = tmp.gradient
Martin Reinecke's avatar
Martin Reinecke committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        if self._controller is not None:
            from ..operators.linear_operator import LinearOperator
            from ..operators.inversion_enabler import InversionEnabler

            if self._preconditioner is None:
                precond = None
            elif isinstance(self._preconditioner, LinearOperator):
                precond = self._preconditioner
            elif isinstance(self._preconditioner, Energy):
                precond = self._preconditioner.at(self._position).metric
            self._metric = InversionEnabler(tmp._metric, self._controller,
                                            precond)
        else:
            self._metric = tmp._metric
Martin Reinecke's avatar
Martin Reinecke committed
52

Martin Reinecke's avatar
Martin Reinecke committed
53
54
    @property
    def value(self):
Martin Reinecke's avatar
Martin Reinecke committed
55
        if self._val is None:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
56
            self._val = self._op(self._position).local_data[()]
Martin Reinecke's avatar
Martin Reinecke committed
57
        return self._val
Martin Reinecke's avatar
Martin Reinecke committed
58
59
60

    @property
    def gradient(self):
Martin Reinecke's avatar
Martin Reinecke committed
61
62
63
        if self._grad is None:
            self._fill_all()
        return self._grad
Martin Reinecke's avatar
Martin Reinecke committed
64
65
66

    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
67
68
69
        if self._metric is None:
            self._fill_all()
        return self._metric