Skip to content
Snippets Groups Projects
Commit bdd70d17 authored by Philipp Frank's avatar Philipp Frank
Browse files

vi more visualize

parent edd87d5f
Branches
No related tags found
No related merge requests found
Pipeline #147581 passed
......@@ -35,6 +35,8 @@ import nifty8 as ift
def main():
dom = ift.UnstructuredDomain(1)
n_samples = 20
show_geo = True
scale = 10.
def transformation(x,y):
......@@ -108,8 +110,6 @@ def main():
res -= np.log(det)
return np.exp(-0.5*res).reshape(shp)
a = ift.FieldAdapter(dom, 'a')
b = ift.FieldAdapter(dom, 'b')
model = transformation(a, b)
......@@ -145,19 +145,23 @@ def main():
fig, axs = plt.subplots(1, 2, figsize=[12, 8])
axs = axs.flatten()
def update_plot(runs):
def update_plot(runs, start = False):
if not show_geo:
runs = runs[:1]
for axx, (nn, kl, m), prob in zip(axs, runs, pdfs):
axx.clear()
axx.imshow(z.T, origin='lower', cmap='gist_earth_r',
norm=LogNorm(vmin=1e-4, vmax=np.max(z)),
extent=x_limits + y_limits)
extent=x_limits + y_limits, zorder=0)
mx, my = m['a'].val[0], m['b'].val[0]
mm = kl.position
ax, ay = mm['a'].val[0], mm['b'].val[0]
p = prob(mx, my, ax, ay)
p[p == np.nan] = 0.
axx.contour(xx, yy, p, levels=np.linspace(0,np.max(p),11))
p[np.isnan(p)] = 0.
p[np.isinf(p)] = 0.
axx.contour(xx, yy, p, levels=np.linspace(1E-4,np.max(p),15),
zorder=1)
samples = kl.samples.iterator()
samples = [[s.val['a'][0], s.val['b'][0]] for s in samples]
......@@ -166,10 +170,11 @@ def main():
mmy = np.sum(yy*p)/np.sum(p)
axx.scatter(samples[:,0], samples[:,1],
label=f'{nn} samples')
axx.scatter(mmx, mmy, label=f'{nn} mean')
axx.scatter(mapx, mapy, label='MAP')
axx.scatter(meanx, meany, label='Posterior mean')
label=f'{nn} samples', zorder = 2)
axx.scatter(mmx, mmy, label=f'{nn} mean', zorder = 2)
axx.scatter(mapx, mapy, label='MAP', zorder = 2)
axx.scatter(meanx, meany, label='Posterior mean', zorder = 2)
axx.scatter(mx, my, label='Sampling geometry mean', zorder = 2)
axx.set_title(nn)
axx.set_xlim(x_limits)
axx.set_ylim(y_limits)
......@@ -180,10 +185,9 @@ def main():
axs[1].set_xlabel('x')
fig.tight_layout()
plt.draw()
plt.pause(2.)
plt.pause(2. if not start else 5.)
n_samples = 20
minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=1, name='Mini'))
posmg = ift.full(ham.domain, -5.)
......@@ -200,7 +204,7 @@ def main():
mg_m = mgkl.position
geo_m = geokl.position
runs = (("MGVI", mgkl, mg_m), ("GeoVI", geokl, geo_m))
update_plot(runs)
update_plot(runs, start = ii == 0)
mgkl, _ = minimizer(mgkl)
geokl, _ = minimizer(geokl)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment