Skip to content
Snippets Groups Projects
Commit 3ee09d18 authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge branch 'more_mpi_functionality' into 'master'

More mpi functionality

See merge request !13
parents c8d398b3 dd72d8de
No related branches found
No related tags found
1 merge request!13More mpi functionality
Pipeline #94212 passed
......@@ -27,11 +27,9 @@ class AllreduceSum(ift.Operator):
def apply(self, x):
self._check_input(x)
if not ift.is_linearization(x):
return ift.utilities.allreduce_sum(
[op(x) for op in self._oplist], self._comm
)
opx = [op(x) for op in self._oplist]
if not ift.is_linearization(x):
return ift.utilities.allreduce_sum(opx, self._comm)
val = ift.utilities.allreduce_sum([lin.val for lin in opx], self._comm)
jac = AllreduceSumLinear([lin.jac for lin in opx], self._comm)
if _get_global_unique(opx, lambda op: op.metric is None, self._comm):
......@@ -40,6 +38,110 @@ class AllreduceSum(ift.Operator):
return x.new(val, jac, met)
class SliceSum(ift.Operator):
def __init__(self, oplist, index_low, parallel_space, comm):
self._oplist, self._comm = oplist, comm
self._lo = int(index_low)
assert len(parallel_space.shape) == 1
if len(oplist) > 0:
assert index_low < parallel_space.shape[0]
else:
assert index_low == parallel_space.shape[0]
doms = _get_global_unique(oplist, lambda op: op.domain, comm)
self._domain = ift.makeDomain((parallel_space,) + tuple(dd for dd in doms))
self._target = ift.makeDomain(
_get_global_unique(oplist, lambda op: op.target, comm)
)
def apply(self, x):
self._check_input(x)
if not ift.is_linearization(x):
opx = [
op(ift.makeField(op.domain, x.val[self._lo + ii]))
for ii, op in enumerate(self._oplist)
]
return ift.utilities.allreduce_sum(opx, self._comm)
oplin = [
op(
ift.Linearization.make_var(
ift.makeField(op.domain, x.val.val[self._lo + ii]), x.want_metric
)
)
for ii, op in enumerate(self._oplist)
]
val = ift.utilities.allreduce_sum([lin.val for lin in oplin], self._comm)
args = self._lo, self.domain[0], self._comm
jac = SliceSumLinear([lin.jac for lin in oplin], *args)
if _get_global_unique(oplin, lambda op: op.metric is None, self._comm):
return x.new(val, jac)
met = SliceLinear([lin.metric for lin in oplin], *args)
return x.new(val, jac, met)
class SliceSumLinear(ift.LinearOperator):
def __init__(self, oplist, index_low, parallel_space, comm):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
doms = _get_global_unique(oplist, lambda op: op.domain, comm)
self._domain = ift.makeDomain((parallel_space,) + tuple(dd for dd in doms))
self._target = ift.makeDomain(
_get_global_unique(oplist, lambda op: op.target, comm)
)
cap = _get_global_unique(oplist, lambda op: op._capability, comm)
self._capability = (self.TIMES | self.ADJOINT_TIMES) & cap
self._oplist = oplist
self._comm = comm
self._lo = index_low
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return ift.utilities.allreduce_sum(
[
op(ift.makeField(op.domain, x.val[self._lo + ii]))
for ii, op in enumerate(self._oplist)
],
self._comm,
)
else:
arr = _allgather([op.adjoint(x).val for op in self._oplist], self._comm)
return ift.makeField(self.domain, arr)
class SliceLinear(ift.EndomorphicOperator):
def __init__(self, oplist, index_low, parallel_space, comm):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
doms = _get_global_unique(oplist, lambda op: op.domain, comm)
self._domain = ift.makeDomain((parallel_space,) + tuple(dd for dd in doms))
cap = _get_global_unique(oplist, lambda op: op._capability, comm)
self._capability = (self.TIMES | self.ADJOINT_TIMES) & cap
self._oplist = oplist
self._comm = comm
self._lo = index_low
local_nwork = [len(oplist)] if comm is None else comm.allgather(len(oplist))
size, rank, _ = ift.utilities.get_MPI_params_from_comm(comm)
self._nwork = sum(local_nwork)
def apply(self, x, mode):
self._check_input(x, mode)
res = _allgather(
[
op.apply(ift.makeField(op.domain, x.val[self._lo + ii]), mode).val
for ii, op in enumerate(self._oplist)
],
self._comm,
)
return ift.makeField(self._domain, res)
def draw_sample(self, from_inverse=False):
sseq = ift.random.spawn_sseq(self._nwork)
local_samples = []
for ii, op in enumerate(self._oplist):
with ift.random.Context(sseq[self._lo + ii]):
local_samples.append(op.draw_sample(from_inverse).val)
res = _allgather(local_samples, self._comm)
return ift.makeField(self._domain, res)
class AllreduceSumLinear(ift.LinearOperator):
def __init__(self, oplist, comm=None):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
......@@ -80,3 +182,11 @@ def _get_global_unique(lst, f, comm):
cap = caps[0]
assert all(cc == cap for cc in caps)
return cap
def _allgather(arrs, comm):
if comm is None:
fulllst = [arrs]
else:
fulllst = comm.allgather(arrs)
return np.array([aa for cc in fulllst for aa in cc])
......@@ -12,12 +12,13 @@ import nifty7 as ift
import resolve as rve
def getop(comm):
def getop(comm, typ):
"""Return energy operator that maps the full multi-frequency sky onto
the log-likelihood value for a frequency slice."""
d = np.load("data.npy")
invcov = np.load("invcov.npy")
skydom = ift.UnstructuredDomain(d.shape[0]), ift.UnstructuredDomain(d.shape[1:])
if comm == -1:
nwork = d.shape[0]
ddom = ift.UnstructuredDomain(d[0].shape)
......@@ -25,6 +26,7 @@ def getop(comm):
ift.GaussianEnergy(
ift.makeField(ddom, d[ii]), ift.makeOp(ift.makeField(ddom, invcov[ii]))
)
@ ift.DomainTupleFieldInserter(skydom, 0, (ii,)).adjoint
for ii in range(nwork)
]
op = reduce(add, ops)
......@@ -39,10 +41,15 @@ def getop(comm):
dd = ift.makeField(ddom, d[ii])
iicc = ift.makeOp(ift.makeField(ddom, invcov[ii]))
ee = ift.GaussianEnergy(dd, iicc)
if typ == 0:
ee = ee @ ift.DomainTupleFieldInserter(skydom, 0, (ii,)).adjoint
lst.append(ee)
op = rve.AllreduceSum(lst, comm)
if typ == 0:
op = rve.AllreduceSum(lst, comm)
else:
op = rve.SliceSum(lst, lo, skydom[0], comm)
ift.extra.check_operator(op, ift.from_random(op.domain))
sky = ift.FieldAdapter(op.domain, "sky")
sky = ift.FieldAdapter(skydom, "sky")
return op @ sky.exp()
......@@ -64,13 +71,20 @@ def test_mpi_adder():
if comm is not None:
comm.Barrier()
lhs = getop(-1), getop(None), getop(comm)
lhs = (
getop(-1, 0),
getop(-1, 1),
getop(None, 0),
getop(None, 1),
getop(comm, 0),
getop(comm, 1),
)
hams = tuple(
ift.StandardHamiltonian(lh, ift.GradientNormController(iteration_limit=10))
for lh in lhs
)
lhs_for_sampling = lhs[1:]
hams_for_sampling = hams[1:]
lhs_for_sampling = lhs[2:]
hams_for_sampling = hams[2:]
# Evaluate Field
dom, tgt = lhs[0].domain, lhs[0].target
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment