Commit a6a139e9 authored by Reimar Leike's avatar Reimar Leike
Browse files

made summing up consistent if MPI is not used

parent 2aceefe5
Pipeline #75634 passed with stages
in 13 minutes and 45 seconds
......@@ -233,10 +233,10 @@ class MetricGaussianKL(Energy):
# This is a deterministic implementation of MPI allreduce in the sense
# that it takes into account that floating point operations are not
# associative.
res = None
if self._comm is None:
for o in obj:
res = o if res is None else res + o
who = np.zeros(self._n_samples, dtype=np.int32)
rank = 0
vals = obj
else:
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
......@@ -252,29 +252,31 @@ class MetricGaussianKL(Energy):
who = np.zeros(len(vals), dtype=np.int32)
for t, (l,h) in enumerate(rank_lo_hi):
who[l:h] = t
def add2(v, w, rank):
#Note that communication only happens if rank in w
if len(v) == 1:
return v[0]
if rank == w[0]:
if w[0] == w[1]:
return v[0]+v[1]
self._comm.send(v[0], dest=w[1])
return None
if rank == w[1]:
return self._comm.recv(source=w[0]) + v[1]
def add2(v, w, rank):
#Note that communication only happens if rank in w
if len(v) == 1:
return v[0]
if rank == w[0]:
if w[0] == w[1]:
return v[0]+v[1]
self._comm.send(v[0], dest=w[1])
return None
if rank == w[1]:
return self._comm.recv(source=w[0]) + v[1]
while len(vals) > 1:
new_vals = []
new_who = []
for j in range((len(vals)+1)//2):
w = who[2*j:2*j+2]
nv = add2(vals[2*j:2*j+2],w, rank)
new_vals += [nv]
new_who += [w[-1]]
vals = new_vals
who = new_who
res = self._comm.bcast(vals[0], root=who[0])
while len(vals) > 1:
new_vals = []
new_who = []
for j in range((len(vals)+1)//2):
w = who[2*j:2*j+2]
nv = add2(vals[2*j:2*j+2],w, rank)
new_vals += [nv]
new_who += [w[-1]]
vals = new_vals
who = new_who
if self._comm is None:
return vals[0]
res = self._comm.bcast(vals[0], root=who[0])
return res
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment