Commit 55b560a6 authored by Philipp Frank's avatar Philipp Frank
Browse files

include mfvi and fcvi into vi visualized demo

parent 98078146
......@@ -74,65 +74,91 @@ def main():
plt.pause(2.0)
plt.close()
pos = ift.from_random(ham.domain, 'normal')
MAP = ift.EnergyAdapter(pos, ham, want_metric=True)
minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=20, name='Mini'))
MAP, _ = minimizer(MAP)
map_xs, map_ys = [], []
for ii in range(10):
samp = (MAP.metric.draw_sample(from_inverse=True) + MAP.position).val
map_xs.append(samp['a'])
map_ys.append(samp['b'])
mapx = xx[z==np.max(z)]
mapy = yy[z==np.max(z)]
meanx = (xx*z).sum()/z.sum()
meany = (yy*z).sum()/z.sum()
n_samples = 100
minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=2, name='Mini'))
pos = pos1 = ift.from_random(ham.domain, 'normal')
fig, axs = plt.subplots(2, 1, figsize=[12, 8])
for ii in range(15):
if ii % 3 == 0:
# Resample
mgkl = ift.MetricGaussianKL(pos, ham, 100, False)
mini_samp = ift.NewtonCG(ift.GradientNormController(iteration_limit=5))
geokl = ift.GeoMetricKL(pos1, ham, 100, mini_samp, False)
for axx in axs:
IC = ift.StochasticAbsDeltaEnergyController(0.1, iteration_limit=20,
name='advi')
stochastic_minimizer_mf = ift.ADVIOptimizer(IC, eta = 0.5)
stochastic_minimizer_fc = ift.ADVIOptimizer(IC, eta = 0.5)
posmg = posgeo = posmf = posfc = ift.from_random(ham.domain, 'normal')
fc = ift.FullCovarianceVI(posfc, ham, 10, False, initial_sig=0.01)
mf = ift.MeanFieldVI(posmf, ham, 10, False, initial_sig=0.01)
fig, axs = plt.subplots(2, 2, figsize=[12, 8])
axs = axs.flatten()
def update_plot(runs):
for axx, (nn, kl, pp, sam) in zip(axs,runs):
axx.clear()
im = axx.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
if ii == 0:
cbar = plt.colorbar(im, ax=axx)
cbar.ax.set_ylabel('pdf')
for jj, nn, kl, pp in ((0, "MGVI", mgkl, pos), (1, "GeoVI", geokl, pos1)):
axx.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
xs, ys = [], []
for samp in kl.samples:
samp = (samp + pp).val
xs.append(samp['a'])
ys.append(samp['b'])
axs[jj].scatter(np.array(xs)*scale, np.array(ys), label=f'{nn} samples')
axs[jj].scatter(pp.val['a']*scale, pp.val['b'], label=f'{nn} latent mean')
axs[jj].set_title(nn)
for axx in axs:
axx.scatter(np.array(map_xs)*scale, np.array(map_ys),
label='Laplace samples')
axx.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
label='Maximum a posterior solution')
if sam:
samples = (samp + pp for samp in kl.samples)
else:
samples = (kl.draw_sample() for _ in range(n_samples))
mx, my = 0., 0.
for samp in samples:
a = samp.val['a']
xs.append(a)
mx += a
b = samp.val['b']
ys.append(b)
my += b
mx /= n_samples
my /= n_samples
axx.scatter(np.array(xs)*scale, np.array(ys), label=f'{nn} samples')
axx.scatter(mx*scale, my, label=f'{nn} mean')
axx.scatter(mapx*scale, mapy, label = 'MAP')
axx.scatter(meanx*scale, meany, label = 'Posterior mean')
axx.set_title(nn)
axx.set_xlim(x_limits_scaled)
axx.set_ylim(y_limits)
axx.set_ylabel('y')
axx.legend(loc='lower right')
axs[0].xaxis.set_visible(False)
axs[1].set_xlabel('x')
axs[1].xaxis.set_visible(False)
axs[1].yaxis.set_visible(False)
axs[2].set_xlabel('x')
axs[2].set_ylabel('y')
axs[3].yaxis.set_visible(False)
axs[3].set_xlabel('x')
plt.tight_layout()
plt.draw()
plt.pause(1.0)
plt.pause(2.0)
for ii in range(15):
if ii % 2 == 0:
# Resample
mgkl = ift.MetricGaussianKL(posmg, ham, n_samples, False)
mini_samp = ift.NewtonCG(ift.AbsDeltaEnergyController(1E-8,
iteration_limit=5))
geokl = ift.GeoMetricKL(posgeo, ham, n_samples, mini_samp, False)
runs = (("MGVI", mgkl, posmg, True),
("GeoVI", geokl, posgeo, True),
("MeanfieldVI", mf, posmf, False),
("FullCovarianceVI", fc, posfc, False))
update_plot(runs)
mgkl, _ = minimizer(mgkl)
geokl, _ = minimizer(geokl)
pos = mgkl.position
pos1 = geokl.position
mf.minimize(stochastic_minimizer_mf)
fc.minimize(stochastic_minimizer_fc)
posmg = mgkl.position
posgeo = geokl.position
posmf = mf.mean
posfc = fc.mean
runs = (("MGVI", mgkl, posmg, True),
("GeoVI", geokl, posgeo, True),
("MeanfieldVI", mf, posmf, False),
("FullCovarianceVI", fc, posfc, False))
update_plot(runs)
ift.logger.info('Finished')
# Uncomment the following line in order to leave the plots open
# plt.show()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment