Commit 5a4f59a7 authored by Philipp Arras's avatar Philipp Arras
Browse files

Integrate GeoKL into visualized demo

parent 57d88b83
...@@ -143,7 +143,7 @@ run_curve_fitting: ...@@ -143,7 +143,7 @@ run_curve_fitting:
paths: paths:
- '*.png' - '*.png'
run_visual_mgvi: run_visual_vi:
stage: demo_runs stage: demo_runs
script: script:
- python3 demos/mgvi_visualized.py - python3 demos/variational_inference_visualized.py
...@@ -11,21 +11,21 @@ ...@@ -11,21 +11,21 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2020 Max-Planck-Society # Copyright(C) 2013-2021 Max-Planck-Society
# Authors: Reimar Leike, Philipp Arras # Authors: Reimar Leike, Philipp Arras, Philipp Frank
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
############################################################################### ###############################################################################
# Metric Gaussian Variational Inference (MGVI) # Variational Inference (VI)
# #
# This script demonstrates how MGVI works for an inference problem with only # This script demonstrates how MGVI and GeoVI work for an inference problem
# two real quantities of interest. This enables us to plot the posterior # with only two real quantities of interest. This enables us to plot the
# probability density as two-dimensional plot. The posterior samples generated # posterior probability density as two-dimensional plot. The approximate
# by MGVI are contrasted with the maximum-a-posterior (MAP) solution together # posterior samples are contrasted with the maximum-a-posterior (MAP) solution
# with samples drawn with the Laplace method. This method uses the local # together with samples drawn with the Laplace method. This method uses the
# curvature at the MAP solution as inverse covariance of a Gaussian probability # local curvature at the MAP solution as inverse covariance of a Gaussian
# density. # probability density.
############################################################################### ###############################################################################
import numpy as np import numpy as np
...@@ -36,8 +36,6 @@ import nifty7 as ift ...@@ -36,8 +36,6 @@ import nifty7 as ift
def main(): def main():
use_geo = False
name = 'GEO' if use_geo else 'MGVI'
dom = ift.UnstructuredDomain(1) dom = ift.UnstructuredDomain(1)
scale = 10 scale = 10
...@@ -89,42 +87,52 @@ def main(): ...@@ -89,42 +87,52 @@ def main():
minimizer = ift.NewtonCG( minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=2, name='Mini')) ift.GradientNormController(iteration_limit=2, name='Mini'))
pos = ift.from_random(ham.domain, 'normal') pos = pos1 = ift.from_random(ham.domain, 'normal')
plt.figure(figsize=[12, 8]) fig, axs = plt.subplots(2, 1, figsize=[12, 8])
for ii in range(15): for ii in range(15):
if ii % 3 == 0: if ii % 3 == 0:
if use_geo: # Resample
mini_samp = ift.NewtonCG( mgkl = ift.MetricGaussianKL(pos, ham, 100, False)
ift.GradientNormController(iteration_limit=5)) mini_samp = ift.NewtonCG(ift.GradientNormController(iteration_limit=5))
mgkl = ift.GeoMetricKL(pos, ham, 100, mini_samp, False) geokl = ift.GeoMetricKL(pos1, ham, 100, mini_samp, False)
else:
mgkl = ift.MetricGaussianKL(pos, ham, 100, False) for axx in axs:
axx.clear()
plt.cla() im = axx.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
plt.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)), cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
cmap='gist_earth_r', extent=x_limits_scaled + y_limits) if ii == 0:
if ii == 0: cbar = plt.colorbar(im, ax=axx)
cbar = plt.colorbar() cbar.ax.set_ylabel('pdf')
cbar.ax.set_ylabel('pdf')
xs, ys = [], [] for jj, nn, kl, pp in ((0, "MGVI", mgkl, pos), (1, "GeoVI", geokl, pos1)):
for samp in mgkl.samples: xs, ys = [], []
samp = (samp + pos).val for samp in kl.samples:
xs.append(samp['a']) samp = (samp + pp).val
ys.append(samp['b']) xs.append(samp['a'])
plt.scatter(np.array(map_xs)*scale, np.array(map_ys), ys.append(samp['b'])
label='Laplace samples') axs[jj].scatter(np.array(xs)*scale, np.array(ys), label=f'{nn} samples')
plt.scatter(np.array(xs)*scale, np.array(ys), label=name+' samples') axs[jj].scatter(pp.val['a']*scale, pp.val['b'], label=f'{nn} latent mean')
plt.scatter(pos.val['a']*scale, pos.val['b'], label=name+' latent mean') axs[jj].set_title(nn)
plt.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
label='Maximum a posterior solution') for axx in axs:
plt.xlim(x_limits_scaled) axx.scatter(np.array(map_xs)*scale, np.array(map_ys),
plt.ylim(y_limits) label='Laplace samples')
plt.legend() axx.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
label='Maximum a posterior solution')
axx.set_xlim(x_limits_scaled)
axx.set_ylim(y_limits)
axx.set_ylabel('y')
axx.legend(loc='lower right')
axs[0].xaxis.set_visible(False)
axs[1].set_xlabel('x')
plt.tight_layout()
plt.draw() plt.draw()
plt.pause(1.0) plt.pause(1.0)
mgkl, _ = minimizer(mgkl) mgkl, _ = minimizer(mgkl)
geokl, _ = minimizer(geokl)
pos = mgkl.position pos = mgkl.position
pos1 = geokl.position
ift.logger.info('Finished') ift.logger.info('Finished')
# Uncomment the following line in order to leave the plots open # Uncomment the following line in order to leave the plots open
# plt.show() # plt.show()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment