Skip to content
Snippets Groups Projects
Commit dd72d8de authored by Philipp Arras's avatar Philipp Arras
Browse files

Add fancier MPI operators

parent b7c322ac
No related branches found
No related tags found
1 merge request!13More mpi functionality
Pipeline #94163 passed
......@@ -27,11 +27,9 @@ class AllreduceSum(ift.Operator):
def apply(self, x):
self._check_input(x)
if not ift.is_linearization(x):
return ift.utilities.allreduce_sum(
[op(x) for op in self._oplist], self._comm
)
opx = [op(x) for op in self._oplist]
if not ift.is_linearization(x):
return ift.utilities.allreduce_sum(opx, self._comm)
val = ift.utilities.allreduce_sum([lin.val for lin in opx], self._comm)
jac = AllreduceSumLinear([lin.jac for lin in opx], self._comm)
if _get_global_unique(opx, lambda op: op.metric is None, self._comm):
......@@ -40,6 +38,110 @@ class AllreduceSum(ift.Operator):
return x.new(val, jac, met)
class SliceSum(ift.Operator):
def __init__(self, oplist, index_low, parallel_space, comm):
self._oplist, self._comm = oplist, comm
self._lo = int(index_low)
assert len(parallel_space.shape) == 1
if len(oplist) > 0:
assert index_low < parallel_space.shape[0]
else:
assert index_low == parallel_space.shape[0]
doms = _get_global_unique(oplist, lambda op: op.domain, comm)
self._domain = ift.makeDomain((parallel_space,) + tuple(dd for dd in doms))
self._target = ift.makeDomain(
_get_global_unique(oplist, lambda op: op.target, comm)
)
def apply(self, x):
self._check_input(x)
if not ift.is_linearization(x):
opx = [
op(ift.makeField(op.domain, x.val[self._lo + ii]))
for ii, op in enumerate(self._oplist)
]
return ift.utilities.allreduce_sum(opx, self._comm)
oplin = [
op(
ift.Linearization.make_var(
ift.makeField(op.domain, x.val.val[self._lo + ii]), x.want_metric
)
)
for ii, op in enumerate(self._oplist)
]
val = ift.utilities.allreduce_sum([lin.val for lin in oplin], self._comm)
args = self._lo, self.domain[0], self._comm
jac = SliceSumLinear([lin.jac for lin in oplin], *args)
if _get_global_unique(oplin, lambda op: op.metric is None, self._comm):
return x.new(val, jac)
met = SliceLinear([lin.metric for lin in oplin], *args)
return x.new(val, jac, met)
class SliceSumLinear(ift.LinearOperator):
def __init__(self, oplist, index_low, parallel_space, comm):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
doms = _get_global_unique(oplist, lambda op: op.domain, comm)
self._domain = ift.makeDomain((parallel_space,) + tuple(dd for dd in doms))
self._target = ift.makeDomain(
_get_global_unique(oplist, lambda op: op.target, comm)
)
cap = _get_global_unique(oplist, lambda op: op._capability, comm)
self._capability = (self.TIMES | self.ADJOINT_TIMES) & cap
self._oplist = oplist
self._comm = comm
self._lo = index_low
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return ift.utilities.allreduce_sum(
[
op(ift.makeField(op.domain, x.val[self._lo + ii]))
for ii, op in enumerate(self._oplist)
],
self._comm,
)
else:
arr = _allgather([op.adjoint(x).val for op in self._oplist], self._comm)
return ift.makeField(self.domain, arr)
class SliceLinear(ift.EndomorphicOperator):
def __init__(self, oplist, index_low, parallel_space, comm):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
doms = _get_global_unique(oplist, lambda op: op.domain, comm)
self._domain = ift.makeDomain((parallel_space,) + tuple(dd for dd in doms))
cap = _get_global_unique(oplist, lambda op: op._capability, comm)
self._capability = (self.TIMES | self.ADJOINT_TIMES) & cap
self._oplist = oplist
self._comm = comm
self._lo = index_low
local_nwork = [len(oplist)] if comm is None else comm.allgather(len(oplist))
size, rank, _ = ift.utilities.get_MPI_params_from_comm(comm)
self._nwork = sum(local_nwork)
def apply(self, x, mode):
self._check_input(x, mode)
res = _allgather(
[
op.apply(ift.makeField(op.domain, x.val[self._lo + ii]), mode).val
for ii, op in enumerate(self._oplist)
],
self._comm,
)
return ift.makeField(self._domain, res)
def draw_sample(self, from_inverse=False):
sseq = ift.random.spawn_sseq(self._nwork)
local_samples = []
for ii, op in enumerate(self._oplist):
with ift.random.Context(sseq[self._lo + ii]):
local_samples.append(op.draw_sample(from_inverse).val)
res = _allgather(local_samples, self._comm)
return ift.makeField(self._domain, res)
class AllreduceSumLinear(ift.LinearOperator):
def __init__(self, oplist, comm=None):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
......@@ -80,3 +182,11 @@ def _get_global_unique(lst, f, comm):
cap = caps[0]
assert all(cc == cap for cc in caps)
return cap
def _allgather(arrs, comm):
if comm is None:
fulllst = [arrs]
else:
fulllst = comm.allgather(arrs)
return np.array([aa for cc in fulllst for aa in cc])
......@@ -12,7 +12,7 @@ import nifty7 as ift
import resolve as rve
def getop(comm):
def getop(comm, typ):
"""Return energy operator that maps the full multi-frequency sky onto
the log-likelihood value for a frequency slice."""
......@@ -40,12 +40,14 @@ def getop(comm):
ddom = ift.UnstructuredDomain(d[ii].shape)
dd = ift.makeField(ddom, d[ii])
iicc = ift.makeOp(ift.makeField(ddom, invcov[ii]))
ee = (
ift.GaussianEnergy(dd, iicc)
@ ift.DomainTupleFieldInserter(skydom, 0, (ii,)).adjoint
)
ee = ift.GaussianEnergy(dd, iicc)
if typ == 0:
ee = ee @ ift.DomainTupleFieldInserter(skydom, 0, (ii,)).adjoint
lst.append(ee)
op = rve.AllreduceSum(lst, comm)
if typ == 0:
op = rve.AllreduceSum(lst, comm)
else:
op = rve.SliceSum(lst, lo, skydom[0], comm)
ift.extra.check_operator(op, ift.from_random(op.domain))
sky = ift.FieldAdapter(skydom, "sky")
return op @ sky.exp()
......@@ -69,13 +71,20 @@ def test_mpi_adder():
if comm is not None:
comm.Barrier()
lhs = getop(-1), getop(None), getop(comm)
lhs = (
getop(-1, 0),
getop(-1, 1),
getop(None, 0),
getop(None, 1),
getop(comm, 0),
getop(comm, 1),
)
hams = tuple(
ift.StandardHamiltonian(lh, ift.GradientNormController(iteration_limit=10))
for lh in lhs
)
lhs_for_sampling = lhs[1:]
hams_for_sampling = hams[1:]
lhs_for_sampling = lhs[2:]
hams_for_sampling = hams[2:]
# Evaluate Field
dom, tgt = lhs[0].domain, lhs[0].target
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment