diff --git a/nifty5/minimization/metric_gaussian_kl_mpi.py b/nifty5/minimization/metric_gaussian_kl_mpi.py index 1970bb2e00cbd10858823acfce04c389a98ea632..5a6b7f864edf7289072439718035a25adcb747c0 100644 --- a/nifty5/minimization/metric_gaussian_kl_mpi.py +++ b/nifty5/minimization/metric_gaussian_kl_mpi.py @@ -59,7 +59,7 @@ def allreduce_sum_field(fld): class MetricGaussianKL_MPI(Energy): def __init__(self, mean, hamiltonian, n_samples, constants=[], point_estimates=[], mirror_samples=False, - _samples=None): + _samples=None, seed_offset=0): super(MetricGaussianKL_MPI, self).__init__(mean) if not isinstance(hamiltonian, StandardHamiltonian): @@ -76,18 +76,30 @@ class MetricGaussianKL_MPI(Energy): self._hamiltonian = hamiltonian if _samples is None: - lo, hi = _shareRange(n_samples, ntask, rank) + if mirror_samples: + lo, hi = _shareRange(n_samples*2, ntask, rank) + else: + lo, hi = _shareRange(n_samples, ntask, rank) met = hamiltonian(Linearization.make_partial_var( mean, point_estimates, True)).metric _samples = [] for i in range(lo, hi): - np.random.seed(i) - _samples.append(met.draw_sample(from_inverse=True)) + if mirror_samples: + np.random.seed(i//2+seed_offset) + if (i % 2) and (i-1 >= lo): + _samples.append(-_samples[-1]) + + else: + _samples.append(((i % 2)*2-1) * + met.draw_sample(from_inverse=True)) + else: + np.random.seed(i) + _samples.append(met.draw_sample(from_inverse=True)) + _samples = tuple(_samples) if mirror_samples: - _samples += [-s for s in _samples] n_samples *= 2 - _samples = tuple(_samples) self._samples = _samples + self._seed_offset = seed_offset self._n_samples = n_samples self._lin = Linearization.make_partial_var(mean, constants) v, g = None, None @@ -111,7 +123,8 @@ class MetricGaussianKL_MPI(Energy): def at(self, position): return MetricGaussianKL_MPI( position, self._hamiltonian, self._n_samples, self._constants, - self._point_estimates, _samples=self._samples) + self._point_estimates, _samples=self._samples, + seed_offset=self._seed_offset) @property def value(self):