covid_matern_model2.py 12.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2024 Max-Planck-Society
# Author: Matteo Guardiani
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import nifty7 as ift

21
22
import data
from tools import density_estimator, domainbreak_2D
23
24
25
26
from utilities import GeomMaskOperator


class MaternCausalModel:
27
28
29
30
31
32
33
34
    def __init__(self, dataset, plot, alphas=None):
        if not isinstance(dataset, data.Data):
            raise TypeError("The dataset needs to be of type Data.")

        if not isinstance(plot, bool):
            raise TypeError("The dataset needs to be of type Data.")

        self.dataset = dataset
35
36
        self.plot = plot
        self.alphas = alphas
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        self.lambda_joint = None
        self.lambda_age = None
        self.lambda_ll = None
        self.amplitudes = None
        self.position_space = None
        self.target_space = None
        self.lambda_combined = self.create_model()[0]
        self.conditional_probability = self.create_model()[1]

    def create_model(self):
        self.lambda_joint = self.build_joint_component()

        self.lambda_age = self.build_independent_components()[0]
        self.lambda_ll = self.build_independent_components()[1]
        self.amplitudes = self.build_independent_components()[2]

        # Dimensionality adjustment for the independent component
        self.target_space = self.lambda_joint.target

        domain_break_op = domainbreak_2D(self.target_space)
        full_domain_break_op = domainbreak_2D(self.position_space)
        lambda_joint_placeholder = ift.FieldAdapter(self.lambda_joint.target, 'lambdajoint')
        lambda_ll_placeholder = ift.FieldAdapter(self.lambda_ll.target, 'lambdall')
        lambda_age_placeholder = ift.FieldAdapter(self.lambda_age.target, 'lambdaage')
        x_marginalizer_op = domain_break_op(lambda_joint_placeholder.ptw('exp')).sum(
            0)  # Field exponentiation and marginalization along the x direction, hence has 'length' y

        age_unit_field = ift.full(self.lambda_age.target, 1)
        dimensionality_operator = ift.OuterProduct(self.lambda_ll.target, age_unit_field)
        lambda_ll_2d = domain_break_op.adjoint @ dimensionality_operator @ lambda_ll_placeholder

        ll_unit_field = ift.full(self.lambda_ll.target, 1)
        dimensionality_operator_2 = ift.OuterProduct(self.lambda_age.target, ll_unit_field)
        transposition_operator = ift.LinearEinsum(dimensionality_operator_2(lambda_age_placeholder).target,
            ift.MultiField.from_dict({}), "xy->yx")
        dimensionality_operator_2 = transposition_operator @ dimensionality_operator_2
        lambda_age_2d = domain_break_op.adjoint @ dimensionality_operator_2 @ lambda_age_placeholder

        joint_component = lambda_ll_2d + lambda_joint_placeholder
        cond_density = joint_component.ptw('exp') * domain_break_op.adjoint(
            dimensionality_operator(x_marginalizer_op.ptw('reciprocal')))
        normalization = domain_break_op(cond_density).sum(1)
        log_lambda_combined = lambda_age_2d + joint_component - domain_break_op.adjoint(
            dimensionality_operator(x_marginalizer_op.ptw('log'))) - domain_break_op.adjoint(
            dimensionality_operator_2(normalization.ptw('log')))

        log_lambda_combined = log_lambda_combined @ (
                    self.lambda_joint.ducktape_left('lambdajoint') + self.lambda_ll.ducktape_left(
                'lambdall') + self.lambda_age.ducktape_left('lambdaag'))
        lambda_combined = log_lambda_combined.ptw('exp')

        conditional_probability = cond_density * domain_break_op.adjoint(dimensionality_operator_2(normalization)).ptw(
            'reciprocal')
        conditional_probability = conditional_probability @ (
                    self.lambda_joint.ducktape_left('lambdajoint') + self.lambda_ll.ducktape_left('lambdall'))

        # Normalize the probability on the given logload interval
        boundaries = [min(self.dataset.coordinates()[0]), max(self.dataset.coordinates()[0]),
                      min(self.dataset.coordinates()[1]), max(self.dataset.coordinates()[1])]
        inv_norm = self.dataset.npix_ll / (boundaries[3] - boundaries[2])
        conditional_probability = conditional_probability * inv_norm

        return lambda_combined, conditional_probability

    def build_joint_component(self):
        setup = self.dataset.setup
        npix_age = self.dataset.npix_age
        npix_ll = self.dataset.npix_ll
        self.position_space, sp1, sp2 = self.dataset.zero_pad()

        # Set up signal model
        joint_offset = setup['joint']['offset_dict']
        offset_mean = joint_offset['offset_mean']
        offset_std = joint_offset['offset_std']
        joint_prefix = joint_offset['prefix']

        joint_setup_ll = setup['joint']['log_load']
        ll_scale = joint_setup_ll['scale']
        ll_cutoff = joint_setup_ll['cutoff']
        ll_loglogslope = joint_setup_ll['loglogslope']
        ll_prefix = joint_setup_ll['prefix']

        joint_setup_age = setup['joint']['age']
        age_scale = joint_setup_age['scale']
        age_cutoff = joint_setup_age['cutoff']
        age_loglogslope = joint_setup_age['loglogslope']
        age_prefix = joint_setup_age['prefix']

        # .make(offset_mean, offset_std,
        correlated_field_maker = ift.CorrelatedFieldMaker(joint_prefix)
        correlated_field_maker.set_amplitude_total_offset(offset_mean, offset_std, )

        if self.dataset.inversion_par:
            correlated_field_maker.add_fluctuations_matern(sp1, ll_scale, ll_cutoff, ll_loglogslope, ll_prefix)
            correlated_field_maker.add_fluctuations_matern(sp2, age_scale, age_cutoff, age_loglogslope, age_prefix)
        else:
            correlated_field_maker.add_fluctuations_matern(sp1, age_scale, age_cutoff, age_loglogslope, age_prefix)
            correlated_field_maker.add_fluctuations_matern(sp2, ll_scale, ll_cutoff, ll_loglogslope, ll_prefix)
        lambda_full = correlated_field_maker.finalize()

        # For the joint model unmasked regions
        tgt = ift.RGSpace((npix_age, npix_ll),
            distances=(lambda_full.target[0].distances[0], lambda_full.target[1].distances[0]))
        if self.dataset.inversion_par:
            tgt = ift.RGSpace((npix_ll, npix_age),
                distances=(lambda_full.target[0].distances[0], lambda_full.target[1].distances[0]))

        GMO = GeomMaskOperator(lambda_full.target, tgt)
        lambda_joint = GMO(lambda_full.clip(-30, 30))

        return lambda_joint

    def build_independent_components(self):
        setup = self.dataset.setup
        npix_age = self.dataset.npix_age
        npix_ll = self.dataset.npix_ll
        _, sp1, sp2 = self.dataset.zero_pad()

        # Set up signal model
        # Age Parameters
        age_dict = setup['indep']['age']

        age_offset_mean = age_dict['offset_dict']['offset_mean']
        age_offset_std = age_dict['offset_dict']['offset_std']
        indep_age_prefix = age_dict['offset_dict']['prefix']

        age_prefix = age_dict['params']['prefix']

        # Log Load Parameters
        ll_dict = setup['indep']['log_load']

        ll_offset_mean = ll_dict['offset_dict']['offset_mean']
        ll_offset_std = ll_dict['offset_dict']['offset_std']
        indep_ll_prefix = ll_dict['offset_dict']['prefix']

        if self.dataset.inversion_par:
            signal_response, ops = density_estimator(sp1, cf_fluctuations=ll_dict['params'],
                cf_azm_uniform=ll_offset_std, azm_offset_mean=ll_offset_mean, pad=0)
            correlated_field_maker = ift.CorrelatedFieldMaker(indep_age_prefix)
            correlated_field_maker.set_amplitude_total_offset(age_offset_mean, age_offset_std)
            correlated_field_maker.add_fluctuations_matern(sp2, **age_dict['params'])
            lambda_ll_full = correlated_field_maker.finalize()
            ll_amplitude = correlated_field_maker.amplitude

        else:
            signal_response, ops = density_estimator(sp1, cf_fluctuations=age_dict['params'],
                cf_azm_uniform=age_offset_std, azm_offset_mean=age_offset_mean, pad=0)
            correlated_field_maker = ift.CorrelatedFieldMaker(indep_ll_prefix)
            correlated_field_maker.set_amplitude_total_offset(ll_offset_mean, ll_offset_std)
            correlated_field_maker.add_fluctuations_matern(sp2, **ll_dict['params'])
            lambda_ll_full = correlated_field_maker.finalize()
            ll_amplitude = correlated_field_maker.amplitude

        lambda_ag_full = ops["correlated_field"]
        age_amplitude = ops["amplitude"]
        zero_mode = ops["amplitude_total_offset"]
        # response = ops["exposure"]

        # Split the center
        # Age
        _dist = lambda_ag_full.target[0].distances
        tgt_age = ift.RGSpace(npix_age, distances=_dist)
        if self.dataset.inversion_par:
            tgt_age = ift.RGSpace(npix_ll, distances=_dist)
        GMO_age = GeomMaskOperator(lambda_ag_full.target, tgt_age)

        lambda_age = GMO_age(lambda_ag_full.clip(-30, 30))

        # Viral load
        _dist = lambda_ll_full.target[0].distances
        tgt_ll = ift.RGSpace(npix_ll, distances=_dist)
        if self.dataset.inversion_par:
            tgt_ll = ift.RGSpace(npix_age, distances=_dist)
        GMO_ll = GeomMaskOperator(lambda_ll_full.target, tgt_ll)

        lambda_ll = GMO_ll(lambda_ll_full.clip(-30, 30))

        amplitudes = [age_amplitude, ll_amplitude]

        return lambda_age, lambda_ll, amplitudes
217
218

    '''
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    def produce_plots(self):
        if plot == True:
            # Additional fields for plotting (indep)
            lambda_ll_2d = lambda_ll_2d @ lambda_ll.ducktape_left('lambdall')
            lambda_ag_2d = lambda_ag_2d @ lambda_ag.ducktape_left('lambdaag')
            lambda_indep = lambda_ag_2d + lambda_ll_2d
            lambda_indep = lambda_indep.exp()

            # True Joint
            margy = dombr(lambda_joint.exp()).sum(1)
            margx = dombr(lambda_joint.exp()).sum(0)

            lambda_true = dombr.adjoint @ (dimop(margx) + dimop2(margy))
            lambda_true = lambda_joint - lambda_true
            lambda_true = lambda_true.exp()

            # Averages
            averages = []

            for alpha in alphas:
                if inv_par:
                    break
                # if inv_par: globals()['y_average_%s' % alpha] = x_averager(cond_prob.target, ift.Field.from_raw(
                # lambda_ag.target, age_coord), 1/inv_norm, alpha)
                else:
                    globals()['y_average_%s' % alpha] = y_averager(cond_prob.target,
                        ift.Field.from_raw(lambda_ll.target, ll_coord), 1 / inv_norm, alpha)
                globals()['load_average_%s' % alpha] = globals()['y_average_%s' % alpha](cond_prob)
                averages.append(globals()['load_average_%s' % alpha])

            # Infectivity
            infec = ['dw', 'mid', 'up']
            infectivity = []

            for it, inf in enumerate(infec):
                if inv_par: break
                globals()['infectivity_%s' % inf] = ift.Field.from_raw(lambda_ll.target,
                    fitted_infectivity(ll_coord + (it - 1)))
                globals()['infectivity_averager_%s' % inf] = y_averager(cond_prob.target,
                    globals()['infectivity_%s' % inf], 1 / inv_norm, 0)
                infectivity.append(globals()['infectivity_averager_%s' % inf](cond_prob))

            results['independent lambdada'] = lambda_indep
            results['age coordinates'] = age_coord
            results['lload coordinates'] = ll_coord
            results['lambdada true'] = lambda_true
            results['averages'] = averages
            results['infectivity'] = infectivity
            results['domain breaker op'] = dombr

            return results

        elif plot == False:
            return results
        '''