Commit 14c672d7 authored by Matteo.Guardiani's avatar Matteo.Guardiani
Browse files

cleanup: put MaternKernelModel into specific class. Separated component...

cleanup: put MaternKernelModel into specific class. Separated component creation into individual methods.
parent 19e8a379
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
# Copyright(C) 2013-2022 Max-Planck-Society
# Author: Matteo Guardiani
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import argparse
import json
import os
import sys
import nifty7 as ift
import numpy as np
from mpi4py import MPI
n_task = comm.Get_size()
rank = comm.Get_rank()
except ImportError:
comm = None
n_task = 1
rank = 0
master = (rank == 0)
from covid_matern_model import covid_matern_model_maker
from data_utilities import save_KL_sample, save_KL_position
from utilities import get_op_post_mean
from const import npix_age, npix_ll
from data import Data
from covid_matern_model2 import MaternCausalModel
# from evidence_g import get_evidence
import matplotlib.colors as colors
# Parser Setup
parser = argparse.ArgumentParser()
parser.add_argument('--json_file', type=str, required=True) # FIXME: Add help --help
parser.add_argument('--csv_file', type=str, required=True)
parser.add_argument('--reshuffle_parameter', type=int, required=True)
args = parser.parse_args()
json_file = args.json_file
csv_file = args.csv_file
reshuffle_iterator = args.reshuffle_parameter
if __name__ == '__main__':
# Read in the configuration file
current_path = os.path.abspath('.')
inversion_parameter = False
if 'inv' in json_file:
inversion_parameter = True
file_setup = open(json_file, "r")
setup = json.load(file_setup)
# Preparing the filename string and plots folder to store live results
if not os.path.exists('./plots'):
filename = "plots/covid_combined_matern_{}.png"
# Results Output Folders
path_j = os.path.basename(json_file)
path_c = os.path.basename(csv_file)
results_path = os.path.join('./Automized_Results_Matern', os.path.splitext(path_j)[0], os.path.splitext(path_c)[0],
results_path = os.path.normpath(results_path)
os.makedirs(results_path, exist_ok=True)
# Load the model
data = Data(npix_age, npix_ll, json_file, reshuffle_iterator, inversion_parameter, csv_file)
model = MaternCausalModel(data, False)
# Setup the response & define the amplitudes
R = ift.GeometryRemover(
R_lamb = R(model.lambda_combined)
A1 = model.amplitudes[0]
A2 = model.amplitudes[1]
# Specify data space
data_space =
# Generate mock signal and data
seed = setup['seed']
if setup['mock']:
# data
mock_position = ift.from_random(model.lambda_combined.domain, 'normal')
data = R_lamb(mock_position)
data = ift.random.current_rng().poisson(data.val.astype(np.float64))
indep_tag = '_indep'
if not setup['same data'] and indep_tag in json_file:
print("\nUsing syinthetic data generated from joint model on independent model")
joint_json_file = json_file.replace(indep_tag, '')
file_setup = open(joint_json_file, "r")
joint_setup = json.load(file_setup)
joint_model = MaternCausalModel(data, False)
joint_model = covid_matern_model_maker(npix_age, npix_ll, joint_setup, csv_file, reshuffle_iterator, False,
joint_lamb_comb = joint_model['combined lambda']
mock_position = ift.from_random(joint_lamb_comb.domain, 'normal')
data = R_lamb(mock_position)
data = ift.random.current_rng().poisson(data.val.astype(np.float64))
if not setup['same data'] and not indep_tag in json_file:
print("\nUsing syinthetic data generated from independent model on joint model")
indep_json_file = os.path.splitext(json_file)[0] + '_indep' + os.path.splitext(json_file)[1]
file_setup = open(indep_json_file, "r")
indep_setup = json.load(file_setup)
indep_model = covid_matern_model_maker(npix_age, npix_ll, indep_setup, csv_file, reshuffle_iterator, False,
indep_lamb_comb = indep_model['combined lambda']
mock_position = ift.from_random(indep_lamb_comb.domain, 'normal')
data = R_lamb(mock_position)
data = ift.random.current_rng().poisson(data.val.astype(np.float64))
data = ift.makeField(data_space, data)
if setup['mock']:
plot = ift.Plot()
plot.add(lamb_comb(mock_position), title='Full Field')
plot.add(R.adjoint(data), title='Data')
plot.add([A1.force(mock_position)], title='Power Spectrum 1')
plot.add([A2.force(mock_position)], title='Power Spectrum 2')
plot.output(ny=3, nx=2, xsize=10, ysize=10, name=filename.format("setup"))
# Minimization parameters
ic_sampling = ift.AbsDeltaEnergyController(deltaE=1e-5, iteration_limit=250, convergence_level=250)
ic_newton = ift.AbsDeltaEnergyController(deltaE=1e-5, iteration_limit=5, name='newton', convergence_level=3)
minimizer = ift.NewtonCG(ic_newton, enable_logging=True)
# Set up likelihood and information Hamiltonian
likelihood = ift.PoissonianEnergy(data) @ R_lamb
H = ift.StandardHamiltonian(likelihood, ic_sampling)
# Begin minimization
initial_mean = ift.from_random(H.domain, 'normal') * 0.1
mean = initial_mean
N_steps = 35 # 34
for i in range(N_steps):
if i < 27:
ic_newton = ift.AbsDeltaEnergyController(deltaE=1e-5, iteration_limit=10, name='newton',
ic_newton = ift.AbsDeltaEnergyController(deltaE=1e-5, iteration_limit=20, name='newton',
minimizer = ift.NewtonCG(ic_newton, enable_logging=True)
if i < 30:
N_samples = 5
elif i < 33:
N_samples = 20
N_samples = 500
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL(mean, H, N_samples, comm=comm, mirror_samples=True, nanisinf=True)
KL, convergence = minimizer(KL)
samples = tuple(KL.samples)
mean = KL.position
if master:
it = 0
pos_path = os.path.join(results_path, "KL_position")
save_KL_position(mean, pos_path)
print("KL position saved", file=sys.stderr)
sam_path = os.path.join(results_path, "samples")
os.makedirs(sam_path, exist_ok=True)
for sample in samples:
save_KL_sample(sample, os.path.join(sam_path, "KL_sample_{}".format(it)))
it += 1
print("KL samples saved", file=sys.stderr)
# Minisanity check
ift.extra.minisanity(data, lambda x: ift.makeOp(R_lamb(x).ptw('reciprocal')), R_lamb, mean,
samples) # Fix Me: Check noise implementation in minisanity
# Plot current reconstruction
plot = ift.Plot()
if setup['mock']:
plot.add([lamb_comb(mock_position)], title="ground truth")
plot.add(R.adjoint(data), title='Data')
plot.add([lamb_comb(mean)], title="reconstruction")
plot.add([lamb_joint.force(mean)], title="Joint component")
plot.add([A1.force(mean), A1.force(mock_position)], title="power1")
plot.add([A2.force(mean), A2.force(mock_position)], title="power2")
plot.add([ic_newton.history, ic_sampling.history, minimizer.inversion_history],
label=['KL', 'Sampling', 'Newton inversion'], title='Cumulative energies', s=[None, None, 1],
alpha=[None, 0.2, None])
plot.add([lamb_comb(mean)], title="Reconstruction", norm=colors.SymLogNorm(linthresh=10e-1),
extent=boundaries, aspect="auto")
plot.add([lamb_full.force(mean)], title="Joint Component Reconstruction",
norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries, aspect="auto")
plot.add([cond_prob.force(mean)], title="Conditional Probability Reconstruction",
norm=colors.SymLogNorm(linthresh=6 * 10e-4), extent=boundaries, aspect="auto")
# plot.add([Aj.force(mean)], title="power1 joint") # FIX ME: MAYBE ACCOUNT FOR THE MARGINALIZATION ??
plot.add([A1.force(mean)], title="power1 independent")
plot.add([A2.force(mean)], title="power2 independent")
plot.add(lamb_ag_full.force(mean), title="Age Reconstruction (full)", aspect="auto")
plot.add(lamb_ll_full.force(mean), title="Log load Reconstruction (full)", aspect="auto")
plot.output(nx=3, ny=3, ysize=10, xsize=15, name=filename.format("loop_{:02d}".format(i)))
print('Lamb combined check:', lamb_comb(KL.position).val.sum(), '\n', file=sys.stderr)
print('Zm Xi:', zm.force(KL.position).val, '\n', file=sys.stderr)
if master:
lamb_comb_mean, lamb_comb_var = get_op_post_mean(lamb_comb, mean, samples)
cond_prob_mean, cond_prob_var = get_op_post_mean(cond_prob, mean, samples)
lamb_full_mean, lamb_full_var = get_op_post_mean(lamb_full.exp(), mean, samples)
powers1 = []
powers2 = []
for sample in samples:
p1 = A1.force(sample + mean)
p2 = A2.force(sample + mean)
# Final Plots
filename_res = "Results.png"
filename_res = os.path.join(results_path, filename_res)
plot = ift.Plot()
plot.add(lamb_comb_mean, title="Posterior Mean", norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries,
plot.add(lamb_comb_var.sqrt(), title="Posterior Standard Deviation", norm=colors.SymLogNorm(linthresh=10e-1),
extent=boundaries, aspect="auto")
plot.add([cond_prob.force(mean)], title="Conditional Probability Reconstruction",
norm=colors.SymLogNorm(linthresh=6 * 10e-4), extent=boundaries, aspect="auto")
plot.add([A1.force(mean)], title="Age Independent Power Spectrum (log[S(k^2)])")
plot.add([A2.force(mean)], title="Log load Independent Power Spectrum (log[S(k^2)])")
plot.add(lamb_ag_full.force(mean), title="Age Reconstruction (full)", norm=colors.SymLogNorm(linthresh=10e-1),
plot.add(lamb_ll_full.force(mean), title="Log load Reconstruction (full)",
norm=colors.SymLogNorm(linthresh=10e-1), aspect="auto")
plot.add([lamb_full.exp().force(mean)], title="Joint Component Reconstruction (full)",
norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries, aspect="auto")
plot.output(ny=3, nx=3, xsize=20, ysize=15, name=filename_res)
print("Saved results as", filename_res, file=sys.stderr)
# Error Plots
filename_ers = "Errors.png"
filename_ers = os.path.join(results_path, filename_ers)
plot = ift.Plot()
plot.add(lamb_comb_mean, title="Posterior Mean", norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries,
plot.add(lamb_comb_var.sqrt() * lamb_comb_mean.ptw('reciprocal'), title="Relative Uncertainty",
norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries, aspect="auto")
plot.add(cond_prob_mean, title="Conditional Probability Reconstruction Mean",
norm=colors.SymLogNorm(linthresh=6 * 10e-4), extent=boundaries, aspect="auto")
plot.add(cond_prob_var.sqrt() * cond_prob_mean.ptw('reciprocal'),
title="Relative Uncertainty on Conditional Probability Reconstruction",
norm=colors.SymLogNorm(linthresh=10e-1), extent=boundaries, aspect="auto")
plot.add(lamb_full_mean, title="Joint Component Reconstruction Mean", norm=colors.SymLogNorm(linthresh=10e-2),
plot.add(lamb_full_var.sqrt() * lamb_full_mean.ptw('reciprocal'),
title="Relative Uncertainty on Joint Component Reconstruction", norm=colors.SymLogNorm(linthresh=10e-2),
plot.output(ny=3, nx=2, xsize=15, ysize=15, name=filename_ers)
print("Saved results as", filename_ers, file=sys.stderr)
# Get the evidence # Uncomment the following to compute the evidence directly from this script. # evidence =
# get_evidence(KL, data=data) # print("EVIDENCE", file=sys.stderr) # print(evidence, file=sys.stderr) # #
# print('\n', file=sys.stderr)
# # if dataset_file: # # shutil.copy(dataset_file, results_path)
# if json_file: # shutil.copy(json_file, results_path)
# with open(os.path.join(results_path,'evidence.txt'), 'wb') as file: # # file.write('EVIDENCE CALCULATION
# \n\n') # pickle.dump(evidence, file)
# with open('./Automized_Results_Matern/Automized_Evidences.txt', 'a') as file: # evidences = # #
# os.path.basename(json_file) + ', ' + os.path.basename(csv_file) + '_' + str(resh_it) + ', Evidence Mean: '
# + str(evidence['estimate']) + '\n\n' # file.write(evidences)
......@@ -17,310 +17,257 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty7 as ift
import numpy as np
from binner import Bin1D, Bin2D, data_filter_x
from data_utilities import fitted_infectivity
from data_utilities import read_in
from tools import density_estimator, y_averager
from tools import domainbreak_2D
import data
from tools import density_estimator, domainbreak_2D
from utilities import GeomMaskOperator
class MaternCausalModel:
def __init__(self, npix_age, npix_ll, setup, csv_file, reshuffle_iterator, plot, inv_par, alphas=None):
self.npix_age = npix_age
self.npix_ll = npix_ll
self.setup = setup
self.csv_file = csv_file
self.reshuffle_iterator = reshuffle_iterator
def __init__(self, dataset, plot, alphas=None):
if not isinstance(dataset, data.Data):
raise TypeError("The dataset needs to be of type Data.")
if not isinstance(plot, bool):
raise TypeError("The dataset needs to be of type Data.")
self.dataset = dataset
self.plot = plot
self.inv_par = inv_par
self.alphas = alphas
self.lambda_joint = None
self.lambda_age = None
self.lambda_ll = None
self.amplitudes = None
self.position_space = None
self.target_space = None
self.lambda_combined = self.create_model()[0]
self.conditional_probability = self.create_model()[1]
def create_model(self):
self.lambda_joint = self.build_joint_component()
self.lambda_age = self.build_independent_components()[0]
self.lambda_ll = self.build_independent_components()[1]
self.amplitudes = self.build_independent_components()[2]
# Dimensionality adjustment for the independent component
self.target_space =
domain_break_op = domainbreak_2D(self.target_space)
full_domain_break_op = domainbreak_2D(self.position_space)
lambda_joint_placeholder = ift.FieldAdapter(, 'lambdajoint')
lambda_ll_placeholder = ift.FieldAdapter(, 'lambdall')
lambda_age_placeholder = ift.FieldAdapter(, 'lambdaage')
x_marginalizer_op = domain_break_op(lambda_joint_placeholder.ptw('exp')).sum(
0) # Field exponentiation and marginalization along the x direction, hence has 'length' y
age_unit_field = ift.full(, 1)
dimensionality_operator = ift.OuterProduct(, age_unit_field)
lambda_ll_2d = domain_break_op.adjoint @ dimensionality_operator @ lambda_ll_placeholder
ll_unit_field = ift.full(, 1)
dimensionality_operator_2 = ift.OuterProduct(, ll_unit_field)
transposition_operator = ift.LinearEinsum(dimensionality_operator_2(lambda_age_placeholder).target,
ift.MultiField.from_dict({}), "xy->yx")
dimensionality_operator_2 = transposition_operator @ dimensionality_operator_2
lambda_age_2d = domain_break_op.adjoint @ dimensionality_operator_2 @ lambda_age_placeholder
joint_component = lambda_ll_2d + lambda_joint_placeholder
cond_density = joint_component.ptw('exp') * domain_break_op.adjoint(
normalization = domain_break_op(cond_density).sum(1)
log_lambda_combined = lambda_age_2d + joint_component - domain_break_op.adjoint(
dimensionality_operator(x_marginalizer_op.ptw('log'))) - domain_break_op.adjoint(
log_lambda_combined = log_lambda_combined @ (
self.lambda_joint.ducktape_left('lambdajoint') + self.lambda_ll.ducktape_left(
'lambdall') + self.lambda_age.ducktape_left('lambdaag'))
lambda_combined = log_lambda_combined.ptw('exp')
conditional_probability = cond_density * domain_break_op.adjoint(dimensionality_operator_2(normalization)).ptw(
conditional_probability = conditional_probability @ (
self.lambda_joint.ducktape_left('lambdajoint') + self.lambda_ll.ducktape_left('lambdall'))
# Normalize the probability on the given logload interval
boundaries = [min(self.dataset.coordinates()[0]), max(self.dataset.coordinates()[0]),
min(self.dataset.coordinates()[1]), max(self.dataset.coordinates()[1])]
inv_norm = self.dataset.npix_ll / (boundaries[3] - boundaries[2])
conditional_probability = conditional_probability * inv_norm
return lambda_combined, conditional_probability
def build_joint_component(self):
setup = self.dataset.setup
npix_age = self.dataset.npix_age
npix_ll = self.dataset.npix_ll
self.position_space, sp1, sp2 = self.dataset.zero_pad()
# Set up signal model
joint_offset = setup['joint']['offset_dict']
offset_mean = joint_offset['offset_mean']
offset_std = joint_offset['offset_std']
joint_prefix = joint_offset['prefix']
joint_setup_ll = setup['joint']['log_load']
ll_scale = joint_setup_ll['scale']
ll_cutoff = joint_setup_ll['cutoff']
ll_loglogslope = joint_setup_ll['loglogslope']
ll_prefix = joint_setup_ll['prefix']
joint_setup_age = setup['joint']['age']
age_scale = joint_setup_age['scale']
age_cutoff = joint_setup_age['cutoff']
age_loglogslope = joint_setup_age['loglogslope']
age_prefix = joint_setup_age['prefix']
# .make(offset_mean, offset_std,
correlated_field_maker = ift.CorrelatedFieldMaker(joint_prefix)
correlated_field_maker.set_amplitude_total_offset(offset_mean, offset_std, )
if self.dataset.inversion_par:
correlated_field_maker.add_fluctuations_matern(sp1, ll_scale, ll_cutoff, ll_loglogslope, ll_prefix)
correlated_field_maker.add_fluctuations_matern(sp2, age_scale, age_cutoff, age_loglogslope, age_prefix)
correlated_field_maker.add_fluctuations_matern(sp1, age_scale, age_cutoff, age_loglogslope, age_prefix)
correlated_field_maker.add_fluctuations_matern(sp2, ll_scale, ll_cutoff, ll_loglogslope, ll_prefix)
lambda_full = correlated_field_maker.finalize()
# For the joint model unmasked regions
tgt = ift.RGSpace((npix_age, npix_ll),
if self.dataset.inversion_par:
tgt = ift.RGSpace((npix_ll, npix_age),
GMO = GeomMaskOperator(, tgt)
lambda_joint = GMO(lambda_full.clip(-30, 30))
return lambda_joint
def build_independent_components(self):
setup = self.dataset.setup
npix_age = self.dataset.npix_age
npix_ll = self.dataset.npix_ll
_, sp1, sp2 = self.dataset.zero_pad()
# Set up signal model
# Age Parameters
age_dict = setup['indep']['age']
age_offset_mean = age_dict['offset_dict']['offset_mean']
age_offset_std = age_dict['offset_dict']['offset_std']
indep_age_prefix = age_dict['offset_dict']['prefix']
age_prefix = age_dict['params']['prefix']
# Log Load Parameters
ll_dict = setup['indep']['log_load']
ll_offset_mean = ll_dict['offset_dict']['offset_mean']
ll_offset_std = ll_dict['offset_dict']['offset_std']
indep_ll_prefix = ll_dict['offset_dict']['prefix']
if self.dataset.inversion_par:
signal_response, ops = density_estimator(sp1, cf_fluctuations=ll_dict['params'],
cf_azm_uniform=ll_offset_std, azm_offset_mean=ll_offset_mean, pad=0)
correlated_field_maker = ift.CorrelatedFieldMaker(indep_age_prefix)
correlated_field_maker.set_amplitude_total_offset(age_offset_mean, age_offset_std)
correlated_field_maker.add_fluctuations_matern(sp2, **age_dict['params'])
lambda_ll_full = correlated_field_maker.finalize()
ll_amplitude = correlated_field_maker.amplitude
signal_response, ops = density_estimator(sp1, cf_fluctuations=age_dict['params'],
cf_azm_uniform=age_offset_std, azm_offset_mean=age_offset_mean, pad=0)
correlated_field_maker = ift.CorrelatedFieldMaker(indep_ll_prefix)
correlated_field_maker.set_amplitude_total_offset(ll_offset_mean, ll_offset_std)
correlated_field_maker.add_fluctuations_matern(sp2, **ll_dict['params'])
lambda_ll_full = correlated_field_maker.finalize()
ll_amplitude = correlated_field_maker.amplitude
lambda_ag_full = ops["correlated_field"]
age_amplitude = ops["amplitude"]
zero_mode = ops["amplitude_total_offset"]
# response = ops["exposure"]
# Split the center
# Age
_dist =[0].distances
tgt_age = ift.RGSpace(npix_age, distances=_dist)
if self.dataset.inversion_par:
tgt_age = ift.RGSpace(npix_ll, distances=_dist)
GMO_age = GeomMaskOperator(, tgt_age)
lambda_age = GMO_age(lambda_ag_full.clip(-30, 30))
# Viral load
_dist =[0].distances
tgt_ll = ift.RGSpace(npix_ll, distances=_dist)
if self.dataset.inversion_par:
tgt_ll = ift.RGSpace(npix_age, distances=_dist)
GMO_ll = GeomMaskOperator(, tgt_ll)
lambda_ll = GMO_ll(lambda_ll_full.clip(-30, 30))
amplitudes = [age_amplitude, ll_amplitude]
return lambda_age, lambda_ll, amplitudes
def covid_matern_model_maker(npix_age, npix_ll, setup, csv_file, reshuffle_iterator, plot, inv_par, alphas=None):
''' Implements a causal model for covid data analysis in Nifty.
Required arguments:
- npix_age, npix_ll: int, number of bins (pixels) per age, log load axis
- setup: .json file with radom seed and parameters needed by the Matérn Kernel correlated field class.
- csv_file: .csv file with age and log load of a Covid positive patient
- reshuffle iterator: random seed for data reshuffling along the log load (y) axis. If == 0, the dataset is not
- plot: bool, if True returns extra-fields which might be needed for plots.
- inv_par: bool, if True exchanges x and y data and axis for opposite causal analysis.
- alphas: alphas for different alpha-power averages in the plots
- Binned data and necessary fields for analysis e.g. cond_prob (the conditional probability of log load given the
age p(y|x))
if not isinstance(npix_age, int):
raise TypeError("Pixel argument needs to be int")
if not isinstance(npix_ll, int):
raise TypeError("Pixel argument needs to be int")
if not isinstance(plot, bool):
raise TypeError("Plot argument needs to be bool")
# Make a bigger domain to avoid boundary effects
ext_npix_age = 2 * npix_age
ext_npix_ll = 2 * npix_ll
if inv_par:
position_space = ift.RGSpace((ext_npix_ll, ext_npix_age))
sp1 = ift.RGSpace(ext_npix_ll)
sp2 = ift.RGSpace(ext_npix_age)
position_space = ift.RGSpace((ext_npix_age, ext_npix_ll))
sp1 = ift.RGSpace(ext_npix_age)
sp2 = ift.RGSpace(ext_npix_ll)
# Load and bin data
threshold = setup["threshold"]
age, ll = read_in(csv_file)
ll, age = data_filter_x(threshold, ll,
age) # For Cobas data filter was set at 3.3 for log load FIXXXXX!!! NOW 3.85 (other param 5.4)
if not reshuffle_iterator == 0:
from sklearn.utils import shuffle
ll = shuffle(ll, random_state=reshuffle_iterator)
if inv_par:
data, ll_edges, age_edges = Bin2D(ll, age, npix_ll, npix_age)
data, age_edges, ll_edges = Bin2D(age, ll, npix_age, npix_ll)
data = np.array(data, dtype=np.int64)
# Create proper coordinates
binned_age, edges_a = Bin1D(age, npix_age)
age_coord = 0.5 * (edges_a[:-1] + edges_a[1:])