Skip to content
Snippets Groups Projects
Commit 41b3caae authored by Philipp Arras's avatar Philipp Arras
Browse files

Simplify

parent b7dea061
No related branches found
No related tags found
1 merge request!9Mpi adder
Pipeline #93657 passed
......@@ -54,18 +54,17 @@ class AllreduceSum(ift.Operator):
def apply(self, x):
self._check_input(x)
if ift.is_linearization(x):
res = [op(x) for op in self._oplist]
val = ift.utilities.allreduce_sum([lin.val for lin in res], self._comm)
jac = AllreduceSumLinear([lin.jac for lin in res], self._comm)
if res[0].want_metric and res[0].metric is not None:
met = AllreduceSumLinear(
[lin.metric for lin in res], self._comm, self._nwork
if not ift.is_linearization(x):
return ift.utilities.allreduce_sum(
[op(x) for op in self._oplist], self._comm
)
opx = [op(x) for op in self._oplist]
val = ift.utilities.allreduce_sum([lin.val for lin in opx], self._comm)
jac = AllreduceSumLinear([lin.jac for lin in opx], self._comm)
if opx[0].metric is None:
return x.new(val, jac)
met = AllreduceSumLinear([lin.metric for lin in opx], self._comm, self._nwork)
return x.new(val, jac, met)
res = x.new(val, jac)
return res
return ift.utilities.allreduce_sum([op(x) for op in self._oplist], self._comm)
class AllreduceSumLinear(ift.LinearOperator):
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment