From af52cc3f0068994a05d1d57b5dec92d4f44d74a2 Mon Sep 17 00:00:00 2001 From: Gordian Edenhofer <gordian.edenhofer@gmail.com> Date: Sun, 2 Mar 2025 10:01:35 -0600 Subject: [PATCH] LikelihoodSum: Give users more flexibility to name keys --- src/re/likelihood.py | 57 ++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/src/re/likelihood.py b/src/re/likelihood.py index 96209114b..cb0478c5b 100644 --- a/src/re/likelihood.py +++ b/src/re/likelihood.py @@ -666,7 +666,7 @@ class LikelihoodSum(Likelihood): *likelihood_summands, domain=NoValue, init=NoValue, - _key_template="lh_{}", + _key_template="lh_{index}", ): for i, lh in enumerate(likelihood_summands): if not isinstance(lh, Likelihood): @@ -675,24 +675,24 @@ class LikelihoodSum(Likelihood): f" invalid type {type(lh)!r}" ) raise TypeError(te) - joined_tangents_shape = { - _key_template.format(i): lh._lsm_tan_shp - for i, lh in enumerate(likelihood_summands) - } - if any(isinstance(lh._lsm_tan_shp, Vector) for lh in likelihood_summands): + self.likelihood_summands = tuple(likelihood_summands) + self._key_template = _key_template + + joined_tangents_shape = {key: lh._lsm_tan_shp for key, lh in self._items()} + if any(isinstance(lh._lsm_tan_shp, Vector) for _, lh in self._items()): joined_tangents_shape = Vector(joined_tangents_shape) if domain is NoValue and all( - lh.domain is not NoValue for lh in likelihood_summands + lh.domain is not NoValue for _, lh in self._items() ): domain = reduce( operator.or_, ( lh.domain.tree if isinstance(lh.domain, Vector) else lh.domain - for lh in likelihood_summands + for _, lh in self._items() ), ) - isvec = any(isinstance(lh.domain, Vector) for lh in likelihood_summands) + isvec = any(isinstance(lh.domain, Vector) for _, lh in self._items()) domain = Vector(domain) if isvec else domain isswd = hasattr(domain, "shape") and hasattr(domain, "dtype") if not isswd and not has_arithmetics(domain): @@ -703,59 +703,54 @@ class LikelihoodSum(Likelihood): " in `Vector`s" ) raise ValueError(ve) - self.likelihood_summands = tuple(likelihood_summands) - self._key_template = _key_template super().__init__( domain=domain, init=init, lsm_tangents_shape=joined_tangents_shape ) + def _items(self): + for i, lh in enumerate(self.likelihood_summands): + # Allow the user to manipulate the keys to contain information about + # the likelihood + yield self._key_template.format(index=i, likelihood=lh), lh + def energy(self, primals, **kwargs): return reduce( - operator.add, - (lh.energy(primals, **kwargs) for lh in self.likelihood_summands), + operator.add, (lh.energy(primals, **kwargs) for _, lh in self._items()) ) def normalized_residual(self, primals, **kwargs): res = { - self._key_template.format(i): lh.normalized_residual(primals, **kwargs) - for i, lh in enumerate(self.likelihood_summands) + key: lh.normalized_residual(primals, **kwargs) for key, lh in self._items() } - isvec = any(isinstance(lh.domain, Vector) for lh in self.likelihood_summands) + isvec = any(isinstance(lh.domain, Vector) for _, lh in self._items()) return Vector(res) if isvec else res def metric(self, primals, tangents, **kwargs): return reduce( operator.add, - (lh.metric(primals, tangents, **kwargs) for lh in self.likelihood_summands), + (lh.metric(primals, tangents, **kwargs) for _, lh in self._items()), ) def transformation(self, primals, **kwargs): - res = { - self._key_template.format(i): lh.transformation(primals, **kwargs) - for i, lh in enumerate(self.likelihood_summands) - } - isvec = any(isinstance(lh.domain, Vector) for lh in self.likelihood_summands) + res = {key: lh.transformation(primals, **kwargs) for key, lh in self._items()} + isvec = any(isinstance(lh.domain, Vector) for _, lh in self._items()) return Vector(res) if isvec else res def left_sqrt_metric(self, primals, tangents, **kwargs): return reduce( operator.add, ( - lh.left_sqrt_metric( - primals, tangents[self._key_template.format(i)], **kwargs - ) - for i, lh in enumerate(self.likelihood_summands) + lh.left_sqrt_metric(primals, tangents[key], **kwargs) + for key, lh in self._items() ), ) def right_sqrt_metric(self, primals, tangents, **kwargs): res = { - self._key_template.format(i): lh.right_sqrt_metric( - primals, tangents, **kwargs - ) - for i, lh in enumerate(self.likelihood_summands) + key: lh.right_sqrt_metric(primals, tangents, **kwargs) + for key, lh in self._items() } - isvec = any(isinstance(lh.domain, Vector) for lh in self.likelihood_summands) + isvec = any(isinstance(lh.domain, Vector) for _, lh in self._items()) return Vector(res) if isvec else res def __add__(self, other): -- GitLab