getting_started_density.py 7.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#!/usr/bin/env python3

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
16
# Copyright(C) 2013-2021 Max-Planck-Society
17
18
19
20
21
22
23
24
25
26
27
28
29
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

############################################################
# Density estimation
#
# Compute a density estimate for a log-normal process measured by a
# Poissonian likelihood.
#
# Demo takes a while to compute
#############################################################

import numpy as np
Philipp Arras's avatar
Philipp Arras committed
30

31
32
33
34
import nifty7 as ift


def density_estimator(
35
36
    domain, pad=1., cf_fluctuations=None, cf_azm_uniform=None
):
Matteo.Guardiani's avatar
Matteo.Guardiani committed
37
    cf_azm_uniform_sane_default = (1e-4, 1.0)
38
    cf_fluctuations_sane_default = {
Matteo Guardiani's avatar
Matteo Guardiani committed
39
        "scale": (0.5, 0.3),
Matteo.Guardiani's avatar
Matteo.Guardiani committed
40
        "cutoff": (4.0, 3.0),
41
        "loglogslope": (-6.0, 3.0)
42
    }
43
44
45

    domain = ift.DomainTuple.make(domain)
    dom_scaling = 1. + np.broadcast_to(pad, (len(domain.axes), ))
46
47
    if cf_fluctuations is None:
        cf_fluctuations = cf_fluctuations_sane_default
48
49
    if cf_azm_uniform is None:
        cf_azm_uni = cf_azm_uniform_sane_default
50
51
52
53
54
55
56
57
58
59

    domain_padded = []
    for d_scl, d in zip(dom_scaling, domain):
        if not isinstance(d, ift.RGSpace) or d.harmonic:
            te = (
                f"unexpected domain encountered in `domain`: {domain}\n"
                "expected a non-harmonic `ift.RGSpace`"
            )
            raise TypeError(te)
        shape_padded = tuple((d_scl * np.array(d.shape)).astype(int))
60
        domain_padded.append(ift.RGSpace(shape_padded, distances=d.distances))
61
62
63
    domain_padded = ift.DomainTuple.make(domain_padded)

    # Set up the signal model
64
65
    prefix = "de_"  # density estimator
    azm_offset_mean = 0.  # The zero-mode should be inferred only from the data
66
67
    cfmaker = ift.CorrelatedFieldMaker(prefix)
    for i, d in enumerate(domain_padded):
68
69
70
71
72
        if isinstance(cf_fluctuations, (list, tuple)):
            cf_fl = cf_fluctuations[i]
        else:
            cf_fl = cf_fluctuations
        cfmaker.add_fluctuations_matern(d, **cf_fl, prefix=f"ax{i}")
Matteo Guardiani's avatar
Matteo Guardiani committed
73
    scalar_domain = ift.DomainTuple.scalar_domain()
74
    uniform = ift.UniformOperator(scalar_domain, *cf_azm_uni)
75
76
    azm = uniform.ducktape("zeromode")
    cfmaker.set_amplitude_total_offset(azm_offset_mean, azm)
Matteo.Guardiani's avatar
Matteo.Guardiani committed
77
    correlated_field = cfmaker.finalize(0).clip(-10., 10.)
78
    normalized_amplitudes = cfmaker.get_normalized_amplitudes()
79
80
81

    domain_shape = tuple(d.shape for d in domain)
    slc = ift.SliceOperator(correlated_field.target, domain_shape)
82
    signal = ift.exp(slc @ correlated_field)
83
84
85
86

    model_operators = {
        "correlated_field": correlated_field,
        "select_subset": slc,
87
88
        "amplitude_total_offset": azm,
        "normalized_amplitudes": normalized_amplitudes
89
90
    }

91
    return signal, model_operators
92
93
94
95
96
97


if __name__ == "__main__":
    # Preparing the filename string for store results
    filename = "getting_started_density_{}.png"

Matteo.Guardiani's avatar
Matteo.Guardiani committed
98
99
100
    # Set the random seed
    ift.random.push_sseq_from_seed(42)

101
102
    # Set up signal domain
    npix1 = 128
Matteo.Guardiani's avatar
Matteo.Guardiani committed
103
104
105
106
    npix2 = 128
    sp1 = ift.RGSpace(npix1)
    sp2 = ift.RGSpace(npix2)
    position_space = ift.DomainTuple.make((sp1, sp2))
107

108
    signal, ops = density_estimator(position_space)
109
110
    correlated_field = ops["correlated_field"]

111
    data_space = signal.target
112
113
    # Generate mock signal and data
    rng = ift.random.current_rng()
114
    mock_position = ift.from_random(signal.domain, 'normal')
115
116
117
    data = ift.Field.from_raw(
        data_space, rng.poisson(signal(mock_position).val)
    )
118

Matteo.Guardiani's avatar
Matteo.Guardiani committed
119
120
121
122
123
124
    # Rejoining domains for ift plotting routine
    plotting_domain = ift.DomainTuple.make(ift.RGSpace((npix1, npix2)))
    plotting_domain_expanded = ift.DomainTuple.make(
        ift.RGSpace((2 * npix1, 2 * npix2))
    )

125
    plot = ift.Plot()
126
    plot.add(
Matteo.Guardiani's avatar
Matteo.Guardiani committed
127
128
129
130
131
132
133
134
135
136
        ift.Field.from_raw(
            plotting_domain_expanded,
            ift.exp(correlated_field(mock_position)).val
        ),
        title='Pre-Slicing Truth'
    )
    plot.add(
        ift.Field.from_raw(plotting_domain,
                           signal(mock_position).val),
        title='Ground Truth'
137
    )
Matteo.Guardiani's avatar
Matteo.Guardiani committed
138
    plot.add(ift.Field.from_raw(plotting_domain, data.val), title='Data')
139
    plot.output(ny=1, nx=3, xsize=10, ysize=10, name=filename.format("setup"))
Matteo.Guardiani's avatar
Matteo.Guardiani committed
140
    print("Setup saved as", filename.format("setup"))
141
142

    # Minimization parameters
143
144
145
146
147
148
    ic_sampling = ift.AbsDeltaEnergyController(
        name='Sampling', deltaE=0.01, iteration_limit=100
    )
    ic_newton = ift.AbsDeltaEnergyController(
        name='Newton', deltaE=0.01, iteration_limit=35
    )
149
150
151
152
153
154
155
156
    ic_sampling.enable_logging()
    ic_newton.enable_logging()
    minimizer = ift.NewtonCG(ic_newton, enable_logging=True)

    # number of samples used to estimate the KL
    n_samples = 5

    # Set up likelihood and information Hamiltonian
157
    likelihood = ift.PoissonianEnergy(data) @ signal
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    ham = ift.StandardHamiltonian(likelihood, ic_sampling)

    # Begin minimization
    initial_mean = ift.MultiField.full(ham.domain, 0.)
    mean = initial_mean

    for i in range(5):
        # Draw new samples and minimize KL
        kl = ift.MetricGaussianKL.make(mean, ham, n_samples, True)
        kl, convergence = minimizer(kl)
        mean = kl.position

        # Plot current reconstruction
        plot = ift.Plot()
Matteo.Guardiani's avatar
Matteo.Guardiani committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        plot.add(
            ift.Field.from_raw(
                plotting_domain_expanded,
                ift.exp(correlated_field(mock_position)).val
            ),
            title="Ground truth"
        )
        plot.add(
            ift.Field.from_raw(plotting_domain,
                               signal(mock_position).val),
            title="Ground truth"
        )
        plot.add(
            ift.Field.from_raw(plotting_domain,
                               signal(kl.position).val),
            title="Reconstruction"
        )
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        plot.add(
            (
                ic_newton.history, ic_sampling.history,
                minimizer.inversion_history
            ),
            label=['kl', 'Sampling', 'Newton inversion'],
            title='Cumulative energies',
            s=[None, None, 1],
            alpha=[None, 0.2, None]
        )
        plot.output(
            nx=3,
            ny=2,
            ysize=10,
            xsize=15,
            name=filename.format(f"loop_{i:02d}")
        )
206
207
208
209
210
211
212
213
214
215
216

    # Done, draw posterior samples
    sc = ift.StatCalculator()
    sc_unsliced = ift.StatCalculator()
    for sample in kl.samples:
        sc.add(signal(sample + kl.position))
        sc_unsliced.add(ift.exp(correlated_field(sample + kl.position)))

    # Plotting
    filename_res = filename.format("results")
    plot = ift.Plot()
217
    plot.add(
Matteo.Guardiani's avatar
Matteo.Guardiani committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        ift.Field.from_raw(plotting_domain, sc.mean.val),
        title="Posterior Mean"
    )
    plot.add(
        ift.Field.from_raw(plotting_domain,
                           ift.sqrt(sc.var).val),
        title="Posterior Standard Deviation"
    )
    plot.add(
        ift.Field.from_raw(plotting_domain_expanded, sc_unsliced.mean.val),
        title="Posterior Unsliced Mean"
    )
    plot.add(
        ift.Field.from_raw(
            plotting_domain_expanded,
            ift.sqrt(sc_unsliced.var).val
        ),
235
236
        title="Posterior Unsliced Standard Deviation"
    )
237
238

    plot.output(ny=2, nx=2, xsize=15, ysize=15, name=filename_res)
Philipp Arras's avatar
Philipp Arras committed
239
    print("Saved results as '{}'.".format(filename_res))