Commit a6a139e9 by Reimar Leike

### made summing up consistent if MPI is not used

parent 2aceefe5
Pipeline #75634 passed with stages
in 13 minutes and 45 seconds
 ... ... @@ -233,10 +233,10 @@ class MetricGaussianKL(Energy): # This is a deterministic implementation of MPI allreduce in the sense # that it takes into account that floating point operations are not # associative. res = None if self._comm is None: for o in obj: res = o if res is None else res + o who = np.zeros(self._n_samples, dtype=np.int32) rank = 0 vals = obj else: ntask = self._comm.Get_size() rank = self._comm.Get_rank() ... ... @@ -252,29 +252,31 @@ class MetricGaussianKL(Energy): who = np.zeros(len(vals), dtype=np.int32) for t, (l,h) in enumerate(rank_lo_hi): who[l:h] = t def add2(v, w, rank): #Note that communication only happens if rank in w if len(v) == 1: return v[0] if rank == w[0]: if w[0] == w[1]: return v[0]+v[1] self._comm.send(v[0], dest=w[1]) return None if rank == w[1]: return self._comm.recv(source=w[0]) + v[1] def add2(v, w, rank): #Note that communication only happens if rank in w if len(v) == 1: return v[0] if rank == w[0]: if w[0] == w[1]: return v[0]+v[1] self._comm.send(v[0], dest=w[1]) return None if rank == w[1]: return self._comm.recv(source=w[0]) + v[1] while len(vals) > 1: new_vals = [] new_who = [] for j in range((len(vals)+1)//2): w = who[2*j:2*j+2] nv = add2(vals[2*j:2*j+2],w, rank) new_vals += [nv] new_who += [w[-1]] vals = new_vals who = new_who res = self._comm.bcast(vals[0], root=who[0]) while len(vals) > 1: new_vals = [] new_who = [] for j in range((len(vals)+1)//2): w = who[2*j:2*j+2] nv = add2(vals[2*j:2*j+2],w, rank) new_vals += [nv] new_who += [w[-1]] vals = new_vals who = new_who if self._comm is None: return vals[0] res = self._comm.bcast(vals[0], root=who[0]) return res ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!