Commit 8ff10b70 authored by Philipp Arras's avatar Philipp Arras

Merge branch 'nan_to_inf' into 'NIFTy_6'

Add possibility to cast NaNs to Infs in KL

See merge request !488
parents c0bedf7c 64fad7d8
Pipeline #75331 passed with stages
in 8 minutes and 52 seconds
...@@ -114,6 +114,11 @@ class MetricGaussianKL(Energy): ...@@ -114,6 +114,11 @@ class MetricGaussianKL(Energy):
If not None, samples will be distributed as evenly as possible If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample and across this communicator. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task. its mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
_local_samples : None _local_samples : None
Only a parameter for internal uses. Typically not to be set by users. Only a parameter for internal uses. Typically not to be set by users.
...@@ -131,7 +136,8 @@ class MetricGaussianKL(Energy): ...@@ -131,7 +136,8 @@ class MetricGaussianKL(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[], def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False, point_estimates=[], mirror_samples=False,
napprox=0, comm=None, _local_samples=None): napprox=0, comm=None, _local_samples=None,
nanisinf=False):
super(MetricGaussianKL, self).__init__(mean) super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian): if not isinstance(hamiltonian, StandardHamiltonian):
...@@ -142,6 +148,7 @@ class MetricGaussianKL(Energy): ...@@ -142,6 +148,7 @@ class MetricGaussianKL(Energy):
raise TypeError raise TypeError
self._constants = tuple(constants) self._constants = tuple(constants)
self._point_estimates = tuple(point_estimates) self._point_estimates = tuple(point_estimates)
self._mitigate_nans = nanisinf
if not isinstance(mirror_samples, bool): if not isinstance(mirror_samples, bool):
raise TypeError raise TypeError
...@@ -195,6 +202,8 @@ class MetricGaussianKL(Energy): ...@@ -195,6 +202,8 @@ class MetricGaussianKL(Energy):
v += tmp.val.val v += tmp.val.val
g = g + tmp.gradient g = g + tmp.gradient
self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples
if np.isnan(self._val) and self._mitigate_nans:
self._val = np.inf
self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples
self._metric = None self._metric = None
...@@ -202,7 +211,7 @@ class MetricGaussianKL(Energy): ...@@ -202,7 +211,7 @@ class MetricGaussianKL(Energy):
return MetricGaussianKL( return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, self._constants, position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, self._mirror_samples, comm=self._comm, self._point_estimates, self._mirror_samples, comm=self._comm,
_local_samples=self._local_samples) _local_samples=self._local_samples, nanisinf=self._mitigate_nans)
@property @property
def value(self): def value(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment