Skip to content
Snippets Groups Projects
Commit 6edd673d authored by Philipp Arras's avatar Philipp Arras
Browse files

Generalize MPI Operators to MultiDomains

parent 8dd3c01f
Branches
Tags
1 merge request!24Mpi magic
Pipeline #107098 passed with warnings
......@@ -94,6 +94,7 @@ class SliceSum(ift.Operator):
class SliceSumLinear(ift.LinearOperator):
"""Special case of AllreduceSumLinear @ Slicer"""
def __init__(self, oplist, index_low, parallel_space, comm):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
doms = _get_global_unique(oplist, lambda op: op.domain, comm)
......@@ -126,23 +127,22 @@ class SliceSumLinear(ift.LinearOperator):
opx.append(op(ift.makeField(op.domain, foo)))
return ift.utilities.allreduce_sum(opx, self._comm)
else:
if isinstance(self.domain, ift.MultiDomain):
res = {}
lst = [op.adjoint(x).val for op in self._oplist]
for kk in self.domain.keys():
res[kk] = array_allgather([xx[kk] for xx in lst], self._comm, self._nwork)
else:
lst = [op.adjoint(x).val for op in self._oplist]
res = array_allgather(lst, self._comm, self._nwork)
lst = [op.adjoint(x).val for op in self._oplist]
res = allgather_dispatch(lst, self._comm, self._nwork)
return ift.makeField(self.domain, res)
class SliceLinear(ift.EndomorphicOperator):
# FIXME Generalize to ift.LinearOperator
def __init__(self, oplist, index_low, parallel_space, comm):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
doms = _get_global_unique(oplist, lambda op: op.domain, comm)
self._domain = ift.makeDomain((parallel_space,) + tuple(dd for dd in doms))
if isinstance(doms, ift.MultiDomain):
dom = {}
for kk in doms.keys():
dom[kk] = (parallel_space,) + tuple(dd for dd in doms[kk])
else:
dom = (parallel_space,) + tuple(dd for dd in doms)
self._domain = ift.makeDomain(dom)
cap = _get_global_unique(oplist, lambda op: op._capability, comm)
self._capability = (self.TIMES | self.ADJOINT_TIMES) & cap
self._oplist = oplist
......@@ -153,6 +153,13 @@ class SliceLinear(ift.EndomorphicOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if isinstance(self.domain, ift.MultiDomain):
out_list = []
for ii, op in enumerate(self._oplist):
inp = ift.makeField(op.domain, {kk: x.val[kk][self._lo + ii] for kk in self.domain.keys()})
out_list.append(op.apply(inp, mode).val)
res = dict_allgather(out_list, self._comm, self._nwork)
return ift.makeField(self.domain, res)
res = array_allgather(
[
op.apply(ift.makeField(op.domain, x.val[self._lo + ii]), mode).val
......@@ -169,7 +176,7 @@ class SliceLinear(ift.EndomorphicOperator):
for ii, op in enumerate(self._oplist):
with ift.random.Context(sseq[self._lo + ii]):
local_samples.append(op.draw_sample(from_inverse).val)
res = array_allgather(local_samples, self._comm, self._nwork)
res = allgather_dispatch(local_samples, self._comm, self._nwork)
return ift.makeField(self._domain, res)
......@@ -215,6 +222,13 @@ def _get_global_unique(lst, f, comm):
return cap
def allgather_dispatch(obj, comm, nwork):
if isinstance(obj[0], np.ndarray):
return array_allgather(obj, comm, nwork)
if isinstance(obj[0], dict):
return dict_allgather(obj, comm, nwork)
def array_allgather(arrs, comm, nwork):
if comm is None:
full_lst = np.array(arrs)
......@@ -240,3 +254,19 @@ def array_allgather(arrs, comm, nwork):
comm.Allgatherv([send_buf, MPI.DOUBLE], [full_lst, tuple(send_count), tuple(displacement), MPI.DOUBLE])
return full_lst
def dict_allgather(lst, comm, nwork):
"""Apply array_allgather for each key of a dictionary.
Parameters
----------
lst: list
List of dictionaries. All dictionaries need to have the same keys.
"""
keys = lst[0].keys()
for aa in lst:
assert isinstance(aa, dict)
assert aa.keys() == keys
return {kk: array_allgather([ll[kk] for ll in lst], comm, nwork) for kk in keys}
......@@ -36,6 +36,7 @@ def test_slice_sum():
op = rve.SliceSum(oplist, lo, parallel_space, rve.mpi.comm)
ift.extra.check_operator(op, ift.from_random(op.domain))
oplist = [ift.GaussianEnergy(domain=dom["a"]) @ Extract(dom, "a") @ ift.ScalingOperator(dom, 2.).exp() for ii in range(hi-lo)]
oplist = [ift.GaussianEnergy(domain=dom["a"], sampling_dtype=np.float64) @ Extract(dom, "a") @ ift.ScalingOperator(dom, 2.).exp() for ii in range(hi-lo)]
op = rve.SliceSum(oplist, lo, parallel_space, rve.mpi.comm)
ift.extra.check_operator(op, ift.from_random(op.domain))
op(ift.Linearization.make_var(ift.from_random(op.domain), True)).metric.draw_sample()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment