Commit 193ffe36 authored by Reimar Leike's avatar Reimar Leike
Browse files

add extended documentation

parent 52213268
Pipeline #75642 passed with stages
in 14 minutes and 45 seconds
......@@ -230,9 +230,38 @@ class MetricGaussianKL(Energy):
def _sumup(self, obj):
# This is a deterministic implementation of MPI allreduce in the sense
# that it takes into account that floating point operations are not
# associative.
""" This is a deterministic implementation of MPI allreduce
Numeric addition is not associative due to rounding errors.
Therefore we provide our own implementation that is consistent
no matter if MPI is used and how many tasks there are.
At the beginning, a list `who` is constructed, that states which obj can
be found on which MPI task.
Then iteratively, pairs of objects are summed. If the two objects
are on different tasks, the tasks that holds the second element of the
pair sums the values after a communication.
Example of how values are summed:
```
3 7 11
| / |
10 11
\ /
21
```
Note that sometimes a value is not paired, if the length of the list is
odd. In this case, nothing has to be done.
If the three initial values are distributed to rank 0,1,1, respectively,
then the this distribution evolves as follows:
```
0 1 1
|/ |
1 1
\ /
1
```
In the end, task 1 broadcasts the result
"""
if self._comm is None:
who = np.zeros(self._n_samples, dtype=np.int32)
rank = 0
......@@ -274,8 +303,6 @@ class MetricGaussianKL(Energy):
res = self._comm.bcast(vals[0], root=who[0])
return res
def _metric_sample(self, from_inverse=False):
if from_inverse:
raise NotImplementedError()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment