Commit be9f5044 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'constantfixups' into 'NIFTy_7'

Fixups regarding constant minimization, nanisinf for EnergyAdapter

See merge request !523
parents 1200b36a 1e23a189
Pipeline #75982 passed with stages
in 13 minutes and 45 seconds
......@@ -15,6 +15,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..linearization import Linearization
from import Energy
from ..sugar import makeDomain
......@@ -39,33 +41,34 @@ class EnergyAdapter(Energy):
If True, the class will provide a `metric` property. This should only
be enabled if it is required, because it will most likely consume
additional resources. Default: False.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
def __init__(self, position, op, constants=[], want_metric=False,
super(EnergyAdapter, self).__init__(position)
self._op = op
self._constants = constants
self._op4eval = _op4eval
if self._op4eval is None:
if len(constants) > 0:
dom = {kk: vv for kk, vv in position.domain.items()
if kk in constants}
dom = makeDomain(dom)
cstpos = position.extract(dom)
_, self._op4eval = op.simplify_for_constant_input(cstpos)
self._op4eval = op
if len(constants) > 0:
dom = makeDomain({kk: vv for kk, vv in position.domain.items()
if kk in constants})
_, self._op = op.simplify_for_constant_input(position.extract(dom))
self._want_metric = want_metric
lin = Linearization.make_var(position, want_metric)
tmp = self._op(lin)
self._val = tmp.val.val[()]
self._grad = tmp.gradient
self._metric = tmp._metric
self._nanisinf = bool(nanisinf)
if self._nanisinf and np.isnan(self._val):
self._val = np.inf
def at(self, position):
return EnergyAdapter(position, self._op, self._constants,
self._want_metric, self._op4eval)
return EnergyAdapter(position, self._op, want_metric=self._want_metric,
def value(self):
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment