Commit 6b98ae93 authored by Philipp Frank's avatar Philipp Frank
Browse files

faster params

parent 36535068
...@@ -82,10 +82,10 @@ def main(): ...@@ -82,10 +82,10 @@ def main():
n_samples = 100 n_samples = 100
minimizer = ift.NewtonCG( minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=3, name='Mini')) ift.GradientNormController(iteration_limit=3, name='Mini'))
IC = ift.StochasticAbsDeltaEnergyController(0.1, iteration_limit=20, IC = ift.StochasticAbsDeltaEnergyController(0.5, iteration_limit=20,
name='advi') name='advi')
stochastic_minimizer_mf = ift.ADVIOptimizer(IC, eta = 0.5) stochastic_minimizer_mf = ift.ADVIOptimizer(IC, eta = 0.3)
stochastic_minimizer_fc = ift.ADVIOptimizer(IC, eta = 0.5) stochastic_minimizer_fc = ift.ADVIOptimizer(IC, eta = 0.3)
posmg = posgeo = posmf = posfc = ift.from_random(ham.domain, 'normal') posmg = posgeo = posmf = posfc = ift.from_random(ham.domain, 'normal')
fc = ift.FullCovarianceVI(posfc, ham, 10, False, initial_sig=0.01) fc = ift.FullCovarianceVI(posfc, ham, 10, False, initial_sig=0.01)
mf = ift.MeanFieldVI(posmf, ham, 10, False, initial_sig=0.01) mf = ift.MeanFieldVI(posmf, ham, 10, False, initial_sig=0.01)
...@@ -96,8 +96,9 @@ def main(): ...@@ -96,8 +96,9 @@ def main():
def update_plot(runs): def update_plot(runs):
for axx, (nn, kl, pp, sam) in zip(axs,runs): for axx, (nn, kl, pp, sam) in zip(axs,runs):
axx.clear() axx.clear()
axx.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)), axx.imshow(z.T, origin='lower', cmap='gist_earth_r',
cmap='gist_earth_r', extent=x_limits_scaled + y_limits) norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
extent=x_limits_scaled + y_limits)
xs, ys = [], [] xs, ys = [], []
if sam: if sam:
samples = (samp + pp for samp in kl.samples) samples = (samp + pp for samp in kl.samples)
...@@ -113,8 +114,9 @@ def main(): ...@@ -113,8 +114,9 @@ def main():
my += b my += b
mx /= n_samples mx /= n_samples
my /= n_samples my /= n_samples
axx.scatter(np.array(xs)*scale, np.array(ys), label=f'{nn} samples') axx.scatter(np.array(xs)*scale, np.array(ys),
axx.scatter(mx*scale, my, label=f'{nn} mean') label = f'{nn} samples')
axx.scatter(mx*scale, my, label = f'{nn} mean')
axx.scatter(mapx*scale, mapy, label = 'MAP') axx.scatter(mapx*scale, mapy, label = 'MAP')
axx.scatter(meanx*scale, meany, label = 'Posterior mean') axx.scatter(meanx*scale, meany, label = 'Posterior mean')
axx.set_title(nn) axx.set_title(nn)
...@@ -131,13 +133,13 @@ def main(): ...@@ -131,13 +133,13 @@ def main():
plt.tight_layout() plt.tight_layout()
plt.draw() plt.draw()
plt.pause(2.0) plt.pause(2.0)
for ii in range(20): for ii in range(20):
if ii % 2 == 0: if ii % 2 == 0:
# Resample # Resample GeoVI and MGVI
mgkl = ift.MetricGaussianKL(posmg, ham, n_samples, False) mgkl = ift.MetricGaussianKL(posmg, ham, n_samples, False)
mini_samp = ift.NewtonCG(ift.AbsDeltaEnergyController(1E-8, mini_samp = ift.NewtonCG(
iteration_limit=5)) ift.AbsDeltaEnergyController(1E-8, iteration_limit=5))
geokl = ift.GeoMetricKL(posgeo, ham, n_samples, mini_samp, False) geokl = ift.GeoMetricKL(posgeo, ham, n_samples, mini_samp, False)
runs = (("MGVI", mgkl, posmg, True), runs = (("MGVI", mgkl, posmg, True),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment