diff --git a/demo/mpi_demo.py b/demo/mpi_demo.py
index c00fe3a2b4db813c2de4fd9471ac0ce6a90d3946..301e11f40e801ecbfcafdb1b923468c7ba885702 100644
--- a/demo/mpi_demo.py
+++ b/demo/mpi_demo.py
@@ -54,18 +54,17 @@ class AllreduceSum(ift.Operator):
 
     def apply(self, x):
         self._check_input(x)
-        if ift.is_linearization(x):
-            res = [op(x) for op in self._oplist]
-            val = ift.utilities.allreduce_sum([lin.val for lin in res], self._comm)
-            jac = AllreduceSumLinear([lin.jac for lin in res], self._comm)
-            if res[0].want_metric and res[0].metric is not None:
-                met = AllreduceSumLinear(
-                    [lin.metric for lin in res], self._comm, self._nwork
-                )
-                return x.new(val, jac, met)
-            res = x.new(val, jac)
-            return res
-        return ift.utilities.allreduce_sum([op(x) for op in self._oplist], self._comm)
+        if not ift.is_linearization(x):
+            return ift.utilities.allreduce_sum(
+                [op(x) for op in self._oplist], self._comm
+            )
+        opx = [op(x) for op in self._oplist]
+        val = ift.utilities.allreduce_sum([lin.val for lin in opx], self._comm)
+        jac = AllreduceSumLinear([lin.jac for lin in opx], self._comm)
+        if opx[0].metric is None:
+            return x.new(val, jac)
+        met = AllreduceSumLinear([lin.metric for lin in opx], self._comm, self._nwork)
+        return x.new(val, jac, met)
 
 
 class AllreduceSumLinear(ift.LinearOperator):