From 1f9be3efbee2902d71c780bf7b48fae5c92c78ee Mon Sep 17 00:00:00 2001
From: Jakob Roth <roth@mpa-garching.mpg.de>
Date: Tue, 2 Jul 2024 13:46:26 +0200
Subject: [PATCH 01/88] fast-resolve: initial version

---
 demo/cygnusa_2ghz_fast_resolve.cfg |  34 ++++
 demo/demo_fast_resolve.py          | 187 +++++++++++++++++++
 resolve/re/__init__.py             |   4 +-
 resolve/re/optimize.py             | 290 +++++++++++++++++++++++++++++
 resolve/re/radio_response.py       | 151 +++++++++++++++
 5 files changed, 665 insertions(+), 1 deletion(-)
 create mode 100644 demo/cygnusa_2ghz_fast_resolve.cfg
 create mode 100644 demo/demo_fast_resolve.py
 create mode 100644 resolve/re/optimize.py
 create mode 100644 resolve/re/radio_response.py

diff --git a/demo/cygnusa_2ghz_fast_resolve.cfg b/demo/cygnusa_2ghz_fast_resolve.cfg
new file mode 100644
index 00000000..9e4f0f24
--- /dev/null
+++ b/demo/cygnusa_2ghz_fast_resolve.cfg
@@ -0,0 +1,34 @@
+[setup]
+data=CYG-ALL-2052-2MHZ_RESOLVE_float64.npz
+psf pixels=300
+cache_noise_kernel=noise_kernel_cygnusa_2ghz_sm
+cache_response_kernel=response_kernel_cygnusa_2ghz_sm
+noise_scaling=True
+varcov=False
+
+
+
+[sky]
+freq mode = single
+polarization=I
+space npix x = 1024
+space npix y = 512
+space fov x = 0.05deg
+space fov y = 0.025deg
+
+stokesI diffuse space i0 zero mode offset = 18
+stokesI diffuse space i0 zero mode mean = 1
+stokesI diffuse space i0 zero mode stddev = 0.1
+stokesI diffuse space i0 fluctuations mean = 5
+stokesI diffuse space i0 fluctuations stddev = 1
+stokesI diffuse space i0 loglogavgslope mean = -2.0
+stokesI diffuse space i0 loglogavgslope stddev = 0.2
+stokesI diffuse space i0 flexibility mean =   1.2
+stokesI diffuse space i0 flexibility stddev = 0.4
+stokesI diffuse space i0 asperity mean =  0.2
+stokesI diffuse space i0 asperity stddev = 0.2
+
+point sources mode = single
+point sources locations = 0deg$0deg,0.35as$-0.22as
+point sources alpha = 0.5
+point sources q = 0.2
diff --git a/demo/demo_fast_resolve.py b/demo/demo_fast_resolve.py
new file mode 100644
index 00000000..c24d7ec7
--- /dev/null
+++ b/demo/demo_fast_resolve.py
@@ -0,0 +1,187 @@
+# %%
+import nifty8 as ift
+import nifty8.re as jft
+import resolve as rve
+import resolve.re as jrve
+import numpy as np
+import configparser
+import sys
+import matplotlib.pyplot as plt
+from matplotlib.colors import LogNorm
+import jax
+import pickle
+
+from jax import random
+import jax.numpy as jnp
+
+
+jax.config.update("jax_enable_x64", True)
+
+seed = 42
+
+key = random.PRNGKey(seed)
+conf_name = "cygnusa_2ghz.cfg"
+out_dir = "demo_fast-resolve"
+
+
+cfg = configparser.ConfigParser()
+cfg.read(conf_name)
+
+data = cfg["setup"]["data"]
+cnk = cfg["setup"]["cache_noise_kernel"]
+crk = cfg["setup"]["cache_response_kernel"]
+noise_scal = cfg["setup"].getboolean("noise_scaling")
+varcov = cfg["setup"].getboolean("varcov")
+if noise_scal and varcov:
+    raise ValueError()
+obs = rve.Observation.load(data)
+obs = obs.restrict_to_stokesi()
+N_inv = ift.makeOp(obs.weight)
+
+sky_model, model_dict = jrve.sky_model(cfg["sky"])
+R, R_l, RNR, RNR_l = jrve.build_exact_r(obs, cfg["sky"], cfg["setup"])
+# %%
+my_sky_model_func = lambda x: sky_model(x)[0, 0, 0, :, :]
+my_sky_model = jft.Model(my_sky_model_func, domain=sky_model.domain)
+invgamma_op = ift.InverseGammaOperator(R.domain, mode=1.0, mean=1.1)
+noise_scaling = lambda x: 1 / (invgamma_op.jax_expr(x["noise"]))
+noise_scaling = jft.Model(noise_scaling, domain={"noise": np.empty(R.domain.shape)})
+
+if noise_scal:
+    RNR_approx, N_inv_approx = jrve.build_approximations(
+        RNR,
+        RNR_l,
+        cache_noise_kernel=cnk,
+        cache_response_kernel=crk,
+        noise_scaling=noise_scaling,
+    )
+else:
+    RNR_approx, N_inv_approx = jrve.build_approximations(
+        RNR, RNR_l, cache_noise_kernel=cnk, cache_response_kernel=crk, varcov=varcov
+    )
+
+
+key, subkey = random.split(key)
+if noise_scaling or varcov:
+    init_pos = jft.random_like(subkey, {**my_sky_model.domain, **noise_scaling.domain})
+else:
+    init_pos = jft.random_like(subkey, my_sky_model.domain)
+
+
+def callback(pos, samp_at_pos, iter, n_major):
+    post_mean = jft.mean(tuple(my_sky_model(s) for s in samp_at_pos))
+    plt.imshow(post_mean.T, origin="lower", norm=LogNorm())
+    plt.colorbar()
+    plt.savefig(f"{out_dir}/major_{n_major}/sky_mean_{iter}.png", dpi=300)
+    plt.close()
+
+
+# inference
+d_new = R.adjoint(N_inv(obs.vis))
+
+d_new = jnp.array(d_new.val)
+init_pos = 1e-2 * jft.Vector(init_pos.copy())
+
+
+absdelta = 1e-10
+
+
+def get_draw_linear_kwargs(i):
+    kwargs = dict(
+        cg_name="cg",
+        cg_kwargs=dict(
+            absdelta=absdelta / 10.0,
+            maxiter=nstep_sampling(i),
+            miniter=nstep_sampling(i),
+        ),
+    )
+    return kwargs
+
+
+def get_nonlinearly_update_kwargs(i):
+    kwargs = dict(
+        minimize_kwargs=dict(
+            name=None,
+            xtol=1e-4,
+            cg_kwargs=dict(name=None),
+            maxiter=nstep_nl_sampling(i),
+            miniter=nstep_nl_sampling(i),
+        )
+    )
+    return kwargs
+
+
+def get_kl_kwargs(i):
+    if i < 6:
+        min_cg = 6
+    else:
+        min_cg = 20
+    kwargs = dict(
+        minimize_kwargs=dict(
+            name="newton",
+            absdelta=absdelta,
+            cg_kwargs=dict(name="ncg", miniter=min_cg),
+            maxiter=nstep_newton(i),
+            miniter=nstep_newton(i),
+            energy_reduction_factor=1e-3,
+        )
+    )
+    return kwargs
+
+
+def method(i):
+    return "linear_resample"
+
+
+def nstep_sampling(ii):
+    if ii < 4:
+        return 100
+    elif ii < 8:
+        return 500
+    elif ii < 12:
+        return 1000
+
+
+def nstep_nl_sampling(ii):
+    return 0 if ii < 20 else 10
+
+
+def nstep_newton(ii):
+    if ii < 4:
+        return 5
+    elif ii < 8:
+        return 10
+    elif ii < 12:
+        return 20
+    elif ii < 14:
+        return 30
+
+draw_linear_kwargs = get_draw_linear_kwargs
+nonlinearly_update_kwargs = get_nonlinearly_update_kwargs
+kl_kwargs = get_kl_kwargs
+
+optimize_kwargs = dict(
+    R=RNR,
+    R_approx=RNR_approx,
+    sky=my_sky_model,
+    N_inv_sqrt=N_inv_approx,
+    data=d_new,
+    pos=init_pos,
+    draw_linear_kwargs=draw_linear_kwargs,
+    nonlinearly_update_kwargs=nonlinearly_update_kwargs,
+    kl_kwargs=kl_kwargs,
+    n_samples=2,
+    n_iter=2,
+    n_major_step=20,
+    key=key,
+    callback=callback,
+    out_dir=out_dir,
+    resume=True,
+    init_samples=None,
+    noise_scaling=noise_scal,
+    varcov=varcov,
+    varcov_op=noise_scaling,
+    method=method,
+)
+pos, samples, key = jrve.optimize(**optimize_kwargs)
+# %%
diff --git a/resolve/re/__init__.py b/resolve/re/__init__.py
index ce2ba46e..f095f226 100644
--- a/resolve/re/__init__.py
+++ b/resolve/re/__init__.py
@@ -1,3 +1,5 @@
 
 from .sky_model import sky_model_diffuse, sky_model_points, sky_model
-from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc
\ No newline at end of file
+from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc
+from .radio_response import build_exact_r, build_approximations
+from .optimize import optimize
\ No newline at end of file
diff --git a/resolve/re/optimize.py b/resolve/re/optimize.py
new file mode 100644
index 00000000..8ea815f1
--- /dev/null
+++ b/resolve/re/optimize.py
@@ -0,0 +1,290 @@
+import jax
+import numpy as np
+import pickle
+import jax.numpy as jnp
+import nifty8.re as jft
+import nifty8 as ift
+
+from functools import partial
+from os import makedirs
+from os.path import isfile
+
+
+from nifty8.re.optimize_kl import _kl_vg
+from nifty8.re.optimize_kl import _kl_met
+from nifty8.re.optimize_kl import draw_linear_residual
+from nifty8.re.optimize_kl import nonlinearly_update_residual
+from nifty8.re.optimize_kl import get_status_message
+from nifty8.re.evi import _nonlinearly_update_residual_functions
+
+
+def optimize(
+    R,
+    R_approx,
+    sky,
+    N_inv_sqrt,
+    data,
+    pos,
+    draw_linear_kwargs,
+    nonlinearly_update_kwargs,
+    kl_kwargs,
+    n_samples,
+    n_iter,
+    n_major_step,
+    key,
+    callback=None,
+    out_dir=None,
+    resume=False,
+    init_samples=None,
+    noise_scaling=False,
+    varcov=False,
+    varcov_op=None,
+    method="linear_sample",
+    save_all=True,
+):
+    if varcov:
+        jax_0_data = jnp.broadcast_to(0.0 + 0j, data.shape)
+    else:
+        jax_0_data = jnp.broadcast_to(0.0, data.shape)
+    if not out_dir == None:
+        makedirs(out_dir, exist_ok=True)
+    lfile = f"{out_dir}/last_started_major"
+    last_started_major = 0
+    last_finished_index = -1
+    if resume and isfile(lfile):
+        with open(lfile) as f:
+            last_started_major = int(f.read())
+        lfile = f"{out_dir}/major_{last_started_major}/last_finished_iteration"
+        if resume and isfile(lfile):
+            with open(lfile) as f:
+                last_finished_index = int(f.read())
+
+    def residual_signal_response(x, old_reconstruction):
+        return R_approx(sky(x) - old_reconstruction)
+
+    def noise_weighted_residual(x, old_reconstruction, residual_data):
+        if noise_scaling:
+            return N_inv_sqrt(
+                {
+                    "sky": residual_signal_response(x, old_reconstruction)
+                    - residual_data,
+                    "noise": x["noise"],
+                }
+            )
+        else:
+            return N_inv_sqrt(
+                {"sky": residual_signal_response(x, old_reconstruction) - residual_data}
+            )
+
+    def noise_weighted_residual_varcov(x, old_reconstruction, residual_data):
+        return [
+            N_inv_sqrt(
+                {"sky": residual_signal_response(x, old_reconstruction) - residual_data}
+            ),
+            varcov_op(x),
+        ]
+
+    kl_map = jax.vmap
+    kl_reduce = partial(jax.tree_map, partial(jnp.mean, axis=0))
+
+    def get_lh(old_reconstruction, residual_data):
+        if varcov:
+            lh = jft.VariableCovarianceGaussian(jax_0_data, iscomplex=True).amend(
+                partial(
+                    noise_weighted_residual_varcov,
+                    old_reconstruction=old_reconstruction,
+                    residual_data=residual_data,
+                )
+            )
+        else:
+            lh = jft.Gaussian(jax_0_data).amend(
+                partial(
+                    noise_weighted_residual,
+                    old_reconstruction=old_reconstruction,
+                    residual_data=residual_data,
+                )
+            )
+        return lh
+
+    @jax.jit
+    def my_kl_vg(primals, primals_samples, *, old_reconstruction, residual_data):
+        lh = get_lh(old_reconstruction, residual_data)
+        return _kl_vg(lh, primals, primals_samples, map=kl_map, reduce=kl_reduce)
+
+    @jax.jit
+    def my_kl_metric(
+        primals, tangents, primals_samples, *, old_reconstruction, residual_data
+    ):
+        lh = get_lh(old_reconstruction, residual_data)
+        return _kl_met(
+            lh, primals, tangents, primals_samples, map=kl_map, reduce=kl_reduce
+        )
+
+    @jax.jit
+    def my_draw_linear_residual(
+        pos, key, *, old_reconstruction, residual_data, **kwargs
+    ):
+        lh = get_lh(old_reconstruction, residual_data)
+        return draw_linear_residual(lh, pos, key, **kwargs)
+
+    def my_nonlinearly_update_residual(
+        pos,
+        residual_sample,
+        metric_sample_key,
+        metric_sample_sign,
+        *,
+        old_reconstruction,
+        residual_data,
+        **kwargs,
+    ):
+        lh = get_lh(old_reconstruction, residual_data)
+        _nonlin_funcs = _nonlinearly_update_residual_functions(
+            likelihood=lh,
+            jit=jax.jit,
+        )
+        return nonlinearly_update_residual(
+            lh,
+            pos,
+            residual_sample,
+            metric_sample_key,
+            metric_sample_sign,
+            _nonlinear_update_funcs=_nonlin_funcs,
+            **kwargs,
+        )
+
+    def my_stat_mes(samples, state, *, old_reconstruction, residual_data, **kwargs):
+        lh = get_lh(old_reconstruction, residual_data)
+        return get_status_message(
+            samples, state, lh.normalized_residual, name="optVI test", **kwargs
+        )
+
+    if last_finished_index > -1:
+        if save_all:
+            opt_vi_state = pickle.load(
+                open(
+                    f"{out_dir}/major_{last_started_major}/optVIstate_it_{last_finished_index}.p",
+                    "rb",
+                )
+            )
+            samples = pickle.load(
+                open(
+                    f"{out_dir}/major_{last_started_major}/samples_{last_finished_index}.p",
+                    "rb",
+                )
+            )
+        else:
+            opt_vi_state = pickle.load(
+                open(
+                    f"{out_dir}/last_optVIstate.p",
+                    "rb",
+                )
+            )
+            samples = pickle.load(
+                open(
+                    f"{out_dir}/last_samples.p",
+                    "rb",
+                )
+            )
+
+        sub_val = jft.mean(tuple(sky(s) for s in samples))
+        post_mean = np.array(sub_val)
+        post_mean = ift.makeField(R.domain, post_mean)
+        residual_data = data - R(post_mean).val
+    else:
+        opt_vi_state = None
+        samples = None
+        if not init_samples == None:
+            sub_val = jft.mean(tuple(sky(s) for s in init_samples))
+            post_mean = np.array(sub_val)
+            post_mean = ift.makeField(R.domain, post_mean)
+            residual_data = data - R(post_mean).val
+        else:
+            residual_data = data
+            sub_val = jnp.zeros(data.shape, dtype=data.dtype)
+
+    # init optVI
+    opt_vi = jft.OptimizeVI(
+        None,
+        n_iter,
+        _kl_value_and_grad=my_kl_vg,
+        _kl_metric=my_kl_metric,
+        _draw_linear_residual=my_draw_linear_residual,
+        _nonlinearly_update_residual=my_nonlinearly_update_residual,
+        _get_status_message=my_stat_mes,
+        residual_map="lmap",
+    )
+
+    update_kwargs = dict(
+        old_reconstruction=sub_val,
+        residual_data=residual_data,
+    )
+
+    # call optVI.init
+    if opt_vi_state is None:
+        print("init opt vi")
+        opt_vi_state = opt_vi.init_state(
+            key,
+            nit=0,
+            n_samples=n_samples,
+            draw_linear_kwargs=draw_linear_kwargs,
+            nonlinearly_update_kwargs=nonlinearly_update_kwargs,
+            kl_kwargs=kl_kwargs,
+            sample_mode=method,
+        )
+    if samples is None:
+        samples = jft.Samples(pos=pos, samples=None, keys=None)
+
+    for n_major in range(last_started_major, n_major_step):
+        if not out_dir == None:
+            makedirs(out_dir + f"/major_{n_major}", exist_ok=True)
+        minor_start = opt_vi_state.nit - n_iter * n_major
+        print("minor_start: ", minor_start)
+        print("n_iter: ", n_iter)
+        print("n_major:", n_major)
+        print("opt_vi_state.nit: ", opt_vi_state.nit)
+        for i in range(minor_start, n_iter):
+            # do opt_vi update
+            samples, opt_vi_state = opt_vi.update(
+                samples, opt_vi_state, **update_kwargs
+            )
+            msg = opt_vi.get_status_message(
+                samples,
+                opt_vi_state,
+                old_reconstruction=sub_val,
+                residual_data=residual_data,
+            )
+            jft.logger.info(msg)
+
+            if not callback == None:
+                callback(opt_vi_state.minimization_state.x, samples, i, n_major)
+            if not out_dir == None:
+                if save_all:
+                    pickle.dump(
+                        opt_vi_state,
+                        open(f"{out_dir}/major_{n_major}/optVIstate_it_{i}.p", "wb"),
+                    )
+                    pickle.dump(
+                        samples, open(f"{out_dir}/major_{n_major}/samples_{i}.p", "wb")
+                    )
+                else:
+                    pickle.dump(
+                        opt_vi_state,
+                        open(f"{out_dir}/last_optVIstate.p", "wb"),
+                    )
+                    pickle.dump(samples, open(f"{out_dir}/last_samples.p", "wb"))
+                with open(
+                    f"{out_dir}/major_{n_major}/last_finished_iteration", "w"
+                ) as f:
+                    f.write(str(i))
+                with open(f"{out_dir}/last_started_major", "w") as f:
+                    f.write(str(n_major))
+
+        sub_val = jft.mean(tuple(sky(s) for s in samples))
+        post_mean = np.array(sub_val)
+        post_mean = ift.makeField(R.domain, post_mean)
+        residual_data = data - R(post_mean).val
+
+        update_kwargs["old_reconstruction"] = sub_val
+        update_kwargs["residual_data"] = residual_data
+
+    return samples.pos, samples, key
diff --git a/resolve/re/radio_response.py b/resolve/re/radio_response.py
new file mode 100644
index 00000000..b0b70b75
--- /dev/null
+++ b/resolve/re/radio_response.py
@@ -0,0 +1,151 @@
+import resolve as rve
+import nifty8 as ift
+import numpy as np
+import pickle
+import jax.numpy as jnp
+from jax.lax import slice as jax_slice
+
+def build_exact_r(obs, conf_sky, conf_setup):
+    sp_sky_dom =rve.sky_model._spatial_dom(conf_sky)
+    sky_dom = rve.default_sky_domain(sdom=sp_sky_dom)
+    R = rve.InterferometryResponse(obs, sky_dom, True, 1e-9, verbosity=1, nthreads=8)
+
+    psf_pixels = conf_setup.getfloat("psf pixels")
+    full_psf0 = min(2*psf_pixels, sp_sky_dom.shape[0])
+    full_psf1 = min(2*psf_pixels, sp_sky_dom.shape[1])
+    sp_sky_dom_l = (sp_sky_dom.shape[0] + full_psf0, sp_sky_dom.shape[1] + full_psf1)
+    sp_sky_dom_l = ift.RGSpace(sp_sky_dom_l, distances=sp_sky_dom.distances)
+    sky_dom_l = rve.default_sky_domain(sdom=sp_sky_dom_l)
+    R_l = rve.InterferometryResponse(obs, sky_dom_l, True, 1e-9, verbosity=1, nthreads=8)
+
+    dch_l = ift.DomainChangerAndReshaper(R_l.domain[3], R_l.domain)
+    R_l = R_l @ dch_l
+    dch = ift.DomainChangerAndReshaper(R.domain[3], R.domain)
+    R = R @ dch
+
+    N_inv = ift.DiagonalOperator(obs.weight)
+    RNR = R.adjoint @ N_inv @ R
+    RNR_l = R_l.adjoint @ N_inv @ R_l
+
+    return R, R_l, RNR, RNR_l
+
+
+def compute_PSF(new_R, n_pix0, n_pix1):
+    dom = new_R.domain
+    shp = dom.shape
+    FFT = ift.FFTOperator(new_R.domain)
+
+    delta = np.zeros(shp)
+    delta[shp[0]//2, shp[1]//2] = 1 / dom.scalar_weight()
+    delta = ift.makeField(dom, delta)
+    kernel = new_R(delta)
+
+    # zero kernel
+    sh0 = shp[0]//2
+    sh1 = shp[1]//2
+    z_kern = np.zeros_like(kernel.val)
+    z_kern[sh0-n_pix0:sh0+n_pix0,sh1-n_pix1:sh1+n_pix1] = kernel.val[sh0 - n_pix0:sh0+n_pix0,sh1-n_pix1:sh1+n_pix1]
+
+    pr_kern = np.roll(z_kern, -shp[0]//2, axis=0)
+    pr_kern = np.roll(pr_kern, -shp[1]//2, axis=1)
+    pr_kern = ift.makeField(FFT.domain, pr_kern)
+    from matplotlib.colors import LogNorm
+    ift.single_plot(pr_kern, norm=LogNorm())
+    fourier_kern = FFT(pr_kern)
+    return fourier_kern.val
+
+def compute_approx_noise_kern(new_R, relativ_min_val=0.):
+    dom = new_R.domain
+    shp = dom.shape
+    FFT = ift.FFTOperator(new_R.domain)
+
+    delta = np.zeros(shp)
+    delta[shp[0]//2, shp[1]//2] = 1 / dom.scalar_weight()
+    delta = ift.makeField(dom, delta)
+    kernel = new_R(delta).val
+    kernel = np.roll(kernel, -shp[0]//2, axis=0)
+    kernel = np.roll(kernel, -shp[1]//2, axis=1)
+    kernel = ift.makeField(new_R.target, kernel)
+    FFT = ift.FFTOperator(new_R.domain)
+    max_val = np.max(FFT(kernel).abs().val)
+    min_val = relativ_min_val * max_val
+    min_val = ift.full(FFT.target, min_val)
+    min_val_adder = ift.Adder(min_val)
+
+    pos_eig_val = ift.Operator.identity_operator(FFT.target).exp()
+    pos_eig_val = min_val_adder @ pos_eig_val
+    rls1 = ift.Realizer(pos_eig_val.target)
+    rls2 = ift.Realizer(FFT.domain)
+
+    kernel_pos = rls2 @ FFT.inverse @ rls1.adjoint @ pos_eig_val
+
+    cov = ift.ScalingOperator(kernel_pos.target, 1e-2*max_val)
+    lh = ift.GaussianEnergy(data=kernel, inverse_covariance=cov.inverse) @ kernel_pos
+    init_pos = (FFT(kernel) - min_val).abs().log()
+    energy = ift.EnergyAdapter(position=init_pos, op=lh, want_metric=True)
+
+    ic_newton = ift.DeltaEnergyController(name='Newton', iteration_limit=80, tol_rel_deltaE=0)
+    #minimizer = ift.NewtonCG(ic_newton, max_cg_iterations=400, energy_reduction_factor=1e-3)
+    minimizer = ift.NewtonCG(ic_newton)
+    res = minimizer(energy)[0].position
+    return pos_eig_val(res).val
+
+def build_approximations(RNR, RNR_l, noise_scaling=None, varcov=False, cache_noise_kernel='None', cache_response_kernel='None'):
+    shp = RNR.domain.shape
+    shp_l = RNR_l.domain.shape
+    # assert(shp[0] == shp[1])
+    # assert(shp_l[0] == shp_l[1])
+    FFT = ift.FFTOperator(RNR_l.domain)
+    FFT_s = ift.FFTOperator(RNR.domain)
+
+    # build approximate response
+    n_psf_pix0 = (shp_l[0] - shp[0])
+    n_psf_pix1 = (shp_l[1] - shp[1])
+    n_padding0 = n_psf_pix0 // 2
+    n_padding1 = n_psf_pix1 // 2
+    if not cache_response_kernel == 'None':
+        try:
+            psf_kernel = pickle.load(open(f"{cache_response_kernel}.p", "rb"))
+        except:
+            psf_kernel = compute_PSF(RNR_l, n_psf_pix0, n_psf_pix1)
+            pickle.dump(psf_kernel, open(f"{cache_response_kernel}.p", "wb"))
+    else:
+            psf_kernel = compute_PSF(RNR_l, n_psf_pix, n_psf_pix1)
+    psf_kernel = jnp.array(psf_kernel)
+    apply_psf_kern = lambda x: FFT.inverse.jax_expr(psf_kernel * FFT.jax_expr(x)).real
+
+    slicer = lambda x: jax_slice(
+        x, (n_padding0, n_padding1), (n_padding0+ shp[0], n_padding1+ shp[1])
+    )
+    padder = lambda x: jnp.pad(x, ((n_padding0, n_padding0), (n_padding1, n_padding1)))
+
+    RNR_approx = lambda x: slicer(apply_psf_kern(padder(x)))
+
+    # build approximate N inverse
+    if not cache_noise_kernel == 'None':
+        try:
+            noise_kernel = pickle.load(open(f"{cache_noise_kernel}.p", "rb"))
+        except:
+            noise_kernel = compute_approx_noise_kern(RNR, 1e-3)
+            pickle.dump(noise_kernel, open(f"{cache_noise_kernel}.p", "wb"))
+    else:
+        noise_kernel = compute_approx_noise_kern(RNR, 1e-3)
+    noise_kernel_inv_sqrt = 1. / np.sqrt(noise_kernel)
+    noise_kernel_inv_sqrt = jnp.array(noise_kernel_inv_sqrt)
+    if varcov:
+        fl = ift.full(FFT_s.target, 1.)
+        vol = FFT_s(FFT_s.adjoint(fl)).real.mean().val
+        fac = np.sqrt(1/vol)
+        apply_n_sqinv_kern = lambda x: fac*noise_kernel_inv_sqrt * FFT_s.jax_expr(x['sky'])
+    elif not noise_scaling is None:
+        apply_n_sqinv_kern = lambda x: FFT_s.inverse.jax_expr(
+            noise_scaling(x) * noise_kernel_inv_sqrt * FFT_s.jax_expr(x['sky'])
+        ).real
+    else:
+        apply_n_sqinv_kern = lambda x: FFT_s.inverse.jax_expr(
+            noise_kernel_inv_sqrt * FFT_s.jax_expr(x['sky'])
+        ).real
+
+
+    return RNR_approx, apply_n_sqinv_kern
+
-- 
GitLab


From 718cb9975aa3af2851296d6ae69da3bf6b50e307 Mon Sep 17 00:00:00 2001
From: Jakob Roth <roth@mpa-garching.mpg.de>
Date: Fri, 17 Jan 2025 15:08:28 +0100
Subject: [PATCH 02/88] re.sky_model: rework point sources model

---
 resolve/re/sky_model.py | 99 ++++++++++++++++++++++-------------------
 1 file changed, 52 insertions(+), 47 deletions(-)

diff --git a/resolve/re/sky_model.py b/resolve/re/sky_model.py
index 790bb958..f07b58b7 100644
--- a/resolve/re/sky_model.py
+++ b/resolve/re/sky_model.py
@@ -1,14 +1,15 @@
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 # Author: Jakob Roth
 
-from numpy import full
 import nifty8.re as jft
 import jax
 from jax import numpy as jnp
+import numpy as np
 
 
 from ..sky_model import _spatial_dom
-from ..sky_model import sky_model_points as rve_sky_pts
+from ..constants import str2rad
+
 
 def build_cf(prefix, conf, shape, dist):
     zmo = conf.getfloat(f"{prefix} zero mode offset")
@@ -34,7 +35,7 @@ def build_cf(prefix, conf, shape, dist):
     cfm = jft.CorrelatedFieldMaker(prefix)
     cfm.set_amplitude_total_offset(**cf_zm)
     cfm.add_fluctuations(
-        shape, distances=dist, **cf_fl, prefix="", non_parametric_kind='power'
+        shape, distances=dist, **cf_fl, prefix="", non_parametric_kind="power"
     )
     amps = cfm.get_normalized_amplitudes()
     cfm = cfm.finalize()
@@ -42,8 +43,6 @@ def build_cf(prefix, conf, shape, dist):
     return cfm, additional
 
 
-
-
 def sky_model_diffuse(cfg):
     if not cfg["freq mode"] == "single":
         raise NotImplementedError("FIXME: only implemented for single frequency")
@@ -52,63 +51,69 @@ def sky_model_diffuse(cfg):
     sky_dom = _spatial_dom(cfg)
     bg_shape = sky_dom.shape
     bg_distances = sky_dom.distances
-    bg_log_diff, additional = build_cf('stokesI diffuse space i0', cfg, bg_shape, bg_distances)
-    full_shape = (1,1,1)+bg_shape
-
+    bg_log_diff, additional = build_cf(
+        "stokesI diffuse space i0", cfg, bg_shape, bg_distances
+    )
+    full_shape = (1, 1, 1) + bg_shape
 
     def bg_diffuse(x):
         return jnp.broadcast_to(
-            jnp.exp(bg_log_diff(x['stokesI diffuse space i0'])), full_shape
+            jnp.exp(bg_log_diff(x["stokesI diffuse space i0"])), full_shape
         )
 
-
     bg_diffuse_model = jft.Model(
-        bg_diffuse,
-        domain={'stokesI diffuse space i0': bg_log_diff.domain}
-        )
+        bg_diffuse, domain={"stokesI diffuse space i0": bg_log_diff.domain}
+    )
 
     return bg_diffuse_model, additional
 
-def get_pts_postions(cfg):
-    rsp, _ = rve_sky_pts(cfg)
-    return rsp._ops[0]._inds[0], rsp._ops[0]._inds[1]
-
-def get_inv_gamma(cfg):
-    rsp, _ = rve_sky_pts(cfg)
-    return rsp._ops[1].jax_expr
 
 def jax_insert(x, ptsx, ptsy, bg):
     bg = bg.at[:, :, :, ptsx, ptsy].set(x)
     return bg
 
-def sky_model_points(cfg, bg=None):
-    if not cfg["freq mode"] == "single":
-        raise NotImplementedError("FIXME: only implemented for single frequency")
-    if not cfg["polarization"] == "I":
-        raise NotImplementedError("FIXME: only implemented for stokes I")
-
-    if not cfg["point sources mode"] == "single":
-        raise NotImplementedError(
-            "FIXME: point sources only implemented for mode single"
-            )
-    else:
-        sky_dom = _spatial_dom(cfg)
-        shp = sky_dom.shape
-        full_shp = (1,1,1)+shp
-        ptsx, ptsy = get_pts_postions(cfg)
-        inv_gamma = get_inv_gamma(cfg)
-        pts_shp = (len(ptsx),)
-        dom = {'points': jax.ShapeDtypeStruct(pts_shp, dtype=jnp.float64)}
-        if bg is None:
-            bg = jnp.zeros(full_shp)
-            pts_func = lambda x: jax_insert(inv_gamma(x['points']), ptsx=ptsx, ptsy=ptsy, bg=bg)
-        else:
-            pts_func = lambda x: jax_insert(inv_gamma(x['points']), ptsx=ptsx, ptsy=ptsy, bg=bg(x))
-            dom = {**dom, **bg.domain}
-    pts_model = jft.Model(pts_func, domain=dom)
-    additional = {}
-    return pts_model, additional
 
+def sky_model_points(cfg, bg=None):
+    if cfg["freq mode"] == "single":
+        if cfg["polarization"] == "I":
+            if cfg["point sources mode"] == "single":
+                ppos = []
+                sky_dom = _spatial_dom(cfg)
+                s = cfg["point sources locations"]
+                for xy in s.split(","):
+                    x, y = xy.split("$")
+                    ppos.append((str2rad(x), str2rad(y)))
+                ppos = np.array(ppos)
+                dx = np.array(sky_dom.distances)
+                center = np.array(sky_dom.shape) // 2
+                inds = np.unique(np.round(ppos / dx + center).astype(int).T, axis=1)
+                indsx, indsy = inds
+                alpha = cfg.getfloat("point sources alpha")
+                q = cfg.getfloat("point sources q")
+                sky_dom = _spatial_dom(cfg)
+                shp = sky_dom.shape
+                full_shp = (1, 1, 1) + shp
+                inv_gamma = jft.InvGammaPrior(
+                    a=alpha,
+                    scale=q,
+                    name="points",
+                    shape=jax.ShapeDtypeStruct((len(indsx),), float),
+                )
+                if bg is None:
+                    bg = jnp.zeros(full_shp)
+                    pts_func = lambda x: jax_insert(
+                        inv_gamma(x), ptsx=indsx, ptsy=indsy, bg=bg
+                    )
+                    dom = inv_gamma.domain
+                else:
+                    pts_func = lambda x: jax_insert(
+                        inv_gamma(x), ptsx=indsx, ptsy=indsy, bg=bg(x)
+                    )
+                    dom = {**inv_gamma.domain, **bg.domain}
+                pts_model = jft.Model(pts_func, domain=dom)
+                additional = {}
+                return pts_model, additional
+    raise NotImplementedError("FIXME: Not Implemented for selected mode.")
 
 
 def sky_model(cfg):
-- 
GitLab


From de9ff4fb2e66c319fbf2ee557dee54e468d6cd14 Mon Sep 17 00:00:00 2001
From: Jakob Roth <roth@mpa-garching.mpg.de>
Date: Tue, 4 Feb 2025 13:59:10 +0100
Subject: [PATCH 03/88] fast-resolve: update to new NIFTy version

---
 demo/demo_fast_resolve.py    | 17 +++++++++++----
 resolve/re/optimize.py       |  8 ++-----
 resolve/re/radio_response.py | 42 ++++++++++++++++++++++++++++++------
 3 files changed, 51 insertions(+), 16 deletions(-)

diff --git a/demo/demo_fast_resolve.py b/demo/demo_fast_resolve.py
index c24d7ec7..6017aa83 100644
--- a/demo/demo_fast_resolve.py
+++ b/demo/demo_fast_resolve.py
@@ -20,7 +20,7 @@ jax.config.update("jax_enable_x64", True)
 seed = 42
 
 key = random.PRNGKey(seed)
-conf_name = "cygnusa_2ghz.cfg"
+conf_name = "cygnusa_2ghz_fast_resolve.cfg"
 out_dir = "demo_fast-resolve"
 
 
@@ -43,9 +43,18 @@ R, R_l, RNR, RNR_l = jrve.build_exact_r(obs, cfg["sky"], cfg["setup"])
 # %%
 my_sky_model_func = lambda x: sky_model(x)[0, 0, 0, :, :]
 my_sky_model = jft.Model(my_sky_model_func, domain=sky_model.domain)
-invgamma_op = ift.InverseGammaOperator(R.domain, mode=1.0, mean=1.1)
-noise_scaling = lambda x: 1 / (invgamma_op.jax_expr(x["noise"]))
-noise_scaling = jft.Model(noise_scaling, domain={"noise": np.empty(R.domain.shape)})
+mode = 1
+mean = 1.1
+a = 2 / (mean / mode - 1) + 1
+scale = mode * (a + 1)
+inv_gamma = jft.InvGammaPrior(
+                    a=a,
+                    scale=scale,
+                    name="noise",
+                    shape=jax.ShapeDtypeStruct(R.domain.shape, float),
+                )
+noise_scaling = lambda x: 1 / (inv_gamma(x))
+noise_scaling = jft.Model(noise_scaling, domain=inv_gamma.domain)
 
 if noise_scal:
     RNR_approx, N_inv_approx = jrve.build_approximations(
diff --git a/resolve/re/optimize.py b/resolve/re/optimize.py
index 8ea815f1..0835e96e 100644
--- a/resolve/re/optimize.py
+++ b/resolve/re/optimize.py
@@ -15,7 +15,7 @@ from nifty8.re.optimize_kl import _kl_met
 from nifty8.re.optimize_kl import draw_linear_residual
 from nifty8.re.optimize_kl import nonlinearly_update_residual
 from nifty8.re.optimize_kl import get_status_message
-from nifty8.re.evi import _nonlinearly_update_residual_functions
+# from nifty8.re.evi import _nonlinearly_update_residual_functions
 
 
 def optimize(
@@ -138,17 +138,13 @@ def optimize(
         **kwargs,
     ):
         lh = get_lh(old_reconstruction, residual_data)
-        _nonlin_funcs = _nonlinearly_update_residual_functions(
-            likelihood=lh,
-            jit=jax.jit,
-        )
         return nonlinearly_update_residual(
             lh,
             pos,
             residual_sample,
             metric_sample_key,
             metric_sample_sign,
-            _nonlinear_update_funcs=_nonlin_funcs,
+            # _nonlinear_update_funcs=_nonlin_funcs,
             **kwargs,
         )
 
diff --git a/resolve/re/radio_response.py b/resolve/re/radio_response.py
index b0b70b75..7ec7728a 100644
--- a/resolve/re/radio_response.py
+++ b/resolve/re/radio_response.py
@@ -3,7 +3,32 @@ import nifty8 as ift
 import numpy as np
 import pickle
 import jax.numpy as jnp
+
 from jax.lax import slice as jax_slice
+from functools import partial
+
+def get_jax_fft(domain, target, inverse):
+    if inverse:
+        if domain.harmonic:
+            func = jnp.fft.fftn
+            fct = 1.
+        else:
+            func = jnp.fft.ifftn
+            fct = domain.size
+        fct *= target.scalar_dvol
+    else:
+        if domain.harmonic:
+            func = jnp.fft.ifftn
+            fct = domain.size
+        else:
+            func = jnp.fft.fftn
+            fct = 1.
+        fct *= domain.scalar_dvol
+
+    def fft(x, fct, func):
+        return fct * func(x) if fct != 1 else func(x)
+
+    return partial(fft, fct=fct, func=func)
 
 def build_exact_r(obs, conf_sky, conf_setup):
     sp_sky_dom =rve.sky_model._spatial_dom(conf_sky)
@@ -96,7 +121,12 @@ def build_approximations(RNR, RNR_l, noise_scaling=None, varcov=False, cache_noi
     # assert(shp[0] == shp[1])
     # assert(shp_l[0] == shp_l[1])
     FFT = ift.FFTOperator(RNR_l.domain)
+    fft_jax = get_jax_fft(FFT.domain[0], FFT.target[0], False)
+    fft_jax_inv = get_jax_fft(FFT.domain[0], FFT.target[0], True)
     FFT_s = ift.FFTOperator(RNR.domain)
+    fft_jax_s = get_jax_fft(FFT_s.domain[0], FFT_s.target[0], False)
+    fft_jax_inv_s = get_jax_fft(FFT_s.domain[0], FFT_s.target[0], True)
+
 
     # build approximate response
     n_psf_pix0 = (shp_l[0] - shp[0])
@@ -112,7 +142,7 @@ def build_approximations(RNR, RNR_l, noise_scaling=None, varcov=False, cache_noi
     else:
             psf_kernel = compute_PSF(RNR_l, n_psf_pix, n_psf_pix1)
     psf_kernel = jnp.array(psf_kernel)
-    apply_psf_kern = lambda x: FFT.inverse.jax_expr(psf_kernel * FFT.jax_expr(x)).real
+    apply_psf_kern = lambda x: fft_jax_inv(psf_kernel * fft_jax(x)).real
 
     slicer = lambda x: jax_slice(
         x, (n_padding0, n_padding1), (n_padding0+ shp[0], n_padding1+ shp[1])
@@ -136,14 +166,14 @@ def build_approximations(RNR, RNR_l, noise_scaling=None, varcov=False, cache_noi
         fl = ift.full(FFT_s.target, 1.)
         vol = FFT_s(FFT_s.adjoint(fl)).real.mean().val
         fac = np.sqrt(1/vol)
-        apply_n_sqinv_kern = lambda x: fac*noise_kernel_inv_sqrt * FFT_s.jax_expr(x['sky'])
+        apply_n_sqinv_kern = lambda x: fac*noise_kernel_inv_sqrt * fft_jax_s(x['sky'])
     elif not noise_scaling is None:
-        apply_n_sqinv_kern = lambda x: FFT_s.inverse.jax_expr(
-            noise_scaling(x) * noise_kernel_inv_sqrt * FFT_s.jax_expr(x['sky'])
+        apply_n_sqinv_kern = lambda x: fft_jax_inv_s(
+            noise_scaling(x) * noise_kernel_inv_sqrt * fft_jax_s(x['sky'])
         ).real
     else:
-        apply_n_sqinv_kern = lambda x: FFT_s.inverse.jax_expr(
-            noise_kernel_inv_sqrt * FFT_s.jax_expr(x['sky'])
+        apply_n_sqinv_kern = lambda x: fft_jax_inv_s(
+            noise_kernel_inv_sqrt * fft_jax_s(x['sky'])
         ).real
 
 
-- 
GitLab


From 3ae0f516c39031416d3b84bb9a681e7067cabb00 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 22 Jan 2025 14:41:48 +0100
Subject: [PATCH 04/88] added calibration file to resolve.re

---
 resolve/re/calibration.py | 0
 1 file changed, 0 insertions(+), 0 deletions(-)
 create mode 100644 resolve/re/calibration.py

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
new file mode 100644
index 00000000..e69de29b
-- 
GitLab


From fb2fce4ea434a55a408329b2d05570df090b6b3f Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Fri, 24 Jan 2025 19:20:35 +0100
Subject: [PATCH 05/88] Added grid interpolator and model for calibration

---
 resolve/re/calibration.py | 49 +++++++++++++++++++++++++++++++++++++++
 1 file changed, 49 insertions(+)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index e69de29b..58e9f49f 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -0,0 +1,49 @@
+import jax
+import nifty8.re as jft
+import jax.scipy as jsc
+import jax.numpy as jnp
+from functools import partial
+
+class CalibrationDistribution(jft.Model):
+    def __init__(self,observation, phase_fields, log_amplitude_fields,dt):
+        ap = observation.antenna_positions
+        self._cop1 = CalibrationInterpolator(dt, ap.ant1, ap.time)
+        self._cop2 = CalibrationInterpolator(dt, ap.ant2, ap.time)
+
+        self._phases = phase_fields
+        self._logamps = log_amplitude_fields
+
+        super().__init__(init=self._phases.init | self._logamps.init)
+
+    def __call__(self,x):
+        res_logamp = jnp.real(self._cop1(self._logamps(x)) + self._cop2(self._logamps(x)))
+        res_phase = jnp.real(self._cop1(self._phases(x)) - self._cop2(self._phases(x)))*1j
+
+        return jnp.exp(res_logamp + res_phase)
+
+class CalibrationInterpolator(jft.Model):
+    def __init__(self,ant_col, time_col,dt,cfs):
+        # Input shape follows (n_pol,n_antenna,n_timesteps,n_freq)
+        # Output shape follows (n_pol,n_visibilities,n_freq)
+        # The model assumes that you have grid with constant width in antenna and times
+
+        coords = [ant_col,time_col/dt]
+
+        n_pol, _, _, n_freq = cfs.shape
+
+        self._li = partial(jsc.ndimage.map_coordinates,coordinates=coords,order=1)
+        n_vis = jax.eval_shape(self._li(cfs[0, :, :, 0])).shape[0]
+
+        self._output_shape = (n_pol,n_vis,n_freq)
+        
+        super().__init__(init=self._cf)
+
+    def __call__(self,x):
+        res = jnp.empty(self._output_shape)
+        n_pol, _, n_freq = self._output_shape
+        for pol in range(n_pol):
+            for freq in range(n_freq):
+                val = x[pol, :, :, freq]
+                res[pol, :, freq] = self._li(val)
+        
+        return res
\ No newline at end of file
-- 
GitLab


From bddd4de44ec7c85548291f26fd5002fac661c2e3 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 4 Feb 2025 17:17:05 +0100
Subject: [PATCH 06/88] Removed model evaluation in CalibrationInterpolator for
 output dimensionality as it is already encoded in the observation

---
 resolve/re/calibration.py | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index 58e9f49f..eec35ffa 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -7,8 +7,9 @@ from functools import partial
 class CalibrationDistribution(jft.Model):
     def __init__(self,observation, phase_fields, log_amplitude_fields,dt):
         ap = observation.antenna_positions
-        self._cop1 = CalibrationInterpolator(dt, ap.ant1, ap.time)
-        self._cop2 = CalibrationInterpolator(dt, ap.ant2, ap.time)
+        target_shape = observation.vis.shape
+        self._cop1 = CalibrationInterpolator(dt, ap.ant1, ap.time, target_shape)
+        self._cop2 = CalibrationInterpolator(dt, ap.ant2, ap.time, target_shape)
 
         self._phases = phase_fields
         self._logamps = log_amplitude_fields
@@ -22,19 +23,16 @@ class CalibrationDistribution(jft.Model):
         return jnp.exp(res_logamp + res_phase)
 
 class CalibrationInterpolator(jft.Model):
-    def __init__(self,ant_col, time_col,dt,cfs):
+    def __init__(self,ant_col, time_col,dt,target_shape):
         # Input shape follows (n_pol,n_antenna,n_timesteps,n_freq)
         # Output shape follows (n_pol,n_visibilities,n_freq)
         # The model assumes that you have grid with constant width in antenna and times
 
         coords = [ant_col,time_col/dt]
 
-        n_pol, _, _, n_freq = cfs.shape
-
         self._li = partial(jsc.ndimage.map_coordinates,coordinates=coords,order=1)
-        n_vis = jax.eval_shape(self._li(cfs[0, :, :, 0])).shape[0]
 
-        self._output_shape = (n_pol,n_vis,n_freq)
+        self._output_shape = target_shape
         
         super().__init__(init=self._cf)
 
-- 
GitLab


From 0735651dca0ea534e3943b038c50350bb75e1479 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 11 Feb 2025 17:38:32 +0100
Subject: [PATCH 07/88] CalibrationInterpolator: Removed redundant dependency
 in init method

---
 resolve/re/calibration.py | 2 --
 1 file changed, 2 deletions(-)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index eec35ffa..9dc06751 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -33,8 +33,6 @@ class CalibrationInterpolator(jft.Model):
         self._li = partial(jsc.ndimage.map_coordinates,coordinates=coords,order=1)
 
         self._output_shape = target_shape
-        
-        super().__init__(init=self._cf)
 
     def __call__(self,x):
         res = jnp.empty(self._output_shape)
-- 
GitLab


From bdd839326a15d7f6971795fbcad479a62a55c5fd Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 11 Feb 2025 17:39:35 +0100
Subject: [PATCH 08/88] Calibration likelihood: Implemented creation of amended
 likelihood for no given inverse covariance operation

---
 resolve/re/likelihood.py | 37 +++++++++++++++++++++++++++++++++++++
 1 file changed, 37 insertions(+)
 create mode 100644 resolve/re/likelihood.py

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
new file mode 100644
index 00000000..a89a07eb
--- /dev/null
+++ b/resolve/re/likelihood.py
@@ -0,0 +1,37 @@
+import numpy as np
+import nifty8.re as jft
+import jax.scipy as jsc
+import jax.numpy as jnp
+from functools import partial
+
+class CalibrationModel(jft.Model):
+    def __init__(self,cop, model_visibilities,mask):
+        self._cop = cop
+        self._vis = model_visibilities
+        self._mask = mask.astype(int)
+
+        super().__init__(init=self._cop.init)
+
+    def __call__(self,x):
+        model_data = self._vis*self._cop(x)
+        flagged_model_data = model_data[self._mask]
+        return flagged_model_data
+
+
+def CalibrationLikelihood(
+    observation,
+    calibration_operator,
+    model_visibilities,
+    log_inverse_covariance_operator=None
+):
+    mask = observation.mask.val
+    model_d = CalibrationModel(calibration_operator,model_visibilities,mask)
+
+    flagged_data = observation.vis.val[mask]
+
+    if log_inverse_covariance_operator is None:
+        flagged_covariance = observation.weight.val[mask]
+        
+        lh = jft.Gaussian(data=flagged_data,noise_cov_inv=flagged_covariance)
+
+        return lh.amend(model_d)
\ No newline at end of file
-- 
GitLab


From 7c0d6b7d3eae164329e392906b5b2dd890ccd3b9 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 11 Feb 2025 17:50:45 +0100
Subject: [PATCH 09/88] Added CalibrationCovarianceModel for inference with
 variable covariance

---
 resolve/re/likelihood.py | 21 ++++++++++++++++-----
 1 file changed, 16 insertions(+), 5 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index a89a07eb..c3df6e2d 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -4,18 +4,29 @@ import jax.scipy as jsc
 import jax.numpy as jnp
 from functools import partial
 
-class CalibrationModel(jft.Model):
-    def __init__(self,cop, model_visibilities,mask):
+class CalibrationDataModel(jft.Model):
+    def __init__(self, cop, model_visibilities, mask):
         self._cop = cop
         self._vis = model_visibilities
         self._mask = mask.astype(int)
 
         super().__init__(init=self._cop.init)
 
-    def __call__(self,x):
+    def __call__(self, x):
         model_data = self._vis*self._cop(x)
         flagged_model_data = model_data[self._mask]
         return flagged_model_data
+    
+class CalibrationCovarianceModel(jft.Model):
+    def __init__(self, covariance_model, mask):
+        self._mask = mask
+        self._covariance = covariance_model
+
+        super().__init__(init=self._cov.init)
+    
+    def __call__(self, x):
+        covariance = self._covariance(x)
+        flagged_covariance = covariance[self._mask]
 
 
 def CalibrationLikelihood(
@@ -25,13 +36,13 @@ def CalibrationLikelihood(
     log_inverse_covariance_operator=None
 ):
     mask = observation.mask.val
-    model_d = CalibrationModel(calibration_operator,model_visibilities,mask)
+    model_d = CalibrationDataModel(calibration_operator, model_visibilities, mask)
 
     flagged_data = observation.vis.val[mask]
 
     if log_inverse_covariance_operator is None:
         flagged_covariance = observation.weight.val[mask]
         
-        lh = jft.Gaussian(data=flagged_data,noise_cov_inv=flagged_covariance)
+        lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_covariance)
 
         return lh.amend(model_d)
\ No newline at end of file
-- 
GitLab


From 269cfa48591c985b14ff9932d26fd43a848c9bf2 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 11 Feb 2025 20:38:16 +0100
Subject: [PATCH 10/88] Shifted CalibrationDataModel and
 CalibrationCovarianceModel from likelihood.py to calibration.py

---
 resolve/re/calibration.py | 25 +++++++++++++++++++++++++
 resolve/re/likelihood.py  | 26 ++------------------------
 2 files changed, 27 insertions(+), 24 deletions(-)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index 9dc06751..0c1d97c7 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -4,6 +4,31 @@ import jax.scipy as jsc
 import jax.numpy as jnp
 from functools import partial
 
+class CalibrationDataModel(jft.Model):
+    def __init__(self, cop, model_visibilities, mask):
+        self._cop = cop
+        self._vis = model_visibilities
+        self._mask = mask.astype(int)
+
+        super().__init__(init=self._cop.init)
+
+    def __call__(self, x):
+        model_data = self._vis*self._cop(x)
+        flagged_model_data = model_data[self._mask]
+        return flagged_model_data
+    
+class CalibrationCovarianceModel(jft.Model):
+    def __init__(self, covariance_model, mask):
+        self._mask = mask
+        self._covariance = covariance_model
+
+        super().__init__(init=self._cov.init)
+    
+    def __call__(self, x):
+        covariance = self._covariance(x)
+        flagged_covariance = covariance[self._mask]
+        return flagged_covariance
+
 class CalibrationDistribution(jft.Model):
     def __init__(self,observation, phase_fields, log_amplitude_fields,dt):
         ap = observation.antenna_positions
diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index c3df6e2d..22fd59d0 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -4,29 +4,7 @@ import jax.scipy as jsc
 import jax.numpy as jnp
 from functools import partial
 
-class CalibrationDataModel(jft.Model):
-    def __init__(self, cop, model_visibilities, mask):
-        self._cop = cop
-        self._vis = model_visibilities
-        self._mask = mask.astype(int)
-
-        super().__init__(init=self._cop.init)
-
-    def __call__(self, x):
-        model_data = self._vis*self._cop(x)
-        flagged_model_data = model_data[self._mask]
-        return flagged_model_data
-    
-class CalibrationCovarianceModel(jft.Model):
-    def __init__(self, covariance_model, mask):
-        self._mask = mask
-        self._covariance = covariance_model
-
-        super().__init__(init=self._cov.init)
-    
-    def __call__(self, x):
-        covariance = self._covariance(x)
-        flagged_covariance = covariance[self._mask]
+from .response import InterferometryResponse
 
 
 def CalibrationLikelihood(
@@ -45,4 +23,4 @@ def CalibrationLikelihood(
         
         lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_covariance)
 
-        return lh.amend(model_d)
\ No newline at end of file
+        return lh.amend(model_d)
-- 
GitLab


From dbb659106e3dbc779bdefdcfefbc3d5fe928a544 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 12 Feb 2025 18:13:13 +0100
Subject: [PATCH 11/88] changed CalibrationCovarianceModel, input to model for
 log_covariance; output is the flagged exponentiated field

---
 resolve/re/calibration.py | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index 0c1d97c7..7cc9cafd 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -1,4 +1,3 @@
-import jax
 import nifty8.re as jft
 import jax.scipy as jsc
 import jax.numpy as jnp
@@ -18,14 +17,14 @@ class CalibrationDataModel(jft.Model):
         return flagged_model_data
     
 class CalibrationCovarianceModel(jft.Model):
-    def __init__(self, covariance_model, mask):
+    def __init__(self, log_covariance_model, mask):
         self._mask = mask
-        self._covariance = covariance_model
+        self._log_covariance = log_covariance_model
 
-        super().__init__(init=self._cov.init)
+        super().__init__(init=self._log_covariance.init)
     
     def __call__(self, x):
-        covariance = self._covariance(x)
+        covariance = jnp.exp(self._log_covariance(x))
         flagged_covariance = covariance[self._mask]
         return flagged_covariance
 
-- 
GitLab


From ec226850ed046fcb874885aa9eb972103ca45caf Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 12 Feb 2025 18:16:56 +0100
Subject: [PATCH 12/88] Added functionality for calibration using a variable
 covariance gaussian

---
 resolve/re/likelihood.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index 22fd59d0..87151175 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -5,7 +5,7 @@ import jax.numpy as jnp
 from functools import partial
 
 from .response import InterferometryResponse
-
+from .calibration import CalibrationDataModel, CalibrationCovarianceModel
 
 def CalibrationLikelihood(
     observation,
@@ -22,5 +22,11 @@ def CalibrationLikelihood(
         flagged_covariance = observation.weight.val[mask]
         
         lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_covariance)
-
         return lh.amend(model_d)
+    
+    else:
+        flagged_cov_model = CalibrationCovarianceModel(log_inverse_covariance_operator, mask)
+
+        lh = jft.VariableCovarianceGaussian(data=flagged_data)
+        models = (model_d,flagged_cov_model)
+        return lh.amend(models)
-- 
GitLab


From c89e554bc82b35177e0961dcfd4f71d30a6c5862 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Thu, 13 Feb 2025 11:43:47 +0100
Subject: [PATCH 13/88] Renamed Calibration models to better fit functionality;
 Variable Covariance model prepares both data and covariance model

---
 resolve/re/calibration.py | 30 ++++++++++++++++++------------
 1 file changed, 18 insertions(+), 12 deletions(-)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index 7cc9cafd..1e525c49 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -3,30 +3,36 @@ import jax.scipy as jsc
 import jax.numpy as jnp
 from functools import partial
 
-class CalibrationDataModel(jft.Model):
+class CalibrationFixedCovarianceModel(jft.Model):
     def __init__(self, cop, model_visibilities, mask):
         self._cop = cop
         self._vis = model_visibilities
-        self._mask = mask.astype(int)
+        self._mask = mask
 
         super().__init__(init=self._cop.init)
 
     def __call__(self, x):
-        model_data = self._vis*self._cop(x)
-        flagged_model_data = model_data[self._mask]
-        return flagged_model_data
+        data_model = self._vis*self._cop(x)
+        flagged_data_model = data_model[self._mask]
+        return flagged_data_model
     
-class CalibrationCovarianceModel(jft.Model):
-    def __init__(self, log_covariance_model, mask):
+class CalibrationVariableCovarianceModel(jft.Model):
+    def __init__(self, cop, model_visibilities, log_inverse_covariance_model, mask):
+        self._cop = cop
+        self._vis = model_visibilities
         self._mask = mask
-        self._log_covariance = log_covariance_model
+        self._log_inv_cov = log_inverse_covariance_model
 
-        super().__init__(init=self._log_covariance.init)
+        super().__init__(init=self._cop.init | self._log_inv_cov.init)
     
     def __call__(self, x):
-        covariance = jnp.exp(self._log_covariance(x))
-        flagged_covariance = covariance[self._mask]
-        return flagged_covariance
+        data_model = self._vis*self._cop(x)
+        flagged_data_model = data_model[self._mask]
+
+        inv_cov = jnp.exp(self._log_inv_cov(x))
+        flagged_inv_cov = inv_cov[self._mask]
+        
+        return (flagged_data_model,flagged_inv_cov)
 
 class CalibrationDistribution(jft.Model):
     def __init__(self,observation, phase_fields, log_amplitude_fields,dt):
-- 
GitLab


From eaf7a55621e8d8fec477bfcd974e56f6949fd60f Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Thu, 13 Feb 2025 12:03:35 +0100
Subject: [PATCH 14/88] Needed models for calibration and imaging gathered in
 seperate file

---
 resolve/re/calibration.py       | 31 -------------------------------
 resolve/re/likelihood_models.py | 33 +++++++++++++++++++++++++++++++++
 2 files changed, 33 insertions(+), 31 deletions(-)
 create mode 100644 resolve/re/likelihood_models.py

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index 1e525c49..79999ebc 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -3,37 +3,6 @@ import jax.scipy as jsc
 import jax.numpy as jnp
 from functools import partial
 
-class CalibrationFixedCovarianceModel(jft.Model):
-    def __init__(self, cop, model_visibilities, mask):
-        self._cop = cop
-        self._vis = model_visibilities
-        self._mask = mask
-
-        super().__init__(init=self._cop.init)
-
-    def __call__(self, x):
-        data_model = self._vis*self._cop(x)
-        flagged_data_model = data_model[self._mask]
-        return flagged_data_model
-    
-class CalibrationVariableCovarianceModel(jft.Model):
-    def __init__(self, cop, model_visibilities, log_inverse_covariance_model, mask):
-        self._cop = cop
-        self._vis = model_visibilities
-        self._mask = mask
-        self._log_inv_cov = log_inverse_covariance_model
-
-        super().__init__(init=self._cop.init | self._log_inv_cov.init)
-    
-    def __call__(self, x):
-        data_model = self._vis*self._cop(x)
-        flagged_data_model = data_model[self._mask]
-
-        inv_cov = jnp.exp(self._log_inv_cov(x))
-        flagged_inv_cov = inv_cov[self._mask]
-        
-        return (flagged_data_model,flagged_inv_cov)
-
 class CalibrationDistribution(jft.Model):
     def __init__(self,observation, phase_fields, log_amplitude_fields,dt):
         ap = observation.antenna_positions
diff --git a/resolve/re/likelihood_models.py b/resolve/re/likelihood_models.py
new file mode 100644
index 00000000..c54277e7
--- /dev/null
+++ b/resolve/re/likelihood_models.py
@@ -0,0 +1,33 @@
+import nifty8.re as jft
+import jax.numpy as jnp
+
+class CalibrationFixedCovarianceModel(jft.Model):
+    def __init__(self, cop, model_visibilities, mask):
+        self._cop = cop
+        self._vis = model_visibilities
+        self._mask = mask
+
+        super().__init__(init=self._cop.init)
+
+    def __call__(self, x):
+        data_model = self._vis*self._cop(x)
+        flagged_data_model = data_model[self._mask]
+        return flagged_data_model
+    
+class CalibrationVariableCovarianceModel(jft.Model):
+    def __init__(self, cop, model_visibilities, log_inverse_covariance_model, mask):
+        self._cop = cop
+        self._vis = model_visibilities
+        self._mask = mask
+        self._log_inv_cov = log_inverse_covariance_model
+
+        super().__init__(init=self._cop.init | self._log_inv_cov.init)
+    
+    def __call__(self, x):
+        data_model = self._vis*self._cop(x)
+        flagged_data_model = data_model[self._mask]
+
+        inv_cov = jnp.exp(self._log_inv_cov(x))
+        flagged_inv_cov = inv_cov[self._mask]
+        
+        return (flagged_data_model,flagged_inv_cov)
-- 
GitLab


From 82edc9766b1a337ef6bb6a244f21471a789113fb Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Thu, 13 Feb 2025 12:08:45 +0100
Subject: [PATCH 15/88] Renamed variable and rewrote structure for
 clarification

---
 resolve/re/likelihood.py | 17 ++++++++---------
 1 file changed, 8 insertions(+), 9 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index 87151175..11d9dc57 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -5,7 +5,7 @@ import jax.numpy as jnp
 from functools import partial
 
 from .response import InterferometryResponse
-from .calibration import CalibrationDataModel, CalibrationCovarianceModel
+from .likelihood_models import CalibrationFixedCovarianceModel, CalibrationVariableCovarianceModel
 
 def CalibrationLikelihood(
     observation,
@@ -14,19 +14,18 @@ def CalibrationLikelihood(
     log_inverse_covariance_operator=None
 ):
     mask = observation.mask.val
-    model_d = CalibrationDataModel(calibration_operator, model_visibilities, mask)
 
     flagged_data = observation.vis.val[mask]
 
     if log_inverse_covariance_operator is None:
-        flagged_covariance = observation.weight.val[mask]
+        model = CalibrationFixedCovarianceModel(calibration_operator,model_visibilities,mask)
+        flagged_inv_cov = observation.weight.val[mask]
         
-        lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_covariance)
-        return lh.amend(model_d)
+        lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
     
     else:
-        flagged_cov_model = CalibrationCovarianceModel(log_inverse_covariance_operator, mask)
+        model = CalibrationVariableCovarianceModel(calibration_operator,model_visibilities,log_inverse_covariance_operator,mask)
 
-        lh = jft.VariableCovarianceGaussian(data=flagged_data)
-        models = (model_d,flagged_cov_model)
-        return lh.amend(models)
+        lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=True)
+    
+    return lh.amend(model)
-- 
GitLab


From 6178125159282a3a2743020f58a8e20a4bb7888a Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Fri, 14 Feb 2025 00:17:05 +0100
Subject: [PATCH 16/88] Added fixed covariance model for classical imagin
 likelihood

---
 resolve/re/likelihood_models.py | 24 ++++++++++++++++++++++--
 1 file changed, 22 insertions(+), 2 deletions(-)

diff --git a/resolve/re/likelihood_models.py b/resolve/re/likelihood_models.py
index c54277e7..2b8043b0 100644
--- a/resolve/re/likelihood_models.py
+++ b/resolve/re/likelihood_models.py
@@ -1,7 +1,7 @@
 import nifty8.re as jft
 import jax.numpy as jnp
 
-class CalibrationFixedCovarianceModel(jft.Model):
+class CalibrationLikelihoodFixedCovarianceModel(jft.Model):
     def __init__(self, cop, model_visibilities, mask):
         self._cop = cop
         self._vis = model_visibilities
@@ -14,7 +14,7 @@ class CalibrationFixedCovarianceModel(jft.Model):
         flagged_data_model = data_model[self._mask]
         return flagged_data_model
     
-class CalibrationVariableCovarianceModel(jft.Model):
+class CalibrationLikelihoodVariableCovarianceModel(jft.Model):
     def __init__(self, cop, model_visibilities, log_inverse_covariance_model, mask):
         self._cop = cop
         self._vis = model_visibilities
@@ -31,3 +31,23 @@ class CalibrationVariableCovarianceModel(jft.Model):
         flagged_inv_cov = inv_cov[self._mask]
         
         return (flagged_data_model,flagged_inv_cov)
+
+class ImagingLikelihoodFixedCovarianceModel(jft.Model):
+    def __init__(self, R, sky_operator, calibration_operator=None, calibration_field=None):
+        self._R = R
+        self._sky = sky_operator
+
+        if calibration_operator is not None:
+            self._cal_op = calibration_operator
+            super().__init__(init=self._R.init | self._cal_op.init)
+        elif calibration_field is not None:
+            self._cal_fld = calibration_field
+            super().__init__(init=self._R.init)
+
+    def __call__(self,x):
+        if self._cal_op is not None:
+            return self._cal_op(x)*self._R(self._sky(x))
+        elif self._cal_fld is not None:
+            return self._cal_fld*self._R(self._sky(x))
+        else:
+            return self._R(self._sky(x))
\ No newline at end of file
-- 
GitLab


From 38628fd15ff3f139d6887df09cb8d8d9ab4a87b2 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Fri, 14 Feb 2025 00:29:27 +0100
Subject: [PATCH 17/88] Added Masking operation to fixed covariance classical
 imaging model

---
 resolve/re/likelihood_models.py | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/resolve/re/likelihood_models.py b/resolve/re/likelihood_models.py
index 2b8043b0..0ea1403e 100644
--- a/resolve/re/likelihood_models.py
+++ b/resolve/re/likelihood_models.py
@@ -33,9 +33,10 @@ class CalibrationLikelihoodVariableCovarianceModel(jft.Model):
         return (flagged_data_model,flagged_inv_cov)
 
 class ImagingLikelihoodFixedCovarianceModel(jft.Model):
-    def __init__(self, R, sky_operator, calibration_operator=None, calibration_field=None):
+    def __init__(self, R, sky_operator, mask, calibration_operator=None, calibration_field=None):
         self._R = R
         self._sky = sky_operator
+        self._mask = mask
 
         if calibration_operator is not None:
             self._cal_op = calibration_operator
@@ -46,8 +47,12 @@ class ImagingLikelihoodFixedCovarianceModel(jft.Model):
 
     def __call__(self,x):
         if self._cal_op is not None:
-            return self._cal_op(x)*self._R(self._sky(x))
+            data_model = self._cal_op(x)*self._R(self._sky(x))
         elif self._cal_fld is not None:
-            return self._cal_fld*self._R(self._sky(x))
+            data_model = self._cal_fld*self._R(self._sky(x))
         else:
-            return self._R(self._sky(x))
\ No newline at end of file
+            data_model = self._R(self._sky(x))
+
+        flagged_data_model = data_model[self._mask]
+        return flagged_data_model
+        
-- 
GitLab


From 6a186d64b473347e150d470abc9978ff1cec86a6 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Fri, 14 Feb 2025 00:31:11 +0100
Subject: [PATCH 18/88] Added variable covariance model for classical imaging
 likelihood

---
 resolve/re/likelihood_models.py | 28 ++++++++++++++++++++++++++++
 1 file changed, 28 insertions(+)

diff --git a/resolve/re/likelihood_models.py b/resolve/re/likelihood_models.py
index 0ea1403e..0e27fe6b 100644
--- a/resolve/re/likelihood_models.py
+++ b/resolve/re/likelihood_models.py
@@ -56,3 +56,31 @@ class ImagingLikelihoodFixedCovarianceModel(jft.Model):
         flagged_data_model = data_model[self._mask]
         return flagged_data_model
         
+class ImagingLikelihoodVariableCovarianceModel(jft.Model):
+    def __init__(self, R, sky_operator, log_inverse_covariance_model, mask, calibration_operator=None, calibration_field=None):
+        self._R = R
+        self._sky = sky_operator
+        self._mask = mask
+        self._log_inv_cov = log_inverse_covariance_model
+
+        if calibration_operator is not None:
+            self._cal_op = calibration_operator
+            super().__init__(init=self._R.init | self._log_inv_cov.init | self._cal_op.init)
+        elif calibration_field is not None:
+            self._cal_fld = calibration_field
+            super().__init__(init=self._R.init | self._log_inv_cov.init)
+
+    def __call__(self,x):
+        if self._cal_op is not None:
+            data_model = self._cal_op(x)*self._R(self._sky(x))
+        elif self._cal_fld is not None:
+            data_model = self._cal_fld*self._R(self._sky(x))
+        else:
+            data_model = self._R(self._sky(x))
+
+        flagged_data_model = data_model[self._mask]
+
+        inv_cov = jnp.exp(self._log_inv_cov(x))
+        flagged_inv_cov = inv_cov[self._mask]
+        
+        return (flagged_data_model,flagged_inv_cov)
\ No newline at end of file
-- 
GitLab


From 1bbf5df04eebe4ee23ba3a72b42a9994e27af1dd Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Fri, 14 Feb 2025 00:44:32 +0100
Subject: [PATCH 19/88] Added imaging likelihood for classical resolve imaging

---
 resolve/re/likelihood.py | 34 +++++++++++++++++++++++++++++++++-
 1 file changed, 33 insertions(+), 1 deletion(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index 11d9dc57..8f4df0e9 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -5,7 +5,7 @@ import jax.numpy as jnp
 from functools import partial
 
 from .response import InterferometryResponse
-from .likelihood_models import CalibrationFixedCovarianceModel, CalibrationVariableCovarianceModel
+from .likelihood_models import CalibrationFixedCovarianceModel, CalibrationVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
 
 def CalibrationLikelihood(
     observation,
@@ -29,3 +29,35 @@ def CalibrationLikelihood(
         lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=True)
     
     return lh.amend(model)
+
+
+def ImagingLikelihood(
+    observation,
+    sky_operator,
+    sky_domain_dict,
+    epsilon,
+    do_wgridding,
+    log_inverse_covariance_operator=None,
+    calibration_operator=None,
+    calibration_field=None,
+    verbosity=0,
+    nthreads=1,
+    backend="ducc0",
+):
+    R = InterferometryResponse(observation,sky_domain_dict,do_wgridding,epsilon,nthreads,verbosity,backend)
+    mask = observation.mask.val
+
+    flagged_data = observation.vis.val[mask]
+
+    if log_inverse_covariance_operator is None:
+        model = ImagingLikelihoodFixedCovarianceModel(R,sky_operator,mask,calibration_operator,calibration_field)
+        flagged_inv_cov = observation.weight.val[mask]
+        
+        lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
+
+    else:
+        model = ImagingLikelihoodVariableCovarianceModel(R,sky_operator,log_inverse_covariance_operator,mask,calibration_operator,calibration_field)
+
+        lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=True)
+
+    return lh.amend(model)
\ No newline at end of file
-- 
GitLab


From ea21fca33a84decd2a44436e304a0416fa33b10e Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Fri, 14 Feb 2025 12:11:03 +0100
Subject: [PATCH 20/88] Added possibility for list of observations and
 operators as inputs of the likelihoods

---
 resolve/re/likelihood.py | 56 +++++++++++++++++++++-------------------
 1 file changed, 30 insertions(+), 26 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index 8f4df0e9..352360eb 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -1,8 +1,4 @@
-import numpy as np
 import nifty8.re as jft
-import jax.scipy as jsc
-import jax.numpy as jnp
-from functools import partial
 
 from .response import InterferometryResponse
 from .likelihood_models import CalibrationFixedCovarianceModel, CalibrationVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
@@ -13,22 +9,26 @@ def CalibrationLikelihood(
     model_visibilities,
     log_inverse_covariance_operator=None
 ):
-    mask = observation.mask.val
+    likelihoods = []
+    for ii, (obs, cop, model_vis, log_inv_cov) in enumerate(zip(observation,calibration_operator,model_visibilities,log_inverse_covariance_operator)):    
+        mask = obs.mask.val
 
-    flagged_data = observation.vis.val[mask]
+        flagged_data = obs.vis.val[mask]
 
-    if log_inverse_covariance_operator is None:
-        model = CalibrationFixedCovarianceModel(calibration_operator,model_visibilities,mask)
-        flagged_inv_cov = observation.weight.val[mask]
+        if log_inv_cov is None:
+            model = CalibrationFixedCovarianceModel(cop,model_vis,mask)
+            flagged_inv_cov = obs.weight.val[mask]
+            
+            lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
         
-        lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
-    
-    else:
-        model = CalibrationVariableCovarianceModel(calibration_operator,model_visibilities,log_inverse_covariance_operator,mask)
+        else:
+            model = CalibrationVariableCovarianceModel(cop,model_vis,log_inv_cov,mask)
 
-        lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=True)
+            lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=True)
+        
+        likelihoods.append(lh.amend(model))
     
-    return lh.amend(model)
+    return sum(likelihoods)
 
 
 def ImagingLikelihood(
@@ -44,20 +44,24 @@ def ImagingLikelihood(
     nthreads=1,
     backend="ducc0",
 ):
-    R = InterferometryResponse(observation,sky_domain_dict,do_wgridding,epsilon,nthreads,verbosity,backend)
-    mask = observation.mask.val
+    likelihoods = []
+    for ii, (obs, cop, cfld, log_inv_cov) in enumerate(zip(observation,calibration_operator,calibration_field,log_inverse_covariance_operator)):    
+        R = InterferometryResponse(obs,sky_domain_dict,do_wgridding,epsilon,nthreads,verbosity,backend)
+        mask = obs.mask.val
 
-    flagged_data = observation.vis.val[mask]
+        flagged_data = obs.vis.val[mask]
 
-    if log_inverse_covariance_operator is None:
-        model = ImagingLikelihoodFixedCovarianceModel(R,sky_operator,mask,calibration_operator,calibration_field)
-        flagged_inv_cov = observation.weight.val[mask]
+        if log_inv_cov is None:
+            model = ImagingLikelihoodFixedCovarianceModel(R,sky_operator,mask,cop,cfld)
+            flagged_inv_cov = obs.weight.val[mask]
         
-        lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
+            lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
 
-    else:
-        model = ImagingLikelihoodVariableCovarianceModel(R,sky_operator,log_inverse_covariance_operator,mask,calibration_operator,calibration_field)
+        else:
+            model = ImagingLikelihoodVariableCovarianceModel(R,sky_operator,log_inv_cov,cfld)
 
-        lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=True)
+            lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=True)
 
-    return lh.amend(model)
\ No newline at end of file
+        likelihoods.append(lh.amend(model))
+    
+    return sum(likelihoods)
\ No newline at end of file
-- 
GitLab


From 309abc1146bd699e3e5aa0929446beb133dd9e1c Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 18 Feb 2025 15:34:47 +0100
Subject: [PATCH 21/88] updated __init__

---
 resolve/re/__init__.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/resolve/re/__init__.py b/resolve/re/__init__.py
index f095f226..60b938b6 100644
--- a/resolve/re/__init__.py
+++ b/resolve/re/__init__.py
@@ -2,4 +2,7 @@
 from .sky_model import sky_model_diffuse, sky_model_points, sky_model
 from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc
 from .radio_response import build_exact_r, build_approximations
-from .optimize import optimize
\ No newline at end of file
+from .optimize import optimize
+from .calibration import CalibrationDistribution, CalibrationInterpolator
+from .likelihood_models import CalibrationLikelihoodFixedCovarianceModel, CalibrationLikelihoodVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
+from .likelihood import CalibrationLikelihood, ImagingLikelihood
\ No newline at end of file
-- 
GitLab


From c50258c7efba6c0ad3ed25a4c89b005b5f83046d Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 19 Feb 2025 17:33:50 +0100
Subject: [PATCH 22/88] Resolved bug

---
 resolve/re/likelihood.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index 352360eb..9e0e522c 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -1,7 +1,7 @@
 import nifty8.re as jft
 
 from .response import InterferometryResponse
-from .likelihood_models import CalibrationFixedCovarianceModel, CalibrationVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
+from .likelihood_models import CalibrationLikelihoodFixedCovarianceModel, CalibrationLikelihoodVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
 
 def CalibrationLikelihood(
     observation,
@@ -16,13 +16,13 @@ def CalibrationLikelihood(
         flagged_data = obs.vis.val[mask]
 
         if log_inv_cov is None:
-            model = CalibrationFixedCovarianceModel(cop,model_vis,mask)
+            model = CalibrationLikelihoodFixedCovarianceModel(cop,model_vis,mask)
             flagged_inv_cov = obs.weight.val[mask]
             
             lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
         
         else:
-            model = CalibrationVariableCovarianceModel(cop,model_vis,log_inv_cov,mask)
+            model = CalibrationLikelihoodVariableCovarianceModel(cop,model_vis,log_inv_cov,mask)
 
             lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=True)
         
-- 
GitLab


From a8081d590e075c53131f0782f8b9d3969af0da32 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 19 Feb 2025 17:34:35 +0100
Subject: [PATCH 23/88] Commented .optimize_out as it is rewritten

---
 resolve/re/__init__.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/resolve/re/__init__.py b/resolve/re/__init__.py
index 60b938b6..26f631ea 100644
--- a/resolve/re/__init__.py
+++ b/resolve/re/__init__.py
@@ -2,7 +2,7 @@
 from .sky_model import sky_model_diffuse, sky_model_points, sky_model
 from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc
 from .radio_response import build_exact_r, build_approximations
-from .optimize import optimize
-from .calibration import CalibrationDistribution, CalibrationInterpolator
+#from .optimize import optimize
+from .calibration import CalibrationInterpolator, CalibrationDistribution
 from .likelihood_models import CalibrationLikelihoodFixedCovarianceModel, CalibrationLikelihoodVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
 from .likelihood import CalibrationLikelihood, ImagingLikelihood
\ No newline at end of file
-- 
GitLab


From 28e995bb26dc3a9a04acedc40527e1a468ed5c17 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 24 Feb 2025 17:22:31 +0100
Subject: [PATCH 24/88] Added functionality to construct random observations
 for testing

---
 misc/observation_generator.py | 42 +++++++++++++++++++++++++++++++++++
 1 file changed, 42 insertions(+)
 create mode 100644 misc/observation_generator.py

diff --git a/misc/observation_generator.py b/misc/observation_generator.py
new file mode 100644
index 00000000..2a4b5f05
--- /dev/null
+++ b/misc/observation_generator.py
@@ -0,0 +1,42 @@
+import numpy as np
+
+import resolve as rve
+
+def RandomAntennaPositions(n_row,uvw_max,n_antenna_max,time_max):
+    ant1,ant2 = [],[]
+
+    for k in range(n_row):
+        x = np.random.randint(0,n_antenna_max-1)
+        y = np.random.randint(1,n_antenna_max)
+
+        if(x==y):
+            while(x==y):
+                y = np.random.randint(1,n_antenna_max)
+
+        if (x < y):
+            ant1.append(x)
+            ant2.append(y)
+        else:
+            ant1.append(y)
+            ant2.append(x)
+    
+    ant1 = np.array(ant1)
+    ant2 = np.array(ant2)
+    time = np.random.uniform(0,time_max,n_row)
+    uvw = np.random.uniform(-uvw_max,uvw_max,(n_row,3))
+
+    return rve.data.antenna_positions.AntennaPositions(uvw,ant1,ant2,time)
+
+def RandomObservation(n_baselines,pol_indices,freq_channels,uvw_max,n_antenna_max,time_max,abs_vis_max,weight_min,weight_max):
+    antenna_pos = RandomAntennaPositions(n_baselines,uvw_max,n_antenna_max,time_max)
+
+    n_pol = len(pol_indices)
+    n_freq = freq_channels.size
+    vis_shape = (n_pol,n_baselines,n_freq)
+
+    pol = rve.data.polarization.Polarization(pol_indices)
+
+    vis = np.random.uniform(0,abs_vis_max,vis_shape)*np.exp(1.0j*np.random.uniform(0,2*np.pi,vis_shape))
+    weights = np.random.uniform(weight_min,weight_max,vis_shape)
+    
+    return rve.data.observation.Observation(antenna_pos,vis,weights,pol,freq_channels)
\ No newline at end of file
-- 
GitLab


From 6f90cc7aa45235fbc52183860a9ed8ac8f0f4186 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Thu, 27 Feb 2025 17:18:13 +0100
Subject: [PATCH 25/88] Fixed minor bug when calling the
 CalibrationInterpolator in CalibrationDistribution

---
 resolve/re/calibration.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index 79999ebc..0b16c914 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -7,8 +7,8 @@ class CalibrationDistribution(jft.Model):
     def __init__(self,observation, phase_fields, log_amplitude_fields,dt):
         ap = observation.antenna_positions
         target_shape = observation.vis.shape
-        self._cop1 = CalibrationInterpolator(dt, ap.ant1, ap.time, target_shape)
-        self._cop2 = CalibrationInterpolator(dt, ap.ant2, ap.time, target_shape)
+        self._cop1 = CalibrationInterpolator(ap.ant1, ap.time, dt, target_shape)
+        self._cop2 = CalibrationInterpolator(ap.ant2, ap.time, dt, target_shape)
 
         self._phases = phase_fields
         self._logamps = log_amplitude_fields
-- 
GitLab


From af0e0fbd446bb9f65fea80f395903f8e55b5d9e9 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 3 Mar 2025 09:13:37 +0100
Subject: [PATCH 26/88] Added rng generator input and set dtype to int32 of
 antennas

---
 misc/observation_generator.py | 25 ++++++++++++++-----------
 1 file changed, 14 insertions(+), 11 deletions(-)

diff --git a/misc/observation_generator.py b/misc/observation_generator.py
index 2a4b5f05..8b34ec43 100644
--- a/misc/observation_generator.py
+++ b/misc/observation_generator.py
@@ -2,12 +2,12 @@ import numpy as np
 
 import resolve as rve
 
-def RandomAntennaPositions(n_row,uvw_max,n_antenna_max,time_max):
+def RandomAntennaPositions(n_row,uvw_max,n_antenna_max,time_max,rng_generator=np.random.default_rng(42)):
     ant1,ant2 = [],[]
 
     for k in range(n_row):
-        x = np.random.randint(0,n_antenna_max-1)
-        y = np.random.randint(1,n_antenna_max)
+        x = rng_generator.integers(0,n_antenna_max-1)
+        y = rng_generator.integers(1,n_antenna_max)
 
         if(x==y):
             while(x==y):
@@ -20,15 +20,15 @@ def RandomAntennaPositions(n_row,uvw_max,n_antenna_max,time_max):
             ant1.append(y)
             ant2.append(x)
     
-    ant1 = np.array(ant1)
-    ant2 = np.array(ant2)
-    time = np.random.uniform(0,time_max,n_row)
-    uvw = np.random.uniform(-uvw_max,uvw_max,(n_row,3))
+    ant1 = np.array(ant1).astype(np.int32)
+    ant2 = np.array(ant2).astype(np.int32)
+    time = rng_generator.uniform(0,time_max,n_row)
+    uvw = rng_generator.uniform(-uvw_max,uvw_max,(n_row,3))
 
     return rve.data.antenna_positions.AntennaPositions(uvw,ant1,ant2,time)
 
-def RandomObservation(n_baselines,pol_indices,freq_channels,uvw_max,n_antenna_max,time_max,abs_vis_max,weight_min,weight_max):
-    antenna_pos = RandomAntennaPositions(n_baselines,uvw_max,n_antenna_max,time_max)
+def RandomObservation(n_baselines,pol_indices,freq_channels,uvw_max,n_antenna_max,time_max,abs_vis_max,weight_min,weight_max,rng_generator=np.random.default_rng(42)):
+    antenna_pos = RandomAntennaPositions(n_baselines,uvw_max,n_antenna_max,time_max,rng_generator=rng_generator)
 
     n_pol = len(pol_indices)
     n_freq = freq_channels.size
@@ -36,7 +36,10 @@ def RandomObservation(n_baselines,pol_indices,freq_channels,uvw_max,n_antenna_ma
 
     pol = rve.data.polarization.Polarization(pol_indices)
 
-    vis = np.random.uniform(0,abs_vis_max,vis_shape)*np.exp(1.0j*np.random.uniform(0,2*np.pi,vis_shape))
-    weights = np.random.uniform(weight_min,weight_max,vis_shape)
+    vis_magnitude = rng_generator.uniform(0,abs_vis_max,vis_shape)
+    vis_phase = rng_generator.uniform(0,2*np.pi,vis_shape)
+    vis = vis_magnitude*np.exp(1.0j*vis_phase)
+
+    weights = rng_generator.uniform(weight_min,weight_max,vis_shape)
     
     return rve.data.observation.Observation(antenna_pos,vis,weights,pol,freq_channels)
\ No newline at end of file
-- 
GitLab


From ee732dd5f5ae6a024326a5e690b3d15a42bc8674 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 3 Mar 2025 10:11:49 +0100
Subject: [PATCH 27/88] Fixed minor bug in assigning of interpolated points in
 resulting array

---
 resolve/re/calibration.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index 0b16c914..68f3a52c 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -39,6 +39,6 @@ class CalibrationInterpolator(jft.Model):
         for pol in range(n_pol):
             for freq in range(n_freq):
                 val = x[pol, :, :, freq]
-                res[pol, :, freq] = self._li(val)
+                res = res.at[pol, :, freq].set(self._li(val))
         
         return res
\ No newline at end of file
-- 
GitLab


From 7a04f11400ab14afac2e05aea302de7a30007270 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 3 Mar 2025 13:13:10 +0100
Subject: [PATCH 28/88] Uncommented .optimze

---
 resolve/re/__init__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/resolve/re/__init__.py b/resolve/re/__init__.py
index 26f631ea..d6a2a192 100644
--- a/resolve/re/__init__.py
+++ b/resolve/re/__init__.py
@@ -2,7 +2,7 @@
 from .sky_model import sky_model_diffuse, sky_model_points, sky_model
 from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc
 from .radio_response import build_exact_r, build_approximations
-#from .optimize import optimize
+from .optimize import optimize
 from .calibration import CalibrationInterpolator, CalibrationDistribution
 from .likelihood_models import CalibrationLikelihoodFixedCovarianceModel, CalibrationLikelihoodVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
 from .likelihood import CalibrationLikelihood, ImagingLikelihood
\ No newline at end of file
-- 
GitLab


From 873c588c47043f176dd4485eee2f6ea03e145bdf Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 3 Mar 2025 13:29:27 +0100
Subject: [PATCH 29/88] Remove sorting of antenna integers in
 RandomAntennaPositions

---
 misc/observation_generator.py | 7 -------
 1 file changed, 7 deletions(-)

diff --git a/misc/observation_generator.py b/misc/observation_generator.py
index 8b34ec43..82e294a5 100644
--- a/misc/observation_generator.py
+++ b/misc/observation_generator.py
@@ -12,13 +12,6 @@ def RandomAntennaPositions(n_row,uvw_max,n_antenna_max,time_max,rng_generator=np
         if(x==y):
             while(x==y):
                 y = np.random.randint(1,n_antenna_max)
-
-        if (x < y):
-            ant1.append(x)
-            ant2.append(y)
-        else:
-            ant1.append(y)
-            ant2.append(x)
     
     ant1 = np.array(ant1).astype(np.int32)
     ant2 = np.array(ant2).astype(np.int32)
-- 
GitLab


From 459b6fd4d41049db5af603fa8b3679ec8c08e1c9 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 3 Mar 2025 14:13:59 +0100
Subject: [PATCH 30/88] Remove bug; integers for both antenna get now appended
 to antenna lists

---
 misc/observation_generator.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/misc/observation_generator.py b/misc/observation_generator.py
index 82e294a5..a1d7977a 100644
--- a/misc/observation_generator.py
+++ b/misc/observation_generator.py
@@ -12,7 +12,10 @@ def RandomAntennaPositions(n_row,uvw_max,n_antenna_max,time_max,rng_generator=np
         if(x==y):
             while(x==y):
                 y = np.random.randint(1,n_antenna_max)
-    
+
+            ant1.append(x)
+            ant2.append(y)
+            
     ant1 = np.array(ant1).astype(np.int32)
     ant2 = np.array(ant2).astype(np.int32)
     time = rng_generator.uniform(0,time_max,n_row)
-- 
GitLab


From a3fc91246701249beb954488b7f30466b0dfe9a8 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 3 Mar 2025 14:16:55 +0100
Subject: [PATCH 31/88] Removed indents when appending antenna indices to
 antenna lists

---
 misc/observation_generator.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/misc/observation_generator.py b/misc/observation_generator.py
index a1d7977a..7e7cd25a 100644
--- a/misc/observation_generator.py
+++ b/misc/observation_generator.py
@@ -13,9 +13,9 @@ def RandomAntennaPositions(n_row,uvw_max,n_antenna_max,time_max,rng_generator=np
             while(x==y):
                 y = np.random.randint(1,n_antenna_max)
 
-            ant1.append(x)
-            ant2.append(y)
-            
+        ant1.append(x)
+        ant2.append(y)
+
     ant1 = np.array(ant1).astype(np.int32)
     ant2 = np.array(ant2).astype(np.int32)
     time = rng_generator.uniform(0,time_max,n_row)
-- 
GitLab


From f74316083f05dd38adc1cfd1f50a1c90186df914 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Sun, 9 Mar 2025 16:52:37 +0100
Subject: [PATCH 32/88] remove jft.Model dependence from
 CalibrationInterpolator

---
 resolve/re/calibration.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index 68f3a52c..d2075697 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -21,7 +21,7 @@ class CalibrationDistribution(jft.Model):
 
         return jnp.exp(res_logamp + res_phase)
 
-class CalibrationInterpolator(jft.Model):
+class CalibrationInterpolator():
     def __init__(self,ant_col, time_col,dt,target_shape):
         # Input shape follows (n_pol,n_antenna,n_timesteps,n_freq)
         # Output shape follows (n_pol,n_visibilities,n_freq)
-- 
GitLab


From 72a2b0eaf5e9be229df99fd8c3f51b63b5ad2d32 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Sun, 9 Mar 2025 16:58:56 +0100
Subject: [PATCH 33/88] Added file for convinience models

---
 resolve/re/sugar.py | 0
 1 file changed, 0 insertions(+), 0 deletions(-)
 create mode 100644 resolve/re/sugar.py

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
new file mode 100644
index 00000000..e69de29b
-- 
GitLab


From 7ecbcb4dbe4cece8db7c12b5ed5c091c26b93819 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Sun, 9 Mar 2025 17:00:15 +0100
Subject: [PATCH 34/88] Added model bulk correlated fields similar to nifty
 correlated field with N_total > 1

---
 resolve/re/sugar.py | 64 +++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 64 insertions(+)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index e69de29b..26d171e0 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -0,0 +1,64 @@
+import jax.numpy as jnp
+
+import nifty8.re as jft
+
+class Bulk_CF(jft.Model):
+    def __init__(self,dct_offset,dct_ps,name,polarizations,frequencies,antennas):
+
+        self._fields = {}
+        self._powerspectra = {}
+        bulk_cf_init = None
+
+        for p in polarizations:
+            self._fields[p] = {}
+            self._powerspectra[p] = {}
+
+            for a in antennas:
+                self._fields[p][a] = {}
+                self._powerspectra[p][a] = {}
+
+                for f in frequencies:
+                    cfm = jft.CorrelatedFieldMaker(f"{name}_{p}_ant{a}_freq{f}_")
+                    cfm.set_amplitude_total_offset(**dct_offset)
+                    cfm.add_fluctuations(**dct_ps)
+
+                    self._fields[p][a][f] = cfm.finalize()
+                    self._powerspectra[p][a][f] = cfm.power_spectrum
+
+                    if bulk_cf_init is not None:
+                        bulk_cf_init = bulk_cf_init | self._fields[p][a][f].init
+                    else:
+                        bulk_cf_init = self._fields[p][a][f].init
+
+        super().__init__(init=bulk_cf_init)
+
+    def __call__(self,x):
+        return jnp.swapaxes(jnp.array([[[self._fields[p][a][f](x) for f in self._fields[p][a].keys()]for a in self._fields[p].keys()]for p in self._fields.keys()]),2,3)
+    
+    def fields_to_dict(self,x):
+        field_dct = {}
+
+        for p in self._fields.keys():
+            field_dct[p] = {}
+
+            for a in self._fields[p].keys():
+                field_dct[p][a] = {}
+
+                for f in self._fields[p][a].keys():
+                    field_dct[p][a][f] = self._fields[p][a][f](x)
+        
+        return field_dct
+    
+    def powerspectra_to_dict(self,x):
+        powerspectra_dct = {}
+
+        for p in self._fields.keys():
+            powerspectra_dct[p] = {}
+
+            for a in self._fields[p].keys():
+                powerspectra_dct[p][a] = {}
+
+                for f in self._fields[p][a].keys():
+                    powerspectra_dct[p][a][f] = self._powerspectra[p][a][f](x)
+        
+        return powerspectra_dct
\ No newline at end of file
-- 
GitLab


From 176df999f82b4f7a53f74779c517a33a711d3f9c Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Sun, 9 Mar 2025 17:02:11 +0100
Subject: [PATCH 35/88] Update __init__.py to include sugar.py

---
 resolve/re/__init__.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/resolve/re/__init__.py b/resolve/re/__init__.py
index d6a2a192..9b9a6d6a 100644
--- a/resolve/re/__init__.py
+++ b/resolve/re/__init__.py
@@ -5,4 +5,5 @@ from .radio_response import build_exact_r, build_approximations
 from .optimize import optimize
 from .calibration import CalibrationInterpolator, CalibrationDistribution
 from .likelihood_models import CalibrationLikelihoodFixedCovarianceModel, CalibrationLikelihoodVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
-from .likelihood import CalibrationLikelihood, ImagingLikelihood
\ No newline at end of file
+from .likelihood import CalibrationLikelihood, ImagingLikelihood
+from .sugar import Bulk_CF
\ No newline at end of file
-- 
GitLab


From cb293d90aa738179d2a86b1603355a266b149fab Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 10 Mar 2025 01:00:26 +0100
Subject: [PATCH 36/88] Removed bugs for CalibrationLikelihood

---
 resolve/re/likelihood.py | 32 +++++++++++++++++++++++++-------
 1 file changed, 25 insertions(+), 7 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index 9e0e522c..8795cf20 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -3,21 +3,33 @@ import nifty8.re as jft
 from .response import InterferometryResponse
 from .likelihood_models import CalibrationLikelihoodFixedCovarianceModel, CalibrationLikelihoodVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
 
+from ..util import _obj2list
+
+from ..data.observation import Observation
+from .calibration import CalibrationDistribution
+from jaxlib.xla_extension import ArrayImpl
+from nifty8.re.model import Model
+
 def CalibrationLikelihood(
     observation,
     calibration_operator,
     model_visibilities,
     log_inverse_covariance_operator=None
 ):
-    likelihoods = []
-    for ii, (obs, cop, model_vis, log_inv_cov) in enumerate(zip(observation,calibration_operator,model_visibilities,log_inverse_covariance_operator)):    
-        mask = obs.mask.val
+    obs = _obj2list(observation,Observation)
+    cops = _obj2list(calibration_operator,CalibrationDistribution)
+    model_d = _obj2list(model_visibilities,ArrayImpl)
+    log_inv_covs = _obj2list(log_inverse_covariance_operator,Model)
 
-        flagged_data = obs.vis.val[mask]
+    lh_sum = None
+    for oo, cop, model_vis, log_inv_cov in zip(obs,cops,model_d,log_inv_covs):    
+        mask = oo.mask.val
+
+        flagged_data = oo.vis.val[mask]
 
         if log_inv_cov is None:
             model = CalibrationLikelihoodFixedCovarianceModel(cop,model_vis,mask)
-            flagged_inv_cov = obs.weight.val[mask]
+            flagged_inv_cov = oo.weight.val[mask]
             
             lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
         
@@ -26,9 +38,15 @@ def CalibrationLikelihood(
 
             lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=True)
         
-        likelihoods.append(lh.amend(model))
+        lh_with_model = lh.amend(model)
+        lh_with_model._domain = jft.Vector(lh_with_model._domain)
+
+        if lh_sum is not None:
+            lh_sum += lh_with_model
+        else:
+            lh_sum = lh_with_model
     
-    return sum(likelihoods)
+    return lh_sum
 
 
 def ImagingLikelihood(
-- 
GitLab


From 38b271fcb1e1a3ae2f3ee53063a9287bbe1990a2 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 10 Mar 2025 09:50:30 +0100
Subject: [PATCH 37/88] Added Bulk Correlated Field model on Visibility Space

---
 resolve/re/sugar.py | 53 ++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 52 insertions(+), 1 deletion(-)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index 26d171e0..8344d504 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -2,7 +2,7 @@ import jax.numpy as jnp
 
 import nifty8.re as jft
 
-class Bulk_CF(jft.Model):
+class Bulk_CF_AntennaTimeDomain(jft.Model):
     def __init__(self,dct_offset,dct_ps,name,polarizations,frequencies,antennas):
 
         self._fields = {}
@@ -61,4 +61,55 @@ class Bulk_CF(jft.Model):
                 for f in self._fields[p][a].keys():
                     powerspectra_dct[p][a][f] = self._powerspectra[p][a][f](x)
         
+        return powerspectra_dct
+    
+class Bulk_CF_VisibilityDomain(jft.Model):
+    def __init__(self,dct_offset,dct_ps,name,polarizations,frequencies):
+
+        self._fields = {}
+        self._powerspectra = {}
+        bulk_cf_init = None
+
+        for p in polarizations:
+            self._fields[p] = {}
+            self._powerspectra[p] = {}
+
+            for f in frequencies:
+                cfm = jft.CorrelatedFieldMaker(f"{name}_{p}_freq{f}_")
+                cfm.set_amplitude_total_offset(**dct_offset)
+                cfm.add_fluctuations(**dct_ps)
+
+                self._fields[p][f] = cfm.finalize()
+                self._powerspectra[p][f] = cfm.power_spectrum
+
+                if bulk_cf_init is not None:
+                    bulk_cf_init = bulk_cf_init | self._fields[p][f].init
+                else:
+                    bulk_cf_init = self._fields[p][f].init
+
+        super().__init__(init=bulk_cf_init)
+
+    def __call__(self,x):
+        return jnp.swapaxes(jnp.array([[self._fields[p][f](x) for f in self._fields[p].keys()]for p in self._fields.keys()]),1,2)
+    
+    def fields_to_dict(self,x):
+        field_dct = {}
+
+        for p in self._fields.keys():
+            field_dct[p] = {}
+
+            for f in self._fields[p].keys():
+                field_dct[p][f] = self._fields[p][f](x)
+        
+        return field_dct
+    
+    def powerspectra_to_dict(self,x):
+        powerspectra_dct = {}
+
+        for p in self._fields.keys():
+            powerspectra_dct[p] = {}
+
+            for f in self._fields[p].keys():
+                powerspectra_dct[p][f] = self._powerspectra[p][f](x)
+        
         return powerspectra_dct
\ No newline at end of file
-- 
GitLab


From 91eaa8f5f7cac8cc67ed0779d032559616de7049 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 10 Mar 2025 10:01:35 +0100
Subject: [PATCH 38/88] Updated __init__.py

---
 resolve/re/__init__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/resolve/re/__init__.py b/resolve/re/__init__.py
index 9b9a6d6a..2107961a 100644
--- a/resolve/re/__init__.py
+++ b/resolve/re/__init__.py
@@ -6,4 +6,4 @@ from .optimize import optimize
 from .calibration import CalibrationInterpolator, CalibrationDistribution
 from .likelihood_models import CalibrationLikelihoodFixedCovarianceModel, CalibrationLikelihoodVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
 from .likelihood import CalibrationLikelihood, ImagingLikelihood
-from .sugar import Bulk_CF
\ No newline at end of file
+from .sugar import Bulk_CF_AntennaTimeDomain, Bulk_CF_VisibilityDomain
\ No newline at end of file
-- 
GitLab


From cf6aa86218770409335f74fcd95991c42ded7cda Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 10 Mar 2025 10:24:56 +0100
Subject: [PATCH 39/88] Removed Bulk_CF_VisibilityDomain

---
 resolve/re/__init__.py |  2 +-
 resolve/re/sugar.py    | 51 ------------------------------------------
 2 files changed, 1 insertion(+), 52 deletions(-)

diff --git a/resolve/re/__init__.py b/resolve/re/__init__.py
index 2107961a..9cb6e714 100644
--- a/resolve/re/__init__.py
+++ b/resolve/re/__init__.py
@@ -6,4 +6,4 @@ from .optimize import optimize
 from .calibration import CalibrationInterpolator, CalibrationDistribution
 from .likelihood_models import CalibrationLikelihoodFixedCovarianceModel, CalibrationLikelihoodVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
 from .likelihood import CalibrationLikelihood, ImagingLikelihood
-from .sugar import Bulk_CF_AntennaTimeDomain, Bulk_CF_VisibilityDomain
\ No newline at end of file
+from .sugar import Bulk_CF_AntennaTimeDomain
\ No newline at end of file
diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index 8344d504..950c9ec9 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -61,55 +61,4 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
                 for f in self._fields[p][a].keys():
                     powerspectra_dct[p][a][f] = self._powerspectra[p][a][f](x)
         
-        return powerspectra_dct
-    
-class Bulk_CF_VisibilityDomain(jft.Model):
-    def __init__(self,dct_offset,dct_ps,name,polarizations,frequencies):
-
-        self._fields = {}
-        self._powerspectra = {}
-        bulk_cf_init = None
-
-        for p in polarizations:
-            self._fields[p] = {}
-            self._powerspectra[p] = {}
-
-            for f in frequencies:
-                cfm = jft.CorrelatedFieldMaker(f"{name}_{p}_freq{f}_")
-                cfm.set_amplitude_total_offset(**dct_offset)
-                cfm.add_fluctuations(**dct_ps)
-
-                self._fields[p][f] = cfm.finalize()
-                self._powerspectra[p][f] = cfm.power_spectrum
-
-                if bulk_cf_init is not None:
-                    bulk_cf_init = bulk_cf_init | self._fields[p][f].init
-                else:
-                    bulk_cf_init = self._fields[p][f].init
-
-        super().__init__(init=bulk_cf_init)
-
-    def __call__(self,x):
-        return jnp.swapaxes(jnp.array([[self._fields[p][f](x) for f in self._fields[p].keys()]for p in self._fields.keys()]),1,2)
-    
-    def fields_to_dict(self,x):
-        field_dct = {}
-
-        for p in self._fields.keys():
-            field_dct[p] = {}
-
-            for f in self._fields[p].keys():
-                field_dct[p][f] = self._fields[p][f](x)
-        
-        return field_dct
-    
-    def powerspectra_to_dict(self,x):
-        powerspectra_dct = {}
-
-        for p in self._fields.keys():
-            powerspectra_dct[p] = {}
-
-            for f in self._fields[p].keys():
-                powerspectra_dct[p][f] = self._powerspectra[p][f](x)
-        
         return powerspectra_dct
\ No newline at end of file
-- 
GitLab


From 029e8b2d3a056d6eaed8df3c7ddcc04fac53fe2b Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 10 Mar 2025 11:26:51 +0100
Subject: [PATCH 40/88] Removed bugs from ImagingLikelihood

---
 resolve/re/likelihood.py | 41 +++++++++++++++++++++++++++++++---------
 1 file changed, 32 insertions(+), 9 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index 8795cf20..64c8eeb1 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -21,6 +21,9 @@ def CalibrationLikelihood(
     model_d = _obj2list(model_visibilities,ArrayImpl)
     log_inv_covs = _obj2list(log_inverse_covariance_operator,Model)
 
+    if len(set([len(obs),len(cops),len(model_d),len(log_inv_covs)])) != 1:
+        raise ValueError("observation, calibration_operator, model_visibilities and log_inverse_covariance_operator must have the same number of elements")
+
     lh_sum = None
     for oo, cop, model_vis, log_inv_cov in zip(obs,cops,model_d,log_inv_covs):    
         mask = oo.mask.val
@@ -62,24 +65,44 @@ def ImagingLikelihood(
     nthreads=1,
     backend="ducc0",
 ):
-    likelihoods = []
-    for ii, (obs, cop, cfld, log_inv_cov) in enumerate(zip(observation,calibration_operator,calibration_field,log_inverse_covariance_operator)):    
-        R = InterferometryResponse(obs,sky_domain_dict,do_wgridding,epsilon,nthreads,verbosity,backend)
-        mask = obs.mask.val
+    obs = _obj2list(observation,Observation)
+    cops = _obj2list(calibration_operator,CalibrationDistribution)
+    cflds = _obj2list(calibration_field,ArrayImpl)
+    log_inv_covs = _obj2list(log_inverse_covariance_operator,Model)
+
+    if len(set([len(obs),len(cops),len(cflds),len(log_inv_covs)])) != 1:
+        raise ValueError("observation, log_inverse_covariance_operator, calibration_operator and calibration_field must have the same number of elements")
 
-        flagged_data = obs.vis.val[mask]
+    lh_sum = None
+    
+    for ii, (oo, cop, cfld, log_inv_cov) in enumerate(zip(obs,cops,cflds,log_inv_covs)):
+        if cfld is not None and cop is not None:
+            raise ValueError(
+                f"Can't set calibration operator and calibration field simultaneously at index {ii}"
+            )
+
+        R = InterferometryResponse(oo,sky_domain_dict,do_wgridding,epsilon,nthreads,verbosity,backend)
+        mask = oo.mask.val
+
+        flagged_data = oo.vis.val[mask]
 
         if log_inv_cov is None:
             model = ImagingLikelihoodFixedCovarianceModel(R,sky_operator,mask,cop,cfld)
-            flagged_inv_cov = obs.weight.val[mask]
+            flagged_inv_cov = oo.weight.val[mask]
         
             lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
 
         else:
-            model = ImagingLikelihoodVariableCovarianceModel(R,sky_operator,log_inv_cov,cfld)
+            model = ImagingLikelihoodVariableCovarianceModel(R,sky_operator,log_inv_cov,mask,cop,cfld)
 
             lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=True)
 
-        likelihoods.append(lh.amend(model))
+        lh_with_model = lh.amend(model)
+        lh_with_model._domain = jft.Vector(lh_with_model._domain)
+
+        if lh_sum is not None:
+            lh_sum += lh_with_model
+        else:
+            lh_sum = lh_with_model
     
-    return sum(likelihoods)
\ No newline at end of file
+    return lh_sum
\ No newline at end of file
-- 
GitLab


From 68bc23cbc0d695c14ec0d8af0d27c44cdd73d69f Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 12 Mar 2025 10:28:09 +0100
Subject: [PATCH 41/88] Changed inverse covariance to sqrt of inverse
 covariance as nifty.re expects it

---
 resolve/re/likelihood_models.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/resolve/re/likelihood_models.py b/resolve/re/likelihood_models.py
index 0e27fe6b..3a074507 100644
--- a/resolve/re/likelihood_models.py
+++ b/resolve/re/likelihood_models.py
@@ -27,10 +27,10 @@ class CalibrationLikelihoodVariableCovarianceModel(jft.Model):
         data_model = self._vis*self._cop(x)
         flagged_data_model = data_model[self._mask]
 
-        inv_cov = jnp.exp(self._log_inv_cov(x))
-        flagged_inv_cov = inv_cov[self._mask]
+        inv_std = jnp.exp(0.5*self._log_inv_cov(x))
+        flagged_inv_std = inv_std[self._mask]
         
-        return (flagged_data_model,flagged_inv_cov)
+        return (flagged_data_model,flagged_inv_std)
 
 class ImagingLikelihoodFixedCovarianceModel(jft.Model):
     def __init__(self, R, sky_operator, mask, calibration_operator=None, calibration_field=None):
@@ -80,7 +80,7 @@ class ImagingLikelihoodVariableCovarianceModel(jft.Model):
 
         flagged_data_model = data_model[self._mask]
 
-        inv_cov = jnp.exp(self._log_inv_cov(x))
-        flagged_inv_cov = inv_cov[self._mask]
+        inv_std = jnp.exp(0.5*self._log_inv_cov(x))
+        flagged_inv_std = inv_std[self._mask]
         
-        return (flagged_data_model,flagged_inv_cov)
\ No newline at end of file
+        return (flagged_data_model,flagged_inv_std)
\ No newline at end of file
-- 
GitLab


From 79dbcff51e69b8286340002757876a317cdd7097 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 12 Mar 2025 10:30:24 +0100
Subject: [PATCH 42/88] Rewrote summation of likelihood energies for multiple
 inputs; Added dtype check for complex inputs in Variable Covariance Gaussian

---
 resolve/re/likelihood.py | 26 ++++++++++++--------------
 1 file changed, 12 insertions(+), 14 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index 64c8eeb1..42f6bb85 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -1,5 +1,9 @@
+import jax.numpy as jnp
 import nifty8.re as jft
 
+from functools import reduce
+from operator import add
+
 from .response import InterferometryResponse
 from .likelihood_models import CalibrationLikelihoodFixedCovarianceModel, CalibrationLikelihoodVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
 
@@ -24,7 +28,7 @@ def CalibrationLikelihood(
     if len(set([len(obs),len(cops),len(model_d),len(log_inv_covs)])) != 1:
         raise ValueError("observation, calibration_operator, model_visibilities and log_inverse_covariance_operator must have the same number of elements")
 
-    lh_sum = None
+    likelihoods = []
     for oo, cop, model_vis, log_inv_cov in zip(obs,cops,model_d,log_inv_covs):    
         mask = oo.mask.val
 
@@ -39,17 +43,14 @@ def CalibrationLikelihood(
         else:
             model = CalibrationLikelihoodVariableCovarianceModel(cop,model_vis,log_inv_cov,mask)
 
-            lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=True)
+            lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=jnp.iscomplexobj(oo.vis.val))
         
         lh_with_model = lh.amend(model)
         lh_with_model._domain = jft.Vector(lh_with_model._domain)
 
-        if lh_sum is not None:
-            lh_sum += lh_with_model
-        else:
-            lh_sum = lh_with_model
+        likelihoods.append(lh_with_model)
     
-    return lh_sum
+    return reduce(add,likelihoods)
 
 
 def ImagingLikelihood(
@@ -73,7 +74,7 @@ def ImagingLikelihood(
     if len(set([len(obs),len(cops),len(cflds),len(log_inv_covs)])) != 1:
         raise ValueError("observation, log_inverse_covariance_operator, calibration_operator and calibration_field must have the same number of elements")
 
-    lh_sum = None
+    likelihoods = []
     
     for ii, (oo, cop, cfld, log_inv_cov) in enumerate(zip(obs,cops,cflds,log_inv_covs)):
         if cfld is not None and cop is not None:
@@ -95,14 +96,11 @@ def ImagingLikelihood(
         else:
             model = ImagingLikelihoodVariableCovarianceModel(R,sky_operator,log_inv_cov,mask,cop,cfld)
 
-            lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=True)
+            lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=jnp.iscomplexobj(oo.vis.val))
 
         lh_with_model = lh.amend(model)
         lh_with_model._domain = jft.Vector(lh_with_model._domain)
 
-        if lh_sum is not None:
-            lh_sum += lh_with_model
-        else:
-            lh_sum = lh_with_model
+        likelihoods.append(lh_with_model)
     
-    return lh_sum
\ No newline at end of file
+    return reduce(add,likelihoods)
\ No newline at end of file
-- 
GitLab


From e36fe3da10b3f2175d9b5f4805c51d6111d8dcf6 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 12 Mar 2025 21:33:22 +0100
Subject: [PATCH 43/88] Removed bug in initalizer assignment

---
 resolve/re/likelihood_models.py | 14 ++++++++++----
 1 file changed, 10 insertions(+), 4 deletions(-)

diff --git a/resolve/re/likelihood_models.py b/resolve/re/likelihood_models.py
index 3a074507..3a9cd96d 100644
--- a/resolve/re/likelihood_models.py
+++ b/resolve/re/likelihood_models.py
@@ -37,14 +37,18 @@ class ImagingLikelihoodFixedCovarianceModel(jft.Model):
         self._R = R
         self._sky = sky_operator
         self._mask = mask
+        
+        if calibration_field is not None:
+            self._cal_fld = calibration_field
 
         if calibration_operator is not None:
             self._cal_op = calibration_operator
             super().__init__(init=self._R.init | self._cal_op.init)
-        elif calibration_field is not None:
-            self._cal_fld = calibration_field
+        else:
             super().__init__(init=self._R.init)
 
+        
+
     def __call__(self,x):
         if self._cal_op is not None:
             data_model = self._cal_op(x)*self._R(self._sky(x))
@@ -63,11 +67,13 @@ class ImagingLikelihoodVariableCovarianceModel(jft.Model):
         self._mask = mask
         self._log_inv_cov = log_inverse_covariance_model
 
+        if calibration_field is not None:
+            self._cal_fld = calibration_field
+
         if calibration_operator is not None:
             self._cal_op = calibration_operator
             super().__init__(init=self._R.init | self._log_inv_cov.init | self._cal_op.init)
-        elif calibration_field is not None:
-            self._cal_fld = calibration_field
+        else:
             super().__init__(init=self._R.init | self._log_inv_cov.init)
 
     def __call__(self,x):
-- 
GitLab


From 42bf08971f3100790666e99d83f6ec5ca6772e5a Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 12 Mar 2025 21:38:36 +0100
Subject: [PATCH 44/88] Corrected false assigment of init method

---
 resolve/re/likelihood_models.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/resolve/re/likelihood_models.py b/resolve/re/likelihood_models.py
index 3a9cd96d..b05401e8 100644
--- a/resolve/re/likelihood_models.py
+++ b/resolve/re/likelihood_models.py
@@ -43,9 +43,9 @@ class ImagingLikelihoodFixedCovarianceModel(jft.Model):
 
         if calibration_operator is not None:
             self._cal_op = calibration_operator
-            super().__init__(init=self._R.init | self._cal_op.init)
+            super().__init__(init=self._sky.init | self._cal_op.init)
         else:
-            super().__init__(init=self._R.init)
+            super().__init__(init=self._sky.init)
 
         
 
@@ -72,9 +72,9 @@ class ImagingLikelihoodVariableCovarianceModel(jft.Model):
 
         if calibration_operator is not None:
             self._cal_op = calibration_operator
-            super().__init__(init=self._R.init | self._log_inv_cov.init | self._cal_op.init)
+            super().__init__(init=self._sky.init | self._log_inv_cov.init | self._cal_op.init)
         else:
-            super().__init__(init=self._R.init | self._log_inv_cov.init)
+            super().__init__(init=self._sky.init | self._log_inv_cov.init)
 
     def __call__(self,x):
         if self._cal_op is not None:
-- 
GitLab


From 39751d087eba774d23611ecd383a1a3fdc3d1c4f Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Thu, 13 Mar 2025 00:06:33 +0100
Subject: [PATCH 45/88] Cleaned up crowded assignment cases in both
 ImagingLikelihood models

---
 resolve/re/likelihood_models.py | 24 ++++++++++++++----------
 1 file changed, 14 insertions(+), 10 deletions(-)

diff --git a/resolve/re/likelihood_models.py b/resolve/re/likelihood_models.py
index b05401e8..11155474 100644
--- a/resolve/re/likelihood_models.py
+++ b/resolve/re/likelihood_models.py
@@ -34,15 +34,17 @@ class CalibrationLikelihoodVariableCovarianceModel(jft.Model):
 
 class ImagingLikelihoodFixedCovarianceModel(jft.Model):
     def __init__(self, R, sky_operator, mask, calibration_operator=None, calibration_field=None):
+        if (calibration_operator is not None) and (calibration_field is not None):
+            raise ValueError("You can either set a calibration operator or a calibration field")
+
         self._R = R
         self._sky = sky_operator
         self._mask = mask
-        
-        if calibration_field is not None:
-            self._cal_fld = calibration_field
 
-        if calibration_operator is not None:
-            self._cal_op = calibration_operator
+        self._cal_op = calibration_operator
+        self._cal_fld = calibration_field
+
+        if self._cal_op is not None:
             super().__init__(init=self._sky.init | self._cal_op.init)
         else:
             super().__init__(init=self._sky.init)
@@ -62,16 +64,18 @@ class ImagingLikelihoodFixedCovarianceModel(jft.Model):
         
 class ImagingLikelihoodVariableCovarianceModel(jft.Model):
     def __init__(self, R, sky_operator, log_inverse_covariance_model, mask, calibration_operator=None, calibration_field=None):
+        if (calibration_operator is not None) and (calibration_field is not None):
+            raise ValueError("You can either set a calibration operator or a calibration field")
+        
         self._R = R
         self._sky = sky_operator
         self._mask = mask
         self._log_inv_cov = log_inverse_covariance_model
+        
+        self._cal_op = calibration_operator
+        self._cal_fld = calibration_field
 
-        if calibration_field is not None:
-            self._cal_fld = calibration_field
-
-        if calibration_operator is not None:
-            self._cal_op = calibration_operator
+        if self._cal_op is not None:
             super().__init__(init=self._sky.init | self._log_inv_cov.init | self._cal_op.init)
         else:
             super().__init__(init=self._sky.init | self._log_inv_cov.init)
-- 
GitLab


From a802862d3e31b5bd31ccc4643b8eec1ca5a5652e Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Thu, 13 Mar 2025 14:36:22 +0100
Subject: [PATCH 46/88] Added docstrings and input typing

---
 resolve/re/calibration.py | 61 +++++++++++++++++++++++++++++++++++++--
 1 file changed, 59 insertions(+), 2 deletions(-)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index d2075697..b8952119 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -1,10 +1,41 @@
 import nifty8.re as jft
 import jax.scipy as jsc
 import jax.numpy as jnp
+
+from ..data.observation import Observation
+from sugar import Bulk_CF_AntennaTimeDomain
+from numpy import ndarray
+
 from functools import partial
 
 class CalibrationDistribution(jft.Model):
-    def __init__(self,observation, phase_fields, log_amplitude_fields,dt):
+    """
+    Computes the calibration operator from given observation data.
+
+    Parameters
+    ----------
+    observation: Observation
+        Observation object from which are the antenna and temporal information corresponding to 
+        the visibilites are extracted.
+    phase_fields: Bulk_CF_AntennaTimeDomain
+        Correlated fields on antenna-time space for phases of calibration solutions.
+    log_amplitude_fields: Bulk_CF_AntennaTimeDomain
+        Correlated fields on antenna-time space for log amplitude of calibration solutions.
+    dt: float
+        Distances between time points on time axis. Has to be the same distance of time points,
+        which is used for phase_fields and log_amplitude fields.
+
+    Note
+    ----
+    Currently, only uniformly spaced time axis are supported.
+    """
+    def __init__(
+            self,
+            observation: Observation, 
+            phase_fields: Bulk_CF_AntennaTimeDomain, 
+            log_amplitude_fields: Bulk_CF_AntennaTimeDomain,
+            dt: float
+            ):
         ap = observation.antenna_positions
         target_shape = observation.vis.shape
         self._cop1 = CalibrationInterpolator(ap.ant1, ap.time, dt, target_shape)
@@ -22,7 +53,33 @@ class CalibrationDistribution(jft.Model):
         return jnp.exp(res_logamp + res_phase)
 
 class CalibrationInterpolator():
-    def __init__(self,ant_col, time_col,dt,target_shape):
+    """
+    Interpolates visibilites for a specific sequence of antenna-time pairs given the visibilities
+    on an evenly spaced antenna-time grid.
+
+    Parameters
+    ----------
+    ant_col: ndarray
+        Antenna points to which one wants to interpolate
+    time_col: ndarry
+        Time points to which one wants to interpolate
+    dt: float
+        Distances between time points on time axis.
+    target_shape: tuple
+        Shape of output when calling this class. First element should encode number of 
+        polarization directions and last element should encode number of frequencies
+    
+    Note
+    ----
+    Currently, only uniformly spaced time axis are supported.
+    """
+    def __init__(
+            self,
+            ant_col: ndarray, 
+            time_col: ndarray,
+            dt: float,
+            target_shape: tuple
+            ):
         # Input shape follows (n_pol,n_antenna,n_timesteps,n_freq)
         # Output shape follows (n_pol,n_visibilities,n_freq)
         # The model assumes that you have grid with constant width in antenna and times
-- 
GitLab


From 6af245a149aba6d8102ce01b576e6c039def7bf2 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Thu, 13 Mar 2025 14:56:46 +0100
Subject: [PATCH 47/88] Added docstrings and input typings

---
 resolve/re/sugar.py | 41 ++++++++++++++++++++++++++++++++++++++---
 1 file changed, 38 insertions(+), 3 deletions(-)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index 950c9ec9..c654829e 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -1,9 +1,38 @@
 import jax.numpy as jnp
-
 import nifty8.re as jft
 
+from typing import Iterable
+
 class Bulk_CF_AntennaTimeDomain(jft.Model):
-    def __init__(self,dct_offset,dct_ps,name,polarizations,frequencies,antennas):
+    """
+    Creates multiple independant correlated fields from the same offset and powerspectra 
+    parameters. Number of correlated fields is the product of the number of polarization
+    directions, antennas and frequencies.
+
+    Parameters
+    ----------
+    dct_offset: dictionary
+        Dictionary containing information about the offset parameters.
+    dct_ps: dictionary
+        Dictionary containing information about the temporal powerspectrum parameters
+    prefix: string
+        Prefix to the names of the parameters of a correlated field.
+    polarizations: Iterable
+        Labels for polarization directions
+    frequencies: Iterable
+        Labels for frequencies:
+    antennas: Iterable
+        Labels for antennas
+    """
+    def __init__(
+            self,
+            dct_offset: dict,
+            dct_ps: dict,
+            prefix: str,
+            polarizations: Iterable,
+            frequencies: Iterable,
+            antennas: Iterable
+            ):
 
         self._fields = {}
         self._powerspectra = {}
@@ -18,7 +47,7 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
                 self._powerspectra[p][a] = {}
 
                 for f in frequencies:
-                    cfm = jft.CorrelatedFieldMaker(f"{name}_{p}_ant{a}_freq{f}_")
+                    cfm = jft.CorrelatedFieldMaker(f"{prefix}_{p}_ant{a}_freq{f}_")
                     cfm.set_amplitude_total_offset(**dct_offset)
                     cfm.add_fluctuations(**dct_ps)
 
@@ -36,6 +65,9 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
         return jnp.swapaxes(jnp.array([[[self._fields[p][a][f](x) for f in self._fields[p][a].keys()]for a in self._fields[p].keys()]for p in self._fields.keys()]),2,3)
     
     def fields_to_dict(self,x):
+        """
+        Gives dictionary of evaluated fields indexed by polarization, antenna and frequency labels
+        """
         field_dct = {}
 
         for p in self._fields.keys():
@@ -50,6 +82,9 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
         return field_dct
     
     def powerspectra_to_dict(self,x):
+        """
+        Gives dictionary of evaluated powerspectra indexed by polarization, antenna and frequency labels
+        """
         powerspectra_dct = {}
 
         for p in self._fields.keys():
-- 
GitLab


From 5d6b7c0828504e35e0894cd0dc0fe0546d52573b Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Thu, 13 Mar 2025 17:17:07 +0100
Subject: [PATCH 48/88] Updated docstring and typing of sugar.py

---
 resolve/re/sugar.py | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index c654829e..2108c9bb 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -1,7 +1,7 @@
 import jax.numpy as jnp
 import nifty8.re as jft
 
-from typing import Iterable
+from typing import Union, Iterable
 
 class Bulk_CF_AntennaTimeDomain(jft.Model):
     """
@@ -17,11 +17,11 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
         Dictionary containing information about the temporal powerspectrum parameters
     prefix: string
         Prefix to the names of the parameters of a correlated field.
-    polarizations: Iterable
+    polarizations: Iterable of int or str
         Labels for polarization directions
-    frequencies: Iterable
+    frequencies: Iterable of int or str
         Labels for frequencies:
-    antennas: Iterable
+    antennas: Iterable of int or str
         Labels for antennas
     """
     def __init__(
@@ -29,9 +29,9 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
             dct_offset: dict,
             dct_ps: dict,
             prefix: str,
-            polarizations: Iterable,
-            frequencies: Iterable,
-            antennas: Iterable
+            polarizations: Iterable[Union[int, str]],
+            frequencies: Iterable[Union[int, str]],
+            antennas: Iterable[Union[int, str]]
             ):
 
         self._fields = {}
-- 
GitLab


From 16fb7478fb84e61e492caed63b94386819b0ea6a Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Thu, 13 Mar 2025 17:36:18 +0100
Subject: [PATCH 49/88] Added docstring and typing to likelihood.py

---
 resolve/re/likelihood.py | 118 +++++++++++++++++++++++++++++++++------
 1 file changed, 101 insertions(+), 17 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index 42f6bb85..d819638c 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -3,23 +3,53 @@ import nifty8.re as jft
 
 from functools import reduce
 from operator import add
+from typing import Union, Iterable
 
 from .response import InterferometryResponse
 from .likelihood_models import CalibrationLikelihoodFixedCovarianceModel, CalibrationLikelihoodVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
+from .calibration import CalibrationDistribution
 
 from ..util import _obj2list
-
 from ..data.observation import Observation
-from .calibration import CalibrationDistribution
+
 from jaxlib.xla_extension import ArrayImpl
 from nifty8.re.model import Model
 
 def CalibrationLikelihood(
-    observation,
-    calibration_operator,
-    model_visibilities,
-    log_inverse_covariance_operator=None
+    observation: Union[Observation, Iterable[Observation]],
+    calibration_operator: Union[CalibrationDistribution, Iterable[CalibrationDistribution]],
+    model_visibilities: Union[jnp.ndarray, Iterable[jnp.ndarray]],
+    log_inverse_covariance_operator: Union[jft.Model,Iterable[jft.Model]] = None
 ):
+    """Versatile calibration likelihood class
+
+    It returns an operator that computes:
+
+    residual = calibration_operator * model_visibilities
+    likelihood = 0.5 * residual^dagger @ inverse_covariance @ residual
+
+    If an inverse_covariance_operator is passed, it is inserted into the above
+    formulae. If it is not passed, 1/observation.weights is used as inverse
+    covariance.
+
+    Parameters
+    ----------
+    observation : Observation or Iterable of Observations
+        Observation object from which observation.vis and potentially
+        observation.weight is used for computing the likelihood.
+
+    calibration_operator : CalibrationDistribution or Iterable of CalibrationDistribution
+        Target needs to be the same as observation.vis.
+
+    model_visibilities jnp.ndarray or Iterable of jnp.ndarray
+        Known model visiblities that are used for calibration. Needs to be
+        defined on the same domain as `observation.vis`.
+
+    log_inverse_covariance_operator : jft.Model or Iterable of jft.Model
+        Optional. Target needs to be the same space as observation.vis. If it is
+        not specified, observation.wgt is taken as covariance.
+    """
+    
     obs = _obj2list(observation,Observation)
     cops = _obj2list(calibration_operator,CalibrationDistribution)
     model_d = _obj2list(model_visibilities,ArrayImpl)
@@ -54,18 +84,72 @@ def CalibrationLikelihood(
 
 
 def ImagingLikelihood(
-    observation,
-    sky_operator,
-    sky_domain_dict,
-    epsilon,
-    do_wgridding,
-    log_inverse_covariance_operator=None,
-    calibration_operator=None,
-    calibration_field=None,
-    verbosity=0,
-    nthreads=1,
-    backend="ducc0",
+    observation: Union[Observation, Iterable[Observation]],
+    sky_operator: Union[jft.Model,Iterable[jft.Model]],
+    sky_domain_dict: dict,
+    epsilon: float,
+    do_wgridding: bool,
+    log_inverse_covariance_operator: Union[jft.Model,Iterable[jft.Model]] = None,
+    calibration_operator: Union[CalibrationDistribution, Iterable[CalibrationDistribution]] = None,
+    calibration_field: Union[jnp.ndarray, Iterable[jnp.ndarray]] = None,
+    verbosity: int = 0,
+    nthreads: int = 1,
+    backend: str = "ducc0",
 ):
+    """Versatile likelihood class.
+
+    If a calibration operator is passed, it returns an operator that computes:
+
+    residual = calibration_operator * (R @ sky_operator)
+    likelihood = 0.5 * residual^dagger @ inverse_covariance @ residual
+
+    Otherwise, it returns an operator that computes:
+
+    residual = R @ sky_operator
+    likelihood = 0.5 * residual^dagger @ inverse_covariance @ residual
+
+    If an inverse_covariance_operator is passed, it is inserted into the above
+    formulae. If it is not passed, 1/observation.weights is used as inverse
+    covariance.
+
+    Parameters
+    ----------
+    observation : Observation or Iterable of Observation
+        Observation objects from which vis, uvw, freq and potentially weight
+        are used for computing the likelihood.
+
+    sky_operator : jft.Model
+        Operator that generates sky.
+
+    sky_domain_dict: dict
+        A dictionary providing information about the discretization of the sky.
+
+    epsilon : float
+
+    do_wgridding : bool
+
+    log_inverse_covariance_operator : jft.Model or Iterable of jft.Model
+        Optional. Target needs to be the same space as observation.vis. If it
+        is not specified, observation.wgt is taken as covariance.
+
+    calibration_operator : CalibrationDistribution or Iterable of CalibrationDistribution
+        Optional. Target needs to be the same as observation.vis.
+
+    calibration_field: jnp.ndarray or Iterable of jnp.ndarray
+        Optional. Domain needs to be the same as observation.vis.
+
+    verbosity : int
+
+    nthreads : int
+
+    backend: string
+
+    Note
+    ----
+    For each observation only either calibration_operator or calibration_field
+    can be set.
+    """
+
     obs = _obj2list(observation,Observation)
     cops = _obj2list(calibration_operator,CalibrationDistribution)
     cflds = _obj2list(calibration_field,ArrayImpl)
-- 
GitLab


From 79da7008758cdf29e8141cf5050c81274036200d Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Fri, 14 Mar 2025 11:14:22 +0100
Subject: [PATCH 50/88] Rewritten evaluation methods for fields and
 powerspectra as mean and std over the input samples

---
 resolve/re/sugar.py | 27 +++++++++++++++++++++------
 1 file changed, 21 insertions(+), 6 deletions(-)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index 2108c9bb..0ba8858d 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -64,9 +64,16 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
     def __call__(self,x):
         return jnp.swapaxes(jnp.array([[[self._fields[p][a][f](x) for f in self._fields[p][a].keys()]for a in self._fields[p].keys()]for p in self._fields.keys()]),2,3)
     
-    def fields_to_dict(self,x):
+    def fields_to_dict(self,samples: jft.Samples):
         """
-        Gives dictionary of evaluated fields indexed by polarization, antenna and frequency labels
+        Gives dictionary of evaluated fields indexed by polarization, antenna and 
+        frequency labels.
+        Each entry consists the (mean,std) of the respective field given the samples.
+
+        Parameters
+        ----------
+        samples: jft.Samples
+            Samples of the inference parameters
         """
         field_dct = {}
 
@@ -77,13 +84,21 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
                 field_dct[p][a] = {}
 
                 for f in self._fields[p][a].keys():
-                    field_dct[p][a][f] = self._fields[p][a][f](x)
+                    field_dct[p][a][f] = jft.mean_and_std(tuple(self._fields[p][a][f](s) for s in samples))
         
         return field_dct
     
-    def powerspectra_to_dict(self,x):
+    def powerspectra_to_dict(self,samples: jft.Samples):
         """
-        Gives dictionary of evaluated powerspectra indexed by polarization, antenna and frequency labels
+        Gives dictionary of evaluated powerspectra indexed by polarization, antenna and
+        frequency labels.
+        Each entry consists the (mean,std) of the respective powerspectrum given the 
+        samples.
+
+        Parameters
+        ----------
+        samples: jft.Samples
+            Samples of the inference parameters
         """
         powerspectra_dct = {}
 
@@ -94,6 +109,6 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
                 powerspectra_dct[p][a] = {}
 
                 for f in self._fields[p][a].keys():
-                    powerspectra_dct[p][a][f] = self._powerspectra[p][a][f](x)
+                    powerspectra_dct[p][a][f] = jft.mean_and_std(tuple(self._powerspectra[p][a][f](s) for s in samples))
         
         return powerspectra_dct
\ No newline at end of file
-- 
GitLab


From 7380ab8a6f4333bf67ebc4ba7ebeb121c52af9fe Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Fri, 14 Mar 2025 15:39:17 +0100
Subject: [PATCH 51/88] Removed unnecessary comments and fixed small file path
 error

---
 resolve/re/calibration.py | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index b8952119..699241b6 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -1,10 +1,13 @@
 import nifty8.re as jft
 import jax.scipy as jsc
 import jax.numpy as jnp
+import numpy as np
+
+from .sugar import Bulk_CF_AntennaTimeDomain
 
 from ..data.observation import Observation
-from sugar import Bulk_CF_AntennaTimeDomain
-from numpy import ndarray
+
+
 
 from functools import partial
 
@@ -59,9 +62,9 @@ class CalibrationInterpolator():
 
     Parameters
     ----------
-    ant_col: ndarray
+    ant_col: numpy.ndarray
         Antenna points to which one wants to interpolate
-    time_col: ndarry
+    time_col: numpy.ndarry
         Time points to which one wants to interpolate
     dt: float
         Distances between time points on time axis.
@@ -75,14 +78,11 @@ class CalibrationInterpolator():
     """
     def __init__(
             self,
-            ant_col: ndarray, 
-            time_col: ndarray,
+            ant_col: np.ndarray, 
+            time_col: np.ndarray,
             dt: float,
             target_shape: tuple
             ):
-        # Input shape follows (n_pol,n_antenna,n_timesteps,n_freq)
-        # Output shape follows (n_pol,n_visibilities,n_freq)
-        # The model assumes that you have grid with constant width in antenna and times
 
         coords = [ant_col,time_col/dt]
 
-- 
GitLab


From 60952fb5995553ba036b6680496b743f14f7476a Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Fri, 14 Mar 2025 16:05:29 +0100
Subject: [PATCH 52/88] Added docstring and typing for likelihood_models.py

---
 resolve/re/likelihood_models.py | 101 ++++++++++++++++++++++++++++++--
 1 file changed, 95 insertions(+), 6 deletions(-)

diff --git a/resolve/re/likelihood_models.py b/resolve/re/likelihood_models.py
index 11155474..7d6de5fc 100644
--- a/resolve/re/likelihood_models.py
+++ b/resolve/re/likelihood_models.py
@@ -1,8 +1,30 @@
 import nifty8.re as jft
 import jax.numpy as jnp
+import numpy as np
+
+from .calibration import CalibrationDistribution
+
+from typing import Callable
 
 class CalibrationLikelihoodFixedCovarianceModel(jft.Model):
-    def __init__(self, cop, model_visibilities, mask):
+    """
+    Provides a flagged data model for calibration
+
+    Parameters
+    ----------
+    cop: CalibrationDistribution
+        Calibration operator
+    model_visibilities: jnp.ndarray
+        Assumed visibilities of the point source.
+    mask: np.array
+        Mask as boolean numpy array for good visibilites
+    """
+    def __init__(
+            self, 
+            cop: CalibrationDistribution, 
+            model_visibilities: jnp.ndarray, 
+            mask: np.ndarray
+            ):
         self._cop = cop
         self._vis = model_visibilities
         self._mask = mask
@@ -15,7 +37,27 @@ class CalibrationLikelihoodFixedCovarianceModel(jft.Model):
         return flagged_data_model
     
 class CalibrationLikelihoodVariableCovarianceModel(jft.Model):
-    def __init__(self, cop, model_visibilities, log_inverse_covariance_model, mask):
+    """
+    Provides a combined flagged data model and flagged inverse covariance model for calibration
+
+    Parameters
+    ----------
+    cop: CalibrationDistribution
+        Calibration operator
+    model_visibilities: jnp.ndarray
+        Assumed visibilities of the point source.
+    log_inverse_covariance_model: jft.Model
+        Model for log inverse covariance
+    mask: np.array
+        Mask as boolean numpy array for good visibilites
+    """
+    def __init__(
+            self, 
+            cop: CalibrationDistribution, 
+            model_visibilities: jnp.ndarray, 
+            log_inverse_covariance_model: jft.Model, 
+            mask: np.ndarray
+            ):
         self._cop = cop
         self._vis = model_visibilities
         self._mask = mask
@@ -33,7 +75,30 @@ class CalibrationLikelihoodVariableCovarianceModel(jft.Model):
         return (flagged_data_model,flagged_inv_std)
 
 class ImagingLikelihoodFixedCovarianceModel(jft.Model):
-    def __init__(self, R, sky_operator, mask, calibration_operator=None, calibration_field=None):
+    """
+    Provides a flagged data model for imaging
+
+    Parameters
+    ----------
+    R: Callable
+        Response operator function
+    sky: jft.Model
+        Model for sky
+    mask: np.array
+        Mask as boolean numpy array for good visibilites
+    calibration_operator: CalibrationDistribution
+        Optional. Calibration operator
+    calibration_field: jnp.ndarray
+        Optional. Calibration field
+    """
+    def __init__(
+            self, 
+            R: Callable, 
+            sky_operator: jft.Model, 
+            mask: np.ndarray, 
+            calibration_operator: CalibrationDistribution = None, 
+            calibration_field: jnp.ndarray = None
+            ):
         if (calibration_operator is not None) and (calibration_field is not None):
             raise ValueError("You can either set a calibration operator or a calibration field")
 
@@ -49,8 +114,6 @@ class ImagingLikelihoodFixedCovarianceModel(jft.Model):
         else:
             super().__init__(init=self._sky.init)
 
-        
-
     def __call__(self,x):
         if self._cal_op is not None:
             data_model = self._cal_op(x)*self._R(self._sky(x))
@@ -63,7 +126,33 @@ class ImagingLikelihoodFixedCovarianceModel(jft.Model):
         return flagged_data_model
         
 class ImagingLikelihoodVariableCovarianceModel(jft.Model):
-    def __init__(self, R, sky_operator, log_inverse_covariance_model, mask, calibration_operator=None, calibration_field=None):
+    """
+    Provides a combined flagged data model and flagged inverse covariance model for imaging
+
+    Parameters
+    ----------
+    R: Callable
+        Response operator function
+    sky: jft.Model
+        Model for sky
+    log_inverse_covariance_model: jft.Model
+        Model for log inverse covariance
+    mask: np.array
+        Mask as boolean numpy array for good visibilites
+    calibration_operator: CalibrationDistribution
+        Optional. Calibration operator
+    calibration_field: jnp.ndarray
+        Optional. Calibration field
+    """
+    def __init__(
+            self, 
+            R: Callable, 
+            sky_operator: jft.Model, 
+            log_inverse_covariance_model: jft.Model, 
+            mask: np.ndarray, 
+            calibration_operator: CalibrationDistribution = None, 
+            calibration_field: jnp.ndarray = None
+            ):
         if (calibration_operator is not None) and (calibration_field is not None):
             raise ValueError("You can either set a calibration operator or a calibration field")
         
-- 
GitLab


From 608192c5411b0cb2263c8b73dd48c7783be3e0c6 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 17 Mar 2025 09:47:26 +0100
Subject: [PATCH 53/88] Tidied up joining of init methods of Bulk_CF model

---
 resolve/re/sugar.py | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index 0ba8858d..dba867c5 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -1,6 +1,9 @@
 import jax.numpy as jnp
 import nifty8.re as jft
 
+from functools import reduce
+from operator import or_
+
 from typing import Union, Iterable
 
 class Bulk_CF_AntennaTimeDomain(jft.Model):
@@ -36,7 +39,7 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
 
         self._fields = {}
         self._powerspectra = {}
-        bulk_cf_init = None
+        bulk_cf_init = []
 
         for p in polarizations:
             self._fields[p] = {}
@@ -51,15 +54,14 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
                     cfm.set_amplitude_total_offset(**dct_offset)
                     cfm.add_fluctuations(**dct_ps)
 
-                    self._fields[p][a][f] = cfm.finalize()
+                    cf = cfm.finalize()
+
+                    self._fields[p][a][f] = cf
                     self._powerspectra[p][a][f] = cfm.power_spectrum
 
-                    if bulk_cf_init is not None:
-                        bulk_cf_init = bulk_cf_init | self._fields[p][a][f].init
-                    else:
-                        bulk_cf_init = self._fields[p][a][f].init
+                    bulk_cf_init.append(cf.init)
 
-        super().__init__(init=bulk_cf_init)
+        super().__init__(init=reduce(or_,bulk_cf_init))
 
     def __call__(self,x):
         return jnp.swapaxes(jnp.array([[[self._fields[p][a][f](x) for f in self._fields[p][a].keys()]for a in self._fields[p].keys()]for p in self._fields.keys()]),2,3)
-- 
GitLab


From 4e10c6d541369ab6acbd8beef6f18f3d67dd7f6d Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 17 Mar 2025 22:27:23 +0100
Subject: [PATCH 54/88] Restructed output of fields and powerspectra evaluation
 for clarification

---
 resolve/re/sugar.py | 12 ++++++++++--
 1 file changed, 10 insertions(+), 2 deletions(-)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index dba867c5..d74f2350 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -86,7 +86,11 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
                 field_dct[p][a] = {}
 
                 for f in self._fields[p][a].keys():
-                    field_dct[p][a][f] = jft.mean_and_std(tuple(self._fields[p][a][f](s) for s in samples))
+                    tmp_mean, tmp_std = jft.mean_and_std(tuple(self._fields[p][a][f](s) for s in samples))
+                    field_dct[p][a][f] = {
+                        "mean": tmp_mean,
+                        "std": tmp_std
+                    }
         
         return field_dct
     
@@ -111,6 +115,10 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
                 powerspectra_dct[p][a] = {}
 
                 for f in self._fields[p][a].keys():
-                    powerspectra_dct[p][a][f] = jft.mean_and_std(tuple(self._powerspectra[p][a][f](s) for s in samples))
+                    tmp_mean, tmp_std = jft.mean_and_std(tuple(self._powerspectra[p][a][f](s) for s in samples))
+                    powerspectra_dct[p][a][f] = {
+                        "mean": tmp_mean,
+                        "std": tmp_std
+                    }
         
         return powerspectra_dct
\ No newline at end of file
-- 
GitLab


From 98bd429da0d43618f82f3c830bbd9c7d5795e31a Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 18 Mar 2025 22:04:08 +0100
Subject: [PATCH 55/88] Replaced model functionality with optimized VModel

---
 resolve/re/sugar.py | 112 ++++++++++++++++----------------------------
 1 file changed, 40 insertions(+), 72 deletions(-)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index d74f2350..3fa4fc16 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -1,12 +1,9 @@
 import jax.numpy as jnp
 import nifty8.re as jft
 
-from functools import reduce
-from operator import or_
-
 from typing import Union, Iterable
 
-class Bulk_CF_AntennaTimeDomain(jft.Model):
+class Bulk_CF_AntennaTimeDomain(jft.model.LazyModel):
     """
     Creates multiple independant correlated fields from the same offset and powerspectra 
     parameters. Number of correlated fields is the product of the number of polarization
@@ -36,39 +33,29 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
             frequencies: Iterable[Union[int, str]],
             antennas: Iterable[Union[int, str]]
             ):
+        self._pol = list(polarizations)
+        self._ant = list(antennas)
+        self._freq = list(frequencies)
 
-        self._fields = {}
-        self._powerspectra = {}
-        bulk_cf_init = []
-
-        for p in polarizations:
-            self._fields[p] = {}
-            self._powerspectra[p] = {}
-
-            for a in antennas:
-                self._fields[p][a] = {}
-                self._powerspectra[p][a] = {}
-
-                for f in frequencies:
-                    cfm = jft.CorrelatedFieldMaker(f"{prefix}_{p}_ant{a}_freq{f}_")
-                    cfm.set_amplitude_total_offset(**dct_offset)
-                    cfm.add_fluctuations(**dct_ps)
-
-                    cf = cfm.finalize()
+        self._output_shape = (len(self._pol),len(self._ant),len(self._freq),dct_ps["shape"][0])
+        
+        cfm = jft.CorrelatedFieldMaker(prefix)
+        cfm.set_amplitude_total_offset(**dct_offset)
+        cfm.add_fluctuations(**dct_ps)
 
-                    self._fields[p][a][f] = cf
-                    self._powerspectra[p][a][f] = cfm.power_spectrum
+        n_total = len(self._pol)*len(self._ant)*len(self._freq)
 
-                    bulk_cf_init.append(cf.init)
+        self._fields = jft.VModel(cfm.finalize(), axis_size=n_total)
+        self._powerspectrum = cfm.power_spectrum
 
-        super().__init__(init=reduce(or_,bulk_cf_init))
+        super().__init__(init=self._fields.init)
 
     def __call__(self,x):
-        return jnp.swapaxes(jnp.array([[[self._fields[p][a][f](x) for f in self._fields[p][a].keys()]for a in self._fields[p].keys()]for p in self._fields.keys()]),2,3)
+        return jnp.swapaxes(jnp.reshape(self._fields(x),self._output_shape),2,3)
     
-    def fields_to_dict(self,samples: jft.Samples):
+    def results_to_dict(self, samples: jft.Samples, mode: str = "fields"):
         """
-        Gives dictionary of evaluated fields indexed by polarization, antenna and 
+        Gives dictionary of evaluated fields or powerspectra indexed by polarization, antenna and 
         frequency labels.
         Each entry consists the (mean,std) of the respective field given the samples.
 
@@ -76,49 +63,30 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
         ----------
         samples: jft.Samples
             Samples of the inference parameters
-        """
-        field_dct = {}
-
-        for p in self._fields.keys():
-            field_dct[p] = {}
-
-            for a in self._fields[p].keys():
-                field_dct[p][a] = {}
-
-                for f in self._fields[p][a].keys():
-                    tmp_mean, tmp_std = jft.mean_and_std(tuple(self._fields[p][a][f](s) for s in samples))
-                    field_dct[p][a][f] = {
-                        "mean": tmp_mean,
-                        "std": tmp_std
-                    }
-        
-        return field_dct
-    
-    def powerspectra_to_dict(self,samples: jft.Samples):
-        """
-        Gives dictionary of evaluated powerspectra indexed by polarization, antenna and
-        frequency labels.
-        Each entry consists the (mean,std) of the respective powerspectrum given the 
-        samples.
+        mode: string
+            Select if either model ("fields") or power spectra ("spectra") should be evaluated.
 
-        Parameters
-        ----------
-        samples: jft.Samples
-            Samples of the inference parameters
+        Note
+        ----
+        Currently only default value for mode is implemented.
         """
-        powerspectra_dct = {}
-
-        for p in self._fields.keys():
-            powerspectra_dct[p] = {}
-
-            for a in self._fields[p].keys():
-                powerspectra_dct[p][a] = {}
-
-                for f in self._fields[p][a].keys():
-                    tmp_mean, tmp_std = jft.mean_and_std(tuple(self._powerspectra[p][a][f](s) for s in samples))
-                    powerspectra_dct[p][a][f] = {
-                        "mean": tmp_mean,
-                        "std": tmp_std
-                    }
         
-        return powerspectra_dct
\ No newline at end of file
+        if mode == "fields":
+            mean, std = jft.mean_and_std(tuple(self._fields(s) for s in samples))
+
+            res = {
+                {
+                    {
+                        {
+                            "mean": mean[i,j,:,k],
+                            "std": std[i,j,:,k]
+                        } for k,freq in enumerate(self._freq)
+                    } for j,ant in enumerate(self._ant)
+                } for i,pol in enumerate(self._pol)
+            }
+        elif mode == "spectra":
+            raise NotImplementedError
+        else:
+            raise ValueError("Mode has to be either 'fields' or 'spectra'.")  
+        
+        return res
\ No newline at end of file
-- 
GitLab


From e3a82125765075b8911e643c9f3c1eb7ea1808c8 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 19 Mar 2025 07:48:05 +0100
Subject: [PATCH 56/88] Replaced for loops in CalibrationInterpolator with
 vmap.

---
 resolve/re/calibration.py | 36 ++++++++++++++++--------------------
 1 file changed, 16 insertions(+), 20 deletions(-)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index 699241b6..fd53369e 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -1,16 +1,12 @@
 import nifty8.re as jft
 import jax.scipy as jsc
 import jax.numpy as jnp
-import numpy as np
+import jax
 
 from .sugar import Bulk_CF_AntennaTimeDomain
 
 from ..data.observation import Observation
 
-
-
-from functools import partial
-
 class CalibrationDistribution(jft.Model):
     """
     Computes the calibration operator from given observation data.
@@ -41,8 +37,8 @@ class CalibrationDistribution(jft.Model):
             ):
         ap = observation.antenna_positions
         target_shape = observation.vis.shape
-        self._cop1 = CalibrationInterpolator(ap.ant1, ap.time, dt, target_shape)
-        self._cop2 = CalibrationInterpolator(ap.ant2, ap.time, dt, target_shape)
+        self._cop1 = CalibrationInterpolator(jnp.asarray(ap.ant1), jnp.asarray(ap.time), dt, target_shape)
+        self._cop2 = CalibrationInterpolator(jnp.asarray(ap.ant2), jnp.asarray(ap.time), dt, target_shape)
 
         self._phases = phase_fields
         self._logamps = log_amplitude_fields
@@ -62,9 +58,9 @@ class CalibrationInterpolator():
 
     Parameters
     ----------
-    ant_col: numpy.ndarray
+    ant_col: jnp.ndarray
         Antenna points to which one wants to interpolate
-    time_col: numpy.ndarry
+    time_col: jnp.ndarry
         Time points to which one wants to interpolate
     dt: float
         Distances between time points on time axis.
@@ -78,24 +74,24 @@ class CalibrationInterpolator():
     """
     def __init__(
             self,
-            ant_col: np.ndarray, 
-            time_col: np.ndarray,
+            ant_col: jnp.ndarray, 
+            time_col: jnp.ndarray,
             dt: float,
             target_shape: tuple
             ):
 
         coords = [ant_col,time_col/dt]
 
-        self._li = partial(jsc.ndimage.map_coordinates,coordinates=coords,order=1)
+        self._li = jax.partial(jsc.ndimage.map_coordinates,coordinates=coords,order=1)
 
-        self._output_shape = target_shape
+        self._n_pol,_,self._n_freq = target_shape
 
     def __call__(self,x):
-        res = jnp.empty(self._output_shape)
-        n_pol, _, n_freq = self._output_shape
-        for pol in range(n_pol):
-            for freq in range(n_freq):
-                val = x[pol, :, :, freq]
-                res = res.at[pol, :, freq].set(self._li(val))
+        res = jax.vmap(
+            jax.vmap(
+                lambda pol, freq: self._li(x[pol, :, :, freq])
+                ,in_axes=(None,0)
+            ),in_axes=(0,None)
+        )
         
-        return res
\ No newline at end of file
+        return res(jnp.arange(self._n_pol), jnp.arange(self._n_freq))
\ No newline at end of file
-- 
GitLab


From 60de179c8df18f6fa809efb7686b03950f5367b9 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 19 Mar 2025 09:15:08 +0100
Subject: [PATCH 57/88] Removed numpy typing

---
 resolve/re/likelihood_models.py | 17 ++++++++---------
 1 file changed, 8 insertions(+), 9 deletions(-)

diff --git a/resolve/re/likelihood_models.py b/resolve/re/likelihood_models.py
index 7d6de5fc..7d0d2628 100644
--- a/resolve/re/likelihood_models.py
+++ b/resolve/re/likelihood_models.py
@@ -1,6 +1,5 @@
 import nifty8.re as jft
 import jax.numpy as jnp
-import numpy as np
 
 from .calibration import CalibrationDistribution
 
@@ -16,14 +15,14 @@ class CalibrationLikelihoodFixedCovarianceModel(jft.Model):
         Calibration operator
     model_visibilities: jnp.ndarray
         Assumed visibilities of the point source.
-    mask: np.array
+    mask: jnp.array
         Mask as boolean numpy array for good visibilites
     """
     def __init__(
             self, 
             cop: CalibrationDistribution, 
             model_visibilities: jnp.ndarray, 
-            mask: np.ndarray
+            mask: jnp.ndarray
             ):
         self._cop = cop
         self._vis = model_visibilities
@@ -48,7 +47,7 @@ class CalibrationLikelihoodVariableCovarianceModel(jft.Model):
         Assumed visibilities of the point source.
     log_inverse_covariance_model: jft.Model
         Model for log inverse covariance
-    mask: np.array
+    mask: jnp.array
         Mask as boolean numpy array for good visibilites
     """
     def __init__(
@@ -56,7 +55,7 @@ class CalibrationLikelihoodVariableCovarianceModel(jft.Model):
             cop: CalibrationDistribution, 
             model_visibilities: jnp.ndarray, 
             log_inverse_covariance_model: jft.Model, 
-            mask: np.ndarray
+            mask: jnp.ndarray
             ):
         self._cop = cop
         self._vis = model_visibilities
@@ -84,7 +83,7 @@ class ImagingLikelihoodFixedCovarianceModel(jft.Model):
         Response operator function
     sky: jft.Model
         Model for sky
-    mask: np.array
+    mask: jnp.array
         Mask as boolean numpy array for good visibilites
     calibration_operator: CalibrationDistribution
         Optional. Calibration operator
@@ -95,7 +94,7 @@ class ImagingLikelihoodFixedCovarianceModel(jft.Model):
             self, 
             R: Callable, 
             sky_operator: jft.Model, 
-            mask: np.ndarray, 
+            mask: jnp.ndarray, 
             calibration_operator: CalibrationDistribution = None, 
             calibration_field: jnp.ndarray = None
             ):
@@ -137,7 +136,7 @@ class ImagingLikelihoodVariableCovarianceModel(jft.Model):
         Model for sky
     log_inverse_covariance_model: jft.Model
         Model for log inverse covariance
-    mask: np.array
+    mask: jnp.array
         Mask as boolean numpy array for good visibilites
     calibration_operator: CalibrationDistribution
         Optional. Calibration operator
@@ -149,7 +148,7 @@ class ImagingLikelihoodVariableCovarianceModel(jft.Model):
             R: Callable, 
             sky_operator: jft.Model, 
             log_inverse_covariance_model: jft.Model, 
-            mask: np.ndarray, 
+            mask: jnp.ndarray, 
             calibration_operator: CalibrationDistribution = None, 
             calibration_field: jnp.ndarray = None
             ):
-- 
GitLab


From ebd0be1852bfed18be980e56f21e198f331139ba Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 19 Mar 2025 09:19:26 +0100
Subject: [PATCH 58/88] Converted np.arrays from observations to jnp.arrays

---
 resolve/re/likelihood.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index d819638c..96f117c3 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -60,13 +60,13 @@ def CalibrationLikelihood(
 
     likelihoods = []
     for oo, cop, model_vis, log_inv_cov in zip(obs,cops,model_d,log_inv_covs):    
-        mask = oo.mask.val
+        mask = jnp.asarray(oo.mask.val)
 
-        flagged_data = oo.vis.val[mask]
+        flagged_data = jnp.asarray(oo.vis.val)[mask]
 
         if log_inv_cov is None:
             model = CalibrationLikelihoodFixedCovarianceModel(cop,model_vis,mask)
-            flagged_inv_cov = oo.weight.val[mask]
+            flagged_inv_cov = jnp.asarray(oo.weight.val)[mask]
             
             lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
         
@@ -167,13 +167,13 @@ def ImagingLikelihood(
             )
 
         R = InterferometryResponse(oo,sky_domain_dict,do_wgridding,epsilon,nthreads,verbosity,backend)
-        mask = oo.mask.val
+        mask = jnp.asarray(oo.mask.val)
 
-        flagged_data = oo.vis.val[mask]
+        flagged_data = jnp.asarray(oo.vis.val)[mask]
 
         if log_inv_cov is None:
             model = ImagingLikelihoodFixedCovarianceModel(R,sky_operator,mask,cop,cfld)
-            flagged_inv_cov = oo.weight.val[mask]
+            flagged_inv_cov = jnp.asarray(oo.weight.val)[mask]
         
             lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
 
-- 
GitLab


From b64b80a010fa4c8455ffbc6687bdb67fca08ccbe Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 19 Mar 2025 09:37:10 +0100
Subject: [PATCH 59/88] Imported correct "Partial" function from jax.tree_utils

---
 resolve/re/calibration.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index fd53369e..74d44da0 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -1,7 +1,9 @@
 import nifty8.re as jft
 import jax.scipy as jsc
 import jax.numpy as jnp
-import jax
+
+from jax.tree_util import Partial
+from jax import vmap
 
 from .sugar import Bulk_CF_AntennaTimeDomain
 
@@ -82,13 +84,13 @@ class CalibrationInterpolator():
 
         coords = [ant_col,time_col/dt]
 
-        self._li = jax.partial(jsc.ndimage.map_coordinates,coordinates=coords,order=1)
+        self._li = Partial(jsc.ndimage.map_coordinates,coordinates=coords,order=1)
 
         self._n_pol,_,self._n_freq = target_shape
 
     def __call__(self,x):
-        res = jax.vmap(
-            jax.vmap(
+        res = vmap(
+            vmap(
                 lambda pol, freq: self._li(x[pol, :, :, freq])
                 ,in_axes=(None,0)
             ),in_axes=(0,None)
-- 
GitLab


From 881a7acadc3c019606b6be41e27cf0f3e36aba55 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 19 Mar 2025 10:19:17 +0100
Subject: [PATCH 60/88] Fixed output shape problem with vmap

---
 resolve/re/calibration.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index 74d44da0..290f3fa4 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -92,7 +92,7 @@ class CalibrationInterpolator():
         res = vmap(
             vmap(
                 lambda pol, freq: self._li(x[pol, :, :, freq])
-                ,in_axes=(None,0)
+                ,in_axes=(None,0), out_axes=1
             ),in_axes=(0,None)
         )
         
-- 
GitLab


From 0617fdff922491f4c47a6658faa72cc7706023d4 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 19 Mar 2025 11:32:36 +0100
Subject: [PATCH 61/88] Bulk_CF_AntennaTimeDomain is child of jft.Model now

---
 resolve/re/sugar.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index 3fa4fc16..61473cd8 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -3,7 +3,7 @@ import nifty8.re as jft
 
 from typing import Union, Iterable
 
-class Bulk_CF_AntennaTimeDomain(jft.model.LazyModel):
+class Bulk_CF_AntennaTimeDomain(jft.Model):
     """
     Creates multiple independant correlated fields from the same offset and powerspectra 
     parameters. Number of correlated fields is the product of the number of polarization
-- 
GitLab


From de687ffe5e9b84dc8de32a296db32c66f87f3fc9 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 19 Mar 2025 21:40:04 +0100
Subject: [PATCH 62/88] Removed evaluation functionality because of slicing
 problems

---
 resolve/re/sugar.py | 41 +++--------------------------------------
 1 file changed, 3 insertions(+), 38 deletions(-)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index 61473cd8..5ec9ad3c 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -1,6 +1,8 @@
 import jax.numpy as jnp
 import nifty8.re as jft
 
+from jax.lax import dynamic_slice_in_dim
+
 from typing import Union, Iterable
 
 class Bulk_CF_AntennaTimeDomain(jft.Model):
@@ -52,41 +54,4 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
 
     def __call__(self,x):
         return jnp.swapaxes(jnp.reshape(self._fields(x),self._output_shape),2,3)
-    
-    def results_to_dict(self, samples: jft.Samples, mode: str = "fields"):
-        """
-        Gives dictionary of evaluated fields or powerspectra indexed by polarization, antenna and 
-        frequency labels.
-        Each entry consists the (mean,std) of the respective field given the samples.
-
-        Parameters
-        ----------
-        samples: jft.Samples
-            Samples of the inference parameters
-        mode: string
-            Select if either model ("fields") or power spectra ("spectra") should be evaluated.
-
-        Note
-        ----
-        Currently only default value for mode is implemented.
-        """
-        
-        if mode == "fields":
-            mean, std = jft.mean_and_std(tuple(self._fields(s) for s in samples))
-
-            res = {
-                {
-                    {
-                        {
-                            "mean": mean[i,j,:,k],
-                            "std": std[i,j,:,k]
-                        } for k,freq in enumerate(self._freq)
-                    } for j,ant in enumerate(self._ant)
-                } for i,pol in enumerate(self._pol)
-            }
-        elif mode == "spectra":
-            raise NotImplementedError
-        else:
-            raise ValueError("Mode has to be either 'fields' or 'spectra'.")  
-        
-        return res
\ No newline at end of file
+    
\ No newline at end of file
-- 
GitLab


From c4fe9b9c80320dda32fe34536e166f92123a01ec Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 19 Mar 2025 21:42:10 +0100
Subject: [PATCH 63/88] Added method to extract correlated field maker

---
 resolve/re/sugar.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index 5ec9ad3c..9f0ee4b0 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -54,4 +54,6 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
 
     def __call__(self,x):
         return jnp.swapaxes(jnp.reshape(self._fields(x),self._output_shape),2,3)
-    
\ No newline at end of file
+    
+    def get_powerspectrum(self):
+        return self._powerspectrum
\ No newline at end of file
-- 
GitLab


From b2c0a8eff7b206759b23717638ad28ed7b739e94 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 25 Mar 2025 22:00:40 +0100
Subject: [PATCH 64/88] Introduced possibility of labels for likelihoods, when
 multiple dataset are used as input

---
 resolve/re/likelihood.py | 84 ++++++++++++++++++++++++++--------------
 1 file changed, 56 insertions(+), 28 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index 96f117c3..29cbdbd3 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -1,15 +1,13 @@
 import jax.numpy as jnp
 import nifty8.re as jft
 
-from functools import reduce
-from operator import add
 from typing import Union, Iterable
 
 from .response import InterferometryResponse
 from .likelihood_models import CalibrationLikelihoodFixedCovarianceModel, CalibrationLikelihoodVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
 from .calibration import CalibrationDistribution
 
-from ..util import _obj2list
+from ..util import _obj2list, _duplicate
 from ..data.observation import Observation
 
 from jaxlib.xla_extension import ArrayImpl
@@ -19,7 +17,8 @@ def CalibrationLikelihood(
     observation: Union[Observation, Iterable[Observation]],
     calibration_operator: Union[CalibrationDistribution, Iterable[CalibrationDistribution]],
     model_visibilities: Union[jnp.ndarray, Iterable[jnp.ndarray]],
-    log_inverse_covariance_operator: Union[jft.Model,Iterable[jft.Model]] = None
+    log_inverse_covariance_operator: Union[jft.Model,Iterable[jft.Model]] = None,
+    likelihood_labels: Union[str,Iterable[str]] = None
 ):
     """Versatile calibration likelihood class
 
@@ -39,27 +38,37 @@ def CalibrationLikelihood(
         observation.weight is used for computing the likelihood.
 
     calibration_operator : CalibrationDistribution or Iterable of CalibrationDistribution
-        Target needs to be the same as observation.vis.
+        Target needs to be the same as observation.vis. The same amount of elements as 
+        number of observations should be provided.
 
     model_visibilities jnp.ndarray or Iterable of jnp.ndarray
         Known model visiblities that are used for calibration. Needs to be
-        defined on the same domain as `observation.vis`.
+        defined on the same domain as `observation.vis`. The same amount of elements as 
+        number of observations should be provided.
 
     log_inverse_covariance_operator : jft.Model or Iterable of jft.Model
         Optional. Target needs to be the same space as observation.vis. If it is
-        not specified, observation.wgt is taken as covariance.
+        not specified, observation.wgt is taken as covariance. If used, the same
+        amount of elements as number of observations should be provided.
+
+    likelihood_labels: string or Iterable of string
+        Optional. Append labels to individual likelihoods which are shown in the minisanity
+        for overview. If used, the same amount of elements as number of observations 
+        should be provided.
     """
     
+    
     obs = _obj2list(observation,Observation)
     cops = _obj2list(calibration_operator,CalibrationDistribution)
     model_d = _obj2list(model_visibilities,ArrayImpl)
-    log_inv_covs = _obj2list(log_inverse_covariance_operator,Model)
+    log_inv_covs = _duplicate(_obj2list(log_inverse_covariance_operator,Model),len(obs))
+    labels = _duplicate(_obj2list(likelihood_labels,str),len(obs))
 
-    if len(set([len(obs),len(cops),len(model_d),len(log_inv_covs)])) != 1:
-        raise ValueError("observation, calibration_operator, model_visibilities and log_inverse_covariance_operator must have the same number of elements")
+    if len(set([len(obs),len(cops),len(model_d)])) != 1:
+        raise ValueError("observation, calibration_operator and model_visibilities must have the same number of elements")
 
-    likelihoods = []
-    for oo, cop, model_vis, log_inv_cov in zip(obs,cops,model_d,log_inv_covs):    
+    lhs = []
+    for oo, cop, model_vis, log_inv_cov,label in zip(obs,cops,model_d,log_inv_covs,labels):    
         mask = jnp.asarray(oo.mask.val)
 
         flagged_data = jnp.asarray(oo.vis.val)[mask]
@@ -78,9 +87,15 @@ def CalibrationLikelihood(
         lh_with_model = lh.amend(model)
         lh_with_model._domain = jft.Vector(lh_with_model._domain)
 
-        likelihoods.append(lh_with_model)
-    
-    return reduce(add,likelihoods)
+        if label is not None:
+            lh_with_model._name = label
+
+        lhs.append(lh_with_model)
+
+    if set(labels) == {None}:
+        return jft.likelihood.LikelihoodSum(*lhs)
+    else:
+        return jft.likelihood.LikelihoodSum(*lhs, _key_template="lh_{index}-{likelihood._name}")
 
 
 def ImagingLikelihood(
@@ -92,6 +107,7 @@ def ImagingLikelihood(
     log_inverse_covariance_operator: Union[jft.Model,Iterable[jft.Model]] = None,
     calibration_operator: Union[CalibrationDistribution, Iterable[CalibrationDistribution]] = None,
     calibration_field: Union[jnp.ndarray, Iterable[jnp.ndarray]] = None,
+    likelihood_labels: Union[str,Iterable[str]] = None,
     verbosity: int = 0,
     nthreads: int = 1,
     backend: str = "ducc0",
@@ -130,13 +146,21 @@ def ImagingLikelihood(
 
     log_inverse_covariance_operator : jft.Model or Iterable of jft.Model
         Optional. Target needs to be the same space as observation.vis. If it
-        is not specified, observation.wgt is taken as covariance.
+        is not specified, observation.wgt is taken as covariance. If used, the same
+        amount of elements as number of observations should be provided.
 
     calibration_operator : CalibrationDistribution or Iterable of CalibrationDistribution
-        Optional. Target needs to be the same as observation.vis.
+        Optional. Target needs to be the same as observation.vis. If used, the same
+        amount of elements as number of observations should be provided.
 
     calibration_field: jnp.ndarray or Iterable of jnp.ndarray
-        Optional. Domain needs to be the same as observation.vis.
+        Optional. Domain needs to be the same as observation.vis. If used, the same
+        amount of elements as number of observations should be provided.
+    
+    likelihood_labels: string or Iterable of string
+        Optional. Append labels to individual likelihoods which are shown in the minisanity
+        for overview. If used, the same amount of elements as number of observations 
+        should be provided.
 
     verbosity : int
 
@@ -151,16 +175,14 @@ def ImagingLikelihood(
     """
 
     obs = _obj2list(observation,Observation)
-    cops = _obj2list(calibration_operator,CalibrationDistribution)
-    cflds = _obj2list(calibration_field,ArrayImpl)
-    log_inv_covs = _obj2list(log_inverse_covariance_operator,Model)
+    cops = _duplicate(_obj2list(calibration_operator,CalibrationDistribution),len(obs))
+    cflds = _duplicate(_obj2list(calibration_field,ArrayImpl),len(obs))
+    log_inv_covs = _duplicate(_obj2list(log_inverse_covariance_operator,Model),len(obs))
+    labels = _duplicate(_obj2list(likelihood_labels,str),len(obs))
 
-    if len(set([len(obs),len(cops),len(cflds),len(log_inv_covs)])) != 1:
-        raise ValueError("observation, log_inverse_covariance_operator, calibration_operator and calibration_field must have the same number of elements")
-
-    likelihoods = []
+    lhs = []
     
-    for ii, (oo, cop, cfld, log_inv_cov) in enumerate(zip(obs,cops,cflds,log_inv_covs)):
+    for ii, (oo, cop, cfld, log_inv_cov,label) in enumerate(zip(obs,cops,cflds,log_inv_covs,labels)):
         if cfld is not None and cop is not None:
             raise ValueError(
                 f"Can't set calibration operator and calibration field simultaneously at index {ii}"
@@ -185,6 +207,12 @@ def ImagingLikelihood(
         lh_with_model = lh.amend(model)
         lh_with_model._domain = jft.Vector(lh_with_model._domain)
 
-        likelihoods.append(lh_with_model)
+        if label is not None:
+            lh_with_model._name = label
+
+        lhs.append(lh_with_model)
     
-    return reduce(add,likelihoods)
\ No newline at end of file
+    if set(labels) == {None}:
+        return jft.likelihood.LikelihoodSum(*lhs)
+    else:
+        return jft.likelihood.LikelihoodSum(*lhs, _key_template="lh_{index}-{likelihood._name}")
\ No newline at end of file
-- 
GitLab


From 34460b2c47c1891d30e11c619a629df0a652127d Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 25 Mar 2025 23:09:34 +0100
Subject: [PATCH 65/88] Removed likelihood models from direct access as they
 are not intended for independent use

---
 resolve/re/__init__.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/resolve/re/__init__.py b/resolve/re/__init__.py
index 9cb6e714..296b247c 100644
--- a/resolve/re/__init__.py
+++ b/resolve/re/__init__.py
@@ -4,6 +4,5 @@ from .response import InterferometryResponse, InterferometryResponseFinuFFT, Int
 from .radio_response import build_exact_r, build_approximations
 from .optimize import optimize
 from .calibration import CalibrationInterpolator, CalibrationDistribution
-from .likelihood_models import CalibrationLikelihoodFixedCovarianceModel, CalibrationLikelihoodVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
 from .likelihood import CalibrationLikelihood, ImagingLikelihood
 from .sugar import Bulk_CF_AntennaTimeDomain
\ No newline at end of file
-- 
GitLab


From 01ab84cc39c042e797df634a2bceae841829526c Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 25 Mar 2025 23:10:17 +0100
Subject: [PATCH 66/88] Seperated models to remove if statements in call method

---
 resolve/re/likelihood_models.py | 142 ++++++++++++++++++++++----------
 1 file changed, 100 insertions(+), 42 deletions(-)

diff --git a/resolve/re/likelihood_models.py b/resolve/re/likelihood_models.py
index 7d0d2628..6111c952 100644
--- a/resolve/re/likelihood_models.py
+++ b/resolve/re/likelihood_models.py
@@ -5,7 +5,7 @@ from .calibration import CalibrationDistribution
 
 from typing import Callable
 
-class CalibrationLikelihoodFixedCovarianceModel(jft.Model):
+class ModelCalibrationLikelihoodFixedCovariance(jft.Model):
     """
     Provides a flagged data model for calibration
 
@@ -33,9 +33,10 @@ class CalibrationLikelihoodFixedCovarianceModel(jft.Model):
     def __call__(self, x):
         data_model = self._vis*self._cop(x)
         flagged_data_model = data_model[self._mask]
+
         return flagged_data_model
     
-class CalibrationLikelihoodVariableCovarianceModel(jft.Model):
+class ModelCalibrationLikelihoodVariableCovariance(jft.Model):
     """
     Provides a combined flagged data model and flagged inverse covariance model for calibration
 
@@ -73,9 +74,9 @@ class CalibrationLikelihoodVariableCovarianceModel(jft.Model):
         
         return (flagged_data_model,flagged_inv_std)
 
-class ImagingLikelihoodFixedCovarianceModel(jft.Model):
+class ModelImagingLikelihoodFixedCovarianceCalibrationField(jft.Model):
     """
-    Provides a flagged data model for imaging
+    Provides a flagged data model for imaging given an optional calibration field
 
     Parameters
     ----------
@@ -85,8 +86,6 @@ class ImagingLikelihoodFixedCovarianceModel(jft.Model):
         Model for sky
     mask: jnp.array
         Mask as boolean numpy array for good visibilites
-    calibration_operator: CalibrationDistribution
-        Optional. Calibration operator
     calibration_field: jnp.ndarray
         Optional. Calibration field
     """
@@ -94,39 +93,67 @@ class ImagingLikelihoodFixedCovarianceModel(jft.Model):
             self, 
             R: Callable, 
             sky_operator: jft.Model, 
-            mask: jnp.ndarray, 
-            calibration_operator: CalibrationDistribution = None, 
+            mask: jnp.ndarray,
             calibration_field: jnp.ndarray = None
             ):
-        if (calibration_operator is not None) and (calibration_field is not None):
-            raise ValueError("You can either set a calibration operator or a calibration field")
 
         self._R = R
         self._sky = sky_operator
         self._mask = mask
 
-        self._cal_op = calibration_operator
-        self._cal_fld = calibration_field
+        if calibration_field is None:
+            self._cal_fld = jnp.ones(mask.shape)
+        else: 
+            self._cal_fld = calibration_field
 
-        if self._cal_op is not None:
-            super().__init__(init=self._sky.init | self._cal_op.init)
-        else:
-            super().__init__(init=self._sky.init)
+        super().__init__(init=self._sky.init)
 
     def __call__(self,x):
-        if self._cal_op is not None:
-            data_model = self._cal_op(x)*self._R(self._sky(x))
-        elif self._cal_fld is not None:
-            data_model = self._cal_fld*self._R(self._sky(x))
-        else:
-            data_model = self._R(self._sky(x))
+        data_model = self._R(self._sky(x))
+        flagged_data_model = data_model[self._mask]
+
+        return flagged_data_model
+    
+class ModelImagingLikelihoodFixedCovarianceCalibrationOperator(jft.Model):
+    """
+    Provides a flagged data model for imaging given a calibration operator
 
+    Parameters
+    ----------
+    R: Callable
+        Response operator function
+    sky: jft.Model
+        Model for sky
+    mask: jnp.array
+        Mask as boolean numpy array for good visibilites
+    calibration_operator: CalibrationDistribution
+        Optional. Calibration operator
+    """
+    def __init__(
+            self, 
+            R: Callable, 
+            sky_operator: jft.Model, 
+            mask: jnp.ndarray, 
+            calibration_operator: CalibrationDistribution
+            ):
+        self._R = R
+        self._sky = sky_operator
+        self._mask = mask
+
+        self._cal_op = calibration_operator
+
+        super().__init__(init=self._sky.init | self._cal_op.init)
+
+    def __call__(self,x):
+        data_model = self._cal_op(x)*self._R(self._sky(x))
         flagged_data_model = data_model[self._mask]
+
         return flagged_data_model
         
-class ImagingLikelihoodVariableCovarianceModel(jft.Model):
+class ModelImagingLikelihoodVariableCovarianceCalibrationField(jft.Model):
     """
-    Provides a combined flagged data model and flagged inverse covariance model for imaging
+    Provides a combined flagged data model and flagged inverse covariance model for imaging 
+    given an optional calibration field
 
     Parameters
     ----------
@@ -138,8 +165,6 @@ class ImagingLikelihoodVariableCovarianceModel(jft.Model):
         Model for log inverse covariance
     mask: jnp.array
         Mask as boolean numpy array for good visibilites
-    calibration_operator: CalibrationDistribution
-        Optional. Calibration operator
     calibration_field: jnp.ndarray
         Optional. Calibration field
     """
@@ -148,34 +173,67 @@ class ImagingLikelihoodVariableCovarianceModel(jft.Model):
             R: Callable, 
             sky_operator: jft.Model, 
             log_inverse_covariance_model: jft.Model, 
-            mask: jnp.ndarray, 
-            calibration_operator: CalibrationDistribution = None, 
+            mask: jnp.ndarray,
             calibration_field: jnp.ndarray = None
             ):
-        if (calibration_operator is not None) and (calibration_field is not None):
-            raise ValueError("You can either set a calibration operator or a calibration field")
+        self._R = R
+        self._sky = sky_operator
+        self._mask = mask
+        self._log_inv_cov = log_inverse_covariance_model
+        
+        if calibration_field is None:
+            self._cal_fld = jnp.ones(mask.shape)
+        else: 
+            self._cal_fld = calibration_field
+
+        super().__init__(init=self._sky.init | self._log_inv_cov.init)
+
+    def __call__(self,x):
+        data_model = self._cal_fld*self._R(self._sky(x))
+        flagged_data_model = data_model[self._mask]
+
+        inv_std = jnp.exp(0.5*self._log_inv_cov(x))
+        flagged_inv_std = inv_std[self._mask]
         
+        return (flagged_data_model,flagged_inv_std)
+
+class ModelImagingLikelihoodVariableCovarianceCalibrationOperator(jft.Model):
+    """
+    Provides a combined flagged data model and flagged inverse covariance model for imaging
+    given a calibration operator
+
+    Parameters
+    ----------
+    R: Callable
+        Response operator function
+    sky: jft.Model
+        Model for sky
+    log_inverse_covariance_model: jft.Model
+        Model for log inverse covariance
+    mask: jnp.array
+        Mask as boolean numpy array for good visibilites
+    calibration_operator: CalibrationDistribution
+        Optional. Calibration operator
+    """
+    def __init__(
+            self, 
+            R: Callable, 
+            sky_operator: jft.Model, 
+            log_inverse_covariance_model: jft.Model, 
+            mask: jnp.ndarray, 
+            calibration_operator: CalibrationDistribution
+            ):
         self._R = R
         self._sky = sky_operator
         self._mask = mask
         self._log_inv_cov = log_inverse_covariance_model
         
         self._cal_op = calibration_operator
-        self._cal_fld = calibration_field
 
-        if self._cal_op is not None:
-            super().__init__(init=self._sky.init | self._log_inv_cov.init | self._cal_op.init)
-        else:
-            super().__init__(init=self._sky.init | self._log_inv_cov.init)
+        super().__init__(init=self._sky.init | self._log_inv_cov.init | self._cal_op.init)
 
     def __call__(self,x):
-        if self._cal_op is not None:
-            data_model = self._cal_op(x)*self._R(self._sky(x))
-        elif self._cal_fld is not None:
-            data_model = self._cal_fld*self._R(self._sky(x))
-        else:
-            data_model = self._R(self._sky(x))
-
+        data_model = self._cal_op(x)*self._R(self._sky(x))
         flagged_data_model = data_model[self._mask]
 
         inv_std = jnp.exp(0.5*self._log_inv_cov(x))
-- 
GitLab


From af866a523270bb5f35d4f93e6d391adbdbcd6b22 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 25 Mar 2025 23:11:18 +0100
Subject: [PATCH 67/88] Added call to specific models into creation function
 for the likelihoods

---
 resolve/re/likelihood.py | 24 +++++++++++++++++++-----
 1 file changed, 19 insertions(+), 5 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index 29cbdbd3..d9a8ef09 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -4,8 +4,15 @@ import nifty8.re as jft
 from typing import Union, Iterable
 
 from .response import InterferometryResponse
-from .likelihood_models import CalibrationLikelihoodFixedCovarianceModel, CalibrationLikelihoodVariableCovarianceModel, ImagingLikelihoodFixedCovarianceModel, ImagingLikelihoodVariableCovarianceModel
 from .calibration import CalibrationDistribution
+from .likelihood_models import *
+# The classes from .likelihoods model are:
+#  - ModelCalibrationLikelihoodFixedCovariance
+#  - ModelCalibrationLikelihoodVariableCovariance
+#  - ModelImagingLikelihoodFixedCovarianceCalibrationField
+#  - ModelImagingLikelihoodFixedCovarianceCalibrationOperator
+#  - ModelImagingLikelihoodVariableCovarianceCalibrationField
+#  - ModelImagingLikelihoodVariableCovarianceCalibrationOperator
 
 from ..util import _obj2list, _duplicate
 from ..data.observation import Observation
@@ -74,13 +81,13 @@ def CalibrationLikelihood(
         flagged_data = jnp.asarray(oo.vis.val)[mask]
 
         if log_inv_cov is None:
-            model = CalibrationLikelihoodFixedCovarianceModel(cop,model_vis,mask)
+            model = ModelCalibrationLikelihoodFixedCovariance(cop,model_vis,mask)
             flagged_inv_cov = jnp.asarray(oo.weight.val)[mask]
             
             lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
         
         else:
-            model = CalibrationLikelihoodVariableCovarianceModel(cop,model_vis,log_inv_cov,mask)
+            model = ModelCalibrationLikelihoodVariableCovariance(cop,model_vis,log_inv_cov,mask)
 
             lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=jnp.iscomplexobj(oo.vis.val))
         
@@ -194,13 +201,20 @@ def ImagingLikelihood(
         flagged_data = jnp.asarray(oo.vis.val)[mask]
 
         if log_inv_cov is None:
-            model = ImagingLikelihoodFixedCovarianceModel(R,sky_operator,mask,cop,cfld)
+            if cop is None:
+                model = ModelImagingLikelihoodFixedCovarianceCalibrationField(R,sky_operator,mask,cfld)
+            else:
+                model = ModelImagingLikelihoodFixedCovarianceCalibrationOperator(R,sky_operator,mask,cop)
+            
             flagged_inv_cov = jnp.asarray(oo.weight.val)[mask]
         
             lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
 
         else:
-            model = ImagingLikelihoodVariableCovarianceModel(R,sky_operator,log_inv_cov,mask,cop,cfld)
+            if cop is None:
+                model = ModelImagingLikelihoodVariableCovarianceCalibrationField(R,sky_operator,log_inv_cov,mask,cfld)
+            else:
+                model = ModelImagingLikelihoodVariableCovarianceCalibrationOperator(R,sky_operator,log_inv_cov,mask,cop)
 
             lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=jnp.iscomplexobj(oo.vis.val))
 
-- 
GitLab


From c4d23e3f77fd5904e51e8feb7955ef9180eef42c Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Thu, 27 Mar 2025 21:14:23 +0100
Subject: [PATCH 68/88] Edited key template formatting

---
 resolve/re/likelihood.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index d9a8ef09..9936fd61 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -75,7 +75,7 @@ def CalibrationLikelihood(
         raise ValueError("observation, calibration_operator and model_visibilities must have the same number of elements")
 
     lhs = []
-    for oo, cop, model_vis, log_inv_cov,label in zip(obs,cops,model_d,log_inv_covs,labels):    
+    for ii, (oo, cop, model_vis, log_inv_cov,label) in enumerate(zip(obs,cops,model_d,log_inv_covs,labels)):    
         mask = jnp.asarray(oo.mask.val)
 
         flagged_data = jnp.asarray(oo.vis.val)[mask]
@@ -96,13 +96,14 @@ def CalibrationLikelihood(
 
         if label is not None:
             lh_with_model._name = label
+            lh_with_model._name_n_ws = " "*(len(str(len(obs)))-len(str(ii)))
 
         lhs.append(lh_with_model)
 
     if set(labels) == {None}:
         return jft.likelihood.LikelihoodSum(*lhs)
     else:
-        return jft.likelihood.LikelihoodSum(*lhs, _key_template="lh_{index}-{likelihood._name}")
+        return jft.likelihood.LikelihoodSum(*lhs, _key_template="lh {likelihood._name_n_ws}{index} | {likelihood._name}")
 
 
 def ImagingLikelihood(
@@ -223,10 +224,11 @@ def ImagingLikelihood(
 
         if label is not None:
             lh_with_model._name = label
+            lh_with_model._name_n_ws = " "*(len(str(len(obs)))-len(str(ii)))
 
         lhs.append(lh_with_model)
     
     if set(labels) == {None}:
         return jft.likelihood.LikelihoodSum(*lhs)
     else:
-        return jft.likelihood.LikelihoodSum(*lhs, _key_template="lh_{index}-{likelihood._name}")
\ No newline at end of file
+        return jft.likelihood.LikelihoodSum(*lhs, _key_template="lh {likelihood._name_n_ws}{index} | {likelihood._name}")
-- 
GitLab


From dad27608f2c86b635285ae53b5b3518d55b95878 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Thu, 3 Apr 2025 14:32:56 +0200
Subject: [PATCH 69/88] Contracted cases for given calibration solution to one
 model

---
 resolve/re/likelihood_models.py | 141 +++++++++-----------------------
 1 file changed, 37 insertions(+), 104 deletions(-)

diff --git a/resolve/re/likelihood_models.py b/resolve/re/likelihood_models.py
index 6111c952..ee3b3ffd 100644
--- a/resolve/re/likelihood_models.py
+++ b/resolve/re/likelihood_models.py
@@ -73,10 +73,12 @@ class ModelCalibrationLikelihoodVariableCovariance(jft.Model):
         flagged_inv_std = inv_std[self._mask]
         
         return (flagged_data_model,flagged_inv_std)
+    
 
-class ModelImagingLikelihoodFixedCovarianceCalibrationField(jft.Model):
+class ModelImagingLikelihoodFixedCovariance(jft.Model):
     """
-    Provides a flagged data model for imaging given an optional calibration field
+    Provides a flagged data model for imaging given an optional calibration operator or
+    calibration field
 
     Parameters
     ----------
@@ -86,6 +88,8 @@ class ModelImagingLikelihoodFixedCovarianceCalibrationField(jft.Model):
         Model for sky
     mask: jnp.array
         Mask as boolean numpy array for good visibilites
+    calibration_operator: CalibrationDistribution
+        Optional. Calibration operator
     calibration_field: jnp.ndarray
         Optional. Calibration field
     """
@@ -93,67 +97,34 @@ class ModelImagingLikelihoodFixedCovarianceCalibrationField(jft.Model):
             self, 
             R: Callable, 
             sky_operator: jft.Model, 
-            mask: jnp.ndarray,
+            mask: jnp.ndarray, 
+            calibration_operator: CalibrationDistribution = None,
             calibration_field: jnp.ndarray = None
             ):
-
-        self._R = R
-        self._sky = sky_operator
+        
         self._mask = mask
+        inits = sky_operator.init 
 
-        if calibration_field is None:
-            self._cal_fld = jnp.ones(mask.shape)
-        else: 
-            self._cal_fld = calibration_field
+        if(calibration_operator is not None):
+            self._data_model = lambda x: calibration_operator(x)*R(sky_operator(x))
+            inits = inits | calibration_operator.init     
+        elif(calibration_field is not None):
+            self._data_model = lambda x: calibration_field*R(sky_operator(x))
+        else:
+            self._data_model = lambda x: R(sky_operator(x))
 
-        super().__init__(init=self._sky.init)
+        super().__init__(init=inits)
 
     def __call__(self,x):
-        data_model = self._R(self._sky(x))
+        data_model = self._data_model(x)
         flagged_data_model = data_model[self._mask]
 
         return flagged_data_model
-    
-class ModelImagingLikelihoodFixedCovarianceCalibrationOperator(jft.Model):
-    """
-    Provides a flagged data model for imaging given a calibration operator
 
-    Parameters
-    ----------
-    R: Callable
-        Response operator function
-    sky: jft.Model
-        Model for sky
-    mask: jnp.array
-        Mask as boolean numpy array for good visibilites
-    calibration_operator: CalibrationDistribution
-        Optional. Calibration operator
-    """
-    def __init__(
-            self, 
-            R: Callable, 
-            sky_operator: jft.Model, 
-            mask: jnp.ndarray, 
-            calibration_operator: CalibrationDistribution
-            ):
-        self._R = R
-        self._sky = sky_operator
-        self._mask = mask
-
-        self._cal_op = calibration_operator
-
-        super().__init__(init=self._sky.init | self._cal_op.init)
-
-    def __call__(self,x):
-        data_model = self._cal_op(x)*self._R(self._sky(x))
-        flagged_data_model = data_model[self._mask]
-
-        return flagged_data_model
-        
-class ModelImagingLikelihoodVariableCovarianceCalibrationField(jft.Model):
+class ModelImagingLikelihoodVariableCovariance(jft.Model):
     """
     Provides a combined flagged data model and flagged inverse covariance model for imaging 
-    given an optional calibration field
+    given an optional calibration operator or calibration field
 
     Parameters
     ----------
@@ -165,6 +136,8 @@ class ModelImagingLikelihoodVariableCovarianceCalibrationField(jft.Model):
         Model for log inverse covariance
     mask: jnp.array
         Mask as boolean numpy array for good visibilites
+    calibration_operator: CalibrationDistribution
+        Optional. Calibration operator
     calibration_field: jnp.ndarray
         Optional. Calibration field
     """
@@ -173,70 +146,30 @@ class ModelImagingLikelihoodVariableCovarianceCalibrationField(jft.Model):
             R: Callable, 
             sky_operator: jft.Model, 
             log_inverse_covariance_model: jft.Model, 
-            mask: jnp.ndarray,
+            mask: jnp.ndarray, 
+            calibration_operator: CalibrationDistribution = None,
             calibration_field: jnp.ndarray = None
             ):
-        self._R = R
-        self._sky = sky_operator
-        self._mask = mask
-        self._log_inv_cov = log_inverse_covariance_model
-        
-        if calibration_field is None:
-            self._cal_fld = jnp.ones(mask.shape)
-        else: 
-            self._cal_fld = calibration_field
-
-        super().__init__(init=self._sky.init | self._log_inv_cov.init)
-
-    def __call__(self,x):
-        data_model = self._cal_fld*self._R(self._sky(x))
-        flagged_data_model = data_model[self._mask]
-
-        inv_std = jnp.exp(0.5*self._log_inv_cov(x))
-        flagged_inv_std = inv_std[self._mask]
         
-        return (flagged_data_model,flagged_inv_std)
-
-class ModelImagingLikelihoodVariableCovarianceCalibrationOperator(jft.Model):
-    """
-    Provides a combined flagged data model and flagged inverse covariance model for imaging
-    given a calibration operator
-
-    Parameters
-    ----------
-    R: Callable
-        Response operator function
-    sky: jft.Model
-        Model for sky
-    log_inverse_covariance_model: jft.Model
-        Model for log inverse covariance
-    mask: jnp.array
-        Mask as boolean numpy array for good visibilites
-    calibration_operator: CalibrationDistribution
-        Optional. Calibration operator
-    """
-    def __init__(
-            self, 
-            R: Callable, 
-            sky_operator: jft.Model, 
-            log_inverse_covariance_model: jft.Model, 
-            mask: jnp.ndarray, 
-            calibration_operator: CalibrationDistribution
-            ):
-        self._R = R
-        self._sky = sky_operator
         self._mask = mask
         self._log_inv_cov = log_inverse_covariance_model
-        
-        self._cal_op = calibration_operator
+        inits = sky_operator.init | self._log_inv_cov.init
 
-        super().__init__(init=self._sky.init | self._log_inv_cov.init | self._cal_op.init)
+        if(calibration_operator is not None):
+            self._data_model = lambda x: calibration_operator(x)*R(sky_operator(x))
+            inits = inits | calibration_operator.init
+        elif(calibration_field is not None):
+            self._data_model = lambda x: calibration_field*R(sky_operator(x))
+        else:
+            self._data_model = lambda x: R(sky_operator(x))
+
+        super().__init__(init=inits)
 
     def __call__(self,x):
-        data_model = self._cal_op(x)*self._R(self._sky(x))
+        data_model = self._data_model(x)
         flagged_data_model = data_model[self._mask]
 
         inv_std = jnp.exp(0.5*self._log_inv_cov(x))
         flagged_inv_std = inv_std[self._mask]
         
-        return (flagged_data_model,flagged_inv_std)
\ No newline at end of file
+        return (flagged_data_model,flagged_inv_std)
-- 
GitLab


From 9e6b97fa344b6d6f1636e7bc8ea581b95691abaf Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Thu, 3 Apr 2025 14:35:00 +0200
Subject: [PATCH 70/88] Removed cases in accordance with used likelihood models

---
 resolve/re/likelihood.py | 12 +++---------
 1 file changed, 3 insertions(+), 9 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index 9936fd61..f0c77ddc 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -202,20 +202,14 @@ def ImagingLikelihood(
         flagged_data = jnp.asarray(oo.vis.val)[mask]
 
         if log_inv_cov is None:
-            if cop is None:
-                model = ModelImagingLikelihoodFixedCovarianceCalibrationField(R,sky_operator,mask,cfld)
-            else:
-                model = ModelImagingLikelihoodFixedCovarianceCalibrationOperator(R,sky_operator,mask,cop)
-            
+            model = ModelImagingLikelihoodFixedCovariance(R,sky_operator,mask,cop,cfld)
+
             flagged_inv_cov = jnp.asarray(oo.weight.val)[mask]
         
             lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
 
         else:
-            if cop is None:
-                model = ModelImagingLikelihoodVariableCovarianceCalibrationField(R,sky_operator,log_inv_cov,mask,cfld)
-            else:
-                model = ModelImagingLikelihoodVariableCovarianceCalibrationOperator(R,sky_operator,log_inv_cov,mask,cop)
+            model = ModelImagingLikelihoodVariableCovariance(R,sky_operator,log_inv_cov,mask,cop,cfld)
 
             lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=jnp.iscomplexobj(oo.vis.val))
 
-- 
GitLab


From 5de011b056f9bfa782d190335cc635c247f97ec2 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 8 Apr 2025 18:16:05 +0200
Subject: [PATCH 71/88] Added missing zero to n_psf_pix in line 143

---
 resolve/re/radio_response.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/resolve/re/radio_response.py b/resolve/re/radio_response.py
index 7ec7728a..ccaf838a 100644
--- a/resolve/re/radio_response.py
+++ b/resolve/re/radio_response.py
@@ -140,7 +140,7 @@ def build_approximations(RNR, RNR_l, noise_scaling=None, varcov=False, cache_noi
             psf_kernel = compute_PSF(RNR_l, n_psf_pix0, n_psf_pix1)
             pickle.dump(psf_kernel, open(f"{cache_response_kernel}.p", "wb"))
     else:
-            psf_kernel = compute_PSF(RNR_l, n_psf_pix, n_psf_pix1)
+            psf_kernel = compute_PSF(RNR_l, n_psf_pix0, n_psf_pix1)
     psf_kernel = jnp.array(psf_kernel)
     apply_psf_kern = lambda x: fft_jax_inv(psf_kernel * fft_jax(x)).real
 
-- 
GitLab


From 5103f4dde25e9c4d584b3d97d3f8c9fe4d26a1d5 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Sat, 12 Apr 2025 15:33:20 +0200
Subject: [PATCH 72/88] Added model for CalibrationDistribution scaled by a
 inferable scalar operator

---
 resolve/re/__init__.py |  2 +-
 resolve/re/sugar.py    | 29 +++++++++++++++++++++++++++--
 2 files changed, 28 insertions(+), 3 deletions(-)

diff --git a/resolve/re/__init__.py b/resolve/re/__init__.py
index 296b247c..1011168e 100644
--- a/resolve/re/__init__.py
+++ b/resolve/re/__init__.py
@@ -5,4 +5,4 @@ from .radio_response import build_exact_r, build_approximations
 from .optimize import optimize
 from .calibration import CalibrationInterpolator, CalibrationDistribution
 from .likelihood import CalibrationLikelihood, ImagingLikelihood
-from .sugar import Bulk_CF_AntennaTimeDomain
\ No newline at end of file
+from .sugar import Bulk_CF_AntennaTimeDomain, ScaledCalibrationDistribution
\ No newline at end of file
diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index 9f0ee4b0..b8184290 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -1,7 +1,7 @@
 import jax.numpy as jnp
 import nifty8.re as jft
 
-from jax.lax import dynamic_slice_in_dim
+from .calibration import CalibrationDistribution
 
 from typing import Union, Iterable
 
@@ -56,4 +56,29 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
         return jnp.swapaxes(jnp.reshape(self._fields(x),self._output_shape),2,3)
     
     def get_powerspectrum(self):
-        return self._powerspectrum
\ No newline at end of file
+        return self._powerspectrum
+    
+class ScaledCalibrationDistribution(jft.Model):
+    """
+    Multiplies a calibration distribution with a scalar operator to correct flux of 
+    model visibilities used in calibration.
+
+    Parameters
+    ----------
+    calibration_opertor: CalibrationDistribution
+        Calibration distribution which should be scaled.
+    scaling_operator: jft.Model    
+        Scalar scaling operator
+    """
+    def __init__(
+            self, 
+            calibration_operator: CalibrationDistribution, 
+            scaling_operator: jft.Model
+            ):
+        self._cop = calibration_operator
+        self._scaling = scaling_operator
+
+        super().__init__(init=self._cop | self._scaling)
+
+    def __call__(self,x):
+        return self._scaling(x)*self._cop(x)
\ No newline at end of file
-- 
GitLab


From d729f23a4a332cd76c980c3b64f877337d26a4d2 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Sat, 12 Apr 2025 16:00:20 +0200
Subject: [PATCH 73/88] Remove circular import dependency

---
 resolve/re/calibration.py | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
index 290f3fa4..f7b3cb65 100644
--- a/resolve/re/calibration.py
+++ b/resolve/re/calibration.py
@@ -5,8 +5,6 @@ import jax.numpy as jnp
 from jax.tree_util import Partial
 from jax import vmap
 
-from .sugar import Bulk_CF_AntennaTimeDomain
-
 from ..data.observation import Observation
 
 class CalibrationDistribution(jft.Model):
@@ -18,9 +16,9 @@ class CalibrationDistribution(jft.Model):
     observation: Observation
         Observation object from which are the antenna and temporal information corresponding to 
         the visibilites are extracted.
-    phase_fields: Bulk_CF_AntennaTimeDomain
+    phase_fields: jft.Model or Bulk_CF_AntennaTimeDomain (preferred)
         Correlated fields on antenna-time space for phases of calibration solutions.
-    log_amplitude_fields: Bulk_CF_AntennaTimeDomain
+    log_amplitude_fields: jft.Model or Bulk_CF_AntennaTimeDomain (preferred)
         Correlated fields on antenna-time space for log amplitude of calibration solutions.
     dt: float
         Distances between time points on time axis. Has to be the same distance of time points,
@@ -33,8 +31,8 @@ class CalibrationDistribution(jft.Model):
     def __init__(
             self,
             observation: Observation, 
-            phase_fields: Bulk_CF_AntennaTimeDomain, 
-            log_amplitude_fields: Bulk_CF_AntennaTimeDomain,
+            phase_fields: jft.Model, 
+            log_amplitude_fields: jft.Model,
             dt: float
             ):
         ap = observation.antenna_positions
-- 
GitLab


From 968a0f296d7203d99f5485e3d7b54089aa5487f7 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Sat, 12 Apr 2025 16:33:12 +0200
Subject: [PATCH 74/88] Fixed init call of ScaledCalibrationOperator

---
 resolve/re/sugar.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index b8184290..d7a219d0 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -78,7 +78,7 @@ class ScaledCalibrationDistribution(jft.Model):
         self._cop = calibration_operator
         self._scaling = scaling_operator
 
-        super().__init__(init=self._cop | self._scaling)
+        super().__init__(init=self._cop.init | self._scaling.init)
 
     def __call__(self,x):
         return self._scaling(x)*self._cop(x)
\ No newline at end of file
-- 
GitLab


From 3128eff6952ebba787e2ec2a523a48b061d5d25e Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 14 Apr 2025 16:21:14 +0200
Subject: [PATCH 75/88] Changed semantics of input arguments for likelihoods

---
 resolve/re/likelihood.py | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index f0c77ddc..27d9fe57 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -25,7 +25,7 @@ def CalibrationLikelihood(
     calibration_operator: Union[CalibrationDistribution, Iterable[CalibrationDistribution]],
     model_visibilities: Union[jnp.ndarray, Iterable[jnp.ndarray]],
     log_inverse_covariance_operator: Union[jft.Model,Iterable[jft.Model]] = None,
-    likelihood_labels: Union[str,Iterable[str]] = None
+    likelihood_label: Union[str,Iterable[str]] = None
 ):
     """Versatile calibration likelihood class
 
@@ -58,8 +58,8 @@ def CalibrationLikelihood(
         not specified, observation.wgt is taken as covariance. If used, the same
         amount of elements as number of observations should be provided.
 
-    likelihood_labels: string or Iterable of string
-        Optional. Append labels to individual likelihoods which are shown in the minisanity
+    likelihood_label: string or Iterable of string
+        Optional. Append label to individual likelihood which is shown in the minisanity
         for overview. If used, the same amount of elements as number of observations 
         should be provided.
     """
@@ -69,7 +69,7 @@ def CalibrationLikelihood(
     cops = _obj2list(calibration_operator,CalibrationDistribution)
     model_d = _obj2list(model_visibilities,ArrayImpl)
     log_inv_covs = _duplicate(_obj2list(log_inverse_covariance_operator,Model),len(obs))
-    labels = _duplicate(_obj2list(likelihood_labels,str),len(obs))
+    labels = _duplicate(_obj2list(likelihood_label,str),len(obs))
 
     if len(set([len(obs),len(cops),len(model_d)])) != 1:
         raise ValueError("observation, calibration_operator and model_visibilities must have the same number of elements")
@@ -115,7 +115,7 @@ def ImagingLikelihood(
     log_inverse_covariance_operator: Union[jft.Model,Iterable[jft.Model]] = None,
     calibration_operator: Union[CalibrationDistribution, Iterable[CalibrationDistribution]] = None,
     calibration_field: Union[jnp.ndarray, Iterable[jnp.ndarray]] = None,
-    likelihood_labels: Union[str,Iterable[str]] = None,
+    likelihood_label: Union[str,Iterable[str]] = None,
     verbosity: int = 0,
     nthreads: int = 1,
     backend: str = "ducc0",
@@ -165,7 +165,7 @@ def ImagingLikelihood(
         Optional. Domain needs to be the same as observation.vis. If used, the same
         amount of elements as number of observations should be provided.
     
-    likelihood_labels: string or Iterable of string
+    likelihood_label: string or Iterable of string
         Optional. Append labels to individual likelihoods which are shown in the minisanity
         for overview. If used, the same amount of elements as number of observations 
         should be provided.
@@ -186,7 +186,7 @@ def ImagingLikelihood(
     cops = _duplicate(_obj2list(calibration_operator,CalibrationDistribution),len(obs))
     cflds = _duplicate(_obj2list(calibration_field,ArrayImpl),len(obs))
     log_inv_covs = _duplicate(_obj2list(log_inverse_covariance_operator,Model),len(obs))
-    labels = _duplicate(_obj2list(likelihood_labels,str),len(obs))
+    labels = _duplicate(_obj2list(likelihood_label,str),len(obs))
 
     lhs = []
     
-- 
GitLab


From 945ca7bba61d4ed5cbdb5f0055052e78e631dc35 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 16 Apr 2025 12:22:14 +0200
Subject: [PATCH 76/88] Added dataclass CalibrationAssembler

---
 resolve/re/sugar.py | 124 ++++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 121 insertions(+), 3 deletions(-)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index d7a219d0..e5fc96f7 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -3,7 +3,10 @@ import nifty8.re as jft
 
 from .calibration import CalibrationDistribution
 
-from typing import Union, Iterable
+from ..data.observation import Observation
+
+from typing import Union, Iterable, Dict, Tuple, Optional
+from dataclasses import dataclass, field
 
 class Bulk_CF_AntennaTimeDomain(jft.Model):
     """
@@ -38,10 +41,11 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
         self._pol = list(polarizations)
         self._ant = list(antennas)
         self._freq = list(frequencies)
+        self._prefix = str(prefix)
 
         self._output_shape = (len(self._pol),len(self._ant),len(self._freq),dct_ps["shape"][0])
         
-        cfm = jft.CorrelatedFieldMaker(prefix)
+        cfm = jft.CorrelatedFieldMaker(self._prefix + "_")
         cfm.set_amplitude_total_offset(**dct_offset)
         cfm.add_fluctuations(**dct_ps)
 
@@ -55,6 +59,10 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
     def __call__(self,x):
         return jnp.swapaxes(jnp.reshape(self._fields(x),self._output_shape),2,3)
     
+    @property
+    def name(self):
+        return self._prefix
+    
     def get_powerspectrum(self):
         return self._powerspectrum
     
@@ -81,4 +89,114 @@ class ScaledCalibrationDistribution(jft.Model):
         super().__init__(init=self._cop.init | self._scaling.init)
 
     def __call__(self,x):
-        return self._scaling(x)*self._cop(x)
\ No newline at end of file
+        return self._scaling(x)*self._cop(x)
+
+@dataclass    
+class CalibrationAssembler:
+    """
+    Collect and automatically creates quantities necessary for reconstruction of the calibration solution
+    of radio-interferometric data.
+
+    Parameters
+    ----------
+    obs: Observation
+        Observation object providing multiple quantities for construction of claibration
+        operator and likelihood.
+    phase_field: jft.Model or Bulk_CF_AntennaTimeDomain (preferred)
+        Correlated fields on antenna-time space for phases of calibration solutions.
+    logflux_field: jft.Model or Bulk_CF_AntennaTimeDomain (preferred)
+        Correlated fields on antenna-time space for log amplitude of calibration solutions.
+    dt: float
+        Distances between time points on time axis. Has to be the same distance of time points,
+        which is used for phase_fields and log_amplitude fields.
+    model_vis: jnp.ndarray
+        Optional. Assumed visibilities of the the source.
+    scaling_op: jft.Model
+        Optional. Model for scalar scaling of model_vis.
+    log_inv_cov: jft.Model
+        Optional. Model for log inverse covariance.
+    lh_label: str
+        Optional. Label for likelihood. If not set the default is the name of 
+        the observation.
+    """
+    obs: Observation
+    phase_field: jft.Model
+    logflux_field: jft.Model
+    dt: float
+    model_vis: Optional[jnp.ndarray] = None
+    scaling_op: Optional[jft.WrappedCall] = None
+    log_inv_cov: Optional[jft.Model] = None
+    lh_label: Optional[str] = None
+    cop: CalibrationDistribution = field(init=False)
+    scaled_cop: Optional[ScaledCalibrationDistribution] = field(init=False)
+
+    def __post_init__(self):
+        self.lh_label = self.obs.source_name if self.lh_label is None else self.lh_label
+        self.cop = CalibrationDistribution(self.obs,self.phase_field,self.logflux_field,self.dt)
+        self.scaled_cop = None if self.scaling_op is None else ScaledCalibrationDistribution(self.cop,self.scaling_op)
+
+    @classmethod
+    def make(
+        cls,
+        observation: Observation, 
+        phase_field: jft.Model, 
+        logflux_field: jft.Model, 
+        dt: float, 
+        model_vis_flux_amplitude: Optional[complex] = None, 
+        scaling_parameters: Optional[Tuple[str,Dict]] = None, 
+        log_inv_cov: Optional[jft.Model] = None, 
+        lh_label: Optional[str] = None
+        ):
+            """
+            Constructs CalibrationAssembler with constant field for model_vis and an optional scaling_op.
+
+            obs: Observation
+                Observation object providing multiple quantities for construction of claibration
+                operator and likelihood.
+            phase_field: jft.Model or Bulk_CF_AntennaTimeDomain (preferred)
+                Correlated fields on antenna-time space for phases of calibration solutions.
+            logflux_field: jft.Model or Bulk_CF_AntennaTimeDomain (preferred)
+                Correlated fields on antenna-time space for log amplitude of calibration solutions.
+            dt: float
+                Distances between time points on time axis. Has to be the same distance of time points,
+                which is used for phase_fields and log_amplitude fields.
+            model_vis_flux_amplitude: complex
+                Optional. Assumed magnitude of field for model_vis.
+            scaling_parameters: Tuple[str,Dict]
+                Optional. Selects prior model for scaling operator through string and dictionary contains
+                parameters of corresponding model. Currently only available for the following prior models 
+                (see NIFTy.re package for arguments):
+                    - "inverse_gamma"
+                    - "log_normal"
+            log_inv_cov: jft.Model
+                Optional. Model for log inverse covariance.
+            lh_label: str
+                Optional. Label for likelihood. If not set the default is the name of 
+                the observation.
+            """
+            init_flux_field = None if model_vis_flux_amplitude is None else model_vis_flux_amplitude*jnp.ones(observation.vis.shape)
+            if scaling_parameters is not None:
+                match scaling_parameters[0]:
+                    case "inv_gamma":
+                        scaling_op = jft.InvGammaPrior(**scaling_parameters[1])
+                    case "log_normal":
+                        scaling_op = jft.LogNormalPrior(**scaling_parameters[1])
+                    case _:
+                        raise NotImplementedError("Constructor currently only supports 'log_normal' and 'inv_gamma' for scaling operator construction")
+            else:
+                scaling_op = None
+
+            return cls(observation,phase_field,logflux_field,dt,init_flux_field,scaling_op,log_inv_cov,lh_label)
+        
+
+    def __repr__(self):
+        return(
+            f"Obervation: {self.obs.source_name}\n"
+            f"Phase field: {self.phase_field.name}\n"
+            f"Logflux field: {self.logflux_field.name}\n"
+            f"dt: {self.dt}\n"
+            f"Model visibilities field set: {False if self.model_vis is None else True}\n"
+            f"Scaling operator set: {False if self.scaling_op is None else True}\n"
+            f"Log inverse covariance set: {False if self.log_inv_cov is None else True}\n"
+            f"Likelihood label: {self.lh_label}\n"
+        )
\ No newline at end of file
-- 
GitLab


From 2c619640f6168bd8fbbc06f929e3a38e916693a3 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 16 Apr 2025 17:58:40 +0200
Subject: [PATCH 77/88] Expanded build_exact_r to include calibration solution
 and InterferometryResponse parameters

---
 resolve/re/radio_response.py | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/resolve/re/radio_response.py b/resolve/re/radio_response.py
index ccaf838a..5f3478b4 100644
--- a/resolve/re/radio_response.py
+++ b/resolve/re/radio_response.py
@@ -30,10 +30,10 @@ def get_jax_fft(domain, target, inverse):
 
     return partial(fft, fct=fct, func=func)
 
-def build_exact_r(obs, conf_sky, conf_setup):
+def build_exact_r(obs, conf_sky, conf_setup, calibration_field=None, do_wgridding=True, epsilon=1e-9, verbosity=1, nthreads=8):
     sp_sky_dom =rve.sky_model._spatial_dom(conf_sky)
     sky_dom = rve.default_sky_domain(sdom=sp_sky_dom)
-    R = rve.InterferometryResponse(obs, sky_dom, True, 1e-9, verbosity=1, nthreads=8)
+    R = rve.InterferometryResponse(obs, sky_dom, do_wgridding, epsilon, verbosity, nthreads)
 
     psf_pixels = conf_setup.getfloat("psf pixels")
     full_psf0 = min(2*psf_pixels, sp_sky_dom.shape[0])
@@ -41,13 +41,19 @@ def build_exact_r(obs, conf_sky, conf_setup):
     sp_sky_dom_l = (sp_sky_dom.shape[0] + full_psf0, sp_sky_dom.shape[1] + full_psf1)
     sp_sky_dom_l = ift.RGSpace(sp_sky_dom_l, distances=sp_sky_dom.distances)
     sky_dom_l = rve.default_sky_domain(sdom=sp_sky_dom_l)
-    R_l = rve.InterferometryResponse(obs, sky_dom_l, True, 1e-9, verbosity=1, nthreads=8)
+    R_l = rve.InterferometryResponse(obs, sky_dom_l, do_wgridding, epsilon, verbosity, nthreads)
 
     dch_l = ift.DomainChangerAndReshaper(R_l.domain[3], R_l.domain)
     R_l = R_l @ dch_l
     dch = ift.DomainChangerAndReshaper(R.domain[3], R.domain)
     R = R @ dch
 
+    if calibration_field is not None:
+        C = ift.makeOp(calibration_field)
+
+        R = C @ R
+        R_l = C @ R_l
+
     N_inv = ift.DiagonalOperator(obs.weight)
     RNR = R.adjoint @ N_inv @ R
     RNR_l = R_l.adjoint @ N_inv @ R_l
-- 
GitLab


From 86e722fa610227af2af5a9406a0c5b15455294df Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 29 Apr 2025 12:14:43 +0200
Subject: [PATCH 78/88] Removed ScaledCop and integrate it into
 CalibrationAssembler, Added SkyAssembler

---
 resolve/re/sugar.py | 68 +++++++++++++++++++++++++++------------------
 1 file changed, 41 insertions(+), 27 deletions(-)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index e5fc96f7..bd2d59e0 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -65,31 +65,6 @@ class Bulk_CF_AntennaTimeDomain(jft.Model):
     
     def get_powerspectrum(self):
         return self._powerspectrum
-    
-class ScaledCalibrationDistribution(jft.Model):
-    """
-    Multiplies a calibration distribution with a scalar operator to correct flux of 
-    model visibilities used in calibration.
-
-    Parameters
-    ----------
-    calibration_opertor: CalibrationDistribution
-        Calibration distribution which should be scaled.
-    scaling_operator: jft.Model    
-        Scalar scaling operator
-    """
-    def __init__(
-            self, 
-            calibration_operator: CalibrationDistribution, 
-            scaling_operator: jft.Model
-            ):
-        self._cop = calibration_operator
-        self._scaling = scaling_operator
-
-        super().__init__(init=self._cop.init | self._scaling.init)
-
-    def __call__(self,x):
-        return self._scaling(x)*self._cop(x)
 
 @dataclass    
 class CalibrationAssembler:
@@ -128,12 +103,16 @@ class CalibrationAssembler:
     log_inv_cov: Optional[jft.Model] = None
     lh_label: Optional[str] = None
     cop: CalibrationDistribution = field(init=False)
-    scaled_cop: Optional[ScaledCalibrationDistribution] = field(init=False)
+    scaled_cop: Optional[jft.Model] = field(init=False)
 
     def __post_init__(self):
         self.lh_label = self.obs.source_name if self.lh_label is None else self.lh_label
         self.cop = CalibrationDistribution(self.obs,self.phase_field,self.logflux_field,self.dt)
-        self.scaled_cop = None if self.scaling_op is None else ScaledCalibrationDistribution(self.cop,self.scaling_op)
+        self.scaled_cop = None if self.scaling_op is None else jft.Model(
+            call = lambda x: self.scaling_op(x)*self.cop(x),
+            domain = {**self.cop.domain,**self.scaling_op.domain},
+            init =  self.cop.init | self.scaling_op.init
+        )
 
     @classmethod
     def make(
@@ -150,6 +129,8 @@ class CalibrationAssembler:
             """
             Constructs CalibrationAssembler with constant field for model_vis and an optional scaling_op.
 
+            Parameters
+            ----------
             obs: Observation
                 Observation object providing multiple quantities for construction of claibration
                 operator and likelihood.
@@ -199,4 +180,37 @@ class CalibrationAssembler:
             f"Scaling operator set: {False if self.scaling_op is None else True}\n"
             f"Log inverse covariance set: {False if self.log_inv_cov is None else True}\n"
             f"Likelihood label: {self.lh_label}\n"
+        )
+
+@dataclass
+class SkyAssembler:
+    """
+    Data class for storage of quantities for the sky reconstruction with fast resolve
+
+    Parameters
+    ----------
+    obs: Observation
+        Observation object providing multiple quantities for construction of claibration
+        operator and likelihood.
+    sky: jft.Model
+        Sky model for reconstruction
+    noise_scaling: jft.Model
+        Optional. Model for noise scaling.
+    varcov: jft.Model
+        Optional. Model for variable covariance
+    """
+    obs: Observation
+    sky: jft.Model
+    noise_scaling: Optional[jft.Model] = None
+    varcov: Optional[jft.Model] = None
+
+    def __post_init__(self):
+        if (self.noise_scaling is not None) and (self.varcov is not None):
+            raise ValueError("Either noise_scaling or varcov can be set.")
+        
+    def __repr__(self):
+        return(
+            f"Obervation: {self.obs.source_name}\n"
+            f"Noise scaling model set: {False if self.noise_scaling is None else True}\n"
+            f"Variable covariance model set: {False if self.varcov is None else True}\n"
         )
\ No newline at end of file
-- 
GitLab


From b9e592b5d061c2920f915f7469e3dbcf9296f0a1 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 29 Apr 2025 12:15:27 +0200
Subject: [PATCH 79/88] Added tests for jax ports

---
 test/test_jax_ports/test_01_BulkCF.py         |  94 +++++++
 .../test_02_CalibrationInterpolator.py        |  64 +++++
 .../test_03_CalibrationDistributor.py         | 104 ++++++++
 .../test_04_CalibrationLikelihood.py          | 151 +++++++++++
 .../test_05_ImagingLikelihood.py              | 240 ++++++++++++++++++
 test/test_jax_ports/to_re_helpers.py          |  63 +++++
 6 files changed, 716 insertions(+)
 create mode 100644 test/test_jax_ports/test_01_BulkCF.py
 create mode 100644 test/test_jax_ports/test_02_CalibrationInterpolator.py
 create mode 100644 test/test_jax_ports/test_03_CalibrationDistributor.py
 create mode 100644 test/test_jax_ports/test_04_CalibrationLikelihood.py
 create mode 100644 test/test_jax_ports/test_05_ImagingLikelihood.py
 create mode 100644 test/test_jax_ports/to_re_helpers.py

diff --git a/test/test_jax_ports/test_01_BulkCF.py b/test/test_jax_ports/test_01_BulkCF.py
new file mode 100644
index 00000000..aac1716f
--- /dev/null
+++ b/test/test_jax_ports/test_01_BulkCF.py
@@ -0,0 +1,94 @@
+import os
+os.environ["JAX_PLATFORM_NAME"] = "cpu"
+
+import numpy as np
+import nifty8 as ift
+import ducc0
+import resolve as rve
+import resolve.re as jrve
+
+from numpy.testing import assert_allclose
+
+import jax
+jax.config.update("jax_enable_x64", True)
+
+np.random.seed(42)
+
+from observation_generator import RandomObservation
+from to_re_helpers import rve_cf_dict_to_jrve_cf_dict
+
+def prepare_operators():
+    param = {
+    "n_baselines": 10000,
+    "pol_indices": (8,5),
+    "freq_channels": np.array([1.5e9]),
+    "uvw_max": 200,
+    "n_antenna_max": 10,
+    "time_max": 5000,
+    "abs_vis_max": 0.5,
+    "weight_min":0,
+    "weight_max": 10000,
+    }
+
+    obs = RandomObservation(**param)
+
+    uantennas = rve.unique_antennas(obs)
+    antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)}
+    total_N = len(param["pol_indices"]) * len(uantennas)
+
+    dt = 20  # s
+    zero_padding_factor = 2
+    tmin, tmax = rve.tmin_tmax(obs)
+    time_domain = ift.RGSpace(
+        ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt
+    )
+
+    polarizations = obs.polarization.to_str_list()
+    frequencies = [1]
+    antennas = uantennas
+
+    dofdex = np.arange(total_N)
+    rve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5), "dofdex": dofdex}
+    rve_dct_ps = {
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "dofdex": dofdex
+    }
+
+    jrve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5),}
+    jrve_dct_ps = {
+        "shape": time_domain.shape,
+        "distances": time_domain.distances[0],
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "non_parametric_kind": "power"
+    }
+
+    cfmph = ift.CorrelatedFieldMaker(prefix='bulk_', total_N=total_N)
+    cfmph.add_fluctuations(time_domain, **rve_dct_ps)
+    cfmph.set_amplitude_total_offset(**rve_dct_offset)
+    rve_bulkCF = cfmph.finalize(0)
+
+    pdom, _, fdom = obs.vis.domain
+    reshape = ift.DomainChangerAndReshaper(rve_bulkCF.target, [pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom])
+
+    rve_op = reshape @ rve_bulkCF
+
+    jrve_op = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"bulk",polarizations,frequencies,antennas)
+
+    return rve_op, jrve_op
+
+def test_consistency_BulkCF():
+    rve_op, jrve_op = prepare_operators()
+
+    rve_inp = ift.from_random(rve_op.domain)
+    jrve_inp = rve_cf_dict_to_jrve_cf_dict(rve_inp.val)
+
+    rve_field = rve_op(rve_inp).val
+    jrve_field = jrve_op(jrve_inp)
+
+    assert_allclose(jrve_field,rve_field)
diff --git a/test/test_jax_ports/test_02_CalibrationInterpolator.py b/test/test_jax_ports/test_02_CalibrationInterpolator.py
new file mode 100644
index 00000000..63fc7bee
--- /dev/null
+++ b/test/test_jax_ports/test_02_CalibrationInterpolator.py
@@ -0,0 +1,64 @@
+import os
+os.environ["JAX_PLATFORM_NAME"] = "cpu"
+
+import numpy as np
+import jax.numpy as jnp
+import nifty8 as ift
+import ducc0
+import resolve as rve
+import resolve.re as jrve
+
+from numpy.testing import assert_allclose
+
+import jax
+jax.config.update("jax_enable_x64", True)
+
+np.random.seed(42)
+
+from observation_generator import RandomObservation
+from to_re_helpers import rve_cf_dict_to_jrve_cf_dict
+
+def prepare_operators():
+    param = {
+    "n_baselines": 10000,
+    "pol_indices": (8,5),
+    "freq_channels": np.array([1.5e9]),
+    "uvw_max": 200,
+    "n_antenna_max": 10,
+    "time_max": 5000,
+    "abs_vis_max": 0.5,
+    "weight_min":0,
+    "weight_max": 10000,
+    }
+
+    obs = RandomObservation(**param)
+
+    uantennas = rve.unique_antennas(obs)
+    antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)}
+    total_N = len(param["pol_indices"]) * len(uantennas)
+
+    dt = 20  # s
+    zero_padding_factor = 2
+    tmin, tmax = rve.tmin_tmax(obs)
+    time_domain = ift.RGSpace(
+        ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt
+    )
+
+    pdom, _, fdom = obs.vis.domain
+    dom = ift.DomainTuple.make([pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom])
+    
+    jrve_op = jrve.CalibrationInterpolator(obs.ant1,obs.time,time_domain.distances[0],obs.vis.val.shape)
+
+    rve_op = rve.CalibrationDistributor(dom,obs.vis.domain,obs.ant1,obs.time,antenna_dct,None)
+
+    return rve_op, jrve_op
+
+def test_consistency_CalibrationInterpolator():
+    rve_op, jrve_op = prepare_operators()
+
+    inp = ift.from_random(rve_op.domain)
+
+    rve_field = rve_op(inp).val
+    jrve_field = jrve_op(jnp.asarray(inp.val))
+
+    assert_allclose(jrve_field,rve_field)
\ No newline at end of file
diff --git a/test/test_jax_ports/test_03_CalibrationDistributor.py b/test/test_jax_ports/test_03_CalibrationDistributor.py
new file mode 100644
index 00000000..b755e699
--- /dev/null
+++ b/test/test_jax_ports/test_03_CalibrationDistributor.py
@@ -0,0 +1,104 @@
+import os
+os.environ["JAX_PLATFORM_NAME"] = "cpu"
+
+import numpy as np
+import nifty8 as ift
+import ducc0
+import resolve as rve
+import resolve.re as jrve
+
+from numpy.testing import assert_allclose
+
+import jax
+jax.config.update("jax_enable_x64", True)
+
+np.random.seed(42)
+
+from observation_generator import RandomObservation
+from to_re_helpers import rve_cf_dict_to_jrve_cf_dict
+
+def prepare_operators():
+    param = {
+    "n_baselines": 10000,
+    "pol_indices": (8,5),
+    "freq_channels": np.array([1.5e9]),
+    "uvw_max": 200,
+    "n_antenna_max": 10,
+    "time_max": 5000,
+    "abs_vis_max": 0.5,
+    "weight_min":0,
+    "weight_max": 10000,
+    }
+
+    obs = RandomObservation(**param)
+
+    uantennas = rve.unique_antennas(obs)
+    antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)}
+    total_N = len(param["pol_indices"]) * len(uantennas)
+
+    dt = 20  # s
+    zero_padding_factor = 2
+    tmin, tmax = rve.tmin_tmax(obs)
+    time_domain = ift.RGSpace(
+        ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt
+    )
+
+    polarizations = obs.polarization.to_str_list()
+    frequencies = [1]
+    antennas = uantennas
+
+    dofdex = np.arange(total_N)
+    rve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5), "dofdex": dofdex}
+    rve_dct_ps = {
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "dofdex": dofdex
+    }
+
+    jrve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5),}
+    jrve_dct_ps = {
+        "shape": time_domain.shape,
+        "distances": time_domain.distances[0],
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "non_parametric_kind": "power"
+    }
+
+    cfmph = ift.CorrelatedFieldMaker(prefix='phase_', total_N=total_N)
+    cfmph.add_fluctuations(time_domain, **rve_dct_ps)
+    cfmph.set_amplitude_total_offset(**rve_dct_offset)
+    rve_phase = cfmph.finalize(0)
+
+    cfmph = ift.CorrelatedFieldMaker(prefix='logamp_', total_N=total_N)
+    cfmph.add_fluctuations(time_domain, **rve_dct_ps)
+    cfmph.set_amplitude_total_offset(**rve_dct_offset)
+    rve_logamp = cfmph.finalize(0)
+
+    pdom, _, fdom = obs.vis.domain
+    reshape = ift.DomainChangerAndReshaper(rve_phase.target, [pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom])
+
+    rve_phase = reshape @ rve_phase
+    rve_logamp = reshape @ rve_logamp
+
+    jrve_phase = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"phase",polarizations,frequencies,antennas)
+    jrve_logamp = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"logamp",polarizations,frequencies,antennas)
+
+    rve_op = rve.calibration_distribution(obs,rve_phase,rve_logamp,antenna_dct,None)
+    jrve_op = jrve.CalibrationDistribution(obs,jrve_phase,jrve_logamp,time_domain.distances[0])
+
+    return rve_op, jrve_op
+
+def test_consistency_CalibrationDistributor():
+    rve_op, jrve_op = prepare_operators()
+
+    rve_inp = ift.from_random(rve_op.domain)
+    jrve_inp = rve_cf_dict_to_jrve_cf_dict(rve_inp.val)
+
+    rve_field = rve_op(rve_inp).val
+    jrve_field = jrve_op(jrve_inp)
+
+    assert_allclose(jrve_field,rve_field)
\ No newline at end of file
diff --git a/test/test_jax_ports/test_04_CalibrationLikelihood.py b/test/test_jax_ports/test_04_CalibrationLikelihood.py
new file mode 100644
index 00000000..0a469d82
--- /dev/null
+++ b/test/test_jax_ports/test_04_CalibrationLikelihood.py
@@ -0,0 +1,151 @@
+import os
+os.environ["JAX_PLATFORM_NAME"] = "cpu"
+
+import numpy as np
+import jax.numpy as jnp
+import nifty8 as ift
+import nifty8.re as jft
+import ducc0
+import resolve as rve
+import resolve.re as jrve
+from functools import partial
+
+import pytest
+from numpy.testing import assert_allclose
+
+import jax
+jax.config.update("jax_enable_x64", True)
+
+np.random.seed(42)
+
+from observation_generator import RandomObservation
+from to_re_helpers import rve_cf_dict_to_jrve_cf_dict, mixed_rve_dct_to_mixed_jrve
+
+pmp = pytest.mark.parametrize
+
+def prepare_partial_operators():
+    param = {
+    "n_baselines": 10000,
+    "pol_indices": (8,5),
+    "freq_channels": np.array([1.5e9]),
+    "uvw_max": 200,
+    "n_antenna_max": 10,
+    "time_max": 5000,
+    "abs_vis_max": 0.5,
+    "weight_min":0,
+    "weight_max": 10000,
+    }
+
+    obs = RandomObservation(**param)
+
+    uantennas = rve.unique_antennas(obs)
+    antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)}
+    total_N = len(param["pol_indices"]) * len(uantennas)
+
+    dt = 20  # s
+    zero_padding_factor = 2
+    tmin, tmax = rve.tmin_tmax(obs)
+    time_domain = ift.RGSpace(
+        ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt
+    )
+
+    polarizations = obs.polarization.to_str_list()
+    frequencies = [1]
+    antennas = uantennas
+
+    dofdex = np.arange(total_N)
+    rve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5), "dofdex": dofdex}
+    rve_dct_ps = {
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "dofdex": dofdex
+    }
+
+    jrve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5),}
+    jrve_dct_ps = {
+        "shape": time_domain.shape,
+        "distances": time_domain.distances[0],
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "non_parametric_kind": "power"
+    }
+
+    cfmph = ift.CorrelatedFieldMaker(prefix='phase_', total_N=total_N)
+    cfmph.add_fluctuations(time_domain, **rve_dct_ps)
+    cfmph.set_amplitude_total_offset(**rve_dct_offset)
+    rve_phase = cfmph.finalize(0)
+
+    cfmph = ift.CorrelatedFieldMaker(prefix='logamp_', total_N=total_N)
+    cfmph.add_fluctuations(time_domain, **rve_dct_ps)
+    cfmph.set_amplitude_total_offset(**rve_dct_offset)
+    rve_logamp = cfmph.finalize(0)
+
+    pdom, _, fdom = obs.vis.domain
+    reshape = ift.DomainChangerAndReshaper(rve_phase.target, [pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom])
+
+    rve_phase = reshape @ rve_phase
+    rve_logamp = reshape @ rve_logamp
+
+    jrve_phase = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"phase",polarizations,frequencies,antennas)
+    jrve_logamp = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"logamp",polarizations,frequencies,antennas)
+
+    rve_cop = rve.calibration_distribution(obs,rve_phase,rve_logamp,antenna_dct,None)
+    jrve_cop = jrve.CalibrationDistribution(obs,jrve_phase,jrve_logamp,time_domain.distances[0])
+
+    rve_model_vis = ift.full(obs.vis.domain,1 + 0.0j)
+    jrve_model_vis = jnp.array(rve_model_vis.val)
+
+    rve_op = partial(
+        rve.CalibrationLikelihood, 
+        observation = obs, 
+        calibration_operator = rve_cop,
+        model_visibilities = rve_model_vis
+        )
+    
+    jrve_op = partial(
+        jrve.CalibrationLikelihood,
+        observation = obs, 
+        calibration_operator = jrve_cop,
+        model_visibilities = jrve_model_vis
+    )
+
+    return rve_op, jrve_op, obs
+
+def prepare_operators(mode):
+    rve_op_partial, jrve_op_partial, obs = prepare_partial_operators()
+
+    if mode == "model":
+        rve_log_inv_cov = ift.NormalTransform(0,1,"log_inv_cov",obs.weight.val.size)
+        reshape = ift.DomainChangerAndReshaper(rve_log_inv_cov.target,obs.weight.domain)
+        rve_log_inv_cov = reshape @ rve_log_inv_cov
+
+        jrve_log_inv_cov = jft.NormalPrior(0,1,name="log_inv_cov",shape=obs.weight.val.shape)
+    else:
+        rve_log_inv_cov = None
+        jrve_log_inv_cov = None
+
+    rve_op = rve_op_partial(log_inverse_covariance_operator=rve_log_inv_cov)
+    jrve_op = jrve_op_partial(log_inverse_covariance_operator=jrve_log_inv_cov)
+
+    return rve_op, jrve_op, obs
+
+
+@pmp("log_inv_cov_mode",[None, "model"])
+def test_consistency_CalibrationLikelihood(log_inv_cov_mode):
+    rve_op, jrve_op, obs = prepare_operators(log_inv_cov_mode)
+
+    rve_inp = ift.from_random(rve_op.domain)
+
+    if log_inv_cov_mode == "model":
+        jrve_inp = mixed_rve_dct_to_mixed_jrve(rve_inp.val,("log_inv_cov",),(obs.weight.val.shape,))
+    else:
+        jrve_inp = rve_cf_dict_to_jrve_cf_dict(rve_inp.val)
+
+    rve_field = rve_op(rve_inp).val
+    jrve_field = jrve_op(jrve_inp)
+
+    assert_allclose(jrve_field,rve_field)
\ No newline at end of file
diff --git a/test/test_jax_ports/test_05_ImagingLikelihood.py b/test/test_jax_ports/test_05_ImagingLikelihood.py
new file mode 100644
index 00000000..57169fda
--- /dev/null
+++ b/test/test_jax_ports/test_05_ImagingLikelihood.py
@@ -0,0 +1,240 @@
+import os
+os.environ["JAX_PLATFORM_NAME"] = "cpu"
+
+import numpy as np
+import jax.numpy as jnp
+import nifty8 as ift
+import nifty8.re as jft
+import ducc0
+import resolve as rve
+import resolve.re as jrve
+from functools import partial
+
+import pytest
+from numpy.testing import assert_allclose
+
+import jax
+jax.config.update("jax_enable_x64", True)
+
+np.random.seed(42)
+
+from observation_generator import RandomObservation
+from to_re_helpers import rve_cf_dict_to_jrve_cf_dict_simple, mixed_rve_dct_to_mixed_jrve
+
+pmp = pytest.mark.parametrize
+
+class JRVE_Sky(jft.Model):
+    def __init__(self,field):
+        self._field = field
+
+        super().__init__(init=self._field.init)
+
+    def __call__(self,x):
+        return jnp.array([[[jnp.exp(self._field(x))]]])
+
+def prepare_partial_operators():
+    param = {
+    "n_baselines": 10000,
+    "pol_indices": (8,5),
+    "freq_channels": np.array([1.5e9]),
+    "uvw_max": 200,
+    "n_antenna_max": 10,
+    "time_max": 5000,
+    "abs_vis_max": 0.5,
+    "weight_min":0,
+    "weight_max": 10000,
+    }
+
+    obs = RandomObservation(**param)
+
+    fov = np.array([1, 1]) * rve.DEG2RAD
+    npix = np.array([64, 64])
+    skydom = ift.RGSpace(npix, fov / npix)
+
+    sky_domain_dict = dict(npix_x=skydom.shape[0],
+                        npix_y=skydom.shape[1],
+                        pixsize_x=skydom.distances[0],
+                        pixsize_y=skydom.distances[1],
+                        pol_labels=['I'],
+                        times=[0.],
+                        freqs=[0.])
+
+    rve_sky_dct = {
+        "target": skydom,
+        "offset_mean": 0,
+        "offset_std": (1, 0.5),
+        "prefix": "logdiffuse_",
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+    }
+    rve_logsky = ift.SimpleCorrelatedField(**rve_sky_dct)
+    rve_sky = rve_logsky.exp()
+
+    rve_sky_dom = rve.default_sky_domain(sdom = skydom)
+    reshape_sky = ift.DomainChangerAndReshaper(rve_sky.target, rve_sky_dom)
+    rve_sky = reshape_sky @ rve_sky
+
+    jrve_dct_sky_offset = {"offset_mean": 0, "offset_std": (1, 0.5),}
+    jrve_dct_sky_ps = {
+        "shape": skydom.shape,
+        "distances": skydom.distances[0],
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "non_parametric_kind": "power"
+    }
+    cfm = jft.CorrelatedFieldMaker("logdiffuse_")
+    cfm.set_amplitude_total_offset(**jrve_dct_sky_offset)
+    cfm.add_fluctuations(**jrve_dct_sky_ps)
+    jrve_logsky = cfm.finalize()
+    jrve_sky = JRVE_Sky(jrve_logsky)
+
+    rve_op = partial(
+        rve.ImagingLikelihood,
+        observation = obs,
+        sky_operator = rve_sky,
+        epsilon = 1e-5,
+        do_wgridding = True
+        )
+    
+    jrve_op = partial(
+        jrve.ImagingLikelihood,
+        observation = obs,
+        sky_operator = jrve_sky,
+        sky_domain_dict = sky_domain_dict,
+        epsilon = 1e-5,
+        do_wgridding = True
+    )
+
+    return rve_op, jrve_op, obs, rve_sky.domain
+
+def prepare_cop(obs):
+    uantennas = rve.unique_antennas(obs)
+    antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)}
+    total_N = len(obs.polarization) * len(uantennas)
+
+    dt = 20  # s
+    zero_padding_factor = 2
+    tmin, tmax = rve.tmin_tmax(obs)
+    time_domain = ift.RGSpace(
+        ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt
+    )
+
+    polarizations = obs.polarization.to_str_list()
+    frequencies = [1]
+    antennas = uantennas
+
+    dofdex = np.arange(total_N)
+    rve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5), "dofdex": dofdex}
+    rve_dct_ps = {
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "dofdex": dofdex
+    }
+
+    jrve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5),}
+    jrve_dct_ps = {
+        "shape": time_domain.shape,
+        "distances": time_domain.distances[0],
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "non_parametric_kind": "power"
+    }
+
+    cfmph = ift.CorrelatedFieldMaker(prefix='phase_', total_N=total_N)
+    cfmph.add_fluctuations(time_domain, **rve_dct_ps)
+    cfmph.set_amplitude_total_offset(**rve_dct_offset)
+    rve_phase = cfmph.finalize(0)
+
+    cfmph = ift.CorrelatedFieldMaker(prefix='logamp_', total_N=total_N)
+    cfmph.add_fluctuations(time_domain, **rve_dct_ps)
+    cfmph.set_amplitude_total_offset(**rve_dct_offset)
+    rve_logamp = cfmph.finalize(0)
+
+    pdom, _, fdom = obs.vis.domain
+    reshape = ift.DomainChangerAndReshaper(rve_phase.target, [pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom])
+
+    rve_phase = reshape @ rve_phase
+    rve_logamp = reshape @ rve_logamp
+
+    jrve_phase = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"phase",polarizations,frequencies,antennas)
+    jrve_logamp = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"logamp",polarizations,frequencies,antennas)
+
+    rve_cop = rve.calibration_distribution(obs,rve_phase,rve_logamp,antenna_dct,None)
+    jrve_cop = jrve.CalibrationDistribution(obs,jrve_phase,jrve_logamp,time_domain.distances[0])
+
+    return rve_cop, jrve_cop
+
+
+def prepare_operators(cov_mode, cal_mode):
+    rve_op_partial, jrve_op_partial, obs, rve_sky_dom = prepare_partial_operators()
+
+    if cov_mode == "model":
+        rve_log_inv_cov = ift.NormalTransform(0,1,"log_inv_cov",obs.weight.val.size)
+        reshape = ift.DomainChangerAndReshaper(rve_log_inv_cov.target,obs.weight.domain)
+        rve_log_inv_cov = reshape @ rve_log_inv_cov
+
+        jrve_log_inv_cov = jft.NormalPrior(0,1,name="log_inv_cov",shape=obs.weight.val.shape)
+    else:
+        rve_log_inv_cov = None
+        jrve_log_inv_cov = None
+
+    rve_op_partial = partial(rve_op_partial,log_inverse_covariance_operator=rve_log_inv_cov)
+    jrve_op_partial = partial(jrve_op_partial,log_inverse_covariance_operator=jrve_log_inv_cov)
+
+    if cal_mode is not None:
+        rve_cop, jrve_cop = prepare_cop(obs)
+
+    if cal_mode == "cop":
+        rve_op = rve_op_partial(calibration_operator=rve_cop,calibration_field=None)
+        jrve_op = jrve_op_partial(calibration_operator=jrve_cop,calibration_field=None)
+    elif cal_mode == "cfld":
+        rve_cfld = ift.from_random(rve_cop.target,dtype="complex128")
+        jrve_cfld = jnp.array(rve_cfld.val)
+
+        rve_op = rve_op_partial(calibration_operator=None,calibration_field=rve_cfld)
+        jrve_op = jrve_op_partial(calibration_operator=None,calibration_field=jrve_cfld)
+    else:
+        rve_op = rve_op_partial(calibration_operator=None,calibration_field=None)
+        jrve_op = jrve_op_partial(calibration_operator=None,calibration_field=None)
+
+    return rve_op, jrve_op, obs, rve_sky_dom
+
+
+@pmp("log_inv_cov_mode",[None, "model"])
+@pmp("cal_op_mode", [None, "cop", "cfld"])
+def test_consistency_CalibrationLikelihood(log_inv_cov_mode,cal_op_mode):
+    rve_op, jrve_op, obs, rve_sky_dom = prepare_operators(log_inv_cov_mode, cal_op_mode)
+
+    rve_inp = ift.from_random(rve_op.domain)
+
+    if log_inv_cov_mode == "model":
+        jrve_inp = mixed_rve_dct_to_mixed_jrve(rve_inp.val,("log_inv_cov",),(obs.weight.val.shape,),rve_sky_dom.keys())
+    else:
+        if (cal_op_mode == "cop"):
+            jrve_inp = mixed_rve_dct_to_mixed_jrve(rve_inp.val,simple_cf_keys=rve_sky_dom.keys())
+        else:
+            jrve_inp = rve_cf_dict_to_jrve_cf_dict_simple(rve_inp.val)
+
+    rve_field = rve_op(rve_inp).val
+    jrve_field = jrve_op(jrve_inp)
+
+    assert_allclose(jrve_field,rve_field)
+
+
+
+
+    #s_jrve_fc_img_lh_dct_translated = rve_cf_dict_to_jrve_cf_dict_simple(s_rve_fc_img_lh_dct.val)
+    #s_jrve_fc_img_cfld_lh_dct_translated = rve_cf_dict_to_jrve_cf_dict_simple(s_rve_fc_img_cfld_lh_dct.val)
+
+    #s_jrve_fc_img_cop_lh_dct_translated = mixed_rve_dct_to_mixed_jrve(s_rve_fc_img_cop_lh_dct.val,simple_cf_keys=ift.from_random(rve_sky.domain).val.keys())
+    #s_jrve_vc_img_lh_dct_translated = mixed_rve_dct_to_mixed_jrve(s_rve_vc_img_lh_dct.val,("log_inv_cov",),(obs.weight.val.shape,),simple_cf_keys=ift.from_random(rve_sky.domain).val.keys())
+    #s_jrve_vc_img_cfld_lh_dct_translated = mixed_rve_dct_to_mixed_jrve(s_rve_vc_img_cfld_lh_dct.val,("log_inv_cov",),(obs.weight.val.shape,),simple_cf_keys=ift.from_random(rve_sky.domain).val.keys())
+    #s_jrve_vc_img_cop_lh_dct_translated = mixed_rve_dct_to_mixed_jrve(s_rve_vc_img_cop_lh_dct.val,("log_inv_cov",),(obs.weight.val.shape,),simple_cf_keys=ift.from_random(rve_sky.domain).val.keys())
\ No newline at end of file
diff --git a/test/test_jax_ports/to_re_helpers.py b/test/test_jax_ports/to_re_helpers.py
new file mode 100644
index 00000000..6fcae115
--- /dev/null
+++ b/test/test_jax_ports/to_re_helpers.py
@@ -0,0 +1,63 @@
+import numpy as np
+import jax.numpy as jnp
+
+def rve_cf_dict_to_jrve_cf_dict(rve_cf_dct):
+    res = rve_cf_dct
+    for h in res.keys():
+        h_split = h.split("_")
+
+        if(h_split[1] == "spectrum"):
+            res[h] = np.swapaxes(res[h],1,2)
+    return res
+
+def rve_cf_dict_to_jrve_cf_dict_simple(rve_cf_dct):
+    res = rve_cf_dct
+
+    for h in rve_cf_dct.keys():
+        h_split = h.split("_")
+        
+        if (h_split[1] == "spectrum"):
+            res[h] = res[h].T
+    return res
+
+def single_rve_dct_to_single_jrve_dct(rve_dct,shape):
+    if len(rve_dct) != 1:
+        raise ValueError("Can only convert dict with single entry/element")
+    
+    key, rve_value = rve_dct.popitem()
+
+    jrve_value = np.reshape(rve_value,shape)
+
+    return {key: jnp.array(jrve_value)}
+
+def mixed_rve_dct_to_mixed_jrve(rve_dct,non_cf_keys=None,non_cf_shapes=None,simple_cf_keys=None):
+    if (non_cf_keys is not None) and (non_cf_shapes is not None):
+        if (len(non_cf_keys) != len(non_cf_shapes)) and (non_cf_keys is not None):
+            raise ValueError("Each non cf dict key should have an shape for the associate dict value")  
+    
+    if ((non_cf_keys is not None) and (simple_cf_keys is None)):
+        non_cf_rve_dcts = [({key: rve_dct[key]},non_cf_shapes[k]) for k,key in enumerate(non_cf_keys)]
+        cf_rve_dct = {key: rve_dct[key] for key in rve_dct.keys() if (key not in non_cf_keys)}
+        simple_cf_rve_dcts = {}
+    elif ((non_cf_keys is None) and (simple_cf_keys is not None)):
+        simple_cf_rve_dcts = {key: rve_dct[key] for key in simple_cf_keys}
+        cf_rve_dct = {key: rve_dct[key] for key in rve_dct.keys() if (key not in simple_cf_keys)}
+        non_cf_rve_dcts = []
+    else:
+        non_cf_rve_dcts = [({key: rve_dct[key]},non_cf_shapes[k]) for k,key in enumerate(non_cf_keys)]
+        simple_cf_rve_dcts = {key: rve_dct[key] for key in simple_cf_keys}
+        cf_rve_dct = {key: rve_dct[key] for key in rve_dct.keys() if ((key not in non_cf_keys) and (key not in simple_cf_keys))}
+
+    jrve_dct = {}
+
+    if len(cf_rve_dct) > 0:
+        jrve_dct = rve_cf_dict_to_jrve_cf_dict(cf_rve_dct)
+
+    if len(non_cf_rve_dcts) > 0:
+        for x in non_cf_rve_dcts:
+            jrve_dct.update(single_rve_dct_to_single_jrve_dct(x[0],x[1]))
+
+    if len(simple_cf_rve_dcts) > 0:
+        jrve_dct.update(rve_cf_dict_to_jrve_cf_dict_simple(simple_cf_rve_dcts))
+
+    return jrve_dct
\ No newline at end of file
-- 
GitLab


From db8d90246000478343cdbd53fb1fee51d9f784d0 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 29 Apr 2025 12:15:55 +0200
Subject: [PATCH 80/88] Moved observation generator to jax port test directory

---
 test/test_jax_ports/observation_generator.py | 41 ++++++++++++++++++++
 1 file changed, 41 insertions(+)
 create mode 100644 test/test_jax_ports/observation_generator.py

diff --git a/test/test_jax_ports/observation_generator.py b/test/test_jax_ports/observation_generator.py
new file mode 100644
index 00000000..7e7cd25a
--- /dev/null
+++ b/test/test_jax_ports/observation_generator.py
@@ -0,0 +1,41 @@
+import numpy as np
+
+import resolve as rve
+
+def RandomAntennaPositions(n_row,uvw_max,n_antenna_max,time_max,rng_generator=np.random.default_rng(42)):
+    ant1,ant2 = [],[]
+
+    for k in range(n_row):
+        x = rng_generator.integers(0,n_antenna_max-1)
+        y = rng_generator.integers(1,n_antenna_max)
+
+        if(x==y):
+            while(x==y):
+                y = np.random.randint(1,n_antenna_max)
+
+        ant1.append(x)
+        ant2.append(y)
+
+    ant1 = np.array(ant1).astype(np.int32)
+    ant2 = np.array(ant2).astype(np.int32)
+    time = rng_generator.uniform(0,time_max,n_row)
+    uvw = rng_generator.uniform(-uvw_max,uvw_max,(n_row,3))
+
+    return rve.data.antenna_positions.AntennaPositions(uvw,ant1,ant2,time)
+
+def RandomObservation(n_baselines,pol_indices,freq_channels,uvw_max,n_antenna_max,time_max,abs_vis_max,weight_min,weight_max,rng_generator=np.random.default_rng(42)):
+    antenna_pos = RandomAntennaPositions(n_baselines,uvw_max,n_antenna_max,time_max,rng_generator=rng_generator)
+
+    n_pol = len(pol_indices)
+    n_freq = freq_channels.size
+    vis_shape = (n_pol,n_baselines,n_freq)
+
+    pol = rve.data.polarization.Polarization(pol_indices)
+
+    vis_magnitude = rng_generator.uniform(0,abs_vis_max,vis_shape)
+    vis_phase = rng_generator.uniform(0,2*np.pi,vis_shape)
+    vis = vis_magnitude*np.exp(1.0j*vis_phase)
+
+    weights = rng_generator.uniform(weight_min,weight_max,vis_shape)
+    
+    return rve.data.observation.Observation(antenna_pos,vis,weights,pol,freq_channels)
\ No newline at end of file
-- 
GitLab


From 7c6616ef4e0d4840fb1f284776b549295577ddbf Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 29 Apr 2025 12:17:11 +0200
Subject: [PATCH 81/88] Added fast-resolve routine with calibration step before
 imaging

---
 resolve/re/fast-resolve_cal.py | 244 +++++++++++++++++++++++++++++++++
 1 file changed, 244 insertions(+)
 create mode 100644 resolve/re/fast-resolve_cal.py

diff --git a/resolve/re/fast-resolve_cal.py b/resolve/re/fast-resolve_cal.py
new file mode 100644
index 00000000..31104a2d
--- /dev/null
+++ b/resolve/re/fast-resolve_cal.py
@@ -0,0 +1,244 @@
+import nifty8 as ift
+import nifty8.re as jft
+
+import resolve.re as jrve
+
+import jax.numpy as jnp
+import jax.random as jr
+
+def calibration_step(
+        cal_assemblers,
+        config_minimizer,
+        key,
+        assembler_keys = None,
+        resume = False,
+        odir = None
+        ):
+    
+    k_i, k_0 = jr.split(key,2)
+    
+    lh_param_keys = ("observation","calibration_operator","model_visibilities","log_inverse_covariance_operator","likelihood_label")
+    cal_assembler_attr_keys = ("obs","cop","model_vis","log_inv_cov","lh_label")
+    assembler_keys = cal_assemblers.keys() if assembler_keys is None else assembler_keys
+
+    lh_components = {i: [getattr(cal_assemblers[k],j) for k in assembler_keys] for i,j in zip(lh_param_keys,cal_assembler_attr_keys)}
+    lh = jrve.CalibrationLikelihood(**lh_components)
+
+    optimize_kl_kwargs = dict(
+        likelihood = lh,
+        position_or_samples = jft.random_like(k_0,lh.domain),
+        key = k_i,
+        resume = resume,
+        odir = odir,
+        **config_minimizer
+    )
+    
+    return jft.optimize_kl(**optimize_kl_kwargs)
+
+def imaging_step(
+        sky_assembler,
+        calibration_field,
+        config_response,
+        config_minimizer,
+        key,
+        init_samples = None,
+        cache_noise_kernel="None",
+        cache_response_kernel="None",
+        resume = False,
+        odir = None
+        ):
+    
+    obs = sky_assembler.obs
+    sky_model = sky_assembler.sky
+    noise_scaling_op = sky_assembler.noise_scaling
+    varcov_op = sky_assembler.varcov
+    
+    noise_scaling = False if noise_scaling_op is None else True
+    varcov = False if varcov_op is None else True
+
+    k_i, k_0 = jr.split(key,2)
+
+    R, _, RNR, RNR_l = jrve.build_exact_r(obs,calibration_field=calibration_field,**config_response)
+
+    N_inv = ift.makeOp(obs.weight)
+    d_0 = jnp.array(R.adjoint(N_inv(obs.vis)).val)
+    
+
+    if noise_scaling:
+        RNR_approx, N_inv_approx = jrve.build_approximations(RNR,RNR_l,cache_noise_kernel=cache_noise_kernel,cache_response_kernel=cache_response_kernel,noise_scaling=noise_scaling_op)
+        init_pos = jft.random_like(k_0,{**sky_model.domain, **noise_scaling_op.domain})
+    else:
+        RNR_approx, N_inv_approx = jrve.build_approximations(RNR,RNR_l,cache_noise_kernel=cache_noise_kernel,cache_response_kernel=cache_response_kernel,varcov=varcov)
+        init_pos = jft.random_like(k_0,{**sky_model.domain, **varcov_op.domain}) if varcov else jft.random_like(k_0,sky_model.domain)
+
+    optimize_kl_kwargs = dict(
+        R = RNR,
+        R_approx = RNR_approx,
+        sky = sky_model,
+        N_inv_sqrt = N_inv_approx,
+        data = d_0,
+        pos = 1e-2*jft.Vector(init_pos.copy()),
+        resume = resume,
+        init_samples = init_samples,
+        noise_scaling = noise_scaling,
+        varcov = varcov,
+        varcov_op = varcov_op,
+        key = k_i,
+        out_dir = odir,
+        **config_minimizer
+    )
+
+    pos, samples, _ = jrve.optimize(**optimize_kl_kwargs)
+
+    post_sky_mean = ift.makeField(R.domain, np.array(jft.mean(tuple(sky_model(s) for s in samples))))
+    new_sci_model_vis = jnp.array(R(post_sky_mean).val)
+
+    return samples, pos, new_sci_model_vis
+
+def resume_preparation(
+    obs,
+    sky_model,
+    odir,
+    config 
+):
+    config_response = config["response"]
+    save_all = config["imaging"]["save_all"]
+
+    lfile = f"{odir}/last_started_major"
+    if path.isfile(lfile):
+        with open(lfile) as f:
+            last_started_major = int(f.read())
+        lfile = f"{odir}/major_{last_started_major}/last_finished_iteration"
+        if path.isfile(lfile):
+            with open(lfile) as f:
+                last_finished_index = int(f.read())
+        else:
+            raise FileNotFoundError(f"{lfile} could not be found")
+    else:
+        raise FileNotFoundError(f"{lfile} could not be found")
+    
+    if save_all:
+        samples = pickle.load(
+            open(
+                f"{odir}/major_{last_started_major}/samples_{last_finished_index}.p",
+                "rb",
+            )
+        )
+    else:
+        samples = pickle.load(
+            open(
+                f"{odir}/last_samples.p",
+                "rb",
+            )
+        )
+
+    sp_sky_dom =rve.sky_model._spatial_dom(config_response["conf_sky"])
+    sky_dom = rve.default_sky_domain(sdom=sp_sky_dom)
+    R = rve.InterferometryResponse(obs, sky_dom, config_response["do_wgridding"], config_response["epsilon"], config_response["verbosity"], config_response["nthreads"])
+    dch = ift.DomainChangerAndReshaper(R.domain[3], R.domain)
+    R = R @ dch
+    
+    post_sky_mean = ift.makeField(R.domain, np.array(jft.mean(tuple(sky_model(s) for s in samples))))
+    return jnp.array(R(post_sky_mean).val)
+    
+
+def single_cal_fr_run(
+        cal_assemblers,
+        sky_assembler,
+        config,
+        key,
+        n_vi_runs=(1,1),
+        resume = False,
+        odirs=(None,None)
+        ):
+    key_cal, key_img = jr.split(key,2)
+    config_response = config["response"]
+    config_cal = config["calibration"]
+    config_img = config["imaging"]
+    assembler_keys = list(cal_assemblers.keys())
+
+    if cal_assemblers["sci"].model_vis is None:
+        assembler_keys.remove("sci")
+
+    config_cal["n_total_iterations"] += n_vi_runs[0]
+
+    print("-"*5,"Calibration step","-"*57)
+    
+    cal_samples, cal_state = calibration_step(cal_assemblers,config_cal,key_cal,assembler_keys,resume,odirs[0])
+
+    antenna_gains = jft.mean(tuple(cal_assemblers["sci"].cop(s) for s in cal_samples))
+    cfld = ift.makeField(cal_assemblers["sci"].obs.vis.domain,antenna_gains)
+
+    config_img["n_major_step"] += n_vi_runs[1]
+    print("-"*5,"Imaging step","-"*61)
+
+    img_samples, img_state, cal_assemblers["sci"].model_vis = imaging_step(sky_assembler,cfld,config_response,config_img,key_img,resume=resume,odir=odirs[1])
+
+
+    return dict(response=config_response,calibration=config_cal,imaging=config_img),dict(cal=cal_samples,img=img_samples), dict(cal=cal_state,img=img_state)
+
+def fastresolve_with_calibration(
+        sky_assembler,
+        sci_cal_assembler,
+        config,
+        key,
+        phase_cal_assembler = None,
+        flux_cal_assembler = None,
+        n_iterations=1,
+        n_vi_runs_per_iteration=(1,1),
+        resume = False,
+        odir = None,
+        return_latest_samples=False,
+        return_latest_states=False,
+        return_latest_config_update=False
+        ):
+    if resume and (odir is None):
+        raise ValueError("Set folder name from where to resume the inference runs")
+    
+    odirs = (None,None) if odir is None else (f"{odir}/calibration",f"{odir}/imaging")
+    
+    new_config = deepcopy(config)
+    sca = dataclasses.replace(sci_cal_assembler)
+
+    if resume:
+        sca.model_vis = resume_preparation(sca.obs,sky_assembler.sky,odirs[1],config)
+
+        with open(f"{odir}/run_info.json","r") as f:
+            loaded_run_info = json.load(f)
+
+        n_cal_vi_steps_done = loaded_run_info["cal_vi_iterations"]
+        n_img_major_steps_done = loaded_run_info["img_n_major_steps"]
+
+        print("Resume: Overwrite given settings for n_total_iterations and n_major_step")
+
+        config["calibration"]["n_total_iterations"] = n_cal_vi_steps_done
+        config["imaging"]["n_major_step"] = n_img_major_steps_done
+
+        print(f"Loaded data of previous {n_cal_vi_steps_done} calibration steps and {n_img_major_steps_done} imaging major steps")
+
+    
+    cal_assemblers = dict(sci=sca)
+    cal_assemblers.update(dict(phase=phase_cal_assembler) if phase_cal_assembler is not None else {})
+    cal_assemblers.update(dict(flux=flux_cal_assembler) if flux_cal_assembler is not None else {})
+
+
+    for k in range(n_iterations):
+        print("-"*80)
+        print(" "*4,f"{k+1}-th calibration and fast-resolve iteration")
+        print("-"*80)
+
+        new_config, samples_dct, state_dct = single_cal_fr_run(cal_assemblers,sky_assembler,new_config,key,n_vi_runs_per_iteration,resume,odirs)
+        resume = True
+
+        with open(f"{odir}/run_info.json","w") as f:
+            json.dump(dict(
+                cal_vi_iterations = new_config["calibration"]["n_total_iterations"],
+                img_n_major_steps = new_config["imaging"]["n_major_step"]
+                ),f)
+    
+    returns = {}
+    returns.update(dict(config=new_config) if return_latest_config_update else {})
+    returns.update(dict(samples=samples_dct) if return_latest_samples else {})
+    returns.update(dict(states=state_dct) if return_latest_states else {})
+
+    return returns
\ No newline at end of file
-- 
GitLab


From 295a6656142e9ed9e521ac7253202c1da787b76e Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 30 Apr 2025 11:39:04 +0200
Subject: [PATCH 82/88] Direct input of psf_pixel instead over readout of
 configparser object

---
 resolve/re/radio_response.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/resolve/re/radio_response.py b/resolve/re/radio_response.py
index 5f3478b4..52e57056 100644
--- a/resolve/re/radio_response.py
+++ b/resolve/re/radio_response.py
@@ -30,12 +30,11 @@ def get_jax_fft(domain, target, inverse):
 
     return partial(fft, fct=fct, func=func)
 
-def build_exact_r(obs, conf_sky, conf_setup, calibration_field=None, do_wgridding=True, epsilon=1e-9, verbosity=1, nthreads=8):
+def build_exact_r(obs, conf_sky, psf_pixels, calibration_field=None, do_wgridding=True, epsilon=1e-9, verbosity=1, nthreads=8):
     sp_sky_dom =rve.sky_model._spatial_dom(conf_sky)
     sky_dom = rve.default_sky_domain(sdom=sp_sky_dom)
     R = rve.InterferometryResponse(obs, sky_dom, do_wgridding, epsilon, verbosity, nthreads)
 
-    psf_pixels = conf_setup.getfloat("psf pixels")
     full_psf0 = min(2*psf_pixels, sp_sky_dom.shape[0])
     full_psf1 = min(2*psf_pixels, sp_sky_dom.shape[1])
     sp_sky_dom_l = (sp_sky_dom.shape[0] + full_psf0, sp_sky_dom.shape[1] + full_psf1)
-- 
GitLab


From b628212ab26a4fd365642947d8e2d8bc1cbb01cf Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 30 Apr 2025 11:39:36 +0200
Subject: [PATCH 83/88] removed observation generator from misc

---
 misc/observation_generator.py | 41 -----------------------------------
 1 file changed, 41 deletions(-)
 delete mode 100644 misc/observation_generator.py

diff --git a/misc/observation_generator.py b/misc/observation_generator.py
deleted file mode 100644
index 7e7cd25a..00000000
--- a/misc/observation_generator.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import numpy as np
-
-import resolve as rve
-
-def RandomAntennaPositions(n_row,uvw_max,n_antenna_max,time_max,rng_generator=np.random.default_rng(42)):
-    ant1,ant2 = [],[]
-
-    for k in range(n_row):
-        x = rng_generator.integers(0,n_antenna_max-1)
-        y = rng_generator.integers(1,n_antenna_max)
-
-        if(x==y):
-            while(x==y):
-                y = np.random.randint(1,n_antenna_max)
-
-        ant1.append(x)
-        ant2.append(y)
-
-    ant1 = np.array(ant1).astype(np.int32)
-    ant2 = np.array(ant2).astype(np.int32)
-    time = rng_generator.uniform(0,time_max,n_row)
-    uvw = rng_generator.uniform(-uvw_max,uvw_max,(n_row,3))
-
-    return rve.data.antenna_positions.AntennaPositions(uvw,ant1,ant2,time)
-
-def RandomObservation(n_baselines,pol_indices,freq_channels,uvw_max,n_antenna_max,time_max,abs_vis_max,weight_min,weight_max,rng_generator=np.random.default_rng(42)):
-    antenna_pos = RandomAntennaPositions(n_baselines,uvw_max,n_antenna_max,time_max,rng_generator=rng_generator)
-
-    n_pol = len(pol_indices)
-    n_freq = freq_channels.size
-    vis_shape = (n_pol,n_baselines,n_freq)
-
-    pol = rve.data.polarization.Polarization(pol_indices)
-
-    vis_magnitude = rng_generator.uniform(0,abs_vis_max,vis_shape)
-    vis_phase = rng_generator.uniform(0,2*np.pi,vis_shape)
-    vis = vis_magnitude*np.exp(1.0j*vis_phase)
-
-    weights = rng_generator.uniform(weight_min,weight_max,vis_shape)
-    
-    return rve.data.observation.Observation(antenna_pos,vis,weights,pol,freq_channels)
\ No newline at end of file
-- 
GitLab


From b9c62ec070e5a6bce3444e743c3b732abc7634ab Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Wed, 30 Apr 2025 11:40:12 +0200
Subject: [PATCH 84/88] Restructure to include new dataclasses in sugar.py

---
 resolve/re/__init__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/resolve/re/__init__.py b/resolve/re/__init__.py
index 1011168e..c96632e6 100644
--- a/resolve/re/__init__.py
+++ b/resolve/re/__init__.py
@@ -5,4 +5,4 @@ from .radio_response import build_exact_r, build_approximations
 from .optimize import optimize
 from .calibration import CalibrationInterpolator, CalibrationDistribution
 from .likelihood import CalibrationLikelihood, ImagingLikelihood
-from .sugar import Bulk_CF_AntennaTimeDomain, ScaledCalibrationDistribution
\ No newline at end of file
+from .sugar import Bulk_CF_AntennaTimeDomain, CalibrationAssembler, SkyAssembler
\ No newline at end of file
-- 
GitLab


From 21663d74def9ea18f3f9e2da9578dd637ec836e0 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Thu, 1 May 2025 19:12:04 +0200
Subject: [PATCH 85/88] Scaled operator is now used to set the actual
 calibration operator in the CalibrationAssembler class

---
 resolve/re/fast-resolve_cal.py | 2 +-
 resolve/re/sugar.py            | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/resolve/re/fast-resolve_cal.py b/resolve/re/fast-resolve_cal.py
index 31104a2d..3cac6b2e 100644
--- a/resolve/re/fast-resolve_cal.py
+++ b/resolve/re/fast-resolve_cal.py
@@ -18,7 +18,7 @@ def calibration_step(
     k_i, k_0 = jr.split(key,2)
     
     lh_param_keys = ("observation","calibration_operator","model_visibilities","log_inverse_covariance_operator","likelihood_label")
-    cal_assembler_attr_keys = ("obs","cop","model_vis","log_inv_cov","lh_label")
+    cal_assembler_attr_keys = ("obs","scaled_cop","model_vis","log_inv_cov","lh_label")
     assembler_keys = cal_assemblers.keys() if assembler_keys is None else assembler_keys
 
     lh_components = {i: [getattr(cal_assemblers[k],j) for k in assembler_keys] for i,j in zip(lh_param_keys,cal_assembler_attr_keys)}
diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index bd2d59e0..0921a3ea 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -108,7 +108,7 @@ class CalibrationAssembler:
     def __post_init__(self):
         self.lh_label = self.obs.source_name if self.lh_label is None else self.lh_label
         self.cop = CalibrationDistribution(self.obs,self.phase_field,self.logflux_field,self.dt)
-        self.scaled_cop = None if self.scaling_op is None else jft.Model(
+        self.scaled_cop = self.cop if self.scaling_op is None else jft.Model(
             call = lambda x: self.scaling_op(x)*self.cop(x),
             domain = {**self.cop.domain,**self.scaling_op.domain},
             init =  self.cop.init | self.scaling_op.init
-- 
GitLab


From 2ab975836b426af07fd52dc3b44ed9b857efdab8 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Tue, 20 May 2025 09:44:49 +0200
Subject: [PATCH 86/88] Added further functionality to CalibrationAssembler

---
 resolve/re/sugar.py | 16 +++++++++++++++-
 1 file changed, 15 insertions(+), 1 deletion(-)

diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
index 0921a3ea..dfb54952 100644
--- a/resolve/re/sugar.py
+++ b/resolve/re/sugar.py
@@ -1,7 +1,9 @@
 import jax.numpy as jnp
 import nifty8.re as jft
+import matplotlib.pyplot as plt
 
 from .calibration import CalibrationDistribution
+from .likelihood_models import ModelCalibrationLikelihoodFixedCovariance
 
 from ..data.observation import Observation
 
@@ -168,7 +170,19 @@ class CalibrationAssembler:
                 scaling_op = None
 
             return cls(observation,phase_field,logflux_field,dt,init_flux_field,scaling_op,log_inv_cov,lh_label)
-        
+
+    def prior_realization(self,rng_key):
+        data_model = ModelCalibrationLikelihoodFixedCovariance(self.scaled_cop,self.model_vis,jnp.asarray(self.obs.mask.val))
+        data_model_realization = data_model(data_model.init(rng_key))
+
+        plt.scatter(self.obs.vis.val.imag,self.obs.vis.val.real,label="Data")
+        plt.scatter(data_model_realization.imag,data_model_realization.real, alpha=0.01, label="Prior sample")
+        plt.legend()
+        plt.show()
+
+    def get_lh_domain(self):
+        dom = self.scaled_cop.domain if self.log_inv_cov is None else self.scaled_cop.domain | self.log_inv_cov.domain
+        return dom
 
     def __repr__(self):
         return(
-- 
GitLab


From cc5e158026c69eed4600ddc4c53c83c6e94c87d6 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 26 May 2025 09:24:31 +0200
Subject: [PATCH 87/88] Restructured for clear overview; Adding of random drawn
 latent parameters to initial sample in calibration step

---
 resolve/re/fast-resolve_cal.py | 95 ++++++++++++++++++++++++----------
 1 file changed, 69 insertions(+), 26 deletions(-)

diff --git a/resolve/re/fast-resolve_cal.py b/resolve/re/fast-resolve_cal.py
index 3cac6b2e..d01a2fff 100644
--- a/resolve/re/fast-resolve_cal.py
+++ b/resolve/re/fast-resolve_cal.py
@@ -1,21 +1,28 @@
 import nifty8 as ift
 import nifty8.re as jft
 
+import resolve as rve
 import resolve.re as jrve
 
 import jax.numpy as jnp
-import jax.random as jr
+from jax import random
+
+import pickle
+import json
+import dataclasses
+from copy import deepcopy
 
 def calibration_step(
         cal_assemblers,
         config_minimizer,
-        key,
+        random_key,
         assembler_keys = None,
+        init_sample = None,
         resume = False,
         odir = None
         ):
     
-    k_i, k_0 = jr.split(key,2)
+    k_i, k_0 = random.split(random_key,2)
     
     lh_param_keys = ("observation","calibration_operator","model_visibilities","log_inverse_covariance_operator","likelihood_label")
     cal_assembler_attr_keys = ("obs","scaled_cop","model_vis","log_inv_cov","lh_label")
@@ -23,10 +30,17 @@ def calibration_step(
 
     lh_components = {i: [getattr(cal_assemblers[k],j) for k in assembler_keys] for i,j in zip(lh_param_keys,cal_assembler_attr_keys)}
     lh = jrve.CalibrationLikelihood(**lh_components)
+    
+    cal_init = jft.random_like(k_0,lh.domain)
+    
+    if init_sample is not None:
+        if not(set(init_sample.tree.keys()).issubset(set(cal_init.tree.keys()))):
+            raise ValueError("init_sample should at most contain the parameters of the likelihoods")
+        cal_init = jft.Vector(cal_init.tree | init_sample.tree)  
 
     optimize_kl_kwargs = dict(
         likelihood = lh,
-        position_or_samples = jft.random_like(k_0,lh.domain),
+        position_or_samples = cal_init,
         key = k_i,
         resume = resume,
         odir = odir,
@@ -40,7 +54,7 @@ def imaging_step(
         calibration_field,
         config_response,
         config_minimizer,
-        key,
+        random_key,
         init_samples = None,
         cache_noise_kernel="None",
         cache_response_kernel="None",
@@ -56,7 +70,7 @@ def imaging_step(
     noise_scaling = False if noise_scaling_op is None else True
     varcov = False if varcov_op is None else True
 
-    k_i, k_0 = jr.split(key,2)
+    k_i, k_0 = random.split(random_key,2)
 
     R, _, RNR, RNR_l = jrve.build_exact_r(obs,calibration_field=calibration_field,**config_response)
 
@@ -139,19 +153,19 @@ def resume_preparation(
     R = R @ dch
     
     post_sky_mean = ift.makeField(R.domain, np.array(jft.mean(tuple(sky_model(s) for s in samples))))
-    return jnp.array(R(post_sky_mean).val)
-    
+    return jnp.array(R(post_sky_mean).val)  
 
 def single_cal_fr_run(
         cal_assemblers,
         sky_assembler,
         config,
-        key,
+        random_key,
         n_vi_runs=(1,1),
+        init_samples = (None, None),
         resume = False,
-        odirs=(None,None)
+        odirs = (None, None)
         ):
-    key_cal, key_img = jr.split(key,2)
+    key_cal, key_img = random.split(random_key,2)
     config_response = config["response"]
     config_cal = config["calibration"]
     config_img = config["imaging"]
@@ -162,18 +176,34 @@ def single_cal_fr_run(
 
     config_cal["n_total_iterations"] += n_vi_runs[0]
 
-    print("-"*5,"Calibration step","-"*57)
+    print("-"*4+"|","Calibration step","|"+"-"*56)
     
-    cal_samples, cal_state = calibration_step(cal_assemblers,config_cal,key_cal,assembler_keys,resume,odirs[0])
-
+    cal_samples, cal_state = calibration_step(
+        cal_assemblers = cal_assemblers,
+        config_minimizer = config_cal,
+        random_key = key_cal,
+        assembler_keys = assembler_keys,
+        init_sample = init_samples[0],
+        resume = resume,
+        odir = odirs[0]
+        )
+    
+    config_img["n_major_step"] += n_vi_runs[1]
+    print("-"*4+"|","Imaging step","|"+"-"*60)
+      
     antenna_gains = jft.mean(tuple(cal_assemblers["sci"].cop(s) for s in cal_samples))
     cfld = ift.makeField(cal_assemblers["sci"].obs.vis.domain,antenna_gains)
 
-    config_img["n_major_step"] += n_vi_runs[1]
-    print("-"*5,"Imaging step","-"*61)
-
-    img_samples, img_state, cal_assemblers["sci"].model_vis = imaging_step(sky_assembler,cfld,config_response,config_img,key_img,resume=resume,odir=odirs[1])
-
+    img_samples, img_state, cal_assemblers["sci"].model_vis = imaging_step(
+        sky_assembler = sky_assembler,
+        calibration_field = cfld,
+        config_response = config_response,
+        config_minimizer = config_img,
+        random_key = key_img,
+        init_samples = init_samples[1],
+        resume = resume,
+        odir = odirs[1]
+        )
 
     return dict(response=config_response,calibration=config_cal,imaging=config_img),dict(cal=cal_samples,img=img_samples), dict(cal=cal_state,img=img_state)
 
@@ -181,11 +211,12 @@ def fastresolve_with_calibration(
         sky_assembler,
         sci_cal_assembler,
         config,
-        key,
+        random_key,
         phase_cal_assembler = None,
         flux_cal_assembler = None,
         n_iterations=1,
         n_vi_runs_per_iteration=(1,1),
+        init_samples = None,
         resume = False,
         odir = None,
         return_latest_samples=False,
@@ -196,23 +227,26 @@ def fastresolve_with_calibration(
         raise ValueError("Set folder name from where to resume the inference runs")
     
     odirs = (None,None) if odir is None else (f"{odir}/calibration",f"{odir}/imaging")
+    init_samples = (None,None) if init_samples is None else init_samples
     
     new_config = deepcopy(config)
     sca = dataclasses.replace(sci_cal_assembler)
 
     if resume:
-        sca.model_vis = resume_preparation(sca.obs,sky_assembler.sky,odirs[1],config)
-
         with open(f"{odir}/run_info.json","r") as f:
             loaded_run_info = json.load(f)
 
         n_cal_vi_steps_done = loaded_run_info["cal_vi_iterations"]
         n_img_major_steps_done = loaded_run_info["img_n_major_steps"]
-
+    
+        if n_img_major_steps_done > 0:
+            sca.model_vis = resume_preparation(sca.obs,sky_assembler.sky,odirs[1],new_config)
+        
+        
         print("Resume: Overwrite given settings for n_total_iterations and n_major_step")
 
-        config["calibration"]["n_total_iterations"] = n_cal_vi_steps_done
-        config["imaging"]["n_major_step"] = n_img_major_steps_done
+        new_config["calibration"]["n_total_iterations"] = n_cal_vi_steps_done
+        new_config["imaging"]["n_major_step"] = n_img_major_steps_done
 
         print(f"Loaded data of previous {n_cal_vi_steps_done} calibration steps and {n_img_major_steps_done} imaging major steps")
 
@@ -227,7 +261,16 @@ def fastresolve_with_calibration(
         print(" "*4,f"{k+1}-th calibration and fast-resolve iteration")
         print("-"*80)
 
-        new_config, samples_dct, state_dct = single_cal_fr_run(cal_assemblers,sky_assembler,new_config,key,n_vi_runs_per_iteration,resume,odirs)
+        new_config, samples_dct, state_dct = single_cal_fr_run(
+            cal_assemblers = cal_assemblers,
+            sky_assembler = sky_assembler,
+            config = new_config,
+            random_key = random_key,
+            n_vi_runs = n_vi_runs_per_iteration,
+            init_samples = init_samples,
+            resume = resume,
+            odirs = odirs
+            )
         resume = True
 
         with open(f"{odir}/run_info.json","w") as f:
-- 
GitLab


From 73b38d5dcd59e26e29ba50c009cdba548d4e5e12 Mon Sep 17 00:00:00 2001
From: Andreas Popp <apopp@mpa-garching.mpg.de>
Date: Mon, 26 May 2025 09:25:29 +0200
Subject: [PATCH 88/88] Added overview print statements

---
 resolve/re/likelihood.py | 22 ++++++++++++++--------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
index 27d9fe57..782c9d87 100644
--- a/resolve/re/likelihood.py
+++ b/resolve/re/likelihood.py
@@ -9,10 +9,8 @@ from .likelihood_models import *
 # The classes from .likelihoods model are:
 #  - ModelCalibrationLikelihoodFixedCovariance
 #  - ModelCalibrationLikelihoodVariableCovariance
-#  - ModelImagingLikelihoodFixedCovarianceCalibrationField
-#  - ModelImagingLikelihoodFixedCovarianceCalibrationOperator
-#  - ModelImagingLikelihoodVariableCovarianceCalibrationField
-#  - ModelImagingLikelihoodVariableCovarianceCalibrationOperator
+#  - ModelImagingLikelihoodFixedCovariance
+#  - ModelImagingLikelihoodVariableCovariance
 
 from ..util import _obj2list, _duplicate
 from ..data.observation import Observation
@@ -81,16 +79,20 @@ def CalibrationLikelihood(
         flagged_data = jnp.asarray(oo.vis.val)[mask]
 
         if log_inv_cov is None:
+            if label is None:
+                print(f"| Imaging Likelihood {ii} |")
+            else:
+                print(f"| Imaging Likelihood {ii} | {label} |")
+
             model = ModelCalibrationLikelihoodFixedCovariance(cop,model_vis,mask)
             flagged_inv_cov = jnp.asarray(oo.weight.val)[mask]
             
-            lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
-        
+            lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov) 
         else:
             model = ModelCalibrationLikelihoodVariableCovariance(cop,model_vis,log_inv_cov,mask)
 
             lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=jnp.iscomplexobj(oo.vis.val))
-        
+
         lh_with_model = lh.amend(model)
         lh_with_model._domain = jft.Vector(lh_with_model._domain)
 
@@ -202,12 +204,16 @@ def ImagingLikelihood(
         flagged_data = jnp.asarray(oo.vis.val)[mask]
 
         if log_inv_cov is None:
+            if label is None:
+                print(f"| Imaging Likelihood {ii} |")
+            else:
+                print(f"| Imaging Likelihood {ii} | {label} |")
+
             model = ModelImagingLikelihoodFixedCovariance(R,sky_operator,mask,cop,cfld)
 
             flagged_inv_cov = jnp.asarray(oo.weight.val)[mask]
         
             lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
-
         else:
             model = ModelImagingLikelihoodVariableCovariance(R,sky_operator,log_inv_cov,mask,cop,cfld)
 
-- 
GitLab