diff --git a/demo_radio_jax.ipynb b/demo_radio_jax.ipynb index aee5e27e0bac1cddfd710efbd37d59b72ef425ae..33b0442e6a0af59a0f623861e4e55da6a8d1d3cb 100644 --- a/demo_radio_jax.ipynb +++ b/demo_radio_jax.ipynb @@ -27,7 +27,6 @@ "from jax import numpy as jnp\n", "from jax import random\n", "import matplotlib.pyplot as plt\n", - "# import nifty8 as ift\n", "import nifty8.re as jft\n", "import resolve as rve\n", "import resolve.re as jre" @@ -133,14 +132,7 @@ "log_sky = cfm.finalize()\n", "\n", "\n", - "class Sky_model(jft.Model):\n", - " def __init__(self, log_sky):\n", - " self.log_sky = log_sky\n", - " super().__init__(init=self.log_sky.init)\n", - "\n", - " def __call__(self, x):\n", - " return jnp.exp(self.log_sky(x))\n", - "sky = Sky_model(log_sky)" + "sky = lambda x: jnp.exp(log_sky(x))" ] }, { @@ -164,7 +156,7 @@ "pspecs = []\n", "for _ in range(8):\n", " key, subkey = random.split(key)\n", - " pos_random = jft.Vector(jft.random_like(subkey, sky.domain))\n", + " pos_random = jft.Vector(jft.random_like(subkey, log_sky.domain))\n", " \n", " ax = axs.pop(0)\n", " ax.imshow(sky(pos_random).T, origin='lower', cmap='afmhot')\n", @@ -180,7 +172,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Setup a mock VLBI imaging task using the `InterferometryResponse` of `resolve`. The `mock_observation` contains all information relevant to set up the likelihood, including visibility data, uv-coordinates, and the noise levels of each measurement. The `measurement_sky` contains all relevant information regarding the prior model of the sky brightness distribution. The additional parameters passed to the `InterferometryResponse` control the accuracy and behaviour of the `wgridder` used within `resolve` which defines the response function." + "Setup a mock VLBI imaging task using the `InterferometryResponse` of `resolve`. The `mock_observation` contains all information relevant to set up the likelihood, including visibility data, uv-coordinates, and the noise levels of each measurement. The `sky` contains all relevant information regarding the prior model of the sky brightness distribution. The additional parameters passed to the `InterferometryResponse` control the accuracy and behaviour of the `wgridder` used within `resolve` which defines the response function." ] }, { @@ -204,7 +196,7 @@ "Solve the inference problem\n", "---------------------------\n", "\n", - "The `likelihood` together with the `sky_model` fully specify a Bayesian inverse problem and imply a posterior probability distribution over the degrees of freedom (DOF) of the model. This distribution is, in general, a high-dimensional (number of pixels + DOF of power spectrum) and non-Gaussian distribution, which prohibits analytical integration. To access its information and compute posterior expectation values, numerical approximations have to be made.\n", + "The `likelihood` together with the `sky` fully specify a Bayesian inverse problem and imply a posterior probability distribution over the degrees of freedom (DOF) of the model. This distribution is, in general, a high-dimensional (number of pixels + DOF of power spectrum) and non-Gaussian distribution, which prohibits analytical integration. To access its information and compute posterior expectation values, numerical approximations have to be made.\n", "\n", "`nifty` provides multiple ways of posterior approximation, with Variational Inference (VI) being by far the most frequently used method. In VI the posterior distribution is approximated with another distribution by minimizing their respective forward Kullbach-Leibler divergence (KL). In the following, the Geometric VI method is employed which utilizes concepts of differential geometry to provide a local estimate of the distribution function.\n", "\n", @@ -223,12 +215,8 @@ "Posterior visualization\n", "-----------------------\n", "\n", - "Before we run the minimization routine, we set up a `plotting_callback` function for visualization. Note that additional information and plots regarding the reconstruction are generated during an `ift.optimize_kl` run and stored in the folder passed to the `output_directory` argument of `ift.optimize_kl`\n", - "The final output of `ift.optimize_kl` is a collection of approximate posterior samples and is provided via an instance of `ift.ResidualSampleList`. A `SampleList` provides a variety of convenience functions such as: \n", - "- `average`: to compute sample averages\n", - "- `sample_stat`: to get the approximate mean and variance of a model\n", - "- `iterator`: a python iterator over all samples\n", - "- ..." + "Before we run the minimization routine, we set up a `plotting_callback` function for visualization. Note that additional information and plots regarding the reconstruction are generated during an `jft.optimize_kl` run and stored in the folder passed to the `output_directory` argument of `jft.optimize_kl`\n", + "The final output of `jft.optimize_kl` is a collection of approximate posterior samples and is provided via an instance of `jft.Samples`. You can iterate over these samples samples. With `jft.mean_and_std` you can get the mean and the standard deviation." ] }, { @@ -344,7 +332,7 @@ "n_samples = (lambda iiter: 2 if iiter < 10 else 5) # Number of samples used for KL approximation\n", "\n", "key, subkey = random.split(key)\n", - "pos_init = jft.Vector(jft.random_like(subkey, sky.domain))\n", + "pos_init = jft.Vector(jft.random_like(subkey, log_sky.domain))\n", "samples, state = jft.optimize_kl(\n", " likelihood,\n", " pos_init,\n",