getting_started_density.py 5.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#!/usr/bin/env python3

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
16
# Copyright(C) 2013-2021 Max-Planck-Society
17
18
19
20
21
22
23
24
25
26
27
28
29
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

############################################################
# Density estimation
#
# Compute a density estimate for a log-normal process measured by a
# Poissonian likelihood.
#
# Demo takes a while to compute
#############################################################

import numpy as np
Philipp Arras's avatar
Philipp Arras committed
30

31
32
33
34
import nifty7 as ift

if __name__ == "__main__":
    filename = "getting_started_density_{}.png"
Matteo.Guardiani's avatar
Matteo.Guardiani committed
35
36
    ift.random.push_sseq_from_seed(42)

37
38
    # Set up signal domain
    npix1 = 128
Matteo.Guardiani's avatar
Matteo.Guardiani committed
39
40
41
42
    npix2 = 128
    sp1 = ift.RGSpace(npix1)
    sp2 = ift.RGSpace(npix2)
    position_space = ift.DomainTuple.make((sp1, sp2))
43

44
    signal, ops = ift.density_estimator(position_space)
45
46
    correlated_field = ops["correlated_field"]

47
    data_space = signal.target
Philipp Arras's avatar
Philipp Arras committed
48

49
50
    # Generate mock signal and data
    rng = ift.random.current_rng()
Philipp Arras's avatar
Philipp Arras committed
51
52
    mock_position = ift.from_random(signal.domain, "normal")
    data = ift.Field.from_raw(data_space, rng.poisson(signal(mock_position).val))
53

Philipp Arras's avatar
Philipp Arras committed
54
    # Rejoin domains for plotting
Matteo.Guardiani's avatar
Matteo.Guardiani committed
55
    plotting_domain = ift.DomainTuple.make(ift.RGSpace((npix1, npix2)))
Philipp Arras's avatar
Philipp Arras committed
56
    plotting_domain_expanded = ift.DomainTuple.make(ift.RGSpace((2 * npix1, 2 * npix2)))
Matteo.Guardiani's avatar
Matteo.Guardiani committed
57

58
    plot = ift.Plot()
59
    plot.add(
Matteo.Guardiani's avatar
Matteo.Guardiani committed
60
        ift.Field.from_raw(
Philipp Arras's avatar
Philipp Arras committed
61
            plotting_domain_expanded, ift.exp(correlated_field(mock_position)).val
Matteo.Guardiani's avatar
Matteo.Guardiani committed
62
        ),
Philipp Arras's avatar
Philipp Arras committed
63
        title="Pre-Slicing Truth",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
64
65
    )
    plot.add(
Philipp Arras's avatar
Philipp Arras committed
66
67
        ift.Field.from_raw(plotting_domain, signal(mock_position).val),
        title="Ground Truth",
68
    )
Philipp Arras's avatar
Philipp Arras committed
69
    plot.add(ift.Field.from_raw(plotting_domain, data.val), title="Data")
Philipp Arras's avatar
Philipp Arras committed
70
    plot.output(ny=1, nx=3, xsize=10, ysize=3, name=filename.format("setup"))
Matteo.Guardiani's avatar
Matteo.Guardiani committed
71
    print("Setup saved as", filename.format("setup"))
72
73

    # Minimization parameters
74
    ic_sampling = ift.AbsDeltaEnergyController(
Philipp Arras's avatar
Philipp Arras committed
75
        name="Sampling", deltaE=0.01, iteration_limit=100
76
77
    )
    ic_newton = ift.AbsDeltaEnergyController(
Philipp Arras's avatar
Philipp Arras committed
78
        name="Newton", deltaE=0.01, iteration_limit=35
79
    )
80
81
82
83
    ic_sampling.enable_logging()
    ic_newton.enable_logging()
    minimizer = ift.NewtonCG(ic_newton, enable_logging=True)

Philipp Arras's avatar
Philipp Arras committed
84
    # Number of samples used to estimate the KL
85
86
87
    n_samples = 5

    # Set up likelihood and information Hamiltonian
88
    likelihood = ift.PoissonianEnergy(data) @ signal
89
90
    ham = ift.StandardHamiltonian(likelihood, ic_sampling)

Philipp Arras's avatar
Philipp Arras committed
91
    # Start minimization
92
93
94
95
96
97
98
99
100
101
102
    initial_mean = ift.MultiField.full(ham.domain, 0.)
    mean = initial_mean

    for i in range(5):
        # Draw new samples and minimize KL
        kl = ift.MetricGaussianKL.make(mean, ham, n_samples, True)
        kl, convergence = minimizer(kl)
        mean = kl.position

        # Plot current reconstruction
        plot = ift.Plot()
Matteo.Guardiani's avatar
Matteo.Guardiani committed
103
104
        plot.add(
            ift.Field.from_raw(
Philipp Arras's avatar
Philipp Arras committed
105
                plotting_domain_expanded, ift.exp(correlated_field(mock_position)).val
Matteo.Guardiani's avatar
Matteo.Guardiani committed
106
            ),
Philipp Arras's avatar
Philipp Arras committed
107
            title="Ground truth",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
108
109
        )
        plot.add(
Philipp Arras's avatar
Philipp Arras committed
110
111
            ift.Field.from_raw(plotting_domain, signal(mock_position).val),
            title="Ground truth",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
112
113
        )
        plot.add(
Philipp Arras's avatar
Philipp Arras committed
114
115
            ift.Field.from_raw(plotting_domain, signal(kl.position).val),
            title="Reconstruction",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
116
        )
117
        plot.add(
Philipp Arras's avatar
Philipp Arras committed
118
119
120
            (ic_newton.history, ic_sampling.history, minimizer.inversion_history),
            label=["kl", "Sampling", "Newton inversion"],
            title="Cumulative energies",
121
            s=[None, None, 1],
Philipp Arras's avatar
Philipp Arras committed
122
            alpha=[None, 0.2, None],
123
124
        )
        plot.output(
Philipp Arras's avatar
Philipp Arras committed
125
            nx=3, ny=2, ysize=10, xsize=15, name=filename.format(f"loop_{i:02d}")
126
        )
127
128
129
130
131
132
133
134
135
136

    # Done, draw posterior samples
    sc = ift.StatCalculator()
    sc_unsliced = ift.StatCalculator()
    for sample in kl.samples:
        sc.add(signal(sample + kl.position))
        sc_unsliced.add(ift.exp(correlated_field(sample + kl.position)))

    # Plotting
    plot = ift.Plot()
Philipp Arras's avatar
Philipp Arras committed
137
    plot.add(ift.Field.from_raw(plotting_domain, sc.mean.val), title="Posterior Mean")
138
    plot.add(
Philipp Arras's avatar
Philipp Arras committed
139
140
        ift.Field.from_raw(plotting_domain, ift.sqrt(sc.var).val),
        title="Posterior Standard Deviation",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
141
142
143
    )
    plot.add(
        ift.Field.from_raw(plotting_domain_expanded, sc_unsliced.mean.val),
Philipp Arras's avatar
Philipp Arras committed
144
        title="Posterior Unsliced Mean",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
145
146
    )
    plot.add(
Philipp Arras's avatar
Philipp Arras committed
147
148
        ift.Field.from_raw(plotting_domain_expanded, ift.sqrt(sc_unsliced.var).val),
        title="Posterior Unsliced Standard Deviation",
149
    )
Philipp Arras's avatar
Philipp Arras committed
150
151
    plot.output(xsize=15, ysize=15, name=filename.format("results"))
    print("Saved results as", filename.format("results"))