Skip to content
Snippets Groups Projects
Commit fd6c6ade authored by Reimar Leike's avatar Reimar Leike
Browse files

Use NewtonCG, Show marginal densities

parent ed5c4f67
Branches
No related tags found
1 merge request!468Visual demo
Pipeline #75232 passed
......@@ -32,8 +32,8 @@ if __name__ == '__main__':
icsamp = ift.AbsDeltaEnergyController(deltaE=0.1, iteration_limit=2)
ham = ift.StandardHamiltonian(lh, icsamp)
x_limits = [-2/uninformative_scaling, 2/uninformative_scaling]
y_limits = [-4, 2]
x_limits = [-5/uninformative_scaling, 5/uninformative_scaling]
y_limits = [-4, 4]
x = np.linspace(*x_limits, num=101)
y = np.linspace(*y_limits, num=101)
z = np.empty((x.size, y.size))
......@@ -41,10 +41,20 @@ if __name__ == '__main__':
for jj, yy in enumerate(y):
inp = ift.MultiField.from_raw(lh.domain, {'a': xx, 'b': yy})
z[ii, jj] = np.exp(-1*ham(inp).val)
plt.plot(y, np.sum(z, axis=0))
plt.xlabel('y')
plt.ylabel('pdf')
plt.title('marginal density')
plt.show()
plt.plot(x*uninformative_scaling, np.sum(z, axis=1))
plt.xlabel('x')
plt.ylabel('pdf')
plt.title('marginal density')
plt.show()
pos = ift.from_random('normal', ham.domain)
MAP = ift.EnergyAdapter(pos, ham, want_metric=True)
minimizer = ift.SteepestDescent(ift.GradientNormController(iteration_limit=20,
minimizer = ift.NewtonCG(ift.GradientNormController(iteration_limit=20,
name='Mini'))
MAP, _ = minimizer(MAP)
map_xs, map_ys = [], []
......@@ -53,15 +63,19 @@ if __name__ == '__main__':
map_xs.append(samp['a'])
map_ys.append(samp['b'])
minimizer = ift.SteepestDescent(
ift.GradientNormController(iteration_limit=1, name='Mini'))
minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=2, name='Mini'))
pos = ift.from_random('normal', ham.domain)
for ii in range(10):
if ii % 2 == 0:
for ii in range(15):
if ii % 3 == 0:
mgkl = ift.MetricGaussianKL(pos, ham, 40)
plt.cla()
plt.contour(x*uninformative_scaling, y, z.T)
plt.imshow(z.T, origin='lower',
extent=(x_limits[0]*uninformative_scaling,
x_limits[1]*uninformative_scaling)+tuple(y_limits), vmin=0., vmax=4)
if ii==0:
plt.colorbar()
xs, ys = [], []
for samp in mgkl.samples:
samp = (samp + pos).val
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment