Commit f572787f authored by Matteo.Guardiani's avatar Matteo.Guardiani
Browse files

andrija: Refactored age -> x, ll -> y. Made more covid-agnostic.

parent 8c8c5aca
......@@ -25,9 +25,9 @@ import matplotlib.colors as colors
import nifty7 as ift
from const import npix_age, npix_ll
from covid_matern_model import MaternCausalModel
from data import Data
from data_utilities import save_kl_position, save_kl_sample
from matern_causal_model import MaternCausalModel
from utilities import get_op_post_mean
# Parser Setup
......
......@@ -41,7 +41,7 @@ from data_utilities import save_kl_sample, save_kl_position
from utilities import get_op_post_mean
from const import npix_age, npix_ll
from data import Data
from covid_matern_model import MaternCausalModel
from matern_causal_model import MaternCausalModel
# from evidence_g import get_evidence
import matplotlib.colors as colors
......
......@@ -32,34 +32,34 @@ class ClassLoader(type):
class Data(metaclass=ClassLoader):
def __init__(self, npix_age, npix_ll, ll_threshold, reshuffle_seed, age, ll):
if not isinstance(npix_age, int):
def __init__(self, npix_x, npix_y, y_threshold, reshuffle_seed, x, y):
if not isinstance(npix_x, int):
raise TypeError("Number of pixels argument needs to be of type int.")
if not isinstance(npix_ll, int):
if not isinstance(npix_y, int):
raise TypeError("Number of pixels argument needs to be of type int.")
if not isinstance(ll_threshold, float):
if not isinstance(y_threshold, float):
raise TypeError("Log load threshold value argument needs to be of type float.")
if not isinstance(reshuffle_seed, int):
raise TypeError("Reshuffle iterator argument needs to be of type int.")
if not isinstance(age, np.ndarray):
raise TypeError("The age dataset argument needs to be of type np.ndarray.")
if not isinstance(x, np.ndarray):
raise TypeError("The x dataset argument needs to be of type np.ndarray.")
if not isinstance(ll, np.ndarray):
raise TypeError("The ll dataset argument needs to be of type np.ndarray.")
if not isinstance(y, np.ndarray):
raise TypeError("The y dataset argument needs to be of type np.ndarray.")
if not age.size == ll.size:
raise TypeError("The dataset has to contain pairs of age and log load data.")
if not x.size == y.size:
raise TypeError("The dataset has to contain pairs of x and log load data.")
self.npix_age, self.npix_ll = npix_age, npix_ll
self.ll_threshold = ll_threshold
self.npix_x, self.npix_y = npix_x, npix_y
self.y_threshold = y_threshold
self.reshuffle_seed = reshuffle_seed
self.age = age, self.ll = ll
self.x = x, self.y = y
self.age, self.ll = self.filter()
self.x, self.y = self.filter()
self.data = None
self.edges = None
self.filename = 'plots/data.pdf'
......@@ -70,30 +70,30 @@ class Data(metaclass=ClassLoader):
def filter(self):
# Loads, filters and reshuffles data
self.ll, self.age = self.data_filter_x(self.ll_threshold, self.ll, self.age)
self.y, self.x = self.data_filter_x(self.y_threshold, self.y, self.x)
if not self.reshuffle_seed == 0:
self.__reshuffle_data(self.ll, self.reshuffle_seed)
self.__reshuffle_data(self.y, self.reshuffle_seed)
return self.age, self.ll
return self.x, self.y
def zero_pad(self):
ext_npix_age = 2 * self.npix_age
ext_npix_ll = 2 * self.npix_ll
ext_npix_x = 2 * self.npix_x
ext_npix_y = 2 * self.npix_y
return self.__create_spaces(ext_npix_age, ext_npix_ll)
return self.__create_spaces(ext_npix_x, ext_npix_y)
def __bin(self):
data, age_edges, ll_edges = bin_2D(self.age, self.ll, self.npix_age, self.npix_ll)
data, x_edges, y_edges = bin_2D(self.x, self.y, self.npix_x, self.npix_y)
self.data = np.array(data, dtype=np.int64)
return data, age_edges, ll_edges
return data, x_edges, y_edges
def coordinates(self):
age_coordinates = self.__obtain_coordinates(self.age, self.npix_age)
ll_coordinates = self.__obtain_coordinates(self.ll, self.npix_ll)
x_coordinates = self.__obtain_coordinates(self.x, self.npix_x)
y_coordinates = self.__obtain_coordinates(self.y, self.npix_y)
return age_coordinates, ll_coordinates
return x_coordinates, y_coordinates
def plot(self):
import matplotlib.pyplot as plt
......@@ -137,7 +137,7 @@ class Data(metaclass=ClassLoader):
class InvertedData(Data):
def __init__(self, data):
super().__init__(data.npix_age, data.npix_ll, data.ll_threshold, data.reshuffle_seed, data.csv_dataset_path)
self.age, self.ll = self.ll, self.age
self.npix_age, self.npix_ll = self.npix_ll, self.npix_age
super().__init__(data.npix_x, data.npix_y, data.y_threshold, data.reshuffle_seed, data.csv_dataset_path)
self.x, self.y = self.y, self.x
self.npix_x, self.npix_y = self.npix_y, self.npix_x
self.filename = 'plots/inverted_data.pdf'
......@@ -40,10 +40,10 @@ class MaternCausalModel:
self.plot = plot
self.alphas = alphas
self.lambda_joint = None
self.lambda_age = None
self.lambda_ll = None
self.lambda_age_full = None
self.lambda_ll_full = None
self.lambda_x = None
self.lambda_y = None
self.lambda_x_full = None
self.lambda_y_full = None
self.lambda_full = None
self.amplitudes = None
self.position_space = None
......@@ -54,7 +54,7 @@ class MaternCausalModel:
def create_model(self):
self.lambda_joint, self.lambda_full = self.build_joint_component()
self.lambda_age, self.lambda_ll, self.lambda_age_full, self.lambda_ll_full, self.amplitudes = \
self.lambda_x, self.lambda_y, self.lambda_x_full, self.lambda_y_full, self.amplitudes = \
self.initialize_independent_components()
# Dimensionality adjustment for the independent component
......@@ -63,51 +63,51 @@ class MaternCausalModel:
domain_break_op = DomainBreak2D(self.target_space)
lambda_joint_placeholder = ift.FieldAdapter(self.lambda_joint.target, 'lambdajoint')
lambda_ll_placeholder = ift.FieldAdapter(self.lambda_ll.target, 'lambdall')
lambda_age_placeholder = ift.FieldAdapter(self.lambda_age.target, 'lambdaage')
lambda_y_placeholder = ift.FieldAdapter(self.lambda_y.target, 'lambday')
lambda_x_placeholder = ift.FieldAdapter(self.lambda_x.target, 'lambdax')
x_marginalizer_op = domain_break_op(lambda_joint_placeholder.ptw('exp')).sum(
0) # Field exponentiation and marginalization along the x direction, hence has 'length' y
age_unit_field = ift.full(self.lambda_age.target, 1)
dimensionality_operator = ift.OuterProduct(self.lambda_ll.target, age_unit_field)
lambda_ll_2d = domain_break_op.adjoint @ dimensionality_operator @ lambda_ll_placeholder
ll_unit_field = ift.full(self.lambda_ll.target, 1)
dimensionality_operator_2 = ift.OuterProduct(self.lambda_age.target, ll_unit_field)
transposition_operator = ift.LinearEinsum(dimensionality_operator_2(lambda_age_placeholder).target,
x_unit_field = ift.full(self.lambda_x.target, 1)
dimensionality_operator = ift.OuterProduct(self.lambda_y.target, x_unit_field)
lambda_y_2d = domain_break_op.adjoint @ dimensionality_operator @ lambda_y_placeholder
y_unit_field = ift.full(self.lambda_y.target, 1)
dimensionality_operator_2 = ift.OuterProduct(self.lambda_x.target, y_unit_field)
transposition_operator = ift.LinearEinsum(dimensionality_operator_2(lambda_x_placeholder).target,
ift.MultiField.from_dict({}), "xy->yx")
dimensionality_operator_2 = transposition_operator @ dimensionality_operator_2
lambda_age_2d = domain_break_op.adjoint @ dimensionality_operator_2 @ lambda_age_placeholder
lambda_x_2d = domain_break_op.adjoint @ dimensionality_operator_2 @ lambda_x_placeholder
joint_component = lambda_ll_2d + lambda_joint_placeholder
joint_component = lambda_y_2d + lambda_joint_placeholder
cond_density = joint_component.ptw('exp') * domain_break_op.adjoint(
dimensionality_operator(x_marginalizer_op.ptw('reciprocal')))
normalization = domain_break_op(cond_density).sum(1)
log_lambda_combined = lambda_age_2d + joint_component - domain_break_op.adjoint(
log_lambda_combined = lambda_x_2d + joint_component - domain_break_op.adjoint(
dimensionality_operator(x_marginalizer_op.ptw('log'))) - domain_break_op.adjoint(
dimensionality_operator_2(normalization.ptw('log')))
log_lambda_combined = log_lambda_combined @ (
self.lambda_joint.ducktape_left('lambdajoint') + self.lambda_ll.ducktape_left(
'lambdall') + self.lambda_age.ducktape_left('lambdaage'))
self.lambda_joint.ducktape_left('lambdajoint') + self.lambda_y.ducktape_left(
'lambday') + self.lambda_x.ducktape_left('lambdax'))
lambda_combined = log_lambda_combined.ptw('exp')
conditional_probability = cond_density * domain_break_op.adjoint(dimensionality_operator_2(normalization)).ptw(
'reciprocal')
conditional_probability = conditional_probability @ (
self.lambda_joint.ducktape_left('lambdajoint') + self.lambda_ll.ducktape_left('lambdall'))
self.lambda_joint.ducktape_left('lambdajoint') + self.lambda_y.ducktape_left('lambday'))
# Normalize the probability on the given logload interval
boundaries = [min(self.dataset.coordinates()[0]), max(self.dataset.coordinates()[0]),
min(self.dataset.coordinates()[1]), max(self.dataset.coordinates()[1])]
inv_norm = self.dataset.npix_ll / (boundaries[3] - boundaries[2])
inv_norm = self.dataset.npix_y / (boundaries[3] - boundaries[2])
conditional_probability = conditional_probability * inv_norm
return lambda_combined, conditional_probability
def build_joint_component(self):
npix_age = self.dataset.npix_age
npix_ll = self.dataset.npix_ll
npix_x = self.dataset.npix_x
npix_y = self.dataset.npix_y
self.position_space, sp1, sp2 = self.dataset.zero_pad()
# Set up signal model
......@@ -116,27 +116,27 @@ class MaternCausalModel:
offset_std = joint_offset['offset_std']
joint_prefix = joint_offset['prefix']
joint_setup_ll = self.setup['joint']['log_load']
ll_scale = joint_setup_ll['scale']
ll_cutoff = joint_setup_ll['cutoff']
ll_loglogslope = joint_setup_ll['loglogslope']
ll_prefix = joint_setup_ll['prefix']
joint_setup_y = self.setup['joint']['log_load']
y_scale = joint_setup_y['scale']
y_cutoff = joint_setup_y['cutoff']
y_loglogslope = joint_setup_y['loglogslope']
y_prefix = joint_setup_y['prefix']
joint_setup_age = self.setup['joint']['age']
age_scale = joint_setup_age['scale']
age_cutoff = joint_setup_age['cutoff']
age_loglogslope = joint_setup_age['loglogslope']
age_prefix = joint_setup_age['prefix']
joint_setup_x = self.setup['joint']['x']
x_scale = joint_setup_x['scale']
x_cutoff = joint_setup_x['cutoff']
x_loglogslope = joint_setup_x['loglogslope']
x_prefix = joint_setup_x['prefix']
correlated_field_maker = ift.CorrelatedFieldMaker(joint_prefix)
correlated_field_maker.set_amplitude_total_offset(offset_mean, offset_std)
correlated_field_maker.add_fluctuations_matern(sp1, age_scale, age_cutoff, age_loglogslope, age_prefix)
correlated_field_maker.add_fluctuations_matern(sp2, ll_scale, ll_cutoff, ll_loglogslope, ll_prefix)
correlated_field_maker.add_fluctuations_matern(sp1, x_scale, x_cutoff, x_loglogslope, x_prefix)
correlated_field_maker.add_fluctuations_matern(sp2, y_scale, y_cutoff, y_loglogslope, y_prefix)
lambda_full = correlated_field_maker.finalize()
# For the joint model unmasked regions
tgt = ift.RGSpace((npix_age, npix_ll),
tgt = ift.RGSpace((npix_x, npix_y),
distances=(lambda_full.target[0].distances[0], lambda_full.target[1].distances[0]))
GMO = GeomMaskOperator(lambda_full.target, tgt)
......@@ -147,55 +147,55 @@ class MaternCausalModel:
return lambda_joint, lambda_full
def build_independent_components(self, lambda_ag_full, lambda_ll_full, amplitudes):
def build_independent_components(self, lambda_x_full, lambda_y_full, amplitudes):
# Split the center
# Age
_dist = lambda_ag_full.target[0].distances
tgt_age = ift.RGSpace(self.dataset.npix_age, distances=_dist)
GMO_age = GeomMaskOperator(lambda_ag_full.target, tgt_age)
lambda_age = GMO_age(lambda_ag_full.clip(-30, 30))
# X
_dist = lambda_x_full.target[0].distances
tgt_x = ift.RGSpace(self.dataset.npix_x, distances=_dist)
GMO_x = GeomMaskOperator(lambda_x_full.target, tgt_x)
lambda_x = GMO_x(lambda_x_full.clip(-30, 30))
# Viral load
_dist = lambda_ll_full.target[0].distances
tgt_ll = ift.RGSpace(self.dataset.npix_ll, distances=_dist)
GMO_ll = GeomMaskOperator(lambda_ll_full.target, tgt_ll)
lambda_ll = GMO_ll(lambda_ll_full.clip(-30, 30))
_dist = lambda_y_full.target[0].distances
tgt_y = ift.RGSpace(self.dataset.npix_y, distances=_dist)
GMO_y = GeomMaskOperator(lambda_y_full.target, tgt_y)
lambda_y = GMO_y(lambda_y_full.clip(-30, 30))
return lambda_age, lambda_ll, lambda_ag_full, lambda_ll_full, amplitudes
return lambda_x, lambda_y, lambda_x_full, lambda_y_full, amplitudes
def initialize_independent_components(self):
_, sp1, sp2 = self.dataset.zero_pad()
# Set up signal model
# Age Parameters
age_dictionary = self.setup['indep']['age']
age_offset_mean = age_dictionary['offset_dict']['offset_mean']
age_offset_std = age_dictionary['offset_dict']['offset_std']
# X Parameters
x_dictionary = self.setup['indep']['x']
x_offset_mean = x_dictionary['offset_dict']['offset_mean']
x_offset_std = x_dictionary['offset_dict']['offset_std']
# Log Load Parameters
ll_dictionary = self.setup['indep']['log_load']
ll_offset_mean = ll_dictionary['offset_dict']['offset_mean']
ll_offset_std = ll_dictionary['offset_dict']['offset_std']
indep_ll_prefix = ll_dictionary['offset_dict']['prefix']
# Create the age axis with the density estimator
signal_response, ops = density_estimator(sp1, cf_fluctuations=age_dictionary['params'],
cf_azm_uniform=age_offset_std, azm_offset_mean=age_offset_mean, pad=0)
lambda_ag_full = ops["correlated_field"]
age_amplitude = ops["amplitude"]
y_dictionary = self.setup['indep']['log_load']
y_offset_mean = y_dictionary['offset_dict']['offset_mean']
y_offset_std = y_dictionary['offset_dict']['offset_std']
indep_y_prefix = y_dictionary['offset_dict']['prefix']
# Create the x axis with the density estimator
signal_response, ops = density_estimator(sp1, cf_fluctuations=x_dictionary['params'],
cf_azm_uniform=x_offset_std, azm_offset_mean=x_offset_mean, pad=0)
lambda_x_full = ops["correlated_field"]
x_amplitude = ops["amplitude"]
zero_mode = ops["amplitude_total_offset"]
# response = ops["exposure"]
# Create the viral load axis with the Matérn-kernel correlated field
correlated_field_maker = ift.CorrelatedFieldMaker(indep_ll_prefix)
correlated_field_maker.set_amplitude_total_offset(ll_offset_mean, ll_offset_std)
correlated_field_maker.add_fluctuations_matern(sp2, **ll_dictionary['params'])
lambda_ll_full = correlated_field_maker.finalize()
ll_amplitude = correlated_field_maker.amplitude
correlated_field_maker = ift.CorrelatedFieldMaker(indep_y_prefix)
correlated_field_maker.set_amplitude_total_offset(y_offset_mean, y_offset_std)
correlated_field_maker.add_fluctuations_matern(sp2, **y_dictionary['params'])
lambda_y_full = correlated_field_maker.finalize()
y_amplitude = correlated_field_maker.amplitude
amplitudes = [age_amplitude, ll_amplitude]
amplitudes = [x_amplitude, y_amplitude]
return self.build_independent_components(lambda_ag_full, lambda_ll_full, amplitudes)
return self.build_independent_components(lambda_x_full, lambda_y_full, amplitudes)
def plot_prior_samples(self, n_samples):
plot = ift.Plot()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment