Commit 2249183a authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'more_MPI_craziness' into 'NIFTy_6'

Simpler way of MPI summation

See merge request !512
parents b4c1d0da ac7ee7f8
Pipeline #75730 passed with stages
in 14 minutes and 56 seconds
......@@ -179,7 +179,6 @@ class MetricGaussianKL(Energy):
if np.isnan(self._val) and self._mitigate_nans:
self._val = np.inf
self._grad = self._sumup(g)/self._n_eff_samples
self._metric = None
def at(self, position):
return MetricGaussianKL(
......@@ -229,23 +228,61 @@ class MetricGaussianKL(Energy):
yield -s
def _sumup(self, obj):
# This is a deterministic implementation of MPI allreduce in the sense
# that it takes into account that floating point operations are not
# associative.
res = None
""" This is a deterministic implementation of MPI allreduce
Numeric addition is not associative due to rounding errors.
Therefore we provide our own implementation that is consistent
no matter if MPI is used and how many tasks there are.
At the beginning, a list `who` is constructed, that states which obj can
be found on which MPI task.
Then elements are added pairwise, with increasing pair distance.
In the first round, the distance between pair members is 1:
v[0] := v[0] + v[1]
v[2] := v[2] + v[3]
v[4] := v[4] + v[5]
Entries whose summation partner lies beyond the end of the array
stay unchanged.
When both summation partners are not located on the same MPI task,
the second summand is sent to the task holding the first summand and
the operation is carried out there.
For the next round, the distance is doubled:
v[0] := v[0] + v[2]
v[4] := v[4] + v[6]
v[8] := v[8] + v[10]
This is repeated until the distance exceeds the length of the array.
At this point v[0] contains the sum of all entries, which is then
broadcast to all tasks.
"""
if self._comm is None:
for o in obj:
res = o if res is None else res + o
who = np.zeros(self._n_samples, dtype=np.int32)
rank = 0
vals = list(obj) # necessary since we don't want to modify `obj`
else:
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
rank_lo_hi = [_shareRange(self._n_samples, ntask, i) for i in range(ntask)]
for itask, (l, h) in enumerate(rank_lo_hi):
for i in range(l, h):
o = obj[i-self._lo] if rank == itask else None
o = self._comm.bcast(o, root=itask)
res = o if res is None else res + o
return res
lo, hi = rank_lo_hi[rank]
vals = [None]*lo + list(obj) + [None]*(self._n_samples-hi)
who = [t for t, (l, h) in enumerate(rank_lo_hi) for cnt in range(h-l)]
step = 1
while step < self._n_samples:
for j in range(0, self._n_samples, 2*step):
if j+step < self._n_samples: # summation partner found
if rank == who[j]:
if who[j] == who[j+step]: # no communication required
vals[j] = vals[j] + vals[j+step]
vals[j+step] = None
else:
vals[j] = vals[j] + self._comm.recv(source=who[j+step])
elif rank == who[j+step]:
self._comm.send(vals[j+step], dest=who[j])
vals[j+step] = None
step *= 2
if self._comm is None:
return vals[0]
return self._comm.bcast(vals[0], root=who[0])
def _metric_sample(self, from_inverse=False):
if from_inverse:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment