Skip to content
Snippets Groups Projects
Commit dd72d8de authored by Philipp Arras's avatar Philipp Arras
Browse files

Add fancier MPI operators

parent b7c322ac
No related branches found
No related tags found
1 merge request!13More mpi functionality
Pipeline #94163 passed
...@@ -27,11 +27,9 @@ class AllreduceSum(ift.Operator): ...@@ -27,11 +27,9 @@ class AllreduceSum(ift.Operator):
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
if not ift.is_linearization(x):
return ift.utilities.allreduce_sum(
[op(x) for op in self._oplist], self._comm
)
opx = [op(x) for op in self._oplist] opx = [op(x) for op in self._oplist]
if not ift.is_linearization(x):
return ift.utilities.allreduce_sum(opx, self._comm)
val = ift.utilities.allreduce_sum([lin.val for lin in opx], self._comm) val = ift.utilities.allreduce_sum([lin.val for lin in opx], self._comm)
jac = AllreduceSumLinear([lin.jac for lin in opx], self._comm) jac = AllreduceSumLinear([lin.jac for lin in opx], self._comm)
if _get_global_unique(opx, lambda op: op.metric is None, self._comm): if _get_global_unique(opx, lambda op: op.metric is None, self._comm):
...@@ -40,6 +38,110 @@ class AllreduceSum(ift.Operator): ...@@ -40,6 +38,110 @@ class AllreduceSum(ift.Operator):
return x.new(val, jac, met) return x.new(val, jac, met)
class SliceSum(ift.Operator):
def __init__(self, oplist, index_low, parallel_space, comm):
self._oplist, self._comm = oplist, comm
self._lo = int(index_low)
assert len(parallel_space.shape) == 1
if len(oplist) > 0:
assert index_low < parallel_space.shape[0]
else:
assert index_low == parallel_space.shape[0]
doms = _get_global_unique(oplist, lambda op: op.domain, comm)
self._domain = ift.makeDomain((parallel_space,) + tuple(dd for dd in doms))
self._target = ift.makeDomain(
_get_global_unique(oplist, lambda op: op.target, comm)
)
def apply(self, x):
self._check_input(x)
if not ift.is_linearization(x):
opx = [
op(ift.makeField(op.domain, x.val[self._lo + ii]))
for ii, op in enumerate(self._oplist)
]
return ift.utilities.allreduce_sum(opx, self._comm)
oplin = [
op(
ift.Linearization.make_var(
ift.makeField(op.domain, x.val.val[self._lo + ii]), x.want_metric
)
)
for ii, op in enumerate(self._oplist)
]
val = ift.utilities.allreduce_sum([lin.val for lin in oplin], self._comm)
args = self._lo, self.domain[0], self._comm
jac = SliceSumLinear([lin.jac for lin in oplin], *args)
if _get_global_unique(oplin, lambda op: op.metric is None, self._comm):
return x.new(val, jac)
met = SliceLinear([lin.metric for lin in oplin], *args)
return x.new(val, jac, met)
class SliceSumLinear(ift.LinearOperator):
def __init__(self, oplist, index_low, parallel_space, comm):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
doms = _get_global_unique(oplist, lambda op: op.domain, comm)
self._domain = ift.makeDomain((parallel_space,) + tuple(dd for dd in doms))
self._target = ift.makeDomain(
_get_global_unique(oplist, lambda op: op.target, comm)
)
cap = _get_global_unique(oplist, lambda op: op._capability, comm)
self._capability = (self.TIMES | self.ADJOINT_TIMES) & cap
self._oplist = oplist
self._comm = comm
self._lo = index_low
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return ift.utilities.allreduce_sum(
[
op(ift.makeField(op.domain, x.val[self._lo + ii]))
for ii, op in enumerate(self._oplist)
],
self._comm,
)
else:
arr = _allgather([op.adjoint(x).val for op in self._oplist], self._comm)
return ift.makeField(self.domain, arr)
class SliceLinear(ift.EndomorphicOperator):
def __init__(self, oplist, index_low, parallel_space, comm):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
doms = _get_global_unique(oplist, lambda op: op.domain, comm)
self._domain = ift.makeDomain((parallel_space,) + tuple(dd for dd in doms))
cap = _get_global_unique(oplist, lambda op: op._capability, comm)
self._capability = (self.TIMES | self.ADJOINT_TIMES) & cap
self._oplist = oplist
self._comm = comm
self._lo = index_low
local_nwork = [len(oplist)] if comm is None else comm.allgather(len(oplist))
size, rank, _ = ift.utilities.get_MPI_params_from_comm(comm)
self._nwork = sum(local_nwork)
def apply(self, x, mode):
self._check_input(x, mode)
res = _allgather(
[
op.apply(ift.makeField(op.domain, x.val[self._lo + ii]), mode).val
for ii, op in enumerate(self._oplist)
],
self._comm,
)
return ift.makeField(self._domain, res)
def draw_sample(self, from_inverse=False):
sseq = ift.random.spawn_sseq(self._nwork)
local_samples = []
for ii, op in enumerate(self._oplist):
with ift.random.Context(sseq[self._lo + ii]):
local_samples.append(op.draw_sample(from_inverse).val)
res = _allgather(local_samples, self._comm)
return ift.makeField(self._domain, res)
class AllreduceSumLinear(ift.LinearOperator): class AllreduceSumLinear(ift.LinearOperator):
def __init__(self, oplist, comm=None): def __init__(self, oplist, comm=None):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist) assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
...@@ -80,3 +182,11 @@ def _get_global_unique(lst, f, comm): ...@@ -80,3 +182,11 @@ def _get_global_unique(lst, f, comm):
cap = caps[0] cap = caps[0]
assert all(cc == cap for cc in caps) assert all(cc == cap for cc in caps)
return cap return cap
def _allgather(arrs, comm):
if comm is None:
fulllst = [arrs]
else:
fulllst = comm.allgather(arrs)
return np.array([aa for cc in fulllst for aa in cc])
...@@ -12,7 +12,7 @@ import nifty7 as ift ...@@ -12,7 +12,7 @@ import nifty7 as ift
import resolve as rve import resolve as rve
def getop(comm): def getop(comm, typ):
"""Return energy operator that maps the full multi-frequency sky onto """Return energy operator that maps the full multi-frequency sky onto
the log-likelihood value for a frequency slice.""" the log-likelihood value for a frequency slice."""
...@@ -40,12 +40,14 @@ def getop(comm): ...@@ -40,12 +40,14 @@ def getop(comm):
ddom = ift.UnstructuredDomain(d[ii].shape) ddom = ift.UnstructuredDomain(d[ii].shape)
dd = ift.makeField(ddom, d[ii]) dd = ift.makeField(ddom, d[ii])
iicc = ift.makeOp(ift.makeField(ddom, invcov[ii])) iicc = ift.makeOp(ift.makeField(ddom, invcov[ii]))
ee = ( ee = ift.GaussianEnergy(dd, iicc)
ift.GaussianEnergy(dd, iicc) if typ == 0:
@ ift.DomainTupleFieldInserter(skydom, 0, (ii,)).adjoint ee = ee @ ift.DomainTupleFieldInserter(skydom, 0, (ii,)).adjoint
)
lst.append(ee) lst.append(ee)
op = rve.AllreduceSum(lst, comm) if typ == 0:
op = rve.AllreduceSum(lst, comm)
else:
op = rve.SliceSum(lst, lo, skydom[0], comm)
ift.extra.check_operator(op, ift.from_random(op.domain)) ift.extra.check_operator(op, ift.from_random(op.domain))
sky = ift.FieldAdapter(skydom, "sky") sky = ift.FieldAdapter(skydom, "sky")
return op @ sky.exp() return op @ sky.exp()
...@@ -69,13 +71,20 @@ def test_mpi_adder(): ...@@ -69,13 +71,20 @@ def test_mpi_adder():
if comm is not None: if comm is not None:
comm.Barrier() comm.Barrier()
lhs = getop(-1), getop(None), getop(comm) lhs = (
getop(-1, 0),
getop(-1, 1),
getop(None, 0),
getop(None, 1),
getop(comm, 0),
getop(comm, 1),
)
hams = tuple( hams = tuple(
ift.StandardHamiltonian(lh, ift.GradientNormController(iteration_limit=10)) ift.StandardHamiltonian(lh, ift.GradientNormController(iteration_limit=10))
for lh in lhs for lh in lhs
) )
lhs_for_sampling = lhs[1:] lhs_for_sampling = lhs[2:]
hams_for_sampling = hams[1:] hams_for_sampling = hams[2:]
# Evaluate Field # Evaluate Field
dom, tgt = lhs[0].domain, lhs[0].target dom, tgt = lhs[0].domain, lhs[0].target
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment