diff --git a/src/re/likelihood.py b/src/re/likelihood.py
index 58085c24be8d984bb485303b0bdf6230c651691a..96b32cb5c5948e2d73a13518ec346ec750db8c65 100644
--- a/src/re/likelihood.py
+++ b/src/re/likelihood.py
@@ -1,7 +1,9 @@
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 
+import operator
 from dataclasses import field
-from typing import Any, Callable, TypeVar
+from functools import reduce
+from typing import Any, Callable, Tuple, TypeVar
 
 import jax
 from jax.tree_util import (
@@ -647,46 +649,41 @@ class LikelihoodWithModel(Likelihood):
 
 
 class LikelihoodSum(Likelihood):
-    left_likelihood: Likelihood = field(metadata=dict(static=False))
-    right_likelihood: Likelihood = field(metadata=dict(static=False))
+    likelihood_summands: Tuple[Likelihood] = field(metadata=dict(static=False))
 
     def __init__(
         self,
-        left,
-        right,
-        /,
+        *likelihood_summands,
         domain=NoValue,
         init=NoValue,
-        _left_key="lh_left",
-        _right_key="lh_right",
+        _key_template="lh_{}",
     ):
-        if not (isinstance(left, Likelihood) and isinstance(right, Likelihood)):
-            te = (
-                "object which to add to this instance is of invalid type"
-                f" {type(right)!r}"
-            )
-            raise TypeError(te)
-        self._lkey, self._rkey = _left_key, _right_key
+        for i, lh in enumerate(likelihood_summands):
+            if not isinstance(lh, Likelihood):
+                te = (
+                    f"object at position {i} which to add to this instance is of"
+                    f" invalid type {type(lh)!r}"
+                )
+                raise TypeError(te)
         joined_tangents_shape = {
-            self._lkey: left._lsm_tan_shp,
-            self._rkey: right._lsm_tan_shp,
+            _key_template.format(i): lh._lsm_tan_shp
+            for i, lh in enumerate(likelihood_summands)
         }
-        if isinstance(left._lsm_tan_shp, Vector) or isinstance(
-            right._lsm_tan_shp, Vector
-        ):
+        if any(isinstance(lh._lsm_tan_shp, Vector) for lh in likelihood_summands):
             joined_tangents_shape = Vector(joined_tangents_shape)
 
-        if (
-            domain is NoValue
-            and left.domain is not NoValue
-            and right.domain is not NoValue
+        if domain is NoValue and all(
+            lh.domain is not NoValue for lh in likelihood_summands
         ):
-            lvec = isinstance(left.domain, Vector)
-            rvec = isinstance(right.domain, Vector)
-            ldomain = left.domain.tree if lvec else left.domain
-            rdomain = right.domain.tree if rvec else right.domain
-            domain = ldomain | rdomain
-            domain = Vector(domain) if lvec or rvec else domain
+            domain = reduce(
+                operator.or_,
+                (
+                    lh.domain.tree if isinstance(lh.domain, Vector) else lh.domain
+                    for lh in likelihood_summands
+                ),
+            )
+            isvec = any(isinstance(lh.domain, Vector) for lh in likelihood_summands)
+            domain = Vector(domain) if isvec else domain
             isswd = hasattr(domain, "shape") and hasattr(domain, "dtype")
             if not isswd and not has_arithmetics(domain):
                 ve = (
@@ -696,49 +693,60 @@ class LikelihoodSum(Likelihood):
                     " in `Vector`s"
                 )
                 raise ValueError(ve)
-        self.left_likelihood = left
-        self.right_likelihood = right
+        self.likelihood_summands = tuple(likelihood_summands)
+        self._key_template = _key_template
         super().__init__(
             domain=domain, init=init, lsm_tangents_shape=joined_tangents_shape
         )
 
     def energy(self, primals, **kwargs):
-        return self.left_likelihood.energy(
-            primals, **kwargs
-        ) + self.right_likelihood.energy(primals, **kwargs)
+        return reduce(
+            operator.add,
+            (lh.energy(primals, **kwargs) for lh in self.likelihood_summands),
+        )
 
     def normalized_residual(self, primals, **kwargs):
-        lres = self.left_likelihood.normalized_residual(primals, **kwargs)
-        rres = self.right_likelihood.normalized_residual(primals, **kwargs)
-        lvec, rvec = isinstance(lres, Vector), isinstance(rres, Vector)
-        res = {self._lkey: lres, self._rkey: rres}
-        res = Vector(res) if lvec or rvec else res
-        return res
+        res = {
+            self._key_template.format(i): lh.normalized_residual(primals, **kwargs)
+            for i, lh in enumerate(self.likelihood_summands)
+        }
+        isvec = any(isinstance(lh.domain, Vector) for lh in self.likelihood_summands)
+        return Vector(res) if isvec else res
 
     def metric(self, primals, tangents, **kwargs):
-        return self.left_likelihood.metric(
-            primals, tangents, **kwargs
-        ) + self.right_likelihood.metric(primals, tangents, **kwargs)
+        return reduce(
+            operator.add,
+            (lh.metric(primals, tangents, **kwargs) for lh in self.likelihood_summands),
+        )
 
     def transformation(self, primals, **kwargs):
-        lres = self.left_likelihood.transformation(primals, **kwargs)
-        rres = self.right_likelihood.transformation(primals, **kwargs)
-        lvec, rvec = isinstance(lres, Vector), isinstance(rres, Vector)
-        res = {self._lkey: lres, self._rkey: rres}
-        res = Vector(res) if lvec or rvec else res
-        return res
+        res = {
+            self._key_template.format(i): lh.transformation(primals, **kwargs)
+            for i, lh in enumerate(self.likelihood_summands)
+        }
+        isvec = any(isinstance(lh.domain, Vector) for lh in self.likelihood_summands)
+        return Vector(res) if isvec else res
 
     def left_sqrt_metric(self, primals, tangents, **kwargs):
-        return self.left_likelihood.left_sqrt_metric(
-            primals, tangents[self._lkey], **kwargs
-        ) + self.right_likelihood.left_sqrt_metric(
-            primals, tangents[self._rkey], **kwargs
+        return reduce(
+            operator.add,
+            (
+                lh.left_sqrt_metric(
+                    primals, tangents[self._key_template.format(i)], **kwargs
+                )
+                for i, lh in enumerate(self.likelihood_summands)
+            ),
         )
 
     def right_sqrt_metric(self, primals, tangents, **kwargs):
-        lres = self.left_likelihood.right_sqrt_metric(primals, tangents, **kwargs)
-        rres = self.right_likelihood.right_sqrt_metric(primals, tangents, **kwargs)
-        lvec, rvec = isinstance(lres, Vector), isinstance(rres, Vector)
-        res = {self._lkey: lres, self._rkey: rres}
-        res = Vector(res) if lvec or rvec else res
-        return res
+        res = {
+            self._key_template.format(i): lh.right_sqrt_metric(
+                primals, tangents, **kwargs
+            )
+            for i, lh in enumerate(self.likelihood_summands)
+        }
+        isvec = any(isinstance(lh.domain, Vector) for lh in self.likelihood_summands)
+        return Vector(res) if isvec else res
+
+    def __add__(self, other):
+        return LikelihoodSum(*self.likelihood_summands, other)
diff --git a/test/test_re/test_likelihood.py b/test/test_re/test_likelihood.py
index ed4288220c8c04c80cd14ba3c4159304d2942fe7..3810681082f2395b7bbb07997b15e50834adb665 100644
--- a/test/test_re/test_likelihood.py
+++ b/test/test_re/test_likelihood.py
@@ -19,6 +19,8 @@ from nifty8.re.likelihood import partial_insert_and_remove as jpartial
 
 pmp = pytest.mark.parametrize
 
+jax.config.update("jax_enable_x64", True)
+
 
 def _identity(x):
     return x
@@ -143,8 +145,8 @@ def test_nonvariable_likelihood_add(seed, likelihood, forward_a, forward_b):
     assert_allclose(jax.vmap(lh_orig)(p), jax.vmap(lh_ab)(p), equal_nan=False)
     rsm_orig = jax.vmap(lh_orig.right_sqrt_metric)(p, t)
     rsm_ab = jax.vmap(lh_ab.right_sqrt_metric)(p, t)
-    tree_assert_allclose(rsm_orig.tree[key_a], rsm_ab["lh_left"], equal_nan=False)
-    tree_assert_allclose(rsm_orig.tree[key_b], rsm_ab["lh_right"], equal_nan=False)
+    tree_assert_allclose(rsm_orig.tree[key_a], rsm_ab["lh_0"], equal_nan=False)
+    tree_assert_allclose(rsm_orig.tree[key_b], rsm_ab["lh_1"], equal_nan=False)
     tree_assert_allclose(
         jax.vmap(
             lambda p, t, q: lh_orig.left_sqrt_metric(p, lh_orig.right_sqrt_metric(t, q))
@@ -159,12 +161,12 @@ def test_nonvariable_likelihood_add(seed, likelihood, forward_a, forward_b):
     )
     nresi_orig = jax.vmap(lh_orig.normalized_residual)(p)
     nresi_ab = jax.vmap(lh_ab.normalized_residual)(p)
-    tree_assert_allclose(nresi_orig.tree[key_a], nresi_ab["lh_left"], equal_nan=False)
-    tree_assert_allclose(nresi_orig.tree[key_b], nresi_ab["lh_right"], equal_nan=False)
+    tree_assert_allclose(nresi_orig.tree[key_a], nresi_ab["lh_0"], equal_nan=False)
+    tree_assert_allclose(nresi_orig.tree[key_b], nresi_ab["lh_1"], equal_nan=False)
     trafo_orig = jax.vmap(lh_orig.transformation)(p)
     trafo_ab = jax.vmap(lh_ab.transformation)(p)
-    tree_assert_allclose(trafo_orig.tree[key_a], trafo_ab["lh_left"], equal_nan=False)
-    tree_assert_allclose(trafo_orig.tree[key_b], trafo_ab["lh_right"], equal_nan=False)
+    tree_assert_allclose(trafo_orig.tree[key_a], trafo_ab["lh_0"], equal_nan=False)
+    tree_assert_allclose(trafo_orig.tree[key_b], trafo_ab["lh_1"], equal_nan=False)
 
 
 @pmp("seed", (33, 42, 43))
@@ -224,10 +226,10 @@ def test_variable_likelihood_add(seed, likelihood, forward_a, forward_b):
     rsm_orig = jax.vmap(lh_orig.right_sqrt_metric)(p, t)
     rsm_ab = jax.vmap(lh_ab.right_sqrt_metric)(p, t)
     tree_assert_allclose(
-        tuple(r.tree[key_a] for r in rsm_orig), rsm_ab["lh_left"], equal_nan=False
+        tuple(r.tree[key_a] for r in rsm_orig), rsm_ab["lh_0"], equal_nan=False
     )
     tree_assert_allclose(
-        tuple(r.tree[key_b] for r in rsm_orig), rsm_ab["lh_right"], equal_nan=False
+        tuple(r.tree[key_b] for r in rsm_orig), rsm_ab["lh_1"], equal_nan=False
     )
     tree_assert_allclose(
         jax.vmap(
@@ -243,8 +245,8 @@ def test_variable_likelihood_add(seed, likelihood, forward_a, forward_b):
     )
     nresi_orig = lh_orig.normalized_residual(p)
     nresi_ab = lh_ab.normalized_residual(p)
-    tree_assert_allclose(nresi_orig[key_a], nresi_ab["lh_left"], equal_nan=False)
-    tree_assert_allclose(nresi_orig[key_b], nresi_ab["lh_right"], equal_nan=False)
+    tree_assert_allclose(nresi_orig[key_a], nresi_ab["lh_0"], equal_nan=False)
+    tree_assert_allclose(nresi_orig[key_b], nresi_ab["lh_1"], equal_nan=False)
 
     try:
         jax.vmap(lh_orig.transformation)(p)
@@ -253,10 +255,10 @@ def test_variable_likelihood_add(seed, likelihood, forward_a, forward_b):
     trafo_orig = jax.vmap(lh_orig.transformation)(p)
     trafo_ab = jax.vmap(lh_ab.transformation)(p)
     tree_assert_allclose(
-        tuple(t[key_a] for t in trafo_orig), trafo_ab["lh_left"], equal_nan=False
+        tuple(t[key_a] for t in trafo_orig), trafo_ab["lh_0"], equal_nan=False
     )
     tree_assert_allclose(
-        tuple(t[key_b] for t in trafo_orig), trafo_ab["lh_right"], equal_nan=False
+        tuple(t[key_b] for t in trafo_orig), trafo_ab["lh_1"], equal_nan=False
     )