Skip to content
Snippets Groups Projects
Commit 6edd673d authored by Philipp Arras's avatar Philipp Arras
Browse files

Generalize MPI Operators to MultiDomains

parent 8dd3c01f
No related branches found
No related tags found
1 merge request!24Mpi magic
Pipeline #107098 passed with warnings
...@@ -94,6 +94,7 @@ class SliceSum(ift.Operator): ...@@ -94,6 +94,7 @@ class SliceSum(ift.Operator):
class SliceSumLinear(ift.LinearOperator): class SliceSumLinear(ift.LinearOperator):
"""Special case of AllreduceSumLinear @ Slicer"""
def __init__(self, oplist, index_low, parallel_space, comm): def __init__(self, oplist, index_low, parallel_space, comm):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist) assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
doms = _get_global_unique(oplist, lambda op: op.domain, comm) doms = _get_global_unique(oplist, lambda op: op.domain, comm)
...@@ -126,23 +127,22 @@ class SliceSumLinear(ift.LinearOperator): ...@@ -126,23 +127,22 @@ class SliceSumLinear(ift.LinearOperator):
opx.append(op(ift.makeField(op.domain, foo))) opx.append(op(ift.makeField(op.domain, foo)))
return ift.utilities.allreduce_sum(opx, self._comm) return ift.utilities.allreduce_sum(opx, self._comm)
else: else:
if isinstance(self.domain, ift.MultiDomain): lst = [op.adjoint(x).val for op in self._oplist]
res = {} res = allgather_dispatch(lst, self._comm, self._nwork)
lst = [op.adjoint(x).val for op in self._oplist]
for kk in self.domain.keys():
res[kk] = array_allgather([xx[kk] for xx in lst], self._comm, self._nwork)
else:
lst = [op.adjoint(x).val for op in self._oplist]
res = array_allgather(lst, self._comm, self._nwork)
return ift.makeField(self.domain, res) return ift.makeField(self.domain, res)
class SliceLinear(ift.EndomorphicOperator): class SliceLinear(ift.EndomorphicOperator):
# FIXME Generalize to ift.LinearOperator
def __init__(self, oplist, index_low, parallel_space, comm): def __init__(self, oplist, index_low, parallel_space, comm):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist) assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
doms = _get_global_unique(oplist, lambda op: op.domain, comm) doms = _get_global_unique(oplist, lambda op: op.domain, comm)
self._domain = ift.makeDomain((parallel_space,) + tuple(dd for dd in doms)) if isinstance(doms, ift.MultiDomain):
dom = {}
for kk in doms.keys():
dom[kk] = (parallel_space,) + tuple(dd for dd in doms[kk])
else:
dom = (parallel_space,) + tuple(dd for dd in doms)
self._domain = ift.makeDomain(dom)
cap = _get_global_unique(oplist, lambda op: op._capability, comm) cap = _get_global_unique(oplist, lambda op: op._capability, comm)
self._capability = (self.TIMES | self.ADJOINT_TIMES) & cap self._capability = (self.TIMES | self.ADJOINT_TIMES) & cap
self._oplist = oplist self._oplist = oplist
...@@ -153,6 +153,13 @@ class SliceLinear(ift.EndomorphicOperator): ...@@ -153,6 +153,13 @@ class SliceLinear(ift.EndomorphicOperator):
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
if isinstance(self.domain, ift.MultiDomain):
out_list = []
for ii, op in enumerate(self._oplist):
inp = ift.makeField(op.domain, {kk: x.val[kk][self._lo + ii] for kk in self.domain.keys()})
out_list.append(op.apply(inp, mode).val)
res = dict_allgather(out_list, self._comm, self._nwork)
return ift.makeField(self.domain, res)
res = array_allgather( res = array_allgather(
[ [
op.apply(ift.makeField(op.domain, x.val[self._lo + ii]), mode).val op.apply(ift.makeField(op.domain, x.val[self._lo + ii]), mode).val
...@@ -169,7 +176,7 @@ class SliceLinear(ift.EndomorphicOperator): ...@@ -169,7 +176,7 @@ class SliceLinear(ift.EndomorphicOperator):
for ii, op in enumerate(self._oplist): for ii, op in enumerate(self._oplist):
with ift.random.Context(sseq[self._lo + ii]): with ift.random.Context(sseq[self._lo + ii]):
local_samples.append(op.draw_sample(from_inverse).val) local_samples.append(op.draw_sample(from_inverse).val)
res = array_allgather(local_samples, self._comm, self._nwork) res = allgather_dispatch(local_samples, self._comm, self._nwork)
return ift.makeField(self._domain, res) return ift.makeField(self._domain, res)
...@@ -215,6 +222,13 @@ def _get_global_unique(lst, f, comm): ...@@ -215,6 +222,13 @@ def _get_global_unique(lst, f, comm):
return cap return cap
def allgather_dispatch(obj, comm, nwork):
if isinstance(obj[0], np.ndarray):
return array_allgather(obj, comm, nwork)
if isinstance(obj[0], dict):
return dict_allgather(obj, comm, nwork)
def array_allgather(arrs, comm, nwork): def array_allgather(arrs, comm, nwork):
if comm is None: if comm is None:
full_lst = np.array(arrs) full_lst = np.array(arrs)
...@@ -240,3 +254,19 @@ def array_allgather(arrs, comm, nwork): ...@@ -240,3 +254,19 @@ def array_allgather(arrs, comm, nwork):
comm.Allgatherv([send_buf, MPI.DOUBLE], [full_lst, tuple(send_count), tuple(displacement), MPI.DOUBLE]) comm.Allgatherv([send_buf, MPI.DOUBLE], [full_lst, tuple(send_count), tuple(displacement), MPI.DOUBLE])
return full_lst return full_lst
def dict_allgather(lst, comm, nwork):
"""Apply array_allgather for each key of a dictionary.
Parameters
----------
lst: list
List of dictionaries. All dictionaries need to have the same keys.
"""
keys = lst[0].keys()
for aa in lst:
assert isinstance(aa, dict)
assert aa.keys() == keys
return {kk: array_allgather([ll[kk] for ll in lst], comm, nwork) for kk in keys}
...@@ -36,6 +36,7 @@ def test_slice_sum(): ...@@ -36,6 +36,7 @@ def test_slice_sum():
op = rve.SliceSum(oplist, lo, parallel_space, rve.mpi.comm) op = rve.SliceSum(oplist, lo, parallel_space, rve.mpi.comm)
ift.extra.check_operator(op, ift.from_random(op.domain)) ift.extra.check_operator(op, ift.from_random(op.domain))
oplist = [ift.GaussianEnergy(domain=dom["a"]) @ Extract(dom, "a") @ ift.ScalingOperator(dom, 2.).exp() for ii in range(hi-lo)] oplist = [ift.GaussianEnergy(domain=dom["a"], sampling_dtype=np.float64) @ Extract(dom, "a") @ ift.ScalingOperator(dom, 2.).exp() for ii in range(hi-lo)]
op = rve.SliceSum(oplist, lo, parallel_space, rve.mpi.comm) op = rve.SliceSum(oplist, lo, parallel_space, rve.mpi.comm)
ift.extra.check_operator(op, ift.from_random(op.domain)) ift.extra.check_operator(op, ift.from_random(op.domain))
op(ift.Linearization.make_var(ift.from_random(op.domain), True)).metric.draw_sample()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment