mgvi_visualized.py 4.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Authors: Reimar Leike, Philipp Arras
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

Philipp Arras's avatar
Philipp Arras committed
19
20
21
22
23
24
25
26
27
28
29
30
###############################################################################
# Metric Gaussian Variational Inference (MGVI)
#
# This script demonstrates how MGVI works for an inference problem with only
# two real quantities of interest. This enables us to plot the posterior
# probability density as two-dimensional plot. The posterior samples generated
# by MGVI are contrasted with the maximum-a-posterior (MAP) solution together
# with samples drawn with the Laplace method. This method uses the local
# curvature at the MAP solution as inverse covariance of a Gaussian probability
# density.
###############################################################################

31
32
import numpy as np
import pylab as plt
Philipp Arras's avatar
Philipp Arras committed
33
from matplotlib.colors import LogNorm
34
35
36
37
38

import nifty6 as ift

if __name__ == '__main__':
    dom = ift.UnstructuredDomain(1)
Philipp Arras's avatar
Philipp Arras committed
39
    scale = 10
40

41
42
    a = ift.FieldAdapter(dom, 'a')
    b = ift.FieldAdapter(dom, 'b')
Philipp Arras's avatar
Philipp Arras committed
43
    lh = (a.adjoint @ a).scale(scale) + (b.adjoint @ b).scale(-1.35*2).exp()
44
45
46
47
    lh = ift.VariableCovarianceGaussianEnergy(dom, 'a', 'b', np.float64) @ lh
    icsamp = ift.AbsDeltaEnergyController(deltaE=0.1, iteration_limit=2)
    ham = ift.StandardHamiltonian(lh, icsamp)

Philipp Arras's avatar
Philipp Arras committed
48
49
    x_limits = [-8/scale, 8/scale]
    x_limits_scaled = [-8, 8]
50
    y_limits = [-4, 4]
51
52
53
    x = np.linspace(*x_limits, num=401)
    y = np.linspace(*y_limits, num=401)
    xx, yy = np.meshgrid(x, y, indexing='ij')
Philipp Arras's avatar
Philipp Arras committed
54

55
    def np_ham(x, y):
Philipp Arras's avatar
Philipp Arras committed
56
        prior = x**2 + y**2
Philipp Arras's avatar
Philipp Arras committed
57
        mean = x*scale
58
59
60
        lcov = 1.35*2*y
        lh = .5*(mean**2*np.exp(-lcov) + lcov)
        return lh + prior
Philipp Arras's avatar
Philipp Arras committed
61

62
    z = np.exp(-1*np_ham(xx, yy))
63
64
    plt.plot(y, np.sum(z, axis=0))
    plt.xlabel('y')
Reimar Leike's avatar
Reimar Leike committed
65
    plt.ylabel('unnormalized pdf')
Philipp Arras's avatar
Philipp Arras committed
66
67
68
69
    plt.title('Marginal density')
    plt.pause(2.0)
    plt.close()
    plt.plot(x*scale, np.sum(z, axis=1))
70
    plt.xlabel('x')
Reimar Leike's avatar
Reimar Leike committed
71
    plt.ylabel('unnormalized pdf')
Philipp Arras's avatar
Philipp Arras committed
72
73
74
    plt.title('Marginal density')
    plt.pause(2.0)
    plt.close()
75

Martin Reinecke's avatar
adjust    
Martin Reinecke committed
76
    pos = ift.from_random(ham.domain, 'normal')
Philipp Arras's avatar
Philipp Arras committed
77
    MAP = ift.EnergyAdapter(pos, ham, want_metric=True)
Philipp Arras's avatar
Philipp Arras committed
78
79
    minimizer = ift.NewtonCG(
        ift.GradientNormController(iteration_limit=20, name='Mini'))
Philipp Arras's avatar
Philipp Arras committed
80
81
82
83
84
85
86
    MAP, _ = minimizer(MAP)
    map_xs, map_ys = [], []
    for ii in range(10):
        samp = (MAP.metric.draw_sample(from_inverse=True) + MAP.position).val
        map_xs.append(samp['a'])
        map_ys.append(samp['b'])

87
    minimizer = ift.NewtonCG(
Philipp Arras's avatar
Philipp Arras committed
88
        ift.GradientNormController(iteration_limit=2, name='Mini'))
Martin Reinecke's avatar
adjust    
Martin Reinecke committed
89
    pos = ift.from_random(ham.domain, 'normal')
Philipp Arras's avatar
Philipp Arras committed
90
    plt.figure(figsize=[12, 8])
91
92
    for ii in range(15):
        if ii % 3 == 0:
93
94
95
            mgkl = ift.MetricGaussianKL(pos, ham, 40)

        plt.cla()
Philipp Arras's avatar
Philipp Arras committed
96
97
98
        plt.imshow(z.T, origin='lower', norm=LogNorm(), vmin=1e-3,
                   vmax=np.max(z), cmap='gist_earth_r',
                   extent=x_limits_scaled + y_limits)
Philipp Arras's avatar
Philipp Arras committed
99
        if ii == 0:
Philipp Arras's avatar
Philipp Arras committed
100
101
            cbar = plt.colorbar()
        cbar.ax.set_ylabel('pdf')
102
103
104
105
106
        xs, ys = [], []
        for samp in mgkl.samples:
            samp = (samp + pos).val
            xs.append(samp['a'])
            ys.append(samp['b'])
Philipp Arras's avatar
Philipp Arras committed
107
108
109
110
111
112
113
        plt.scatter(np.array(xs)*scale, np.array(ys), label='MGVI samples')
        plt.scatter(pos.val['a']*scale, pos.val['b'], label='MGVI latent mean')
        plt.scatter(np.array(map_xs)*scale, np.array(map_ys),
                    label='Laplace samples')
        plt.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
                    label='Maximum a posterior solution')
        plt.legend()
114
        plt.draw()
Reimar Leike's avatar
Reimar Leike committed
115
        plt.pause(1.0)
116

Philipp Arras's avatar
Philipp Arras committed
117
        mgkl, _ = minimizer(mgkl)
118
119
        pos = mgkl.position
    ift.logger.info('Finished')
Philipp Arras's avatar
Philipp Arras committed
120
    # Uncomment the following line in order to leave the plots open
121
    # plt.show()