Skip to content
Snippets Groups Projects

Work on vi visualized

Merged Philipp Frank requested to merge work_on_vi_visualized into more_samplers
1 unresolved thread
@@ -19,13 +19,9 @@
@@ -19,13 +19,9 @@
###############################################################################
###############################################################################
# Variational Inference (VI)
# Variational Inference (VI)
#
#
# This script demonstrates how MGVI and GeoVI work for an inference problem
# This script demonstrates how MGVI, GeoVI, MeanfieldVI and FullCovarianceVI
# with only two real quantities of interest. This enables us to plot the
# work for an inference problem with only two real quantities of interest. This
# posterior probability density as two-dimensional plot. The approximate
# enables us to plot the posterior probability density as two-dimensional plot.
# posterior samples are contrasted with the maximum-a-posterior (MAP) solution
# together with samples drawn with the Laplace method. This method uses the
# local curvature at the MAP solution as inverse covariance of a Gaussian
# probability density.
###############################################################################
###############################################################################
import numpy as np
import numpy as np
@@ -74,65 +70,93 @@ def main():
@@ -74,65 +70,93 @@ def main():
plt.pause(2.0)
plt.pause(2.0)
plt.close()
plt.close()
pos = ift.from_random(ham.domain, 'normal')
mapx = xx[z==np.max(z)]
MAP = ift.EnergyAdapter(pos, ham, want_metric=True)
mapy = yy[z==np.max(z)]
minimizer = ift.NewtonCG(
meanx = (xx*z).sum()/z.sum()
ift.GradientNormController(iteration_limit=20, name='Mini'))
meany = (yy*z).sum()/z.sum()
MAP, _ = minimizer(MAP)
map_xs, map_ys = [], []
for ii in range(10):
samp = (MAP.metric.draw_sample(from_inverse=True) + MAP.position).val
map_xs.append(samp['a'])
map_ys.append(samp['b'])
 
n_samples = 100
minimizer = ift.NewtonCG(
minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=2, name='Mini'))
ift.GradientNormController(iteration_limit=3, name='Mini'))
pos = pos1 = ift.from_random(ham.domain, 'normal')
IC = ift.StochasticAbsDeltaEnergyController(0.5, iteration_limit=20,
fig, axs = plt.subplots(2, 1, figsize=[12, 8])
name='advi')
for ii in range(15):
stochastic_minimizer_mf = ift.ADVIOptimizer(IC, eta = 0.3)
if ii % 3 == 0:
stochastic_minimizer_fc = ift.ADVIOptimizer(IC, eta = 0.3)
# Resample
posmg = posgeo = posmf = posfc = ift.from_random(ham.domain, 'normal')
mgkl = ift.MetricGaussianKL(pos, ham, 100, False)
fc = ift.FullCovarianceVI(posfc, ham, 10, False, initial_sig=0.01)
mini_samp = ift.NewtonCG(ift.GradientNormController(iteration_limit=5))
mf = ift.MeanFieldVI(posmf, ham, 10, False, initial_sig=0.01)
geokl = ift.GeoMetricKL(pos1, ham, 100, mini_samp, False)
fig, axs = plt.subplots(2, 2, figsize=[12, 8])
for axx in axs:
axs = axs.flatten()
 
 
def update_plot(runs):
 
for axx, (nn, kl, pp, sam) in zip(axs,runs):
axx.clear()
axx.clear()
im = axx.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
axx.imshow(z.T, origin='lower', cmap='gist_earth_r',
cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
if ii == 0:
extent=x_limits_scaled + y_limits)
cbar = plt.colorbar(im, ax=axx)
cbar.ax.set_ylabel('pdf')
for jj, nn, kl, pp in ((0, "MGVI", mgkl, pos), (1, "GeoVI", geokl, pos1)):
xs, ys = [], []
xs, ys = [], []
for samp in kl.samples:
if sam:
samp = (samp + pp).val
samples = (samp + pp for samp in kl.samples)
xs.append(samp['a'])
else:
ys.append(samp['b'])
samples = (kl.draw_sample() for _ in range(n_samples))
axs[jj].scatter(np.array(xs)*scale, np.array(ys), label=f'{nn} samples')
mx, my = 0., 0.
axs[jj].scatter(pp.val['a']*scale, pp.val['b'], label=f'{nn} latent mean')
for samp in samples:
axs[jj].set_title(nn)
a = samp.val['a']
xs.append(a)
for axx in axs:
mx += a
axx.scatter(np.array(map_xs)*scale, np.array(map_ys),
b = samp.val['b']
label='Laplace samples')
ys.append(b)
axx.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
my += b
label='Maximum a posterior solution')
mx /= n_samples
 
my /= n_samples
 
axx.scatter(np.array(xs)*scale, np.array(ys),
 
label = f'{nn} samples')
 
axx.scatter(mx*scale, my, label = f'{nn} mean')
 
axx.scatter(mapx*scale, mapy, label = 'MAP')
 
axx.scatter(meanx*scale, meany, label = 'Posterior mean')
 
axx.set_title(nn)
axx.set_xlim(x_limits_scaled)
axx.set_xlim(x_limits_scaled)
axx.set_ylim(y_limits)
axx.set_ylim(y_limits)
axx.set_ylabel('y')
axx.legend(loc='lower right')
axx.legend(loc='lower right')
axs[0].xaxis.set_visible(False)
axs[0].xaxis.set_visible(False)
axs[1].set_xlabel('x')
axs[1].xaxis.set_visible(False)
 
axs[1].yaxis.set_visible(False)
 
axs[2].set_xlabel('x')
 
axs[2].set_ylabel('y')
 
axs[3].yaxis.set_visible(False)
 
axs[3].set_xlabel('x')
plt.tight_layout()
plt.tight_layout()
plt.draw()
plt.draw()
plt.pause(1.0)
plt.pause(2.0)
 
 
for ii in range(20):
 
if ii % 2 == 0:
 
# Resample GeoVI and MGVI
 
mgkl = ift.MetricGaussianKL(posmg, ham, n_samples, False)
 
mini_samp = ift.NewtonCG(
 
ift.AbsDeltaEnergyController(1E-8, iteration_limit=5))
 
geokl = ift.GeoMetricKL(posgeo, ham, n_samples, mini_samp, False)
 
 
runs = (("MGVI", mgkl, posmg, True),
 
("GeoVI", geokl, posgeo, True),
 
("MeanfieldVI", mf, posmf, False),
 
("FullCovarianceVI", fc, posfc, False))
 
update_plot(runs)
mgkl, _ = minimizer(mgkl)
mgkl, _ = minimizer(mgkl)
geokl, _ = minimizer(geokl)
geokl, _ = minimizer(geokl)
pos = mgkl.position
mf.minimize(stochastic_minimizer_mf)
pos1 = geokl.position
fc.minimize(stochastic_minimizer_fc)
 
posmg = mgkl.position
 
posgeo = geokl.position
 
posmf = mf.mean
 
posfc = fc.mean
 
runs = (("MGVI", mgkl, posmg, True),
 
("GeoVI", geokl, posgeo, True),
 
("MeanfieldVI", mf, posmf, False),
 
("FullCovarianceVI", fc, posfc, False))
 
update_plot(runs)
ift.logger.info('Finished')
ift.logger.info('Finished')
# Uncomment the following line in order to leave the plots open
# Uncomment the following line in order to leave the plots open
# plt.show()
# plt.show()
Loading