Commit c2f2b24e authored by Reimar Leike's avatar Reimar Leike Committed by Philipp Arras
Browse files

Add possibility to cast NaNs to Infs in KL

parent fd208746
Pipeline #75328 passed with stages
in 8 minutes and 53 seconds
......@@ -131,7 +131,8 @@ class MetricGaussianKL(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
napprox=0, comm=None, _local_samples=None):
napprox=0, comm=None, _local_samples=None,
nanisinf=False):
super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
......@@ -142,6 +143,7 @@ class MetricGaussianKL(Energy):
raise TypeError
self._constants = tuple(constants)
self._point_estimates = tuple(point_estimates)
self._mitigate_nans = nanisinf
if not isinstance(mirror_samples, bool):
raise TypeError
......@@ -195,6 +197,8 @@ class MetricGaussianKL(Energy):
v += tmp.val.val
g = g + tmp.gradient
self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples
if np.isnan(self._val) and self._mitigate_nans:
self._val = np.inf
self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples
self._metric = None
......@@ -202,7 +206,7 @@ class MetricGaussianKL(Energy):
return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, self._mirror_samples, comm=self._comm,
_local_samples=self._local_samples)
_local_samples=self._local_samples, nanisinf=self._mitigate_nans)
@property
def value(self):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment