Commit f572787f authored by Matteo.Guardiani's avatar Matteo.Guardiani
Browse files

andrija: Refactored age -> x, ll -> y. Made more covid-agnostic.

parent 8c8c5aca
...@@ -25,9 +25,9 @@ import matplotlib.colors as colors ...@@ -25,9 +25,9 @@ import matplotlib.colors as colors
import nifty7 as ift import nifty7 as ift
from const import npix_age, npix_ll from const import npix_age, npix_ll
from covid_matern_model import MaternCausalModel
from data import Data from data import Data
from data_utilities import save_kl_position, save_kl_sample from data_utilities import save_kl_position, save_kl_sample
from matern_causal_model import MaternCausalModel
from utilities import get_op_post_mean from utilities import get_op_post_mean
# Parser Setup # Parser Setup
......
...@@ -41,7 +41,7 @@ from data_utilities import save_kl_sample, save_kl_position ...@@ -41,7 +41,7 @@ from data_utilities import save_kl_sample, save_kl_position
from utilities import get_op_post_mean from utilities import get_op_post_mean
from const import npix_age, npix_ll from const import npix_age, npix_ll
from data import Data from data import Data
from covid_matern_model import MaternCausalModel from matern_causal_model import MaternCausalModel
# from evidence_g import get_evidence # from evidence_g import get_evidence
import matplotlib.colors as colors import matplotlib.colors as colors
......
...@@ -32,34 +32,34 @@ class ClassLoader(type): ...@@ -32,34 +32,34 @@ class ClassLoader(type):
class Data(metaclass=ClassLoader): class Data(metaclass=ClassLoader):
def __init__(self, npix_age, npix_ll, ll_threshold, reshuffle_seed, age, ll): def __init__(self, npix_x, npix_y, y_threshold, reshuffle_seed, x, y):
if not isinstance(npix_age, int): if not isinstance(npix_x, int):
raise TypeError("Number of pixels argument needs to be of type int.") raise TypeError("Number of pixels argument needs to be of type int.")
if not isinstance(npix_ll, int): if not isinstance(npix_y, int):
raise TypeError("Number of pixels argument needs to be of type int.") raise TypeError("Number of pixels argument needs to be of type int.")
if not isinstance(ll_threshold, float): if not isinstance(y_threshold, float):
raise TypeError("Log load threshold value argument needs to be of type float.") raise TypeError("Log load threshold value argument needs to be of type float.")
if not isinstance(reshuffle_seed, int): if not isinstance(reshuffle_seed, int):
raise TypeError("Reshuffle iterator argument needs to be of type int.") raise TypeError("Reshuffle iterator argument needs to be of type int.")
if not isinstance(age, np.ndarray): if not isinstance(x, np.ndarray):
raise TypeError("The age dataset argument needs to be of type np.ndarray.") raise TypeError("The x dataset argument needs to be of type np.ndarray.")
if not isinstance(ll, np.ndarray): if not isinstance(y, np.ndarray):
raise TypeError("The ll dataset argument needs to be of type np.ndarray.") raise TypeError("The y dataset argument needs to be of type np.ndarray.")
if not age.size == ll.size: if not x.size == y.size:
raise TypeError("The dataset has to contain pairs of age and log load data.") raise TypeError("The dataset has to contain pairs of x and log load data.")
self.npix_age, self.npix_ll = npix_age, npix_ll self.npix_x, self.npix_y = npix_x, npix_y
self.ll_threshold = ll_threshold self.y_threshold = y_threshold
self.reshuffle_seed = reshuffle_seed self.reshuffle_seed = reshuffle_seed
self.age = age, self.ll = ll self.x = x, self.y = y
self.age, self.ll = self.filter() self.x, self.y = self.filter()
self.data = None self.data = None
self.edges = None self.edges = None
self.filename = 'plots/data.pdf' self.filename = 'plots/data.pdf'
...@@ -70,30 +70,30 @@ class Data(metaclass=ClassLoader): ...@@ -70,30 +70,30 @@ class Data(metaclass=ClassLoader):
def filter(self): def filter(self):
# Loads, filters and reshuffles data # Loads, filters and reshuffles data
self.ll, self.age = self.data_filter_x(self.ll_threshold, self.ll, self.age) self.y, self.x = self.data_filter_x(self.y_threshold, self.y, self.x)
if not self.reshuffle_seed == 0: if not self.reshuffle_seed == 0:
self.__reshuffle_data(self.ll, self.reshuffle_seed) self.__reshuffle_data(self.y, self.reshuffle_seed)
return self.age, self.ll return self.x, self.y
def zero_pad(self): def zero_pad(self):
ext_npix_age = 2 * self.npix_age ext_npix_x = 2 * self.npix_x
ext_npix_ll = 2 * self.npix_ll ext_npix_y = 2 * self.npix_y
return self.__create_spaces(ext_npix_age, ext_npix_ll) return self.__create_spaces(ext_npix_x, ext_npix_y)
def __bin(self): def __bin(self):
data, age_edges, ll_edges = bin_2D(self.age, self.ll, self.npix_age, self.npix_ll) data, x_edges, y_edges = bin_2D(self.x, self.y, self.npix_x, self.npix_y)
self.data = np.array(data, dtype=np.int64) self.data = np.array(data, dtype=np.int64)
return data, age_edges, ll_edges return data, x_edges, y_edges
def coordinates(self): def coordinates(self):
age_coordinates = self.__obtain_coordinates(self.age, self.npix_age) x_coordinates = self.__obtain_coordinates(self.x, self.npix_x)
ll_coordinates = self.__obtain_coordinates(self.ll, self.npix_ll) y_coordinates = self.__obtain_coordinates(self.y, self.npix_y)
return age_coordinates, ll_coordinates return x_coordinates, y_coordinates
def plot(self): def plot(self):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -137,7 +137,7 @@ class Data(metaclass=ClassLoader): ...@@ -137,7 +137,7 @@ class Data(metaclass=ClassLoader):
class InvertedData(Data): class InvertedData(Data):
def __init__(self, data): def __init__(self, data):
super().__init__(data.npix_age, data.npix_ll, data.ll_threshold, data.reshuffle_seed, data.csv_dataset_path) super().__init__(data.npix_x, data.npix_y, data.y_threshold, data.reshuffle_seed, data.csv_dataset_path)
self.age, self.ll = self.ll, self.age self.x, self.y = self.y, self.x
self.npix_age, self.npix_ll = self.npix_ll, self.npix_age self.npix_x, self.npix_y = self.npix_y, self.npix_x
self.filename = 'plots/inverted_data.pdf' self.filename = 'plots/inverted_data.pdf'
...@@ -40,10 +40,10 @@ class MaternCausalModel: ...@@ -40,10 +40,10 @@ class MaternCausalModel:
self.plot = plot self.plot = plot
self.alphas = alphas self.alphas = alphas
self.lambda_joint = None self.lambda_joint = None
self.lambda_age = None self.lambda_x = None
self.lambda_ll = None self.lambda_y = None
self.lambda_age_full = None self.lambda_x_full = None
self.lambda_ll_full = None self.lambda_y_full = None
self.lambda_full = None self.lambda_full = None
self.amplitudes = None self.amplitudes = None
self.position_space = None self.position_space = None
...@@ -54,7 +54,7 @@ class MaternCausalModel: ...@@ -54,7 +54,7 @@ class MaternCausalModel:
def create_model(self): def create_model(self):
self.lambda_joint, self.lambda_full = self.build_joint_component() self.lambda_joint, self.lambda_full = self.build_joint_component()
self.lambda_age, self.lambda_ll, self.lambda_age_full, self.lambda_ll_full, self.amplitudes = \ self.lambda_x, self.lambda_y, self.lambda_x_full, self.lambda_y_full, self.amplitudes = \
self.initialize_independent_components() self.initialize_independent_components()
# Dimensionality adjustment for the independent component # Dimensionality adjustment for the independent component
...@@ -63,51 +63,51 @@ class MaternCausalModel: ...@@ -63,51 +63,51 @@ class MaternCausalModel:
domain_break_op = DomainBreak2D(self.target_space) domain_break_op = DomainBreak2D(self.target_space)
lambda_joint_placeholder = ift.FieldAdapter(self.lambda_joint.target, 'lambdajoint') lambda_joint_placeholder = ift.FieldAdapter(self.lambda_joint.target, 'lambdajoint')
lambda_ll_placeholder = ift.FieldAdapter(self.lambda_ll.target, 'lambdall') lambda_y_placeholder = ift.FieldAdapter(self.lambda_y.target, 'lambday')
lambda_age_placeholder = ift.FieldAdapter(self.lambda_age.target, 'lambdaage') lambda_x_placeholder = ift.FieldAdapter(self.lambda_x.target, 'lambdax')
x_marginalizer_op = domain_break_op(lambda_joint_placeholder.ptw('exp')).sum( x_marginalizer_op = domain_break_op(lambda_joint_placeholder.ptw('exp')).sum(
0) # Field exponentiation and marginalization along the x direction, hence has 'length' y 0) # Field exponentiation and marginalization along the x direction, hence has 'length' y
age_unit_field = ift.full(self.lambda_age.target, 1) x_unit_field = ift.full(self.lambda_x.target, 1)
dimensionality_operator = ift.OuterProduct(self.lambda_ll.target, age_unit_field) dimensionality_operator = ift.OuterProduct(self.lambda_y.target, x_unit_field)
lambda_ll_2d = domain_break_op.adjoint @ dimensionality_operator @ lambda_ll_placeholder lambda_y_2d = domain_break_op.adjoint @ dimensionality_operator @ lambda_y_placeholder
ll_unit_field = ift.full(self.lambda_ll.target, 1) y_unit_field = ift.full(self.lambda_y.target, 1)
dimensionality_operator_2 = ift.OuterProduct(self.lambda_age.target, ll_unit_field) dimensionality_operator_2 = ift.OuterProduct(self.lambda_x.target, y_unit_field)
transposition_operator = ift.LinearEinsum(dimensionality_operator_2(lambda_age_placeholder).target, transposition_operator = ift.LinearEinsum(dimensionality_operator_2(lambda_x_placeholder).target,
ift.MultiField.from_dict({}), "xy->yx") ift.MultiField.from_dict({}), "xy->yx")
dimensionality_operator_2 = transposition_operator @ dimensionality_operator_2 dimensionality_operator_2 = transposition_operator @ dimensionality_operator_2
lambda_age_2d = domain_break_op.adjoint @ dimensionality_operator_2 @ lambda_age_placeholder lambda_x_2d = domain_break_op.adjoint @ dimensionality_operator_2 @ lambda_x_placeholder
joint_component = lambda_ll_2d + lambda_joint_placeholder joint_component = lambda_y_2d + lambda_joint_placeholder
cond_density = joint_component.ptw('exp') * domain_break_op.adjoint( cond_density = joint_component.ptw('exp') * domain_break_op.adjoint(
dimensionality_operator(x_marginalizer_op.ptw('reciprocal'))) dimensionality_operator(x_marginalizer_op.ptw('reciprocal')))
normalization = domain_break_op(cond_density).sum(1) normalization = domain_break_op(cond_density).sum(1)
log_lambda_combined = lambda_age_2d + joint_component - domain_break_op.adjoint( log_lambda_combined = lambda_x_2d + joint_component - domain_break_op.adjoint(
dimensionality_operator(x_marginalizer_op.ptw('log'))) - domain_break_op.adjoint( dimensionality_operator(x_marginalizer_op.ptw('log'))) - domain_break_op.adjoint(
dimensionality_operator_2(normalization.ptw('log'))) dimensionality_operator_2(normalization.ptw('log')))
log_lambda_combined = log_lambda_combined @ ( log_lambda_combined = log_lambda_combined @ (
self.lambda_joint.ducktape_left('lambdajoint') + self.lambda_ll.ducktape_left( self.lambda_joint.ducktape_left('lambdajoint') + self.lambda_y.ducktape_left(
'lambdall') + self.lambda_age.ducktape_left('lambdaage')) 'lambday') + self.lambda_x.ducktape_left('lambdax'))
lambda_combined = log_lambda_combined.ptw('exp') lambda_combined = log_lambda_combined.ptw('exp')
conditional_probability = cond_density * domain_break_op.adjoint(dimensionality_operator_2(normalization)).ptw( conditional_probability = cond_density * domain_break_op.adjoint(dimensionality_operator_2(normalization)).ptw(
'reciprocal') 'reciprocal')
conditional_probability = conditional_probability @ ( conditional_probability = conditional_probability @ (
self.lambda_joint.ducktape_left('lambdajoint') + self.lambda_ll.ducktape_left('lambdall')) self.lambda_joint.ducktape_left('lambdajoint') + self.lambda_y.ducktape_left('lambday'))
# Normalize the probability on the given logload interval # Normalize the probability on the given logload interval
boundaries = [min(self.dataset.coordinates()[0]), max(self.dataset.coordinates()[0]), boundaries = [min(self.dataset.coordinates()[0]), max(self.dataset.coordinates()[0]),
min(self.dataset.coordinates()[1]), max(self.dataset.coordinates()[1])] min(self.dataset.coordinates()[1]), max(self.dataset.coordinates()[1])]
inv_norm = self.dataset.npix_ll / (boundaries[3] - boundaries[2]) inv_norm = self.dataset.npix_y / (boundaries[3] - boundaries[2])
conditional_probability = conditional_probability * inv_norm conditional_probability = conditional_probability * inv_norm
return lambda_combined, conditional_probability return lambda_combined, conditional_probability
def build_joint_component(self): def build_joint_component(self):
npix_age = self.dataset.npix_age npix_x = self.dataset.npix_x
npix_ll = self.dataset.npix_ll npix_y = self.dataset.npix_y
self.position_space, sp1, sp2 = self.dataset.zero_pad() self.position_space, sp1, sp2 = self.dataset.zero_pad()
# Set up signal model # Set up signal model
...@@ -116,27 +116,27 @@ class MaternCausalModel: ...@@ -116,27 +116,27 @@ class MaternCausalModel:
offset_std = joint_offset['offset_std'] offset_std = joint_offset['offset_std']
joint_prefix = joint_offset['prefix'] joint_prefix = joint_offset['prefix']
joint_setup_ll = self.setup['joint']['log_load'] joint_setup_y = self.setup['joint']['log_load']
ll_scale = joint_setup_ll['scale'] y_scale = joint_setup_y['scale']
ll_cutoff = joint_setup_ll['cutoff'] y_cutoff = joint_setup_y['cutoff']
ll_loglogslope = joint_setup_ll['loglogslope'] y_loglogslope = joint_setup_y['loglogslope']
ll_prefix = joint_setup_ll['prefix'] y_prefix = joint_setup_y['prefix']
joint_setup_age = self.setup['joint']['age'] joint_setup_x = self.setup['joint']['x']
age_scale = joint_setup_age['scale'] x_scale = joint_setup_x['scale']
age_cutoff = joint_setup_age['cutoff'] x_cutoff = joint_setup_x['cutoff']
age_loglogslope = joint_setup_age['loglogslope'] x_loglogslope = joint_setup_x['loglogslope']
age_prefix = joint_setup_age['prefix'] x_prefix = joint_setup_x['prefix']
correlated_field_maker = ift.CorrelatedFieldMaker(joint_prefix) correlated_field_maker = ift.CorrelatedFieldMaker(joint_prefix)
correlated_field_maker.set_amplitude_total_offset(offset_mean, offset_std) correlated_field_maker.set_amplitude_total_offset(offset_mean, offset_std)
correlated_field_maker.add_fluctuations_matern(sp1, age_scale, age_cutoff, age_loglogslope, age_prefix) correlated_field_maker.add_fluctuations_matern(sp1, x_scale, x_cutoff, x_loglogslope, x_prefix)
correlated_field_maker.add_fluctuations_matern(sp2, ll_scale, ll_cutoff, ll_loglogslope, ll_prefix) correlated_field_maker.add_fluctuations_matern(sp2, y_scale, y_cutoff, y_loglogslope, y_prefix)
lambda_full = correlated_field_maker.finalize() lambda_full = correlated_field_maker.finalize()
# For the joint model unmasked regions # For the joint model unmasked regions
tgt = ift.RGSpace((npix_age, npix_ll), tgt = ift.RGSpace((npix_x, npix_y),
distances=(lambda_full.target[0].distances[0], lambda_full.target[1].distances[0])) distances=(lambda_full.target[0].distances[0], lambda_full.target[1].distances[0]))
GMO = GeomMaskOperator(lambda_full.target, tgt) GMO = GeomMaskOperator(lambda_full.target, tgt)
...@@ -147,55 +147,55 @@ class MaternCausalModel: ...@@ -147,55 +147,55 @@ class MaternCausalModel:
return lambda_joint, lambda_full return lambda_joint, lambda_full
def build_independent_components(self, lambda_ag_full, lambda_ll_full, amplitudes): def build_independent_components(self, lambda_x_full, lambda_y_full, amplitudes):
# Split the center # Split the center
# Age # X
_dist = lambda_ag_full.target[0].distances _dist = lambda_x_full.target[0].distances
tgt_age = ift.RGSpace(self.dataset.npix_age, distances=_dist) tgt_x = ift.RGSpace(self.dataset.npix_x, distances=_dist)
GMO_age = GeomMaskOperator(lambda_ag_full.target, tgt_age) GMO_x = GeomMaskOperator(lambda_x_full.target, tgt_x)
lambda_age = GMO_age(lambda_ag_full.clip(-30, 30)) lambda_x = GMO_x(lambda_x_full.clip(-30, 30))
# Viral load # Viral load
_dist = lambda_ll_full.target[0].distances _dist = lambda_y_full.target[0].distances
tgt_ll = ift.RGSpace(self.dataset.npix_ll, distances=_dist) tgt_y = ift.RGSpace(self.dataset.npix_y, distances=_dist)
GMO_ll = GeomMaskOperator(lambda_ll_full.target, tgt_ll) GMO_y = GeomMaskOperator(lambda_y_full.target, tgt_y)
lambda_ll = GMO_ll(lambda_ll_full.clip(-30, 30)) lambda_y = GMO_y(lambda_y_full.clip(-30, 30))
return lambda_age, lambda_ll, lambda_ag_full, lambda_ll_full, amplitudes return lambda_x, lambda_y, lambda_x_full, lambda_y_full, amplitudes
def initialize_independent_components(self): def initialize_independent_components(self):
_, sp1, sp2 = self.dataset.zero_pad() _, sp1, sp2 = self.dataset.zero_pad()
# Set up signal model # Set up signal model
# Age Parameters # X Parameters
age_dictionary = self.setup['indep']['age'] x_dictionary = self.setup['indep']['x']
age_offset_mean = age_dictionary['offset_dict']['offset_mean'] x_offset_mean = x_dictionary['offset_dict']['offset_mean']
age_offset_std = age_dictionary['offset_dict']['offset_std'] x_offset_std = x_dictionary['offset_dict']['offset_std']
# Log Load Parameters # Log Load Parameters
ll_dictionary = self.setup['indep']['log_load'] y_dictionary = self.setup['indep']['log_load']
ll_offset_mean = ll_dictionary['offset_dict']['offset_mean'] y_offset_mean = y_dictionary['offset_dict']['offset_mean']
ll_offset_std = ll_dictionary['offset_dict']['offset_std'] y_offset_std = y_dictionary['offset_dict']['offset_std']
indep_ll_prefix = ll_dictionary['offset_dict']['prefix'] indep_y_prefix = y_dictionary['offset_dict']['prefix']
# Create the age axis with the density estimator # Create the x axis with the density estimator
signal_response, ops = density_estimator(sp1, cf_fluctuations=age_dictionary['params'], signal_response, ops = density_estimator(sp1, cf_fluctuations=x_dictionary['params'],
cf_azm_uniform=age_offset_std, azm_offset_mean=age_offset_mean, pad=0) cf_azm_uniform=x_offset_std, azm_offset_mean=x_offset_mean, pad=0)
lambda_ag_full = ops["correlated_field"] lambda_x_full = ops["correlated_field"]
age_amplitude = ops["amplitude"] x_amplitude = ops["amplitude"]
zero_mode = ops["amplitude_total_offset"] zero_mode = ops["amplitude_total_offset"]
# response = ops["exposure"] # response = ops["exposure"]
# Create the viral load axis with the Matérn-kernel correlated field # Create the viral load axis with the Matérn-kernel correlated field
correlated_field_maker = ift.CorrelatedFieldMaker(indep_ll_prefix) correlated_field_maker = ift.CorrelatedFieldMaker(indep_y_prefix)
correlated_field_maker.set_amplitude_total_offset(ll_offset_mean, ll_offset_std) correlated_field_maker.set_amplitude_total_offset(y_offset_mean, y_offset_std)
correlated_field_maker.add_fluctuations_matern(sp2, **ll_dictionary['params']) correlated_field_maker.add_fluctuations_matern(sp2, **y_dictionary['params'])
lambda_ll_full = correlated_field_maker.finalize() lambda_y_full = correlated_field_maker.finalize()
ll_amplitude = correlated_field_maker.amplitude y_amplitude = correlated_field_maker.amplitude
amplitudes = [age_amplitude, ll_amplitude] amplitudes = [x_amplitude, y_amplitude]
return self.build_independent_components(lambda_ag_full, lambda_ll_full, amplitudes) return self.build_independent_components(lambda_x_full, lambda_y_full, amplitudes)
def plot_prior_samples(self, n_samples): def plot_prior_samples(self, n_samples):
plot = ift.Plot() plot = ift.Plot()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment