getting_started_density.py 7.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#!/usr/bin/env python3

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
16
# Copyright(C) 2013-2021 Max-Planck-Society
17
18
19
20
21
22
23
24
25
26
27
28
29
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

############################################################
# Density estimation
#
# Compute a density estimate for a log-normal process measured by a
# Poissonian likelihood.
#
# Demo takes a while to compute
#############################################################

import numpy as np
Philipp Arras's avatar
Philipp Arras committed
30

31
32
33
import nifty7 as ift


Philipp Arras's avatar
Philipp Arras committed
34
def density_estimator(domain, pad=1.0, cf_fluctuations=None, cf_azm_uniform=None):
Matteo.Guardiani's avatar
Matteo.Guardiani committed
35
    cf_azm_uniform_sane_default = (1e-4, 1.0)
36
    cf_fluctuations_sane_default = {
Matteo Guardiani's avatar
Matteo Guardiani committed
37
        "scale": (0.5, 0.3),
Matteo.Guardiani's avatar
Matteo.Guardiani committed
38
        "cutoff": (4.0, 3.0),
39
        "loglogslope": (-6.0, 3.0)
40
    }
41
42
43

    domain = ift.DomainTuple.make(domain)
    dom_scaling = 1. + np.broadcast_to(pad, (len(domain.axes), ))
44
45
    if cf_fluctuations is None:
        cf_fluctuations = cf_fluctuations_sane_default
46
    if cf_azm_uniform is None:
47
        cf_azm_uniform = cf_azm_uniform_sane_default
48
49
50
51

    domain_padded = []
    for d_scl, d in zip(dom_scaling, domain):
        if not isinstance(d, ift.RGSpace) or d.harmonic:
Philipp Arras's avatar
Philipp Arras committed
52
53
54
            te = [f"unexpected domain encountered in `domain`: {domain}"]
            te += "expected a non-harmonic `ift.RGSpace`"
            raise TypeError("\n".join(te))
55
        shape_padded = tuple((d_scl * np.array(d.shape)).astype(int))
56
        domain_padded.append(ift.RGSpace(shape_padded, distances=d.distances))
57
58
59
    domain_padded = ift.DomainTuple.make(domain_padded)

    # Set up the signal model
Philipp Arras's avatar
Philipp Arras committed
60
61
    azm_offset_mean = 0.0  # The zero-mode should be inferred only from the data
    cfmaker = ift.CorrelatedFieldMaker("")
62
    for i, d in enumerate(domain_padded):
63
64
65
66
67
        if isinstance(cf_fluctuations, (list, tuple)):
            cf_fl = cf_fluctuations[i]
        else:
            cf_fl = cf_fluctuations
        cfmaker.add_fluctuations_matern(d, **cf_fl, prefix=f"ax{i}")
Matteo Guardiani's avatar
Matteo Guardiani committed
68
    scalar_domain = ift.DomainTuple.scalar_domain()
69
    uniform = ift.UniformOperator(scalar_domain, *cf_azm_uniform)
70
71
    azm = uniform.ducktape("zeromode")
    cfmaker.set_amplitude_total_offset(azm_offset_mean, azm)
Matteo.Guardiani's avatar
Matteo.Guardiani committed
72
    correlated_field = cfmaker.finalize(0).clip(-10., 10.)
73
    normalized_amplitudes = cfmaker.get_normalized_amplitudes()
74
75
76

    domain_shape = tuple(d.shape for d in domain)
    slc = ift.SliceOperator(correlated_field.target, domain_shape)
77
    signal = ift.exp(slc @ correlated_field)
78
79
80
81

    model_operators = {
        "correlated_field": correlated_field,
        "select_subset": slc,
82
83
        "amplitude_total_offset": azm,
        "normalized_amplitudes": normalized_amplitudes
84
85
    }

86
    return signal, model_operators
87
88
89
90


if __name__ == "__main__":
    filename = "getting_started_density_{}.png"
Matteo.Guardiani's avatar
Matteo.Guardiani committed
91
92
    ift.random.push_sseq_from_seed(42)

93
94
    # Set up signal domain
    npix1 = 128
Matteo.Guardiani's avatar
Matteo.Guardiani committed
95
96
97
98
    npix2 = 128
    sp1 = ift.RGSpace(npix1)
    sp2 = ift.RGSpace(npix2)
    position_space = ift.DomainTuple.make((sp1, sp2))
99

100
    signal, ops = density_estimator(position_space)
101
102
    correlated_field = ops["correlated_field"]

103
    data_space = signal.target
Philipp Arras's avatar
Philipp Arras committed
104

105
106
    # Generate mock signal and data
    rng = ift.random.current_rng()
Philipp Arras's avatar
Philipp Arras committed
107
108
    mock_position = ift.from_random(signal.domain, "normal")
    data = ift.Field.from_raw(data_space, rng.poisson(signal(mock_position).val))
109

Philipp Arras's avatar
Philipp Arras committed
110
    # Rejoin domains for plotting
Matteo.Guardiani's avatar
Matteo.Guardiani committed
111
    plotting_domain = ift.DomainTuple.make(ift.RGSpace((npix1, npix2)))
Philipp Arras's avatar
Philipp Arras committed
112
    plotting_domain_expanded = ift.DomainTuple.make(ift.RGSpace((2 * npix1, 2 * npix2)))
Matteo.Guardiani's avatar
Matteo.Guardiani committed
113

114
    plot = ift.Plot()
115
    plot.add(
Matteo.Guardiani's avatar
Matteo.Guardiani committed
116
        ift.Field.from_raw(
Philipp Arras's avatar
Philipp Arras committed
117
            plotting_domain_expanded, ift.exp(correlated_field(mock_position)).val
Matteo.Guardiani's avatar
Matteo.Guardiani committed
118
        ),
Philipp Arras's avatar
Philipp Arras committed
119
        title="Pre-Slicing Truth",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
120
121
    )
    plot.add(
Philipp Arras's avatar
Philipp Arras committed
122
123
        ift.Field.from_raw(plotting_domain, signal(mock_position).val),
        title="Ground Truth",
124
    )
Philipp Arras's avatar
Philipp Arras committed
125
    plot.add(ift.Field.from_raw(plotting_domain, data.val), title="Data")
Philipp Arras's avatar
Philipp Arras committed
126
    plot.output(ny=1, nx=3, xsize=10, ysize=3, name=filename.format("setup"))
Matteo.Guardiani's avatar
Matteo.Guardiani committed
127
    print("Setup saved as", filename.format("setup"))
128
129

    # Minimization parameters
130
    ic_sampling = ift.AbsDeltaEnergyController(
Philipp Arras's avatar
Philipp Arras committed
131
        name="Sampling", deltaE=0.01, iteration_limit=100
132
133
    )
    ic_newton = ift.AbsDeltaEnergyController(
Philipp Arras's avatar
Philipp Arras committed
134
        name="Newton", deltaE=0.01, iteration_limit=35
135
    )
136
137
138
139
    ic_sampling.enable_logging()
    ic_newton.enable_logging()
    minimizer = ift.NewtonCG(ic_newton, enable_logging=True)

Philipp Arras's avatar
Philipp Arras committed
140
    # Number of samples used to estimate the KL
141
142
143
    n_samples = 5

    # Set up likelihood and information Hamiltonian
144
    likelihood = ift.PoissonianEnergy(data) @ signal
145
146
    ham = ift.StandardHamiltonian(likelihood, ic_sampling)

Philipp Arras's avatar
Philipp Arras committed
147
    # Start minimization
148
149
150
151
152
153
154
155
156
157
158
    initial_mean = ift.MultiField.full(ham.domain, 0.)
    mean = initial_mean

    for i in range(5):
        # Draw new samples and minimize KL
        kl = ift.MetricGaussianKL.make(mean, ham, n_samples, True)
        kl, convergence = minimizer(kl)
        mean = kl.position

        # Plot current reconstruction
        plot = ift.Plot()
Matteo.Guardiani's avatar
Matteo.Guardiani committed
159
160
        plot.add(
            ift.Field.from_raw(
Philipp Arras's avatar
Philipp Arras committed
161
                plotting_domain_expanded, ift.exp(correlated_field(mock_position)).val
Matteo.Guardiani's avatar
Matteo.Guardiani committed
162
            ),
Philipp Arras's avatar
Philipp Arras committed
163
            title="Ground truth",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
164
165
        )
        plot.add(
Philipp Arras's avatar
Philipp Arras committed
166
167
            ift.Field.from_raw(plotting_domain, signal(mock_position).val),
            title="Ground truth",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
168
169
        )
        plot.add(
Philipp Arras's avatar
Philipp Arras committed
170
171
            ift.Field.from_raw(plotting_domain, signal(kl.position).val),
            title="Reconstruction",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
172
        )
173
        plot.add(
Philipp Arras's avatar
Philipp Arras committed
174
175
176
            (ic_newton.history, ic_sampling.history, minimizer.inversion_history),
            label=["kl", "Sampling", "Newton inversion"],
            title="Cumulative energies",
177
            s=[None, None, 1],
Philipp Arras's avatar
Philipp Arras committed
178
            alpha=[None, 0.2, None],
179
180
        )
        plot.output(
Philipp Arras's avatar
Philipp Arras committed
181
            nx=3, ny=2, ysize=10, xsize=15, name=filename.format(f"loop_{i:02d}")
182
        )
183
184
185
186
187
188
189
190
191
192

    # Done, draw posterior samples
    sc = ift.StatCalculator()
    sc_unsliced = ift.StatCalculator()
    for sample in kl.samples:
        sc.add(signal(sample + kl.position))
        sc_unsliced.add(ift.exp(correlated_field(sample + kl.position)))

    # Plotting
    plot = ift.Plot()
Philipp Arras's avatar
Philipp Arras committed
193
    plot.add(ift.Field.from_raw(plotting_domain, sc.mean.val), title="Posterior Mean")
194
    plot.add(
Philipp Arras's avatar
Philipp Arras committed
195
196
        ift.Field.from_raw(plotting_domain, ift.sqrt(sc.var).val),
        title="Posterior Standard Deviation",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
197
198
199
    )
    plot.add(
        ift.Field.from_raw(plotting_domain_expanded, sc_unsliced.mean.val),
Philipp Arras's avatar
Philipp Arras committed
200
        title="Posterior Unsliced Mean",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
201
202
    )
    plot.add(
Philipp Arras's avatar
Philipp Arras committed
203
204
        ift.Field.from_raw(plotting_domain_expanded, ift.sqrt(sc_unsliced.var).val),
        title="Posterior Unsliced Standard Deviation",
205
    )
Philipp Arras's avatar
Philipp Arras committed
206
207
    plot.output(xsize=15, ysize=15, name=filename.format("results"))
    print("Saved results as", filename.format("results"))