Skip to content
Snippets Groups Projects
Commit f09f19e2 authored by Philipp Arras's avatar Philipp Arras
Browse files

Generalize to case where some tasks are empty

parent 656ea0f8
Branches
Tags
1 merge request!9Mpi adder
Pipeline #93670 passed
...@@ -8,8 +8,12 @@ import nifty7 as ift ...@@ -8,8 +8,12 @@ import nifty7 as ift
class AllreduceSum(ift.Operator): class AllreduceSum(ift.Operator):
def __init__(self, oplist, comm): def __init__(self, oplist, comm):
self._oplist, self._comm = oplist, comm self._oplist, self._comm = oplist, comm
self._domain = self._oplist[0].domain self._domain = ift.makeDomain(
self._target = self._oplist[0].target _get_global_unique(oplist, lambda op: op.domain, comm)
)
self._target = ift.makeDomain(
_get_global_unique(oplist, lambda op: op.target, comm)
)
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
...@@ -20,7 +24,7 @@ class AllreduceSum(ift.Operator): ...@@ -20,7 +24,7 @@ class AllreduceSum(ift.Operator):
opx = [op(x) for op in self._oplist] opx = [op(x) for op in self._oplist]
val = ift.utilities.allreduce_sum([lin.val for lin in opx], self._comm) val = ift.utilities.allreduce_sum([lin.val for lin in opx], self._comm)
jac = AllreduceSumLinear([lin.jac for lin in opx], self._comm) jac = AllreduceSumLinear([lin.jac for lin in opx], self._comm)
if opx[0].metric is None: if _get_global_unique(opx, lambda op: op.metric is None, self._comm):
return x.new(val, jac) return x.new(val, jac)
met = AllreduceSumLinear([lin.metric for lin in opx], self._comm) met = AllreduceSumLinear([lin.metric for lin in opx], self._comm)
return x.new(val, jac, met) return x.new(val, jac, met)
...@@ -29,12 +33,13 @@ class AllreduceSum(ift.Operator): ...@@ -29,12 +33,13 @@ class AllreduceSum(ift.Operator):
class AllreduceSumLinear(ift.LinearOperator): class AllreduceSumLinear(ift.LinearOperator):
def __init__(self, oplist, comm=None): def __init__(self, oplist, comm=None):
assert all(isinstance(oo, ift.LinearOperator) for oo in oplist) assert all(isinstance(oo, ift.LinearOperator) for oo in oplist)
self._domain = ift.makeDomain(oplist[0].domain) self._domain = ift.makeDomain(
self._target = ift.makeDomain(oplist[0].target) _get_global_unique(oplist, lambda op: op.domain, comm)
cap = oplist[0]._capability )
assert all(oo.domain == self._domain for oo in oplist) self._target = ift.makeDomain(
assert all(oo.target == self._target for oo in oplist) _get_global_unique(oplist, lambda op: op.target, comm)
assert all(oo._capability == cap for oo in oplist) )
cap = _get_global_unique(oplist, lambda op: op._capability, comm)
self._capability = (self.TIMES | self.ADJOINT_TIMES) & cap self._capability = (self.TIMES | self.ADJOINT_TIMES) & cap
self._oplist = oplist self._oplist = oplist
self._comm = comm self._comm = comm
...@@ -57,3 +62,13 @@ class AllreduceSumLinear(ift.LinearOperator): ...@@ -57,3 +62,13 @@ class AllreduceSumLinear(ift.LinearOperator):
with ift.random.Context(sseq[lo + ii]): with ift.random.Context(sseq[lo + ii]):
local_samples.append(op.draw_sample(from_inverse)) local_samples.append(op.draw_sample(from_inverse))
return ift.utilities.allreduce_sum(local_samples, self._comm) return ift.utilities.allreduce_sum(local_samples, self._comm)
def _get_global_unique(lst, f, comm):
caps = [f(oo) for oo in lst]
if comm is not None:
caps = comm.allgather(caps)
caps = [aa for cc in caps for aa in cc]
cap = caps[0]
assert all(cc == cap for cc in caps)
return cap
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment