matern_causal_model.py 9.87 KB
Newer Older
Matteo.Guardiani's avatar
Matteo.Guardiani committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2024 Max-Planck-Society
Matteo.Guardiani's avatar
Matteo.Guardiani committed
15
16
17
# Author: Matteo Guardiani
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
18
import json
Matteo.Guardiani's avatar
Matteo.Guardiani committed
19
20
21

import nifty7 as ift

22
23
import data
from tools import DomainBreak2D, density_estimator
Matteo.Guardiani's avatar
Matteo.Guardiani committed
24
25
from utilities import GeomMaskOperator

26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class MaternCausalModel:
    def __init__(self, setup, dataset, plot, alphas=None):
        if not isinstance(setup, dict):
            raise TypeError("The setup argument needs to be of type dict.")

        if not isinstance(dataset, data.Data):
            raise TypeError("The dataset argument needs to be of type data.Data.")

        if not isinstance(plot, bool):
            raise TypeError("The plot argument needs to be of type bool.")

        self.setup = setup
        self.dataset = dataset
        self.plot = plot
        self.alphas = alphas
        self.lambda_joint = None
43
44
45
46
        self.lambda_x = None
        self.lambda_y = None
        self.lambda_x_full = None
        self.lambda_y_full = None
47
48
49
50
51
52
53
54
55
56
        self.lambda_full = None
        self.amplitudes = None
        self.position_space = None
        self.target_space = None
        self.lambda_combined, self.conditional_probability = self.create_model()
        print('Matérn kernel model initialized.')

    def create_model(self):
        self.lambda_joint, self.lambda_full = self.build_joint_component()

57
        self.lambda_x, self.lambda_y, self.lambda_x_full, self.lambda_y_full, self.amplitudes = \
58
59
60
61
62
63
64
65
            self.initialize_independent_components()

        # Dimensionality adjustment for the independent component
        self.target_space = self.lambda_joint.target

        domain_break_op = DomainBreak2D(self.target_space)

        lambda_joint_placeholder = ift.FieldAdapter(self.lambda_joint.target, 'lambdajoint')
66
67
        lambda_y_placeholder = ift.FieldAdapter(self.lambda_y.target, 'lambday')
        lambda_x_placeholder = ift.FieldAdapter(self.lambda_x.target, 'lambdax')
68
69
70
71

        x_marginalizer_op = domain_break_op(lambda_joint_placeholder.ptw('exp')).sum(
            0)  # Field exponentiation and marginalization along the x direction, hence has 'length' y

72
73
74
75
76
77
        x_unit_field = ift.full(self.lambda_x.target, 1)
        dimensionality_operator = ift.OuterProduct(self.lambda_y.target, x_unit_field)
        lambda_y_2d = domain_break_op.adjoint @ dimensionality_operator @ lambda_y_placeholder
        y_unit_field = ift.full(self.lambda_y.target, 1)
        dimensionality_operator_2 = ift.OuterProduct(self.lambda_x.target, y_unit_field)
        transposition_operator = ift.LinearEinsum(dimensionality_operator_2(lambda_x_placeholder).target,
78
79
            ift.MultiField.from_dict({}), "xy->yx")
        dimensionality_operator_2 = transposition_operator @ dimensionality_operator_2
80
        lambda_x_2d = domain_break_op.adjoint @ dimensionality_operator_2 @ lambda_x_placeholder
81

82
        joint_component = lambda_y_2d + lambda_joint_placeholder
83
84
85
        cond_density = joint_component.ptw('exp') * domain_break_op.adjoint(
            dimensionality_operator(x_marginalizer_op.ptw('reciprocal')))
        normalization = domain_break_op(cond_density).sum(1)
86
        log_lambda_combined = lambda_x_2d + joint_component - domain_break_op.adjoint(
87
88
89
90
            dimensionality_operator(x_marginalizer_op.ptw('log'))) - domain_break_op.adjoint(
            dimensionality_operator_2(normalization.ptw('log')))

        log_lambda_combined = log_lambda_combined @ (
91
92
                self.lambda_joint.ducktape_left('lambdajoint') + self.lambda_y.ducktape_left(
            'lambday') + self.lambda_x.ducktape_left('lambdax'))
93
94
95
96
97
        lambda_combined = log_lambda_combined.ptw('exp')

        conditional_probability = cond_density * domain_break_op.adjoint(dimensionality_operator_2(normalization)).ptw(
            'reciprocal')
        conditional_probability = conditional_probability @ (
98
                self.lambda_joint.ducktape_left('lambdajoint') + self.lambda_y.ducktape_left('lambday'))
99
100
101
102

        # Normalize the probability on the given logload interval
        boundaries = [min(self.dataset.coordinates()[0]), max(self.dataset.coordinates()[0]),
                      min(self.dataset.coordinates()[1]), max(self.dataset.coordinates()[1])]
103
        inv_norm = self.dataset.npix_y / (boundaries[3] - boundaries[2])
104
105
106
107
108
        conditional_probability = conditional_probability * inv_norm

        return lambda_combined, conditional_probability

    def build_joint_component(self):
109
110
        npix_x = self.dataset.npix_x
        npix_y = self.dataset.npix_y
111
112
113
114
115
116
117
118
        self.position_space, sp1, sp2 = self.dataset.zero_pad()

        # Set up signal model
        joint_offset = self.setup['joint']['offset_dict']
        offset_mean = joint_offset['offset_mean']
        offset_std = joint_offset['offset_std']
        joint_prefix = joint_offset['prefix']

119
120
121
122
123
        joint_setup_y = self.setup['joint']['log_load']
        y_scale = joint_setup_y['scale']
        y_cutoff = joint_setup_y['cutoff']
        y_loglogslope = joint_setup_y['loglogslope']
        y_prefix = joint_setup_y['prefix']
124

125
126
127
128
129
        joint_setup_x = self.setup['joint']['x']
        x_scale = joint_setup_x['scale']
        x_cutoff = joint_setup_x['cutoff']
        x_loglogslope = joint_setup_x['loglogslope']
        x_prefix = joint_setup_x['prefix']
130
131
132
133

        correlated_field_maker = ift.CorrelatedFieldMaker(joint_prefix)
        correlated_field_maker.set_amplitude_total_offset(offset_mean, offset_std)

134
135
        correlated_field_maker.add_fluctuations_matern(sp1, x_scale, x_cutoff, x_loglogslope, x_prefix)
        correlated_field_maker.add_fluctuations_matern(sp2, y_scale, y_cutoff, y_loglogslope, y_prefix)
136
137
138
        lambda_full = correlated_field_maker.finalize()

        # For the joint model unmasked regions
139
        tgt = ift.RGSpace((npix_x, npix_y),
140
141
142
143
144
145
146
147
148
149
            distances=(lambda_full.target[0].distances[0], lambda_full.target[1].distances[0]))

        GMO = GeomMaskOperator(lambda_full.target, tgt)
        lambda_joint = GMO(lambda_full.clip(-30, 30))

        full_domain_break_operator = DomainBreak2D(self.position_space)
        lambda_full = full_domain_break_operator.adjoint @ lambda_full

        return lambda_joint, lambda_full

150
    def build_independent_components(self, lambda_x_full, lambda_y_full, amplitudes):
151
        # Split the center
152
153
154
155
156
        # X
        _dist = lambda_x_full.target[0].distances
        tgt_x = ift.RGSpace(self.dataset.npix_x, distances=_dist)
        GMO_x = GeomMaskOperator(lambda_x_full.target, tgt_x)
        lambda_x = GMO_x(lambda_x_full.clip(-30, 30))
157
158

        # Viral load
159
160
161
162
        _dist = lambda_y_full.target[0].distances
        tgt_y = ift.RGSpace(self.dataset.npix_y, distances=_dist)
        GMO_y = GeomMaskOperator(lambda_y_full.target, tgt_y)
        lambda_y = GMO_y(lambda_y_full.clip(-30, 30))
163

164
        return lambda_x, lambda_y, lambda_x_full, lambda_y_full, amplitudes
165
166
167
168
169

    def initialize_independent_components(self):
        _, sp1, sp2 = self.dataset.zero_pad()

        # Set up signal model
170
171
172
173
        # X Parameters
        x_dictionary = self.setup['indep']['x']
        x_offset_mean = x_dictionary['offset_dict']['offset_mean']
        x_offset_std = x_dictionary['offset_dict']['offset_std']
174
175

        # Log Load Parameters
176
177
178
179
180
181
182
183
184
185
        y_dictionary = self.setup['indep']['log_load']
        y_offset_mean = y_dictionary['offset_dict']['offset_mean']
        y_offset_std = y_dictionary['offset_dict']['offset_std']
        indep_y_prefix = y_dictionary['offset_dict']['prefix']

        # Create the x axis with the density estimator
        signal_response, ops = density_estimator(sp1, cf_fluctuations=x_dictionary['params'],
            cf_azm_uniform=x_offset_std, azm_offset_mean=x_offset_mean, pad=0)
        lambda_x_full = ops["correlated_field"]
        x_amplitude = ops["amplitude"]
186
187
188
189
        zero_mode = ops["amplitude_total_offset"]
        # response = ops["exposure"]

        # Create the viral load axis with the Matérn-kernel correlated field
190
191
192
193
194
        correlated_field_maker = ift.CorrelatedFieldMaker(indep_y_prefix)
        correlated_field_maker.set_amplitude_total_offset(y_offset_mean, y_offset_std)
        correlated_field_maker.add_fluctuations_matern(sp2, **y_dictionary['params'])
        lambda_y_full = correlated_field_maker.finalize()
        y_amplitude = correlated_field_maker.amplitude
195

196
        amplitudes = [x_amplitude, y_amplitude]
197

198
        return self.build_independent_components(lambda_x_full, lambda_y_full, amplitudes)
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

    def plot_prior_samples(self, n_samples):
        plot = ift.Plot()

        for i in range(n_samples):
            mock_position = ift.from_random(self.lambda_combined.domain, 'normal')
            plot.add(self.lambda_combined(mock_position))

        filename = "priors.pdf"
        plot.output(nx=int(n_samples / int(n_samples / 3)), ny=int(n_samples / 3), xsize=10, ysize=10, name=filename)

        print("Prior samples saved as", filename)


if __name__ == '__main__':
    file_setup = open('config/config.json', "r")
    setup = json.load(file_setup)
    file_setup.close()

    config_file = 'config/config.json'

    dataset = data.Data(90, 128, setup["threshold"], 0, 'Cobas_Dataset.csv')
    inverted_dataset = data.InvertedData(dataset)
    model = MaternCausalModel(config_file, dataset, False)
    inverted_model = MaternCausalModel(config_file, inverted_dataset, False)
    model.plot_prior_samples(8)
    inverted_model.plot_prior_samples(8)