covid_combined_matern_mpi.py 13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2022 Max-Planck-Society
# Author: Matteo Guardiani
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import argparse
import json
import os
import sys

import nifty7 as ift
import numpy as np

try:
    from mpi4py import MPI

    comm = MPI.COMM_WORLD
    n_task = comm.Get_size()
    rank = comm.Get_rank()
except ImportError:
    comm = None
    n_task = 1
    rank = 0

master = (rank == 0)

40
from data_utilities import save_kl_sample, save_kl_position
41
42
43
from utilities import get_op_post_mean
from const import npix_age, npix_ll
from data import Data
44
from matern_causal_model import MaternCausalModel
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# from evidence_g import get_evidence
import matplotlib.colors as colors

# Parser Setup
parser = argparse.ArgumentParser()
parser.add_argument('--json_file', type=str, required=True)  # FIXME: Add help --help
parser.add_argument('--csv_file', type=str, required=True)
parser.add_argument('--reshuffle_parameter', type=int, required=True)
args = parser.parse_args()

json_file = args.json_file
csv_file = args.csv_file
reshuffle_iterator = args.reshuffle_parameter

if __name__ == '__main__':

    # Read in the configuration file
    current_path = os.path.abspath('.')
    file_setup = open(json_file, "r")
    setup = json.load(file_setup)
    file_setup.close()

    # Preparing the filename string and plots folder to store live results
    if not os.path.exists('./plots'):
        os.mkdir('./plots')

    filename = "plots/covid_combined_matern_{}.png"

    # Results Output Folders
Matteo.Guardiani's avatar
Matteo.Guardiani committed
74
75
    json_filename = os.path.basename(json_file)
    csv_filename = os.path.basename(csv_file)
76

Matteo.Guardiani's avatar
Matteo.Guardiani committed
77
    results_path = os.path.join('./Automized_Results_Matern', os.path.splitext(json_filename)[0], os.path.splitext(csv_filename)[0],
78
79
80
81
82
83
        str(reshuffle_iterator))
    results_path = os.path.normpath(results_path)

    os.makedirs(results_path, exist_ok=True)

    # Load the model
Matteo.Guardiani's avatar
Matteo.Guardiani committed
84
85
    data = Data(npix_age, npix_ll, setup['threshold'], reshuffle_iterator, csv_file)
    model = MaternCausalModel(setup, data, False)
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

    # Setup the response & define the amplitudes
    R = ift.GeometryRemover(model.lambda_combined.target)
    R_lamb = R(model.lambda_combined)

    A1 = model.amplitudes[0]
    A2 = model.amplitudes[1]

    # Specify data space
    data_space = R_lamb.target

    # Generate mock signal and data
    seed = setup['seed']
    ift.random.push_sseq_from_seed(seed)

    if setup['mock']:
        # data
        mock_position = ift.from_random(model.lambda_combined.domain, 'normal')
        data = R_lamb(mock_position)
        data = ift.random.current_rng().poisson(data.val.astype(np.float64))
Matteo.Guardiani's avatar
Matteo.Guardiani committed
106
        independent_tag = '_indep'
107

Matteo.Guardiani's avatar
Matteo.Guardiani committed
108
        if not setup['same data'] and independent_tag in json_file:
109
            print("\nUsing syinthetic data generated from joint model on independent model")
Matteo.Guardiani's avatar
Matteo.Guardiani committed
110
            joint_json_file = json_file.replace(independent_tag, '')
111
112
113
114
            file_setup = open(joint_json_file, "r")
            joint_setup = json.load(file_setup)
            file_setup.close()

Matteo.Guardiani's avatar
Matteo.Guardiani committed
115
116
117
            joint_model = MaternCausalModel(joint_setup, data, None)
            joint_lamb_combined = joint_model.lambda_combined
            mock_position = ift.from_random(joint_lamb_combined.domain, 'normal')
118
119
120
            data = R_lamb(mock_position)
            data = ift.random.current_rng().poisson(data.val.astype(np.float64))

Matteo.Guardiani's avatar
Matteo.Guardiani committed
121
        if not setup['same data'] and not independent_tag in json_file:
122
            print("\nUsing syinthetic data generated from independent model on joint model")
Matteo.Guardiani's avatar
Matteo.Guardiani committed
123
124
125
            independent_json_file = os.path.splitext(json_file)[0] + '_indep' + os.path.splitext(json_file)[1]
            file_setup = open(independent_json_file, "r")
            independent_setup = json.load(file_setup)
126
127
            file_setup.close()

Matteo.Guardiani's avatar
Matteo.Guardiani committed
128
129
130
            independent_model = MaternCausalModel(independent_setup, data, None)
            independent_lamb_comb = independent_model.lambda_combined
            mock_position = ift.from_random(independent_lamb_comb.domain, 'normal')
131
132
133
            data = R_lamb(mock_position)
            data = ift.random.current_rng().poisson(data.val.astype(np.float64))

Matteo.Guardiani's avatar
Matteo.Guardiani committed
134
    data_field = ift.makeField(data_space, data.data)
135
136
137

    if setup['mock']:
        plot = ift.Plot()
Matteo.Guardiani's avatar
Matteo.Guardiani committed
138
139
        plot.add(model.lambda_combined(mock_position), title='Full Field')
        plot.add(R.adjoint(data_field), title='Data')
140
141
142
143
144
145
146
147
148
149
150
151
152
        plot.add([A1.force(mock_position)], title='Power Spectrum 1')
        plot.add([A2.force(mock_position)], title='Power Spectrum 2')
        plot.output(ny=3, nx=2, xsize=10, ysize=10, name=filename.format("setup"))

    # Minimization parameters
    ic_sampling = ift.AbsDeltaEnergyController(deltaE=1e-5, iteration_limit=250, convergence_level=250)
    ic_newton = ift.AbsDeltaEnergyController(deltaE=1e-5, iteration_limit=5, name='newton', convergence_level=3)

    ic_sampling.enable_logging()
    ic_newton.enable_logging()
    minimizer = ift.NewtonCG(ic_newton, enable_logging=True)

    # Set up likelihood and information Hamiltonian
Matteo.Guardiani's avatar
Matteo.Guardiani committed
153
    likelihood = ift.PoissonianEnergy(data_field) @ R_lamb
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    H = ift.StandardHamiltonian(likelihood, ic_sampling)

    # Begin minimization
    initial_mean = ift.from_random(H.domain, 'normal') * 0.1
    mean = initial_mean

    N_steps = 35  # 34
    for i in range(N_steps):
        if i < 27:
            ic_newton = ift.AbsDeltaEnergyController(deltaE=1e-5, iteration_limit=10, name='newton',
                convergence_level=3)
            ic_newton.enable_logging()

        else:
            ic_newton = ift.AbsDeltaEnergyController(deltaE=1e-5, iteration_limit=20, name='newton',
                convergence_level=3)
            ic_newton.enable_logging()

        minimizer = ift.NewtonCG(ic_newton, enable_logging=True)

        if i < 30:
            N_samples = 5
        elif i < 33:
            N_samples = 20
        else:
            N_samples = 500

        # Draw new samples and minimize KL
        KL = ift.MetricGaussianKL(mean, H, N_samples, comm=comm, mirror_samples=True, nanisinf=True)
        KL, convergence = minimizer(KL)
        samples = tuple(KL.samples)
        mean = KL.position

        if master:
            it = 0
            pos_path = os.path.join(results_path, "KL_position")
190
            save_kl_position(mean, pos_path)
191
192
193
194
195
196
            print("KL position saved", file=sys.stderr)

            sam_path = os.path.join(results_path, "samples")
            os.makedirs(sam_path, exist_ok=True)

            for sample in samples:
197
                save_kl_sample(sample, os.path.join(sam_path, "KL_sample_{}".format(it)))
198
199
200
201
                it += 1
            print("KL samples saved", file=sys.stderr)

            # Minisanity check
Matteo.Guardiani's avatar
Matteo.Guardiani committed
202
            ift.extra.minisanity(data_field, lambda x: ift.makeOp(R_lamb(x).ptw('reciprocal')), R_lamb, mean,
203
204
205
206
                samples)  # Fix Me: Check noise implementation in minisanity

            # Plot current reconstruction
            plot = ift.Plot()
Matteo.Guardiani's avatar
Matteo.Guardiani committed
207
208
            boundaries = [min(data.coordinates()[0]), max(data.coordinates()[0]),
                          min(data.coordinates()[1]), max(data.coordinates()[1])]
209
210

            if setup['mock']:
Matteo.Guardiani's avatar
Matteo.Guardiani committed
211
212
213
214
                plot.add([model.lambda_combined(mock_position)], title="ground truth")
                plot.add(R.adjoint(data_field), title='Data')
                plot.add([model.lambda_combined(mean)], title="reconstruction")
                plot.add([model.lambda_joint.force(mean)], title="Joint component")
215
216
217
218
219
220
221

                plot.add([A1.force(mean), A1.force(mock_position)], title="power1")
                plot.add([A2.force(mean), A2.force(mock_position)], title="power2")
                plot.add([ic_newton.history, ic_sampling.history, minimizer.inversion_history],
                    label=['KL', 'Sampling', 'Newton inversion'], title='Cumulative energies', s=[None, None, 1],
                    alpha=[None, 0.2, None])
            else:
Matteo.Guardiani's avatar
Matteo.Guardiani committed
222
                plot.add([model.lambda_combined(mean)], title="Reconstruction", norm=colors.SymLogNorm(linthresh=10e-1),
223
                    extent=boundaries, aspect="auto")
Matteo.Guardiani's avatar
Matteo.Guardiani committed
224
225
                # plot.add([lamb_full.force(mean)], title="Joint Component Reconstruction", norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries, aspect="auto")
                plot.add([model.conditional_probability.force(mean)], title="Conditional Probability Reconstruction",
226
227
228
229
                    norm=colors.SymLogNorm(linthresh=6 * 10e-4), extent=boundaries, aspect="auto")
                # plot.add([Aj.force(mean)], title="power1 joint") # FIX ME: MAYBE ACCOUNT FOR THE MARGINALIZATION ??
                plot.add([A1.force(mean)], title="power1 independent")
                plot.add([A2.force(mean)], title="power2 independent")
Matteo.Guardiani's avatar
Matteo.Guardiani committed
230
231
                # plot.add(lamb_ag_full.force(mean), title="Age Reconstruction (full)", aspect="auto")
                # plot.add(lamb_ll_full.force(mean), title="Log load Reconstruction (full)", aspect="auto")
232
233

            plot.output(nx=3, ny=3, ysize=10, xsize=15, name=filename.format("loop_{:02d}".format(i)))
Matteo.Guardiani's avatar
Matteo.Guardiani committed
234
235
            print('Lamb combined check:', model.lambda_combined(KL.position).val.sum(), '\n', file=sys.stderr)
            # print('Zm Xi:', zm.force(KL.position).val, '\n', file=sys.stderr)
236
237
238

    if master:

Matteo.Guardiani's avatar
Matteo.Guardiani committed
239
240
241
        lamb_comb_mean, lamb_comb_var = get_op_post_mean(model.lambda_combined, mean, samples)
        cond_prob_mean, cond_prob_var = get_op_post_mean(model.conditional_probability, mean, samples)
        lamb_full_mean, lamb_full_var = get_op_post_mean(model.lambda_full.exp(), mean, samples)
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

        powers1 = []
        powers2 = []

        for sample in samples:
            p1 = A1.force(sample + mean)
            p2 = A2.force(sample + mean)
            powers1.append(p1)
            powers2.append(p2)

        # Final Plots
        filename_res = "Results.png"
        filename_res = os.path.join(results_path, filename_res)
        plot = ift.Plot()
        plot.add(lamb_comb_mean, title="Posterior Mean", norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries,
            aspect="auto")
        plot.add(lamb_comb_var.sqrt(), title="Posterior Standard Deviation", norm=colors.SymLogNorm(linthresh=10e-1),
            extent=boundaries, aspect="auto")
Matteo.Guardiani's avatar
Matteo.Guardiani committed
260
        plot.add([model.conditional_probability.force(mean)], title="Conditional Probability Reconstruction",
261
262
263
            norm=colors.SymLogNorm(linthresh=6 * 10e-4), extent=boundaries, aspect="auto")
        plot.add([A1.force(mean)], title="Age Independent Power Spectrum (log[S(k^2)])")
        plot.add([A2.force(mean)], title="Log load Independent Power Spectrum (log[S(k^2)])")
Matteo.Guardiani's avatar
Matteo.Guardiani committed
264
        plot.add(model.lambda_age_full.force(mean), title="Age Reconstruction (full)", norm=colors.SymLogNorm(linthresh=10e-1),
265
            aspect="auto")
Matteo.Guardiani's avatar
Matteo.Guardiani committed
266
        plot.add(model.lambda_ll_full.force(mean), title="Log load Reconstruction (full)",
267
            norm=colors.SymLogNorm(linthresh=10e-1), aspect="auto")
Matteo.Guardiani's avatar
Matteo.Guardiani committed
268
        plot.add([model.lambda_full.exp().force(mean)], title="Joint Component Reconstruction (full)",
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
            norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries, aspect="auto")

        plot.output(ny=3, nx=3, xsize=20, ysize=15, name=filename_res)
        print("Saved results as", filename_res, file=sys.stderr)

        # Error Plots
        filename_ers = "Errors.png"
        filename_ers = os.path.join(results_path, filename_ers)
        plot = ift.Plot()
        plot.add(lamb_comb_mean, title="Posterior Mean", norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries,
            aspect="auto")
        plot.add(lamb_comb_var.sqrt() * lamb_comb_mean.ptw('reciprocal'), title="Relative Uncertainty",
            norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries, aspect="auto")
        plot.add(cond_prob_mean, title="Conditional Probability Reconstruction Mean",
            norm=colors.SymLogNorm(linthresh=6 * 10e-4), extent=boundaries, aspect="auto")
        plot.add(cond_prob_var.sqrt() * cond_prob_mean.ptw('reciprocal'),
            title="Relative Uncertainty on Conditional Probability Reconstruction",
            norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries, aspect="auto")
        plot.add(lamb_full_mean, title="Joint Component Reconstruction Mean", norm=colors.SymLogNorm(linthresh=10e-2),
            aspect="auto")
        plot.add(lamb_full_var.sqrt() * lamb_full_mean.ptw('reciprocal'),
            title="Relative Uncertainty on Joint Component Reconstruction", norm=colors.SymLogNorm(linthresh=10e-2),
            aspect="auto")

        plot.output(ny=3, nx=2, xsize=15, ysize=15, name=filename_ers)
Matteo.Guardiani's avatar
Matteo.Guardiani committed
294
        print("Saved results as", filename_ers, file=sys.stderr)