From a3e05700cf4bd854351889a8f48c515b1ed5899c Mon Sep 17 00:00:00 2001
From: Philipp Frank <philipp@mpa-garching.mpg.de>
Date: Mon, 7 Nov 2022 09:01:44 -0500
Subject: [PATCH] vi update params

---
 variational_inference_visualized.py | 217 ++++++++++++++++++++++++++++
 1 file changed, 217 insertions(+)
 create mode 100644 variational_inference_visualized.py

diff --git a/variational_inference_visualized.py b/variational_inference_visualized.py
new file mode 100644
index 0000000..14a3286
--- /dev/null
+++ b/variational_inference_visualized.py
@@ -0,0 +1,217 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Copyright(C) 2013-2021 Max-Planck-Society
+# Authors: Reimar Leike, Philipp Arras, Philipp Frank
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
+
+###############################################################################
+# Variational Inference (VI)
+#
+# This script demonstrates how MGVI, GeoVI, MeanfieldVI and FullCovarianceVI
+# work for an inference problem with only two real quantities of interest. This
+# enables us to plot the posterior probability density as two-dimensional plot.
+###############################################################################
+
+import numpy as np
+import matplotlib.pyplot as plt
+from functools import partial
+from matplotlib.colors import LogNorm
+
+import nifty8 as ift
+
+
+def main():
+    dom = ift.UnstructuredDomain(1)
+
+    scale = 10.
+
+    def transformation(x,y):
+        e = x.exp() if isinstance(x, ift.Operator) else np.exp(x)
+        return scale * e * y
+
+    def jac_transformation(x,y):
+        d = scale * np.exp(x)
+        return np.stack((d*y, d), axis = -1)
+
+    def metric(x,y):
+        jac = jac_transformation(x,y)
+        met = np.einsum('ij,ik -> ijk', jac, jac)
+        met += np.multiply.outer(np.ones(x.shape), np.eye(2))
+        return met
+
+    def metric_function(met, fun):
+        v, U = np.linalg.eigh(met)
+        fv = fun(v)
+        return np.einsum('ijk, ik, ilk -> ijl', U, fv, U)
+
+    def geo_transformation(x,y, x0, y0):
+        t = transformation(x, y)
+        t0 = transformation(x0, y0)
+        j0 = jac_transformation(x0, y0)
+        m0 = metric(x0, y0)
+        s = np.stack((x, y), axis = -1)
+        s0 = np.stack((x0, y0), axis = -1)
+        g = s - s0 + (j0.T * (t - t0)).T
+        inv_sq = metric_function(m0, lambda k: 1./np.sqrt(k))
+        return np.einsum('ijk, ik -> ij', inv_sq, g)
+
+    def jac_geo_trafo(x,y,x0,y0):
+        j = jac_transformation(x,y)
+        j0 = jac_transformation(x0,y0)
+        m0 = metric(x0,y0)
+        inv_sq = metric_function(m0, lambda k: 1./np.sqrt(k))
+        jg = np.multiply.outer(np.ones_like(x), np.eye(2))
+        jg += np.einsum('ij, ik -> ijk', j0, j)
+        return np.einsum('ijk, ikl -> ijl', inv_sq, jg)
+
+    def mg_prob(x,y, x0,y0, a0,b0):
+        shp = x.shape
+        x = x.flatten()
+        y = y.flatten()
+        x0 = np.ones_like(x) * x0
+        y0 = np.ones_like(y) * y0
+        a0 = np.ones_like(x) * a0
+        b0 = np.ones_like(y) * b0
+        metric0 = metric(x0, y0)
+        s = np.stack((x, y), axis = -1)
+        s = s - np.stack((a0, b0), axis = -1)
+        res = np.einsum('ij, ijk, ik -> i', s, metric0, s)
+        return np.exp(-0.5*res).reshape(shp)
+
+    def geo_prob(x,y,x0,y0,a0,b0):
+        shp = x.shape
+        x = x.flatten()
+        y = y.flatten()
+        x0 = np.ones_like(x) * x0
+        y0 = np.ones_like(y) * y0
+        a0 = np.ones_like(x) * a0
+        b0 = np.ones_like(y) * b0
+        x = x - a0 + x0
+        y = y - b0 + y0
+        g = geo_transformation(x,y,x0,y0)
+        jg = jac_geo_trafo(x,y,x0,y0)
+        res = np.einsum('ij,ij -> i', g, g)
+        mymet = np.einsum('ikj, ikl -> ijl', jg, jg)
+        det = np.linalg.det(mymet)
+        res -= np.log(det)
+        return np.exp(-0.5*res).reshape(shp)
+
+        
+
+    a = ift.FieldAdapter(dom, 'a')
+    b = ift.FieldAdapter(dom, 'b')
+    model = transformation(a, b)
+    data = ift.full(dom, 2.)
+    lh = ift.GaussianEnergy(data=data) @ model
+    icsamp = ift.AbsDeltaEnergyController(deltaE=0.1, iteration_limit=2)
+    ham = ift.StandardHamiltonian(lh, icsamp)
+
+    x_limits = [-6, 6]
+    y_limits = [-6, 6]
+    x = np.linspace(*x_limits, num=401)
+    y = np.linspace(*y_limits, num=401)
+    xx, yy = np.meshgrid(x, y, indexing='ij')
+
+    pdfs = [mg_prob, geo_prob]
+    pdfs = [partial(p, xx, yy) for p in pdfs]
+
+    def np_ham(x, y):
+        prior = x**2 + y**2
+        mean = transformation(x, y)
+        d = data.val[0]
+        lh = .5*(d - mean)**2
+        return lh + prior
+
+    z = np.exp(-1.*np_ham(xx, yy))
+    z /= np.max(z)
+
+    mapx = xx[z == np.max(z)]
+    mapy = yy[z == np.max(z)]
+    meanx = (xx*z).sum()/z.sum()
+    meany = (yy*z).sum()/z.sum()
+
+    fig, axs = plt.subplots(1, 2, figsize=[12, 8])
+    axs = axs.flatten()
+
+    def update_plot(runs):
+        for axx, (nn, kl, m), prob in zip(axs, runs, pdfs):
+            axx.clear()
+            axx.imshow(z.T, origin='lower',  cmap='gist_earth_r',
+                       norm=LogNorm(vmin=1e-4, vmax=np.max(z)),
+                       extent=x_limits + y_limits)
+
+            mx, my = m['a'].val[0], m['b'].val[0]
+            mm = kl.position
+            ax, ay = mm['a'].val[0], mm['b'].val[0]
+            p = prob(mx, my, ax, ay)
+            p[p == np.nan] = 0.
+            axx.contour(xx, yy, p, levels=np.linspace(0,np.max(p),11))
+
+            samples = kl.samples.iterator()
+            samples = [[s.val['a'][0], s.val['b'][0]] for s in samples]
+            samples = np.array(samples)
+            mmx = np.sum(xx*p)/np.sum(p)
+            mmy = np.sum(yy*p)/np.sum(p)
+
+            axx.scatter(samples[:,0], samples[:,1],
+                        label=f'{nn} samples')
+            axx.scatter(mmx, mmy, label=f'{nn} mean')
+            axx.scatter(mapx, mapy, label='MAP')
+            axx.scatter(meanx, meany, label='Posterior mean')
+            axx.set_title(nn)
+            axx.set_xlim(x_limits)
+            axx.set_ylim(y_limits)
+            axx.legend(loc='lower right')
+        axs[1].yaxis.set_visible(False)
+        axs[0].set_xlabel('x')
+        axs[0].set_ylabel('y')
+        axs[1].set_xlabel('x')
+        fig.tight_layout()
+        plt.draw()
+        plt.pause(2.)
+
+
+    n_samples = 20
+    minimizer = ift.NewtonCG(
+        ift.GradientNormController(iteration_limit=1, name='Mini'))
+    posmg = ift.full(ham.domain, -5.)
+    posgeo = ift.full(ham.domain, -5.)
+
+    for ii in range(30):
+        if ii % 3 == 0:
+            # Resample GeoVI and MGVI
+            mgkl = ift.SampledKLEnergy(posmg, ham, n_samples, None, True)
+            mini_samp = ift.NewtonCG(
+                    ift.AbsDeltaEnergyController(1E-8, iteration_limit=5))
+            geokl = ift.SampledKLEnergy(posgeo, ham, n_samples, mini_samp, 
+                                        True)
+            mg_m = mgkl.position
+            geo_m = geokl.position
+            runs = (("MGVI", mgkl, mg_m), ("GeoVI", geokl, geo_m))
+            update_plot(runs)
+
+        mgkl, _ = minimizer(mgkl)
+        geokl, _ = minimizer(geokl)
+        posmg = mgkl.position
+        posgeo = geokl.position
+        runs = (("MGVI", mgkl, mg_m), ("GeoVI", geokl, geo_m))
+        update_plot(runs)
+    ift.logger.info('Finished')
+    # Uncomment the following line in order to leave the plots open
+    plt.show()
+
+
+if __name__ == '__main__':
+    main()
-- 
GitLab