meanfield_inference.py 3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
14
# Copyright(C) 2013-2021 Max-Planck-Society
15
16
17
18
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

###############################################################################
Philipp Arras's avatar
Philipp Arras committed
19
20
21
# FIXME Short text what this demo does
#
#
22
23
24
25
###############################################################################

import numpy as np

Philipp Arras's avatar
Philipp Arras committed
26
27
import nifty7 as ift
from matplotlib import pyplot as plt
28
29


Philipp Arras's avatar
Philipp Arras committed
30
if __name__ == "__main__":
31
32
33
34
    position_space = ift.RGSpace([100])
    harmonic_space = position_space.get_default_codomain()
    HT = ift.HarmonicTransformOperator(harmonic_space, position_space)
    p_space = ift.PowerSpace(harmonic_space)
Philipp Arras's avatar
Philipp Arras committed
35

36
    pd = ift.PowerDistributor(harmonic_space, p_space)
Philipp Arras's avatar
Philipp Arras committed
37
    a = ift.PS_field(p_space, lambda k: 1.0 / (1.0 + k ** 2))
38
    A = pd(a)
Philipp Arras's avatar
Philipp Arras committed
39
    sky = 10 * ift.exp(HT(ift.makeOp(A))).ducktape("xi")
Philipp Arras's avatar
Philipp Arras committed
40
    R = ift.GeometryRemover(position_space)
41
42
43

    d_space = R.target[0]
    lamb = R(sky)
Philipp Arras's avatar
Philipp Arras committed
44
    mock_position = ift.from_random(sky.domain, "normal")
Philipp Arras's avatar
Philipp Arras committed
45
46
    data = ift.random.current_rng().poisson(lamb(mock_position).val)
    likelihood = ift.PoissonianEnergy(ift.makeField(d_space, data)) @ lamb
47
48
49

    # Settings for minimization
    ic_newton = ift.DeltaEnergyController(
Philipp Arras's avatar
Philipp Arras committed
50
51
        name="Newton", iteration_limit=1, tol_rel_deltaE=1e-8
    )
52
53

    H = ift.StandardHamiltonian(likelihood)
Philipp Frank's avatar
Philipp Frank committed
54
    position_fc = ift.from_random(H.domain)*0.1
Philipp Frank's avatar
Philipp Frank committed
55
    position_mf = ift.from_random(H.domain)*0.1
56

57
58
    fc = ift.library.variational_models.FullCovarianceVI(position_fc, H, 3, True, initial_sig=0.01)
    mf = ift.library.variational_models.MeanFieldVI(position_mf, H, 3, True, initial_sig=0.01)
Philipp Arras's avatar
Philipp Arras committed
59
    minimizer_fc = ift.ADVIOptimizer(20, eta=0.1)
Philipp Frank's avatar
Philipp Frank committed
60
    minimizer_mf = ift.ADVIOptimizer(10)
61

Philipp Arras's avatar
Philipp Arras committed
62
63
64
65
66
67
68
69
70
    niter = 25
    for ii in range(niter):
        # Plotting
        plt.plot(sky(fc.mean).val, "b-", label="Full covariance")
        plt.plot(sky(mf.mean).val, "r-", label="Mean field")
        for _ in range(5):
            plt.plot(sky(fc.draw_sample()).val, "b-", alpha=0.3)
            plt.plot(sky(mf.draw_sample()).val, "r-", alpha=0.3)
        plt.plot(data, "kx")
Philipp Arras's avatar
Philipp Arras committed
71
        plt.plot(sky(mock_position).val, "k-", label="Ground truth")
72
        plt.legend()
Philipp Arras's avatar
Philipp Arras committed
73
74
75
76
77
78
79
80
81
        plt.ylim(0, data.max() + 10)
        fname = f"meanfield_{ii:03d}.png"
        plt.savefig(fname)
        print(f"Saved results as '{fname}' ({ii}/{niter-1}).")
        plt.close()
        # /Plotting

        fc.minimize(minimizer_fc)
        mf.minimize(minimizer_mf)