getting_started_3.py 5.65 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
15
#
16 17 18
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

############################################################
Philipp Arras's avatar
Philipp Arras committed
19
# Non-linear tomography
Torsten Ensslin's avatar
Torsten Ensslin committed
20 21 22 23 24 25
#
# The signal is a sigmoid-normal distributed field.
# The data is the field integrated along lines of sight that are
# randomly (set mode=0) or radially (mode=1) distributed
#
# Demo takes a while to compute
26
#############################################################
27

Philipp Arras's avatar
Philipp Arras committed
28 29
import sys

Jakob Knollmueller's avatar
Jakob Knollmueller committed
30 31
import numpy as np

Philipp Arras's avatar
Philipp Arras committed
32 33
import nifty5 as ift

Jakob Knollmueller's avatar
Jakob Knollmueller committed
34

Philipp Arras's avatar
Philipp Arras committed
35
def random_los(n_los):
36
    starts = list(np.random.uniform(0, 1, (n_los, 2)).T)
Torsten Ensslin's avatar
Torsten Ensslin committed
37
    ends = list(np.random.uniform(0, 1, (n_los, 2)).T)
Philipp Arras's avatar
Philipp Arras committed
38 39 40 41 42
    return starts, ends


def radial_los(n_los):
    starts = list(np.random.uniform(0, 1, (n_los, 2)).T)
Torsten Ensslin's avatar
Torsten Ensslin committed
43
    ends = list(0.5 + 0*np.random.uniform(0, 1, (n_los, 2)).T)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
44 45
    return starts, ends

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
46

Jakob Knollmueller's avatar
Jakob Knollmueller committed
47
if __name__ == '__main__':
Philipp Arras's avatar
Philipp Arras committed
48
    np.random.seed(420)
Philipp Arras's avatar
Philipp Arras committed
49

Torsten Ensslin's avatar
Torsten Ensslin committed
50 51
    # Choose between random line-of-sight response (mode=0) and radial lines
    # of sight (mode=1)
Lukas Platz's avatar
Lukas Platz committed
52 53 54 55
    if len(sys.argv) == 2:
        mode = int(sys.argv[1])
    else:
        mode = 0
Philipp Arras's avatar
Philipp Arras committed
56
    filename = "getting_started_3_mode_{}_".format(mode) + "{}.png"
Philipp Arras's avatar
Philipp Arras committed
57

58
    position_space = ift.RGSpace([128, 128])
59 60 61
    harmonic_space = position_space.get_default_codomain()
    ht = ift.HarmonicTransformOperator(harmonic_space, position_space)
    power_space = ift.PowerSpace(harmonic_space)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
62

Philipp Arras's avatar
Philipp Arras committed
63
    # Set up an amplitude operator for the field
64 65 66 67 68 69 70 71 72 73 74
    dct = {
        'target': power_space,
        'n_pix': 64,  # 64 spectral bins

        # Spectral smoothness (affects Gaussian process part)
        'a': 3,  # relatively high variance of spectral curbvature
        'k0': .4,  # quefrency mode below which cepstrum flattens

        # Power-law part of spectrum:
        'sm': -5,  # preferred power-law slope
        'sv': .5,  # low variance of power-law slope
Torsten Ensslin's avatar
Torsten Ensslin committed
75 76
        'im':  0,  # y-intercept mean, in-/decrease for more/less contrast
        'iv': .3   # y-intercept variance
77
    }
78
    A = ift.SLAmplitude(**dct)
Philipp Arras's avatar
Philipp Arras committed
79

Philipp Arras's avatar
Philipp Arras committed
80
    # Build the operator for a correlated signal
Jakob Knollmueller's avatar
Jakob Knollmueller committed
81
    power_distributor = ift.PowerDistributor(harmonic_space, power_space)
82 83 84
    vol = harmonic_space.scalar_dvol**-0.5
    xi = ift.ducktape(harmonic_space, None, 'xi')
    correlated_field = ht(vol*power_distributor(A)*xi)
Philipp Arras's avatar
Philipp Arras committed
85 86
    # Alternatively, one can use:
    # correlated_field = ift.CorrelatedField(position_space, A)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
87

Philipp Arras's avatar
Philipp Arras committed
88
    # Apply a nonlinearity
Jakob Knollmueller's avatar
Jakob Knollmueller committed
89
    signal = ift.sigmoid(correlated_field)
Martin Reinecke's avatar
Martin Reinecke committed
90

Philipp Arras's avatar
Philipp Arras committed
91
    # Build the line-of-sight response and define signal response
Torsten Ensslin's avatar
Torsten Ensslin committed
92
    LOS_starts, LOS_ends = random_los(100) if mode == 0 else radial_los(100)
Philipp Arras's avatar
Philipp Arras committed
93
    R = ift.LOSResponse(position_space, starts=LOS_starts, ends=LOS_ends)
Martin Reinecke's avatar
Martin Reinecke committed
94
    signal_response = R(signal)
Philipp Arras's avatar
Philipp Arras committed
95 96

    # Specify noise
Jakob Knollmueller's avatar
Jakob Knollmueller committed
97
    data_space = R.target
Jakob Knollmueller's avatar
Jakob Knollmueller committed
98
    noise = .001
99
    N = ift.ScalingOperator(noise, data_space)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
100

Philipp Arras's avatar
Philipp Arras committed
101 102 103
    # Generate mock signal and data
    mock_position = ift.from_random('normal', signal_response.domain)
    data = signal_response(mock_position) + N.draw_sample()
Jakob Knollmueller's avatar
Jakob Knollmueller committed
104

Philipp Arras's avatar
Philipp Arras committed
105
    # Minimization parameters
Jakob Knollmueller's avatar
Jakob Knollmueller committed
106
    ic_sampling = ift.GradientNormController(iteration_limit=100)
Martin Reinecke's avatar
Martin Reinecke committed
107
    ic_newton = ift.GradInfNormController(
108
        name='Newton', tol=1e-7, iteration_limit=35)
109
    minimizer = ift.NewtonCG(ic_newton)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
110

Philipp Arras's avatar
Philipp Arras committed
111
    # Set up likelihood and information Hamiltonian
Philipp Arras's avatar
Philipp Arras committed
112
    likelihood = ift.GaussianEnergy(mean=data, covariance=N)(signal_response)
113
    H = ift.StandardHamiltonian(likelihood, ic_sampling)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
114

115 116
    initial_mean = ift.MultiField.full(H.domain, 0.)
    mean = initial_mean
Jakob Knollmueller's avatar
Jakob Knollmueller committed
117

118
    plot = ift.Plot()
Philipp Arras's avatar
Philipp Arras committed
119
    plot.add(signal(mock_position), title='Ground Truth')
Martin Reinecke's avatar
merge  
Martin Reinecke committed
120
    plot.add(R.adjoint_times(data), title='Data')
Philipp Arras's avatar
Philipp Arras committed
121
    plot.add([A.force(mock_position)], title='Power Spectrum')
Lukas Platz's avatar
Lukas Platz committed
122
    plot.output(ny=1, nx=3, xsize=24, ysize=6, name=filename.format("setup"))
Jakob Knollmueller's avatar
Jakob Knollmueller committed
123

Jakob Knollmueller's avatar
Jakob Knollmueller committed
124
    # number of samples used to estimate the KL
125
    N_samples = 20
Philipp Arras's avatar
Philipp Arras committed
126 127

    # Draw new samples to approximate the KL five times
128
    for i in range(5):
Philipp Arras's avatar
Philipp Arras committed
129
        # Draw new samples and minimize KL
130
        KL = ift.MetricGaussianKL(mean, H, N_samples)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
131
        KL, convergence = minimizer(KL)
132
        mean = KL.position
Philipp Arras's avatar
Philipp Arras committed
133 134

        # Plot current reconstruction
135
        plot = ift.Plot()
Martin Reinecke's avatar
merge  
Martin Reinecke committed
136
        plot.add(signal(KL.position), title="reconstruction")
Philipp Arras's avatar
Philipp Arras committed
137
        plot.add([A.force(KL.position), A.force(mock_position)], title="power")
Lukas Platz's avatar
Lukas Platz committed
138
        plot.output(ny=1, ysize=6, xsize=16,
Philipp Arras's avatar
Philipp Arras committed
139
                    name=filename.format("loop_{:02d}".format(i)))
Jakob Knollmueller's avatar
Jakob Knollmueller committed
140

Philipp Arras's avatar
Philipp Arras committed
141
    # Draw posterior samples
142
    KL = ift.MetricGaussianKL(mean, H, N_samples)
Martin Reinecke's avatar
Martin Reinecke committed
143
    sc = ift.StatCalculator()
144
    for sample in KL.samples:
Philipp Arras's avatar
Philipp Arras committed
145
        sc.add(signal(sample + KL.position))
Philipp Arras's avatar
Philipp Arras committed
146 147

    # Plotting
Lukas Platz's avatar
Lukas Platz committed
148
    filename_res = filename.format("results")
Philipp Arras's avatar
Philipp Arras committed
149
    plot = ift.Plot()
Martin Reinecke's avatar
merge  
Martin Reinecke committed
150 151
    plot.add(sc.mean, title="Posterior Mean")
    plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation")
Martin Reinecke's avatar
Martin Reinecke committed
152

Philipp Arras's avatar
Philipp Arras committed
153
    powers = [A.force(s + KL.position) for s in KL.samples]
Martin Reinecke's avatar
merge  
Martin Reinecke committed
154
    plot.add(
Philipp Arras's avatar
Philipp Arras committed
155 156
        powers + [A.force(KL.position),
                  A.force(mock_position)],
Lukas Platz's avatar
Lukas Platz committed
157 158
        title="Sampled Posterior Power Spectrum",
        linewidth=[1.]*len(powers) + [3., 3.])
Lukas Platz's avatar
Lukas Platz committed
159
    plot.output(ny=1, nx=3, xsize=24, ysize=6, name=filename_res)
Philipp Arras's avatar
Philipp Arras committed
160
    print("Saved results as '{}'.".format(filename_res))