Commit 55b560a6 authored by Philipp Frank's avatar Philipp Frank
Browse files

include mfvi and fcvi into vi visualized demo

parent 98078146
...@@ -74,65 +74,91 @@ def main(): ...@@ -74,65 +74,91 @@ def main():
plt.pause(2.0) plt.pause(2.0)
plt.close() plt.close()
pos = ift.from_random(ham.domain, 'normal') mapx = xx[z==np.max(z)]
MAP = ift.EnergyAdapter(pos, ham, want_metric=True) mapy = yy[z==np.max(z)]
minimizer = ift.NewtonCG( meanx = (xx*z).sum()/z.sum()
ift.GradientNormController(iteration_limit=20, name='Mini')) meany = (yy*z).sum()/z.sum()
MAP, _ = minimizer(MAP)
map_xs, map_ys = [], []
for ii in range(10):
samp = (MAP.metric.draw_sample(from_inverse=True) + MAP.position).val
map_xs.append(samp['a'])
map_ys.append(samp['b'])
n_samples = 100
minimizer = ift.NewtonCG( minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=2, name='Mini')) ift.GradientNormController(iteration_limit=2, name='Mini'))
pos = pos1 = ift.from_random(ham.domain, 'normal') IC = ift.StochasticAbsDeltaEnergyController(0.1, iteration_limit=20,
fig, axs = plt.subplots(2, 1, figsize=[12, 8]) name='advi')
for ii in range(15): stochastic_minimizer_mf = ift.ADVIOptimizer(IC, eta = 0.5)
if ii % 3 == 0: stochastic_minimizer_fc = ift.ADVIOptimizer(IC, eta = 0.5)
# Resample posmg = posgeo = posmf = posfc = ift.from_random(ham.domain, 'normal')
mgkl = ift.MetricGaussianKL(pos, ham, 100, False) fc = ift.FullCovarianceVI(posfc, ham, 10, False, initial_sig=0.01)
mini_samp = ift.NewtonCG(ift.GradientNormController(iteration_limit=5)) mf = ift.MeanFieldVI(posmf, ham, 10, False, initial_sig=0.01)
geokl = ift.GeoMetricKL(pos1, ham, 100, mini_samp, False)
fig, axs = plt.subplots(2, 2, figsize=[12, 8])
for axx in axs: axs = axs.flatten()
def update_plot(runs):
for axx, (nn, kl, pp, sam) in zip(axs,runs):
axx.clear() axx.clear()
im = axx.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)), axx.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
cmap='gist_earth_r', extent=x_limits_scaled + y_limits) cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
if ii == 0:
cbar = plt.colorbar(im, ax=axx)
cbar.ax.set_ylabel('pdf')
for jj, nn, kl, pp in ((0, "MGVI", mgkl, pos), (1, "GeoVI", geokl, pos1)):
xs, ys = [], [] xs, ys = [], []
for samp in kl.samples: if sam:
samp = (samp + pp).val samples = (samp + pp for samp in kl.samples)
xs.append(samp['a']) else:
ys.append(samp['b']) samples = (kl.draw_sample() for _ in range(n_samples))
axs[jj].scatter(np.array(xs)*scale, np.array(ys), label=f'{nn} samples') mx, my = 0., 0.
axs[jj].scatter(pp.val['a']*scale, pp.val['b'], label=f'{nn} latent mean') for samp in samples:
axs[jj].set_title(nn) a = samp.val['a']
xs.append(a)
for axx in axs: mx += a
axx.scatter(np.array(map_xs)*scale, np.array(map_ys), b = samp.val['b']
label='Laplace samples') ys.append(b)
axx.scatter(MAP.position.val['a']*scale, MAP.position.val['b'], my += b
label='Maximum a posterior solution') mx /= n_samples
my /= n_samples
axx.scatter(np.array(xs)*scale, np.array(ys), label=f'{nn} samples')
axx.scatter(mx*scale, my, label=f'{nn} mean')
axx.scatter(mapx*scale, mapy, label = 'MAP')
axx.scatter(meanx*scale, meany, label = 'Posterior mean')
axx.set_title(nn)
axx.set_xlim(x_limits_scaled) axx.set_xlim(x_limits_scaled)
axx.set_ylim(y_limits) axx.set_ylim(y_limits)
axx.set_ylabel('y')
axx.legend(loc='lower right') axx.legend(loc='lower right')
axs[0].xaxis.set_visible(False) axs[0].xaxis.set_visible(False)
axs[1].set_xlabel('x') axs[1].xaxis.set_visible(False)
axs[1].yaxis.set_visible(False)
axs[2].set_xlabel('x')
axs[2].set_ylabel('y')
axs[3].yaxis.set_visible(False)
axs[3].set_xlabel('x')
plt.tight_layout() plt.tight_layout()
plt.draw() plt.draw()
plt.pause(1.0) plt.pause(2.0)
for ii in range(15):
if ii % 2 == 0:
# Resample
mgkl = ift.MetricGaussianKL(posmg, ham, n_samples, False)
mini_samp = ift.NewtonCG(ift.AbsDeltaEnergyController(1E-8,
iteration_limit=5))
geokl = ift.GeoMetricKL(posgeo, ham, n_samples, mini_samp, False)
runs = (("MGVI", mgkl, posmg, True),
("GeoVI", geokl, posgeo, True),
("MeanfieldVI", mf, posmf, False),
("FullCovarianceVI", fc, posfc, False))
update_plot(runs)
mgkl, _ = minimizer(mgkl) mgkl, _ = minimizer(mgkl)
geokl, _ = minimizer(geokl) geokl, _ = minimizer(geokl)
pos = mgkl.position mf.minimize(stochastic_minimizer_mf)
pos1 = geokl.position fc.minimize(stochastic_minimizer_fc)
posmg = mgkl.position
posgeo = geokl.position
posmf = mf.mean
posfc = fc.mean
runs = (("MGVI", mgkl, posmg, True),
("GeoVI", geokl, posgeo, True),
("MeanfieldVI", mf, posmf, False),
("FullCovarianceVI", fc, posfc, False))
update_plot(runs)
ift.logger.info('Finished') ift.logger.info('Finished')
# Uncomment the following line in order to leave the plots open # Uncomment the following line in order to leave the plots open
# plt.show() # plt.show()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment