From af52cc3f0068994a05d1d57b5dec92d4f44d74a2 Mon Sep 17 00:00:00 2001
From: Gordian Edenhofer <gordian.edenhofer@gmail.com>
Date: Sun, 2 Mar 2025 10:01:35 -0600
Subject: [PATCH] LikelihoodSum: Give users more flexibility to name keys

---
 src/re/likelihood.py | 57 ++++++++++++++++++++------------------------
 1 file changed, 26 insertions(+), 31 deletions(-)

diff --git a/src/re/likelihood.py b/src/re/likelihood.py
index 96209114b..cb0478c5b 100644
--- a/src/re/likelihood.py
+++ b/src/re/likelihood.py
@@ -666,7 +666,7 @@ class LikelihoodSum(Likelihood):
         *likelihood_summands,
         domain=NoValue,
         init=NoValue,
-        _key_template="lh_{}",
+        _key_template="lh_{index}",
     ):
         for i, lh in enumerate(likelihood_summands):
             if not isinstance(lh, Likelihood):
@@ -675,24 +675,24 @@ class LikelihoodSum(Likelihood):
                     f" invalid type {type(lh)!r}"
                 )
                 raise TypeError(te)
-        joined_tangents_shape = {
-            _key_template.format(i): lh._lsm_tan_shp
-            for i, lh in enumerate(likelihood_summands)
-        }
-        if any(isinstance(lh._lsm_tan_shp, Vector) for lh in likelihood_summands):
+        self.likelihood_summands = tuple(likelihood_summands)
+        self._key_template = _key_template
+
+        joined_tangents_shape = {key: lh._lsm_tan_shp for key, lh in self._items()}
+        if any(isinstance(lh._lsm_tan_shp, Vector) for _, lh in self._items()):
             joined_tangents_shape = Vector(joined_tangents_shape)
 
         if domain is NoValue and all(
-            lh.domain is not NoValue for lh in likelihood_summands
+            lh.domain is not NoValue for _, lh in self._items()
         ):
             domain = reduce(
                 operator.or_,
                 (
                     lh.domain.tree if isinstance(lh.domain, Vector) else lh.domain
-                    for lh in likelihood_summands
+                    for _, lh in self._items()
                 ),
             )
-            isvec = any(isinstance(lh.domain, Vector) for lh in likelihood_summands)
+            isvec = any(isinstance(lh.domain, Vector) for _, lh in self._items())
             domain = Vector(domain) if isvec else domain
             isswd = hasattr(domain, "shape") and hasattr(domain, "dtype")
             if not isswd and not has_arithmetics(domain):
@@ -703,59 +703,54 @@ class LikelihoodSum(Likelihood):
                     " in `Vector`s"
                 )
                 raise ValueError(ve)
-        self.likelihood_summands = tuple(likelihood_summands)
-        self._key_template = _key_template
         super().__init__(
             domain=domain, init=init, lsm_tangents_shape=joined_tangents_shape
         )
 
+    def _items(self):
+        for i, lh in enumerate(self.likelihood_summands):
+            # Allow the user to manipulate the keys to contain information about
+            # the likelihood
+            yield self._key_template.format(index=i, likelihood=lh), lh
+
     def energy(self, primals, **kwargs):
         return reduce(
-            operator.add,
-            (lh.energy(primals, **kwargs) for lh in self.likelihood_summands),
+            operator.add, (lh.energy(primals, **kwargs) for _, lh in self._items())
         )
 
     def normalized_residual(self, primals, **kwargs):
         res = {
-            self._key_template.format(i): lh.normalized_residual(primals, **kwargs)
-            for i, lh in enumerate(self.likelihood_summands)
+            key: lh.normalized_residual(primals, **kwargs) for key, lh in self._items()
         }
-        isvec = any(isinstance(lh.domain, Vector) for lh in self.likelihood_summands)
+        isvec = any(isinstance(lh.domain, Vector) for _, lh in self._items())
         return Vector(res) if isvec else res
 
     def metric(self, primals, tangents, **kwargs):
         return reduce(
             operator.add,
-            (lh.metric(primals, tangents, **kwargs) for lh in self.likelihood_summands),
+            (lh.metric(primals, tangents, **kwargs) for _, lh in self._items()),
         )
 
     def transformation(self, primals, **kwargs):
-        res = {
-            self._key_template.format(i): lh.transformation(primals, **kwargs)
-            for i, lh in enumerate(self.likelihood_summands)
-        }
-        isvec = any(isinstance(lh.domain, Vector) for lh in self.likelihood_summands)
+        res = {key: lh.transformation(primals, **kwargs) for key, lh in self._items()}
+        isvec = any(isinstance(lh.domain, Vector) for _, lh in self._items())
         return Vector(res) if isvec else res
 
     def left_sqrt_metric(self, primals, tangents, **kwargs):
         return reduce(
             operator.add,
             (
-                lh.left_sqrt_metric(
-                    primals, tangents[self._key_template.format(i)], **kwargs
-                )
-                for i, lh in enumerate(self.likelihood_summands)
+                lh.left_sqrt_metric(primals, tangents[key], **kwargs)
+                for key, lh in self._items()
             ),
         )
 
     def right_sqrt_metric(self, primals, tangents, **kwargs):
         res = {
-            self._key_template.format(i): lh.right_sqrt_metric(
-                primals, tangents, **kwargs
-            )
-            for i, lh in enumerate(self.likelihood_summands)
+            key: lh.right_sqrt_metric(primals, tangents, **kwargs)
+            for key, lh in self._items()
         }
-        isvec = any(isinstance(lh.domain, Vector) for lh in self.likelihood_summands)
+        isvec = any(isinstance(lh.domain, Vector) for _, lh in self._items())
         return Vector(res) if isvec else res
 
     def __add__(self, other):
-- 
GitLab