Commit 6b98ae93 authored by Philipp Frank's avatar Philipp Frank
Browse files

faster params

parent 36535068
......@@ -82,10 +82,10 @@ def main():
n_samples = 100
minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=3, name='Mini'))
IC = ift.StochasticAbsDeltaEnergyController(0.1, iteration_limit=20,
IC = ift.StochasticAbsDeltaEnergyController(0.5, iteration_limit=20,
name='advi')
stochastic_minimizer_mf = ift.ADVIOptimizer(IC, eta = 0.5)
stochastic_minimizer_fc = ift.ADVIOptimizer(IC, eta = 0.5)
stochastic_minimizer_mf = ift.ADVIOptimizer(IC, eta = 0.3)
stochastic_minimizer_fc = ift.ADVIOptimizer(IC, eta = 0.3)
posmg = posgeo = posmf = posfc = ift.from_random(ham.domain, 'normal')
fc = ift.FullCovarianceVI(posfc, ham, 10, False, initial_sig=0.01)
mf = ift.MeanFieldVI(posmf, ham, 10, False, initial_sig=0.01)
......@@ -96,8 +96,9 @@ def main():
def update_plot(runs):
for axx, (nn, kl, pp, sam) in zip(axs,runs):
axx.clear()
axx.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
axx.imshow(z.T, origin='lower', cmap='gist_earth_r',
norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
extent=x_limits_scaled + y_limits)
xs, ys = [], []
if sam:
samples = (samp + pp for samp in kl.samples)
......@@ -113,8 +114,9 @@ def main():
my += b
mx /= n_samples
my /= n_samples
axx.scatter(np.array(xs)*scale, np.array(ys), label=f'{nn} samples')
axx.scatter(mx*scale, my, label=f'{nn} mean')
axx.scatter(np.array(xs)*scale, np.array(ys),
label = f'{nn} samples')
axx.scatter(mx*scale, my, label = f'{nn} mean')
axx.scatter(mapx*scale, mapy, label = 'MAP')
axx.scatter(meanx*scale, meany, label = 'Posterior mean')
axx.set_title(nn)
......@@ -131,13 +133,13 @@ def main():
plt.tight_layout()
plt.draw()
plt.pause(2.0)
for ii in range(20):
if ii % 2 == 0:
# Resample
# Resample GeoVI and MGVI
mgkl = ift.MetricGaussianKL(posmg, ham, n_samples, False)
mini_samp = ift.NewtonCG(ift.AbsDeltaEnergyController(1E-8,
iteration_limit=5))
mini_samp = ift.NewtonCG(
ift.AbsDeltaEnergyController(1E-8, iteration_limit=5))
geokl = ift.GeoMetricKL(posgeo, ham, n_samples, mini_samp, False)
runs = (("MGVI", mgkl, posmg, True),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment