Commit 2aceefe5 authored by Reimar Leike's avatar Reimar Leike
Browse files

Made _sumup of KL faster by summign up in parallel when possible

parent a8c78eb9
Pipeline #75626 passed with stages
in 13 minutes and 16 seconds
......@@ -228,6 +228,7 @@ class MetricGaussianKL(Energy):
if self._mirror_samples:
yield -s
def _sumup(self, obj):
# This is a deterministic implementation of MPI allreduce in the sense
# that it takes into account that floating point operations are not
......@@ -240,13 +241,44 @@ class MetricGaussianKL(Energy):
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
rank_lo_hi = [_shareRange(self._n_samples, ntask, i) for i in range(ntask)]
for itask, (l, h) in enumerate(rank_lo_hi):
for i in range(l, h):
o = obj[i-self._lo] if rank == itask else None
o = self._comm.bcast(o, root=itask)
res = o if res is None else res + o
lo = rank_lo_hi[rank][0]
hi = rank_lo_hi[rank][1]
vals = []
for i in range(self._n_samples):
if (i>=lo) and (i<hi):
vals += [obj[i-lo]]
vals += [None]
who = np.zeros(len(vals), dtype=np.int32)
for t, (l,h) in enumerate(rank_lo_hi):
who[l:h] = t
def add2(v, w, rank):
#Note that communication only happens if rank in w
if len(v) == 1:
return v[0]
if rank == w[0]:
if w[0] == w[1]:
return v[0]+v[1]
self._comm.send(v[0], dest=w[1])
return None
if rank == w[1]:
return self._comm.recv(source=w[0]) + v[1]
while len(vals) > 1:
new_vals = []
new_who = []
for j in range((len(vals)+1)//2):
w = who[2*j:2*j+2]
nv = add2(vals[2*j:2*j+2],w, rank)
new_vals += [nv]
new_who += [w[-1]]
vals = new_vals
who = new_who
res = self._comm.bcast(vals[0], root=who[0])
return res
def _metric_sample(self, from_inverse=False):
if from_inverse:
raise NotImplementedError()
