Commit 1e23a189 authored by Philipp Arras's avatar Philipp Arras

Fixups regarding constant minimization, nanisinf for EnergyAdapter

parent 1200b36a
Pipeline #75815 passed with stages
in 13 minutes and 44 seconds
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..linearization import Linearization from ..linearization import Linearization
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..sugar import makeDomain from ..sugar import makeDomain
...@@ -39,33 +41,34 @@ class EnergyAdapter(Energy): ...@@ -39,33 +41,34 @@ class EnergyAdapter(Energy):
If True, the class will provide a `metric` property. This should only If True, the class will provide a `metric` property. This should only
be enabled if it is required, because it will most likely consume be enabled if it is required, because it will most likely consume
additional resources. Default: False. additional resources. Default: False.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
""" """
def __init__(self, position, op, constants=[], want_metric=False, def __init__(self, position, op, constants=[], want_metric=False,
_op4eval=None): nanisinf=False):
super(EnergyAdapter, self).__init__(position) super(EnergyAdapter, self).__init__(position)
self._op = op self._op = op
self._constants = constants if len(constants) > 0:
self._op4eval = _op4eval dom = makeDomain({kk: vv for kk, vv in position.domain.items()
if self._op4eval is None: if kk in constants})
if len(constants) > 0: _, self._op = op.simplify_for_constant_input(position.extract(dom))
dom = {kk: vv for kk, vv in position.domain.items()
if kk in constants}
dom = makeDomain(dom)
cstpos = position.extract(dom)
_, self._op4eval = op.simplify_for_constant_input(cstpos)
else:
self._op4eval = op
self._want_metric = want_metric self._want_metric = want_metric
lin = Linearization.make_var(position, want_metric) lin = Linearization.make_var(position, want_metric)
tmp = self._op(lin) tmp = self._op(lin)
self._val = tmp.val.val[()] self._val = tmp.val.val[()]
self._grad = tmp.gradient self._grad = tmp.gradient
self._metric = tmp._metric self._metric = tmp._metric
self._nanisinf = bool(nanisinf)
if self._nanisinf and np.isnan(self._val):
self._val = np.inf
def at(self, position): def at(self, position):
return EnergyAdapter(position, self._op, self._constants, return EnergyAdapter(position, self._op, want_metric=self._want_metric,
self._want_metric, self._op4eval) nanisinf=self._nanisinf)
@property @property
def value(self): def value(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment