Skip to content
Snippets Groups Projects
Commit beda3815 authored by Jakob Roth's avatar Jakob Roth
Browse files

update radio demo jax

parent a12bfff3
No related tags found
No related merge requests found
Pipeline #227569 failed
%% Cell type:markdown id: tags:
Nifty tutorial for radio interferometric imaging
================================================
%% Cell type:markdown id: tags:
Setup
-----
%% Cell type:code id: tags:
``` python
import numpy as np
import jax
from jax import numpy as jnp
from jax import random
import matplotlib.pyplot as plt
# import nifty8 as ift
import nifty8.re as jft
import resolve as rve
import resolve.re as jre
```
%% Cell type:code id: tags:
``` python
jax.config.update("jax_enable_x64", True)
seed = 42
key = random.PRNGKey(seed)
```
%% Cell type:markdown id: tags:
Load mock data with EHT measurement configuration and visualize uv-plane
%% Cell type:code id: tags:
``` python
mock_observation = rve.Observation.load('data/mock_data_eht_config_SR1_M87_2017_096_hi.npz')
plt.plot(mock_observation.uvw[:,0], -mock_observation.uvw[:,1], 'b.')
plt.xlabel('u')
plt.ylabel('v')
plt.xlim([-12e6,12e6])
plt.ylim([-12e6,12e6])
plt.show()
```
%% Cell type:markdown id: tags:
Set up the FOV and number of pixels of the space which defines the image plane.
%% Cell type:code id: tags:
``` python
x_fov, y_fov = 300, 300 # Field of view in x and y directon in myas
x_npix, y_npix = 256, 256 # Number of pixels
dx = rve.str2rad(f'{x_fov}muas') / x_npix
dy = rve.str2rad(f'{y_fov}muas') / y_npix
```
%% Cell type:markdown id: tags:
Model of the sky brightness distribution
----------------------------------------
The sky model is going to be a log-normal random process (the log-brightness is Gaussian distributed). As the prior correlation structure of the log-brightness is unknown, it will be generated using a `CorrelatedField` model and the power-spectrum is inferred along with the realization.
%% Cell type:code id: tags:
``` python
# Hyperparameters (mean and std pairs) for prior models of parameters
args_zm = {# Overall offset from zero of Gaussian process
'offset_mean': -np.log(dx*dy) - 10.,
# Variability of inferred offset amplitude
'offset_std': (3., 1.)}
args_fl = {# Amplitude of field fluctuations
'fluctuations': (1.5, 0.5),
# Exponent of power law power spectrum component
'loglogavgslope': (-4., .5),
# Amplitude of integrated Wiener process power spectrum component
'flexibility': (.3, .1),
# How ragged the integrated Wiener process component is
'asperity': None # Passing 'None' disables this part of the model
}
cfm = jft.CorrelatedFieldMaker("cf")
cfm.set_amplitude_total_offset(**args_zm)
cfm.add_fluctuations(
(x_npix, y_npix),
distances=(dx, dy),
**args_fl,
prefix="ax1",
non_parametric_kind="power"
)
log_sky = cfm.finalize()
class Sky_model(jft.Model):
def __init__(self, log_sky):
self.log_sky = log_sky
super().__init__(init=self.log_sky.init)
def __call__(self, x):
return jnp.exp(self.log_sky(x))
sky = Sky_model(log_sky)
sky = lambda x: jnp.exp(log_sky(x))
```
%% Cell type:markdown id: tags:
Prior samples
-------------
To get a feeling for the prior variability of the sky model we generate several random realizations of the process and visualize them. The power-spectra, generated for each process realization, are depicted in the last panel.
%% Cell type:code id: tags:
``` python
fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(15,13))
axs = list(axs.ravel())
pspecs = []
for _ in range(8):
key, subkey = random.split(key)
pos_random = jft.Vector(jft.random_like(subkey, sky.domain))
pos_random = jft.Vector(jft.random_like(subkey, log_sky.domain))
ax = axs.pop(0)
ax.imshow(sky(pos_random).T, origin='lower', cmap='afmhot')
pspecs.append(cfm.amplitude(pos_random)[1:])
ax = axs.pop(0)
for p in pspecs:
ax.loglog(p)
plt.show()
```
%% Cell type:markdown id: tags:
Setup a mock VLBI imaging task using the `InterferometryResponse` of `resolve`. The `mock_observation` contains all information relevant to set up the likelihood, including visibility data, uv-coordinates, and the noise levels of each measurement. The `measurement_sky` contains all relevant information regarding the prior model of the sky brightness distribution. The additional parameters passed to the `InterferometryResponse` control the accuracy and behaviour of the `wgridder` used within `resolve` which defines the response function.
Setup a mock VLBI imaging task using the `InterferometryResponse` of `resolve`. The `mock_observation` contains all information relevant to set up the likelihood, including visibility data, uv-coordinates, and the noise levels of each measurement. The `sky` contains all relevant information regarding the prior model of the sky brightness distribution. The additional parameters passed to the `InterferometryResponse` control the accuracy and behaviour of the `wgridder` used within `resolve` which defines the response function.
%% Cell type:code id: tags:
``` python
# Set up response
R = jre.InterferometryResponseDucc(mock_observation, x_npix, y_npix, dx, dy,
do_wgridding=False, epsilon=1e-6)
signal_response = lambda x: R(sky(x))
likelihood = jft.Gaussian(mock_observation.vis.val, mock_observation.weight.val).amend(signal_response)
```
%% Cell type:markdown id: tags:
Solve the inference problem
---------------------------
The `likelihood` together with the `sky_model` fully specify a Bayesian inverse problem and imply a posterior probability distribution over the degrees of freedom (DOF) of the model. This distribution is, in general, a high-dimensional (number of pixels + DOF of power spectrum) and non-Gaussian distribution, which prohibits analytical integration. To access its information and compute posterior expectation values, numerical approximations have to be made.
The `likelihood` together with the `sky` fully specify a Bayesian inverse problem and imply a posterior probability distribution over the degrees of freedom (DOF) of the model. This distribution is, in general, a high-dimensional (number of pixels + DOF of power spectrum) and non-Gaussian distribution, which prohibits analytical integration. To access its information and compute posterior expectation values, numerical approximations have to be made.
`nifty` provides multiple ways of posterior approximation, with Variational Inference (VI) being by far the most frequently used method. In VI the posterior distribution is approximated with another distribution by minimizing their respective forward Kullbach-Leibler divergence (KL). In the following, the Geometric VI method is employed which utilizes concepts of differential geometry to provide a local estimate of the distribution function.
Its numerical implementation (`jft.optimize_kl`) consists of a repeated and successive re-approximation of the VI objective function (the KL) via a stochastic estimate. This estimate is generated using the at the time best available approximation of the posterior, and then the KL gets minimized to further improve it. The resulting algorithm consists of a repeated re-generation of new samples for the estimate and a successing optimization thereof until convergence is reached.
The internal steps of `jft.optimize_kl` invoke the approximate solution of multiple interdependent optimization problems:
- For sample generation, a linear system of equations is approximated using the `ConjugateGradient` (CG) method
- Furthermore, the sample generation invokes a non-linear optimization problem approximated using the `NewtonCG` method
- Finally, the approximative distribution is optimized by minimizing the KL between the true posterior and the approximation. This again invokes a non-linear optimization problem approximated with `NewtonCG`.
%% Cell type:markdown id: tags:
Posterior visualization
-----------------------
Before we run the minimization routine, we set up a `plotting_callback` function for visualization. Note that additional information and plots regarding the reconstruction are generated during an `ift.optimize_kl` run and stored in the folder passed to the `output_directory` argument of `ift.optimize_kl`
The final output of `ift.optimize_kl` is a collection of approximate posterior samples and is provided via an instance of `ift.ResidualSampleList`. A `SampleList` provides a variety of convenience functions such as:
- `average`: to compute sample averages
- `sample_stat`: to get the approximate mean and variance of a model
- `iterator`: a python iterator over all samples
- ...
Before we run the minimization routine, we set up a `plotting_callback` function for visualization. Note that additional information and plots regarding the reconstruction are generated during an `jft.optimize_kl` run and stored in the folder passed to the `output_directory` argument of `jft.optimize_kl`
The final output of `jft.optimize_kl` is a collection of approximate posterior samples and is provided via an instance of `jft.Samples`. You can iterate over these samples samples. With `jft.mean_and_std` you can get the mean and the standard deviation.
%% Cell type:code id: tags:
``` python
from IPython.display import clear_output
def _imshow(figure, field, ax, title, vmin = 0, vmax = None, cmap='afmhot'):
im0 = ax.imshow(field.T, origin = 'lower', extent = [0, x_fov, 0, y_fov], cmap=cmap,
vmin = vmin, vmax = vmax)
figure.colorbar(im0, ax=ax)
ax.set_xlabel(r'x $\left(\mu as\right)$')
ax.set_ylabel(r'y $\left(\mu as\right)$')
ax.set_title(title)
def _plot_histogram(nodes, hist, ax, title, ):
nodes = 0.5*(nodes[1:] + nodes[:-1])
ax.bar(nodes, hist)
rs = np.arange(nodes[0], nodes[-1], 0.1)
gauss = np.exp(-0.5*rs**2)/np.sqrt(2*np.pi)
ax.plot(rs, gauss, 'k--', label = r'standard Gauss')
ax.set_xlabel(r'$r$')
ax.set_ylabel(r'$P(r)$')
ax.set_title(title)
ax.legend()
ax.set_xlim([nodes[0], nodes[-1]])
def plotting_callback(samples, opt_state):
clear_output(wait=True)
sky_mean, sky_std = jft.mean_and_std(tuple(sky(s) for s in samples))
pspec_mean = jft.mean(tuple(cfm.amplitude(s)[1:] for s in samples))
pspecs = tuple(cfm.amplitude(s)[1:] for s in samples)
residual = lambda x: (signal_response(x) - mock_observation.vis.val)*mock_observation.weight.sqrt().val
nbins = 50
hist = np.zeros(nbins)
for s in samples:
rr = residual(s)
rr = rr.flatten()
rr = np.concatenate((rr.real, rr.imag))
wgt, nodes = np.histogram(rr, nbins, range=[-5, 5])
hist += wgt/wgt.sum()/(nodes[1]-nodes[0])
hist /= len(samples)
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,13))
axs = axs.flatten()
_imshow(fig, sky_mean, axs[0], 'Sky brightness mean', vmax = 2e19)
_imshow(fig, sky_std, axs[1], 'Sky brightness std', cmap = 'viridis')
axs[1].yaxis.set_visible(False)
# k_lengths = pspec_mean.domain[0].k_lengths[1:]
lbl = 'samples'
for ps in pspecs:
axs[2].loglog(ps, 'k-', alpha = 0.5, label = lbl)
lbl = None
axs[2].loglog(pspec_mean, 'r-', label = 'mean')
# axs[2].set_xlim([k_lengths[0], k_lengths[-1]])
# axs[2].set_xscale('log')
# axs[2].set_yscale('log')
axs[2].set_xlabel(r'$|k|$')
axs[2].set_ylabel(r'$P_s\left(|k|\right)$')
axs[2].set_title(r'Power-spectrum of log-sky brightness')
axs[2].legend()
_plot_histogram(nodes, hist, axs[3],
r'Inverse noise weighted data residual ($r$) distribution ($P(r)$)')
fig.tight_layout()
plt.show();
```
%% Cell type:markdown id: tags:
WARNING: The entire reconstruction takes a few minutes to run.
%% Cell type:code id: tags:
``` python
# Minimization parameters and minimizers for the VI algorithm
def sample_mode_update(i):
return "nonlinear_sample"
def draw_linear_kwargs(i):
return dict(cg_name="SL", cg_kwargs=dict(absdelta=absdelta / 10.0, maxiter=50))
def kl_kwargs(i):
return dict(
minimize_kwargs=dict(
name="M", absdelta=absdelta, xtol=1e-8, cg_kwargs=dict(name="MCG"), maxiter=10
)
)
n_vi_iterations = 15
delta = 1e-8
absdelta = delta * x_npix*y_npix
n_samples = (lambda iiter: 2 if iiter < 10 else 5) # Number of samples used for KL approximation
key, subkey = random.split(key)
pos_init = jft.Vector(jft.random_like(subkey, sky.domain))
pos_init = jft.Vector(jft.random_like(subkey, log_sky.domain))
samples, state = jft.optimize_kl(
likelihood,
pos_init,
n_total_iterations=n_vi_iterations,
n_samples=n_samples,
key=key,
draw_linear_kwargs=draw_linear_kwargs,
nonlinearly_update_kwargs=dict(
minimize_kwargs=dict(
name="SN",
xtol=delta,
cg_kwargs=dict(name=None),
maxiter=10,
)
),
kl_kwargs=kl_kwargs,
sample_mode=sample_mode_update,
odir=f"results",
resume=False,
callback=plotting_callback,
)
```
%% Cell type:markdown id: tags:
Comparison of posterior to ground truth
=======================================
%% Cell type:code id: tags:
``` python
mock_sky = np.load('data/mock_signal.npz')['signal']
sky_mean, sky_std = jft.mean_and_std(tuple(sky(s) for s in samples))
sky_samples = tuple(sky(s) for s in samples)
fig, axs = plt.subplots(nrows=3, ncols=2, figsize = (15,18))
_imshow(fig, mock_sky, axs[0,0], 'Sky brightness ground truth', vmax = 2e19)
_imshow(fig, sky_mean, axs[0,1], 'Sky brightness mean', vmax = 2e19)
_imshow(fig, sky_std/sky_mean, axs[1,0], 'Sky brightness relative uncertainty')
_imshow(fig, sky_samples[0], axs[1,1], 'Sky brightness posterior sample (1)', vmax = 2e19)
_imshow(fig, sky_samples[1], axs[2,0], 'Sky brightness posterior sample (2)', vmax = 2e19)
_imshow(fig, sky_samples[2], axs[2,1], 'Sky brightness posterior sample (3)', vmax = 2e19)
fig.tight_layout()
```
%% Cell type:markdown id: tags:
UV - Uncertainty map
--------------------
We can generate and study the posterior uncertainty maps for any kind of quantity we are interested in (i.e., not only for the sky brightness). In particular, we can also take a look at the uncertainty of the sky brightness in the UV plane. Comparing this to the measured UV-tracks is insghtfull.
%% Cell type:code id: tags:
``` python
lim_uv = 1.2E7
u = np.linspace(-lim_uv, lim_uv, num=257)
v = np.linspace(-lim_uv, lim_uv, num=257)
uu, vv = np.meshgrid(u,v)
ww = np.zeros_like(uu)
uvw = np.stack(tuple(a.flatten() for a in (uu,vv,ww)), axis = -1)
grid_antennas = rve.AntennaPositions(uvw)
vis = np.zeros(uu.size, dtype=np.complex128).reshape((1,-1,1))
wgts = np.ones_like(vis)
grid_obs = rve.Observation(grid_antennas, vis, wgts, rve.Polarization.trivial(),
mock_observation.freq)
grid_R = jre.InterferometryResponseDucc(grid_obs, x_npix, y_npix, dx, dy,
do_wgridding=False, epsilon=1e-10)
def intensity_uv(inp):
r = grid_R(inp)
r = jnp.sqrt((r.real**2 + r.imag**2))
return r.reshape(uu.shape)
uv_mean, uv_std = jft.mean_and_std(tuple(intensity_uv(sky(s)) for s in samples))
uv_var = uv_std**2
uv_gt = intensity_uv(mock_sky)
def uv_plot(inp, pre = ""):
f, ax = plt.subplots(ncols = 2, figsize = (15,10))
ax[0].imshow(inp, origin='lower', extent=(-lim_uv, lim_uv, -lim_uv, lim_uv))
ax[1].imshow(inp, origin='lower', extent=(-lim_uv, lim_uv, -lim_uv, lim_uv))
ax[1].scatter(mock_observation.uvw[:,0], mock_observation.uvw[:,1],
c='r', marker='.', s = 12, alpha = 0.5, label = 'measurements')
ax[1].scatter(-mock_observation.uvw[:,0], -mock_observation.uvw[:,1],
c='c', marker='.', s = 10, alpha = 0.5, label = 'measurements (c.c.)')
ax[1].legend()
for a in ax:
a.set_xlabel('u')
a.set_ylabel('v')
ax[0].set_title(pre+'UV posterior uncertainty')
ax[1].set_title(pre+'UV posterior uncertainty & measured UV-tracks')
fig.tight_layout()
plt.show()
uv_plot(np.log(uv_var))
f, ax = plt.subplots(ncols = 2, figsize = (15,10))
ax[0].imshow(np.log(uv_gt), origin='lower',
extent=(-lim_uv, lim_uv, -lim_uv, lim_uv))
ax[1].imshow(np.log(uv_mean), origin='lower',
extent=(-lim_uv, lim_uv, -lim_uv, lim_uv))
for a in ax:
a.set_xlabel('u')
a.set_ylabel('v')
ax[0].set_title('true UV intensity')
ax[1].set_title('rec. UV intensity')
fig.tight_layout()
plt.show()
uv_plot(np.log(uv_var / uv_gt), pre = 'relative ')
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment