# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..linearization import Linearization
from ..minimization.energy import Energy
from ..sugar import makeDomain
class PointKL(Energy):
"""Helper class which provides the traditional Nifty Energy interface to
Nifty operators with a scalar target domain.
Parameters
-----------
position: Field or MultiField
The position where the minimization process is started.
op: EnergyOperator
The expression computing the energy from the input data.
constants: list of strings
The component names of the operator's input domain which are assumed
to be constant during the minimization process.
If the operator's input domain is not a MultiField, this must be empty.
Default: [].
want_metric: bool
If True, the class will provide a `metric` property. This should only
be enabled if it is required, because it will most likely consume
additional resources. Default: False.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
"""
def __init__(self, position, op, constants=[], want_metric=False,
nanisinf=False):
if len(constants) > 0:
cstpos = position.extract_by_keys(constants)
_, op = op.simplify_for_constant_input(cstpos)
varkeys = set(op.domain.keys()) - set(constants)
position = position.extract_by_keys(varkeys)
super(PointKL, self).__init__(position)
self._op = op
self._want_metric = want_metric
lin = Linearization.make_var(position, want_metric)
tmp = self._op(lin)
self._val = tmp.val.val[()]
self._grad = tmp.gradient
self._metric = tmp._metric
self._nanisinf = bool(nanisinf)
if self._nanisinf and np.isnan(self._val):
self._val = np.inf
def at(self, position):
return PointKL(position, self._op, want_metric=self._want_metric,
nanisinf=self._nanisinf)
@property
def value(self):
return self._val
@property
def gradient(self):
return self._grad
@property
def metric(self):
return self._metric
def apply_metric(self, x):
return self._metric(x)