meanfield_inference.py 3.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
14
# Copyright(C) 2013-2021 Max-Planck-Society
15
16
17
18
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

###############################################################################
Jakob Knollmüller's avatar
Jakob Knollmüller committed
19
# Meanfield and fullcovariance variational inference
Philipp Arras's avatar
Philipp Arras committed
20
#
Jakob Knollmüller's avatar
Jakob Knollmüller committed
21
22
23
24
25
# The signal is a 1-D lognormal distributed field.
# The  data follows a Poisson likelihood.
# The posterior distribution is approximated with a diagonal, as well as a
# full covariance Gaussian distribution. This is achieved by minimizing
# a stochastic estimate of the KL-Divergence
Philipp Arras's avatar
Philipp Arras committed
26
#
Jakob Knollmüller's avatar
Jakob Knollmüller committed
27
28
# Note that the fullcovariance approximation scales quadratically with the
# number of parameters. 
29
30
31
32
###############################################################################

import numpy as np

Philipp Arras's avatar
Philipp Arras committed
33
34
import nifty7 as ift
from matplotlib import pyplot as plt
35

Jakob Knollmüller's avatar
Jakob Knollmüller committed
36
37
ift.random.push_sseq_from_seed(27)

38

Philipp Arras's avatar
Philipp Arras committed
39
if __name__ == "__main__":
Jakob Knollmüller's avatar
Jakob Knollmüller committed
40
    # Space and model setup
41
42
43
44
    position_space = ift.RGSpace([100])
    harmonic_space = position_space.get_default_codomain()
    HT = ift.HarmonicTransformOperator(harmonic_space, position_space)
    p_space = ift.PowerSpace(harmonic_space)
Philipp Arras's avatar
Philipp Arras committed
45

46
    pd = ift.PowerDistributor(harmonic_space, p_space)
Philipp Arras's avatar
Philipp Arras committed
47
    a = ift.PS_field(p_space, lambda k: 1.0 / (1.0 + k ** 2))
48
    A = pd(a)
Philipp Arras's avatar
Philipp Arras committed
49
    sky = 10 * ift.exp(HT(ift.makeOp(A))).ducktape("xi")
Philipp Arras's avatar
Philipp Arras committed
50
    R = ift.GeometryRemover(position_space)
51

Jakob Knollmüller's avatar
Jakob Knollmüller committed
52
53
54
55
56
    mask = np.zeros(position_space.shape)
    mask[mask.shape[0]//3:2*mask.shape[0]//3] = 1
    mask = ift.Field.from_raw(position_space, mask)
    R = ift.MaskOperator(mask)

57
58
    d_space = R.target[0]
    lamb = R(sky)
Jakob Knollmüller's avatar
Jakob Knollmüller committed
59
60

    # Generate simulated signal and data and build likelihood.
Philipp Arras's avatar
Philipp Arras committed
61
    mock_position = ift.from_random(sky.domain, "normal")
Philipp Arras's avatar
Philipp Arras committed
62
    data = ift.random.current_rng().poisson(lamb(mock_position).val)
Jakob Knollmüller's avatar
Jakob Knollmüller committed
63
64
65
    data = ift.makeField(d_space, data)
    likelihood = ift.PoissonianEnergy(data) @ lamb
    H = ift.StandardHamiltonian(likelihood)
66
67

    # Settings for minimization
Jakob Knollmüller's avatar
Jakob Knollmüller committed
68
69
70
71
    IC = ift.StochasticAbsDeltaEnergyController(5, iteration_limit=200,
                                                name='advi')
    minimizer_fc = ift.ADVIOptimizer(IC, eta=0.1)
    minimizer_mf = ift.ADVIOptimizer(IC)
72

Jakob Knollmüller's avatar
Jakob Knollmüller committed
73
    # Initial positions 
Philipp Frank's avatar
Philipp Frank committed
74
    position_fc = ift.from_random(H.domain)*0.1
Philipp Frank's avatar
Philipp Frank committed
75
    position_mf = ift.from_random(H.domain)*0.1
76

Jakob Knollmüller's avatar
Jakob Knollmüller committed
77
    # Setup of the variational models
Philipp Arras's avatar
Philipp Arras committed
78
79
    fc = ift.FullCovarianceVI(position_fc, H, 3, True, initial_sig=0.01)
    mf = ift.MeanFieldVI(position_mf, H, 3, True, initial_sig=0.01)
80

Jakob Knollmüller's avatar
Jakob Knollmüller committed
81
82

    niter = 10
Philipp Arras's avatar
Philipp Arras committed
83
84
85
86
87
88
89
    for ii in range(niter):
        # Plotting
        plt.plot(sky(fc.mean).val, "b-", label="Full covariance")
        plt.plot(sky(mf.mean).val, "r-", label="Mean field")
        for _ in range(5):
            plt.plot(sky(fc.draw_sample()).val, "b-", alpha=0.3)
            plt.plot(sky(mf.draw_sample()).val, "r-", alpha=0.3)
Jakob Knollmüller's avatar
Jakob Knollmüller committed
90
        plt.plot(R.adjoint(data).val, "kx")
Philipp Arras's avatar
Philipp Arras committed
91
        plt.plot(sky(mock_position).val, "k-", label="Ground truth")
92
        plt.legend()
Jakob Knollmüller's avatar
Jakob Knollmüller committed
93
        plt.ylim(0.1, data.val.max() + 10)
Philipp Arras's avatar
Philipp Arras committed
94
95
96
97
98
99
        fname = f"meanfield_{ii:03d}.png"
        plt.savefig(fname)
        print(f"Saved results as '{fname}' ({ii}/{niter-1}).")
        plt.close()
        # /Plotting

Jakob Knollmüller's avatar
Jakob Knollmüller committed
100
        # Run minimization
Philipp Arras's avatar
Philipp Arras committed
101
102
        fc.minimize(minimizer_fc)
        mf.minimize(minimizer_mf)