meanfield_inference.py 3.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
14
# Copyright(C) 2013-2021 Max-Planck-Society
15
16
17
18
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

###############################################################################
Philipp Arras's avatar
Philipp Arras committed
19
20
21
# FIXME Short text what this demo does
#
#
22
23
24
25
###############################################################################

import numpy as np

Philipp Arras's avatar
Philipp Arras committed
26
27
import nifty7 as ift
from matplotlib import pyplot as plt
28
29


Philipp Arras's avatar
Philipp Arras committed
30
if __name__ == "__main__":
31
32
33
34
35
36
37
38
39
40
41
42
43

    # Two-dimensional regular grid with inhomogeneous exposure
    position_space = ift.RGSpace([100])

    # Define harmonic space and harmonic transform
    harmonic_space = position_space.get_default_codomain()
    HT = ift.HarmonicTransformOperator(harmonic_space, position_space)

    # Domain on which the field's degrees of freedom are defined
    domain = ift.DomainTuple.make(harmonic_space)

    # Define amplitude (square root of power spectrum)
    def sqrtpspec(k):
Philipp Arras's avatar
Philipp Arras committed
44
        return 1.0 / (1.0 + k ** 2)
45
46
47
48
49
50
51

    p_space = ift.PowerSpace(harmonic_space)
    pd = ift.PowerDistributor(harmonic_space, p_space)
    a = ift.PS_field(p_space, sqrtpspec)
    A = pd(a)

    # Define sky operator
Philipp Arras's avatar
Philipp Arras committed
52
    sky = 10 * ift.exp(HT(ift.makeOp(A))).ducktape("xi")
53
54
55
56
57
58
59
60
61
62

    # M = ift.DiagonalOperator(exposure)
    GR = ift.GeometryRemover(position_space)
    # Define instrumental response
    # R = GR(M)
    R = GR

    # Generate mock data and define likelihood operator
    d_space = R.target[0]
    lamb = R(sky)
Philipp Arras's avatar
Philipp Arras committed
63
    mock_position = ift.from_random(sky.domain, "normal")
64
65
66
67
68
69
70
    data = lamb(mock_position)
    data = ift.random.current_rng().poisson(data.val.astype(np.float64))
    data = ift.Field.from_raw(d_space, data)
    likelihood = ift.PoissonianEnergy(data) @ lamb

    # Settings for minimization
    ic_newton = ift.DeltaEnergyController(
Philipp Arras's avatar
Philipp Arras committed
71
72
        name="Newton", iteration_limit=1, tol_rel_deltaE=1e-8
    )
73
74

    H = ift.StandardHamiltonian(likelihood)
Philipp Frank's avatar
Philipp Frank committed
75
76
    position_fc = ift.from_random(H.domain)*0.1
    position_mf = ift.from_random(H.domain)*0.
77

Philipp Frank's avatar
Philipp Frank committed
78
79
80
81
    fc = ift.FullCovariance(position_fc, H, 3, True, initial_sig=0.01)
    mf = ift.MeanField(position_mf, H, 3, True, initial_sig=0.0001)
    minimizer_fc = ift.ADVIOptimizer(10)
    minimizer_mf = ift.ADVIOptimizer(10)
82
83

    plt.pause(0.001)
Philipp Arras's avatar
Philipp Arras committed
84
    for i in range(25):
Philipp Frank's avatar
Philipp Frank committed
85
86
87
        if i != 0:
            fc.minimize(minimizer_fc)
            mf.minimize(minimizer_mf)
88

Philipp Arras's avatar
Philipp Arras committed
89
        plt.figure("result")
90
        plt.cla()
Philipp Arras's avatar
Philipp Arras committed
91
        plt.plot(
Philipp Frank's avatar
Philipp Frank committed
92
            sky(fc.position).val,
Philipp Arras's avatar
Philipp Arras committed
93
94
95
96
            "b-",
            label="Full covariance",
        )
        plt.plot(
Philipp Frank's avatar
Philipp Frank committed
97
            sky(mf.position).val, "r-", label="Mean field"
Philipp Arras's avatar
Philipp Arras committed
98
        )
Philipp Frank's avatar
Philipp Frank committed
99
100
101
102
103
104
105
106
107
108
        #for samp in KL_fc.samples:
        #    plt.plot(
        #        sky(fullcov_model.generator(KL_fc.position + samp)).val, "b-", alpha=0.3
        #    )
        #for samp in KL_mf.samples:
        #    plt.plot(
        #        sky(meanfield_model.generator(KL_mf.position + samp)).val,
        #        "r-",
        #        alpha=0.3,
        #    )
Philipp Arras's avatar
Philipp Arras committed
109
110
        plt.plot(data.val, "kx")
        plt.plot(sky(mock_position).val, "k-", label="Ground truth")
111
        plt.legend()
Philipp Arras's avatar
Philipp Arras committed
112
113
        plt.ylim(0, data.val.max() + 10)
        plt.pause(0.001)