Commit 1e23a189 authored by Philipp Arras's avatar Philipp Arras

Fixups regarding constant minimization, nanisinf for EnergyAdapter

parent 1200b36a
Pipeline #75815 passed with stages
in 13 minutes and 44 seconds
......@@ -15,6 +15,8 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..linearization import Linearization
from ..minimization.energy import Energy
from ..sugar import makeDomain
......@@ -39,33 +41,34 @@ class EnergyAdapter(Energy):
If True, the class will provide a `metric` property. This should only
be enabled if it is required, because it will most likely consume
additional resources. Default: False.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
"""
def __init__(self, position, op, constants=[], want_metric=False,
_op4eval=None):
nanisinf=False):
super(EnergyAdapter, self).__init__(position)
self._op = op
self._constants = constants
self._op4eval = _op4eval
if self._op4eval is None:
if len(constants) > 0:
dom = {kk: vv for kk, vv in position.domain.items()
if kk in constants}
dom = makeDomain(dom)
cstpos = position.extract(dom)
_, self._op4eval = op.simplify_for_constant_input(cstpos)
else:
self._op4eval = op
if len(constants) > 0:
dom = makeDomain({kk: vv for kk, vv in position.domain.items()
if kk in constants})
_, self._op = op.simplify_for_constant_input(position.extract(dom))
self._want_metric = want_metric
lin = Linearization.make_var(position, want_metric)
tmp = self._op(lin)
self._val = tmp.val.val[()]
self._grad = tmp.gradient
self._metric = tmp._metric
self._nanisinf = bool(nanisinf)
if self._nanisinf and np.isnan(self._val):
self._val = np.inf
def at(self, position):
return EnergyAdapter(position, self._op, self._constants,
self._want_metric, self._op4eval)
return EnergyAdapter(position, self._op, want_metric=self._want_metric,
nanisinf=self._nanisinf)
@property
def value(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment