Commit 8ff10b70 authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge branch 'nan_to_inf' into 'NIFTy_6'

Add possibility to cast NaNs to Infs in KL

See merge request !488
parents c0bedf7c 64fad7d8
Pipeline #75331 passed with stages
in 8 minutes and 52 seconds
......@@ -114,6 +114,11 @@ class MetricGaussianKL(Energy):
If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
_local_samples : None
Only a parameter for internal uses. Typically not to be set by users.
......@@ -131,7 +136,8 @@ class MetricGaussianKL(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
napprox=0, comm=None, _local_samples=None):
napprox=0, comm=None, _local_samples=None,
super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
......@@ -142,6 +148,7 @@ class MetricGaussianKL(Energy):
raise TypeError
self._constants = tuple(constants)
self._point_estimates = tuple(point_estimates)
self._mitigate_nans = nanisinf
if not isinstance(mirror_samples, bool):
raise TypeError
......@@ -195,6 +202,8 @@ class MetricGaussianKL(Energy):
v += tmp.val.val
g = g + tmp.gradient
self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples
if np.isnan(self._val) and self._mitigate_nans:
self._val = np.inf
self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples
self._metric = None
......@@ -202,7 +211,7 @@ class MetricGaussianKL(Energy):
return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, self._mirror_samples, comm=self._comm,
_local_samples=self._local_samples, nanisinf=self._mitigate_nans)
def value(self):
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment