diff --git a/demo/cygnusa_2ghz_fast_resolve.cfg b/demo/cygnusa_2ghz_fast_resolve.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..9e4f0f24711cb57834b97770ef4c0c3e6e0d17ab
--- /dev/null
+++ b/demo/cygnusa_2ghz_fast_resolve.cfg
@@ -0,0 +1,34 @@
+[setup]
+data=CYG-ALL-2052-2MHZ_RESOLVE_float64.npz
+psf pixels=300
+cache_noise_kernel=noise_kernel_cygnusa_2ghz_sm
+cache_response_kernel=response_kernel_cygnusa_2ghz_sm
+noise_scaling=True
+varcov=False
+
+
+
+[sky]
+freq mode = single
+polarization=I
+space npix x = 1024
+space npix y = 512
+space fov x = 0.05deg
+space fov y = 0.025deg
+
+stokesI diffuse space i0 zero mode offset = 18
+stokesI diffuse space i0 zero mode mean = 1
+stokesI diffuse space i0 zero mode stddev = 0.1
+stokesI diffuse space i0 fluctuations mean = 5
+stokesI diffuse space i0 fluctuations stddev = 1
+stokesI diffuse space i0 loglogavgslope mean = -2.0
+stokesI diffuse space i0 loglogavgslope stddev = 0.2
+stokesI diffuse space i0 flexibility mean =   1.2
+stokesI diffuse space i0 flexibility stddev = 0.4
+stokesI diffuse space i0 asperity mean =  0.2
+stokesI diffuse space i0 asperity stddev = 0.2
+
+point sources mode = single
+point sources locations = 0deg$0deg,0.35as$-0.22as
+point sources alpha = 0.5
+point sources q = 0.2
diff --git a/demo/demo_fast_resolve.py b/demo/demo_fast_resolve.py
new file mode 100644
index 0000000000000000000000000000000000000000..6017aa83f357c82c9b9e4ae9de45393e36f46ef1
--- /dev/null
+++ b/demo/demo_fast_resolve.py
@@ -0,0 +1,196 @@
+# %%
+import nifty8 as ift
+import nifty8.re as jft
+import resolve as rve
+import resolve.re as jrve
+import numpy as np
+import configparser
+import sys
+import matplotlib.pyplot as plt
+from matplotlib.colors import LogNorm
+import jax
+import pickle
+
+from jax import random
+import jax.numpy as jnp
+
+
+jax.config.update("jax_enable_x64", True)
+
+seed = 42
+
+key = random.PRNGKey(seed)
+conf_name = "cygnusa_2ghz_fast_resolve.cfg"
+out_dir = "demo_fast-resolve"
+
+
+cfg = configparser.ConfigParser()
+cfg.read(conf_name)
+
+data = cfg["setup"]["data"]
+cnk = cfg["setup"]["cache_noise_kernel"]
+crk = cfg["setup"]["cache_response_kernel"]
+noise_scal = cfg["setup"].getboolean("noise_scaling")
+varcov = cfg["setup"].getboolean("varcov")
+if noise_scal and varcov:
+    raise ValueError()
+obs = rve.Observation.load(data)
+obs = obs.restrict_to_stokesi()
+N_inv = ift.makeOp(obs.weight)
+
+sky_model, model_dict = jrve.sky_model(cfg["sky"])
+R, R_l, RNR, RNR_l = jrve.build_exact_r(obs, cfg["sky"], cfg["setup"])
+# %%
+my_sky_model_func = lambda x: sky_model(x)[0, 0, 0, :, :]
+my_sky_model = jft.Model(my_sky_model_func, domain=sky_model.domain)
+mode = 1
+mean = 1.1
+a = 2 / (mean / mode - 1) + 1
+scale = mode * (a + 1)
+inv_gamma = jft.InvGammaPrior(
+                    a=a,
+                    scale=scale,
+                    name="noise",
+                    shape=jax.ShapeDtypeStruct(R.domain.shape, float),
+                )
+noise_scaling = lambda x: 1 / (inv_gamma(x))
+noise_scaling = jft.Model(noise_scaling, domain=inv_gamma.domain)
+
+if noise_scal:
+    RNR_approx, N_inv_approx = jrve.build_approximations(
+        RNR,
+        RNR_l,
+        cache_noise_kernel=cnk,
+        cache_response_kernel=crk,
+        noise_scaling=noise_scaling,
+    )
+else:
+    RNR_approx, N_inv_approx = jrve.build_approximations(
+        RNR, RNR_l, cache_noise_kernel=cnk, cache_response_kernel=crk, varcov=varcov
+    )
+
+
+key, subkey = random.split(key)
+if noise_scaling or varcov:
+    init_pos = jft.random_like(subkey, {**my_sky_model.domain, **noise_scaling.domain})
+else:
+    init_pos = jft.random_like(subkey, my_sky_model.domain)
+
+
+def callback(pos, samp_at_pos, iter, n_major):
+    post_mean = jft.mean(tuple(my_sky_model(s) for s in samp_at_pos))
+    plt.imshow(post_mean.T, origin="lower", norm=LogNorm())
+    plt.colorbar()
+    plt.savefig(f"{out_dir}/major_{n_major}/sky_mean_{iter}.png", dpi=300)
+    plt.close()
+
+
+# inference
+d_new = R.adjoint(N_inv(obs.vis))
+
+d_new = jnp.array(d_new.val)
+init_pos = 1e-2 * jft.Vector(init_pos.copy())
+
+
+absdelta = 1e-10
+
+
+def get_draw_linear_kwargs(i):
+    kwargs = dict(
+        cg_name="cg",
+        cg_kwargs=dict(
+            absdelta=absdelta / 10.0,
+            maxiter=nstep_sampling(i),
+            miniter=nstep_sampling(i),
+        ),
+    )
+    return kwargs
+
+
+def get_nonlinearly_update_kwargs(i):
+    kwargs = dict(
+        minimize_kwargs=dict(
+            name=None,
+            xtol=1e-4,
+            cg_kwargs=dict(name=None),
+            maxiter=nstep_nl_sampling(i),
+            miniter=nstep_nl_sampling(i),
+        )
+    )
+    return kwargs
+
+
+def get_kl_kwargs(i):
+    if i < 6:
+        min_cg = 6
+    else:
+        min_cg = 20
+    kwargs = dict(
+        minimize_kwargs=dict(
+            name="newton",
+            absdelta=absdelta,
+            cg_kwargs=dict(name="ncg", miniter=min_cg),
+            maxiter=nstep_newton(i),
+            miniter=nstep_newton(i),
+            energy_reduction_factor=1e-3,
+        )
+    )
+    return kwargs
+
+
+def method(i):
+    return "linear_resample"
+
+
+def nstep_sampling(ii):
+    if ii < 4:
+        return 100
+    elif ii < 8:
+        return 500
+    elif ii < 12:
+        return 1000
+
+
+def nstep_nl_sampling(ii):
+    return 0 if ii < 20 else 10
+
+
+def nstep_newton(ii):
+    if ii < 4:
+        return 5
+    elif ii < 8:
+        return 10
+    elif ii < 12:
+        return 20
+    elif ii < 14:
+        return 30
+
+draw_linear_kwargs = get_draw_linear_kwargs
+nonlinearly_update_kwargs = get_nonlinearly_update_kwargs
+kl_kwargs = get_kl_kwargs
+
+optimize_kwargs = dict(
+    R=RNR,
+    R_approx=RNR_approx,
+    sky=my_sky_model,
+    N_inv_sqrt=N_inv_approx,
+    data=d_new,
+    pos=init_pos,
+    draw_linear_kwargs=draw_linear_kwargs,
+    nonlinearly_update_kwargs=nonlinearly_update_kwargs,
+    kl_kwargs=kl_kwargs,
+    n_samples=2,
+    n_iter=2,
+    n_major_step=20,
+    key=key,
+    callback=callback,
+    out_dir=out_dir,
+    resume=True,
+    init_samples=None,
+    noise_scaling=noise_scal,
+    varcov=varcov,
+    varcov_op=noise_scaling,
+    method=method,
+)
+pos, samples, key = jrve.optimize(**optimize_kwargs)
+# %%
diff --git a/resolve/re/__init__.py b/resolve/re/__init__.py
index ce2ba46e2dd79aa995eb542f9f9c8aad80c18a68..c96632e67661b85427538689ecc7360899fd7a42 100644
--- a/resolve/re/__init__.py
+++ b/resolve/re/__init__.py
@@ -1,3 +1,8 @@
 
 from .sky_model import sky_model_diffuse, sky_model_points, sky_model
-from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc
\ No newline at end of file
+from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc
+from .radio_response import build_exact_r, build_approximations
+from .optimize import optimize
+from .calibration import CalibrationInterpolator, CalibrationDistribution
+from .likelihood import CalibrationLikelihood, ImagingLikelihood
+from .sugar import Bulk_CF_AntennaTimeDomain, CalibrationAssembler, SkyAssembler
\ No newline at end of file
diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7b3cb6597b1ab6e9edf6597b22c7f306b37b7e6
--- /dev/null
+++ b/resolve/re/calibration.py
@@ -0,0 +1,97 @@
+import nifty8.re as jft
+import jax.scipy as jsc
+import jax.numpy as jnp
+
+from jax.tree_util import Partial
+from jax import vmap
+
+from ..data.observation import Observation
+
+class CalibrationDistribution(jft.Model):
+    """
+    Computes the calibration operator from given observation data.
+
+    Parameters
+    ----------
+    observation: Observation
+        Observation object from which are the antenna and temporal information corresponding to 
+        the visibilites are extracted.
+    phase_fields: jft.Model or Bulk_CF_AntennaTimeDomain (preferred)
+        Correlated fields on antenna-time space for phases of calibration solutions.
+    log_amplitude_fields: jft.Model or Bulk_CF_AntennaTimeDomain (preferred)
+        Correlated fields on antenna-time space for log amplitude of calibration solutions.
+    dt: float
+        Distances between time points on time axis. Has to be the same distance of time points,
+        which is used for phase_fields and log_amplitude fields.
+
+    Note
+    ----
+    Currently, only uniformly spaced time axis are supported.
+    """
+    def __init__(
+            self,
+            observation: Observation, 
+            phase_fields: jft.Model, 
+            log_amplitude_fields: jft.Model,
+            dt: float
+            ):
+        ap = observation.antenna_positions
+        target_shape = observation.vis.shape
+        self._cop1 = CalibrationInterpolator(jnp.asarray(ap.ant1), jnp.asarray(ap.time), dt, target_shape)
+        self._cop2 = CalibrationInterpolator(jnp.asarray(ap.ant2), jnp.asarray(ap.time), dt, target_shape)
+
+        self._phases = phase_fields
+        self._logamps = log_amplitude_fields
+
+        super().__init__(init=self._phases.init | self._logamps.init)
+
+    def __call__(self,x):
+        res_logamp = jnp.real(self._cop1(self._logamps(x)) + self._cop2(self._logamps(x)))
+        res_phase = jnp.real(self._cop1(self._phases(x)) - self._cop2(self._phases(x)))*1j
+
+        return jnp.exp(res_logamp + res_phase)
+
+class CalibrationInterpolator():
+    """
+    Interpolates visibilites for a specific sequence of antenna-time pairs given the visibilities
+    on an evenly spaced antenna-time grid.
+
+    Parameters
+    ----------
+    ant_col: jnp.ndarray
+        Antenna points to which one wants to interpolate
+    time_col: jnp.ndarry
+        Time points to which one wants to interpolate
+    dt: float
+        Distances between time points on time axis.
+    target_shape: tuple
+        Shape of output when calling this class. First element should encode number of 
+        polarization directions and last element should encode number of frequencies
+    
+    Note
+    ----
+    Currently, only uniformly spaced time axis are supported.
+    """
+    def __init__(
+            self,
+            ant_col: jnp.ndarray, 
+            time_col: jnp.ndarray,
+            dt: float,
+            target_shape: tuple
+            ):
+
+        coords = [ant_col,time_col/dt]
+
+        self._li = Partial(jsc.ndimage.map_coordinates,coordinates=coords,order=1)
+
+        self._n_pol,_,self._n_freq = target_shape
+
+    def __call__(self,x):
+        res = vmap(
+            vmap(
+                lambda pol, freq: self._li(x[pol, :, :, freq])
+                ,in_axes=(None,0), out_axes=1
+            ),in_axes=(0,None)
+        )
+        
+        return res(jnp.arange(self._n_pol), jnp.arange(self._n_freq))
\ No newline at end of file
diff --git a/resolve/re/fast-resolve_cal.py b/resolve/re/fast-resolve_cal.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cac6b2e41f537cbffb2e651d84fb5bb4cb1ae5f
--- /dev/null
+++ b/resolve/re/fast-resolve_cal.py
@@ -0,0 +1,244 @@
+import nifty8 as ift
+import nifty8.re as jft
+
+import resolve.re as jrve
+
+import jax.numpy as jnp
+import jax.random as jr
+
+def calibration_step(
+        cal_assemblers,
+        config_minimizer,
+        key,
+        assembler_keys = None,
+        resume = False,
+        odir = None
+        ):
+    
+    k_i, k_0 = jr.split(key,2)
+    
+    lh_param_keys = ("observation","calibration_operator","model_visibilities","log_inverse_covariance_operator","likelihood_label")
+    cal_assembler_attr_keys = ("obs","scaled_cop","model_vis","log_inv_cov","lh_label")
+    assembler_keys = cal_assemblers.keys() if assembler_keys is None else assembler_keys
+
+    lh_components = {i: [getattr(cal_assemblers[k],j) for k in assembler_keys] for i,j in zip(lh_param_keys,cal_assembler_attr_keys)}
+    lh = jrve.CalibrationLikelihood(**lh_components)
+
+    optimize_kl_kwargs = dict(
+        likelihood = lh,
+        position_or_samples = jft.random_like(k_0,lh.domain),
+        key = k_i,
+        resume = resume,
+        odir = odir,
+        **config_minimizer
+    )
+    
+    return jft.optimize_kl(**optimize_kl_kwargs)
+
+def imaging_step(
+        sky_assembler,
+        calibration_field,
+        config_response,
+        config_minimizer,
+        key,
+        init_samples = None,
+        cache_noise_kernel="None",
+        cache_response_kernel="None",
+        resume = False,
+        odir = None
+        ):
+    
+    obs = sky_assembler.obs
+    sky_model = sky_assembler.sky
+    noise_scaling_op = sky_assembler.noise_scaling
+    varcov_op = sky_assembler.varcov
+    
+    noise_scaling = False if noise_scaling_op is None else True
+    varcov = False if varcov_op is None else True
+
+    k_i, k_0 = jr.split(key,2)
+
+    R, _, RNR, RNR_l = jrve.build_exact_r(obs,calibration_field=calibration_field,**config_response)
+
+    N_inv = ift.makeOp(obs.weight)
+    d_0 = jnp.array(R.adjoint(N_inv(obs.vis)).val)
+    
+
+    if noise_scaling:
+        RNR_approx, N_inv_approx = jrve.build_approximations(RNR,RNR_l,cache_noise_kernel=cache_noise_kernel,cache_response_kernel=cache_response_kernel,noise_scaling=noise_scaling_op)
+        init_pos = jft.random_like(k_0,{**sky_model.domain, **noise_scaling_op.domain})
+    else:
+        RNR_approx, N_inv_approx = jrve.build_approximations(RNR,RNR_l,cache_noise_kernel=cache_noise_kernel,cache_response_kernel=cache_response_kernel,varcov=varcov)
+        init_pos = jft.random_like(k_0,{**sky_model.domain, **varcov_op.domain}) if varcov else jft.random_like(k_0,sky_model.domain)
+
+    optimize_kl_kwargs = dict(
+        R = RNR,
+        R_approx = RNR_approx,
+        sky = sky_model,
+        N_inv_sqrt = N_inv_approx,
+        data = d_0,
+        pos = 1e-2*jft.Vector(init_pos.copy()),
+        resume = resume,
+        init_samples = init_samples,
+        noise_scaling = noise_scaling,
+        varcov = varcov,
+        varcov_op = varcov_op,
+        key = k_i,
+        out_dir = odir,
+        **config_minimizer
+    )
+
+    pos, samples, _ = jrve.optimize(**optimize_kl_kwargs)
+
+    post_sky_mean = ift.makeField(R.domain, np.array(jft.mean(tuple(sky_model(s) for s in samples))))
+    new_sci_model_vis = jnp.array(R(post_sky_mean).val)
+
+    return samples, pos, new_sci_model_vis
+
+def resume_preparation(
+    obs,
+    sky_model,
+    odir,
+    config 
+):
+    config_response = config["response"]
+    save_all = config["imaging"]["save_all"]
+
+    lfile = f"{odir}/last_started_major"
+    if path.isfile(lfile):
+        with open(lfile) as f:
+            last_started_major = int(f.read())
+        lfile = f"{odir}/major_{last_started_major}/last_finished_iteration"
+        if path.isfile(lfile):
+            with open(lfile) as f:
+                last_finished_index = int(f.read())
+        else:
+            raise FileNotFoundError(f"{lfile} could not be found")
+    else:
+        raise FileNotFoundError(f"{lfile} could not be found")
+    
+    if save_all:
+        samples = pickle.load(
+            open(
+                f"{odir}/major_{last_started_major}/samples_{last_finished_index}.p",
+                "rb",
+            )
+        )
+    else:
+        samples = pickle.load(
+            open(
+                f"{odir}/last_samples.p",
+                "rb",
+            )
+        )
+
+    sp_sky_dom =rve.sky_model._spatial_dom(config_response["conf_sky"])
+    sky_dom = rve.default_sky_domain(sdom=sp_sky_dom)
+    R = rve.InterferometryResponse(obs, sky_dom, config_response["do_wgridding"], config_response["epsilon"], config_response["verbosity"], config_response["nthreads"])
+    dch = ift.DomainChangerAndReshaper(R.domain[3], R.domain)
+    R = R @ dch
+    
+    post_sky_mean = ift.makeField(R.domain, np.array(jft.mean(tuple(sky_model(s) for s in samples))))
+    return jnp.array(R(post_sky_mean).val)
+    
+
+def single_cal_fr_run(
+        cal_assemblers,
+        sky_assembler,
+        config,
+        key,
+        n_vi_runs=(1,1),
+        resume = False,
+        odirs=(None,None)
+        ):
+    key_cal, key_img = jr.split(key,2)
+    config_response = config["response"]
+    config_cal = config["calibration"]
+    config_img = config["imaging"]
+    assembler_keys = list(cal_assemblers.keys())
+
+    if cal_assemblers["sci"].model_vis is None:
+        assembler_keys.remove("sci")
+
+    config_cal["n_total_iterations"] += n_vi_runs[0]
+
+    print("-"*5,"Calibration step","-"*57)
+    
+    cal_samples, cal_state = calibration_step(cal_assemblers,config_cal,key_cal,assembler_keys,resume,odirs[0])
+
+    antenna_gains = jft.mean(tuple(cal_assemblers["sci"].cop(s) for s in cal_samples))
+    cfld = ift.makeField(cal_assemblers["sci"].obs.vis.domain,antenna_gains)
+
+    config_img["n_major_step"] += n_vi_runs[1]
+    print("-"*5,"Imaging step","-"*61)
+
+    img_samples, img_state, cal_assemblers["sci"].model_vis = imaging_step(sky_assembler,cfld,config_response,config_img,key_img,resume=resume,odir=odirs[1])
+
+
+    return dict(response=config_response,calibration=config_cal,imaging=config_img),dict(cal=cal_samples,img=img_samples), dict(cal=cal_state,img=img_state)
+
+def fastresolve_with_calibration(
+        sky_assembler,
+        sci_cal_assembler,
+        config,
+        key,
+        phase_cal_assembler = None,
+        flux_cal_assembler = None,
+        n_iterations=1,
+        n_vi_runs_per_iteration=(1,1),
+        resume = False,
+        odir = None,
+        return_latest_samples=False,
+        return_latest_states=False,
+        return_latest_config_update=False
+        ):
+    if resume and (odir is None):
+        raise ValueError("Set folder name from where to resume the inference runs")
+    
+    odirs = (None,None) if odir is None else (f"{odir}/calibration",f"{odir}/imaging")
+    
+    new_config = deepcopy(config)
+    sca = dataclasses.replace(sci_cal_assembler)
+
+    if resume:
+        sca.model_vis = resume_preparation(sca.obs,sky_assembler.sky,odirs[1],config)
+
+        with open(f"{odir}/run_info.json","r") as f:
+            loaded_run_info = json.load(f)
+
+        n_cal_vi_steps_done = loaded_run_info["cal_vi_iterations"]
+        n_img_major_steps_done = loaded_run_info["img_n_major_steps"]
+
+        print("Resume: Overwrite given settings for n_total_iterations and n_major_step")
+
+        config["calibration"]["n_total_iterations"] = n_cal_vi_steps_done
+        config["imaging"]["n_major_step"] = n_img_major_steps_done
+
+        print(f"Loaded data of previous {n_cal_vi_steps_done} calibration steps and {n_img_major_steps_done} imaging major steps")
+
+    
+    cal_assemblers = dict(sci=sca)
+    cal_assemblers.update(dict(phase=phase_cal_assembler) if phase_cal_assembler is not None else {})
+    cal_assemblers.update(dict(flux=flux_cal_assembler) if flux_cal_assembler is not None else {})
+
+
+    for k in range(n_iterations):
+        print("-"*80)
+        print(" "*4,f"{k+1}-th calibration and fast-resolve iteration")
+        print("-"*80)
+
+        new_config, samples_dct, state_dct = single_cal_fr_run(cal_assemblers,sky_assembler,new_config,key,n_vi_runs_per_iteration,resume,odirs)
+        resume = True
+
+        with open(f"{odir}/run_info.json","w") as f:
+            json.dump(dict(
+                cal_vi_iterations = new_config["calibration"]["n_total_iterations"],
+                img_n_major_steps = new_config["imaging"]["n_major_step"]
+                ),f)
+    
+    returns = {}
+    returns.update(dict(config=new_config) if return_latest_config_update else {})
+    returns.update(dict(samples=samples_dct) if return_latest_samples else {})
+    returns.update(dict(states=state_dct) if return_latest_states else {})
+
+    return returns
\ No newline at end of file
diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py
new file mode 100644
index 0000000000000000000000000000000000000000..27d9fe571a57abb556090572578a1684772dc3c8
--- /dev/null
+++ b/resolve/re/likelihood.py
@@ -0,0 +1,228 @@
+import jax.numpy as jnp
+import nifty8.re as jft
+
+from typing import Union, Iterable
+
+from .response import InterferometryResponse
+from .calibration import CalibrationDistribution
+from .likelihood_models import *
+# The classes from .likelihoods model are:
+#  - ModelCalibrationLikelihoodFixedCovariance
+#  - ModelCalibrationLikelihoodVariableCovariance
+#  - ModelImagingLikelihoodFixedCovarianceCalibrationField
+#  - ModelImagingLikelihoodFixedCovarianceCalibrationOperator
+#  - ModelImagingLikelihoodVariableCovarianceCalibrationField
+#  - ModelImagingLikelihoodVariableCovarianceCalibrationOperator
+
+from ..util import _obj2list, _duplicate
+from ..data.observation import Observation
+
+from jaxlib.xla_extension import ArrayImpl
+from nifty8.re.model import Model
+
+def CalibrationLikelihood(
+    observation: Union[Observation, Iterable[Observation]],
+    calibration_operator: Union[CalibrationDistribution, Iterable[CalibrationDistribution]],
+    model_visibilities: Union[jnp.ndarray, Iterable[jnp.ndarray]],
+    log_inverse_covariance_operator: Union[jft.Model,Iterable[jft.Model]] = None,
+    likelihood_label: Union[str,Iterable[str]] = None
+):
+    """Versatile calibration likelihood class
+
+    It returns an operator that computes:
+
+    residual = calibration_operator * model_visibilities
+    likelihood = 0.5 * residual^dagger @ inverse_covariance @ residual
+
+    If an inverse_covariance_operator is passed, it is inserted into the above
+    formulae. If it is not passed, 1/observation.weights is used as inverse
+    covariance.
+
+    Parameters
+    ----------
+    observation : Observation or Iterable of Observations
+        Observation object from which observation.vis and potentially
+        observation.weight is used for computing the likelihood.
+
+    calibration_operator : CalibrationDistribution or Iterable of CalibrationDistribution
+        Target needs to be the same as observation.vis. The same amount of elements as 
+        number of observations should be provided.
+
+    model_visibilities jnp.ndarray or Iterable of jnp.ndarray
+        Known model visiblities that are used for calibration. Needs to be
+        defined on the same domain as `observation.vis`. The same amount of elements as 
+        number of observations should be provided.
+
+    log_inverse_covariance_operator : jft.Model or Iterable of jft.Model
+        Optional. Target needs to be the same space as observation.vis. If it is
+        not specified, observation.wgt is taken as covariance. If used, the same
+        amount of elements as number of observations should be provided.
+
+    likelihood_label: string or Iterable of string
+        Optional. Append label to individual likelihood which is shown in the minisanity
+        for overview. If used, the same amount of elements as number of observations 
+        should be provided.
+    """
+    
+    
+    obs = _obj2list(observation,Observation)
+    cops = _obj2list(calibration_operator,CalibrationDistribution)
+    model_d = _obj2list(model_visibilities,ArrayImpl)
+    log_inv_covs = _duplicate(_obj2list(log_inverse_covariance_operator,Model),len(obs))
+    labels = _duplicate(_obj2list(likelihood_label,str),len(obs))
+
+    if len(set([len(obs),len(cops),len(model_d)])) != 1:
+        raise ValueError("observation, calibration_operator and model_visibilities must have the same number of elements")
+
+    lhs = []
+    for ii, (oo, cop, model_vis, log_inv_cov,label) in enumerate(zip(obs,cops,model_d,log_inv_covs,labels)):    
+        mask = jnp.asarray(oo.mask.val)
+
+        flagged_data = jnp.asarray(oo.vis.val)[mask]
+
+        if log_inv_cov is None:
+            model = ModelCalibrationLikelihoodFixedCovariance(cop,model_vis,mask)
+            flagged_inv_cov = jnp.asarray(oo.weight.val)[mask]
+            
+            lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
+        
+        else:
+            model = ModelCalibrationLikelihoodVariableCovariance(cop,model_vis,log_inv_cov,mask)
+
+            lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=jnp.iscomplexobj(oo.vis.val))
+        
+        lh_with_model = lh.amend(model)
+        lh_with_model._domain = jft.Vector(lh_with_model._domain)
+
+        if label is not None:
+            lh_with_model._name = label
+            lh_with_model._name_n_ws = " "*(len(str(len(obs)))-len(str(ii)))
+
+        lhs.append(lh_with_model)
+
+    if set(labels) == {None}:
+        return jft.likelihood.LikelihoodSum(*lhs)
+    else:
+        return jft.likelihood.LikelihoodSum(*lhs, _key_template="lh {likelihood._name_n_ws}{index} | {likelihood._name}")
+
+
+def ImagingLikelihood(
+    observation: Union[Observation, Iterable[Observation]],
+    sky_operator: Union[jft.Model,Iterable[jft.Model]],
+    sky_domain_dict: dict,
+    epsilon: float,
+    do_wgridding: bool,
+    log_inverse_covariance_operator: Union[jft.Model,Iterable[jft.Model]] = None,
+    calibration_operator: Union[CalibrationDistribution, Iterable[CalibrationDistribution]] = None,
+    calibration_field: Union[jnp.ndarray, Iterable[jnp.ndarray]] = None,
+    likelihood_label: Union[str,Iterable[str]] = None,
+    verbosity: int = 0,
+    nthreads: int = 1,
+    backend: str = "ducc0",
+):
+    """Versatile likelihood class.
+
+    If a calibration operator is passed, it returns an operator that computes:
+
+    residual = calibration_operator * (R @ sky_operator)
+    likelihood = 0.5 * residual^dagger @ inverse_covariance @ residual
+
+    Otherwise, it returns an operator that computes:
+
+    residual = R @ sky_operator
+    likelihood = 0.5 * residual^dagger @ inverse_covariance @ residual
+
+    If an inverse_covariance_operator is passed, it is inserted into the above
+    formulae. If it is not passed, 1/observation.weights is used as inverse
+    covariance.
+
+    Parameters
+    ----------
+    observation : Observation or Iterable of Observation
+        Observation objects from which vis, uvw, freq and potentially weight
+        are used for computing the likelihood.
+
+    sky_operator : jft.Model
+        Operator that generates sky.
+
+    sky_domain_dict: dict
+        A dictionary providing information about the discretization of the sky.
+
+    epsilon : float
+
+    do_wgridding : bool
+
+    log_inverse_covariance_operator : jft.Model or Iterable of jft.Model
+        Optional. Target needs to be the same space as observation.vis. If it
+        is not specified, observation.wgt is taken as covariance. If used, the same
+        amount of elements as number of observations should be provided.
+
+    calibration_operator : CalibrationDistribution or Iterable of CalibrationDistribution
+        Optional. Target needs to be the same as observation.vis. If used, the same
+        amount of elements as number of observations should be provided.
+
+    calibration_field: jnp.ndarray or Iterable of jnp.ndarray
+        Optional. Domain needs to be the same as observation.vis. If used, the same
+        amount of elements as number of observations should be provided.
+    
+    likelihood_label: string or Iterable of string
+        Optional. Append labels to individual likelihoods which are shown in the minisanity
+        for overview. If used, the same amount of elements as number of observations 
+        should be provided.
+
+    verbosity : int
+
+    nthreads : int
+
+    backend: string
+
+    Note
+    ----
+    For each observation only either calibration_operator or calibration_field
+    can be set.
+    """
+
+    obs = _obj2list(observation,Observation)
+    cops = _duplicate(_obj2list(calibration_operator,CalibrationDistribution),len(obs))
+    cflds = _duplicate(_obj2list(calibration_field,ArrayImpl),len(obs))
+    log_inv_covs = _duplicate(_obj2list(log_inverse_covariance_operator,Model),len(obs))
+    labels = _duplicate(_obj2list(likelihood_label,str),len(obs))
+
+    lhs = []
+    
+    for ii, (oo, cop, cfld, log_inv_cov,label) in enumerate(zip(obs,cops,cflds,log_inv_covs,labels)):
+        if cfld is not None and cop is not None:
+            raise ValueError(
+                f"Can't set calibration operator and calibration field simultaneously at index {ii}"
+            )
+
+        R = InterferometryResponse(oo,sky_domain_dict,do_wgridding,epsilon,nthreads,verbosity,backend)
+        mask = jnp.asarray(oo.mask.val)
+
+        flagged_data = jnp.asarray(oo.vis.val)[mask]
+
+        if log_inv_cov is None:
+            model = ModelImagingLikelihoodFixedCovariance(R,sky_operator,mask,cop,cfld)
+
+            flagged_inv_cov = jnp.asarray(oo.weight.val)[mask]
+        
+            lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov)
+
+        else:
+            model = ModelImagingLikelihoodVariableCovariance(R,sky_operator,log_inv_cov,mask,cop,cfld)
+
+            lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=jnp.iscomplexobj(oo.vis.val))
+
+        lh_with_model = lh.amend(model)
+        lh_with_model._domain = jft.Vector(lh_with_model._domain)
+
+        if label is not None:
+            lh_with_model._name = label
+            lh_with_model._name_n_ws = " "*(len(str(len(obs)))-len(str(ii)))
+
+        lhs.append(lh_with_model)
+    
+    if set(labels) == {None}:
+        return jft.likelihood.LikelihoodSum(*lhs)
+    else:
+        return jft.likelihood.LikelihoodSum(*lhs, _key_template="lh {likelihood._name_n_ws}{index} | {likelihood._name}")
diff --git a/resolve/re/likelihood_models.py b/resolve/re/likelihood_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee3b3ffd1cd67484bdc8e7c7d779ccea7dba85b2
--- /dev/null
+++ b/resolve/re/likelihood_models.py
@@ -0,0 +1,175 @@
+import nifty8.re as jft
+import jax.numpy as jnp
+
+from .calibration import CalibrationDistribution
+
+from typing import Callable
+
+class ModelCalibrationLikelihoodFixedCovariance(jft.Model):
+    """
+    Provides a flagged data model for calibration
+
+    Parameters
+    ----------
+    cop: CalibrationDistribution
+        Calibration operator
+    model_visibilities: jnp.ndarray
+        Assumed visibilities of the point source.
+    mask: jnp.array
+        Mask as boolean numpy array for good visibilites
+    """
+    def __init__(
+            self, 
+            cop: CalibrationDistribution, 
+            model_visibilities: jnp.ndarray, 
+            mask: jnp.ndarray
+            ):
+        self._cop = cop
+        self._vis = model_visibilities
+        self._mask = mask
+
+        super().__init__(init=self._cop.init)
+
+    def __call__(self, x):
+        data_model = self._vis*self._cop(x)
+        flagged_data_model = data_model[self._mask]
+
+        return flagged_data_model
+    
+class ModelCalibrationLikelihoodVariableCovariance(jft.Model):
+    """
+    Provides a combined flagged data model and flagged inverse covariance model for calibration
+
+    Parameters
+    ----------
+    cop: CalibrationDistribution
+        Calibration operator
+    model_visibilities: jnp.ndarray
+        Assumed visibilities of the point source.
+    log_inverse_covariance_model: jft.Model
+        Model for log inverse covariance
+    mask: jnp.array
+        Mask as boolean numpy array for good visibilites
+    """
+    def __init__(
+            self, 
+            cop: CalibrationDistribution, 
+            model_visibilities: jnp.ndarray, 
+            log_inverse_covariance_model: jft.Model, 
+            mask: jnp.ndarray
+            ):
+        self._cop = cop
+        self._vis = model_visibilities
+        self._mask = mask
+        self._log_inv_cov = log_inverse_covariance_model
+
+        super().__init__(init=self._cop.init | self._log_inv_cov.init)
+    
+    def __call__(self, x):
+        data_model = self._vis*self._cop(x)
+        flagged_data_model = data_model[self._mask]
+
+        inv_std = jnp.exp(0.5*self._log_inv_cov(x))
+        flagged_inv_std = inv_std[self._mask]
+        
+        return (flagged_data_model,flagged_inv_std)
+    
+
+class ModelImagingLikelihoodFixedCovariance(jft.Model):
+    """
+    Provides a flagged data model for imaging given an optional calibration operator or
+    calibration field
+
+    Parameters
+    ----------
+    R: Callable
+        Response operator function
+    sky: jft.Model
+        Model for sky
+    mask: jnp.array
+        Mask as boolean numpy array for good visibilites
+    calibration_operator: CalibrationDistribution
+        Optional. Calibration operator
+    calibration_field: jnp.ndarray
+        Optional. Calibration field
+    """
+    def __init__(
+            self, 
+            R: Callable, 
+            sky_operator: jft.Model, 
+            mask: jnp.ndarray, 
+            calibration_operator: CalibrationDistribution = None,
+            calibration_field: jnp.ndarray = None
+            ):
+        
+        self._mask = mask
+        inits = sky_operator.init 
+
+        if(calibration_operator is not None):
+            self._data_model = lambda x: calibration_operator(x)*R(sky_operator(x))
+            inits = inits | calibration_operator.init     
+        elif(calibration_field is not None):
+            self._data_model = lambda x: calibration_field*R(sky_operator(x))
+        else:
+            self._data_model = lambda x: R(sky_operator(x))
+
+        super().__init__(init=inits)
+
+    def __call__(self,x):
+        data_model = self._data_model(x)
+        flagged_data_model = data_model[self._mask]
+
+        return flagged_data_model
+
+class ModelImagingLikelihoodVariableCovariance(jft.Model):
+    """
+    Provides a combined flagged data model and flagged inverse covariance model for imaging 
+    given an optional calibration operator or calibration field
+
+    Parameters
+    ----------
+    R: Callable
+        Response operator function
+    sky: jft.Model
+        Model for sky
+    log_inverse_covariance_model: jft.Model
+        Model for log inverse covariance
+    mask: jnp.array
+        Mask as boolean numpy array for good visibilites
+    calibration_operator: CalibrationDistribution
+        Optional. Calibration operator
+    calibration_field: jnp.ndarray
+        Optional. Calibration field
+    """
+    def __init__(
+            self, 
+            R: Callable, 
+            sky_operator: jft.Model, 
+            log_inverse_covariance_model: jft.Model, 
+            mask: jnp.ndarray, 
+            calibration_operator: CalibrationDistribution = None,
+            calibration_field: jnp.ndarray = None
+            ):
+        
+        self._mask = mask
+        self._log_inv_cov = log_inverse_covariance_model
+        inits = sky_operator.init | self._log_inv_cov.init
+
+        if(calibration_operator is not None):
+            self._data_model = lambda x: calibration_operator(x)*R(sky_operator(x))
+            inits = inits | calibration_operator.init
+        elif(calibration_field is not None):
+            self._data_model = lambda x: calibration_field*R(sky_operator(x))
+        else:
+            self._data_model = lambda x: R(sky_operator(x))
+
+        super().__init__(init=inits)
+
+    def __call__(self,x):
+        data_model = self._data_model(x)
+        flagged_data_model = data_model[self._mask]
+
+        inv_std = jnp.exp(0.5*self._log_inv_cov(x))
+        flagged_inv_std = inv_std[self._mask]
+        
+        return (flagged_data_model,flagged_inv_std)
diff --git a/resolve/re/optimize.py b/resolve/re/optimize.py
new file mode 100644
index 0000000000000000000000000000000000000000..0835e96ec4eaaa3b8b7c98c904f46be2a4287dc3
--- /dev/null
+++ b/resolve/re/optimize.py
@@ -0,0 +1,286 @@
+import jax
+import numpy as np
+import pickle
+import jax.numpy as jnp
+import nifty8.re as jft
+import nifty8 as ift
+
+from functools import partial
+from os import makedirs
+from os.path import isfile
+
+
+from nifty8.re.optimize_kl import _kl_vg
+from nifty8.re.optimize_kl import _kl_met
+from nifty8.re.optimize_kl import draw_linear_residual
+from nifty8.re.optimize_kl import nonlinearly_update_residual
+from nifty8.re.optimize_kl import get_status_message
+# from nifty8.re.evi import _nonlinearly_update_residual_functions
+
+
+def optimize(
+    R,
+    R_approx,
+    sky,
+    N_inv_sqrt,
+    data,
+    pos,
+    draw_linear_kwargs,
+    nonlinearly_update_kwargs,
+    kl_kwargs,
+    n_samples,
+    n_iter,
+    n_major_step,
+    key,
+    callback=None,
+    out_dir=None,
+    resume=False,
+    init_samples=None,
+    noise_scaling=False,
+    varcov=False,
+    varcov_op=None,
+    method="linear_sample",
+    save_all=True,
+):
+    if varcov:
+        jax_0_data = jnp.broadcast_to(0.0 + 0j, data.shape)
+    else:
+        jax_0_data = jnp.broadcast_to(0.0, data.shape)
+    if not out_dir == None:
+        makedirs(out_dir, exist_ok=True)
+    lfile = f"{out_dir}/last_started_major"
+    last_started_major = 0
+    last_finished_index = -1
+    if resume and isfile(lfile):
+        with open(lfile) as f:
+            last_started_major = int(f.read())
+        lfile = f"{out_dir}/major_{last_started_major}/last_finished_iteration"
+        if resume and isfile(lfile):
+            with open(lfile) as f:
+                last_finished_index = int(f.read())
+
+    def residual_signal_response(x, old_reconstruction):
+        return R_approx(sky(x) - old_reconstruction)
+
+    def noise_weighted_residual(x, old_reconstruction, residual_data):
+        if noise_scaling:
+            return N_inv_sqrt(
+                {
+                    "sky": residual_signal_response(x, old_reconstruction)
+                    - residual_data,
+                    "noise": x["noise"],
+                }
+            )
+        else:
+            return N_inv_sqrt(
+                {"sky": residual_signal_response(x, old_reconstruction) - residual_data}
+            )
+
+    def noise_weighted_residual_varcov(x, old_reconstruction, residual_data):
+        return [
+            N_inv_sqrt(
+                {"sky": residual_signal_response(x, old_reconstruction) - residual_data}
+            ),
+            varcov_op(x),
+        ]
+
+    kl_map = jax.vmap
+    kl_reduce = partial(jax.tree_map, partial(jnp.mean, axis=0))
+
+    def get_lh(old_reconstruction, residual_data):
+        if varcov:
+            lh = jft.VariableCovarianceGaussian(jax_0_data, iscomplex=True).amend(
+                partial(
+                    noise_weighted_residual_varcov,
+                    old_reconstruction=old_reconstruction,
+                    residual_data=residual_data,
+                )
+            )
+        else:
+            lh = jft.Gaussian(jax_0_data).amend(
+                partial(
+                    noise_weighted_residual,
+                    old_reconstruction=old_reconstruction,
+                    residual_data=residual_data,
+                )
+            )
+        return lh
+
+    @jax.jit
+    def my_kl_vg(primals, primals_samples, *, old_reconstruction, residual_data):
+        lh = get_lh(old_reconstruction, residual_data)
+        return _kl_vg(lh, primals, primals_samples, map=kl_map, reduce=kl_reduce)
+
+    @jax.jit
+    def my_kl_metric(
+        primals, tangents, primals_samples, *, old_reconstruction, residual_data
+    ):
+        lh = get_lh(old_reconstruction, residual_data)
+        return _kl_met(
+            lh, primals, tangents, primals_samples, map=kl_map, reduce=kl_reduce
+        )
+
+    @jax.jit
+    def my_draw_linear_residual(
+        pos, key, *, old_reconstruction, residual_data, **kwargs
+    ):
+        lh = get_lh(old_reconstruction, residual_data)
+        return draw_linear_residual(lh, pos, key, **kwargs)
+
+    def my_nonlinearly_update_residual(
+        pos,
+        residual_sample,
+        metric_sample_key,
+        metric_sample_sign,
+        *,
+        old_reconstruction,
+        residual_data,
+        **kwargs,
+    ):
+        lh = get_lh(old_reconstruction, residual_data)
+        return nonlinearly_update_residual(
+            lh,
+            pos,
+            residual_sample,
+            metric_sample_key,
+            metric_sample_sign,
+            # _nonlinear_update_funcs=_nonlin_funcs,
+            **kwargs,
+        )
+
+    def my_stat_mes(samples, state, *, old_reconstruction, residual_data, **kwargs):
+        lh = get_lh(old_reconstruction, residual_data)
+        return get_status_message(
+            samples, state, lh.normalized_residual, name="optVI test", **kwargs
+        )
+
+    if last_finished_index > -1:
+        if save_all:
+            opt_vi_state = pickle.load(
+                open(
+                    f"{out_dir}/major_{last_started_major}/optVIstate_it_{last_finished_index}.p",
+                    "rb",
+                )
+            )
+            samples = pickle.load(
+                open(
+                    f"{out_dir}/major_{last_started_major}/samples_{last_finished_index}.p",
+                    "rb",
+                )
+            )
+        else:
+            opt_vi_state = pickle.load(
+                open(
+                    f"{out_dir}/last_optVIstate.p",
+                    "rb",
+                )
+            )
+            samples = pickle.load(
+                open(
+                    f"{out_dir}/last_samples.p",
+                    "rb",
+                )
+            )
+
+        sub_val = jft.mean(tuple(sky(s) for s in samples))
+        post_mean = np.array(sub_val)
+        post_mean = ift.makeField(R.domain, post_mean)
+        residual_data = data - R(post_mean).val
+    else:
+        opt_vi_state = None
+        samples = None
+        if not init_samples == None:
+            sub_val = jft.mean(tuple(sky(s) for s in init_samples))
+            post_mean = np.array(sub_val)
+            post_mean = ift.makeField(R.domain, post_mean)
+            residual_data = data - R(post_mean).val
+        else:
+            residual_data = data
+            sub_val = jnp.zeros(data.shape, dtype=data.dtype)
+
+    # init optVI
+    opt_vi = jft.OptimizeVI(
+        None,
+        n_iter,
+        _kl_value_and_grad=my_kl_vg,
+        _kl_metric=my_kl_metric,
+        _draw_linear_residual=my_draw_linear_residual,
+        _nonlinearly_update_residual=my_nonlinearly_update_residual,
+        _get_status_message=my_stat_mes,
+        residual_map="lmap",
+    )
+
+    update_kwargs = dict(
+        old_reconstruction=sub_val,
+        residual_data=residual_data,
+    )
+
+    # call optVI.init
+    if opt_vi_state is None:
+        print("init opt vi")
+        opt_vi_state = opt_vi.init_state(
+            key,
+            nit=0,
+            n_samples=n_samples,
+            draw_linear_kwargs=draw_linear_kwargs,
+            nonlinearly_update_kwargs=nonlinearly_update_kwargs,
+            kl_kwargs=kl_kwargs,
+            sample_mode=method,
+        )
+    if samples is None:
+        samples = jft.Samples(pos=pos, samples=None, keys=None)
+
+    for n_major in range(last_started_major, n_major_step):
+        if not out_dir == None:
+            makedirs(out_dir + f"/major_{n_major}", exist_ok=True)
+        minor_start = opt_vi_state.nit - n_iter * n_major
+        print("minor_start: ", minor_start)
+        print("n_iter: ", n_iter)
+        print("n_major:", n_major)
+        print("opt_vi_state.nit: ", opt_vi_state.nit)
+        for i in range(minor_start, n_iter):
+            # do opt_vi update
+            samples, opt_vi_state = opt_vi.update(
+                samples, opt_vi_state, **update_kwargs
+            )
+            msg = opt_vi.get_status_message(
+                samples,
+                opt_vi_state,
+                old_reconstruction=sub_val,
+                residual_data=residual_data,
+            )
+            jft.logger.info(msg)
+
+            if not callback == None:
+                callback(opt_vi_state.minimization_state.x, samples, i, n_major)
+            if not out_dir == None:
+                if save_all:
+                    pickle.dump(
+                        opt_vi_state,
+                        open(f"{out_dir}/major_{n_major}/optVIstate_it_{i}.p", "wb"),
+                    )
+                    pickle.dump(
+                        samples, open(f"{out_dir}/major_{n_major}/samples_{i}.p", "wb")
+                    )
+                else:
+                    pickle.dump(
+                        opt_vi_state,
+                        open(f"{out_dir}/last_optVIstate.p", "wb"),
+                    )
+                    pickle.dump(samples, open(f"{out_dir}/last_samples.p", "wb"))
+                with open(
+                    f"{out_dir}/major_{n_major}/last_finished_iteration", "w"
+                ) as f:
+                    f.write(str(i))
+                with open(f"{out_dir}/last_started_major", "w") as f:
+                    f.write(str(n_major))
+
+        sub_val = jft.mean(tuple(sky(s) for s in samples))
+        post_mean = np.array(sub_val)
+        post_mean = ift.makeField(R.domain, post_mean)
+        residual_data = data - R(post_mean).val
+
+        update_kwargs["old_reconstruction"] = sub_val
+        update_kwargs["residual_data"] = residual_data
+
+    return samples.pos, samples, key
diff --git a/resolve/re/radio_response.py b/resolve/re/radio_response.py
new file mode 100644
index 0000000000000000000000000000000000000000..52e57056d26ce50e8bb147459df232de6388582b
--- /dev/null
+++ b/resolve/re/radio_response.py
@@ -0,0 +1,186 @@
+import resolve as rve
+import nifty8 as ift
+import numpy as np
+import pickle
+import jax.numpy as jnp
+
+from jax.lax import slice as jax_slice
+from functools import partial
+
+def get_jax_fft(domain, target, inverse):
+    if inverse:
+        if domain.harmonic:
+            func = jnp.fft.fftn
+            fct = 1.
+        else:
+            func = jnp.fft.ifftn
+            fct = domain.size
+        fct *= target.scalar_dvol
+    else:
+        if domain.harmonic:
+            func = jnp.fft.ifftn
+            fct = domain.size
+        else:
+            func = jnp.fft.fftn
+            fct = 1.
+        fct *= domain.scalar_dvol
+
+    def fft(x, fct, func):
+        return fct * func(x) if fct != 1 else func(x)
+
+    return partial(fft, fct=fct, func=func)
+
+def build_exact_r(obs, conf_sky, psf_pixels, calibration_field=None, do_wgridding=True, epsilon=1e-9, verbosity=1, nthreads=8):
+    sp_sky_dom =rve.sky_model._spatial_dom(conf_sky)
+    sky_dom = rve.default_sky_domain(sdom=sp_sky_dom)
+    R = rve.InterferometryResponse(obs, sky_dom, do_wgridding, epsilon, verbosity, nthreads)
+
+    full_psf0 = min(2*psf_pixels, sp_sky_dom.shape[0])
+    full_psf1 = min(2*psf_pixels, sp_sky_dom.shape[1])
+    sp_sky_dom_l = (sp_sky_dom.shape[0] + full_psf0, sp_sky_dom.shape[1] + full_psf1)
+    sp_sky_dom_l = ift.RGSpace(sp_sky_dom_l, distances=sp_sky_dom.distances)
+    sky_dom_l = rve.default_sky_domain(sdom=sp_sky_dom_l)
+    R_l = rve.InterferometryResponse(obs, sky_dom_l, do_wgridding, epsilon, verbosity, nthreads)
+
+    dch_l = ift.DomainChangerAndReshaper(R_l.domain[3], R_l.domain)
+    R_l = R_l @ dch_l
+    dch = ift.DomainChangerAndReshaper(R.domain[3], R.domain)
+    R = R @ dch
+
+    if calibration_field is not None:
+        C = ift.makeOp(calibration_field)
+
+        R = C @ R
+        R_l = C @ R_l
+
+    N_inv = ift.DiagonalOperator(obs.weight)
+    RNR = R.adjoint @ N_inv @ R
+    RNR_l = R_l.adjoint @ N_inv @ R_l
+
+    return R, R_l, RNR, RNR_l
+
+
+def compute_PSF(new_R, n_pix0, n_pix1):
+    dom = new_R.domain
+    shp = dom.shape
+    FFT = ift.FFTOperator(new_R.domain)
+
+    delta = np.zeros(shp)
+    delta[shp[0]//2, shp[1]//2] = 1 / dom.scalar_weight()
+    delta = ift.makeField(dom, delta)
+    kernel = new_R(delta)
+
+    # zero kernel
+    sh0 = shp[0]//2
+    sh1 = shp[1]//2
+    z_kern = np.zeros_like(kernel.val)
+    z_kern[sh0-n_pix0:sh0+n_pix0,sh1-n_pix1:sh1+n_pix1] = kernel.val[sh0 - n_pix0:sh0+n_pix0,sh1-n_pix1:sh1+n_pix1]
+
+    pr_kern = np.roll(z_kern, -shp[0]//2, axis=0)
+    pr_kern = np.roll(pr_kern, -shp[1]//2, axis=1)
+    pr_kern = ift.makeField(FFT.domain, pr_kern)
+    from matplotlib.colors import LogNorm
+    ift.single_plot(pr_kern, norm=LogNorm())
+    fourier_kern = FFT(pr_kern)
+    return fourier_kern.val
+
+def compute_approx_noise_kern(new_R, relativ_min_val=0.):
+    dom = new_R.domain
+    shp = dom.shape
+    FFT = ift.FFTOperator(new_R.domain)
+
+    delta = np.zeros(shp)
+    delta[shp[0]//2, shp[1]//2] = 1 / dom.scalar_weight()
+    delta = ift.makeField(dom, delta)
+    kernel = new_R(delta).val
+    kernel = np.roll(kernel, -shp[0]//2, axis=0)
+    kernel = np.roll(kernel, -shp[1]//2, axis=1)
+    kernel = ift.makeField(new_R.target, kernel)
+    FFT = ift.FFTOperator(new_R.domain)
+    max_val = np.max(FFT(kernel).abs().val)
+    min_val = relativ_min_val * max_val
+    min_val = ift.full(FFT.target, min_val)
+    min_val_adder = ift.Adder(min_val)
+
+    pos_eig_val = ift.Operator.identity_operator(FFT.target).exp()
+    pos_eig_val = min_val_adder @ pos_eig_val
+    rls1 = ift.Realizer(pos_eig_val.target)
+    rls2 = ift.Realizer(FFT.domain)
+
+    kernel_pos = rls2 @ FFT.inverse @ rls1.adjoint @ pos_eig_val
+
+    cov = ift.ScalingOperator(kernel_pos.target, 1e-2*max_val)
+    lh = ift.GaussianEnergy(data=kernel, inverse_covariance=cov.inverse) @ kernel_pos
+    init_pos = (FFT(kernel) - min_val).abs().log()
+    energy = ift.EnergyAdapter(position=init_pos, op=lh, want_metric=True)
+
+    ic_newton = ift.DeltaEnergyController(name='Newton', iteration_limit=80, tol_rel_deltaE=0)
+    #minimizer = ift.NewtonCG(ic_newton, max_cg_iterations=400, energy_reduction_factor=1e-3)
+    minimizer = ift.NewtonCG(ic_newton)
+    res = minimizer(energy)[0].position
+    return pos_eig_val(res).val
+
+def build_approximations(RNR, RNR_l, noise_scaling=None, varcov=False, cache_noise_kernel='None', cache_response_kernel='None'):
+    shp = RNR.domain.shape
+    shp_l = RNR_l.domain.shape
+    # assert(shp[0] == shp[1])
+    # assert(shp_l[0] == shp_l[1])
+    FFT = ift.FFTOperator(RNR_l.domain)
+    fft_jax = get_jax_fft(FFT.domain[0], FFT.target[0], False)
+    fft_jax_inv = get_jax_fft(FFT.domain[0], FFT.target[0], True)
+    FFT_s = ift.FFTOperator(RNR.domain)
+    fft_jax_s = get_jax_fft(FFT_s.domain[0], FFT_s.target[0], False)
+    fft_jax_inv_s = get_jax_fft(FFT_s.domain[0], FFT_s.target[0], True)
+
+
+    # build approximate response
+    n_psf_pix0 = (shp_l[0] - shp[0])
+    n_psf_pix1 = (shp_l[1] - shp[1])
+    n_padding0 = n_psf_pix0 // 2
+    n_padding1 = n_psf_pix1 // 2
+    if not cache_response_kernel == 'None':
+        try:
+            psf_kernel = pickle.load(open(f"{cache_response_kernel}.p", "rb"))
+        except:
+            psf_kernel = compute_PSF(RNR_l, n_psf_pix0, n_psf_pix1)
+            pickle.dump(psf_kernel, open(f"{cache_response_kernel}.p", "wb"))
+    else:
+            psf_kernel = compute_PSF(RNR_l, n_psf_pix0, n_psf_pix1)
+    psf_kernel = jnp.array(psf_kernel)
+    apply_psf_kern = lambda x: fft_jax_inv(psf_kernel * fft_jax(x)).real
+
+    slicer = lambda x: jax_slice(
+        x, (n_padding0, n_padding1), (n_padding0+ shp[0], n_padding1+ shp[1])
+    )
+    padder = lambda x: jnp.pad(x, ((n_padding0, n_padding0), (n_padding1, n_padding1)))
+
+    RNR_approx = lambda x: slicer(apply_psf_kern(padder(x)))
+
+    # build approximate N inverse
+    if not cache_noise_kernel == 'None':
+        try:
+            noise_kernel = pickle.load(open(f"{cache_noise_kernel}.p", "rb"))
+        except:
+            noise_kernel = compute_approx_noise_kern(RNR, 1e-3)
+            pickle.dump(noise_kernel, open(f"{cache_noise_kernel}.p", "wb"))
+    else:
+        noise_kernel = compute_approx_noise_kern(RNR, 1e-3)
+    noise_kernel_inv_sqrt = 1. / np.sqrt(noise_kernel)
+    noise_kernel_inv_sqrt = jnp.array(noise_kernel_inv_sqrt)
+    if varcov:
+        fl = ift.full(FFT_s.target, 1.)
+        vol = FFT_s(FFT_s.adjoint(fl)).real.mean().val
+        fac = np.sqrt(1/vol)
+        apply_n_sqinv_kern = lambda x: fac*noise_kernel_inv_sqrt * fft_jax_s(x['sky'])
+    elif not noise_scaling is None:
+        apply_n_sqinv_kern = lambda x: fft_jax_inv_s(
+            noise_scaling(x) * noise_kernel_inv_sqrt * fft_jax_s(x['sky'])
+        ).real
+    else:
+        apply_n_sqinv_kern = lambda x: fft_jax_inv_s(
+            noise_kernel_inv_sqrt * fft_jax_s(x['sky'])
+        ).real
+
+
+    return RNR_approx, apply_n_sqinv_kern
+
diff --git a/resolve/re/sky_model.py b/resolve/re/sky_model.py
index 790bb9589e9d41f4a8410630667263cbee997a20..f07b58b7e0ba0afcbb54ce327e41e87ddd1de8e8 100644
--- a/resolve/re/sky_model.py
+++ b/resolve/re/sky_model.py
@@ -1,14 +1,15 @@
 # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
 # Author: Jakob Roth
 
-from numpy import full
 import nifty8.re as jft
 import jax
 from jax import numpy as jnp
+import numpy as np
 
 
 from ..sky_model import _spatial_dom
-from ..sky_model import sky_model_points as rve_sky_pts
+from ..constants import str2rad
+
 
 def build_cf(prefix, conf, shape, dist):
     zmo = conf.getfloat(f"{prefix} zero mode offset")
@@ -34,7 +35,7 @@ def build_cf(prefix, conf, shape, dist):
     cfm = jft.CorrelatedFieldMaker(prefix)
     cfm.set_amplitude_total_offset(**cf_zm)
     cfm.add_fluctuations(
-        shape, distances=dist, **cf_fl, prefix="", non_parametric_kind='power'
+        shape, distances=dist, **cf_fl, prefix="", non_parametric_kind="power"
     )
     amps = cfm.get_normalized_amplitudes()
     cfm = cfm.finalize()
@@ -42,8 +43,6 @@ def build_cf(prefix, conf, shape, dist):
     return cfm, additional
 
 
-
-
 def sky_model_diffuse(cfg):
     if not cfg["freq mode"] == "single":
         raise NotImplementedError("FIXME: only implemented for single frequency")
@@ -52,63 +51,69 @@ def sky_model_diffuse(cfg):
     sky_dom = _spatial_dom(cfg)
     bg_shape = sky_dom.shape
     bg_distances = sky_dom.distances
-    bg_log_diff, additional = build_cf('stokesI diffuse space i0', cfg, bg_shape, bg_distances)
-    full_shape = (1,1,1)+bg_shape
-
+    bg_log_diff, additional = build_cf(
+        "stokesI diffuse space i0", cfg, bg_shape, bg_distances
+    )
+    full_shape = (1, 1, 1) + bg_shape
 
     def bg_diffuse(x):
         return jnp.broadcast_to(
-            jnp.exp(bg_log_diff(x['stokesI diffuse space i0'])), full_shape
+            jnp.exp(bg_log_diff(x["stokesI diffuse space i0"])), full_shape
         )
 
-
     bg_diffuse_model = jft.Model(
-        bg_diffuse,
-        domain={'stokesI diffuse space i0': bg_log_diff.domain}
-        )
+        bg_diffuse, domain={"stokesI diffuse space i0": bg_log_diff.domain}
+    )
 
     return bg_diffuse_model, additional
 
-def get_pts_postions(cfg):
-    rsp, _ = rve_sky_pts(cfg)
-    return rsp._ops[0]._inds[0], rsp._ops[0]._inds[1]
-
-def get_inv_gamma(cfg):
-    rsp, _ = rve_sky_pts(cfg)
-    return rsp._ops[1].jax_expr
 
 def jax_insert(x, ptsx, ptsy, bg):
     bg = bg.at[:, :, :, ptsx, ptsy].set(x)
     return bg
 
-def sky_model_points(cfg, bg=None):
-    if not cfg["freq mode"] == "single":
-        raise NotImplementedError("FIXME: only implemented for single frequency")
-    if not cfg["polarization"] == "I":
-        raise NotImplementedError("FIXME: only implemented for stokes I")
-
-    if not cfg["point sources mode"] == "single":
-        raise NotImplementedError(
-            "FIXME: point sources only implemented for mode single"
-            )
-    else:
-        sky_dom = _spatial_dom(cfg)
-        shp = sky_dom.shape
-        full_shp = (1,1,1)+shp
-        ptsx, ptsy = get_pts_postions(cfg)
-        inv_gamma = get_inv_gamma(cfg)
-        pts_shp = (len(ptsx),)
-        dom = {'points': jax.ShapeDtypeStruct(pts_shp, dtype=jnp.float64)}
-        if bg is None:
-            bg = jnp.zeros(full_shp)
-            pts_func = lambda x: jax_insert(inv_gamma(x['points']), ptsx=ptsx, ptsy=ptsy, bg=bg)
-        else:
-            pts_func = lambda x: jax_insert(inv_gamma(x['points']), ptsx=ptsx, ptsy=ptsy, bg=bg(x))
-            dom = {**dom, **bg.domain}
-    pts_model = jft.Model(pts_func, domain=dom)
-    additional = {}
-    return pts_model, additional
 
+def sky_model_points(cfg, bg=None):
+    if cfg["freq mode"] == "single":
+        if cfg["polarization"] == "I":
+            if cfg["point sources mode"] == "single":
+                ppos = []
+                sky_dom = _spatial_dom(cfg)
+                s = cfg["point sources locations"]
+                for xy in s.split(","):
+                    x, y = xy.split("$")
+                    ppos.append((str2rad(x), str2rad(y)))
+                ppos = np.array(ppos)
+                dx = np.array(sky_dom.distances)
+                center = np.array(sky_dom.shape) // 2
+                inds = np.unique(np.round(ppos / dx + center).astype(int).T, axis=1)
+                indsx, indsy = inds
+                alpha = cfg.getfloat("point sources alpha")
+                q = cfg.getfloat("point sources q")
+                sky_dom = _spatial_dom(cfg)
+                shp = sky_dom.shape
+                full_shp = (1, 1, 1) + shp
+                inv_gamma = jft.InvGammaPrior(
+                    a=alpha,
+                    scale=q,
+                    name="points",
+                    shape=jax.ShapeDtypeStruct((len(indsx),), float),
+                )
+                if bg is None:
+                    bg = jnp.zeros(full_shp)
+                    pts_func = lambda x: jax_insert(
+                        inv_gamma(x), ptsx=indsx, ptsy=indsy, bg=bg
+                    )
+                    dom = inv_gamma.domain
+                else:
+                    pts_func = lambda x: jax_insert(
+                        inv_gamma(x), ptsx=indsx, ptsy=indsy, bg=bg(x)
+                    )
+                    dom = {**inv_gamma.domain, **bg.domain}
+                pts_model = jft.Model(pts_func, domain=dom)
+                additional = {}
+                return pts_model, additional
+    raise NotImplementedError("FIXME: Not Implemented for selected mode.")
 
 
 def sky_model(cfg):
diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfb549524a5886de919ddf649d0ac74a40fd684a
--- /dev/null
+++ b/resolve/re/sugar.py
@@ -0,0 +1,230 @@
+import jax.numpy as jnp
+import nifty8.re as jft
+import matplotlib.pyplot as plt
+
+from .calibration import CalibrationDistribution
+from .likelihood_models import ModelCalibrationLikelihoodFixedCovariance
+
+from ..data.observation import Observation
+
+from typing import Union, Iterable, Dict, Tuple, Optional
+from dataclasses import dataclass, field
+
+class Bulk_CF_AntennaTimeDomain(jft.Model):
+    """
+    Creates multiple independant correlated fields from the same offset and powerspectra 
+    parameters. Number of correlated fields is the product of the number of polarization
+    directions, antennas and frequencies.
+
+    Parameters
+    ----------
+    dct_offset: dictionary
+        Dictionary containing information about the offset parameters.
+    dct_ps: dictionary
+        Dictionary containing information about the temporal powerspectrum parameters
+    prefix: string
+        Prefix to the names of the parameters of a correlated field.
+    polarizations: Iterable of int or str
+        Labels for polarization directions
+    frequencies: Iterable of int or str
+        Labels for frequencies:
+    antennas: Iterable of int or str
+        Labels for antennas
+    """
+    def __init__(
+            self,
+            dct_offset: dict,
+            dct_ps: dict,
+            prefix: str,
+            polarizations: Iterable[Union[int, str]],
+            frequencies: Iterable[Union[int, str]],
+            antennas: Iterable[Union[int, str]]
+            ):
+        self._pol = list(polarizations)
+        self._ant = list(antennas)
+        self._freq = list(frequencies)
+        self._prefix = str(prefix)
+
+        self._output_shape = (len(self._pol),len(self._ant),len(self._freq),dct_ps["shape"][0])
+        
+        cfm = jft.CorrelatedFieldMaker(self._prefix + "_")
+        cfm.set_amplitude_total_offset(**dct_offset)
+        cfm.add_fluctuations(**dct_ps)
+
+        n_total = len(self._pol)*len(self._ant)*len(self._freq)
+
+        self._fields = jft.VModel(cfm.finalize(), axis_size=n_total)
+        self._powerspectrum = cfm.power_spectrum
+
+        super().__init__(init=self._fields.init)
+
+    def __call__(self,x):
+        return jnp.swapaxes(jnp.reshape(self._fields(x),self._output_shape),2,3)
+    
+    @property
+    def name(self):
+        return self._prefix
+    
+    def get_powerspectrum(self):
+        return self._powerspectrum
+
+@dataclass    
+class CalibrationAssembler:
+    """
+    Collect and automatically creates quantities necessary for reconstruction of the calibration solution
+    of radio-interferometric data.
+
+    Parameters
+    ----------
+    obs: Observation
+        Observation object providing multiple quantities for construction of claibration
+        operator and likelihood.
+    phase_field: jft.Model or Bulk_CF_AntennaTimeDomain (preferred)
+        Correlated fields on antenna-time space for phases of calibration solutions.
+    logflux_field: jft.Model or Bulk_CF_AntennaTimeDomain (preferred)
+        Correlated fields on antenna-time space for log amplitude of calibration solutions.
+    dt: float
+        Distances between time points on time axis. Has to be the same distance of time points,
+        which is used for phase_fields and log_amplitude fields.
+    model_vis: jnp.ndarray
+        Optional. Assumed visibilities of the the source.
+    scaling_op: jft.Model
+        Optional. Model for scalar scaling of model_vis.
+    log_inv_cov: jft.Model
+        Optional. Model for log inverse covariance.
+    lh_label: str
+        Optional. Label for likelihood. If not set the default is the name of 
+        the observation.
+    """
+    obs: Observation
+    phase_field: jft.Model
+    logflux_field: jft.Model
+    dt: float
+    model_vis: Optional[jnp.ndarray] = None
+    scaling_op: Optional[jft.WrappedCall] = None
+    log_inv_cov: Optional[jft.Model] = None
+    lh_label: Optional[str] = None
+    cop: CalibrationDistribution = field(init=False)
+    scaled_cop: Optional[jft.Model] = field(init=False)
+
+    def __post_init__(self):
+        self.lh_label = self.obs.source_name if self.lh_label is None else self.lh_label
+        self.cop = CalibrationDistribution(self.obs,self.phase_field,self.logflux_field,self.dt)
+        self.scaled_cop = self.cop if self.scaling_op is None else jft.Model(
+            call = lambda x: self.scaling_op(x)*self.cop(x),
+            domain = {**self.cop.domain,**self.scaling_op.domain},
+            init =  self.cop.init | self.scaling_op.init
+        )
+
+    @classmethod
+    def make(
+        cls,
+        observation: Observation, 
+        phase_field: jft.Model, 
+        logflux_field: jft.Model, 
+        dt: float, 
+        model_vis_flux_amplitude: Optional[complex] = None, 
+        scaling_parameters: Optional[Tuple[str,Dict]] = None, 
+        log_inv_cov: Optional[jft.Model] = None, 
+        lh_label: Optional[str] = None
+        ):
+            """
+            Constructs CalibrationAssembler with constant field for model_vis and an optional scaling_op.
+
+            Parameters
+            ----------
+            obs: Observation
+                Observation object providing multiple quantities for construction of claibration
+                operator and likelihood.
+            phase_field: jft.Model or Bulk_CF_AntennaTimeDomain (preferred)
+                Correlated fields on antenna-time space for phases of calibration solutions.
+            logflux_field: jft.Model or Bulk_CF_AntennaTimeDomain (preferred)
+                Correlated fields on antenna-time space for log amplitude of calibration solutions.
+            dt: float
+                Distances between time points on time axis. Has to be the same distance of time points,
+                which is used for phase_fields and log_amplitude fields.
+            model_vis_flux_amplitude: complex
+                Optional. Assumed magnitude of field for model_vis.
+            scaling_parameters: Tuple[str,Dict]
+                Optional. Selects prior model for scaling operator through string and dictionary contains
+                parameters of corresponding model. Currently only available for the following prior models 
+                (see NIFTy.re package for arguments):
+                    - "inverse_gamma"
+                    - "log_normal"
+            log_inv_cov: jft.Model
+                Optional. Model for log inverse covariance.
+            lh_label: str
+                Optional. Label for likelihood. If not set the default is the name of 
+                the observation.
+            """
+            init_flux_field = None if model_vis_flux_amplitude is None else model_vis_flux_amplitude*jnp.ones(observation.vis.shape)
+            if scaling_parameters is not None:
+                match scaling_parameters[0]:
+                    case "inv_gamma":
+                        scaling_op = jft.InvGammaPrior(**scaling_parameters[1])
+                    case "log_normal":
+                        scaling_op = jft.LogNormalPrior(**scaling_parameters[1])
+                    case _:
+                        raise NotImplementedError("Constructor currently only supports 'log_normal' and 'inv_gamma' for scaling operator construction")
+            else:
+                scaling_op = None
+
+            return cls(observation,phase_field,logflux_field,dt,init_flux_field,scaling_op,log_inv_cov,lh_label)
+
+    def prior_realization(self,rng_key):
+        data_model = ModelCalibrationLikelihoodFixedCovariance(self.scaled_cop,self.model_vis,jnp.asarray(self.obs.mask.val))
+        data_model_realization = data_model(data_model.init(rng_key))
+
+        plt.scatter(self.obs.vis.val.imag,self.obs.vis.val.real,label="Data")
+        plt.scatter(data_model_realization.imag,data_model_realization.real, alpha=0.01, label="Prior sample")
+        plt.legend()
+        plt.show()
+
+    def get_lh_domain(self):
+        dom = self.scaled_cop.domain if self.log_inv_cov is None else self.scaled_cop.domain | self.log_inv_cov.domain
+        return dom
+
+    def __repr__(self):
+        return(
+            f"Obervation: {self.obs.source_name}\n"
+            f"Phase field: {self.phase_field.name}\n"
+            f"Logflux field: {self.logflux_field.name}\n"
+            f"dt: {self.dt}\n"
+            f"Model visibilities field set: {False if self.model_vis is None else True}\n"
+            f"Scaling operator set: {False if self.scaling_op is None else True}\n"
+            f"Log inverse covariance set: {False if self.log_inv_cov is None else True}\n"
+            f"Likelihood label: {self.lh_label}\n"
+        )
+
+@dataclass
+class SkyAssembler:
+    """
+    Data class for storage of quantities for the sky reconstruction with fast resolve
+
+    Parameters
+    ----------
+    obs: Observation
+        Observation object providing multiple quantities for construction of claibration
+        operator and likelihood.
+    sky: jft.Model
+        Sky model for reconstruction
+    noise_scaling: jft.Model
+        Optional. Model for noise scaling.
+    varcov: jft.Model
+        Optional. Model for variable covariance
+    """
+    obs: Observation
+    sky: jft.Model
+    noise_scaling: Optional[jft.Model] = None
+    varcov: Optional[jft.Model] = None
+
+    def __post_init__(self):
+        if (self.noise_scaling is not None) and (self.varcov is not None):
+            raise ValueError("Either noise_scaling or varcov can be set.")
+        
+    def __repr__(self):
+        return(
+            f"Obervation: {self.obs.source_name}\n"
+            f"Noise scaling model set: {False if self.noise_scaling is None else True}\n"
+            f"Variable covariance model set: {False if self.varcov is None else True}\n"
+        )
\ No newline at end of file
diff --git a/test/test_jax_ports/observation_generator.py b/test/test_jax_ports/observation_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e7cd25ad89300359fe38545a2659cc443d758e9
--- /dev/null
+++ b/test/test_jax_ports/observation_generator.py
@@ -0,0 +1,41 @@
+import numpy as np
+
+import resolve as rve
+
+def RandomAntennaPositions(n_row,uvw_max,n_antenna_max,time_max,rng_generator=np.random.default_rng(42)):
+    ant1,ant2 = [],[]
+
+    for k in range(n_row):
+        x = rng_generator.integers(0,n_antenna_max-1)
+        y = rng_generator.integers(1,n_antenna_max)
+
+        if(x==y):
+            while(x==y):
+                y = np.random.randint(1,n_antenna_max)
+
+        ant1.append(x)
+        ant2.append(y)
+
+    ant1 = np.array(ant1).astype(np.int32)
+    ant2 = np.array(ant2).astype(np.int32)
+    time = rng_generator.uniform(0,time_max,n_row)
+    uvw = rng_generator.uniform(-uvw_max,uvw_max,(n_row,3))
+
+    return rve.data.antenna_positions.AntennaPositions(uvw,ant1,ant2,time)
+
+def RandomObservation(n_baselines,pol_indices,freq_channels,uvw_max,n_antenna_max,time_max,abs_vis_max,weight_min,weight_max,rng_generator=np.random.default_rng(42)):
+    antenna_pos = RandomAntennaPositions(n_baselines,uvw_max,n_antenna_max,time_max,rng_generator=rng_generator)
+
+    n_pol = len(pol_indices)
+    n_freq = freq_channels.size
+    vis_shape = (n_pol,n_baselines,n_freq)
+
+    pol = rve.data.polarization.Polarization(pol_indices)
+
+    vis_magnitude = rng_generator.uniform(0,abs_vis_max,vis_shape)
+    vis_phase = rng_generator.uniform(0,2*np.pi,vis_shape)
+    vis = vis_magnitude*np.exp(1.0j*vis_phase)
+
+    weights = rng_generator.uniform(weight_min,weight_max,vis_shape)
+    
+    return rve.data.observation.Observation(antenna_pos,vis,weights,pol,freq_channels)
\ No newline at end of file
diff --git a/test/test_jax_ports/test_01_BulkCF.py b/test/test_jax_ports/test_01_BulkCF.py
new file mode 100644
index 0000000000000000000000000000000000000000..aac1716f152bad99685134611ee797e9e5424f52
--- /dev/null
+++ b/test/test_jax_ports/test_01_BulkCF.py
@@ -0,0 +1,94 @@
+import os
+os.environ["JAX_PLATFORM_NAME"] = "cpu"
+
+import numpy as np
+import nifty8 as ift
+import ducc0
+import resolve as rve
+import resolve.re as jrve
+
+from numpy.testing import assert_allclose
+
+import jax
+jax.config.update("jax_enable_x64", True)
+
+np.random.seed(42)
+
+from observation_generator import RandomObservation
+from to_re_helpers import rve_cf_dict_to_jrve_cf_dict
+
+def prepare_operators():
+    param = {
+    "n_baselines": 10000,
+    "pol_indices": (8,5),
+    "freq_channels": np.array([1.5e9]),
+    "uvw_max": 200,
+    "n_antenna_max": 10,
+    "time_max": 5000,
+    "abs_vis_max": 0.5,
+    "weight_min":0,
+    "weight_max": 10000,
+    }
+
+    obs = RandomObservation(**param)
+
+    uantennas = rve.unique_antennas(obs)
+    antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)}
+    total_N = len(param["pol_indices"]) * len(uantennas)
+
+    dt = 20  # s
+    zero_padding_factor = 2
+    tmin, tmax = rve.tmin_tmax(obs)
+    time_domain = ift.RGSpace(
+        ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt
+    )
+
+    polarizations = obs.polarization.to_str_list()
+    frequencies = [1]
+    antennas = uantennas
+
+    dofdex = np.arange(total_N)
+    rve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5), "dofdex": dofdex}
+    rve_dct_ps = {
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "dofdex": dofdex
+    }
+
+    jrve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5),}
+    jrve_dct_ps = {
+        "shape": time_domain.shape,
+        "distances": time_domain.distances[0],
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "non_parametric_kind": "power"
+    }
+
+    cfmph = ift.CorrelatedFieldMaker(prefix='bulk_', total_N=total_N)
+    cfmph.add_fluctuations(time_domain, **rve_dct_ps)
+    cfmph.set_amplitude_total_offset(**rve_dct_offset)
+    rve_bulkCF = cfmph.finalize(0)
+
+    pdom, _, fdom = obs.vis.domain
+    reshape = ift.DomainChangerAndReshaper(rve_bulkCF.target, [pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom])
+
+    rve_op = reshape @ rve_bulkCF
+
+    jrve_op = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"bulk",polarizations,frequencies,antennas)
+
+    return rve_op, jrve_op
+
+def test_consistency_BulkCF():
+    rve_op, jrve_op = prepare_operators()
+
+    rve_inp = ift.from_random(rve_op.domain)
+    jrve_inp = rve_cf_dict_to_jrve_cf_dict(rve_inp.val)
+
+    rve_field = rve_op(rve_inp).val
+    jrve_field = jrve_op(jrve_inp)
+
+    assert_allclose(jrve_field,rve_field)
diff --git a/test/test_jax_ports/test_02_CalibrationInterpolator.py b/test/test_jax_ports/test_02_CalibrationInterpolator.py
new file mode 100644
index 0000000000000000000000000000000000000000..63fc7bee37ba2466dfc77603f6ed1eaf31f192e8
--- /dev/null
+++ b/test/test_jax_ports/test_02_CalibrationInterpolator.py
@@ -0,0 +1,64 @@
+import os
+os.environ["JAX_PLATFORM_NAME"] = "cpu"
+
+import numpy as np
+import jax.numpy as jnp
+import nifty8 as ift
+import ducc0
+import resolve as rve
+import resolve.re as jrve
+
+from numpy.testing import assert_allclose
+
+import jax
+jax.config.update("jax_enable_x64", True)
+
+np.random.seed(42)
+
+from observation_generator import RandomObservation
+from to_re_helpers import rve_cf_dict_to_jrve_cf_dict
+
+def prepare_operators():
+    param = {
+    "n_baselines": 10000,
+    "pol_indices": (8,5),
+    "freq_channels": np.array([1.5e9]),
+    "uvw_max": 200,
+    "n_antenna_max": 10,
+    "time_max": 5000,
+    "abs_vis_max": 0.5,
+    "weight_min":0,
+    "weight_max": 10000,
+    }
+
+    obs = RandomObservation(**param)
+
+    uantennas = rve.unique_antennas(obs)
+    antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)}
+    total_N = len(param["pol_indices"]) * len(uantennas)
+
+    dt = 20  # s
+    zero_padding_factor = 2
+    tmin, tmax = rve.tmin_tmax(obs)
+    time_domain = ift.RGSpace(
+        ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt
+    )
+
+    pdom, _, fdom = obs.vis.domain
+    dom = ift.DomainTuple.make([pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom])
+    
+    jrve_op = jrve.CalibrationInterpolator(obs.ant1,obs.time,time_domain.distances[0],obs.vis.val.shape)
+
+    rve_op = rve.CalibrationDistributor(dom,obs.vis.domain,obs.ant1,obs.time,antenna_dct,None)
+
+    return rve_op, jrve_op
+
+def test_consistency_CalibrationInterpolator():
+    rve_op, jrve_op = prepare_operators()
+
+    inp = ift.from_random(rve_op.domain)
+
+    rve_field = rve_op(inp).val
+    jrve_field = jrve_op(jnp.asarray(inp.val))
+
+    assert_allclose(jrve_field,rve_field)
\ No newline at end of file
diff --git a/test/test_jax_ports/test_03_CalibrationDistributor.py b/test/test_jax_ports/test_03_CalibrationDistributor.py
new file mode 100644
index 0000000000000000000000000000000000000000..b755e699f0d0bd3638ac49a5bd99a54ee208e7dd
--- /dev/null
+++ b/test/test_jax_ports/test_03_CalibrationDistributor.py
@@ -0,0 +1,104 @@
+import os
+os.environ["JAX_PLATFORM_NAME"] = "cpu"
+
+import numpy as np
+import nifty8 as ift
+import ducc0
+import resolve as rve
+import resolve.re as jrve
+
+from numpy.testing import assert_allclose
+
+import jax
+jax.config.update("jax_enable_x64", True)
+
+np.random.seed(42)
+
+from observation_generator import RandomObservation
+from to_re_helpers import rve_cf_dict_to_jrve_cf_dict
+
+def prepare_operators():
+    param = {
+    "n_baselines": 10000,
+    "pol_indices": (8,5),
+    "freq_channels": np.array([1.5e9]),
+    "uvw_max": 200,
+    "n_antenna_max": 10,
+    "time_max": 5000,
+    "abs_vis_max": 0.5,
+    "weight_min":0,
+    "weight_max": 10000,
+    }
+
+    obs = RandomObservation(**param)
+
+    uantennas = rve.unique_antennas(obs)
+    antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)}
+    total_N = len(param["pol_indices"]) * len(uantennas)
+
+    dt = 20  # s
+    zero_padding_factor = 2
+    tmin, tmax = rve.tmin_tmax(obs)
+    time_domain = ift.RGSpace(
+        ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt
+    )
+
+    polarizations = obs.polarization.to_str_list()
+    frequencies = [1]
+    antennas = uantennas
+
+    dofdex = np.arange(total_N)
+    rve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5), "dofdex": dofdex}
+    rve_dct_ps = {
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "dofdex": dofdex
+    }
+
+    jrve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5),}
+    jrve_dct_ps = {
+        "shape": time_domain.shape,
+        "distances": time_domain.distances[0],
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "non_parametric_kind": "power"
+    }
+
+    cfmph = ift.CorrelatedFieldMaker(prefix='phase_', total_N=total_N)
+    cfmph.add_fluctuations(time_domain, **rve_dct_ps)
+    cfmph.set_amplitude_total_offset(**rve_dct_offset)
+    rve_phase = cfmph.finalize(0)
+
+    cfmph = ift.CorrelatedFieldMaker(prefix='logamp_', total_N=total_N)
+    cfmph.add_fluctuations(time_domain, **rve_dct_ps)
+    cfmph.set_amplitude_total_offset(**rve_dct_offset)
+    rve_logamp = cfmph.finalize(0)
+
+    pdom, _, fdom = obs.vis.domain
+    reshape = ift.DomainChangerAndReshaper(rve_phase.target, [pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom])
+
+    rve_phase = reshape @ rve_phase
+    rve_logamp = reshape @ rve_logamp
+
+    jrve_phase = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"phase",polarizations,frequencies,antennas)
+    jrve_logamp = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"logamp",polarizations,frequencies,antennas)
+
+    rve_op = rve.calibration_distribution(obs,rve_phase,rve_logamp,antenna_dct,None)
+    jrve_op = jrve.CalibrationDistribution(obs,jrve_phase,jrve_logamp,time_domain.distances[0])
+
+    return rve_op, jrve_op
+
+def test_consistency_CalibrationDistributor():
+    rve_op, jrve_op = prepare_operators()
+
+    rve_inp = ift.from_random(rve_op.domain)
+    jrve_inp = rve_cf_dict_to_jrve_cf_dict(rve_inp.val)
+
+    rve_field = rve_op(rve_inp).val
+    jrve_field = jrve_op(jrve_inp)
+
+    assert_allclose(jrve_field,rve_field)
\ No newline at end of file
diff --git a/test/test_jax_ports/test_04_CalibrationLikelihood.py b/test/test_jax_ports/test_04_CalibrationLikelihood.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a469d8217b173ec34063c8c3813a22be4ed9e3a
--- /dev/null
+++ b/test/test_jax_ports/test_04_CalibrationLikelihood.py
@@ -0,0 +1,151 @@
+import os
+os.environ["JAX_PLATFORM_NAME"] = "cpu"
+
+import numpy as np
+import jax.numpy as jnp
+import nifty8 as ift
+import nifty8.re as jft
+import ducc0
+import resolve as rve
+import resolve.re as jrve
+from functools import partial
+
+import pytest
+from numpy.testing import assert_allclose
+
+import jax
+jax.config.update("jax_enable_x64", True)
+
+np.random.seed(42)
+
+from observation_generator import RandomObservation
+from to_re_helpers import rve_cf_dict_to_jrve_cf_dict, mixed_rve_dct_to_mixed_jrve
+
+pmp = pytest.mark.parametrize
+
+def prepare_partial_operators():
+    param = {
+    "n_baselines": 10000,
+    "pol_indices": (8,5),
+    "freq_channels": np.array([1.5e9]),
+    "uvw_max": 200,
+    "n_antenna_max": 10,
+    "time_max": 5000,
+    "abs_vis_max": 0.5,
+    "weight_min":0,
+    "weight_max": 10000,
+    }
+
+    obs = RandomObservation(**param)
+
+    uantennas = rve.unique_antennas(obs)
+    antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)}
+    total_N = len(param["pol_indices"]) * len(uantennas)
+
+    dt = 20  # s
+    zero_padding_factor = 2
+    tmin, tmax = rve.tmin_tmax(obs)
+    time_domain = ift.RGSpace(
+        ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt
+    )
+
+    polarizations = obs.polarization.to_str_list()
+    frequencies = [1]
+    antennas = uantennas
+
+    dofdex = np.arange(total_N)
+    rve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5), "dofdex": dofdex}
+    rve_dct_ps = {
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "dofdex": dofdex
+    }
+
+    jrve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5),}
+    jrve_dct_ps = {
+        "shape": time_domain.shape,
+        "distances": time_domain.distances[0],
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "non_parametric_kind": "power"
+    }
+
+    cfmph = ift.CorrelatedFieldMaker(prefix='phase_', total_N=total_N)
+    cfmph.add_fluctuations(time_domain, **rve_dct_ps)
+    cfmph.set_amplitude_total_offset(**rve_dct_offset)
+    rve_phase = cfmph.finalize(0)
+
+    cfmph = ift.CorrelatedFieldMaker(prefix='logamp_', total_N=total_N)
+    cfmph.add_fluctuations(time_domain, **rve_dct_ps)
+    cfmph.set_amplitude_total_offset(**rve_dct_offset)
+    rve_logamp = cfmph.finalize(0)
+
+    pdom, _, fdom = obs.vis.domain
+    reshape = ift.DomainChangerAndReshaper(rve_phase.target, [pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom])
+
+    rve_phase = reshape @ rve_phase
+    rve_logamp = reshape @ rve_logamp
+
+    jrve_phase = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"phase",polarizations,frequencies,antennas)
+    jrve_logamp = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"logamp",polarizations,frequencies,antennas)
+
+    rve_cop = rve.calibration_distribution(obs,rve_phase,rve_logamp,antenna_dct,None)
+    jrve_cop = jrve.CalibrationDistribution(obs,jrve_phase,jrve_logamp,time_domain.distances[0])
+
+    rve_model_vis = ift.full(obs.vis.domain,1 + 0.0j)
+    jrve_model_vis = jnp.array(rve_model_vis.val)
+
+    rve_op = partial(
+        rve.CalibrationLikelihood, 
+        observation = obs, 
+        calibration_operator = rve_cop,
+        model_visibilities = rve_model_vis
+        )
+    
+    jrve_op = partial(
+        jrve.CalibrationLikelihood,
+        observation = obs, 
+        calibration_operator = jrve_cop,
+        model_visibilities = jrve_model_vis
+    )
+
+    return rve_op, jrve_op, obs
+
+def prepare_operators(mode):
+    rve_op_partial, jrve_op_partial, obs = prepare_partial_operators()
+
+    if mode == "model":
+        rve_log_inv_cov = ift.NormalTransform(0,1,"log_inv_cov",obs.weight.val.size)
+        reshape = ift.DomainChangerAndReshaper(rve_log_inv_cov.target,obs.weight.domain)
+        rve_log_inv_cov = reshape @ rve_log_inv_cov
+
+        jrve_log_inv_cov = jft.NormalPrior(0,1,name="log_inv_cov",shape=obs.weight.val.shape)
+    else:
+        rve_log_inv_cov = None
+        jrve_log_inv_cov = None
+
+    rve_op = rve_op_partial(log_inverse_covariance_operator=rve_log_inv_cov)
+    jrve_op = jrve_op_partial(log_inverse_covariance_operator=jrve_log_inv_cov)
+
+    return rve_op, jrve_op, obs
+
+
+@pmp("log_inv_cov_mode",[None, "model"])
+def test_consistency_CalibrationLikelihood(log_inv_cov_mode):
+    rve_op, jrve_op, obs = prepare_operators(log_inv_cov_mode)
+
+    rve_inp = ift.from_random(rve_op.domain)
+
+    if log_inv_cov_mode == "model":
+        jrve_inp = mixed_rve_dct_to_mixed_jrve(rve_inp.val,("log_inv_cov",),(obs.weight.val.shape,))
+    else:
+        jrve_inp = rve_cf_dict_to_jrve_cf_dict(rve_inp.val)
+
+    rve_field = rve_op(rve_inp).val
+    jrve_field = jrve_op(jrve_inp)
+
+    assert_allclose(jrve_field,rve_field)
\ No newline at end of file
diff --git a/test/test_jax_ports/test_05_ImagingLikelihood.py b/test/test_jax_ports/test_05_ImagingLikelihood.py
new file mode 100644
index 0000000000000000000000000000000000000000..57169fdaa43ceced01cd18ea671a31938c0f1398
--- /dev/null
+++ b/test/test_jax_ports/test_05_ImagingLikelihood.py
@@ -0,0 +1,240 @@
+import os
+os.environ["JAX_PLATFORM_NAME"] = "cpu"
+
+import numpy as np
+import jax.numpy as jnp
+import nifty8 as ift
+import nifty8.re as jft
+import ducc0
+import resolve as rve
+import resolve.re as jrve
+from functools import partial
+
+import pytest
+from numpy.testing import assert_allclose
+
+import jax
+jax.config.update("jax_enable_x64", True)
+
+np.random.seed(42)
+
+from observation_generator import RandomObservation
+from to_re_helpers import rve_cf_dict_to_jrve_cf_dict_simple, mixed_rve_dct_to_mixed_jrve
+
+pmp = pytest.mark.parametrize
+
+class JRVE_Sky(jft.Model):
+    def __init__(self,field):
+        self._field = field
+
+        super().__init__(init=self._field.init)
+
+    def __call__(self,x):
+        return jnp.array([[[jnp.exp(self._field(x))]]])
+
+def prepare_partial_operators():
+    param = {
+    "n_baselines": 10000,
+    "pol_indices": (8,5),
+    "freq_channels": np.array([1.5e9]),
+    "uvw_max": 200,
+    "n_antenna_max": 10,
+    "time_max": 5000,
+    "abs_vis_max": 0.5,
+    "weight_min":0,
+    "weight_max": 10000,
+    }
+
+    obs = RandomObservation(**param)
+
+    fov = np.array([1, 1]) * rve.DEG2RAD
+    npix = np.array([64, 64])
+    skydom = ift.RGSpace(npix, fov / npix)
+
+    sky_domain_dict = dict(npix_x=skydom.shape[0],
+                        npix_y=skydom.shape[1],
+                        pixsize_x=skydom.distances[0],
+                        pixsize_y=skydom.distances[1],
+                        pol_labels=['I'],
+                        times=[0.],
+                        freqs=[0.])
+
+    rve_sky_dct = {
+        "target": skydom,
+        "offset_mean": 0,
+        "offset_std": (1, 0.5),
+        "prefix": "logdiffuse_",
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+    }
+    rve_logsky = ift.SimpleCorrelatedField(**rve_sky_dct)
+    rve_sky = rve_logsky.exp()
+
+    rve_sky_dom = rve.default_sky_domain(sdom = skydom)
+    reshape_sky = ift.DomainChangerAndReshaper(rve_sky.target, rve_sky_dom)
+    rve_sky = reshape_sky @ rve_sky
+
+    jrve_dct_sky_offset = {"offset_mean": 0, "offset_std": (1, 0.5),}
+    jrve_dct_sky_ps = {
+        "shape": skydom.shape,
+        "distances": skydom.distances[0],
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "non_parametric_kind": "power"
+    }
+    cfm = jft.CorrelatedFieldMaker("logdiffuse_")
+    cfm.set_amplitude_total_offset(**jrve_dct_sky_offset)
+    cfm.add_fluctuations(**jrve_dct_sky_ps)
+    jrve_logsky = cfm.finalize()
+    jrve_sky = JRVE_Sky(jrve_logsky)
+
+    rve_op = partial(
+        rve.ImagingLikelihood,
+        observation = obs,
+        sky_operator = rve_sky,
+        epsilon = 1e-5,
+        do_wgridding = True
+        )
+    
+    jrve_op = partial(
+        jrve.ImagingLikelihood,
+        observation = obs,
+        sky_operator = jrve_sky,
+        sky_domain_dict = sky_domain_dict,
+        epsilon = 1e-5,
+        do_wgridding = True
+    )
+
+    return rve_op, jrve_op, obs, rve_sky.domain
+
+def prepare_cop(obs):
+    uantennas = rve.unique_antennas(obs)
+    antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)}
+    total_N = len(obs.polarization) * len(uantennas)
+
+    dt = 20  # s
+    zero_padding_factor = 2
+    tmin, tmax = rve.tmin_tmax(obs)
+    time_domain = ift.RGSpace(
+        ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt
+    )
+
+    polarizations = obs.polarization.to_str_list()
+    frequencies = [1]
+    antennas = uantennas
+
+    dofdex = np.arange(total_N)
+    rve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5), "dofdex": dofdex}
+    rve_dct_ps = {
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "dofdex": dofdex
+    }
+
+    jrve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5),}
+    jrve_dct_ps = {
+        "shape": time_domain.shape,
+        "distances": time_domain.distances[0],
+        "fluctuations": (.2, .1),
+        "loglogavgslope": (-4.0, 1),
+        "flexibility": (1, .3),
+        "asperity": None,
+        "non_parametric_kind": "power"
+    }
+
+    cfmph = ift.CorrelatedFieldMaker(prefix='phase_', total_N=total_N)
+    cfmph.add_fluctuations(time_domain, **rve_dct_ps)
+    cfmph.set_amplitude_total_offset(**rve_dct_offset)
+    rve_phase = cfmph.finalize(0)
+
+    cfmph = ift.CorrelatedFieldMaker(prefix='logamp_', total_N=total_N)
+    cfmph.add_fluctuations(time_domain, **rve_dct_ps)
+    cfmph.set_amplitude_total_offset(**rve_dct_offset)
+    rve_logamp = cfmph.finalize(0)
+
+    pdom, _, fdom = obs.vis.domain
+    reshape = ift.DomainChangerAndReshaper(rve_phase.target, [pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom])
+
+    rve_phase = reshape @ rve_phase
+    rve_logamp = reshape @ rve_logamp
+
+    jrve_phase = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"phase",polarizations,frequencies,antennas)
+    jrve_logamp = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"logamp",polarizations,frequencies,antennas)
+
+    rve_cop = rve.calibration_distribution(obs,rve_phase,rve_logamp,antenna_dct,None)
+    jrve_cop = jrve.CalibrationDistribution(obs,jrve_phase,jrve_logamp,time_domain.distances[0])
+
+    return rve_cop, jrve_cop
+
+
+def prepare_operators(cov_mode, cal_mode):
+    rve_op_partial, jrve_op_partial, obs, rve_sky_dom = prepare_partial_operators()
+
+    if cov_mode == "model":
+        rve_log_inv_cov = ift.NormalTransform(0,1,"log_inv_cov",obs.weight.val.size)
+        reshape = ift.DomainChangerAndReshaper(rve_log_inv_cov.target,obs.weight.domain)
+        rve_log_inv_cov = reshape @ rve_log_inv_cov
+
+        jrve_log_inv_cov = jft.NormalPrior(0,1,name="log_inv_cov",shape=obs.weight.val.shape)
+    else:
+        rve_log_inv_cov = None
+        jrve_log_inv_cov = None
+
+    rve_op_partial = partial(rve_op_partial,log_inverse_covariance_operator=rve_log_inv_cov)
+    jrve_op_partial = partial(jrve_op_partial,log_inverse_covariance_operator=jrve_log_inv_cov)
+
+    if cal_mode is not None:
+        rve_cop, jrve_cop = prepare_cop(obs)
+
+    if cal_mode == "cop":
+        rve_op = rve_op_partial(calibration_operator=rve_cop,calibration_field=None)
+        jrve_op = jrve_op_partial(calibration_operator=jrve_cop,calibration_field=None)
+    elif cal_mode == "cfld":
+        rve_cfld = ift.from_random(rve_cop.target,dtype="complex128")
+        jrve_cfld = jnp.array(rve_cfld.val)
+
+        rve_op = rve_op_partial(calibration_operator=None,calibration_field=rve_cfld)
+        jrve_op = jrve_op_partial(calibration_operator=None,calibration_field=jrve_cfld)
+    else:
+        rve_op = rve_op_partial(calibration_operator=None,calibration_field=None)
+        jrve_op = jrve_op_partial(calibration_operator=None,calibration_field=None)
+
+    return rve_op, jrve_op, obs, rve_sky_dom
+
+
+@pmp("log_inv_cov_mode",[None, "model"])
+@pmp("cal_op_mode", [None, "cop", "cfld"])
+def test_consistency_CalibrationLikelihood(log_inv_cov_mode,cal_op_mode):
+    rve_op, jrve_op, obs, rve_sky_dom = prepare_operators(log_inv_cov_mode, cal_op_mode)
+
+    rve_inp = ift.from_random(rve_op.domain)
+
+    if log_inv_cov_mode == "model":
+        jrve_inp = mixed_rve_dct_to_mixed_jrve(rve_inp.val,("log_inv_cov",),(obs.weight.val.shape,),rve_sky_dom.keys())
+    else:
+        if (cal_op_mode == "cop"):
+            jrve_inp = mixed_rve_dct_to_mixed_jrve(rve_inp.val,simple_cf_keys=rve_sky_dom.keys())
+        else:
+            jrve_inp = rve_cf_dict_to_jrve_cf_dict_simple(rve_inp.val)
+
+    rve_field = rve_op(rve_inp).val
+    jrve_field = jrve_op(jrve_inp)
+
+    assert_allclose(jrve_field,rve_field)
+
+
+
+
+    #s_jrve_fc_img_lh_dct_translated = rve_cf_dict_to_jrve_cf_dict_simple(s_rve_fc_img_lh_dct.val)
+    #s_jrve_fc_img_cfld_lh_dct_translated = rve_cf_dict_to_jrve_cf_dict_simple(s_rve_fc_img_cfld_lh_dct.val)
+
+    #s_jrve_fc_img_cop_lh_dct_translated = mixed_rve_dct_to_mixed_jrve(s_rve_fc_img_cop_lh_dct.val,simple_cf_keys=ift.from_random(rve_sky.domain).val.keys())
+    #s_jrve_vc_img_lh_dct_translated = mixed_rve_dct_to_mixed_jrve(s_rve_vc_img_lh_dct.val,("log_inv_cov",),(obs.weight.val.shape,),simple_cf_keys=ift.from_random(rve_sky.domain).val.keys())
+    #s_jrve_vc_img_cfld_lh_dct_translated = mixed_rve_dct_to_mixed_jrve(s_rve_vc_img_cfld_lh_dct.val,("log_inv_cov",),(obs.weight.val.shape,),simple_cf_keys=ift.from_random(rve_sky.domain).val.keys())
+    #s_jrve_vc_img_cop_lh_dct_translated = mixed_rve_dct_to_mixed_jrve(s_rve_vc_img_cop_lh_dct.val,("log_inv_cov",),(obs.weight.val.shape,),simple_cf_keys=ift.from_random(rve_sky.domain).val.keys())
\ No newline at end of file
diff --git a/test/test_jax_ports/to_re_helpers.py b/test/test_jax_ports/to_re_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fcae11587f0c36932779f87fd60f890d2250eec
--- /dev/null
+++ b/test/test_jax_ports/to_re_helpers.py
@@ -0,0 +1,63 @@
+import numpy as np
+import jax.numpy as jnp
+
+def rve_cf_dict_to_jrve_cf_dict(rve_cf_dct):
+    res = rve_cf_dct
+    for h in res.keys():
+        h_split = h.split("_")
+
+        if(h_split[1] == "spectrum"):
+            res[h] = np.swapaxes(res[h],1,2)
+    return res
+
+def rve_cf_dict_to_jrve_cf_dict_simple(rve_cf_dct):
+    res = rve_cf_dct
+
+    for h in rve_cf_dct.keys():
+        h_split = h.split("_")
+        
+        if (h_split[1] == "spectrum"):
+            res[h] = res[h].T
+    return res
+
+def single_rve_dct_to_single_jrve_dct(rve_dct,shape):
+    if len(rve_dct) != 1:
+        raise ValueError("Can only convert dict with single entry/element")
+    
+    key, rve_value = rve_dct.popitem()
+
+    jrve_value = np.reshape(rve_value,shape)
+
+    return {key: jnp.array(jrve_value)}
+
+def mixed_rve_dct_to_mixed_jrve(rve_dct,non_cf_keys=None,non_cf_shapes=None,simple_cf_keys=None):
+    if (non_cf_keys is not None) and (non_cf_shapes is not None):
+        if (len(non_cf_keys) != len(non_cf_shapes)) and (non_cf_keys is not None):
+            raise ValueError("Each non cf dict key should have an shape for the associate dict value")  
+    
+    if ((non_cf_keys is not None) and (simple_cf_keys is None)):
+        non_cf_rve_dcts = [({key: rve_dct[key]},non_cf_shapes[k]) for k,key in enumerate(non_cf_keys)]
+        cf_rve_dct = {key: rve_dct[key] for key in rve_dct.keys() if (key not in non_cf_keys)}
+        simple_cf_rve_dcts = {}
+    elif ((non_cf_keys is None) and (simple_cf_keys is not None)):
+        simple_cf_rve_dcts = {key: rve_dct[key] for key in simple_cf_keys}
+        cf_rve_dct = {key: rve_dct[key] for key in rve_dct.keys() if (key not in simple_cf_keys)}
+        non_cf_rve_dcts = []
+    else:
+        non_cf_rve_dcts = [({key: rve_dct[key]},non_cf_shapes[k]) for k,key in enumerate(non_cf_keys)]
+        simple_cf_rve_dcts = {key: rve_dct[key] for key in simple_cf_keys}
+        cf_rve_dct = {key: rve_dct[key] for key in rve_dct.keys() if ((key not in non_cf_keys) and (key not in simple_cf_keys))}
+
+    jrve_dct = {}
+
+    if len(cf_rve_dct) > 0:
+        jrve_dct = rve_cf_dict_to_jrve_cf_dict(cf_rve_dct)
+
+    if len(non_cf_rve_dcts) > 0:
+        for x in non_cf_rve_dcts:
+            jrve_dct.update(single_rve_dct_to_single_jrve_dct(x[0],x[1]))
+
+    if len(simple_cf_rve_dcts) > 0:
+        jrve_dct.update(rve_cf_dict_to_jrve_cf_dict_simple(simple_cf_rve_dcts))
+
+    return jrve_dct
\ No newline at end of file