diff --git a/src/re/likelihood.py b/src/re/likelihood.py index 58085c24be8d984bb485303b0bdf6230c651691a..96b32cb5c5948e2d73a13518ec346ec750db8c65 100644 --- a/src/re/likelihood.py +++ b/src/re/likelihood.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause +import operator from dataclasses import field -from typing import Any, Callable, TypeVar +from functools import reduce +from typing import Any, Callable, Tuple, TypeVar import jax from jax.tree_util import ( @@ -647,46 +649,41 @@ class LikelihoodWithModel(Likelihood): class LikelihoodSum(Likelihood): - left_likelihood: Likelihood = field(metadata=dict(static=False)) - right_likelihood: Likelihood = field(metadata=dict(static=False)) + likelihood_summands: Tuple[Likelihood] = field(metadata=dict(static=False)) def __init__( self, - left, - right, - /, + *likelihood_summands, domain=NoValue, init=NoValue, - _left_key="lh_left", - _right_key="lh_right", + _key_template="lh_{}", ): - if not (isinstance(left, Likelihood) and isinstance(right, Likelihood)): - te = ( - "object which to add to this instance is of invalid type" - f" {type(right)!r}" - ) - raise TypeError(te) - self._lkey, self._rkey = _left_key, _right_key + for i, lh in enumerate(likelihood_summands): + if not isinstance(lh, Likelihood): + te = ( + f"object at position {i} which to add to this instance is of" + f" invalid type {type(lh)!r}" + ) + raise TypeError(te) joined_tangents_shape = { - self._lkey: left._lsm_tan_shp, - self._rkey: right._lsm_tan_shp, + _key_template.format(i): lh._lsm_tan_shp + for i, lh in enumerate(likelihood_summands) } - if isinstance(left._lsm_tan_shp, Vector) or isinstance( - right._lsm_tan_shp, Vector - ): + if any(isinstance(lh._lsm_tan_shp, Vector) for lh in likelihood_summands): joined_tangents_shape = Vector(joined_tangents_shape) - if ( - domain is NoValue - and left.domain is not NoValue - and right.domain is not NoValue + if domain is NoValue and all( + lh.domain is not NoValue for lh in likelihood_summands ): - lvec = isinstance(left.domain, Vector) - rvec = isinstance(right.domain, Vector) - ldomain = left.domain.tree if lvec else left.domain - rdomain = right.domain.tree if rvec else right.domain - domain = ldomain | rdomain - domain = Vector(domain) if lvec or rvec else domain + domain = reduce( + operator.or_, + ( + lh.domain.tree if isinstance(lh.domain, Vector) else lh.domain + for lh in likelihood_summands + ), + ) + isvec = any(isinstance(lh.domain, Vector) for lh in likelihood_summands) + domain = Vector(domain) if isvec else domain isswd = hasattr(domain, "shape") and hasattr(domain, "dtype") if not isswd and not has_arithmetics(domain): ve = ( @@ -696,49 +693,60 @@ class LikelihoodSum(Likelihood): " in `Vector`s" ) raise ValueError(ve) - self.left_likelihood = left - self.right_likelihood = right + self.likelihood_summands = tuple(likelihood_summands) + self._key_template = _key_template super().__init__( domain=domain, init=init, lsm_tangents_shape=joined_tangents_shape ) def energy(self, primals, **kwargs): - return self.left_likelihood.energy( - primals, **kwargs - ) + self.right_likelihood.energy(primals, **kwargs) + return reduce( + operator.add, + (lh.energy(primals, **kwargs) for lh in self.likelihood_summands), + ) def normalized_residual(self, primals, **kwargs): - lres = self.left_likelihood.normalized_residual(primals, **kwargs) - rres = self.right_likelihood.normalized_residual(primals, **kwargs) - lvec, rvec = isinstance(lres, Vector), isinstance(rres, Vector) - res = {self._lkey: lres, self._rkey: rres} - res = Vector(res) if lvec or rvec else res - return res + res = { + self._key_template.format(i): lh.normalized_residual(primals, **kwargs) + for i, lh in enumerate(self.likelihood_summands) + } + isvec = any(isinstance(lh.domain, Vector) for lh in self.likelihood_summands) + return Vector(res) if isvec else res def metric(self, primals, tangents, **kwargs): - return self.left_likelihood.metric( - primals, tangents, **kwargs - ) + self.right_likelihood.metric(primals, tangents, **kwargs) + return reduce( + operator.add, + (lh.metric(primals, tangents, **kwargs) for lh in self.likelihood_summands), + ) def transformation(self, primals, **kwargs): - lres = self.left_likelihood.transformation(primals, **kwargs) - rres = self.right_likelihood.transformation(primals, **kwargs) - lvec, rvec = isinstance(lres, Vector), isinstance(rres, Vector) - res = {self._lkey: lres, self._rkey: rres} - res = Vector(res) if lvec or rvec else res - return res + res = { + self._key_template.format(i): lh.transformation(primals, **kwargs) + for i, lh in enumerate(self.likelihood_summands) + } + isvec = any(isinstance(lh.domain, Vector) for lh in self.likelihood_summands) + return Vector(res) if isvec else res def left_sqrt_metric(self, primals, tangents, **kwargs): - return self.left_likelihood.left_sqrt_metric( - primals, tangents[self._lkey], **kwargs - ) + self.right_likelihood.left_sqrt_metric( - primals, tangents[self._rkey], **kwargs + return reduce( + operator.add, + ( + lh.left_sqrt_metric( + primals, tangents[self._key_template.format(i)], **kwargs + ) + for i, lh in enumerate(self.likelihood_summands) + ), ) def right_sqrt_metric(self, primals, tangents, **kwargs): - lres = self.left_likelihood.right_sqrt_metric(primals, tangents, **kwargs) - rres = self.right_likelihood.right_sqrt_metric(primals, tangents, **kwargs) - lvec, rvec = isinstance(lres, Vector), isinstance(rres, Vector) - res = {self._lkey: lres, self._rkey: rres} - res = Vector(res) if lvec or rvec else res - return res + res = { + self._key_template.format(i): lh.right_sqrt_metric( + primals, tangents, **kwargs + ) + for i, lh in enumerate(self.likelihood_summands) + } + isvec = any(isinstance(lh.domain, Vector) for lh in self.likelihood_summands) + return Vector(res) if isvec else res + + def __add__(self, other): + return LikelihoodSum(*self.likelihood_summands, other) diff --git a/test/test_re/test_likelihood.py b/test/test_re/test_likelihood.py index ed4288220c8c04c80cd14ba3c4159304d2942fe7..3810681082f2395b7bbb07997b15e50834adb665 100644 --- a/test/test_re/test_likelihood.py +++ b/test/test_re/test_likelihood.py @@ -19,6 +19,8 @@ from nifty8.re.likelihood import partial_insert_and_remove as jpartial pmp = pytest.mark.parametrize +jax.config.update("jax_enable_x64", True) + def _identity(x): return x @@ -143,8 +145,8 @@ def test_nonvariable_likelihood_add(seed, likelihood, forward_a, forward_b): assert_allclose(jax.vmap(lh_orig)(p), jax.vmap(lh_ab)(p), equal_nan=False) rsm_orig = jax.vmap(lh_orig.right_sqrt_metric)(p, t) rsm_ab = jax.vmap(lh_ab.right_sqrt_metric)(p, t) - tree_assert_allclose(rsm_orig.tree[key_a], rsm_ab["lh_left"], equal_nan=False) - tree_assert_allclose(rsm_orig.tree[key_b], rsm_ab["lh_right"], equal_nan=False) + tree_assert_allclose(rsm_orig.tree[key_a], rsm_ab["lh_0"], equal_nan=False) + tree_assert_allclose(rsm_orig.tree[key_b], rsm_ab["lh_1"], equal_nan=False) tree_assert_allclose( jax.vmap( lambda p, t, q: lh_orig.left_sqrt_metric(p, lh_orig.right_sqrt_metric(t, q)) @@ -159,12 +161,12 @@ def test_nonvariable_likelihood_add(seed, likelihood, forward_a, forward_b): ) nresi_orig = jax.vmap(lh_orig.normalized_residual)(p) nresi_ab = jax.vmap(lh_ab.normalized_residual)(p) - tree_assert_allclose(nresi_orig.tree[key_a], nresi_ab["lh_left"], equal_nan=False) - tree_assert_allclose(nresi_orig.tree[key_b], nresi_ab["lh_right"], equal_nan=False) + tree_assert_allclose(nresi_orig.tree[key_a], nresi_ab["lh_0"], equal_nan=False) + tree_assert_allclose(nresi_orig.tree[key_b], nresi_ab["lh_1"], equal_nan=False) trafo_orig = jax.vmap(lh_orig.transformation)(p) trafo_ab = jax.vmap(lh_ab.transformation)(p) - tree_assert_allclose(trafo_orig.tree[key_a], trafo_ab["lh_left"], equal_nan=False) - tree_assert_allclose(trafo_orig.tree[key_b], trafo_ab["lh_right"], equal_nan=False) + tree_assert_allclose(trafo_orig.tree[key_a], trafo_ab["lh_0"], equal_nan=False) + tree_assert_allclose(trafo_orig.tree[key_b], trafo_ab["lh_1"], equal_nan=False) @pmp("seed", (33, 42, 43)) @@ -224,10 +226,10 @@ def test_variable_likelihood_add(seed, likelihood, forward_a, forward_b): rsm_orig = jax.vmap(lh_orig.right_sqrt_metric)(p, t) rsm_ab = jax.vmap(lh_ab.right_sqrt_metric)(p, t) tree_assert_allclose( - tuple(r.tree[key_a] for r in rsm_orig), rsm_ab["lh_left"], equal_nan=False + tuple(r.tree[key_a] for r in rsm_orig), rsm_ab["lh_0"], equal_nan=False ) tree_assert_allclose( - tuple(r.tree[key_b] for r in rsm_orig), rsm_ab["lh_right"], equal_nan=False + tuple(r.tree[key_b] for r in rsm_orig), rsm_ab["lh_1"], equal_nan=False ) tree_assert_allclose( jax.vmap( @@ -243,8 +245,8 @@ def test_variable_likelihood_add(seed, likelihood, forward_a, forward_b): ) nresi_orig = lh_orig.normalized_residual(p) nresi_ab = lh_ab.normalized_residual(p) - tree_assert_allclose(nresi_orig[key_a], nresi_ab["lh_left"], equal_nan=False) - tree_assert_allclose(nresi_orig[key_b], nresi_ab["lh_right"], equal_nan=False) + tree_assert_allclose(nresi_orig[key_a], nresi_ab["lh_0"], equal_nan=False) + tree_assert_allclose(nresi_orig[key_b], nresi_ab["lh_1"], equal_nan=False) try: jax.vmap(lh_orig.transformation)(p) @@ -253,10 +255,10 @@ def test_variable_likelihood_add(seed, likelihood, forward_a, forward_b): trafo_orig = jax.vmap(lh_orig.transformation)(p) trafo_ab = jax.vmap(lh_ab.transformation)(p) tree_assert_allclose( - tuple(t[key_a] for t in trafo_orig), trafo_ab["lh_left"], equal_nan=False + tuple(t[key_a] for t in trafo_orig), trafo_ab["lh_0"], equal_nan=False ) tree_assert_allclose( - tuple(t[key_b] for t in trafo_orig), trafo_ab["lh_right"], equal_nan=False + tuple(t[key_b] for t in trafo_orig), trafo_ab["lh_1"], equal_nan=False )