point_kl.py 3.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2020 Max-Planck-Society
15
16
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

18
19
import numpy as np

Martin Reinecke's avatar
Martin Reinecke committed
20
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
21
from ..minimization.energy import Energy
22
from ..sugar import makeDomain
Martin Reinecke's avatar
Martin Reinecke committed
23
24


25
class PointKL(Energy):
Martin Reinecke's avatar
Martin Reinecke committed
26
27
28
29
30
    """Helper class which provides the traditional Nifty Energy interface to
    Nifty operators with a scalar target domain.

    Parameters
    -----------
Philipp Arras's avatar
Philipp Arras committed
31
32
    position: Field or MultiField
        The position where the minimization process is started.
Philipp Arras's avatar
Philipp Arras committed
33
34
35
    op: EnergyOperator
        The expression computing the energy from the input data.
    constants: list of strings
Martin Reinecke's avatar
Martin Reinecke committed
36
37
38
        The component names of the operator's input domain which are assumed
        to be constant during the minimization process.
        If the operator's input domain is not a MultiField, this must be empty.
Philipp Arras's avatar
Philipp Arras committed
39
40
41
        Default: [].
    want_metric: bool
        If True, the class will provide a `metric` property. This should only
Martin Reinecke's avatar
Martin Reinecke committed
42
        be enabled if it is required, because it will most likely consume
Philipp Arras's avatar
Philipp Arras committed
43
        additional resources. Default: False.
44
45
46
47
48
    nanisinf : bool
        If true, nan energies which can happen due to overflows in the forward
        model are interpreted as inf. Thereby, the code does not crash on
        these occaisions but rather the minimizer is told that the position it
        has tried is not sensible.
Martin Reinecke's avatar
Martin Reinecke committed
49
50
    """

51
    def __init__(self, position, op, constants=[], want_metric=False,
52
53
                 nanisinf=False):
        if len(constants) > 0:
54
            cstpos = position.extract_by_keys(constants)
55
            _, op = op.simplify_for_constant_input(cstpos)
56
57
            varkeys = set(op.domain.keys()) - set(constants)
            position = position.extract_by_keys(varkeys)
Jakob Knollmüller's avatar
Jakob Knollmüller committed
58
        super(PointKL, self).__init__(position)
59
        self._op = op
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
60
        self._want_metric = want_metric
61
        lin = Linearization.make_var(position, want_metric)
Martin Reinecke's avatar
Martin Reinecke committed
62
        tmp = self._op(lin)
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
63
        self._val = tmp.val.val[()]
Martin Reinecke's avatar
Martin Reinecke committed
64
65
        self._grad = tmp.gradient
        self._metric = tmp._metric
66
67
68
        self._nanisinf = bool(nanisinf)
        if self._nanisinf and np.isnan(self._val):
            self._val = np.inf
Martin Reinecke's avatar
Martin Reinecke committed
69
70

    def at(self, position):
Jakob Knollmüller's avatar
Jakob Knollmüller committed
71
        return PointKL(position, self._op, want_metric=self._want_metric,
72
                             nanisinf=self._nanisinf)
Martin Reinecke's avatar
Martin Reinecke committed
73
74
75

    @property
    def value(self):
Martin Reinecke's avatar
Martin Reinecke committed
76
        return self._val
Martin Reinecke's avatar
Martin Reinecke committed
77
78
79

    @property
    def gradient(self):
Martin Reinecke's avatar
Martin Reinecke committed
80
        return self._grad
Martin Reinecke's avatar
Martin Reinecke committed
81

Martin Reinecke's avatar
Martin Reinecke committed
82
83
84
85
    @property
    def metric(self):
        return self._metric

Martin Reinecke's avatar
Martin Reinecke committed
86
87
    def apply_metric(self, x):
        return self._metric(x)