Skip to content
Snippets Groups Projects
Commit 55b560a6 authored by Philipp Frank's avatar Philipp Frank
Browse files

include mfvi and fcvi into vi visualized demo

parent 98078146
No related branches found
No related tags found
2 merge requests!648Work on vi visualized,!604Parametric MGVI
Pipeline #103290 passed
...@@ -74,65 +74,91 @@ def main(): ...@@ -74,65 +74,91 @@ def main():
plt.pause(2.0) plt.pause(2.0)
plt.close() plt.close()
pos = ift.from_random(ham.domain, 'normal') mapx = xx[z==np.max(z)]
MAP = ift.EnergyAdapter(pos, ham, want_metric=True) mapy = yy[z==np.max(z)]
minimizer = ift.NewtonCG( meanx = (xx*z).sum()/z.sum()
ift.GradientNormController(iteration_limit=20, name='Mini')) meany = (yy*z).sum()/z.sum()
MAP, _ = minimizer(MAP)
map_xs, map_ys = [], []
for ii in range(10):
samp = (MAP.metric.draw_sample(from_inverse=True) + MAP.position).val
map_xs.append(samp['a'])
map_ys.append(samp['b'])
n_samples = 100
minimizer = ift.NewtonCG( minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=2, name='Mini')) ift.GradientNormController(iteration_limit=2, name='Mini'))
pos = pos1 = ift.from_random(ham.domain, 'normal') IC = ift.StochasticAbsDeltaEnergyController(0.1, iteration_limit=20,
fig, axs = plt.subplots(2, 1, figsize=[12, 8]) name='advi')
for ii in range(15): stochastic_minimizer_mf = ift.ADVIOptimizer(IC, eta = 0.5)
if ii % 3 == 0: stochastic_minimizer_fc = ift.ADVIOptimizer(IC, eta = 0.5)
# Resample posmg = posgeo = posmf = posfc = ift.from_random(ham.domain, 'normal')
mgkl = ift.MetricGaussianKL(pos, ham, 100, False) fc = ift.FullCovarianceVI(posfc, ham, 10, False, initial_sig=0.01)
mini_samp = ift.NewtonCG(ift.GradientNormController(iteration_limit=5)) mf = ift.MeanFieldVI(posmf, ham, 10, False, initial_sig=0.01)
geokl = ift.GeoMetricKL(pos1, ham, 100, mini_samp, False)
fig, axs = plt.subplots(2, 2, figsize=[12, 8])
for axx in axs: axs = axs.flatten()
def update_plot(runs):
for axx, (nn, kl, pp, sam) in zip(axs,runs):
axx.clear() axx.clear()
im = axx.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)), axx.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
cmap='gist_earth_r', extent=x_limits_scaled + y_limits) cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
if ii == 0:
cbar = plt.colorbar(im, ax=axx)
cbar.ax.set_ylabel('pdf')
for jj, nn, kl, pp in ((0, "MGVI", mgkl, pos), (1, "GeoVI", geokl, pos1)):
xs, ys = [], [] xs, ys = [], []
for samp in kl.samples: if sam:
samp = (samp + pp).val samples = (samp + pp for samp in kl.samples)
xs.append(samp['a']) else:
ys.append(samp['b']) samples = (kl.draw_sample() for _ in range(n_samples))
axs[jj].scatter(np.array(xs)*scale, np.array(ys), label=f'{nn} samples') mx, my = 0., 0.
axs[jj].scatter(pp.val['a']*scale, pp.val['b'], label=f'{nn} latent mean') for samp in samples:
axs[jj].set_title(nn) a = samp.val['a']
xs.append(a)
for axx in axs: mx += a
axx.scatter(np.array(map_xs)*scale, np.array(map_ys), b = samp.val['b']
label='Laplace samples') ys.append(b)
axx.scatter(MAP.position.val['a']*scale, MAP.position.val['b'], my += b
label='Maximum a posterior solution') mx /= n_samples
my /= n_samples
axx.scatter(np.array(xs)*scale, np.array(ys), label=f'{nn} samples')
axx.scatter(mx*scale, my, label=f'{nn} mean')
axx.scatter(mapx*scale, mapy, label = 'MAP')
axx.scatter(meanx*scale, meany, label = 'Posterior mean')
axx.set_title(nn)
axx.set_xlim(x_limits_scaled) axx.set_xlim(x_limits_scaled)
axx.set_ylim(y_limits) axx.set_ylim(y_limits)
axx.set_ylabel('y')
axx.legend(loc='lower right') axx.legend(loc='lower right')
axs[0].xaxis.set_visible(False) axs[0].xaxis.set_visible(False)
axs[1].set_xlabel('x') axs[1].xaxis.set_visible(False)
axs[1].yaxis.set_visible(False)
axs[2].set_xlabel('x')
axs[2].set_ylabel('y')
axs[3].yaxis.set_visible(False)
axs[3].set_xlabel('x')
plt.tight_layout() plt.tight_layout()
plt.draw() plt.draw()
plt.pause(1.0) plt.pause(2.0)
for ii in range(15):
if ii % 2 == 0:
# Resample
mgkl = ift.MetricGaussianKL(posmg, ham, n_samples, False)
mini_samp = ift.NewtonCG(ift.AbsDeltaEnergyController(1E-8,
iteration_limit=5))
geokl = ift.GeoMetricKL(posgeo, ham, n_samples, mini_samp, False)
runs = (("MGVI", mgkl, posmg, True),
("GeoVI", geokl, posgeo, True),
("MeanfieldVI", mf, posmf, False),
("FullCovarianceVI", fc, posfc, False))
update_plot(runs)
mgkl, _ = minimizer(mgkl) mgkl, _ = minimizer(mgkl)
geokl, _ = minimizer(geokl) geokl, _ = minimizer(geokl)
pos = mgkl.position mf.minimize(stochastic_minimizer_mf)
pos1 = geokl.position fc.minimize(stochastic_minimizer_fc)
posmg = mgkl.position
posgeo = geokl.position
posmf = mf.mean
posfc = fc.mean
runs = (("MGVI", mgkl, posmg, True),
("GeoVI", geokl, posgeo, True),
("MeanfieldVI", mf, posmf, False),
("FullCovarianceVI", fc, posfc, False))
update_plot(runs)
ift.logger.info('Finished') ift.logger.info('Finished')
# Uncomment the following line in order to leave the plots open # Uncomment the following line in order to leave the plots open
# plt.show() # plt.show()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment