From a3e05700cf4bd854351889a8f48c515b1ed5899c Mon Sep 17 00:00:00 2001 From: Philipp Frank <philipp@mpa-garching.mpg.de> Date: Mon, 7 Nov 2022 09:01:44 -0500 Subject: [PATCH] vi update params --- variational_inference_visualized.py | 217 ++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 variational_inference_visualized.py diff --git a/variational_inference_visualized.py b/variational_inference_visualized.py new file mode 100644 index 0000000..14a3286 --- /dev/null +++ b/variational_inference_visualized.py @@ -0,0 +1,217 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2021 Max-Planck-Society +# Authors: Reimar Leike, Philipp Arras, Philipp Frank +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. + +############################################################################### +# Variational Inference (VI) +# +# This script demonstrates how MGVI, GeoVI, MeanfieldVI and FullCovarianceVI +# work for an inference problem with only two real quantities of interest. This +# enables us to plot the posterior probability density as two-dimensional plot. +############################################################################### + +import numpy as np +import matplotlib.pyplot as plt +from functools import partial +from matplotlib.colors import LogNorm + +import nifty8 as ift + + +def main(): + dom = ift.UnstructuredDomain(1) + + scale = 10. + + def transformation(x,y): + e = x.exp() if isinstance(x, ift.Operator) else np.exp(x) + return scale * e * y + + def jac_transformation(x,y): + d = scale * np.exp(x) + return np.stack((d*y, d), axis = -1) + + def metric(x,y): + jac = jac_transformation(x,y) + met = np.einsum('ij,ik -> ijk', jac, jac) + met += np.multiply.outer(np.ones(x.shape), np.eye(2)) + return met + + def metric_function(met, fun): + v, U = np.linalg.eigh(met) + fv = fun(v) + return np.einsum('ijk, ik, ilk -> ijl', U, fv, U) + + def geo_transformation(x,y, x0, y0): + t = transformation(x, y) + t0 = transformation(x0, y0) + j0 = jac_transformation(x0, y0) + m0 = metric(x0, y0) + s = np.stack((x, y), axis = -1) + s0 = np.stack((x0, y0), axis = -1) + g = s - s0 + (j0.T * (t - t0)).T + inv_sq = metric_function(m0, lambda k: 1./np.sqrt(k)) + return np.einsum('ijk, ik -> ij', inv_sq, g) + + def jac_geo_trafo(x,y,x0,y0): + j = jac_transformation(x,y) + j0 = jac_transformation(x0,y0) + m0 = metric(x0,y0) + inv_sq = metric_function(m0, lambda k: 1./np.sqrt(k)) + jg = np.multiply.outer(np.ones_like(x), np.eye(2)) + jg += np.einsum('ij, ik -> ijk', j0, j) + return np.einsum('ijk, ikl -> ijl', inv_sq, jg) + + def mg_prob(x,y, x0,y0, a0,b0): + shp = x.shape + x = x.flatten() + y = y.flatten() + x0 = np.ones_like(x) * x0 + y0 = np.ones_like(y) * y0 + a0 = np.ones_like(x) * a0 + b0 = np.ones_like(y) * b0 + metric0 = metric(x0, y0) + s = np.stack((x, y), axis = -1) + s = s - np.stack((a0, b0), axis = -1) + res = np.einsum('ij, ijk, ik -> i', s, metric0, s) + return np.exp(-0.5*res).reshape(shp) + + def geo_prob(x,y,x0,y0,a0,b0): + shp = x.shape + x = x.flatten() + y = y.flatten() + x0 = np.ones_like(x) * x0 + y0 = np.ones_like(y) * y0 + a0 = np.ones_like(x) * a0 + b0 = np.ones_like(y) * b0 + x = x - a0 + x0 + y = y - b0 + y0 + g = geo_transformation(x,y,x0,y0) + jg = jac_geo_trafo(x,y,x0,y0) + res = np.einsum('ij,ij -> i', g, g) + mymet = np.einsum('ikj, ikl -> ijl', jg, jg) + det = np.linalg.det(mymet) + res -= np.log(det) + return np.exp(-0.5*res).reshape(shp) + + + + a = ift.FieldAdapter(dom, 'a') + b = ift.FieldAdapter(dom, 'b') + model = transformation(a, b) + data = ift.full(dom, 2.) + lh = ift.GaussianEnergy(data=data) @ model + icsamp = ift.AbsDeltaEnergyController(deltaE=0.1, iteration_limit=2) + ham = ift.StandardHamiltonian(lh, icsamp) + + x_limits = [-6, 6] + y_limits = [-6, 6] + x = np.linspace(*x_limits, num=401) + y = np.linspace(*y_limits, num=401) + xx, yy = np.meshgrid(x, y, indexing='ij') + + pdfs = [mg_prob, geo_prob] + pdfs = [partial(p, xx, yy) for p in pdfs] + + def np_ham(x, y): + prior = x**2 + y**2 + mean = transformation(x, y) + d = data.val[0] + lh = .5*(d - mean)**2 + return lh + prior + + z = np.exp(-1.*np_ham(xx, yy)) + z /= np.max(z) + + mapx = xx[z == np.max(z)] + mapy = yy[z == np.max(z)] + meanx = (xx*z).sum()/z.sum() + meany = (yy*z).sum()/z.sum() + + fig, axs = plt.subplots(1, 2, figsize=[12, 8]) + axs = axs.flatten() + + def update_plot(runs): + for axx, (nn, kl, m), prob in zip(axs, runs, pdfs): + axx.clear() + axx.imshow(z.T, origin='lower', cmap='gist_earth_r', + norm=LogNorm(vmin=1e-4, vmax=np.max(z)), + extent=x_limits + y_limits) + + mx, my = m['a'].val[0], m['b'].val[0] + mm = kl.position + ax, ay = mm['a'].val[0], mm['b'].val[0] + p = prob(mx, my, ax, ay) + p[p == np.nan] = 0. + axx.contour(xx, yy, p, levels=np.linspace(0,np.max(p),11)) + + samples = kl.samples.iterator() + samples = [[s.val['a'][0], s.val['b'][0]] for s in samples] + samples = np.array(samples) + mmx = np.sum(xx*p)/np.sum(p) + mmy = np.sum(yy*p)/np.sum(p) + + axx.scatter(samples[:,0], samples[:,1], + label=f'{nn} samples') + axx.scatter(mmx, mmy, label=f'{nn} mean') + axx.scatter(mapx, mapy, label='MAP') + axx.scatter(meanx, meany, label='Posterior mean') + axx.set_title(nn) + axx.set_xlim(x_limits) + axx.set_ylim(y_limits) + axx.legend(loc='lower right') + axs[1].yaxis.set_visible(False) + axs[0].set_xlabel('x') + axs[0].set_ylabel('y') + axs[1].set_xlabel('x') + fig.tight_layout() + plt.draw() + plt.pause(2.) + + + n_samples = 20 + minimizer = ift.NewtonCG( + ift.GradientNormController(iteration_limit=1, name='Mini')) + posmg = ift.full(ham.domain, -5.) + posgeo = ift.full(ham.domain, -5.) + + for ii in range(30): + if ii % 3 == 0: + # Resample GeoVI and MGVI + mgkl = ift.SampledKLEnergy(posmg, ham, n_samples, None, True) + mini_samp = ift.NewtonCG( + ift.AbsDeltaEnergyController(1E-8, iteration_limit=5)) + geokl = ift.SampledKLEnergy(posgeo, ham, n_samples, mini_samp, + True) + mg_m = mgkl.position + geo_m = geokl.position + runs = (("MGVI", mgkl, mg_m), ("GeoVI", geokl, geo_m)) + update_plot(runs) + + mgkl, _ = minimizer(mgkl) + geokl, _ = minimizer(geokl) + posmg = mgkl.position + posgeo = geokl.position + runs = (("MGVI", mgkl, mg_m), ("GeoVI", geokl, geo_m)) + update_plot(runs) + ift.logger.info('Finished') + # Uncomment the following line in order to leave the plots open + plt.show() + + +if __name__ == '__main__': + main() -- GitLab