diff --git a/demo/cygnusa_2ghz_fast_resolve.cfg b/demo/cygnusa_2ghz_fast_resolve.cfg new file mode 100644 index 0000000000000000000000000000000000000000..9e4f0f24711cb57834b97770ef4c0c3e6e0d17ab --- /dev/null +++ b/demo/cygnusa_2ghz_fast_resolve.cfg @@ -0,0 +1,34 @@ +[setup] +data=CYG-ALL-2052-2MHZ_RESOLVE_float64.npz +psf pixels=300 +cache_noise_kernel=noise_kernel_cygnusa_2ghz_sm +cache_response_kernel=response_kernel_cygnusa_2ghz_sm +noise_scaling=True +varcov=False + + + +[sky] +freq mode = single +polarization=I +space npix x = 1024 +space npix y = 512 +space fov x = 0.05deg +space fov y = 0.025deg + +stokesI diffuse space i0 zero mode offset = 18 +stokesI diffuse space i0 zero mode mean = 1 +stokesI diffuse space i0 zero mode stddev = 0.1 +stokesI diffuse space i0 fluctuations mean = 5 +stokesI diffuse space i0 fluctuations stddev = 1 +stokesI diffuse space i0 loglogavgslope mean = -2.0 +stokesI diffuse space i0 loglogavgslope stddev = 0.2 +stokesI diffuse space i0 flexibility mean = 1.2 +stokesI diffuse space i0 flexibility stddev = 0.4 +stokesI diffuse space i0 asperity mean = 0.2 +stokesI diffuse space i0 asperity stddev = 0.2 + +point sources mode = single +point sources locations = 0deg$0deg,0.35as$-0.22as +point sources alpha = 0.5 +point sources q = 0.2 diff --git a/demo/demo_fast_resolve.py b/demo/demo_fast_resolve.py new file mode 100644 index 0000000000000000000000000000000000000000..6017aa83f357c82c9b9e4ae9de45393e36f46ef1 --- /dev/null +++ b/demo/demo_fast_resolve.py @@ -0,0 +1,196 @@ +# %% +import nifty8 as ift +import nifty8.re as jft +import resolve as rve +import resolve.re as jrve +import numpy as np +import configparser +import sys +import matplotlib.pyplot as plt +from matplotlib.colors import LogNorm +import jax +import pickle + +from jax import random +import jax.numpy as jnp + + +jax.config.update("jax_enable_x64", True) + +seed = 42 + +key = random.PRNGKey(seed) +conf_name = "cygnusa_2ghz_fast_resolve.cfg" +out_dir = "demo_fast-resolve" + + +cfg = configparser.ConfigParser() +cfg.read(conf_name) + +data = cfg["setup"]["data"] +cnk = cfg["setup"]["cache_noise_kernel"] +crk = cfg["setup"]["cache_response_kernel"] +noise_scal = cfg["setup"].getboolean("noise_scaling") +varcov = cfg["setup"].getboolean("varcov") +if noise_scal and varcov: + raise ValueError() +obs = rve.Observation.load(data) +obs = obs.restrict_to_stokesi() +N_inv = ift.makeOp(obs.weight) + +sky_model, model_dict = jrve.sky_model(cfg["sky"]) +R, R_l, RNR, RNR_l = jrve.build_exact_r(obs, cfg["sky"], cfg["setup"]) +# %% +my_sky_model_func = lambda x: sky_model(x)[0, 0, 0, :, :] +my_sky_model = jft.Model(my_sky_model_func, domain=sky_model.domain) +mode = 1 +mean = 1.1 +a = 2 / (mean / mode - 1) + 1 +scale = mode * (a + 1) +inv_gamma = jft.InvGammaPrior( + a=a, + scale=scale, + name="noise", + shape=jax.ShapeDtypeStruct(R.domain.shape, float), + ) +noise_scaling = lambda x: 1 / (inv_gamma(x)) +noise_scaling = jft.Model(noise_scaling, domain=inv_gamma.domain) + +if noise_scal: + RNR_approx, N_inv_approx = jrve.build_approximations( + RNR, + RNR_l, + cache_noise_kernel=cnk, + cache_response_kernel=crk, + noise_scaling=noise_scaling, + ) +else: + RNR_approx, N_inv_approx = jrve.build_approximations( + RNR, RNR_l, cache_noise_kernel=cnk, cache_response_kernel=crk, varcov=varcov + ) + + +key, subkey = random.split(key) +if noise_scaling or varcov: + init_pos = jft.random_like(subkey, {**my_sky_model.domain, **noise_scaling.domain}) +else: + init_pos = jft.random_like(subkey, my_sky_model.domain) + + +def callback(pos, samp_at_pos, iter, n_major): + post_mean = jft.mean(tuple(my_sky_model(s) for s in samp_at_pos)) + plt.imshow(post_mean.T, origin="lower", norm=LogNorm()) + plt.colorbar() + plt.savefig(f"{out_dir}/major_{n_major}/sky_mean_{iter}.png", dpi=300) + plt.close() + + +# inference +d_new = R.adjoint(N_inv(obs.vis)) + +d_new = jnp.array(d_new.val) +init_pos = 1e-2 * jft.Vector(init_pos.copy()) + + +absdelta = 1e-10 + + +def get_draw_linear_kwargs(i): + kwargs = dict( + cg_name="cg", + cg_kwargs=dict( + absdelta=absdelta / 10.0, + maxiter=nstep_sampling(i), + miniter=nstep_sampling(i), + ), + ) + return kwargs + + +def get_nonlinearly_update_kwargs(i): + kwargs = dict( + minimize_kwargs=dict( + name=None, + xtol=1e-4, + cg_kwargs=dict(name=None), + maxiter=nstep_nl_sampling(i), + miniter=nstep_nl_sampling(i), + ) + ) + return kwargs + + +def get_kl_kwargs(i): + if i < 6: + min_cg = 6 + else: + min_cg = 20 + kwargs = dict( + minimize_kwargs=dict( + name="newton", + absdelta=absdelta, + cg_kwargs=dict(name="ncg", miniter=min_cg), + maxiter=nstep_newton(i), + miniter=nstep_newton(i), + energy_reduction_factor=1e-3, + ) + ) + return kwargs + + +def method(i): + return "linear_resample" + + +def nstep_sampling(ii): + if ii < 4: + return 100 + elif ii < 8: + return 500 + elif ii < 12: + return 1000 + + +def nstep_nl_sampling(ii): + return 0 if ii < 20 else 10 + + +def nstep_newton(ii): + if ii < 4: + return 5 + elif ii < 8: + return 10 + elif ii < 12: + return 20 + elif ii < 14: + return 30 + +draw_linear_kwargs = get_draw_linear_kwargs +nonlinearly_update_kwargs = get_nonlinearly_update_kwargs +kl_kwargs = get_kl_kwargs + +optimize_kwargs = dict( + R=RNR, + R_approx=RNR_approx, + sky=my_sky_model, + N_inv_sqrt=N_inv_approx, + data=d_new, + pos=init_pos, + draw_linear_kwargs=draw_linear_kwargs, + nonlinearly_update_kwargs=nonlinearly_update_kwargs, + kl_kwargs=kl_kwargs, + n_samples=2, + n_iter=2, + n_major_step=20, + key=key, + callback=callback, + out_dir=out_dir, + resume=True, + init_samples=None, + noise_scaling=noise_scal, + varcov=varcov, + varcov_op=noise_scaling, + method=method, +) +pos, samples, key = jrve.optimize(**optimize_kwargs) +# %% diff --git a/resolve/re/__init__.py b/resolve/re/__init__.py index ce2ba46e2dd79aa995eb542f9f9c8aad80c18a68..c96632e67661b85427538689ecc7360899fd7a42 100644 --- a/resolve/re/__init__.py +++ b/resolve/re/__init__.py @@ -1,3 +1,8 @@ from .sky_model import sky_model_diffuse, sky_model_points, sky_model -from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc \ No newline at end of file +from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc +from .radio_response import build_exact_r, build_approximations +from .optimize import optimize +from .calibration import CalibrationInterpolator, CalibrationDistribution +from .likelihood import CalibrationLikelihood, ImagingLikelihood +from .sugar import Bulk_CF_AntennaTimeDomain, CalibrationAssembler, SkyAssembler \ No newline at end of file diff --git a/resolve/re/calibration.py b/resolve/re/calibration.py new file mode 100644 index 0000000000000000000000000000000000000000..f7b3cb6597b1ab6e9edf6597b22c7f306b37b7e6 --- /dev/null +++ b/resolve/re/calibration.py @@ -0,0 +1,97 @@ +import nifty8.re as jft +import jax.scipy as jsc +import jax.numpy as jnp + +from jax.tree_util import Partial +from jax import vmap + +from ..data.observation import Observation + +class CalibrationDistribution(jft.Model): + """ + Computes the calibration operator from given observation data. + + Parameters + ---------- + observation: Observation + Observation object from which are the antenna and temporal information corresponding to + the visibilites are extracted. + phase_fields: jft.Model or Bulk_CF_AntennaTimeDomain (preferred) + Correlated fields on antenna-time space for phases of calibration solutions. + log_amplitude_fields: jft.Model or Bulk_CF_AntennaTimeDomain (preferred) + Correlated fields on antenna-time space for log amplitude of calibration solutions. + dt: float + Distances between time points on time axis. Has to be the same distance of time points, + which is used for phase_fields and log_amplitude fields. + + Note + ---- + Currently, only uniformly spaced time axis are supported. + """ + def __init__( + self, + observation: Observation, + phase_fields: jft.Model, + log_amplitude_fields: jft.Model, + dt: float + ): + ap = observation.antenna_positions + target_shape = observation.vis.shape + self._cop1 = CalibrationInterpolator(jnp.asarray(ap.ant1), jnp.asarray(ap.time), dt, target_shape) + self._cop2 = CalibrationInterpolator(jnp.asarray(ap.ant2), jnp.asarray(ap.time), dt, target_shape) + + self._phases = phase_fields + self._logamps = log_amplitude_fields + + super().__init__(init=self._phases.init | self._logamps.init) + + def __call__(self,x): + res_logamp = jnp.real(self._cop1(self._logamps(x)) + self._cop2(self._logamps(x))) + res_phase = jnp.real(self._cop1(self._phases(x)) - self._cop2(self._phases(x)))*1j + + return jnp.exp(res_logamp + res_phase) + +class CalibrationInterpolator(): + """ + Interpolates visibilites for a specific sequence of antenna-time pairs given the visibilities + on an evenly spaced antenna-time grid. + + Parameters + ---------- + ant_col: jnp.ndarray + Antenna points to which one wants to interpolate + time_col: jnp.ndarry + Time points to which one wants to interpolate + dt: float + Distances between time points on time axis. + target_shape: tuple + Shape of output when calling this class. First element should encode number of + polarization directions and last element should encode number of frequencies + + Note + ---- + Currently, only uniformly spaced time axis are supported. + """ + def __init__( + self, + ant_col: jnp.ndarray, + time_col: jnp.ndarray, + dt: float, + target_shape: tuple + ): + + coords = [ant_col,time_col/dt] + + self._li = Partial(jsc.ndimage.map_coordinates,coordinates=coords,order=1) + + self._n_pol,_,self._n_freq = target_shape + + def __call__(self,x): + res = vmap( + vmap( + lambda pol, freq: self._li(x[pol, :, :, freq]) + ,in_axes=(None,0), out_axes=1 + ),in_axes=(0,None) + ) + + return res(jnp.arange(self._n_pol), jnp.arange(self._n_freq)) \ No newline at end of file diff --git a/resolve/re/fast-resolve_cal.py b/resolve/re/fast-resolve_cal.py new file mode 100644 index 0000000000000000000000000000000000000000..3cac6b2e41f537cbffb2e651d84fb5bb4cb1ae5f --- /dev/null +++ b/resolve/re/fast-resolve_cal.py @@ -0,0 +1,244 @@ +import nifty8 as ift +import nifty8.re as jft + +import resolve.re as jrve + +import jax.numpy as jnp +import jax.random as jr + +def calibration_step( + cal_assemblers, + config_minimizer, + key, + assembler_keys = None, + resume = False, + odir = None + ): + + k_i, k_0 = jr.split(key,2) + + lh_param_keys = ("observation","calibration_operator","model_visibilities","log_inverse_covariance_operator","likelihood_label") + cal_assembler_attr_keys = ("obs","scaled_cop","model_vis","log_inv_cov","lh_label") + assembler_keys = cal_assemblers.keys() if assembler_keys is None else assembler_keys + + lh_components = {i: [getattr(cal_assemblers[k],j) for k in assembler_keys] for i,j in zip(lh_param_keys,cal_assembler_attr_keys)} + lh = jrve.CalibrationLikelihood(**lh_components) + + optimize_kl_kwargs = dict( + likelihood = lh, + position_or_samples = jft.random_like(k_0,lh.domain), + key = k_i, + resume = resume, + odir = odir, + **config_minimizer + ) + + return jft.optimize_kl(**optimize_kl_kwargs) + +def imaging_step( + sky_assembler, + calibration_field, + config_response, + config_minimizer, + key, + init_samples = None, + cache_noise_kernel="None", + cache_response_kernel="None", + resume = False, + odir = None + ): + + obs = sky_assembler.obs + sky_model = sky_assembler.sky + noise_scaling_op = sky_assembler.noise_scaling + varcov_op = sky_assembler.varcov + + noise_scaling = False if noise_scaling_op is None else True + varcov = False if varcov_op is None else True + + k_i, k_0 = jr.split(key,2) + + R, _, RNR, RNR_l = jrve.build_exact_r(obs,calibration_field=calibration_field,**config_response) + + N_inv = ift.makeOp(obs.weight) + d_0 = jnp.array(R.adjoint(N_inv(obs.vis)).val) + + + if noise_scaling: + RNR_approx, N_inv_approx = jrve.build_approximations(RNR,RNR_l,cache_noise_kernel=cache_noise_kernel,cache_response_kernel=cache_response_kernel,noise_scaling=noise_scaling_op) + init_pos = jft.random_like(k_0,{**sky_model.domain, **noise_scaling_op.domain}) + else: + RNR_approx, N_inv_approx = jrve.build_approximations(RNR,RNR_l,cache_noise_kernel=cache_noise_kernel,cache_response_kernel=cache_response_kernel,varcov=varcov) + init_pos = jft.random_like(k_0,{**sky_model.domain, **varcov_op.domain}) if varcov else jft.random_like(k_0,sky_model.domain) + + optimize_kl_kwargs = dict( + R = RNR, + R_approx = RNR_approx, + sky = sky_model, + N_inv_sqrt = N_inv_approx, + data = d_0, + pos = 1e-2*jft.Vector(init_pos.copy()), + resume = resume, + init_samples = init_samples, + noise_scaling = noise_scaling, + varcov = varcov, + varcov_op = varcov_op, + key = k_i, + out_dir = odir, + **config_minimizer + ) + + pos, samples, _ = jrve.optimize(**optimize_kl_kwargs) + + post_sky_mean = ift.makeField(R.domain, np.array(jft.mean(tuple(sky_model(s) for s in samples)))) + new_sci_model_vis = jnp.array(R(post_sky_mean).val) + + return samples, pos, new_sci_model_vis + +def resume_preparation( + obs, + sky_model, + odir, + config +): + config_response = config["response"] + save_all = config["imaging"]["save_all"] + + lfile = f"{odir}/last_started_major" + if path.isfile(lfile): + with open(lfile) as f: + last_started_major = int(f.read()) + lfile = f"{odir}/major_{last_started_major}/last_finished_iteration" + if path.isfile(lfile): + with open(lfile) as f: + last_finished_index = int(f.read()) + else: + raise FileNotFoundError(f"{lfile} could not be found") + else: + raise FileNotFoundError(f"{lfile} could not be found") + + if save_all: + samples = pickle.load( + open( + f"{odir}/major_{last_started_major}/samples_{last_finished_index}.p", + "rb", + ) + ) + else: + samples = pickle.load( + open( + f"{odir}/last_samples.p", + "rb", + ) + ) + + sp_sky_dom =rve.sky_model._spatial_dom(config_response["conf_sky"]) + sky_dom = rve.default_sky_domain(sdom=sp_sky_dom) + R = rve.InterferometryResponse(obs, sky_dom, config_response["do_wgridding"], config_response["epsilon"], config_response["verbosity"], config_response["nthreads"]) + dch = ift.DomainChangerAndReshaper(R.domain[3], R.domain) + R = R @ dch + + post_sky_mean = ift.makeField(R.domain, np.array(jft.mean(tuple(sky_model(s) for s in samples)))) + return jnp.array(R(post_sky_mean).val) + + +def single_cal_fr_run( + cal_assemblers, + sky_assembler, + config, + key, + n_vi_runs=(1,1), + resume = False, + odirs=(None,None) + ): + key_cal, key_img = jr.split(key,2) + config_response = config["response"] + config_cal = config["calibration"] + config_img = config["imaging"] + assembler_keys = list(cal_assemblers.keys()) + + if cal_assemblers["sci"].model_vis is None: + assembler_keys.remove("sci") + + config_cal["n_total_iterations"] += n_vi_runs[0] + + print("-"*5,"Calibration step","-"*57) + + cal_samples, cal_state = calibration_step(cal_assemblers,config_cal,key_cal,assembler_keys,resume,odirs[0]) + + antenna_gains = jft.mean(tuple(cal_assemblers["sci"].cop(s) for s in cal_samples)) + cfld = ift.makeField(cal_assemblers["sci"].obs.vis.domain,antenna_gains) + + config_img["n_major_step"] += n_vi_runs[1] + print("-"*5,"Imaging step","-"*61) + + img_samples, img_state, cal_assemblers["sci"].model_vis = imaging_step(sky_assembler,cfld,config_response,config_img,key_img,resume=resume,odir=odirs[1]) + + + return dict(response=config_response,calibration=config_cal,imaging=config_img),dict(cal=cal_samples,img=img_samples), dict(cal=cal_state,img=img_state) + +def fastresolve_with_calibration( + sky_assembler, + sci_cal_assembler, + config, + key, + phase_cal_assembler = None, + flux_cal_assembler = None, + n_iterations=1, + n_vi_runs_per_iteration=(1,1), + resume = False, + odir = None, + return_latest_samples=False, + return_latest_states=False, + return_latest_config_update=False + ): + if resume and (odir is None): + raise ValueError("Set folder name from where to resume the inference runs") + + odirs = (None,None) if odir is None else (f"{odir}/calibration",f"{odir}/imaging") + + new_config = deepcopy(config) + sca = dataclasses.replace(sci_cal_assembler) + + if resume: + sca.model_vis = resume_preparation(sca.obs,sky_assembler.sky,odirs[1],config) + + with open(f"{odir}/run_info.json","r") as f: + loaded_run_info = json.load(f) + + n_cal_vi_steps_done = loaded_run_info["cal_vi_iterations"] + n_img_major_steps_done = loaded_run_info["img_n_major_steps"] + + print("Resume: Overwrite given settings for n_total_iterations and n_major_step") + + config["calibration"]["n_total_iterations"] = n_cal_vi_steps_done + config["imaging"]["n_major_step"] = n_img_major_steps_done + + print(f"Loaded data of previous {n_cal_vi_steps_done} calibration steps and {n_img_major_steps_done} imaging major steps") + + + cal_assemblers = dict(sci=sca) + cal_assemblers.update(dict(phase=phase_cal_assembler) if phase_cal_assembler is not None else {}) + cal_assemblers.update(dict(flux=flux_cal_assembler) if flux_cal_assembler is not None else {}) + + + for k in range(n_iterations): + print("-"*80) + print(" "*4,f"{k+1}-th calibration and fast-resolve iteration") + print("-"*80) + + new_config, samples_dct, state_dct = single_cal_fr_run(cal_assemblers,sky_assembler,new_config,key,n_vi_runs_per_iteration,resume,odirs) + resume = True + + with open(f"{odir}/run_info.json","w") as f: + json.dump(dict( + cal_vi_iterations = new_config["calibration"]["n_total_iterations"], + img_n_major_steps = new_config["imaging"]["n_major_step"] + ),f) + + returns = {} + returns.update(dict(config=new_config) if return_latest_config_update else {}) + returns.update(dict(samples=samples_dct) if return_latest_samples else {}) + returns.update(dict(states=state_dct) if return_latest_states else {}) + + return returns \ No newline at end of file diff --git a/resolve/re/likelihood.py b/resolve/re/likelihood.py new file mode 100644 index 0000000000000000000000000000000000000000..27d9fe571a57abb556090572578a1684772dc3c8 --- /dev/null +++ b/resolve/re/likelihood.py @@ -0,0 +1,228 @@ +import jax.numpy as jnp +import nifty8.re as jft + +from typing import Union, Iterable + +from .response import InterferometryResponse +from .calibration import CalibrationDistribution +from .likelihood_models import * +# The classes from .likelihoods model are: +# - ModelCalibrationLikelihoodFixedCovariance +# - ModelCalibrationLikelihoodVariableCovariance +# - ModelImagingLikelihoodFixedCovarianceCalibrationField +# - ModelImagingLikelihoodFixedCovarianceCalibrationOperator +# - ModelImagingLikelihoodVariableCovarianceCalibrationField +# - ModelImagingLikelihoodVariableCovarianceCalibrationOperator + +from ..util import _obj2list, _duplicate +from ..data.observation import Observation + +from jaxlib.xla_extension import ArrayImpl +from nifty8.re.model import Model + +def CalibrationLikelihood( + observation: Union[Observation, Iterable[Observation]], + calibration_operator: Union[CalibrationDistribution, Iterable[CalibrationDistribution]], + model_visibilities: Union[jnp.ndarray, Iterable[jnp.ndarray]], + log_inverse_covariance_operator: Union[jft.Model,Iterable[jft.Model]] = None, + likelihood_label: Union[str,Iterable[str]] = None +): + """Versatile calibration likelihood class + + It returns an operator that computes: + + residual = calibration_operator * model_visibilities + likelihood = 0.5 * residual^dagger @ inverse_covariance @ residual + + If an inverse_covariance_operator is passed, it is inserted into the above + formulae. If it is not passed, 1/observation.weights is used as inverse + covariance. + + Parameters + ---------- + observation : Observation or Iterable of Observations + Observation object from which observation.vis and potentially + observation.weight is used for computing the likelihood. + + calibration_operator : CalibrationDistribution or Iterable of CalibrationDistribution + Target needs to be the same as observation.vis. The same amount of elements as + number of observations should be provided. + + model_visibilities jnp.ndarray or Iterable of jnp.ndarray + Known model visiblities that are used for calibration. Needs to be + defined on the same domain as `observation.vis`. The same amount of elements as + number of observations should be provided. + + log_inverse_covariance_operator : jft.Model or Iterable of jft.Model + Optional. Target needs to be the same space as observation.vis. If it is + not specified, observation.wgt is taken as covariance. If used, the same + amount of elements as number of observations should be provided. + + likelihood_label: string or Iterable of string + Optional. Append label to individual likelihood which is shown in the minisanity + for overview. If used, the same amount of elements as number of observations + should be provided. + """ + + + obs = _obj2list(observation,Observation) + cops = _obj2list(calibration_operator,CalibrationDistribution) + model_d = _obj2list(model_visibilities,ArrayImpl) + log_inv_covs = _duplicate(_obj2list(log_inverse_covariance_operator,Model),len(obs)) + labels = _duplicate(_obj2list(likelihood_label,str),len(obs)) + + if len(set([len(obs),len(cops),len(model_d)])) != 1: + raise ValueError("observation, calibration_operator and model_visibilities must have the same number of elements") + + lhs = [] + for ii, (oo, cop, model_vis, log_inv_cov,label) in enumerate(zip(obs,cops,model_d,log_inv_covs,labels)): + mask = jnp.asarray(oo.mask.val) + + flagged_data = jnp.asarray(oo.vis.val)[mask] + + if log_inv_cov is None: + model = ModelCalibrationLikelihoodFixedCovariance(cop,model_vis,mask) + flagged_inv_cov = jnp.asarray(oo.weight.val)[mask] + + lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov) + + else: + model = ModelCalibrationLikelihoodVariableCovariance(cop,model_vis,log_inv_cov,mask) + + lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=jnp.iscomplexobj(oo.vis.val)) + + lh_with_model = lh.amend(model) + lh_with_model._domain = jft.Vector(lh_with_model._domain) + + if label is not None: + lh_with_model._name = label + lh_with_model._name_n_ws = " "*(len(str(len(obs)))-len(str(ii))) + + lhs.append(lh_with_model) + + if set(labels) == {None}: + return jft.likelihood.LikelihoodSum(*lhs) + else: + return jft.likelihood.LikelihoodSum(*lhs, _key_template="lh {likelihood._name_n_ws}{index} | {likelihood._name}") + + +def ImagingLikelihood( + observation: Union[Observation, Iterable[Observation]], + sky_operator: Union[jft.Model,Iterable[jft.Model]], + sky_domain_dict: dict, + epsilon: float, + do_wgridding: bool, + log_inverse_covariance_operator: Union[jft.Model,Iterable[jft.Model]] = None, + calibration_operator: Union[CalibrationDistribution, Iterable[CalibrationDistribution]] = None, + calibration_field: Union[jnp.ndarray, Iterable[jnp.ndarray]] = None, + likelihood_label: Union[str,Iterable[str]] = None, + verbosity: int = 0, + nthreads: int = 1, + backend: str = "ducc0", +): + """Versatile likelihood class. + + If a calibration operator is passed, it returns an operator that computes: + + residual = calibration_operator * (R @ sky_operator) + likelihood = 0.5 * residual^dagger @ inverse_covariance @ residual + + Otherwise, it returns an operator that computes: + + residual = R @ sky_operator + likelihood = 0.5 * residual^dagger @ inverse_covariance @ residual + + If an inverse_covariance_operator is passed, it is inserted into the above + formulae. If it is not passed, 1/observation.weights is used as inverse + covariance. + + Parameters + ---------- + observation : Observation or Iterable of Observation + Observation objects from which vis, uvw, freq and potentially weight + are used for computing the likelihood. + + sky_operator : jft.Model + Operator that generates sky. + + sky_domain_dict: dict + A dictionary providing information about the discretization of the sky. + + epsilon : float + + do_wgridding : bool + + log_inverse_covariance_operator : jft.Model or Iterable of jft.Model + Optional. Target needs to be the same space as observation.vis. If it + is not specified, observation.wgt is taken as covariance. If used, the same + amount of elements as number of observations should be provided. + + calibration_operator : CalibrationDistribution or Iterable of CalibrationDistribution + Optional. Target needs to be the same as observation.vis. If used, the same + amount of elements as number of observations should be provided. + + calibration_field: jnp.ndarray or Iterable of jnp.ndarray + Optional. Domain needs to be the same as observation.vis. If used, the same + amount of elements as number of observations should be provided. + + likelihood_label: string or Iterable of string + Optional. Append labels to individual likelihoods which are shown in the minisanity + for overview. If used, the same amount of elements as number of observations + should be provided. + + verbosity : int + + nthreads : int + + backend: string + + Note + ---- + For each observation only either calibration_operator or calibration_field + can be set. + """ + + obs = _obj2list(observation,Observation) + cops = _duplicate(_obj2list(calibration_operator,CalibrationDistribution),len(obs)) + cflds = _duplicate(_obj2list(calibration_field,ArrayImpl),len(obs)) + log_inv_covs = _duplicate(_obj2list(log_inverse_covariance_operator,Model),len(obs)) + labels = _duplicate(_obj2list(likelihood_label,str),len(obs)) + + lhs = [] + + for ii, (oo, cop, cfld, log_inv_cov,label) in enumerate(zip(obs,cops,cflds,log_inv_covs,labels)): + if cfld is not None and cop is not None: + raise ValueError( + f"Can't set calibration operator and calibration field simultaneously at index {ii}" + ) + + R = InterferometryResponse(oo,sky_domain_dict,do_wgridding,epsilon,nthreads,verbosity,backend) + mask = jnp.asarray(oo.mask.val) + + flagged_data = jnp.asarray(oo.vis.val)[mask] + + if log_inv_cov is None: + model = ModelImagingLikelihoodFixedCovariance(R,sky_operator,mask,cop,cfld) + + flagged_inv_cov = jnp.asarray(oo.weight.val)[mask] + + lh = jft.Gaussian(data=flagged_data, noise_cov_inv=flagged_inv_cov) + + else: + model = ModelImagingLikelihoodVariableCovariance(R,sky_operator,log_inv_cov,mask,cop,cfld) + + lh = jft.VariableCovarianceGaussian(data=flagged_data,iscomplex=jnp.iscomplexobj(oo.vis.val)) + + lh_with_model = lh.amend(model) + lh_with_model._domain = jft.Vector(lh_with_model._domain) + + if label is not None: + lh_with_model._name = label + lh_with_model._name_n_ws = " "*(len(str(len(obs)))-len(str(ii))) + + lhs.append(lh_with_model) + + if set(labels) == {None}: + return jft.likelihood.LikelihoodSum(*lhs) + else: + return jft.likelihood.LikelihoodSum(*lhs, _key_template="lh {likelihood._name_n_ws}{index} | {likelihood._name}") diff --git a/resolve/re/likelihood_models.py b/resolve/re/likelihood_models.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3b3ffd1cd67484bdc8e7c7d779ccea7dba85b2 --- /dev/null +++ b/resolve/re/likelihood_models.py @@ -0,0 +1,175 @@ +import nifty8.re as jft +import jax.numpy as jnp + +from .calibration import CalibrationDistribution + +from typing import Callable + +class ModelCalibrationLikelihoodFixedCovariance(jft.Model): + """ + Provides a flagged data model for calibration + + Parameters + ---------- + cop: CalibrationDistribution + Calibration operator + model_visibilities: jnp.ndarray + Assumed visibilities of the point source. + mask: jnp.array + Mask as boolean numpy array for good visibilites + """ + def __init__( + self, + cop: CalibrationDistribution, + model_visibilities: jnp.ndarray, + mask: jnp.ndarray + ): + self._cop = cop + self._vis = model_visibilities + self._mask = mask + + super().__init__(init=self._cop.init) + + def __call__(self, x): + data_model = self._vis*self._cop(x) + flagged_data_model = data_model[self._mask] + + return flagged_data_model + +class ModelCalibrationLikelihoodVariableCovariance(jft.Model): + """ + Provides a combined flagged data model and flagged inverse covariance model for calibration + + Parameters + ---------- + cop: CalibrationDistribution + Calibration operator + model_visibilities: jnp.ndarray + Assumed visibilities of the point source. + log_inverse_covariance_model: jft.Model + Model for log inverse covariance + mask: jnp.array + Mask as boolean numpy array for good visibilites + """ + def __init__( + self, + cop: CalibrationDistribution, + model_visibilities: jnp.ndarray, + log_inverse_covariance_model: jft.Model, + mask: jnp.ndarray + ): + self._cop = cop + self._vis = model_visibilities + self._mask = mask + self._log_inv_cov = log_inverse_covariance_model + + super().__init__(init=self._cop.init | self._log_inv_cov.init) + + def __call__(self, x): + data_model = self._vis*self._cop(x) + flagged_data_model = data_model[self._mask] + + inv_std = jnp.exp(0.5*self._log_inv_cov(x)) + flagged_inv_std = inv_std[self._mask] + + return (flagged_data_model,flagged_inv_std) + + +class ModelImagingLikelihoodFixedCovariance(jft.Model): + """ + Provides a flagged data model for imaging given an optional calibration operator or + calibration field + + Parameters + ---------- + R: Callable + Response operator function + sky: jft.Model + Model for sky + mask: jnp.array + Mask as boolean numpy array for good visibilites + calibration_operator: CalibrationDistribution + Optional. Calibration operator + calibration_field: jnp.ndarray + Optional. Calibration field + """ + def __init__( + self, + R: Callable, + sky_operator: jft.Model, + mask: jnp.ndarray, + calibration_operator: CalibrationDistribution = None, + calibration_field: jnp.ndarray = None + ): + + self._mask = mask + inits = sky_operator.init + + if(calibration_operator is not None): + self._data_model = lambda x: calibration_operator(x)*R(sky_operator(x)) + inits = inits | calibration_operator.init + elif(calibration_field is not None): + self._data_model = lambda x: calibration_field*R(sky_operator(x)) + else: + self._data_model = lambda x: R(sky_operator(x)) + + super().__init__(init=inits) + + def __call__(self,x): + data_model = self._data_model(x) + flagged_data_model = data_model[self._mask] + + return flagged_data_model + +class ModelImagingLikelihoodVariableCovariance(jft.Model): + """ + Provides a combined flagged data model and flagged inverse covariance model for imaging + given an optional calibration operator or calibration field + + Parameters + ---------- + R: Callable + Response operator function + sky: jft.Model + Model for sky + log_inverse_covariance_model: jft.Model + Model for log inverse covariance + mask: jnp.array + Mask as boolean numpy array for good visibilites + calibration_operator: CalibrationDistribution + Optional. Calibration operator + calibration_field: jnp.ndarray + Optional. Calibration field + """ + def __init__( + self, + R: Callable, + sky_operator: jft.Model, + log_inverse_covariance_model: jft.Model, + mask: jnp.ndarray, + calibration_operator: CalibrationDistribution = None, + calibration_field: jnp.ndarray = None + ): + + self._mask = mask + self._log_inv_cov = log_inverse_covariance_model + inits = sky_operator.init | self._log_inv_cov.init + + if(calibration_operator is not None): + self._data_model = lambda x: calibration_operator(x)*R(sky_operator(x)) + inits = inits | calibration_operator.init + elif(calibration_field is not None): + self._data_model = lambda x: calibration_field*R(sky_operator(x)) + else: + self._data_model = lambda x: R(sky_operator(x)) + + super().__init__(init=inits) + + def __call__(self,x): + data_model = self._data_model(x) + flagged_data_model = data_model[self._mask] + + inv_std = jnp.exp(0.5*self._log_inv_cov(x)) + flagged_inv_std = inv_std[self._mask] + + return (flagged_data_model,flagged_inv_std) diff --git a/resolve/re/optimize.py b/resolve/re/optimize.py new file mode 100644 index 0000000000000000000000000000000000000000..0835e96ec4eaaa3b8b7c98c904f46be2a4287dc3 --- /dev/null +++ b/resolve/re/optimize.py @@ -0,0 +1,286 @@ +import jax +import numpy as np +import pickle +import jax.numpy as jnp +import nifty8.re as jft +import nifty8 as ift + +from functools import partial +from os import makedirs +from os.path import isfile + + +from nifty8.re.optimize_kl import _kl_vg +from nifty8.re.optimize_kl import _kl_met +from nifty8.re.optimize_kl import draw_linear_residual +from nifty8.re.optimize_kl import nonlinearly_update_residual +from nifty8.re.optimize_kl import get_status_message +# from nifty8.re.evi import _nonlinearly_update_residual_functions + + +def optimize( + R, + R_approx, + sky, + N_inv_sqrt, + data, + pos, + draw_linear_kwargs, + nonlinearly_update_kwargs, + kl_kwargs, + n_samples, + n_iter, + n_major_step, + key, + callback=None, + out_dir=None, + resume=False, + init_samples=None, + noise_scaling=False, + varcov=False, + varcov_op=None, + method="linear_sample", + save_all=True, +): + if varcov: + jax_0_data = jnp.broadcast_to(0.0 + 0j, data.shape) + else: + jax_0_data = jnp.broadcast_to(0.0, data.shape) + if not out_dir == None: + makedirs(out_dir, exist_ok=True) + lfile = f"{out_dir}/last_started_major" + last_started_major = 0 + last_finished_index = -1 + if resume and isfile(lfile): + with open(lfile) as f: + last_started_major = int(f.read()) + lfile = f"{out_dir}/major_{last_started_major}/last_finished_iteration" + if resume and isfile(lfile): + with open(lfile) as f: + last_finished_index = int(f.read()) + + def residual_signal_response(x, old_reconstruction): + return R_approx(sky(x) - old_reconstruction) + + def noise_weighted_residual(x, old_reconstruction, residual_data): + if noise_scaling: + return N_inv_sqrt( + { + "sky": residual_signal_response(x, old_reconstruction) + - residual_data, + "noise": x["noise"], + } + ) + else: + return N_inv_sqrt( + {"sky": residual_signal_response(x, old_reconstruction) - residual_data} + ) + + def noise_weighted_residual_varcov(x, old_reconstruction, residual_data): + return [ + N_inv_sqrt( + {"sky": residual_signal_response(x, old_reconstruction) - residual_data} + ), + varcov_op(x), + ] + + kl_map = jax.vmap + kl_reduce = partial(jax.tree_map, partial(jnp.mean, axis=0)) + + def get_lh(old_reconstruction, residual_data): + if varcov: + lh = jft.VariableCovarianceGaussian(jax_0_data, iscomplex=True).amend( + partial( + noise_weighted_residual_varcov, + old_reconstruction=old_reconstruction, + residual_data=residual_data, + ) + ) + else: + lh = jft.Gaussian(jax_0_data).amend( + partial( + noise_weighted_residual, + old_reconstruction=old_reconstruction, + residual_data=residual_data, + ) + ) + return lh + + @jax.jit + def my_kl_vg(primals, primals_samples, *, old_reconstruction, residual_data): + lh = get_lh(old_reconstruction, residual_data) + return _kl_vg(lh, primals, primals_samples, map=kl_map, reduce=kl_reduce) + + @jax.jit + def my_kl_metric( + primals, tangents, primals_samples, *, old_reconstruction, residual_data + ): + lh = get_lh(old_reconstruction, residual_data) + return _kl_met( + lh, primals, tangents, primals_samples, map=kl_map, reduce=kl_reduce + ) + + @jax.jit + def my_draw_linear_residual( + pos, key, *, old_reconstruction, residual_data, **kwargs + ): + lh = get_lh(old_reconstruction, residual_data) + return draw_linear_residual(lh, pos, key, **kwargs) + + def my_nonlinearly_update_residual( + pos, + residual_sample, + metric_sample_key, + metric_sample_sign, + *, + old_reconstruction, + residual_data, + **kwargs, + ): + lh = get_lh(old_reconstruction, residual_data) + return nonlinearly_update_residual( + lh, + pos, + residual_sample, + metric_sample_key, + metric_sample_sign, + # _nonlinear_update_funcs=_nonlin_funcs, + **kwargs, + ) + + def my_stat_mes(samples, state, *, old_reconstruction, residual_data, **kwargs): + lh = get_lh(old_reconstruction, residual_data) + return get_status_message( + samples, state, lh.normalized_residual, name="optVI test", **kwargs + ) + + if last_finished_index > -1: + if save_all: + opt_vi_state = pickle.load( + open( + f"{out_dir}/major_{last_started_major}/optVIstate_it_{last_finished_index}.p", + "rb", + ) + ) + samples = pickle.load( + open( + f"{out_dir}/major_{last_started_major}/samples_{last_finished_index}.p", + "rb", + ) + ) + else: + opt_vi_state = pickle.load( + open( + f"{out_dir}/last_optVIstate.p", + "rb", + ) + ) + samples = pickle.load( + open( + f"{out_dir}/last_samples.p", + "rb", + ) + ) + + sub_val = jft.mean(tuple(sky(s) for s in samples)) + post_mean = np.array(sub_val) + post_mean = ift.makeField(R.domain, post_mean) + residual_data = data - R(post_mean).val + else: + opt_vi_state = None + samples = None + if not init_samples == None: + sub_val = jft.mean(tuple(sky(s) for s in init_samples)) + post_mean = np.array(sub_val) + post_mean = ift.makeField(R.domain, post_mean) + residual_data = data - R(post_mean).val + else: + residual_data = data + sub_val = jnp.zeros(data.shape, dtype=data.dtype) + + # init optVI + opt_vi = jft.OptimizeVI( + None, + n_iter, + _kl_value_and_grad=my_kl_vg, + _kl_metric=my_kl_metric, + _draw_linear_residual=my_draw_linear_residual, + _nonlinearly_update_residual=my_nonlinearly_update_residual, + _get_status_message=my_stat_mes, + residual_map="lmap", + ) + + update_kwargs = dict( + old_reconstruction=sub_val, + residual_data=residual_data, + ) + + # call optVI.init + if opt_vi_state is None: + print("init opt vi") + opt_vi_state = opt_vi.init_state( + key, + nit=0, + n_samples=n_samples, + draw_linear_kwargs=draw_linear_kwargs, + nonlinearly_update_kwargs=nonlinearly_update_kwargs, + kl_kwargs=kl_kwargs, + sample_mode=method, + ) + if samples is None: + samples = jft.Samples(pos=pos, samples=None, keys=None) + + for n_major in range(last_started_major, n_major_step): + if not out_dir == None: + makedirs(out_dir + f"/major_{n_major}", exist_ok=True) + minor_start = opt_vi_state.nit - n_iter * n_major + print("minor_start: ", minor_start) + print("n_iter: ", n_iter) + print("n_major:", n_major) + print("opt_vi_state.nit: ", opt_vi_state.nit) + for i in range(minor_start, n_iter): + # do opt_vi update + samples, opt_vi_state = opt_vi.update( + samples, opt_vi_state, **update_kwargs + ) + msg = opt_vi.get_status_message( + samples, + opt_vi_state, + old_reconstruction=sub_val, + residual_data=residual_data, + ) + jft.logger.info(msg) + + if not callback == None: + callback(opt_vi_state.minimization_state.x, samples, i, n_major) + if not out_dir == None: + if save_all: + pickle.dump( + opt_vi_state, + open(f"{out_dir}/major_{n_major}/optVIstate_it_{i}.p", "wb"), + ) + pickle.dump( + samples, open(f"{out_dir}/major_{n_major}/samples_{i}.p", "wb") + ) + else: + pickle.dump( + opt_vi_state, + open(f"{out_dir}/last_optVIstate.p", "wb"), + ) + pickle.dump(samples, open(f"{out_dir}/last_samples.p", "wb")) + with open( + f"{out_dir}/major_{n_major}/last_finished_iteration", "w" + ) as f: + f.write(str(i)) + with open(f"{out_dir}/last_started_major", "w") as f: + f.write(str(n_major)) + + sub_val = jft.mean(tuple(sky(s) for s in samples)) + post_mean = np.array(sub_val) + post_mean = ift.makeField(R.domain, post_mean) + residual_data = data - R(post_mean).val + + update_kwargs["old_reconstruction"] = sub_val + update_kwargs["residual_data"] = residual_data + + return samples.pos, samples, key diff --git a/resolve/re/radio_response.py b/resolve/re/radio_response.py new file mode 100644 index 0000000000000000000000000000000000000000..52e57056d26ce50e8bb147459df232de6388582b --- /dev/null +++ b/resolve/re/radio_response.py @@ -0,0 +1,186 @@ +import resolve as rve +import nifty8 as ift +import numpy as np +import pickle +import jax.numpy as jnp + +from jax.lax import slice as jax_slice +from functools import partial + +def get_jax_fft(domain, target, inverse): + if inverse: + if domain.harmonic: + func = jnp.fft.fftn + fct = 1. + else: + func = jnp.fft.ifftn + fct = domain.size + fct *= target.scalar_dvol + else: + if domain.harmonic: + func = jnp.fft.ifftn + fct = domain.size + else: + func = jnp.fft.fftn + fct = 1. + fct *= domain.scalar_dvol + + def fft(x, fct, func): + return fct * func(x) if fct != 1 else func(x) + + return partial(fft, fct=fct, func=func) + +def build_exact_r(obs, conf_sky, psf_pixels, calibration_field=None, do_wgridding=True, epsilon=1e-9, verbosity=1, nthreads=8): + sp_sky_dom =rve.sky_model._spatial_dom(conf_sky) + sky_dom = rve.default_sky_domain(sdom=sp_sky_dom) + R = rve.InterferometryResponse(obs, sky_dom, do_wgridding, epsilon, verbosity, nthreads) + + full_psf0 = min(2*psf_pixels, sp_sky_dom.shape[0]) + full_psf1 = min(2*psf_pixels, sp_sky_dom.shape[1]) + sp_sky_dom_l = (sp_sky_dom.shape[0] + full_psf0, sp_sky_dom.shape[1] + full_psf1) + sp_sky_dom_l = ift.RGSpace(sp_sky_dom_l, distances=sp_sky_dom.distances) + sky_dom_l = rve.default_sky_domain(sdom=sp_sky_dom_l) + R_l = rve.InterferometryResponse(obs, sky_dom_l, do_wgridding, epsilon, verbosity, nthreads) + + dch_l = ift.DomainChangerAndReshaper(R_l.domain[3], R_l.domain) + R_l = R_l @ dch_l + dch = ift.DomainChangerAndReshaper(R.domain[3], R.domain) + R = R @ dch + + if calibration_field is not None: + C = ift.makeOp(calibration_field) + + R = C @ R + R_l = C @ R_l + + N_inv = ift.DiagonalOperator(obs.weight) + RNR = R.adjoint @ N_inv @ R + RNR_l = R_l.adjoint @ N_inv @ R_l + + return R, R_l, RNR, RNR_l + + +def compute_PSF(new_R, n_pix0, n_pix1): + dom = new_R.domain + shp = dom.shape + FFT = ift.FFTOperator(new_R.domain) + + delta = np.zeros(shp) + delta[shp[0]//2, shp[1]//2] = 1 / dom.scalar_weight() + delta = ift.makeField(dom, delta) + kernel = new_R(delta) + + # zero kernel + sh0 = shp[0]//2 + sh1 = shp[1]//2 + z_kern = np.zeros_like(kernel.val) + z_kern[sh0-n_pix0:sh0+n_pix0,sh1-n_pix1:sh1+n_pix1] = kernel.val[sh0 - n_pix0:sh0+n_pix0,sh1-n_pix1:sh1+n_pix1] + + pr_kern = np.roll(z_kern, -shp[0]//2, axis=0) + pr_kern = np.roll(pr_kern, -shp[1]//2, axis=1) + pr_kern = ift.makeField(FFT.domain, pr_kern) + from matplotlib.colors import LogNorm + ift.single_plot(pr_kern, norm=LogNorm()) + fourier_kern = FFT(pr_kern) + return fourier_kern.val + +def compute_approx_noise_kern(new_R, relativ_min_val=0.): + dom = new_R.domain + shp = dom.shape + FFT = ift.FFTOperator(new_R.domain) + + delta = np.zeros(shp) + delta[shp[0]//2, shp[1]//2] = 1 / dom.scalar_weight() + delta = ift.makeField(dom, delta) + kernel = new_R(delta).val + kernel = np.roll(kernel, -shp[0]//2, axis=0) + kernel = np.roll(kernel, -shp[1]//2, axis=1) + kernel = ift.makeField(new_R.target, kernel) + FFT = ift.FFTOperator(new_R.domain) + max_val = np.max(FFT(kernel).abs().val) + min_val = relativ_min_val * max_val + min_val = ift.full(FFT.target, min_val) + min_val_adder = ift.Adder(min_val) + + pos_eig_val = ift.Operator.identity_operator(FFT.target).exp() + pos_eig_val = min_val_adder @ pos_eig_val + rls1 = ift.Realizer(pos_eig_val.target) + rls2 = ift.Realizer(FFT.domain) + + kernel_pos = rls2 @ FFT.inverse @ rls1.adjoint @ pos_eig_val + + cov = ift.ScalingOperator(kernel_pos.target, 1e-2*max_val) + lh = ift.GaussianEnergy(data=kernel, inverse_covariance=cov.inverse) @ kernel_pos + init_pos = (FFT(kernel) - min_val).abs().log() + energy = ift.EnergyAdapter(position=init_pos, op=lh, want_metric=True) + + ic_newton = ift.DeltaEnergyController(name='Newton', iteration_limit=80, tol_rel_deltaE=0) + #minimizer = ift.NewtonCG(ic_newton, max_cg_iterations=400, energy_reduction_factor=1e-3) + minimizer = ift.NewtonCG(ic_newton) + res = minimizer(energy)[0].position + return pos_eig_val(res).val + +def build_approximations(RNR, RNR_l, noise_scaling=None, varcov=False, cache_noise_kernel='None', cache_response_kernel='None'): + shp = RNR.domain.shape + shp_l = RNR_l.domain.shape + # assert(shp[0] == shp[1]) + # assert(shp_l[0] == shp_l[1]) + FFT = ift.FFTOperator(RNR_l.domain) + fft_jax = get_jax_fft(FFT.domain[0], FFT.target[0], False) + fft_jax_inv = get_jax_fft(FFT.domain[0], FFT.target[0], True) + FFT_s = ift.FFTOperator(RNR.domain) + fft_jax_s = get_jax_fft(FFT_s.domain[0], FFT_s.target[0], False) + fft_jax_inv_s = get_jax_fft(FFT_s.domain[0], FFT_s.target[0], True) + + + # build approximate response + n_psf_pix0 = (shp_l[0] - shp[0]) + n_psf_pix1 = (shp_l[1] - shp[1]) + n_padding0 = n_psf_pix0 // 2 + n_padding1 = n_psf_pix1 // 2 + if not cache_response_kernel == 'None': + try: + psf_kernel = pickle.load(open(f"{cache_response_kernel}.p", "rb")) + except: + psf_kernel = compute_PSF(RNR_l, n_psf_pix0, n_psf_pix1) + pickle.dump(psf_kernel, open(f"{cache_response_kernel}.p", "wb")) + else: + psf_kernel = compute_PSF(RNR_l, n_psf_pix0, n_psf_pix1) + psf_kernel = jnp.array(psf_kernel) + apply_psf_kern = lambda x: fft_jax_inv(psf_kernel * fft_jax(x)).real + + slicer = lambda x: jax_slice( + x, (n_padding0, n_padding1), (n_padding0+ shp[0], n_padding1+ shp[1]) + ) + padder = lambda x: jnp.pad(x, ((n_padding0, n_padding0), (n_padding1, n_padding1))) + + RNR_approx = lambda x: slicer(apply_psf_kern(padder(x))) + + # build approximate N inverse + if not cache_noise_kernel == 'None': + try: + noise_kernel = pickle.load(open(f"{cache_noise_kernel}.p", "rb")) + except: + noise_kernel = compute_approx_noise_kern(RNR, 1e-3) + pickle.dump(noise_kernel, open(f"{cache_noise_kernel}.p", "wb")) + else: + noise_kernel = compute_approx_noise_kern(RNR, 1e-3) + noise_kernel_inv_sqrt = 1. / np.sqrt(noise_kernel) + noise_kernel_inv_sqrt = jnp.array(noise_kernel_inv_sqrt) + if varcov: + fl = ift.full(FFT_s.target, 1.) + vol = FFT_s(FFT_s.adjoint(fl)).real.mean().val + fac = np.sqrt(1/vol) + apply_n_sqinv_kern = lambda x: fac*noise_kernel_inv_sqrt * fft_jax_s(x['sky']) + elif not noise_scaling is None: + apply_n_sqinv_kern = lambda x: fft_jax_inv_s( + noise_scaling(x) * noise_kernel_inv_sqrt * fft_jax_s(x['sky']) + ).real + else: + apply_n_sqinv_kern = lambda x: fft_jax_inv_s( + noise_kernel_inv_sqrt * fft_jax_s(x['sky']) + ).real + + + return RNR_approx, apply_n_sqinv_kern + diff --git a/resolve/re/sky_model.py b/resolve/re/sky_model.py index 790bb9589e9d41f4a8410630667263cbee997a20..f07b58b7e0ba0afcbb54ce327e41e87ddd1de8e8 100644 --- a/resolve/re/sky_model.py +++ b/resolve/re/sky_model.py @@ -1,14 +1,15 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause # Author: Jakob Roth -from numpy import full import nifty8.re as jft import jax from jax import numpy as jnp +import numpy as np from ..sky_model import _spatial_dom -from ..sky_model import sky_model_points as rve_sky_pts +from ..constants import str2rad + def build_cf(prefix, conf, shape, dist): zmo = conf.getfloat(f"{prefix} zero mode offset") @@ -34,7 +35,7 @@ def build_cf(prefix, conf, shape, dist): cfm = jft.CorrelatedFieldMaker(prefix) cfm.set_amplitude_total_offset(**cf_zm) cfm.add_fluctuations( - shape, distances=dist, **cf_fl, prefix="", non_parametric_kind='power' + shape, distances=dist, **cf_fl, prefix="", non_parametric_kind="power" ) amps = cfm.get_normalized_amplitudes() cfm = cfm.finalize() @@ -42,8 +43,6 @@ def build_cf(prefix, conf, shape, dist): return cfm, additional - - def sky_model_diffuse(cfg): if not cfg["freq mode"] == "single": raise NotImplementedError("FIXME: only implemented for single frequency") @@ -52,63 +51,69 @@ def sky_model_diffuse(cfg): sky_dom = _spatial_dom(cfg) bg_shape = sky_dom.shape bg_distances = sky_dom.distances - bg_log_diff, additional = build_cf('stokesI diffuse space i0', cfg, bg_shape, bg_distances) - full_shape = (1,1,1)+bg_shape - + bg_log_diff, additional = build_cf( + "stokesI diffuse space i0", cfg, bg_shape, bg_distances + ) + full_shape = (1, 1, 1) + bg_shape def bg_diffuse(x): return jnp.broadcast_to( - jnp.exp(bg_log_diff(x['stokesI diffuse space i0'])), full_shape + jnp.exp(bg_log_diff(x["stokesI diffuse space i0"])), full_shape ) - bg_diffuse_model = jft.Model( - bg_diffuse, - domain={'stokesI diffuse space i0': bg_log_diff.domain} - ) + bg_diffuse, domain={"stokesI diffuse space i0": bg_log_diff.domain} + ) return bg_diffuse_model, additional -def get_pts_postions(cfg): - rsp, _ = rve_sky_pts(cfg) - return rsp._ops[0]._inds[0], rsp._ops[0]._inds[1] - -def get_inv_gamma(cfg): - rsp, _ = rve_sky_pts(cfg) - return rsp._ops[1].jax_expr def jax_insert(x, ptsx, ptsy, bg): bg = bg.at[:, :, :, ptsx, ptsy].set(x) return bg -def sky_model_points(cfg, bg=None): - if not cfg["freq mode"] == "single": - raise NotImplementedError("FIXME: only implemented for single frequency") - if not cfg["polarization"] == "I": - raise NotImplementedError("FIXME: only implemented for stokes I") - - if not cfg["point sources mode"] == "single": - raise NotImplementedError( - "FIXME: point sources only implemented for mode single" - ) - else: - sky_dom = _spatial_dom(cfg) - shp = sky_dom.shape - full_shp = (1,1,1)+shp - ptsx, ptsy = get_pts_postions(cfg) - inv_gamma = get_inv_gamma(cfg) - pts_shp = (len(ptsx),) - dom = {'points': jax.ShapeDtypeStruct(pts_shp, dtype=jnp.float64)} - if bg is None: - bg = jnp.zeros(full_shp) - pts_func = lambda x: jax_insert(inv_gamma(x['points']), ptsx=ptsx, ptsy=ptsy, bg=bg) - else: - pts_func = lambda x: jax_insert(inv_gamma(x['points']), ptsx=ptsx, ptsy=ptsy, bg=bg(x)) - dom = {**dom, **bg.domain} - pts_model = jft.Model(pts_func, domain=dom) - additional = {} - return pts_model, additional +def sky_model_points(cfg, bg=None): + if cfg["freq mode"] == "single": + if cfg["polarization"] == "I": + if cfg["point sources mode"] == "single": + ppos = [] + sky_dom = _spatial_dom(cfg) + s = cfg["point sources locations"] + for xy in s.split(","): + x, y = xy.split("$") + ppos.append((str2rad(x), str2rad(y))) + ppos = np.array(ppos) + dx = np.array(sky_dom.distances) + center = np.array(sky_dom.shape) // 2 + inds = np.unique(np.round(ppos / dx + center).astype(int).T, axis=1) + indsx, indsy = inds + alpha = cfg.getfloat("point sources alpha") + q = cfg.getfloat("point sources q") + sky_dom = _spatial_dom(cfg) + shp = sky_dom.shape + full_shp = (1, 1, 1) + shp + inv_gamma = jft.InvGammaPrior( + a=alpha, + scale=q, + name="points", + shape=jax.ShapeDtypeStruct((len(indsx),), float), + ) + if bg is None: + bg = jnp.zeros(full_shp) + pts_func = lambda x: jax_insert( + inv_gamma(x), ptsx=indsx, ptsy=indsy, bg=bg + ) + dom = inv_gamma.domain + else: + pts_func = lambda x: jax_insert( + inv_gamma(x), ptsx=indsx, ptsy=indsy, bg=bg(x) + ) + dom = {**inv_gamma.domain, **bg.domain} + pts_model = jft.Model(pts_func, domain=dom) + additional = {} + return pts_model, additional + raise NotImplementedError("FIXME: Not Implemented for selected mode.") def sky_model(cfg): diff --git a/resolve/re/sugar.py b/resolve/re/sugar.py new file mode 100644 index 0000000000000000000000000000000000000000..dfb549524a5886de919ddf649d0ac74a40fd684a --- /dev/null +++ b/resolve/re/sugar.py @@ -0,0 +1,230 @@ +import jax.numpy as jnp +import nifty8.re as jft +import matplotlib.pyplot as plt + +from .calibration import CalibrationDistribution +from .likelihood_models import ModelCalibrationLikelihoodFixedCovariance + +from ..data.observation import Observation + +from typing import Union, Iterable, Dict, Tuple, Optional +from dataclasses import dataclass, field + +class Bulk_CF_AntennaTimeDomain(jft.Model): + """ + Creates multiple independant correlated fields from the same offset and powerspectra + parameters. Number of correlated fields is the product of the number of polarization + directions, antennas and frequencies. + + Parameters + ---------- + dct_offset: dictionary + Dictionary containing information about the offset parameters. + dct_ps: dictionary + Dictionary containing information about the temporal powerspectrum parameters + prefix: string + Prefix to the names of the parameters of a correlated field. + polarizations: Iterable of int or str + Labels for polarization directions + frequencies: Iterable of int or str + Labels for frequencies: + antennas: Iterable of int or str + Labels for antennas + """ + def __init__( + self, + dct_offset: dict, + dct_ps: dict, + prefix: str, + polarizations: Iterable[Union[int, str]], + frequencies: Iterable[Union[int, str]], + antennas: Iterable[Union[int, str]] + ): + self._pol = list(polarizations) + self._ant = list(antennas) + self._freq = list(frequencies) + self._prefix = str(prefix) + + self._output_shape = (len(self._pol),len(self._ant),len(self._freq),dct_ps["shape"][0]) + + cfm = jft.CorrelatedFieldMaker(self._prefix + "_") + cfm.set_amplitude_total_offset(**dct_offset) + cfm.add_fluctuations(**dct_ps) + + n_total = len(self._pol)*len(self._ant)*len(self._freq) + + self._fields = jft.VModel(cfm.finalize(), axis_size=n_total) + self._powerspectrum = cfm.power_spectrum + + super().__init__(init=self._fields.init) + + def __call__(self,x): + return jnp.swapaxes(jnp.reshape(self._fields(x),self._output_shape),2,3) + + @property + def name(self): + return self._prefix + + def get_powerspectrum(self): + return self._powerspectrum + +@dataclass +class CalibrationAssembler: + """ + Collect and automatically creates quantities necessary for reconstruction of the calibration solution + of radio-interferometric data. + + Parameters + ---------- + obs: Observation + Observation object providing multiple quantities for construction of claibration + operator and likelihood. + phase_field: jft.Model or Bulk_CF_AntennaTimeDomain (preferred) + Correlated fields on antenna-time space for phases of calibration solutions. + logflux_field: jft.Model or Bulk_CF_AntennaTimeDomain (preferred) + Correlated fields on antenna-time space for log amplitude of calibration solutions. + dt: float + Distances between time points on time axis. Has to be the same distance of time points, + which is used for phase_fields and log_amplitude fields. + model_vis: jnp.ndarray + Optional. Assumed visibilities of the the source. + scaling_op: jft.Model + Optional. Model for scalar scaling of model_vis. + log_inv_cov: jft.Model + Optional. Model for log inverse covariance. + lh_label: str + Optional. Label for likelihood. If not set the default is the name of + the observation. + """ + obs: Observation + phase_field: jft.Model + logflux_field: jft.Model + dt: float + model_vis: Optional[jnp.ndarray] = None + scaling_op: Optional[jft.WrappedCall] = None + log_inv_cov: Optional[jft.Model] = None + lh_label: Optional[str] = None + cop: CalibrationDistribution = field(init=False) + scaled_cop: Optional[jft.Model] = field(init=False) + + def __post_init__(self): + self.lh_label = self.obs.source_name if self.lh_label is None else self.lh_label + self.cop = CalibrationDistribution(self.obs,self.phase_field,self.logflux_field,self.dt) + self.scaled_cop = self.cop if self.scaling_op is None else jft.Model( + call = lambda x: self.scaling_op(x)*self.cop(x), + domain = {**self.cop.domain,**self.scaling_op.domain}, + init = self.cop.init | self.scaling_op.init + ) + + @classmethod + def make( + cls, + observation: Observation, + phase_field: jft.Model, + logflux_field: jft.Model, + dt: float, + model_vis_flux_amplitude: Optional[complex] = None, + scaling_parameters: Optional[Tuple[str,Dict]] = None, + log_inv_cov: Optional[jft.Model] = None, + lh_label: Optional[str] = None + ): + """ + Constructs CalibrationAssembler with constant field for model_vis and an optional scaling_op. + + Parameters + ---------- + obs: Observation + Observation object providing multiple quantities for construction of claibration + operator and likelihood. + phase_field: jft.Model or Bulk_CF_AntennaTimeDomain (preferred) + Correlated fields on antenna-time space for phases of calibration solutions. + logflux_field: jft.Model or Bulk_CF_AntennaTimeDomain (preferred) + Correlated fields on antenna-time space for log amplitude of calibration solutions. + dt: float + Distances between time points on time axis. Has to be the same distance of time points, + which is used for phase_fields and log_amplitude fields. + model_vis_flux_amplitude: complex + Optional. Assumed magnitude of field for model_vis. + scaling_parameters: Tuple[str,Dict] + Optional. Selects prior model for scaling operator through string and dictionary contains + parameters of corresponding model. Currently only available for the following prior models + (see NIFTy.re package for arguments): + - "inverse_gamma" + - "log_normal" + log_inv_cov: jft.Model + Optional. Model for log inverse covariance. + lh_label: str + Optional. Label for likelihood. If not set the default is the name of + the observation. + """ + init_flux_field = None if model_vis_flux_amplitude is None else model_vis_flux_amplitude*jnp.ones(observation.vis.shape) + if scaling_parameters is not None: + match scaling_parameters[0]: + case "inv_gamma": + scaling_op = jft.InvGammaPrior(**scaling_parameters[1]) + case "log_normal": + scaling_op = jft.LogNormalPrior(**scaling_parameters[1]) + case _: + raise NotImplementedError("Constructor currently only supports 'log_normal' and 'inv_gamma' for scaling operator construction") + else: + scaling_op = None + + return cls(observation,phase_field,logflux_field,dt,init_flux_field,scaling_op,log_inv_cov,lh_label) + + def prior_realization(self,rng_key): + data_model = ModelCalibrationLikelihoodFixedCovariance(self.scaled_cop,self.model_vis,jnp.asarray(self.obs.mask.val)) + data_model_realization = data_model(data_model.init(rng_key)) + + plt.scatter(self.obs.vis.val.imag,self.obs.vis.val.real,label="Data") + plt.scatter(data_model_realization.imag,data_model_realization.real, alpha=0.01, label="Prior sample") + plt.legend() + plt.show() + + def get_lh_domain(self): + dom = self.scaled_cop.domain if self.log_inv_cov is None else self.scaled_cop.domain | self.log_inv_cov.domain + return dom + + def __repr__(self): + return( + f"Obervation: {self.obs.source_name}\n" + f"Phase field: {self.phase_field.name}\n" + f"Logflux field: {self.logflux_field.name}\n" + f"dt: {self.dt}\n" + f"Model visibilities field set: {False if self.model_vis is None else True}\n" + f"Scaling operator set: {False if self.scaling_op is None else True}\n" + f"Log inverse covariance set: {False if self.log_inv_cov is None else True}\n" + f"Likelihood label: {self.lh_label}\n" + ) + +@dataclass +class SkyAssembler: + """ + Data class for storage of quantities for the sky reconstruction with fast resolve + + Parameters + ---------- + obs: Observation + Observation object providing multiple quantities for construction of claibration + operator and likelihood. + sky: jft.Model + Sky model for reconstruction + noise_scaling: jft.Model + Optional. Model for noise scaling. + varcov: jft.Model + Optional. Model for variable covariance + """ + obs: Observation + sky: jft.Model + noise_scaling: Optional[jft.Model] = None + varcov: Optional[jft.Model] = None + + def __post_init__(self): + if (self.noise_scaling is not None) and (self.varcov is not None): + raise ValueError("Either noise_scaling or varcov can be set.") + + def __repr__(self): + return( + f"Obervation: {self.obs.source_name}\n" + f"Noise scaling model set: {False if self.noise_scaling is None else True}\n" + f"Variable covariance model set: {False if self.varcov is None else True}\n" + ) \ No newline at end of file diff --git a/test/test_jax_ports/observation_generator.py b/test/test_jax_ports/observation_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7cd25ad89300359fe38545a2659cc443d758e9 --- /dev/null +++ b/test/test_jax_ports/observation_generator.py @@ -0,0 +1,41 @@ +import numpy as np + +import resolve as rve + +def RandomAntennaPositions(n_row,uvw_max,n_antenna_max,time_max,rng_generator=np.random.default_rng(42)): + ant1,ant2 = [],[] + + for k in range(n_row): + x = rng_generator.integers(0,n_antenna_max-1) + y = rng_generator.integers(1,n_antenna_max) + + if(x==y): + while(x==y): + y = np.random.randint(1,n_antenna_max) + + ant1.append(x) + ant2.append(y) + + ant1 = np.array(ant1).astype(np.int32) + ant2 = np.array(ant2).astype(np.int32) + time = rng_generator.uniform(0,time_max,n_row) + uvw = rng_generator.uniform(-uvw_max,uvw_max,(n_row,3)) + + return rve.data.antenna_positions.AntennaPositions(uvw,ant1,ant2,time) + +def RandomObservation(n_baselines,pol_indices,freq_channels,uvw_max,n_antenna_max,time_max,abs_vis_max,weight_min,weight_max,rng_generator=np.random.default_rng(42)): + antenna_pos = RandomAntennaPositions(n_baselines,uvw_max,n_antenna_max,time_max,rng_generator=rng_generator) + + n_pol = len(pol_indices) + n_freq = freq_channels.size + vis_shape = (n_pol,n_baselines,n_freq) + + pol = rve.data.polarization.Polarization(pol_indices) + + vis_magnitude = rng_generator.uniform(0,abs_vis_max,vis_shape) + vis_phase = rng_generator.uniform(0,2*np.pi,vis_shape) + vis = vis_magnitude*np.exp(1.0j*vis_phase) + + weights = rng_generator.uniform(weight_min,weight_max,vis_shape) + + return rve.data.observation.Observation(antenna_pos,vis,weights,pol,freq_channels) \ No newline at end of file diff --git a/test/test_jax_ports/test_01_BulkCF.py b/test/test_jax_ports/test_01_BulkCF.py new file mode 100644 index 0000000000000000000000000000000000000000..aac1716f152bad99685134611ee797e9e5424f52 --- /dev/null +++ b/test/test_jax_ports/test_01_BulkCF.py @@ -0,0 +1,94 @@ +import os +os.environ["JAX_PLATFORM_NAME"] = "cpu" + +import numpy as np +import nifty8 as ift +import ducc0 +import resolve as rve +import resolve.re as jrve + +from numpy.testing import assert_allclose + +import jax +jax.config.update("jax_enable_x64", True) + +np.random.seed(42) + +from observation_generator import RandomObservation +from to_re_helpers import rve_cf_dict_to_jrve_cf_dict + +def prepare_operators(): + param = { + "n_baselines": 10000, + "pol_indices": (8,5), + "freq_channels": np.array([1.5e9]), + "uvw_max": 200, + "n_antenna_max": 10, + "time_max": 5000, + "abs_vis_max": 0.5, + "weight_min":0, + "weight_max": 10000, + } + + obs = RandomObservation(**param) + + uantennas = rve.unique_antennas(obs) + antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)} + total_N = len(param["pol_indices"]) * len(uantennas) + + dt = 20 # s + zero_padding_factor = 2 + tmin, tmax = rve.tmin_tmax(obs) + time_domain = ift.RGSpace( + ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt + ) + + polarizations = obs.polarization.to_str_list() + frequencies = [1] + antennas = uantennas + + dofdex = np.arange(total_N) + rve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5), "dofdex": dofdex} + rve_dct_ps = { + "fluctuations": (.2, .1), + "loglogavgslope": (-4.0, 1), + "flexibility": (1, .3), + "asperity": None, + "dofdex": dofdex + } + + jrve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5),} + jrve_dct_ps = { + "shape": time_domain.shape, + "distances": time_domain.distances[0], + "fluctuations": (.2, .1), + "loglogavgslope": (-4.0, 1), + "flexibility": (1, .3), + "asperity": None, + "non_parametric_kind": "power" + } + + cfmph = ift.CorrelatedFieldMaker(prefix='bulk_', total_N=total_N) + cfmph.add_fluctuations(time_domain, **rve_dct_ps) + cfmph.set_amplitude_total_offset(**rve_dct_offset) + rve_bulkCF = cfmph.finalize(0) + + pdom, _, fdom = obs.vis.domain + reshape = ift.DomainChangerAndReshaper(rve_bulkCF.target, [pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom]) + + rve_op = reshape @ rve_bulkCF + + jrve_op = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"bulk",polarizations,frequencies,antennas) + + return rve_op, jrve_op + +def test_consistency_BulkCF(): + rve_op, jrve_op = prepare_operators() + + rve_inp = ift.from_random(rve_op.domain) + jrve_inp = rve_cf_dict_to_jrve_cf_dict(rve_inp.val) + + rve_field = rve_op(rve_inp).val + jrve_field = jrve_op(jrve_inp) + + assert_allclose(jrve_field,rve_field) diff --git a/test/test_jax_ports/test_02_CalibrationInterpolator.py b/test/test_jax_ports/test_02_CalibrationInterpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..63fc7bee37ba2466dfc77603f6ed1eaf31f192e8 --- /dev/null +++ b/test/test_jax_ports/test_02_CalibrationInterpolator.py @@ -0,0 +1,64 @@ +import os +os.environ["JAX_PLATFORM_NAME"] = "cpu" + +import numpy as np +import jax.numpy as jnp +import nifty8 as ift +import ducc0 +import resolve as rve +import resolve.re as jrve + +from numpy.testing import assert_allclose + +import jax +jax.config.update("jax_enable_x64", True) + +np.random.seed(42) + +from observation_generator import RandomObservation +from to_re_helpers import rve_cf_dict_to_jrve_cf_dict + +def prepare_operators(): + param = { + "n_baselines": 10000, + "pol_indices": (8,5), + "freq_channels": np.array([1.5e9]), + "uvw_max": 200, + "n_antenna_max": 10, + "time_max": 5000, + "abs_vis_max": 0.5, + "weight_min":0, + "weight_max": 10000, + } + + obs = RandomObservation(**param) + + uantennas = rve.unique_antennas(obs) + antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)} + total_N = len(param["pol_indices"]) * len(uantennas) + + dt = 20 # s + zero_padding_factor = 2 + tmin, tmax = rve.tmin_tmax(obs) + time_domain = ift.RGSpace( + ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt + ) + + pdom, _, fdom = obs.vis.domain + dom = ift.DomainTuple.make([pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom]) + + jrve_op = jrve.CalibrationInterpolator(obs.ant1,obs.time,time_domain.distances[0],obs.vis.val.shape) + + rve_op = rve.CalibrationDistributor(dom,obs.vis.domain,obs.ant1,obs.time,antenna_dct,None) + + return rve_op, jrve_op + +def test_consistency_CalibrationInterpolator(): + rve_op, jrve_op = prepare_operators() + + inp = ift.from_random(rve_op.domain) + + rve_field = rve_op(inp).val + jrve_field = jrve_op(jnp.asarray(inp.val)) + + assert_allclose(jrve_field,rve_field) \ No newline at end of file diff --git a/test/test_jax_ports/test_03_CalibrationDistributor.py b/test/test_jax_ports/test_03_CalibrationDistributor.py new file mode 100644 index 0000000000000000000000000000000000000000..b755e699f0d0bd3638ac49a5bd99a54ee208e7dd --- /dev/null +++ b/test/test_jax_ports/test_03_CalibrationDistributor.py @@ -0,0 +1,104 @@ +import os +os.environ["JAX_PLATFORM_NAME"] = "cpu" + +import numpy as np +import nifty8 as ift +import ducc0 +import resolve as rve +import resolve.re as jrve + +from numpy.testing import assert_allclose + +import jax +jax.config.update("jax_enable_x64", True) + +np.random.seed(42) + +from observation_generator import RandomObservation +from to_re_helpers import rve_cf_dict_to_jrve_cf_dict + +def prepare_operators(): + param = { + "n_baselines": 10000, + "pol_indices": (8,5), + "freq_channels": np.array([1.5e9]), + "uvw_max": 200, + "n_antenna_max": 10, + "time_max": 5000, + "abs_vis_max": 0.5, + "weight_min":0, + "weight_max": 10000, + } + + obs = RandomObservation(**param) + + uantennas = rve.unique_antennas(obs) + antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)} + total_N = len(param["pol_indices"]) * len(uantennas) + + dt = 20 # s + zero_padding_factor = 2 + tmin, tmax = rve.tmin_tmax(obs) + time_domain = ift.RGSpace( + ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt + ) + + polarizations = obs.polarization.to_str_list() + frequencies = [1] + antennas = uantennas + + dofdex = np.arange(total_N) + rve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5), "dofdex": dofdex} + rve_dct_ps = { + "fluctuations": (.2, .1), + "loglogavgslope": (-4.0, 1), + "flexibility": (1, .3), + "asperity": None, + "dofdex": dofdex + } + + jrve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5),} + jrve_dct_ps = { + "shape": time_domain.shape, + "distances": time_domain.distances[0], + "fluctuations": (.2, .1), + "loglogavgslope": (-4.0, 1), + "flexibility": (1, .3), + "asperity": None, + "non_parametric_kind": "power" + } + + cfmph = ift.CorrelatedFieldMaker(prefix='phase_', total_N=total_N) + cfmph.add_fluctuations(time_domain, **rve_dct_ps) + cfmph.set_amplitude_total_offset(**rve_dct_offset) + rve_phase = cfmph.finalize(0) + + cfmph = ift.CorrelatedFieldMaker(prefix='logamp_', total_N=total_N) + cfmph.add_fluctuations(time_domain, **rve_dct_ps) + cfmph.set_amplitude_total_offset(**rve_dct_offset) + rve_logamp = cfmph.finalize(0) + + pdom, _, fdom = obs.vis.domain + reshape = ift.DomainChangerAndReshaper(rve_phase.target, [pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom]) + + rve_phase = reshape @ rve_phase + rve_logamp = reshape @ rve_logamp + + jrve_phase = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"phase",polarizations,frequencies,antennas) + jrve_logamp = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"logamp",polarizations,frequencies,antennas) + + rve_op = rve.calibration_distribution(obs,rve_phase,rve_logamp,antenna_dct,None) + jrve_op = jrve.CalibrationDistribution(obs,jrve_phase,jrve_logamp,time_domain.distances[0]) + + return rve_op, jrve_op + +def test_consistency_CalibrationDistributor(): + rve_op, jrve_op = prepare_operators() + + rve_inp = ift.from_random(rve_op.domain) + jrve_inp = rve_cf_dict_to_jrve_cf_dict(rve_inp.val) + + rve_field = rve_op(rve_inp).val + jrve_field = jrve_op(jrve_inp) + + assert_allclose(jrve_field,rve_field) \ No newline at end of file diff --git a/test/test_jax_ports/test_04_CalibrationLikelihood.py b/test/test_jax_ports/test_04_CalibrationLikelihood.py new file mode 100644 index 0000000000000000000000000000000000000000..0a469d8217b173ec34063c8c3813a22be4ed9e3a --- /dev/null +++ b/test/test_jax_ports/test_04_CalibrationLikelihood.py @@ -0,0 +1,151 @@ +import os +os.environ["JAX_PLATFORM_NAME"] = "cpu" + +import numpy as np +import jax.numpy as jnp +import nifty8 as ift +import nifty8.re as jft +import ducc0 +import resolve as rve +import resolve.re as jrve +from functools import partial + +import pytest +from numpy.testing import assert_allclose + +import jax +jax.config.update("jax_enable_x64", True) + +np.random.seed(42) + +from observation_generator import RandomObservation +from to_re_helpers import rve_cf_dict_to_jrve_cf_dict, mixed_rve_dct_to_mixed_jrve + +pmp = pytest.mark.parametrize + +def prepare_partial_operators(): + param = { + "n_baselines": 10000, + "pol_indices": (8,5), + "freq_channels": np.array([1.5e9]), + "uvw_max": 200, + "n_antenna_max": 10, + "time_max": 5000, + "abs_vis_max": 0.5, + "weight_min":0, + "weight_max": 10000, + } + + obs = RandomObservation(**param) + + uantennas = rve.unique_antennas(obs) + antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)} + total_N = len(param["pol_indices"]) * len(uantennas) + + dt = 20 # s + zero_padding_factor = 2 + tmin, tmax = rve.tmin_tmax(obs) + time_domain = ift.RGSpace( + ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt + ) + + polarizations = obs.polarization.to_str_list() + frequencies = [1] + antennas = uantennas + + dofdex = np.arange(total_N) + rve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5), "dofdex": dofdex} + rve_dct_ps = { + "fluctuations": (.2, .1), + "loglogavgslope": (-4.0, 1), + "flexibility": (1, .3), + "asperity": None, + "dofdex": dofdex + } + + jrve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5),} + jrve_dct_ps = { + "shape": time_domain.shape, + "distances": time_domain.distances[0], + "fluctuations": (.2, .1), + "loglogavgslope": (-4.0, 1), + "flexibility": (1, .3), + "asperity": None, + "non_parametric_kind": "power" + } + + cfmph = ift.CorrelatedFieldMaker(prefix='phase_', total_N=total_N) + cfmph.add_fluctuations(time_domain, **rve_dct_ps) + cfmph.set_amplitude_total_offset(**rve_dct_offset) + rve_phase = cfmph.finalize(0) + + cfmph = ift.CorrelatedFieldMaker(prefix='logamp_', total_N=total_N) + cfmph.add_fluctuations(time_domain, **rve_dct_ps) + cfmph.set_amplitude_total_offset(**rve_dct_offset) + rve_logamp = cfmph.finalize(0) + + pdom, _, fdom = obs.vis.domain + reshape = ift.DomainChangerAndReshaper(rve_phase.target, [pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom]) + + rve_phase = reshape @ rve_phase + rve_logamp = reshape @ rve_logamp + + jrve_phase = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"phase",polarizations,frequencies,antennas) + jrve_logamp = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"logamp",polarizations,frequencies,antennas) + + rve_cop = rve.calibration_distribution(obs,rve_phase,rve_logamp,antenna_dct,None) + jrve_cop = jrve.CalibrationDistribution(obs,jrve_phase,jrve_logamp,time_domain.distances[0]) + + rve_model_vis = ift.full(obs.vis.domain,1 + 0.0j) + jrve_model_vis = jnp.array(rve_model_vis.val) + + rve_op = partial( + rve.CalibrationLikelihood, + observation = obs, + calibration_operator = rve_cop, + model_visibilities = rve_model_vis + ) + + jrve_op = partial( + jrve.CalibrationLikelihood, + observation = obs, + calibration_operator = jrve_cop, + model_visibilities = jrve_model_vis + ) + + return rve_op, jrve_op, obs + +def prepare_operators(mode): + rve_op_partial, jrve_op_partial, obs = prepare_partial_operators() + + if mode == "model": + rve_log_inv_cov = ift.NormalTransform(0,1,"log_inv_cov",obs.weight.val.size) + reshape = ift.DomainChangerAndReshaper(rve_log_inv_cov.target,obs.weight.domain) + rve_log_inv_cov = reshape @ rve_log_inv_cov + + jrve_log_inv_cov = jft.NormalPrior(0,1,name="log_inv_cov",shape=obs.weight.val.shape) + else: + rve_log_inv_cov = None + jrve_log_inv_cov = None + + rve_op = rve_op_partial(log_inverse_covariance_operator=rve_log_inv_cov) + jrve_op = jrve_op_partial(log_inverse_covariance_operator=jrve_log_inv_cov) + + return rve_op, jrve_op, obs + + +@pmp("log_inv_cov_mode",[None, "model"]) +def test_consistency_CalibrationLikelihood(log_inv_cov_mode): + rve_op, jrve_op, obs = prepare_operators(log_inv_cov_mode) + + rve_inp = ift.from_random(rve_op.domain) + + if log_inv_cov_mode == "model": + jrve_inp = mixed_rve_dct_to_mixed_jrve(rve_inp.val,("log_inv_cov",),(obs.weight.val.shape,)) + else: + jrve_inp = rve_cf_dict_to_jrve_cf_dict(rve_inp.val) + + rve_field = rve_op(rve_inp).val + jrve_field = jrve_op(jrve_inp) + + assert_allclose(jrve_field,rve_field) \ No newline at end of file diff --git a/test/test_jax_ports/test_05_ImagingLikelihood.py b/test/test_jax_ports/test_05_ImagingLikelihood.py new file mode 100644 index 0000000000000000000000000000000000000000..57169fdaa43ceced01cd18ea671a31938c0f1398 --- /dev/null +++ b/test/test_jax_ports/test_05_ImagingLikelihood.py @@ -0,0 +1,240 @@ +import os +os.environ["JAX_PLATFORM_NAME"] = "cpu" + +import numpy as np +import jax.numpy as jnp +import nifty8 as ift +import nifty8.re as jft +import ducc0 +import resolve as rve +import resolve.re as jrve +from functools import partial + +import pytest +from numpy.testing import assert_allclose + +import jax +jax.config.update("jax_enable_x64", True) + +np.random.seed(42) + +from observation_generator import RandomObservation +from to_re_helpers import rve_cf_dict_to_jrve_cf_dict_simple, mixed_rve_dct_to_mixed_jrve + +pmp = pytest.mark.parametrize + +class JRVE_Sky(jft.Model): + def __init__(self,field): + self._field = field + + super().__init__(init=self._field.init) + + def __call__(self,x): + return jnp.array([[[jnp.exp(self._field(x))]]]) + +def prepare_partial_operators(): + param = { + "n_baselines": 10000, + "pol_indices": (8,5), + "freq_channels": np.array([1.5e9]), + "uvw_max": 200, + "n_antenna_max": 10, + "time_max": 5000, + "abs_vis_max": 0.5, + "weight_min":0, + "weight_max": 10000, + } + + obs = RandomObservation(**param) + + fov = np.array([1, 1]) * rve.DEG2RAD + npix = np.array([64, 64]) + skydom = ift.RGSpace(npix, fov / npix) + + sky_domain_dict = dict(npix_x=skydom.shape[0], + npix_y=skydom.shape[1], + pixsize_x=skydom.distances[0], + pixsize_y=skydom.distances[1], + pol_labels=['I'], + times=[0.], + freqs=[0.]) + + rve_sky_dct = { + "target": skydom, + "offset_mean": 0, + "offset_std": (1, 0.5), + "prefix": "logdiffuse_", + "fluctuations": (.2, .1), + "loglogavgslope": (-4.0, 1), + "flexibility": (1, .3), + "asperity": None, + } + rve_logsky = ift.SimpleCorrelatedField(**rve_sky_dct) + rve_sky = rve_logsky.exp() + + rve_sky_dom = rve.default_sky_domain(sdom = skydom) + reshape_sky = ift.DomainChangerAndReshaper(rve_sky.target, rve_sky_dom) + rve_sky = reshape_sky @ rve_sky + + jrve_dct_sky_offset = {"offset_mean": 0, "offset_std": (1, 0.5),} + jrve_dct_sky_ps = { + "shape": skydom.shape, + "distances": skydom.distances[0], + "fluctuations": (.2, .1), + "loglogavgslope": (-4.0, 1), + "flexibility": (1, .3), + "asperity": None, + "non_parametric_kind": "power" + } + cfm = jft.CorrelatedFieldMaker("logdiffuse_") + cfm.set_amplitude_total_offset(**jrve_dct_sky_offset) + cfm.add_fluctuations(**jrve_dct_sky_ps) + jrve_logsky = cfm.finalize() + jrve_sky = JRVE_Sky(jrve_logsky) + + rve_op = partial( + rve.ImagingLikelihood, + observation = obs, + sky_operator = rve_sky, + epsilon = 1e-5, + do_wgridding = True + ) + + jrve_op = partial( + jrve.ImagingLikelihood, + observation = obs, + sky_operator = jrve_sky, + sky_domain_dict = sky_domain_dict, + epsilon = 1e-5, + do_wgridding = True + ) + + return rve_op, jrve_op, obs, rve_sky.domain + +def prepare_cop(obs): + uantennas = rve.unique_antennas(obs) + antenna_dct = {int(aa): ii for ii, aa in enumerate(uantennas)} + total_N = len(obs.polarization) * len(uantennas) + + dt = 20 # s + zero_padding_factor = 2 + tmin, tmax = rve.tmin_tmax(obs) + time_domain = ift.RGSpace( + ducc0.fft.good_size(int(zero_padding_factor * (tmax - tmin) / dt)), dt + ) + + polarizations = obs.polarization.to_str_list() + frequencies = [1] + antennas = uantennas + + dofdex = np.arange(total_N) + rve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5), "dofdex": dofdex} + rve_dct_ps = { + "fluctuations": (.2, .1), + "loglogavgslope": (-4.0, 1), + "flexibility": (1, .3), + "asperity": None, + "dofdex": dofdex + } + + jrve_dct_offset = {"offset_mean": 0, "offset_std": (1, 0.5),} + jrve_dct_ps = { + "shape": time_domain.shape, + "distances": time_domain.distances[0], + "fluctuations": (.2, .1), + "loglogavgslope": (-4.0, 1), + "flexibility": (1, .3), + "asperity": None, + "non_parametric_kind": "power" + } + + cfmph = ift.CorrelatedFieldMaker(prefix='phase_', total_N=total_N) + cfmph.add_fluctuations(time_domain, **rve_dct_ps) + cfmph.set_amplitude_total_offset(**rve_dct_offset) + rve_phase = cfmph.finalize(0) + + cfmph = ift.CorrelatedFieldMaker(prefix='logamp_', total_N=total_N) + cfmph.add_fluctuations(time_domain, **rve_dct_ps) + cfmph.set_amplitude_total_offset(**rve_dct_offset) + rve_logamp = cfmph.finalize(0) + + pdom, _, fdom = obs.vis.domain + reshape = ift.DomainChangerAndReshaper(rve_phase.target, [pdom, ift.UnstructuredDomain(len(uantennas)), time_domain, fdom]) + + rve_phase = reshape @ rve_phase + rve_logamp = reshape @ rve_logamp + + jrve_phase = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"phase",polarizations,frequencies,antennas) + jrve_logamp = jrve.Bulk_CF_AntennaTimeDomain(jrve_dct_offset,jrve_dct_ps,"logamp",polarizations,frequencies,antennas) + + rve_cop = rve.calibration_distribution(obs,rve_phase,rve_logamp,antenna_dct,None) + jrve_cop = jrve.CalibrationDistribution(obs,jrve_phase,jrve_logamp,time_domain.distances[0]) + + return rve_cop, jrve_cop + + +def prepare_operators(cov_mode, cal_mode): + rve_op_partial, jrve_op_partial, obs, rve_sky_dom = prepare_partial_operators() + + if cov_mode == "model": + rve_log_inv_cov = ift.NormalTransform(0,1,"log_inv_cov",obs.weight.val.size) + reshape = ift.DomainChangerAndReshaper(rve_log_inv_cov.target,obs.weight.domain) + rve_log_inv_cov = reshape @ rve_log_inv_cov + + jrve_log_inv_cov = jft.NormalPrior(0,1,name="log_inv_cov",shape=obs.weight.val.shape) + else: + rve_log_inv_cov = None + jrve_log_inv_cov = None + + rve_op_partial = partial(rve_op_partial,log_inverse_covariance_operator=rve_log_inv_cov) + jrve_op_partial = partial(jrve_op_partial,log_inverse_covariance_operator=jrve_log_inv_cov) + + if cal_mode is not None: + rve_cop, jrve_cop = prepare_cop(obs) + + if cal_mode == "cop": + rve_op = rve_op_partial(calibration_operator=rve_cop,calibration_field=None) + jrve_op = jrve_op_partial(calibration_operator=jrve_cop,calibration_field=None) + elif cal_mode == "cfld": + rve_cfld = ift.from_random(rve_cop.target,dtype="complex128") + jrve_cfld = jnp.array(rve_cfld.val) + + rve_op = rve_op_partial(calibration_operator=None,calibration_field=rve_cfld) + jrve_op = jrve_op_partial(calibration_operator=None,calibration_field=jrve_cfld) + else: + rve_op = rve_op_partial(calibration_operator=None,calibration_field=None) + jrve_op = jrve_op_partial(calibration_operator=None,calibration_field=None) + + return rve_op, jrve_op, obs, rve_sky_dom + + +@pmp("log_inv_cov_mode",[None, "model"]) +@pmp("cal_op_mode", [None, "cop", "cfld"]) +def test_consistency_CalibrationLikelihood(log_inv_cov_mode,cal_op_mode): + rve_op, jrve_op, obs, rve_sky_dom = prepare_operators(log_inv_cov_mode, cal_op_mode) + + rve_inp = ift.from_random(rve_op.domain) + + if log_inv_cov_mode == "model": + jrve_inp = mixed_rve_dct_to_mixed_jrve(rve_inp.val,("log_inv_cov",),(obs.weight.val.shape,),rve_sky_dom.keys()) + else: + if (cal_op_mode == "cop"): + jrve_inp = mixed_rve_dct_to_mixed_jrve(rve_inp.val,simple_cf_keys=rve_sky_dom.keys()) + else: + jrve_inp = rve_cf_dict_to_jrve_cf_dict_simple(rve_inp.val) + + rve_field = rve_op(rve_inp).val + jrve_field = jrve_op(jrve_inp) + + assert_allclose(jrve_field,rve_field) + + + + + #s_jrve_fc_img_lh_dct_translated = rve_cf_dict_to_jrve_cf_dict_simple(s_rve_fc_img_lh_dct.val) + #s_jrve_fc_img_cfld_lh_dct_translated = rve_cf_dict_to_jrve_cf_dict_simple(s_rve_fc_img_cfld_lh_dct.val) + + #s_jrve_fc_img_cop_lh_dct_translated = mixed_rve_dct_to_mixed_jrve(s_rve_fc_img_cop_lh_dct.val,simple_cf_keys=ift.from_random(rve_sky.domain).val.keys()) + #s_jrve_vc_img_lh_dct_translated = mixed_rve_dct_to_mixed_jrve(s_rve_vc_img_lh_dct.val,("log_inv_cov",),(obs.weight.val.shape,),simple_cf_keys=ift.from_random(rve_sky.domain).val.keys()) + #s_jrve_vc_img_cfld_lh_dct_translated = mixed_rve_dct_to_mixed_jrve(s_rve_vc_img_cfld_lh_dct.val,("log_inv_cov",),(obs.weight.val.shape,),simple_cf_keys=ift.from_random(rve_sky.domain).val.keys()) + #s_jrve_vc_img_cop_lh_dct_translated = mixed_rve_dct_to_mixed_jrve(s_rve_vc_img_cop_lh_dct.val,("log_inv_cov",),(obs.weight.val.shape,),simple_cf_keys=ift.from_random(rve_sky.domain).val.keys()) \ No newline at end of file diff --git a/test/test_jax_ports/to_re_helpers.py b/test/test_jax_ports/to_re_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..6fcae11587f0c36932779f87fd60f890d2250eec --- /dev/null +++ b/test/test_jax_ports/to_re_helpers.py @@ -0,0 +1,63 @@ +import numpy as np +import jax.numpy as jnp + +def rve_cf_dict_to_jrve_cf_dict(rve_cf_dct): + res = rve_cf_dct + for h in res.keys(): + h_split = h.split("_") + + if(h_split[1] == "spectrum"): + res[h] = np.swapaxes(res[h],1,2) + return res + +def rve_cf_dict_to_jrve_cf_dict_simple(rve_cf_dct): + res = rve_cf_dct + + for h in rve_cf_dct.keys(): + h_split = h.split("_") + + if (h_split[1] == "spectrum"): + res[h] = res[h].T + return res + +def single_rve_dct_to_single_jrve_dct(rve_dct,shape): + if len(rve_dct) != 1: + raise ValueError("Can only convert dict with single entry/element") + + key, rve_value = rve_dct.popitem() + + jrve_value = np.reshape(rve_value,shape) + + return {key: jnp.array(jrve_value)} + +def mixed_rve_dct_to_mixed_jrve(rve_dct,non_cf_keys=None,non_cf_shapes=None,simple_cf_keys=None): + if (non_cf_keys is not None) and (non_cf_shapes is not None): + if (len(non_cf_keys) != len(non_cf_shapes)) and (non_cf_keys is not None): + raise ValueError("Each non cf dict key should have an shape for the associate dict value") + + if ((non_cf_keys is not None) and (simple_cf_keys is None)): + non_cf_rve_dcts = [({key: rve_dct[key]},non_cf_shapes[k]) for k,key in enumerate(non_cf_keys)] + cf_rve_dct = {key: rve_dct[key] for key in rve_dct.keys() if (key not in non_cf_keys)} + simple_cf_rve_dcts = {} + elif ((non_cf_keys is None) and (simple_cf_keys is not None)): + simple_cf_rve_dcts = {key: rve_dct[key] for key in simple_cf_keys} + cf_rve_dct = {key: rve_dct[key] for key in rve_dct.keys() if (key not in simple_cf_keys)} + non_cf_rve_dcts = [] + else: + non_cf_rve_dcts = [({key: rve_dct[key]},non_cf_shapes[k]) for k,key in enumerate(non_cf_keys)] + simple_cf_rve_dcts = {key: rve_dct[key] for key in simple_cf_keys} + cf_rve_dct = {key: rve_dct[key] for key in rve_dct.keys() if ((key not in non_cf_keys) and (key not in simple_cf_keys))} + + jrve_dct = {} + + if len(cf_rve_dct) > 0: + jrve_dct = rve_cf_dict_to_jrve_cf_dict(cf_rve_dct) + + if len(non_cf_rve_dcts) > 0: + for x in non_cf_rve_dcts: + jrve_dct.update(single_rve_dct_to_single_jrve_dct(x[0],x[1])) + + if len(simple_cf_rve_dcts) > 0: + jrve_dct.update(rve_cf_dict_to_jrve_cf_dict_simple(simple_cf_rve_dcts)) + + return jrve_dct \ No newline at end of file