Commit 5a4f59a7 authored by Philipp Arras's avatar Philipp Arras
Browse files

Integrate GeoKL into visualized demo

parent 57d88b83
......@@ -143,7 +143,7 @@ run_curve_fitting:
paths:
- '*.png'
run_visual_mgvi:
run_visual_vi:
stage: demo_runs
script:
- python3 demos/mgvi_visualized.py
- python3 demos/variational_inference_visualized.py
......@@ -11,21 +11,21 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Authors: Reimar Leike, Philipp Arras
# Copyright(C) 2013-2021 Max-Planck-Society
# Authors: Reimar Leike, Philipp Arras, Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
###############################################################################
# Metric Gaussian Variational Inference (MGVI)
# Variational Inference (VI)
#
# This script demonstrates how MGVI works for an inference problem with only
# two real quantities of interest. This enables us to plot the posterior
# probability density as two-dimensional plot. The posterior samples generated
# by MGVI are contrasted with the maximum-a-posterior (MAP) solution together
# with samples drawn with the Laplace method. This method uses the local
# curvature at the MAP solution as inverse covariance of a Gaussian probability
# density.
# This script demonstrates how MGVI and GeoVI work for an inference problem
# with only two real quantities of interest. This enables us to plot the
# posterior probability density as two-dimensional plot. The approximate
# posterior samples are contrasted with the maximum-a-posterior (MAP) solution
# together with samples drawn with the Laplace method. This method uses the
# local curvature at the MAP solution as inverse covariance of a Gaussian
# probability density.
###############################################################################
import numpy as np
......@@ -36,8 +36,6 @@ import nifty7 as ift
def main():
use_geo = False
name = 'GEO' if use_geo else 'MGVI'
dom = ift.UnstructuredDomain(1)
scale = 10
......@@ -89,42 +87,52 @@ def main():
minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=2, name='Mini'))
pos = ift.from_random(ham.domain, 'normal')
plt.figure(figsize=[12, 8])
pos = pos1 = ift.from_random(ham.domain, 'normal')
fig, axs = plt.subplots(2, 1, figsize=[12, 8])
for ii in range(15):
if ii % 3 == 0:
if use_geo:
mini_samp = ift.NewtonCG(
ift.GradientNormController(iteration_limit=5))
mgkl = ift.GeoMetricKL(pos, ham, 100, mini_samp, False)
else:
mgkl = ift.MetricGaussianKL(pos, ham, 100, False)
plt.cla()
plt.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
if ii == 0:
cbar = plt.colorbar()
cbar.ax.set_ylabel('pdf')
xs, ys = [], []
for samp in mgkl.samples:
samp = (samp + pos).val
xs.append(samp['a'])
ys.append(samp['b'])
plt.scatter(np.array(map_xs)*scale, np.array(map_ys),
label='Laplace samples')
plt.scatter(np.array(xs)*scale, np.array(ys), label=name+' samples')
plt.scatter(pos.val['a']*scale, pos.val['b'], label=name+' latent mean')
plt.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
label='Maximum a posterior solution')
plt.xlim(x_limits_scaled)
plt.ylim(y_limits)
plt.legend()
# Resample
mgkl = ift.MetricGaussianKL(pos, ham, 100, False)
mini_samp = ift.NewtonCG(ift.GradientNormController(iteration_limit=5))
geokl = ift.GeoMetricKL(pos1, ham, 100, mini_samp, False)
for axx in axs:
axx.clear()
im = axx.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
if ii == 0:
cbar = plt.colorbar(im, ax=axx)
cbar.ax.set_ylabel('pdf')
for jj, nn, kl, pp in ((0, "MGVI", mgkl, pos), (1, "GeoVI", geokl, pos1)):
xs, ys = [], []
for samp in kl.samples:
samp = (samp + pp).val
xs.append(samp['a'])
ys.append(samp['b'])
axs[jj].scatter(np.array(xs)*scale, np.array(ys), label=f'{nn} samples')
axs[jj].scatter(pp.val['a']*scale, pp.val['b'], label=f'{nn} latent mean')
axs[jj].set_title(nn)
for axx in axs:
axx.scatter(np.array(map_xs)*scale, np.array(map_ys),
label='Laplace samples')
axx.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
label='Maximum a posterior solution')
axx.set_xlim(x_limits_scaled)
axx.set_ylim(y_limits)
axx.set_ylabel('y')
axx.legend(loc='lower right')
axs[0].xaxis.set_visible(False)
axs[1].set_xlabel('x')
plt.tight_layout()
plt.draw()
plt.pause(1.0)
mgkl, _ = minimizer(mgkl)
geokl, _ = minimizer(geokl)
pos = mgkl.position
pos1 = geokl.position
ift.logger.info('Finished')
# Uncomment the following line in order to leave the plots open
# plt.show()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment