Sum of likelihoods with named domains operators

When summing two likelihoods applied to some operators with different named domains, the summation does not work during minimization, because the metric update does not know how to correctly join multiple named inputs. Minimal breaking example:

from jax import random
import nifty8.re as jft
import numpy as np

import jax
jax.config.update("jax_platform_name", "cpu")


seed = 42
key = random.PRNGKey(seed)
shape = (10, 10)

cf_zm = {"offset_mean": 0., "offset_std": (1e-3, 1e-4)}
cf_fl = {
    "fluctuations": (1e-1, 5e-3),
    "loglogavgslope": (-3., 1e-2),
    "flexibility": (1e+0, 5e-1),
    "asperity": (5e-1, 5e-2),
}
cfm = jft.CorrelatedFieldMaker("jcf_")
cfm.set_amplitude_total_offset(**cf_zm)
cfm.add_fluctuations(
    shape,
    distances=1. / shape[0],
    **cf_fl,
    prefix="",
    non_parametric_kind="power",
)

jcf = cfm.finalize()
key, subkey = random.split(key)
pos = jft.random_like(subkey, jcf.domain)
noise_level = 0.2
datar = jcf(pos) + np.random.normal(0, 1, shape)*noise_level


dom_key = 'a'
m = jft.Model(
    lambda x: {dom_key: jcf(x)},
    domain=jcf.domain)


R = jft.Model(lambda x: x[dom_key], domain=m.target)
like1 = jft.Gaussian(datar, lambda x: 1/noise_level**2 * x) @ R
like2 = jft.Gaussian(datar, lambda x: 2/noise_level**2 * x) @ R

ll = (like1 + like2) @ m

pos = jft.random_like(key, jcf.domain)

n_iterations = 2
n_samples = 2
delta = 1e-3
absdelta = 1e-4

samples, _ = jft.optimize_kl(
    ll,
    jft.Vector(pos),
    n_total_iterations=n_iterations,
    n_samples=n_samples,
    # Source for the stochasticity for sampling
    key=key,
    draw_linear_kwargs=dict(cg_name="SL",
                            cg_kwargs=dict(absdelta=absdelta / 10., maxiter=100)),
    nonlinearly_update_kwargs=dict(
        minimize_kwargs=dict(
            name="SN",
            xtol=delta,
            cg_kwargs=dict(name=None),
            maxiter=5,
        )
    ),
    kl_kwargs=dict(
        minimize_kwargs=dict(
            name="M", absdelta=absdelta, cg_kwargs=dict(name="MCG"), maxiter=35
        )
    ),
    sample_mode="nonlinear_resample",
    resume=False)

gives the following error:

~/pro/python/nifty/nifty8/re/likelihood.py in joined_left_sqrt_metric(p, t, **pkw)
    641         def joined_left_sqrt_metric(p, t, **pkw):
    642             return (
--> 643                 self.left_sqrt_metric(p, t[lkey], **pkw) +
    644                 other.left_sqrt_metric(p, t[rkey], **pkw)
    645             )

TypeError: unsupported operand type(s) for +: 'dict' and 'dict'
Edited Jan 05, 2024 by Matteo Guardiani
Assignee Loading
Time tracking Loading