Commit c01d10cc authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'fix_KL' into 'NIFTy_6'

Fix KL

See merge request !429
parents 93e2788f e66e8b5e
Pipeline #71428 passed with stages
in 17 minutes and 14 seconds
......@@ -231,8 +231,12 @@ class MetricGaussianKL(Energy):
else:
mymap = map(lambda v: self._hamiltonian(lin+v).metric,
self._samples)
self.unscaled_metric = utilities.my_sum(mymap)
self._metric = self.unscaled_metric.scale(1./self._n_eff_samples)
unscaled_metric = utilities.my_sum(mymap)
if self._mirror_samples:
mymap = map(lambda v: self._hamiltonian(lin-v).metric,
self._samples)
unscaled_metric = unscaled_metric + utilities.my_sum(mymap)
self._metric = unscaled_metric.scale(1./self._n_eff_samples)
def apply_metric(self, x):
self._get_metric()
......@@ -253,7 +257,7 @@ class MetricGaussianKL(Energy):
res = res + tuple(-item for item in res)
return res
def _unscaled_metric_sample(self, from_inverse=False, dtype=np.float64):
def _metric_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse:
raise NotImplementedError()
lin = self._lin.with_want_metric()
......@@ -265,7 +269,4 @@ class MetricGaussianKL(Energy):
if self._mirror_samples:
samp = samp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False, dtype=dtype)
random.pop_sseq()
return _allreduce_sum_field(self._comm, samp)
def _metric_sample(self, from_inverse=False, dtype=np.float64):
return self._unscaled_metric_sample(from_inverse, dtype)/self._n_eff_samples
return _allreduce_sum_field(self._comm, samp)/self._n_eff_samples
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment