variational_inference_visualized.py 6.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
15
# Copyright(C) 2013-2021 Max-Planck-Society
# Authors: Reimar Leike, Philipp Arras, Philipp Frank
16
17
18
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

Philipp Arras's avatar
Philipp Arras committed
19
###############################################################################
20
# Variational Inference (VI)
Philipp Arras's avatar
Philipp Arras committed
21
#
22
23
24
25
26
27
28
# This script demonstrates how MGVI and GeoVI work for an inference problem
# with only two real quantities of interest. This enables us to plot the
# posterior probability density as two-dimensional plot. The approximate
# posterior samples are contrasted with the maximum-a-posterior (MAP) solution
# together with samples drawn with the Laplace method. This method uses the
# local curvature at the MAP solution as inverse covariance of a Gaussian
# probability density.
Philipp Arras's avatar
Philipp Arras committed
29
30
###############################################################################

31
32
import numpy as np
import pylab as plt
Philipp Arras's avatar
Philipp Arras committed
33
from matplotlib.colors import LogNorm
34

Martin Reinecke's avatar
merge    
Martin Reinecke committed
35
import nifty7 as ift
36

37

Philipp Arras's avatar
Philipp Arras committed
38
def main():
39
    dom = ift.UnstructuredDomain(1)
Philipp Arras's avatar
Philipp Arras committed
40
    scale = 10
41

42
43
    a = ift.FieldAdapter(dom, 'a')
    b = ift.FieldAdapter(dom, 'b')
Philipp Arras's avatar
Philipp Arras committed
44
    lh = (a.adjoint @ a).scale(scale) + (b.adjoint @ b).scale(-1.35*2).exp()
45
46
47
48
    lh = ift.VariableCovarianceGaussianEnergy(dom, 'a', 'b', np.float64) @ lh
    icsamp = ift.AbsDeltaEnergyController(deltaE=0.1, iteration_limit=2)
    ham = ift.StandardHamiltonian(lh, icsamp)

Philipp Arras's avatar
Philipp Arras committed
49
50
    x_limits = [-8/scale, 8/scale]
    x_limits_scaled = [-8, 8]
51
    y_limits = [-4, 4]
52
53
54
    x = np.linspace(*x_limits, num=401)
    y = np.linspace(*y_limits, num=401)
    xx, yy = np.meshgrid(x, y, indexing='ij')
Philipp Arras's avatar
Philipp Arras committed
55

56
    def np_ham(x, y):
Philipp Arras's avatar
Philipp Arras committed
57
        prior = x**2 + y**2
Philipp Arras's avatar
Philipp Arras committed
58
        mean = x*scale
59
60
61
        lcov = 1.35*2*y
        lh = .5*(mean**2*np.exp(-lcov) + lcov)
        return lh + prior
Philipp Arras's avatar
Philipp Arras committed
62

63
    z = np.exp(-1*np_ham(xx, yy))
Philipp Arras's avatar
Philipp Arras committed
64
65
66
67
68
69
70
71
72
73
74
75
    plt.plot(y, np.sum(z, axis=0))
    plt.xlabel('y')
    plt.ylabel('unnormalized pdf')
    plt.title('Marginal density')
    plt.pause(2.0)
    plt.close()
    plt.plot(x*scale, np.sum(z, axis=1))
    plt.xlabel('x')
    plt.ylabel('unnormalized pdf')
    plt.title('Marginal density')
    plt.pause(2.0)
    plt.close()
76

77
78
79
80
    mapx = xx[z==np.max(z)]
    mapy = yy[z==np.max(z)]
    meanx = (xx*z).sum()/z.sum()
    meany = (yy*z).sum()/z.sum()
Philipp Arras's avatar
Philipp Arras committed
81

82
    n_samples = 100
83
    minimizer = ift.NewtonCG(
Philipp Arras's avatar
Philipp Arras committed
84
        ift.GradientNormController(iteration_limit=2, name='Mini'))
85
86
87
88
89
90
91
92
93
94
95
96
97
    IC = ift.StochasticAbsDeltaEnergyController(0.1, iteration_limit=20,
                                                name='advi')
    stochastic_minimizer_mf = ift.ADVIOptimizer(IC, eta = 0.5)
    stochastic_minimizer_fc = ift.ADVIOptimizer(IC, eta = 0.5)
    posmg = posgeo = posmf = posfc = ift.from_random(ham.domain, 'normal')
    fc = ift.FullCovarianceVI(posfc, ham, 10, False, initial_sig=0.01)
    mf = ift.MeanFieldVI(posmf, ham, 10, False, initial_sig=0.01)

    fig, axs = plt.subplots(2, 2, figsize=[12, 8])
    axs = axs.flatten()

    def update_plot(runs):
        for axx, (nn, kl, pp, sam) in zip(axs,runs):
98
            axx.clear()
99
100
            axx.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
                       cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
101
            xs, ys = [], []
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
            if sam:
                samples = (samp + pp for samp in kl.samples)
            else:
                samples = (kl.draw_sample() for _ in range(n_samples))
            mx, my = 0., 0.
            for samp in samples:
                a = samp.val['a']
                xs.append(a)
                mx += a
                b = samp.val['b']
                ys.append(b)
                my += b
            mx /= n_samples
            my /= n_samples
            axx.scatter(np.array(xs)*scale, np.array(ys), label=f'{nn} samples')
            axx.scatter(mx*scale, my, label=f'{nn} mean')
            axx.scatter(mapx*scale, mapy, label = 'MAP')
            axx.scatter(meanx*scale, meany, label = 'Posterior mean')
            axx.set_title(nn)
121
122
123
124
            axx.set_xlim(x_limits_scaled)
            axx.set_ylim(y_limits)
            axx.legend(loc='lower right')
        axs[0].xaxis.set_visible(False)
125
126
127
128
129
130
        axs[1].xaxis.set_visible(False)
        axs[1].yaxis.set_visible(False)
        axs[2].set_xlabel('x')
        axs[2].set_ylabel('y')
        axs[3].yaxis.set_visible(False)
        axs[3].set_xlabel('x')
131
        plt.tight_layout()
132
        plt.draw()
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        plt.pause(2.0)
    
    for ii in range(15):
        if ii % 2 == 0:
            # Resample
            mgkl = ift.MetricGaussianKL(posmg, ham, n_samples, False)
            mini_samp = ift.NewtonCG(ift.AbsDeltaEnergyController(1E-8,
                                                                  iteration_limit=5))
            geokl = ift.GeoMetricKL(posgeo, ham, n_samples, mini_samp, False)

            runs = (("MGVI", mgkl, posmg, True),
                    ("GeoVI", geokl, posgeo, True),
                    ("MeanfieldVI", mf, posmf, False),
                    ("FullCovarianceVI", fc, posfc, False))
            update_plot(runs)
148

Philipp Arras's avatar
Philipp Arras committed
149
        mgkl, _ = minimizer(mgkl)
150
        geokl, _ = minimizer(geokl)
151
152
153
154
155
156
157
158
159
160
161
        mf.minimize(stochastic_minimizer_mf)
        fc.minimize(stochastic_minimizer_fc)
        posmg = mgkl.position
        posgeo = geokl.position
        posmf = mf.mean
        posfc = fc.mean
        runs = (("MGVI", mgkl, posmg, True),
                ("GeoVI", geokl, posgeo, True),
                ("MeanfieldVI", mf, posmf, False),
                ("FullCovarianceVI", fc, posfc, False))
        update_plot(runs)
162
    ift.logger.info('Finished')
Philipp Arras's avatar
Philipp Arras committed
163
    # Uncomment the following line in order to leave the plots open
164
    # plt.show()
Philipp Arras's avatar
Philipp Arras committed
165
166
167
168


if __name__ == '__main__':
    main()