diff --git a/nifty6/minimization/metric_gaussian_kl.py b/nifty6/minimization/metric_gaussian_kl.py index 606f333d00fb9e9d1f04ab9dde9be9762ce90762..5e5ff971d4536ef75f69d70b8169134db553394f 100644 --- a/nifty6/minimization/metric_gaussian_kl.py +++ b/nifty6/minimization/metric_gaussian_kl.py @@ -231,8 +231,12 @@ class MetricGaussianKL(Energy): else: mymap = map(lambda v: self._hamiltonian(lin+v).metric, self._samples) - self.unscaled_metric = utilities.my_sum(mymap) - self._metric = self.unscaled_metric.scale(1./self._n_eff_samples) + unscaled_metric = utilities.my_sum(mymap) + if self._mirror_samples: + mymap = map(lambda v: self._hamiltonian(lin-v).metric, + self._samples) + unscaled_metric = unscaled_metric + utilities.my_sum(mymap) + self._metric = unscaled_metric.scale(1./self._n_eff_samples) def apply_metric(self, x): self._get_metric() @@ -253,7 +257,7 @@ class MetricGaussianKL(Energy): res = res + tuple(-item for item in res) return res - def _unscaled_metric_sample(self, from_inverse=False, dtype=np.float64): + def _metric_sample(self, from_inverse=False, dtype=np.float64): if from_inverse: raise NotImplementedError() lin = self._lin.with_want_metric() @@ -265,7 +269,4 @@ class MetricGaussianKL(Energy): if self._mirror_samples: samp = samp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False, dtype=dtype) random.pop_sseq() - return _allreduce_sum_field(self._comm, samp) - - def _metric_sample(self, from_inverse=False, dtype=np.float64): - return self._unscaled_metric_sample(from_inverse, dtype)/self._n_eff_samples + return _allreduce_sum_field(self._comm, samp)/self._n_eff_samples