Commit 19e8a379 authored by Matteo.Guardiani's avatar Matteo.Guardiani
Browse files

cleanup: created new data class to process the data and keep data prep and...

cleanup: created new data class to process the data and keep data prep and model separate. Cosmetics.
parent 8db30b64
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
# Copyright(C) 2013-2022 Max-Planck-Society
# Author: Matteo Guardiani
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -24,8 +24,10 @@ import json
import shutil
import pickle
import sys
try:
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
n_task = comm.Get_size()
rank = comm.Get_rank()
......@@ -40,12 +42,12 @@ from covid_matern_model import covid_matern_model_maker
from data_utilities import read_in, export_synthetic, save_random_state, save_KL_sample, save_KL_position
from utilities import get_op_post_mean
from const import npix_age, npix_ll
from evidence_g import get_evidence
# from evidence_g import get_evidence
import matplotlib.colors as colors
# Parser Setup
parser = argparse.ArgumentParser()
parser.add_argument('--json_file', type=str, required=True) # FIX ME: Add help --help
parser.add_argument('--json_file', type=str, required=True) # FIX ME: Add help --help
parser.add_argument('--csv_file', type=str, required=True)
parser.add_argument('--reshuffle_parameter', type=int, required=True)
args = parser.parse_args()
......@@ -62,24 +64,28 @@ if __name__ == '__main__':
inv_par = False
if inv in json_file: inv_par = True
file_setup = open(json_file, "r")
file_setup = open(json_file, "r")
setup = json.load(file_setup)
file_setup.close()
# Preparing the filename string for store live results
# Preparing the filename string and plots folder to store live results
if not os.path.exists('./plots'):
os.mkdir('./plots')
filename = "plots/covid_combined_matern_{}.png"
# Results Output Folders
path_j = os.path.basename(json_file)
path_c = os.path.basename(csv_file)
results_path = os.path.join('./Automized_Results_Matern', os.path.splitext(path_j)[0], os.path.splitext(path_c)[0], str(resh_it))
results_path = os.path.join('./Automized_Results_Matern', os.path.splitext(path_j)[0], os.path.splitext(path_c)[0],
str(resh_it))
results_path = os.path.normpath(results_path)
os.makedirs(results_path, exist_ok=True)
# Load the model
model = covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, resh_it, False, inv_par)
model = covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, resh_it, False, inv_par)
data, boundaries, lamb_full, lamb_ag_full, lamb_ll_full, amps, lamb_joint, lamb_ag, lamb_ll, lamb_ag_2d, lamb_comb, cond_prob, zm = model.values()
# Setup the response & define the amplitudes
......@@ -91,11 +97,11 @@ if __name__ == '__main__':
# Specify data space
data_space = R_lamb.target
# Generate mock signal and data
seed = setup['seed']
ift.random.push_sseq_from_seed(seed)
if setup['mock']:
# data
......@@ -106,11 +112,11 @@ if __name__ == '__main__':
if not setup['same data'] and indep_tag in json_file:
print("\nUsing syinthetic data generated from joint model on independent model")
joint_json_file = json_file.replace(indep_tag, '')
file_setup = open(joint_json_file, "r")
joint_json_file = json_file.replace(indep_tag, '')
file_setup = open(joint_json_file, "r")
joint_setup = json.load(file_setup)
file_setup.close()
joint_model = covid_matern_model_maker(npix_age, npix_ll, joint_setup, csv_file, resh_it, False, inv_par)
joint_model = covid_matern_model_maker(npix_age, npix_ll, joint_setup, csv_file, resh_it, False, inv_par)
joint_lamb_comb = joint_model['combined lambda']
mock_position = ift.from_random(joint_lamb_comb.domain, 'normal')
data = R_lamb(mock_position)
......@@ -119,10 +125,10 @@ if __name__ == '__main__':
if not setup['same data'] and not indep_tag in json_file:
print("\nUsing syinthetic data generated from independent model on joint model")
indep_json_file = os.path.splitext(json_file)[0] + '_indep' + os.path.splitext(json_file)[1]
file_setup = open(indep_json_file, "r")
file_setup = open(indep_json_file, "r")
indep_setup = json.load(file_setup)
file_setup.close()
indep_model = covid_matern_model_maker(npix_age, npix_ll, indep_setup, csv_file, resh_it, False, inv_par)
indep_model = covid_matern_model_maker(npix_age, npix_ll, indep_setup, csv_file, resh_it, False, inv_par)
indep_lamb_comb = indep_model['combined lambda']
mock_position = ift.from_random(indep_lamb_comb.domain, 'normal')
data = R_lamb(mock_position)
......@@ -130,7 +136,7 @@ if __name__ == '__main__':
data = ift.makeField(data_space, data)
if setup['mock']:
if setup['mock']:
plot = ift.Plot()
plot.add(lamb_comb(mock_position), title='Full Field')
plot.add(R.adjoint(data), title='Data')
......@@ -151,30 +157,32 @@ if __name__ == '__main__':
H = ift.StandardHamiltonian(likelihood, ic_sampling)
# Begin minimization
initial_mean = ift.from_random(H.domain, 'normal')*0.1
initial_mean = ift.from_random(H.domain, 'normal') * 0.1
mean = initial_mean
N_steps = 35 # 34
N_steps = 35 # 34
for i in range(N_steps):
if i < 27:
ic_newton = ift.AbsDeltaEnergyController(deltaE=1e-5, iteration_limit=10, name='newton', convergence_level=3)
if i < 27:
ic_newton = ift.AbsDeltaEnergyController(deltaE=1e-5, iteration_limit=10, name='newton',
convergence_level=3)
ic_newton.enable_logging()
else:
ic_newton = ift.AbsDeltaEnergyController(deltaE=1e-5, iteration_limit=20, name='newton', convergence_level=3)
ic_newton = ift.AbsDeltaEnergyController(deltaE=1e-5, iteration_limit=20, name='newton',
convergence_level=3)
ic_newton.enable_logging()
minimizer = ift.NewtonCG(ic_newton, enable_logging=True)
if i<30:
if i < 30:
N_samples = 5
elif i<33:
elif i < 33:
N_samples = 20
else:
N_samples = 500
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL.make(mean, H, N_samples, comm=comm, mirror_samples=True, nanisinf=True)
KL = ift.MetricGaussianKL(mean, H, N_samples, comm=comm, mirror_samples=True, nanisinf=True)
KL, convergence = minimizer(KL)
samples = tuple(KL.samples)
mean = KL.position
......@@ -185,7 +193,7 @@ if __name__ == '__main__':
save_KL_position(mean, pos_path)
print("KL position saved", file=sys.stderr)
sam_path = os.path.join(results_path, "samples")
sam_path = os.path.join(results_path, "samples")
os.makedirs(sam_path, exist_ok=True)
for sample in samples:
......@@ -194,7 +202,8 @@ if __name__ == '__main__':
print("KL samples saved", file=sys.stderr)
# Minisanity check
ift.extra.minisanity(data, lambda x:ift.makeOp(R_lamb(x).ptw('reciprocal')), R_lamb, mean, samples) # Fix Me: Check noise implementation in minisanity
ift.extra.minisanity(data, lambda x: ift.makeOp(R_lamb(x).ptw('reciprocal')), R_lamb, mean,
samples) # Fix Me: Check noise implementation in minisanity
# Plot current reconstruction
plot = ift.Plot()
......@@ -204,20 +213,24 @@ if __name__ == '__main__':
plot.add(R.adjoint(data), title='Data')
plot.add([lamb_comb(mean)], title="reconstruction")
plot.add([lamb_joint.force(mean)], title="Joint component")
plot.add([A1.force(mean), A1.force(mock_position)], title="power1")
plot.add([A2.force(mean), A2.force(mock_position)], title="power2")
plot.add([ic_newton.history, ic_sampling.history, minimizer.inversion_history], label=['KL', 'Sampling', 'Newton inversion'], title='Cumulative energies', s=[None, None, 1], alpha=[None, 0.2, None])
plot.add([ic_newton.history, ic_sampling.history, minimizer.inversion_history],
label=['KL', 'Sampling', 'Newton inversion'], title='Cumulative energies', s=[None, None, 1],
alpha=[None, 0.2, None])
else:
plot.add([lamb_comb(mean)], title="Reconstruction", norm=colors.SymLogNorm(linthresh=10e-1), extent = boundaries, aspect = "auto")
plot.add([lamb_full.force(mean)], title="Joint Component Reconstruction", norm=colors.SymLogNorm(linthresh=10e-1), extent = boundaries, aspect = "auto")
plot.add([cond_prob.force(mean)], title="Conditional Probability Reconstruction", norm=colors.SymLogNorm(linthresh=6*10e-4), extent = boundaries, aspect = "auto")
plot.add([lamb_comb(mean)], title="Reconstruction", norm=colors.SymLogNorm(linthresh=10e-1),
extent=boundaries, aspect="auto")
plot.add([lamb_full.force(mean)], title="Joint Component Reconstruction",
norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries, aspect="auto")
plot.add([cond_prob.force(mean)], title="Conditional Probability Reconstruction",
norm=colors.SymLogNorm(linthresh=6 * 10e-4), extent=boundaries, aspect="auto")
# plot.add([Aj.force(mean)], title="power1 joint") # FIX ME: MAYBE ACCOUNT FOR THE MARGINALIZATION ??
plot.add([A1.force(mean)], title="power1 independent")
plot.add([A2.force(mean)], title="power2 independent")
plot.add(lamb_ag_full.force(mean), title="Age Reconstruction (full)", aspect = "auto")
plot.add(lamb_ll_full.force(mean), title="Log load Reconstruction (full)", aspect = "auto")
plot.add([A2.force(mean)], title="power2 independent")
plot.add(lamb_ag_full.force(mean), title="Age Reconstruction (full)", aspect="auto")
plot.add(lamb_ll_full.force(mean), title="Log load Reconstruction (full)", aspect="auto")
plot.output(nx=3, ny=3, ysize=10, xsize=15, name=filename.format("loop_{:02d}".format(i)))
print('Lamb combined check:', lamb_comb(KL.position).val.sum(), '\n', file=sys.stderr)
......@@ -232,28 +245,31 @@ if __name__ == '__main__':
powers1 = []
powers2 = []
for sample in samples:
p1 = A1.force(sample + mean)
p2 = A2.force(sample + mean)
powers1.append(p1)
powers2.append(p2)
# Final Plots
filename_res = "Results.png"
filename_res = os.path.join(results_path, filename_res)
plot = ift.Plot()
plot.add(lamb_comb_mean, title="Posterior Mean", norm=colors.SymLogNorm(linthresh=10e-1),extent = boundaries, aspect = "auto")
plot.add(lamb_comb_var.sqrt(), title="Posterior Standard Deviation", norm=colors.SymLogNorm(linthresh=10e-1), extent = boundaries, aspect = "auto")
plot.add([cond_prob.force(mean)], title="Conditional Probability Reconstruction", norm=colors.SymLogNorm(linthresh=6*10e-4), extent = boundaries, aspect = "auto")
plot.add(lamb_comb_mean, title="Posterior Mean", norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries,
aspect="auto")
plot.add(lamb_comb_var.sqrt(), title="Posterior Standard Deviation", norm=colors.SymLogNorm(linthresh=10e-1),
extent=boundaries, aspect="auto")
plot.add([cond_prob.force(mean)], title="Conditional Probability Reconstruction",
norm=colors.SymLogNorm(linthresh=6 * 10e-4), extent=boundaries, aspect="auto")
plot.add([A1.force(mean)], title="Age Independent Power Spectrum (log[S(k^2)])")
plot.add([A2.force(mean)], title="Log load Independent Power Spectrum (log[S(k^2)])")
plot.add(lamb_ag_full.force(mean), title="Age Reconstruction (full)", norm=colors.SymLogNorm(linthresh=10e-1), aspect = "auto")
plot.add(lamb_ll_full.force(mean), title="Log load Reconstruction (full)", norm=colors.SymLogNorm(linthresh=10e-1), aspect = "auto")
plot.add([lamb_full.exp().force(mean)], title="Joint Component Reconstruction (full)", norm=colors.SymLogNorm(linthresh=10e-1), extent = boundaries, aspect = "auto")
plot.add(lamb_ag_full.force(mean), title="Age Reconstruction (full)", norm=colors.SymLogNorm(linthresh=10e-1),
aspect="auto")
plot.add(lamb_ll_full.force(mean), title="Log load Reconstruction (full)",
norm=colors.SymLogNorm(linthresh=10e-1), aspect="auto")
plot.add([lamb_full.exp().force(mean)], title="Joint Component Reconstruction (full)",
norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries, aspect="auto")
plot.output(ny=3, nx=3, xsize=20, ysize=15, name=filename_res)
print("Saved results as", filename_res, file=sys.stderr)
......@@ -262,12 +278,20 @@ if __name__ == '__main__':
filename_ers = "Errors.png"
filename_ers = os.path.join(results_path, filename_ers)
plot = ift.Plot()
plot.add(lamb_comb_mean, title="Posterior Mean", norm=colors.SymLogNorm(linthresh=10e-1), extent = boundaries, aspect = "auto")
plot.add(lamb_comb_var.sqrt()*lamb_comb_mean.ptw('reciprocal'), title="Relative Uncertainty", norm=colors.SymLogNorm(linthresh=10e-1), extent = boundaries, aspect = "auto")
plot.add(cond_prob_mean, title="Conditional Probability Reconstruction Mean", norm=colors.SymLogNorm(linthresh=6*10e-4), extent = boundaries, aspect = "auto")
plot.add(cond_prob_var.sqrt()*cond_prob_mean.ptw('reciprocal'), title="Relative Uncertainty on Conditional Probability Reconstruction", norm=colors.SymLogNorm(linthresh=10e-1), extent = boundaries, aspect = "auto")
plot.add(lamb_full_mean, title="Joint Component Reconstruction Mean", norm=colors.SymLogNorm(linthresh=10e-2), aspect = "auto")
plot.add(lamb_full_var.sqrt()*lamb_full_mean.ptw('reciprocal'), title="Relative Uncertainty on Joint Component Reconstruction", norm=colors.SymLogNorm(linthresh=10e-2), aspect = "auto")
plot.add(lamb_comb_mean, title="Posterior Mean", norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries,
aspect="auto")
plot.add(lamb_comb_var.sqrt() * lamb_comb_mean.ptw('reciprocal'), title="Relative Uncertainty",
norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries, aspect="auto")
plot.add(cond_prob_mean, title="Conditional Probability Reconstruction Mean",
norm=colors.SymLogNorm(linthresh=6 * 10e-4), extent=boundaries, aspect="auto")
plot.add(cond_prob_var.sqrt() * cond_prob_mean.ptw('reciprocal'),
title="Relative Uncertainty on Conditional Probability Reconstruction",
norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries, aspect="auto")
plot.add(lamb_full_mean, title="Joint Component Reconstruction Mean", norm=colors.SymLogNorm(linthresh=10e-2),
aspect="auto")
plot.add(lamb_full_var.sqrt() * lamb_full_mean.ptw('reciprocal'),
title="Relative Uncertainty on Joint Component Reconstruction",
norm=colors.SymLogNorm(linthresh=10e-2), aspect="auto")
plot.output(ny=3, nx=2, xsize=15, ysize=15, name=filename_ers)
print("Saved results as", filename_ers, file=sys.stderr)
......@@ -290,11 +314,4 @@ if __name__ == '__main__':
# with open('./Automized_Results_Matern/Automized_Evidences.txt', 'a') as file:
# evidences = os.path.basename(json_file) + ', ' + os.path.basename(csv_file) + '_' + str(resh_it) + ', Evidence Mean: ' + str(evidence['estimate']) + '\n\n'
# file.write(evidences)
# file.write(evidences)
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
# Copyright(C) 2013-2022 Max-Planck-Society
# Author: Matteo Guardiani
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -26,6 +26,7 @@ from tools import y_averager, x_averager, density_estimator
from data_utilities import fitted_infectivity
from utilities import GeomMaskOperator
def covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, reshuffle_iterator, plot, inv_par, alphas=None):
''' Implements a causal model for covid data analysis in Nifty.
Required arguments:
......@@ -43,48 +44,51 @@ def covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, reshuffle_itera
raise TypeError("Pixel argument needs to be int")
if not isinstance(npix_ll, int):
raise TypeError("Pixel argument needs to be int")
raise TypeError("Pixel argument needs to be int")
if not isinstance(plot, bool):
raise TypeError("Plot argument needs to be bool")
# DATA BINNING
# Make a bigger domain to avoid boundary effects
ext_npix_age = 2*npix_age
ext_npix_ll = 2*npix_ll
if inv_par:
position_space = ift.RGSpace([ext_npix_ll, ext_npix_age])
ext_npix_age = 2 * npix_age
ext_npix_ll = 2 * npix_ll
if inv_par:
position_space = ift.RGSpace((ext_npix_ll, ext_npix_age))
sp1 = ift.RGSpace(ext_npix_ll)
sp2 = ift.RGSpace(ext_npix_age)
else:
position_space = ift.RGSpace([ext_npix_age, ext_npix_ll])
else:
position_space = ift.RGSpace((ext_npix_age, ext_npix_ll))
sp1 = ift.RGSpace(ext_npix_age)
sp2 = ift.RGSpace(ext_npix_ll)
# Load and bin data
threshold = setup["threshold"]
age, ll = read_in(csv_file)
ll, age = data_filter_x(threshold, ll, age) # For Cobas data filter was set at 3.3 for log load FIXXXXX!!! NOW 3.85 (other param 5.4)
ll, age = data_filter_x(threshold, ll,
age) # For Cobas data filter was set at 3.3 for log load FIXXXXX!!! NOW 3.85 (other param 5.4)
if not reshuffle_iterator == 0:
from sklearn.utils import shuffle
ll = shuffle(ll, random_state = reshuffle_iterator)
if inv_par: data, ll_edges, age_edges = Bin2D(ll, age, npix_ll, npix_age)
else: data, age_edges, ll_edges = Bin2D(age, ll, npix_age, npix_ll)
ll = shuffle(ll, random_state=reshuffle_iterator)
if inv_par:
data, ll_edges, age_edges = Bin2D(ll, age, npix_ll, npix_age)
else:
data, age_edges, ll_edges = Bin2D(age, ll, npix_age, npix_ll)
data = np.array(data, dtype=np.int64)
# Create proper coordinates
binned_age, edges_a = Bin1D(age, npix_age)
age_coord = 0.5*(edges_a[:-1] + edges_a[1:])
age_coord = 0.5 * (edges_a[:-1] + edges_a[1:])
binned_ll, edges_ll = Bin1D(ll, npix_ll)
ll_coord = 0.5*(edges_ll[:-1] + edges_ll[1:])
ll_coord = 0.5 * (edges_ll[:-1] + edges_ll[1:])
if inv_par:
binned_age, edges_a = Bin1D(ll, npix_ll)
age_coord = 0.5*(edges_a[:-1] + edges_a[1:])
age_coord = 0.5 * (edges_a[:-1] + edges_a[1:])
binned_ll, edges_ll = Bin1D(age, npix_age)
ll_coord = 0.5*(edges_ll[:-1] + edges_ll[1:])
ll_coord = 0.5 * (edges_ll[:-1] + edges_ll[1:])
# --------------------
# --------------------
# JOINT MODEL PART
......@@ -108,9 +112,9 @@ def covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, reshuffle_itera
# .make(offset_mean, offset_std,
cfmaker = ift.CorrelatedFieldMaker(joint_prefix)
cfmaker.set_amplitude_total_offset(offset_mean, offset_std,)
cfmaker.set_amplitude_total_offset(offset_mean, offset_std, )
if inv_par:
cfmaker.add_fluctuations_matern(sp1, ll_scale, ll_cutoff, ll_loglogslope, ll_prefix)
cfmaker.add_fluctuations_matern(sp1, ll_scale, ll_cutoff, ll_loglogslope, ll_prefix)
cfmaker.add_fluctuations_matern(sp2, age_scale, age_cutoff, age_loglogslope, age_prefix)
else:
cfmaker.add_fluctuations_matern(sp1, age_scale, age_cutoff, age_loglogslope, age_prefix)
......@@ -123,14 +127,16 @@ def covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, reshuffle_itera
flag_arr = np.ones(lamb_full.target[0].shape)
# For the joint model unmasked regions
tgt = ift.RGSpace((npix_age,npix_ll), distances=(lamb_full.target[0].distances[0], lamb_full.target[1].distances[0]))
tgt = ift.RGSpace((npix_age, npix_ll),
distances=(lamb_full.target[0].distances[0], lamb_full.target[1].distances[0]))
if inv_par:
tgt = ift.RGSpace((npix_ll, npix_age), distances=(lamb_full.target[0].distances[0], lamb_full.target[1].distances[0]))
tgt = ift.RGSpace((npix_ll, npix_age),
distances=(lamb_full.target[0].distances[0], lamb_full.target[1].distances[0]))
GMO = GeomMaskOperator(lamb_full.target, tgt)
lamb_joint = GMO(lamb_full.clip(-30,30))
lamb_joint = GMO(lamb_full.clip(-30, 30))
# -------
# -------
# INDEPENDENT COMPONENT SETUP
# Set up signal model
......@@ -154,17 +160,18 @@ def covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, reshuffle_itera
ll_offset_std = ll_dict['offset_dict']['offset_std']
indep_ll_prefix = ll_dict['offset_dict']['prefix']
if inv_par:
signal_response, ops = density_estimator(sp1, cf_fluctuations=ll_dict['params'], cf_azm_uniform=ll_offset_std, azm_offset_mean=ll_offset_mean, pad=0)
signal_response, ops = density_estimator(sp1, cf_fluctuations=ll_dict['params'], cf_azm_uniform=ll_offset_std,
azm_offset_mean=ll_offset_mean, pad=0)
cfmaker = ift.CorrelatedFieldMaker(indep_age_prefix)
cfmaker.set_amplitude_total_offset(age_offset_mean, age_offset_std)
cfmaker.add_fluctuations_matern(sp2, **age_dict['params'])
lamb_ll_full = cfmaker.finalize()
A_ll = cfmaker.amplitude
else:
signal_response, ops = density_estimator(sp1, cf_fluctuations=age_dict['params'], cf_azm_uniform=age_offset_std, azm_offset_mean=age_offset_mean, pad=0)
else:
signal_response, ops = density_estimator(sp1, cf_fluctuations=age_dict['params'], cf_azm_uniform=age_offset_std,
azm_offset_mean=age_offset_mean, pad=0)
cfmaker = ift.CorrelatedFieldMaker(indep_ll_prefix)
cfmaker.set_amplitude_total_offset(ll_offset_mean, ll_offset_std)
cfmaker.add_fluctuations_matern(sp2, **ll_dict['params'])
......@@ -176,7 +183,6 @@ def covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, reshuffle_itera
zm = ops["amplitude_total_offset"]
# response = ops["exposure"]
# Split the center
# Age
_dist = lamb_ag_full.target[0].distances
......@@ -185,7 +191,7 @@ def covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, reshuffle_itera
tgt_age = ift.RGSpace(npix_ll, distances=_dist)
GMO_age = GeomMaskOperator(lamb_ag_full.target, tgt_age)
lamb_ag = GMO_age(lamb_ag_full.clip(-30,30))
lamb_ag = GMO_age(lamb_ag_full.clip(-30, 30))
# Viral load
_dist = lamb_ll_full.target[0].distances
......@@ -194,7 +200,7 @@ def covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, reshuffle_itera
tgt_ll = ift.RGSpace(npix_age, distances=_dist)
GMO_ll = GeomMaskOperator(lamb_ll_full.target, tgt_ll)
lamb_ll = GMO_ll(lamb_ll_full.clip(-30,30))
lamb_ll = GMO_ll(lamb_ll_full.clip(-30, 30))
amps = []
amps.append(A_age)
......@@ -206,7 +212,8 @@ def covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, reshuffle_itera
lamb_joint_placeholder = ift.FieldAdapter(lamb_joint.target, 'lambjoint')
lamb_ll_placeholder = ift.FieldAdapter(lamb_ll.target, 'lambll')
lamb_ag_placeholder = ift.FieldAdapter(lamb_ag.target, 'lambag')
margx = dombr(lamb_joint_placeholder.ptw('exp')).sum(0) # Field exponentiation and marginalization along the x direction, hence has 'length' y
margx = dombr(lamb_joint_placeholder.ptw('exp')).sum(
0) # Field exponentiation and marginalization along the x direction, hence has 'length' y
# print(lamb_full.target)
# print(dombr_full.target)
......@@ -217,44 +224,49 @@ def covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, reshuffle_itera
dimop = ift.OuterProduct(lamb_ll.target, unitfield_ag)
lamb_ll_2d = dombr.adjoint @ dimop @ lamb_ll_placeholder
unitfield_ll = ift.full(lamb_ll.target, 1)
unitfield_ll = ift.full(lamb_ll.target, 1)
dimop2 = ift.OuterProduct(lamb_ag.target, unitfield_ll)
transp = ift.LinearEinsum(dimop2(lamb_ag_placeholder).target, ift.MultiField.from_dict({}), "xy->yx")
dimop2 = transp @ dimop2
lamb_ag_2d = dombr.adjoint @ dimop2 @ lamb_ag_placeholder
lamb_ag_2d = dombr.adjoint @ dimop2 @ lamb_ag_placeholder
#--------------
# --------------
# MODEL BUILDING (Combined Model)
joint_comp = lamb_ll_2d + lamb_joint_placeholder
cond_density = joint_comp.ptw('exp') * dombr.adjoint(dimop(margx.ptw('reciprocal')))
normalization = dombr(cond_density).sum(1)
llamb_comb = lamb_ag_2d + joint_comp - dombr.adjoint(dimop(margx.ptw('log'))) - dombr.adjoint(dimop2(normalization.ptw('log')))
llamb_comb = lamb_ag_2d + joint_comp - dombr.adjoint(dimop(margx.ptw('log'))) - dombr.adjoint(
dimop2(normalization.ptw('log')))
llamb_comb = llamb_comb @ (lamb_joint.ducktape_left('lambjoint') + lamb_ll.ducktape_left('lambll') + lamb_ag.ducktape_left('lambag'))
llamb_comb = llamb_comb @ (
lamb_joint.ducktape_left('lambjoint') + lamb_ll.ducktape_left('lambll') + lamb_ag.ducktape_left(
'lambag'))
lamb_comb = llamb_comb.ptw('exp')
cond_prob = cond_density * dombr.adjoint(dimop2(normalization)).ptw('reciprocal')
cond_prob = cond_prob @ (lamb_joint.ducktape_left('lambjoint') + lamb_ll.ducktape_left('lambll'))
# Normalize the probability on the given logload interval
boundaries = [min(age_coord), max(age_coord), min(ll_coord), max(ll_coord)]
inv_norm = npix_ll/(boundaries[3]-boundaries[2])
inv_norm = npix_ll / (boundaries[3] - boundaries[2])
cond_prob = cond_prob * inv_norm
results = {'data' : data, 'boundaries' : boundaries, 'lambda full' : dombr_full.adjoint(lamb_full), 'lambda age full' : lamb_ag_full, 'lambda lload full' : lamb_ll_full, 'amplitudes' : amps,\
'lambda joint' : lamb_joint, 'lambda age' : lamb_ag, 'lambda lload' : lamb_ll, '2d lambda age': lamb_ag_2d, 'combined lambda' : lamb_comb, 'p(y|x)' : cond_prob, 'zero mode': zm}
results = {'data': data, 'boundaries': boundaries, 'lambda full': dombr_full.adjoint(lamb_full),
'lambda age full': lamb_ag_full, 'lambda lload full': lamb_ll_full, 'amplitudes': amps, \
'lambda joint': lamb_joint, 'lambda age': lamb_ag, 'lambda lload': lamb_ll, '2d lambda age': lamb_ag_2d,
'combined lambda': lamb_comb, 'p(y|x)': cond_prob, 'zero mode': zm}
if plot == True:
# Additional fields for plotting (indep)
lamb_ll_2d = lamb_ll_2d @ lamb_ll.ducktape_left('lambll')
lamb_ag_2d = lamb_ag_2d @ lamb_ag.ducktape_left('lambag')
lamb_indep = lamb_ag_2d + lamb_ll_2d
lamb_indep = lamb_ag_2d + lamb_ll_2d
lamb_indep = lamb_indep.exp()
# True Joint
margy = dombr(lamb_joint.exp()).sum(1)
margy = dombr(lamb_joint.exp()).sum(1)
margx = dombr(lamb_joint.exp()).sum(0)
lamb_true = dombr.adjoint @ (dimop(margx) + dimop2(margy))
......@@ -265,10 +277,14 @@ def covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, reshuffle_itera
averages = []
for alpha in alphas:
if inv_par: break
if inv_par:
break
# if inv_par: globals()['y_average_%s' % alpha] = x_averager(cond_prob.target, ift.Field.from_raw(lamb_ag.target, age_coord), 1/inv_norm, alpha)
else: globals()['y_average_%s' % alpha] = y_averager(cond_prob.target, ift.Field.from_raw(lamb_ll.target, ll_coord), 1/inv_norm, alpha)
globals()['load_average_%s' % alpha] = globals()['y_average_%s' % alpha](cond_prob)
else:
globals()['y_average_%s' % alpha] = y_averager(cond_prob.target,
ift.Field.from_raw(lamb_ll.target, ll_coord),
1 / inv_norm, alpha)
globals()['load_average_%s' % alpha] = globals()['y_average_%s' % alpha](cond_prob)
averages.append(globals()['load_average_%s' % alpha])
# Infectivity
......@@ -277,15 +293,17 @@ def covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, reshuffle_itera
for it, inf in enumerate(infec):
if inv_par: break
globals()['infectivity_%s' % inf] = ift.Field.from_raw(lamb_ll.target, fitted_infectivity(ll_coord+(it-1)))
globals()['infectivity_averager_%s' % inf] = y_averager(cond_prob.target, globals()['infectivity_%s' % inf], 1/inv_norm, 0)
globals()['infectivity_%s' % inf] = ift.Field.from_raw(lamb_ll.target,
fitted_infectivity(ll_coord + (it - 1)))
globals()['infectivity_averager_%s' % inf] = y_averager(cond_prob.target, globals()['infectivity_%s' % inf],
1 / inv_norm, 0)
infectivity.append(globals()['infectivity_averager_%s' % inf](cond_prob))
results['independent lambda'] = lamb_indep
results['age coordinates'] = age_coord
results['lload coordinates'] = ll_coord
results['lambda true'] = lamb_true
results['averages'] = averages
results['averages'] = averages
results['infectivity'] = infectivity
results['domain breaker op'] = dombr
......@@ -293,10 +311,3 @@ def covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, reshuffle_itera
elif plot == False:
return results
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2024 Max-Planck-Society