Skip to content
Snippets Groups Projects
Commit b4590ea7 authored by Philipp Frank's avatar Philipp Frank
Browse files

update optimize_kl

parent 8482f198
Branches
No related tags found
1 merge request!3Update to latest nifty and add intro files
%% Cell type:markdown id: tags:
Nifty tutorial for radio interferometric imaging
================================================
%% Cell type:markdown id: tags:
Setup
-----
%% Cell type:code id: tags:
``` python
import numpy as np
import matplotlib.pyplot as plt
import nifty8 as ift
import resolve as rve
```
%% Cell type:markdown id: tags:
Load mock data with EHT measurement configuration and visualize uv-plane
%% Cell type:code id: tags:
``` python
mock_observation = rve.Observation.load('data/mock_data_eht_config_SR1_M87_2017_096_hi.npz')
plt.plot(mock_observation.uvw[:,0], -mock_observation.uvw[:,1], 'b.')
plt.xlabel('u')
plt.ylabel('v')
plt.xlim([-10e6,10e6])
plt.ylim([-10e6,10e6])
plt.show()
```
%% Cell type:markdown id: tags:
Set up the FOV and number of pixels of the `nifty` space which defines the image plane.
%% Cell type:code id: tags:
``` python
x_fov, y_fov = 300, 300 # Field of view in x and y directon in myas
x_npix, y_npix = 256, 256 # Number of pixels
dx = rve.str2rad(f'{x_fov}muas') / x_npix
dy = rve.str2rad(f'{y_fov}muas') / y_npix
space = ift.RGSpace((x_npix, y_npix), (dx, dy))
```
%% Cell type:markdown id: tags:
Model of the sky brightness distribution
----------------------------------------
The sky model is going to be a log-normal random process (the log-brightness is Gaussian distributed). As the prior correlation structure of the log-brightness is unkown, it will be generated using a `CorrelatedField` model and the power-spectrum is inferred along with the realization.
%% Cell type:code id: tags:
``` python
# Hyperparameters (mean and std pairs) for prior models of parameters
args = {# Overall offset from zero of Gaussian process
'offset_mean': -np.log(space.scalar_dvol) - 10.,
# Variability of inferred offset amplitude
'offset_std': (3., 1.),
# Amplitude of field fluctuations
'fluctuations': (1.5, 0.5),
# Exponent of power law power spectrum component
'loglogavgslope': (-4., 0.5),
# Amplitude of integrated Wiener process power spectrum component
'flexibility': (0.3, .1),
# How ragged the integrated Wiener process component is
'asperity': None # Passing 'None' disables this part of the model
}
log_signal = ift.SimpleCorrelatedField(space, **args)
sky_model = ift.exp(log_signal) # Exponentiate to recieve a log-normal distributed sky model
pspec = log_signal.power_spectrum # Save the model of the power-spectrum for visualization
```
%% Cell type:markdown id: tags:
Prior samples
-------------
To get a feeling for the prior variability of the sky model we generate several random realizations of the process and visualize them. The power-spectra, generated for each process realization, are depicted in the last panel.
%% Cell type:code id: tags:
``` python
pl = ift.Plot()
pspecs = []
for _ in range(8):
mock = ift.from_random(sky_model.domain)
pl.add(sky_model(mock), cmap='afmhot')
pspecs.append(pspec.force(mock))
pl.add(pspecs)
pl.output()
```
%% Cell type:markdown id: tags:
Alternatively, one can define a log-normal process with a fixed, user-specified spectrum. Note that in case the spectrum does not match the true underlying data-generating process, this may yield suboptimal results!
%% Cell type:code id: tags:
``` python
use_fixed_spectrum = False
def fixed_spectrum_model(space):
hp = space.get_default_codomain()
a, k0 = 1e-36, 1E9
ker = ift.PS_field(ift.PowerSpace(hp), (lambda k: a / (1.+ (k/k0)**6)))
amplitude = ift.PowerDistributor(hp, power_space=ker.domain[0])(ker.sqrt())
HT = ift.HarmonicTransformOperator(hp, target=space)
A = ift.DiagonalOperator(amplitude)
signal = ift.FieldAdapter(hp, 'xi')
signal = ift.exp(HT @ A @ signal)
return (1e-4/signal.target.scalar_weight()) * signal
if use_fixed_spectrum:
sky_model = fixed_spectrum_model(space)
pspec = None
```
%% Cell type:markdown id: tags:
Setup a mock VLBI imaging task using the `InterferometryResponse` of `resolve`. The `mock_observation` contains all information relevant to set up the likelihood, including visibility data, uv-coordinates, and the noise levels of each measurement. The `measurement_sky` contains all relevant information regarding the prior model of the sky brightness distribution. The additional parameters passed to the `InterferometryResponse` control the accuracy and behaviour of the `wgridder` used within `resolve` which defines the response function.
%% Cell type:code id: tags:
``` python
measurement_domain = (rve.PolarizationSpace('I'),) + (rve.IRGSpace(np.zeros(1)),)*2 + (space,)
measurement_sky = ift.DomainChangerAndReshaper(space, measurement_domain) @ sky_model
# Set up response
R = rve.InterferometryResponse(mock_observation, measurement_sky.target,
do_wgridding=False, epsilon=1e-6)
# Extract data and noise level from observation
data = mock_observation.vis
inverse_covariance = ift.DiagonalOperator(mock_observation.weight, sampling_dtype=data.dtype)
# Build gaussian likelihood and apply it to the sky model
likelihood = ift.GaussianEnergy(data = data, inverse_covariance = inverse_covariance)
likelihood = likelihood @ R @ measurement_sky
```
%% Cell type:markdown id: tags:
Solve the inference problem
---------------------------
The `likelihood` together with the `sky_model` fully specify a Bayesian inverse problem and imply a posterior probabiltiy distribution over the degrees of freedom (DOF) of the model. This distribution is in general a high-dimensional (number of pixels + DOF of power spectrum) and non-Gaussian distribution which prohibits analytical integration. To access its information and compute posterior expectation values numerical approximations have to be made.
`nifty` provides multiple ways of psterior approximation, with Variational Inference (VI) being by far the most frequently used method. In VI the posterior distribution is approximated with another distribution by minimizing their respective forward Kullbach-Leibler divergence (KL). In the following, the Geometric VI method is employed which utilizes conceps of differential geometry to provide a local estimate of the distribution function.
Its numerical implementation (`ift.optimize_kl`) consists of a repeated and successive re-approximation of the VI objective function (the KL) via a stochastic estimate. This estimate is generated using the at the time best available approximation of the posterior, and then the KL gets minimized to further improve it. The resulting algorithm cosists of a repeated re-generation of novel samples for the estimate and a successing optimization thereof until convergence is reached.
The internal steps of `ift.optimize_kl` invoke the approximate solution of multiple interdependent optimization problems:
- For sample generation, a linear system of eqations is approximated using the `ConjugateGradient` (CG) method
- Furthermore, the sample generation invokes a non-linear optimization problem approximated using the `NewtonCG` method
- Finally, the approximative distribution is optimized by minimizing the KL between the true posterior and the approximation. This again invokes a non-linear optimization problem approximated with `NewtonCG`.
%% Cell type:markdown id: tags:
Posterior visualization
-----------------------
Before we set run the minimization routine, we set up a `plotting_callback` function for visualization. Note that additional information and plots regarding the reconstruction are generated during an `ift.optimize_kl` run and stored in the folder passed to the `output_directory` argument of `ift.optimize_kl`
The final output of `ift.optimize_kl` is a collection of approximate posterior samples and is provided via an instance of `ift.ResidualSampleList`. A `SampleList` provides a variety of convenience functions such as:
- `average`: to compute samples averages
- `sample_stat`: to get the approximate mean and variance of a model
- `iterator`: a python iterator over all samples
- ...
%% Cell type:code id: tags:
``` python
from IPython.display import clear_output
def _imshow(figure, field, ax, title, vmin = 0, vmax = None, cmap='afmhot'):
im0 = ax.imshow(field.val.T, origin = 'lower', extent = [0, x_fov, 0, y_fov], cmap=cmap,
vmin = vmin, vmax = vmax)
figure.colorbar(im0, ax=ax)
ax.set_xlabel(r'x $\left(\mu as\right)$')
ax.set_ylabel(r'y $\left(\mu as\right)$')
ax.set_title(title)
def _plot_histogram(nodes, hist, ax, title, ):
nodes = 0.5*(nodes[1:] + nodes[:-1])
ax.bar(nodes, hist)
rs = np.arange(nodes[0], nodes[-1], 0.1)
gauss = np.exp(-0.5*rs**2)/np.sqrt(2*np.pi)
ax.plot(rs, gauss, 'k--', label = r'standard Gauss')
ax.set_xlabel(r'$r$')
ax.set_ylabel(r'$P(r)$')
ax.set_title(title)
ax.legend()
ax.set_xlim([nodes[0], nodes[-1]])
def plotting_callback(samples):
clear_output(wait=True)
sky_mean, sky_var = samples.sample_stat(sky_model)
pspec_mean = samples.average(pspec.log().force).exp()
pspecs = samples.iterator(pspec.force)
residual = ift.Adder(mock_observation.vis, neg=True) @ R @ measurement_sky
residual = ift.makeOp(mock_observation.weight.sqrt()) @ residual
nbins = 50
hist = np.zeros(nbins)
for rr in samples.iterator(residual):
rr = rr.val.flatten()
rr = np.concatenate((rr.real, rr.imag))
wgt, nodes = np.histogram(rr, nbins, range=[-5, 5])
hist += wgt/wgt.sum()/(nodes[1]-nodes[0])
hist /= samples.n_samples
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,13))
axs = axs.flatten()
_imshow(fig, sky_mean, axs[0], 'Sky brightness mean', vmax = 2e19)
_imshow(fig, sky_var.sqrt(), axs[1], 'Sky brightness std', cmap = 'viridis')
axs[1].yaxis.set_visible(False)
k_lengths = pspec_mean.domain[0].k_lengths[1:]
lbl = 'samples'
for ps in pspecs:
axs[2].plot(k_lengths, ps.val[1:], 'k-', alpha = 0.5, label = lbl)
lbl = None
axs[2].plot(k_lengths, pspec_mean.val[1:], 'r-', label = 'mean')
axs[2].set_xlim([k_lengths[0], k_lengths[-1]])
axs[2].set_xscale('log')
axs[2].set_yscale('log')
axs[2].set_xlabel(r'$|k|$')
axs[2].set_ylabel(r'$P_s\left(|k|\right)$')
axs[2].set_title(r'Power-spectrum of log-sky brightness')
axs[2].legend()
_plot_histogram(nodes, hist, axs[3],
r'Inverse noise weighted data residual ($r$) distribution ($P(r)$)')
fig.tight_layout()
plt.show();
```
%% Cell type:markdown id: tags:
WARNING: The entire reconstruction takes a few minutes to run.
%% Cell type:code id: tags:
``` python
# Minimization parameters and minimizers for the VI algorithm
ic_sampling = ift.AbsDeltaEnergyController(deltaE=0.05, iteration_limit=50)
ic_sampling_nl = ift.AbsDeltaEnergyController(deltaE=0.5,
iteration_limit=5,
convergence_level=1)
ic_newton_early = ift.AbsDeltaEnergyController(name='Newton',
deltaE=0.1,
iteration_limit=5,
convergence_level=2)
ic_newton_late = ift.AbsDeltaEnergyController(name='Newton',
deltaE=0.1,
iteration_limit=10,
convergence_level=2)
minimizer_early = ift.NewtonCG(ic_newton_early)
minimizer_late = ift.NewtonCG(ic_newton_late)
# Minimize KL between true posterior and approximation. Each iteration includes sample
# generation and optimization.
n_iterations = 15 # Total number of iterations.
n_samples = (lambda iiter: 2 if iiter < 10 else 5) # Number of samples used for KL approximation
minimizer = (lambda iiter: minimizer_early if iiter < 10 else minimizer_late) # When to use which optimizer
minimizer_sampling = (lambda iiter: None if iiter < 10 else ift.NewtonCG(ic_sampling_nl)) # initially MGVI,
# later geoVI
samples = ift.optimize_kl(likelihood, n_iterations, n_samples, minimizer,
ic_sampling, minimizer_sampling,
plottable_operators={"signal": sky_model,
"power spectrum": pspec},
inspect_callback=plotting_callback,
output_directory="mock_inference", overwrite=True)
output_directory="mock_inference")
```
%% Cell type:markdown id: tags:
Comparison of posterior to ground truth
=======================================
%% Cell type:code id: tags:
``` python
mock_sky = np.load('data/mock_signal.npz')['signal']
mock_sky = ift.makeField(space, mock_sky)
sky_mean, sky_var = samples.sample_stat(sky_model)
sky_samples = list(s for s in samples.iterator(sky_model))
fig, axs = plt.subplots(nrows=3, ncols=2, figsize = (15,18))
_imshow(fig, mock_sky, axs[0,0], 'Sky brightness ground truth', vmax = 2e19)
_imshow(fig, sky_mean, axs[0,1], 'Sky brightness mean', vmax = 2e19)
_imshow(fig, sky_var.sqrt()/sky_mean, axs[1,0], 'Sky brightness relative uncertainty')
_imshow(fig, sky_samples[0], axs[1,1], 'Sky brightness posterior sample (1)', vmax = 2e19)
_imshow(fig, sky_samples[1], axs[2,0], 'Sky brightness posterior sample (2)', vmax = 2e19)
_imshow(fig, sky_samples[2], axs[2,1], 'Sky brightness posterior sample (3)', vmax = 2e19)
fig.tight_layout()
```
%% Cell type:code id: tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment