Commit 5ded9f44 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'parallizing_mirrored_samples' into 'NIFTy_5'

parallelization for mirrored KL

See merge request !301
parents 4f68d243 b590d545
Pipeline #60253 passed with stages
in 23 minutes and 37 seconds
......@@ -59,7 +59,7 @@ def allreduce_sum_field(fld):
class MetricGaussianKL_MPI(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
_samples=None, seed_offset=0):
super(MetricGaussianKL_MPI, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
......@@ -76,18 +76,30 @@ class MetricGaussianKL_MPI(Energy):
self._hamiltonian = hamiltonian
if _samples is None:
if mirror_samples:
lo, hi = _shareRange(n_samples*2, ntask, rank)
lo, hi = _shareRange(n_samples, ntask, rank)
met = hamiltonian(Linearization.make_partial_var(
mean, point_estimates, True)).metric
_samples = []
for i in range(lo, hi):
if mirror_samples:
if (i % 2) and (i-1 >= lo):
_samples.append(((i % 2)*2-1) *
_samples = tuple(_samples)
if mirror_samples:
_samples += [-s for s in _samples]
n_samples *= 2
_samples = tuple(_samples)
self._samples = _samples
self._seed_offset = seed_offset
self._n_samples = n_samples
self._lin = Linearization.make_partial_var(mean, constants)
v, g = None, None
......@@ -111,7 +123,8 @@ class MetricGaussianKL_MPI(Energy):
def at(self, position):
return MetricGaussianKL_MPI(
position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, _samples=self._samples)
self._point_estimates, _samples=self._samples,
def value(self):
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment