# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..linearization import Linearization
from ..minimization.energy import Energy
class EnergyAdapter(Energy):
"""Helper class which provides the traditional Nifty Energy interface to
Nifty operators with a scalar target domain.
Parameters
-----------
position: Field or MultiField living on the operator's input domain.
The position where the minimization process is started
op: Operator with a scalar target domain
The expression computing the energy from the input data
constants: list of strings (default: [])
The component names of the operator's input domain which are assumed
to be constant during the minimization process.
If the operator's input domain is not a MultiField, this must be empty.
want_metric: bool (default: False)
if True, the class will provide a `metric` property. This should only
be enabled if it is required, because it will most likely consume
additional resources.
"""
def __init__(self, position, op, constants=[], want_metric=False):
super(EnergyAdapter, self).__init__(position)
self._op = op
self._constants = constants
self._want_metric = want_metric
lin = Linearization.make_partial_var(position, constants, want_metric)
tmp = self._op(lin)
self._val = tmp.val.local_data[()]
self._grad = tmp.gradient
self._metric = tmp._metric
def at(self, position):
return EnergyAdapter(position, self._op, self._constants,
self._want_metric)
@property
def value(self):
return self._val
@property
def gradient(self):
return self._grad
@property
def metric(self):
return self._metric
def apply_metric(self, x):
return self._metric(x)