Commit 8e377ee6 authored by Gordian Edenhofer's avatar Gordian Edenhofer
Browse files

metric_gaussian_kl.py: Simplify MPI sample prop

Use the same logic for returning the samples from the various MPI tasks
as for working with them in each task.
parent 7211a0a0
Pipeline #71459 passed with stages
in 17 minutes and 4 seconds
......@@ -36,18 +36,6 @@ def _shareRange(nwork, nshares, myshare):
return lo, hi
def _getTask(iwork, nwork, nshares):
nbase = nwork//nshares
additional = nwork % nshares
# FIXME: this is crappy code!
for ishare in range(nshares):
lo = ishare*nbase + min(ishare, additional)
hi = lo + nbase + int(ishare < additional)
if hi>iwork:
return ishare
raise RunTimeError("must not arrive here")
def _np_allreduce_sum(comm, arr):
if comm is None:
return arr
......@@ -268,13 +256,14 @@ class MetricGaussianKL(Energy):
else:
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
for i in range(self._n_samples):
itask = _getTask(i, self._n_samples, ntask)
data = self._local_samples[i-self._lo] if rank == itask else None
s = self._comm.bcast(data, root=itask)
yield s
if self._mirror_samples:
yield -s
rank_lo_hi = [_shareRange(self._n_samples, ntask, i) for i in range(ntask)]
for itask, (l, h) in enumerate(rank_lo_hi):
for i in range(l, h):
data = self._local_samples[i-self._lo] if rank == itask else None
s = self._comm.bcast(data, root=itask)
yield s
if self._mirror_samples:
yield -s
def _metric_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment