energy_adapter.py 2.48 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17 18

from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
19
from ..minimization.energy import Energy
Martin Reinecke's avatar
Martin Reinecke committed
20 21 22


class EnergyAdapter(Energy):
Martin Reinecke's avatar
Martin Reinecke committed
23 24 25 26 27
    """Helper class which provides the traditional Nifty Energy interface to
    Nifty operators with a scalar target domain.

    Parameters
    -----------
Philipp Arras's avatar
Philipp Arras committed
28 29
    position: Field or MultiField
        The position where the minimization process is started.
Philipp Arras's avatar
Philipp Arras committed
30 31 32
    op: EnergyOperator
        The expression computing the energy from the input data.
    constants: list of strings
Martin Reinecke's avatar
Martin Reinecke committed
33 34 35
        The component names of the operator's input domain which are assumed
        to be constant during the minimization process.
        If the operator's input domain is not a MultiField, this must be empty.
Philipp Arras's avatar
Philipp Arras committed
36 37 38
        Default: [].
    want_metric: bool
        If True, the class will provide a `metric` property. This should only
Martin Reinecke's avatar
Martin Reinecke committed
39
        be enabled if it is required, because it will most likely consume
Philipp Arras's avatar
Philipp Arras committed
40
        additional resources. Default: False.
Martin Reinecke's avatar
Martin Reinecke committed
41 42
    """

43
    def __init__(self, position, op, constants=[], want_metric=False):
Martin Reinecke's avatar
Martin Reinecke committed
44 45
        super(EnergyAdapter, self).__init__(position)
        self._op = op
Martin Reinecke's avatar
Martin Reinecke committed
46
        self._constants = constants
Martin Reinecke's avatar
fixes  
Martin Reinecke committed
47
        self._want_metric = want_metric
Martin Reinecke's avatar
Martin Reinecke committed
48 49
        lin = Linearization.make_partial_var(position, constants, want_metric)
        tmp = self._op(lin)
Martin Reinecke's avatar
stage2  
Martin Reinecke committed
50
        self._val = tmp.val.val[()]
Martin Reinecke's avatar
Martin Reinecke committed
51 52
        self._grad = tmp.gradient
        self._metric = tmp._metric
Martin Reinecke's avatar
Martin Reinecke committed
53 54

    def at(self, position):
Martin Reinecke's avatar
fixes  
Martin Reinecke committed
55 56
        return EnergyAdapter(position, self._op, self._constants,
                             self._want_metric)
Martin Reinecke's avatar
Martin Reinecke committed
57 58 59

    @property
    def value(self):
Martin Reinecke's avatar
Martin Reinecke committed
60
        return self._val
Martin Reinecke's avatar
Martin Reinecke committed
61 62 63

    @property
    def gradient(self):
Martin Reinecke's avatar
Martin Reinecke committed
64
        return self._grad
Martin Reinecke's avatar
Martin Reinecke committed
65

Martin Reinecke's avatar
Martin Reinecke committed
66 67 68 69
    @property
    def metric(self):
        return self._metric

Martin Reinecke's avatar
Martin Reinecke committed
70 71
    def apply_metric(self, x):
        return self._metric(x)