getting_started_density.py 7.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#!/usr/bin/env python3

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
16
# Copyright(C) 2013-2021 Max-Planck-Society
17
18
19
20
21
22
23
24
25
26
27
28
29
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

############################################################
# Density estimation
#
# Compute a density estimate for a log-normal process measured by a
# Poissonian likelihood.
#
# Demo takes a while to compute
#############################################################

import numpy as np
Philipp Arras's avatar
Philipp Arras committed
30

31
32
33
import nifty7 as ift


Philipp Arras's avatar
Philipp Arras committed
34
def density_estimator(domain, pad=1.0, cf_fluctuations=None, cf_azm_uniform=None):
Matteo.Guardiani's avatar
Matteo.Guardiani committed
35
    cf_azm_uniform_sane_default = (1e-4, 1.0)
36
    cf_fluctuations_sane_default = {
Matteo Guardiani's avatar
Matteo Guardiani committed
37
        "scale": (0.5, 0.3),
Matteo.Guardiani's avatar
Matteo.Guardiani committed
38
        "cutoff": (4.0, 3.0),
39
        "loglogslope": (-6.0, 3.0)
40
    }
41
42
43

    domain = ift.DomainTuple.make(domain)
    dom_scaling = 1. + np.broadcast_to(pad, (len(domain.axes), ))
44
45
    if cf_fluctuations is None:
        cf_fluctuations = cf_fluctuations_sane_default
46
47
    if cf_azm_uniform is None:
        cf_azm_uni = cf_azm_uniform_sane_default
48
49
50
51
52
53
54
55
56
57

    domain_padded = []
    for d_scl, d in zip(dom_scaling, domain):
        if not isinstance(d, ift.RGSpace) or d.harmonic:
            te = (
                f"unexpected domain encountered in `domain`: {domain}\n"
                "expected a non-harmonic `ift.RGSpace`"
            )
            raise TypeError(te)
        shape_padded = tuple((d_scl * np.array(d.shape)).astype(int))
58
        domain_padded.append(ift.RGSpace(shape_padded, distances=d.distances))
59
60
61
    domain_padded = ift.DomainTuple.make(domain_padded)

    # Set up the signal model
62
63
    prefix = "de_"  # density estimator
    azm_offset_mean = 0.  # The zero-mode should be inferred only from the data
64
65
    cfmaker = ift.CorrelatedFieldMaker(prefix)
    for i, d in enumerate(domain_padded):
66
67
68
69
70
        if isinstance(cf_fluctuations, (list, tuple)):
            cf_fl = cf_fluctuations[i]
        else:
            cf_fl = cf_fluctuations
        cfmaker.add_fluctuations_matern(d, **cf_fl, prefix=f"ax{i}")
Matteo Guardiani's avatar
Matteo Guardiani committed
71
    scalar_domain = ift.DomainTuple.scalar_domain()
72
    uniform = ift.UniformOperator(scalar_domain, *cf_azm_uni)
73
74
    azm = uniform.ducktape("zeromode")
    cfmaker.set_amplitude_total_offset(azm_offset_mean, azm)
Matteo.Guardiani's avatar
Matteo.Guardiani committed
75
    correlated_field = cfmaker.finalize(0).clip(-10., 10.)
76
    normalized_amplitudes = cfmaker.get_normalized_amplitudes()
77
78
79

    domain_shape = tuple(d.shape for d in domain)
    slc = ift.SliceOperator(correlated_field.target, domain_shape)
80
    signal = ift.exp(slc @ correlated_field)
81
82
83
84

    model_operators = {
        "correlated_field": correlated_field,
        "select_subset": slc,
85
86
        "amplitude_total_offset": azm,
        "normalized_amplitudes": normalized_amplitudes
87
88
    }

89
    return signal, model_operators
90
91
92
93


if __name__ == "__main__":
    filename = "getting_started_density_{}.png"
Matteo.Guardiani's avatar
Matteo.Guardiani committed
94
95
    ift.random.push_sseq_from_seed(42)

96
97
    # Set up signal domain
    npix1 = 128
Matteo.Guardiani's avatar
Matteo.Guardiani committed
98
99
100
101
    npix2 = 128
    sp1 = ift.RGSpace(npix1)
    sp2 = ift.RGSpace(npix2)
    position_space = ift.DomainTuple.make((sp1, sp2))
102

103
    signal, ops = density_estimator(position_space)
104
105
    correlated_field = ops["correlated_field"]

106
    data_space = signal.target
Philipp Arras's avatar
Philipp Arras committed
107

108
109
    # Generate mock signal and data
    rng = ift.random.current_rng()
Philipp Arras's avatar
Philipp Arras committed
110
111
    mock_position = ift.from_random(signal.domain, "normal")
    data = ift.Field.from_raw(data_space, rng.poisson(signal(mock_position).val))
112

Philipp Arras's avatar
Philipp Arras committed
113
    # Rejoin domains for plotting
Matteo.Guardiani's avatar
Matteo.Guardiani committed
114
    plotting_domain = ift.DomainTuple.make(ift.RGSpace((npix1, npix2)))
Philipp Arras's avatar
Philipp Arras committed
115
    plotting_domain_expanded = ift.DomainTuple.make(ift.RGSpace((2 * npix1, 2 * npix2)))
Matteo.Guardiani's avatar
Matteo.Guardiani committed
116

117
    plot = ift.Plot()
118
    plot.add(
Matteo.Guardiani's avatar
Matteo.Guardiani committed
119
        ift.Field.from_raw(
Philipp Arras's avatar
Philipp Arras committed
120
            plotting_domain_expanded, ift.exp(correlated_field(mock_position)).val
Matteo.Guardiani's avatar
Matteo.Guardiani committed
121
        ),
Philipp Arras's avatar
Philipp Arras committed
122
        title="Pre-Slicing Truth",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
123
124
    )
    plot.add(
Philipp Arras's avatar
Philipp Arras committed
125
126
        ift.Field.from_raw(plotting_domain, signal(mock_position).val),
        title="Ground Truth",
127
    )
Philipp Arras's avatar
Philipp Arras committed
128
    plot.add(ift.Field.from_raw(plotting_domain, data.val), title="Data")
129
    plot.output(ny=1, nx=3, xsize=10, ysize=10, name=filename.format("setup"))
Matteo.Guardiani's avatar
Matteo.Guardiani committed
130
    print("Setup saved as", filename.format("setup"))
131
132

    # Minimization parameters
133
    ic_sampling = ift.AbsDeltaEnergyController(
Philipp Arras's avatar
Philipp Arras committed
134
        name="Sampling", deltaE=0.01, iteration_limit=100
135
136
    )
    ic_newton = ift.AbsDeltaEnergyController(
Philipp Arras's avatar
Philipp Arras committed
137
        name="Newton", deltaE=0.01, iteration_limit=35
138
    )
139
140
141
142
    ic_sampling.enable_logging()
    ic_newton.enable_logging()
    minimizer = ift.NewtonCG(ic_newton, enable_logging=True)

Philipp Arras's avatar
Philipp Arras committed
143
    # Number of samples used to estimate the KL
144
145
146
    n_samples = 5

    # Set up likelihood and information Hamiltonian
147
    likelihood = ift.PoissonianEnergy(data) @ signal
148
149
    ham = ift.StandardHamiltonian(likelihood, ic_sampling)

Philipp Arras's avatar
Philipp Arras committed
150
    # Start minimization
151
152
153
154
155
156
157
158
159
160
161
    initial_mean = ift.MultiField.full(ham.domain, 0.)
    mean = initial_mean

    for i in range(5):
        # Draw new samples and minimize KL
        kl = ift.MetricGaussianKL.make(mean, ham, n_samples, True)
        kl, convergence = minimizer(kl)
        mean = kl.position

        # Plot current reconstruction
        plot = ift.Plot()
Matteo.Guardiani's avatar
Matteo.Guardiani committed
162
163
        plot.add(
            ift.Field.from_raw(
Philipp Arras's avatar
Philipp Arras committed
164
                plotting_domain_expanded, ift.exp(correlated_field(mock_position)).val
Matteo.Guardiani's avatar
Matteo.Guardiani committed
165
            ),
Philipp Arras's avatar
Philipp Arras committed
166
            title="Ground truth",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
167
168
        )
        plot.add(
Philipp Arras's avatar
Philipp Arras committed
169
170
            ift.Field.from_raw(plotting_domain, signal(mock_position).val),
            title="Ground truth",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
171
172
        )
        plot.add(
Philipp Arras's avatar
Philipp Arras committed
173
174
            ift.Field.from_raw(plotting_domain, signal(kl.position).val),
            title="Reconstruction",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
175
        )
176
        plot.add(
Philipp Arras's avatar
Philipp Arras committed
177
178
179
            (ic_newton.history, ic_sampling.history, minimizer.inversion_history),
            label=["kl", "Sampling", "Newton inversion"],
            title="Cumulative energies",
180
            s=[None, None, 1],
Philipp Arras's avatar
Philipp Arras committed
181
            alpha=[None, 0.2, None],
182
183
        )
        plot.output(
Philipp Arras's avatar
Philipp Arras committed
184
            nx=3, ny=2, ysize=10, xsize=15, name=filename.format(f"loop_{i:02d}")
185
        )
186
187
188
189
190
191
192
193
194
195

    # Done, draw posterior samples
    sc = ift.StatCalculator()
    sc_unsliced = ift.StatCalculator()
    for sample in kl.samples:
        sc.add(signal(sample + kl.position))
        sc_unsliced.add(ift.exp(correlated_field(sample + kl.position)))

    # Plotting
    plot = ift.Plot()
Philipp Arras's avatar
Philipp Arras committed
196
    plot.add(ift.Field.from_raw(plotting_domain, sc.mean.val), title="Posterior Mean")
197
    plot.add(
Philipp Arras's avatar
Philipp Arras committed
198
199
        ift.Field.from_raw(plotting_domain, ift.sqrt(sc.var).val),
        title="Posterior Standard Deviation",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
200
201
202
    )
    plot.add(
        ift.Field.from_raw(plotting_domain_expanded, sc_unsliced.mean.val),
Philipp Arras's avatar
Philipp Arras committed
203
        title="Posterior Unsliced Mean",
Matteo.Guardiani's avatar
Matteo.Guardiani committed
204
205
    )
    plot.add(
Philipp Arras's avatar
Philipp Arras committed
206
207
        ift.Field.from_raw(plotting_domain_expanded, ift.sqrt(sc_unsliced.var).val),
        title="Posterior Unsliced Standard Deviation",
208
    )
Philipp Arras's avatar
Philipp Arras committed
209
    filename_res = filename.format("results")
210
    plot.output(ny=2, nx=2, xsize=15, ysize=15, name=filename_res)
Philipp Arras's avatar
Philipp Arras committed
211
    print("Saved results as '{}'.".format(filename_res))