mgvi_visualized.py 4.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Authors: Reimar Leike, Philipp Arras
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

Philipp Arras's avatar
Philipp Arras committed
19
20
21
22
23
24
25
26
27
28
29
30
###############################################################################
# Metric Gaussian Variational Inference (MGVI)
#
# This script demonstrates how MGVI works for an inference problem with only
# two real quantities of interest. This enables us to plot the posterior
# probability density as two-dimensional plot. The posterior samples generated
# by MGVI are contrasted with the maximum-a-posterior (MAP) solution together
# with samples drawn with the Laplace method. This method uses the local
# curvature at the MAP solution as inverse covariance of a Gaussian probability
# density.
###############################################################################

31
32
import numpy as np
import pylab as plt
Philipp Arras's avatar
Philipp Arras committed
33
from matplotlib.colors import LogNorm
34
35
36
37
38

import nifty6 as ift

if __name__ == '__main__':
    dom = ift.UnstructuredDomain(1)
Philipp Arras's avatar
Philipp Arras committed
39
    scale = 10
40

41
42
    a = ift.FieldAdapter(dom, 'a')
    b = ift.FieldAdapter(dom, 'b')
Philipp Arras's avatar
Philipp Arras committed
43
    lh = (a.adjoint @ a).scale(scale) + (b.adjoint @ b).scale(-1.35*2).exp()
44
45
46
47
    lh = ift.VariableCovarianceGaussianEnergy(dom, 'a', 'b', np.float64) @ lh
    icsamp = ift.AbsDeltaEnergyController(deltaE=0.1, iteration_limit=2)
    ham = ift.StandardHamiltonian(lh, icsamp)

Philipp Arras's avatar
Philipp Arras committed
48
49
    x_limits = [-8/scale, 8/scale]
    x_limits_scaled = [-8, 8]
50
    y_limits = [-4, 4]
51
52
53
    x = np.linspace(*x_limits, num=401)
    y = np.linspace(*y_limits, num=401)
    xx, yy = np.meshgrid(x, y, indexing='ij')
Philipp Arras's avatar
Philipp Arras committed
54

55
    def np_ham(x, y):
Philipp Arras's avatar
Philipp Arras committed
56
        prior = x**2 + y**2
Philipp Arras's avatar
Philipp Arras committed
57
        mean = x*scale
58
59
60
        lcov = 1.35*2*y
        lh = .5*(mean**2*np.exp(-lcov) + lcov)
        return lh + prior
Philipp Arras's avatar
Philipp Arras committed
61

62
    z = np.exp(-1*np_ham(xx, yy))
63
64
65
    plt.plot(y, np.sum(z, axis=0))
    plt.xlabel('y')
    plt.ylabel('pdf')
Philipp Arras's avatar
Philipp Arras committed
66
67
68
69
    plt.title('Marginal density')
    plt.pause(2.0)
    plt.close()
    plt.plot(x*scale, np.sum(z, axis=1))
70
71
    plt.xlabel('x')
    plt.ylabel('pdf')
Philipp Arras's avatar
Philipp Arras committed
72
73
74
    plt.title('Marginal density')
    plt.pause(2.0)
    plt.close()
75

Philipp Arras's avatar
Philipp Arras committed
76
77
    pos = ift.from_random('normal', ham.domain)
    MAP = ift.EnergyAdapter(pos, ham, want_metric=True)
Philipp Arras's avatar
Philipp Arras committed
78
79
    minimizer = ift.NewtonCG(
        ift.GradientNormController(iteration_limit=20, name='Mini'))
Philipp Arras's avatar
Philipp Arras committed
80
81
82
83
84
85
86
    MAP, _ = minimizer(MAP)
    map_xs, map_ys = [], []
    for ii in range(10):
        samp = (MAP.metric.draw_sample(from_inverse=True) + MAP.position).val
        map_xs.append(samp['a'])
        map_ys.append(samp['b'])

87
    minimizer = ift.NewtonCG(
Philipp Arras's avatar
Philipp Arras committed
88
        ift.GradientNormController(iteration_limit=2, name='Mini'))
Philipp Arras's avatar
Philipp Arras committed
89
    pos = ift.from_random('normal', ham.domain)
Philipp Arras's avatar
Philipp Arras committed
90
    plt.figure(figsize=[12, 8])
91
92
    for ii in range(15):
        if ii % 3 == 0:
93
94
95
            mgkl = ift.MetricGaussianKL(pos, ham, 40)

        plt.cla()
Philipp Arras's avatar
Philipp Arras committed
96
97
98
        plt.imshow(z.T, origin='lower', norm=LogNorm(), vmin=1e-3,
                   vmax=np.max(z), cmap='gist_earth_r',
                   extent=x_limits_scaled + y_limits)
Philipp Arras's avatar
Philipp Arras committed
99
        if ii == 0:
Philipp Arras's avatar
Philipp Arras committed
100
101
            cbar = plt.colorbar()
        cbar.ax.set_ylabel('pdf')
102
103
104
105
106
        xs, ys = [], []
        for samp in mgkl.samples:
            samp = (samp + pos).val
            xs.append(samp['a'])
            ys.append(samp['b'])
Philipp Arras's avatar
Philipp Arras committed
107
108
109
110
111
112
113
        plt.scatter(np.array(xs)*scale, np.array(ys), label='MGVI samples')
        plt.scatter(pos.val['a']*scale, pos.val['b'], label='MGVI latent mean')
        plt.scatter(np.array(map_xs)*scale, np.array(map_ys),
                    label='Laplace samples')
        plt.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
                    label='Maximum a posterior solution')
        plt.legend()
114
        plt.draw()
Reimar Leike's avatar
Reimar Leike committed
115
        plt.pause(1.0)
116

Philipp Arras's avatar
Philipp Arras committed
117
        mgkl, _ = minimizer(mgkl)
118
119
        pos = mgkl.position
    ift.logger.info('Finished')
Philipp Arras's avatar
Philipp Arras committed
120
    # Uncomment the following line in order to leave the plots open
121
    # plt.show()