diff --git a/demo/mpi_demo.py b/demo/mpi_demo.py index c00fe3a2b4db813c2de4fd9471ac0ce6a90d3946..301e11f40e801ecbfcafdb1b923468c7ba885702 100644 --- a/demo/mpi_demo.py +++ b/demo/mpi_demo.py @@ -54,18 +54,17 @@ class AllreduceSum(ift.Operator): def apply(self, x): self._check_input(x) - if ift.is_linearization(x): - res = [op(x) for op in self._oplist] - val = ift.utilities.allreduce_sum([lin.val for lin in res], self._comm) - jac = AllreduceSumLinear([lin.jac for lin in res], self._comm) - if res[0].want_metric and res[0].metric is not None: - met = AllreduceSumLinear( - [lin.metric for lin in res], self._comm, self._nwork - ) - return x.new(val, jac, met) - res = x.new(val, jac) - return res - return ift.utilities.allreduce_sum([op(x) for op in self._oplist], self._comm) + if not ift.is_linearization(x): + return ift.utilities.allreduce_sum( + [op(x) for op in self._oplist], self._comm + ) + opx = [op(x) for op in self._oplist] + val = ift.utilities.allreduce_sum([lin.val for lin in opx], self._comm) + jac = AllreduceSumLinear([lin.jac for lin in opx], self._comm) + if opx[0].metric is None: + return x.new(val, jac) + met = AllreduceSumLinear([lin.metric for lin in opx], self._comm, self._nwork) + return x.new(val, jac, met) class AllreduceSumLinear(ift.LinearOperator):