Commit 9a934ab1 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

merge

parents 74a81e8c 193ffe36
Pipeline #75649 passed with stages
in 13 minutes and 43 seconds
......@@ -228,15 +228,37 @@ class MetricGaussianKL(Energy):
if self._mirror_samples:
yield -s
def _sumup(self, obj):
# This is a deterministic implementation of MPI allreduce in the sense
# that it takes into account that floating point operations are not
# associative.
""" This is a deterministic implementation of MPI allreduce
Numeric addition is not associative due to rounding errors.
Therefore we provide our own implementation that is consistent
no matter if MPI is used and how many tasks there are.
At the beginning, a list `who` is constructed, that states which obj can
be found on which MPI task.
Then elements are added pairwise, with increasing pair distance.
In the first round, the distance between pair members is 1:
v[0] := v[0] + v[1]
v[2] := v[2] + v[3]
v[4] := v[4] + v[5]
Entries whose summation partner lies beyond the end of the array
stay unchanged.
When both summation partners are not located on the same MPI task,
the second summand is sent to the task holding the first summand and
the operation is carried out there.
For the next round, the distance is doubled:
v[0] := v[0] + v[2]
v[4] := v[4] + v[6]
v[8] := v[8] + v[10]
This is repeated until the distance exceeds the length of the array.
At this point v[0] contains the sum of all entries, which is then
broadcast to all tasks.
"""
if self._comm is None:
who = np.zeros(self._n_samples, dtype=np.int32)
rank = 0
vals = list(obj)
vals = list(obj) # necessary since we don't want to modify `obj`
else:
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
......@@ -248,29 +270,23 @@ class MetricGaussianKL(Energy):
who[l:h] = t
step = 1
# `step` doubles with every iteration
# first round: add entries 0 and 1, store result in 0
# add entries 2 and 3, store result in 2
# ...
# second round: add entries 0 and 2, store result in 0
# add entries 4 and 6, store result in 4
# ...
while step < self._n_samples:
for j in range(0, self._n_samples, 2*step):
if j+step < self._n_samples: # summation partner found
if rank == who[j]:
if who[j] == who[j+step]: # no communication required
vals[j] = vals[j] + vals[j+step]
vals[j+step] = None
else:
vals[j] = vals[j] + self._comm.recv(source=who[j+step])
elif rank == who[j+step]:
self._comm.send(vals[j+step], dest=who[j])
vals[j+step] = None
step *= 2
if self._comm is None:
return vals[0]
return self._comm.bcast(vals[0], root=who[0])
def _metric_sample(self, from_inverse=False):
if from_inverse:
raise NotImplementedError()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment