There is a maintenance of MPCDF Gitlab on Thursday, April 22st 2020, 9:00 am CEST - Expect some service interruptions during this time

energy_adapter.py 3.19 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2020 Max-Planck-Society
15 16
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

18 19
import numpy as np

Martin Reinecke's avatar
Martin Reinecke committed
20
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
21
from ..minimization.energy import Energy
22
from ..sugar import makeDomain
Martin Reinecke's avatar
Martin Reinecke committed
23 24 25


class EnergyAdapter(Energy):
Martin Reinecke's avatar
Martin Reinecke committed
26 27 28 29 30
    """Helper class which provides the traditional Nifty Energy interface to
    Nifty operators with a scalar target domain.

    Parameters
    -----------
Philipp Arras's avatar
Philipp Arras committed
31 32
    position: Field or MultiField
        The position where the minimization process is started.
Philipp Arras's avatar
Philipp Arras committed
33 34 35
    op: EnergyOperator
        The expression computing the energy from the input data.
    constants: list of strings
Martin Reinecke's avatar
Martin Reinecke committed
36 37 38
        The component names of the operator's input domain which are assumed
        to be constant during the minimization process.
        If the operator's input domain is not a MultiField, this must be empty.
Philipp Arras's avatar
Philipp Arras committed
39 40 41
        Default: [].
    want_metric: bool
        If True, the class will provide a `metric` property. This should only
Martin Reinecke's avatar
Martin Reinecke committed
42
        be enabled if it is required, because it will most likely consume
Philipp Arras's avatar
Philipp Arras committed
43
        additional resources. Default: False.
44 45 46 47 48
    nanisinf : bool
        If true, nan energies which can happen due to overflows in the forward
        model are interpreted as inf. Thereby, the code does not crash on
        these occaisions but rather the minimizer is told that the position it
        has tried is not sensible.
Martin Reinecke's avatar
Martin Reinecke committed
49 50
    """

51
    def __init__(self, position, op, constants=[], want_metric=False,
52
                 nanisinf=False):
Martin Reinecke's avatar
Martin Reinecke committed
53 54
        super(EnergyAdapter, self).__init__(position)
        self._op = op
55
        if len(constants) > 0:
56 57 58 59
            cstpos = position.extract_by_keys(constants)
            _, self._op = op.simplify_for_constant_input(cstpos)
            varkeys = set(op.domain.keys()) - set(constants)
            position = position.extract_by_keys(varkeys)
Martin Reinecke's avatar
fixes  
Martin Reinecke committed
60
        self._want_metric = want_metric
61
        lin = Linearization.make_var(position, want_metric)
Martin Reinecke's avatar
Martin Reinecke committed
62
        tmp = self._op(lin)
Martin Reinecke's avatar
stage2  
Martin Reinecke committed
63
        self._val = tmp.val.val[()]
Martin Reinecke's avatar
Martin Reinecke committed
64 65
        self._grad = tmp.gradient
        self._metric = tmp._metric
66 67 68
        self._nanisinf = bool(nanisinf)
        if self._nanisinf and np.isnan(self._val):
            self._val = np.inf
Martin Reinecke's avatar
Martin Reinecke committed
69 70

    def at(self, position):
71 72
        return EnergyAdapter(position, self._op, want_metric=self._want_metric,
                             nanisinf=self._nanisinf)
Martin Reinecke's avatar
Martin Reinecke committed
73 74 75

    @property
    def value(self):
Martin Reinecke's avatar
Martin Reinecke committed
76
        return self._val
Martin Reinecke's avatar
Martin Reinecke committed
77 78 79

    @property
    def gradient(self):
Martin Reinecke's avatar
Martin Reinecke committed
80
        return self._grad
Martin Reinecke's avatar
Martin Reinecke committed
81

Martin Reinecke's avatar
Martin Reinecke committed
82 83 84 85
    @property
    def metric(self):
        return self._metric

Martin Reinecke's avatar
Martin Reinecke committed
86 87
    def apply_metric(self, x):
        return self._metric(x)