Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
public
causal_age_viral_load_model
Commits
f572787f
Commit
f572787f
authored
Jan 25, 2022
by
Matteo.Guardiani
Browse files
andrija: Refactored age -> x, ll -> y. Made more covid-agnostic.
parent
8c8c5aca
Changes
4
Hide whitespace changes
Inline
Side-by-side
covid_combined_matern_essential.py
View file @
f572787f
...
...
@@ -25,9 +25,9 @@ import matplotlib.colors as colors
import
nifty7
as
ift
from
const
import
npix_age
,
npix_ll
from
covid_matern_model
import
MaternCausalModel
from
data
import
Data
from
data_utilities
import
save_kl_position
,
save_kl_sample
from
matern_causal_model
import
MaternCausalModel
from
utilities
import
get_op_post_mean
# Parser Setup
...
...
covid_combined_matern_mpi.py
View file @
f572787f
...
...
@@ -41,7 +41,7 @@ from data_utilities import save_kl_sample, save_kl_position
from
utilities
import
get_op_post_mean
from
const
import
npix_age
,
npix_ll
from
data
import
Data
from
covid_
matern_model
import
MaternCausalModel
from
matern
_causal
_model
import
MaternCausalModel
# from evidence_g import get_evidence
import
matplotlib.colors
as
colors
...
...
data.py
View file @
f572787f
...
...
@@ -32,34 +32,34 @@ class ClassLoader(type):
class
Data
(
metaclass
=
ClassLoader
):
def
__init__
(
self
,
npix_
age
,
npix_
ll
,
ll
_threshold
,
reshuffle_seed
,
age
,
ll
):
if
not
isinstance
(
npix_
age
,
int
):
def
__init__
(
self
,
npix_
x
,
npix_
y
,
y
_threshold
,
reshuffle_seed
,
x
,
y
):
if
not
isinstance
(
npix_
x
,
int
):
raise
TypeError
(
"Number of pixels argument needs to be of type int."
)
if
not
isinstance
(
npix_
ll
,
int
):
if
not
isinstance
(
npix_
y
,
int
):
raise
TypeError
(
"Number of pixels argument needs to be of type int."
)
if
not
isinstance
(
ll
_threshold
,
float
):
if
not
isinstance
(
y
_threshold
,
float
):
raise
TypeError
(
"Log load threshold value argument needs to be of type float."
)
if
not
isinstance
(
reshuffle_seed
,
int
):
raise
TypeError
(
"Reshuffle iterator argument needs to be of type int."
)
if
not
isinstance
(
age
,
np
.
ndarray
):
raise
TypeError
(
"The
age
dataset argument needs to be of type np.ndarray."
)
if
not
isinstance
(
x
,
np
.
ndarray
):
raise
TypeError
(
"The
x
dataset argument needs to be of type np.ndarray."
)
if
not
isinstance
(
ll
,
np
.
ndarray
):
raise
TypeError
(
"The
ll
dataset argument needs to be of type np.ndarray."
)
if
not
isinstance
(
y
,
np
.
ndarray
):
raise
TypeError
(
"The
y
dataset argument needs to be of type np.ndarray."
)
if
not
age
.
size
==
ll
.
size
:
raise
TypeError
(
"The dataset has to contain pairs of
age
and log load data."
)
if
not
x
.
size
==
y
.
size
:
raise
TypeError
(
"The dataset has to contain pairs of
x
and log load data."
)
self
.
npix_
age
,
self
.
npix_
ll
=
npix_
age
,
npix_
ll
self
.
ll
_threshold
=
ll
_threshold
self
.
npix_
x
,
self
.
npix_
y
=
npix_
x
,
npix_
y
self
.
y
_threshold
=
y
_threshold
self
.
reshuffle_seed
=
reshuffle_seed
self
.
age
=
age
,
self
.
ll
=
ll
self
.
x
=
x
,
self
.
y
=
y
self
.
age
,
self
.
ll
=
self
.
filter
()
self
.
x
,
self
.
y
=
self
.
filter
()
self
.
data
=
None
self
.
edges
=
None
self
.
filename
=
'plots/data.pdf'
...
...
@@ -70,30 +70,30 @@ class Data(metaclass=ClassLoader):
def
filter
(
self
):
# Loads, filters and reshuffles data
self
.
ll
,
self
.
age
=
self
.
data_filter_x
(
self
.
ll
_threshold
,
self
.
ll
,
self
.
age
)
self
.
y
,
self
.
x
=
self
.
data_filter_x
(
self
.
y
_threshold
,
self
.
y
,
self
.
x
)
if
not
self
.
reshuffle_seed
==
0
:
self
.
__reshuffle_data
(
self
.
ll
,
self
.
reshuffle_seed
)
self
.
__reshuffle_data
(
self
.
y
,
self
.
reshuffle_seed
)
return
self
.
age
,
self
.
ll
return
self
.
x
,
self
.
y
def
zero_pad
(
self
):
ext_npix_
age
=
2
*
self
.
npix_
age
ext_npix_
ll
=
2
*
self
.
npix_
ll
ext_npix_
x
=
2
*
self
.
npix_
x
ext_npix_
y
=
2
*
self
.
npix_
y
return
self
.
__create_spaces
(
ext_npix_
age
,
ext_npix_
ll
)
return
self
.
__create_spaces
(
ext_npix_
x
,
ext_npix_
y
)
def
__bin
(
self
):
data
,
age
_edges
,
ll
_edges
=
bin_2D
(
self
.
age
,
self
.
ll
,
self
.
npix_
age
,
self
.
npix_
ll
)
data
,
x
_edges
,
y
_edges
=
bin_2D
(
self
.
x
,
self
.
y
,
self
.
npix_
x
,
self
.
npix_
y
)
self
.
data
=
np
.
array
(
data
,
dtype
=
np
.
int64
)
return
data
,
age
_edges
,
ll
_edges
return
data
,
x
_edges
,
y
_edges
def
coordinates
(
self
):
age
_coordinates
=
self
.
__obtain_coordinates
(
self
.
age
,
self
.
npix_
age
)
ll
_coordinates
=
self
.
__obtain_coordinates
(
self
.
ll
,
self
.
npix_
ll
)
x
_coordinates
=
self
.
__obtain_coordinates
(
self
.
x
,
self
.
npix_
x
)
y
_coordinates
=
self
.
__obtain_coordinates
(
self
.
y
,
self
.
npix_
y
)
return
age
_coordinates
,
ll
_coordinates
return
x
_coordinates
,
y
_coordinates
def
plot
(
self
):
import
matplotlib.pyplot
as
plt
...
...
@@ -137,7 +137,7 @@ class Data(metaclass=ClassLoader):
class
InvertedData
(
Data
):
def
__init__
(
self
,
data
):
super
().
__init__
(
data
.
npix_
age
,
data
.
npix_
ll
,
data
.
ll
_threshold
,
data
.
reshuffle_seed
,
data
.
csv_dataset_path
)
self
.
age
,
self
.
ll
=
self
.
ll
,
self
.
age
self
.
npix_
age
,
self
.
npix_
ll
=
self
.
npix_
ll
,
self
.
npix_
age
super
().
__init__
(
data
.
npix_
x
,
data
.
npix_
y
,
data
.
y
_threshold
,
data
.
reshuffle_seed
,
data
.
csv_dataset_path
)
self
.
x
,
self
.
y
=
self
.
y
,
self
.
x
self
.
npix_
x
,
self
.
npix_
y
=
self
.
npix_
y
,
self
.
npix_
x
self
.
filename
=
'plots/inverted_data.pdf'
covid_
matern_model.py
→
matern
_causal
_model.py
View file @
f572787f
...
...
@@ -40,10 +40,10 @@ class MaternCausalModel:
self
.
plot
=
plot
self
.
alphas
=
alphas
self
.
lambda_joint
=
None
self
.
lambda_
age
=
None
self
.
lambda_
ll
=
None
self
.
lambda_
age
_full
=
None
self
.
lambda_
ll
_full
=
None
self
.
lambda_
x
=
None
self
.
lambda_
y
=
None
self
.
lambda_
x
_full
=
None
self
.
lambda_
y
_full
=
None
self
.
lambda_full
=
None
self
.
amplitudes
=
None
self
.
position_space
=
None
...
...
@@ -54,7 +54,7 @@ class MaternCausalModel:
def
create_model
(
self
):
self
.
lambda_joint
,
self
.
lambda_full
=
self
.
build_joint_component
()
self
.
lambda_
age
,
self
.
lambda_
ll
,
self
.
lambda_
age
_full
,
self
.
lambda_
ll
_full
,
self
.
amplitudes
=
\
self
.
lambda_
x
,
self
.
lambda_
y
,
self
.
lambda_
x
_full
,
self
.
lambda_
y
_full
,
self
.
amplitudes
=
\
self
.
initialize_independent_components
()
# Dimensionality adjustment for the independent component
...
...
@@ -63,51 +63,51 @@ class MaternCausalModel:
domain_break_op
=
DomainBreak2D
(
self
.
target_space
)
lambda_joint_placeholder
=
ift
.
FieldAdapter
(
self
.
lambda_joint
.
target
,
'lambdajoint'
)
lambda_
ll
_placeholder
=
ift
.
FieldAdapter
(
self
.
lambda_
ll
.
target
,
'lambda
ll
'
)
lambda_
age
_placeholder
=
ift
.
FieldAdapter
(
self
.
lambda_
age
.
target
,
'lambda
age
'
)
lambda_
y
_placeholder
=
ift
.
FieldAdapter
(
self
.
lambda_
y
.
target
,
'lambda
y
'
)
lambda_
x
_placeholder
=
ift
.
FieldAdapter
(
self
.
lambda_
x
.
target
,
'lambda
x
'
)
x_marginalizer_op
=
domain_break_op
(
lambda_joint_placeholder
.
ptw
(
'exp'
)).
sum
(
0
)
# Field exponentiation and marginalization along the x direction, hence has 'length' y
age
_unit_field
=
ift
.
full
(
self
.
lambda_
age
.
target
,
1
)
dimensionality_operator
=
ift
.
OuterProduct
(
self
.
lambda_
ll
.
target
,
age
_unit_field
)
lambda_
ll
_2d
=
domain_break_op
.
adjoint
@
dimensionality_operator
@
lambda_
ll
_placeholder
ll
_unit_field
=
ift
.
full
(
self
.
lambda_
ll
.
target
,
1
)
dimensionality_operator_2
=
ift
.
OuterProduct
(
self
.
lambda_
age
.
target
,
ll
_unit_field
)
transposition_operator
=
ift
.
LinearEinsum
(
dimensionality_operator_2
(
lambda_
age
_placeholder
).
target
,
x
_unit_field
=
ift
.
full
(
self
.
lambda_
x
.
target
,
1
)
dimensionality_operator
=
ift
.
OuterProduct
(
self
.
lambda_
y
.
target
,
x
_unit_field
)
lambda_
y
_2d
=
domain_break_op
.
adjoint
@
dimensionality_operator
@
lambda_
y
_placeholder
y
_unit_field
=
ift
.
full
(
self
.
lambda_
y
.
target
,
1
)
dimensionality_operator_2
=
ift
.
OuterProduct
(
self
.
lambda_
x
.
target
,
y
_unit_field
)
transposition_operator
=
ift
.
LinearEinsum
(
dimensionality_operator_2
(
lambda_
x
_placeholder
).
target
,
ift
.
MultiField
.
from_dict
({}),
"xy->yx"
)
dimensionality_operator_2
=
transposition_operator
@
dimensionality_operator_2
lambda_
age
_2d
=
domain_break_op
.
adjoint
@
dimensionality_operator_2
@
lambda_
age
_placeholder
lambda_
x
_2d
=
domain_break_op
.
adjoint
@
dimensionality_operator_2
@
lambda_
x
_placeholder
joint_component
=
lambda_
ll
_2d
+
lambda_joint_placeholder
joint_component
=
lambda_
y
_2d
+
lambda_joint_placeholder
cond_density
=
joint_component
.
ptw
(
'exp'
)
*
domain_break_op
.
adjoint
(
dimensionality_operator
(
x_marginalizer_op
.
ptw
(
'reciprocal'
)))
normalization
=
domain_break_op
(
cond_density
).
sum
(
1
)
log_lambda_combined
=
lambda_
age
_2d
+
joint_component
-
domain_break_op
.
adjoint
(
log_lambda_combined
=
lambda_
x
_2d
+
joint_component
-
domain_break_op
.
adjoint
(
dimensionality_operator
(
x_marginalizer_op
.
ptw
(
'log'
)))
-
domain_break_op
.
adjoint
(
dimensionality_operator_2
(
normalization
.
ptw
(
'log'
)))
log_lambda_combined
=
log_lambda_combined
@
(
self
.
lambda_joint
.
ducktape_left
(
'lambdajoint'
)
+
self
.
lambda_
ll
.
ducktape_left
(
'lambda
ll
'
)
+
self
.
lambda_
age
.
ducktape_left
(
'lambda
age
'
))
self
.
lambda_joint
.
ducktape_left
(
'lambdajoint'
)
+
self
.
lambda_
y
.
ducktape_left
(
'lambda
y
'
)
+
self
.
lambda_
x
.
ducktape_left
(
'lambda
x
'
))
lambda_combined
=
log_lambda_combined
.
ptw
(
'exp'
)
conditional_probability
=
cond_density
*
domain_break_op
.
adjoint
(
dimensionality_operator_2
(
normalization
)).
ptw
(
'reciprocal'
)
conditional_probability
=
conditional_probability
@
(
self
.
lambda_joint
.
ducktape_left
(
'lambdajoint'
)
+
self
.
lambda_
ll
.
ducktape_left
(
'lambda
ll
'
))
self
.
lambda_joint
.
ducktape_left
(
'lambdajoint'
)
+
self
.
lambda_
y
.
ducktape_left
(
'lambda
y
'
))
# Normalize the probability on the given logload interval
boundaries
=
[
min
(
self
.
dataset
.
coordinates
()[
0
]),
max
(
self
.
dataset
.
coordinates
()[
0
]),
min
(
self
.
dataset
.
coordinates
()[
1
]),
max
(
self
.
dataset
.
coordinates
()[
1
])]
inv_norm
=
self
.
dataset
.
npix_
ll
/
(
boundaries
[
3
]
-
boundaries
[
2
])
inv_norm
=
self
.
dataset
.
npix_
y
/
(
boundaries
[
3
]
-
boundaries
[
2
])
conditional_probability
=
conditional_probability
*
inv_norm
return
lambda_combined
,
conditional_probability
def
build_joint_component
(
self
):
npix_
age
=
self
.
dataset
.
npix_
age
npix_
ll
=
self
.
dataset
.
npix_
ll
npix_
x
=
self
.
dataset
.
npix_
x
npix_
y
=
self
.
dataset
.
npix_
y
self
.
position_space
,
sp1
,
sp2
=
self
.
dataset
.
zero_pad
()
# Set up signal model
...
...
@@ -116,27 +116,27 @@ class MaternCausalModel:
offset_std
=
joint_offset
[
'offset_std'
]
joint_prefix
=
joint_offset
[
'prefix'
]
joint_setup_
ll
=
self
.
setup
[
'joint'
][
'log_load'
]
ll
_scale
=
joint_setup_
ll
[
'scale'
]
ll
_cutoff
=
joint_setup_
ll
[
'cutoff'
]
ll
_loglogslope
=
joint_setup_
ll
[
'loglogslope'
]
ll
_prefix
=
joint_setup_
ll
[
'prefix'
]
joint_setup_
y
=
self
.
setup
[
'joint'
][
'log_load'
]
y
_scale
=
joint_setup_
y
[
'scale'
]
y
_cutoff
=
joint_setup_
y
[
'cutoff'
]
y
_loglogslope
=
joint_setup_
y
[
'loglogslope'
]
y
_prefix
=
joint_setup_
y
[
'prefix'
]
joint_setup_
age
=
self
.
setup
[
'joint'
][
'
age
'
]
age
_scale
=
joint_setup_
age
[
'scale'
]
age
_cutoff
=
joint_setup_
age
[
'cutoff'
]
age
_loglogslope
=
joint_setup_
age
[
'loglogslope'
]
age
_prefix
=
joint_setup_
age
[
'prefix'
]
joint_setup_
x
=
self
.
setup
[
'joint'
][
'
x
'
]
x
_scale
=
joint_setup_
x
[
'scale'
]
x
_cutoff
=
joint_setup_
x
[
'cutoff'
]
x
_loglogslope
=
joint_setup_
x
[
'loglogslope'
]
x
_prefix
=
joint_setup_
x
[
'prefix'
]
correlated_field_maker
=
ift
.
CorrelatedFieldMaker
(
joint_prefix
)
correlated_field_maker
.
set_amplitude_total_offset
(
offset_mean
,
offset_std
)
correlated_field_maker
.
add_fluctuations_matern
(
sp1
,
age
_scale
,
age
_cutoff
,
age
_loglogslope
,
age
_prefix
)
correlated_field_maker
.
add_fluctuations_matern
(
sp2
,
ll
_scale
,
ll
_cutoff
,
ll
_loglogslope
,
ll
_prefix
)
correlated_field_maker
.
add_fluctuations_matern
(
sp1
,
x
_scale
,
x
_cutoff
,
x
_loglogslope
,
x
_prefix
)
correlated_field_maker
.
add_fluctuations_matern
(
sp2
,
y
_scale
,
y
_cutoff
,
y
_loglogslope
,
y
_prefix
)
lambda_full
=
correlated_field_maker
.
finalize
()
# For the joint model unmasked regions
tgt
=
ift
.
RGSpace
((
npix_
age
,
npix_
ll
),
tgt
=
ift
.
RGSpace
((
npix_
x
,
npix_
y
),
distances
=
(
lambda_full
.
target
[
0
].
distances
[
0
],
lambda_full
.
target
[
1
].
distances
[
0
]))
GMO
=
GeomMaskOperator
(
lambda_full
.
target
,
tgt
)
...
...
@@ -147,55 +147,55 @@ class MaternCausalModel:
return
lambda_joint
,
lambda_full
def
build_independent_components
(
self
,
lambda_
ag
_full
,
lambda_
ll
_full
,
amplitudes
):
def
build_independent_components
(
self
,
lambda_
x
_full
,
lambda_
y
_full
,
amplitudes
):
# Split the center
#
Age
_dist
=
lambda_
ag
_full
.
target
[
0
].
distances
tgt_
age
=
ift
.
RGSpace
(
self
.
dataset
.
npix_
age
,
distances
=
_dist
)
GMO_
age
=
GeomMaskOperator
(
lambda_
ag
_full
.
target
,
tgt_
age
)
lambda_
age
=
GMO_
age
(
lambda_
ag
_full
.
clip
(
-
30
,
30
))
#
X
_dist
=
lambda_
x
_full
.
target
[
0
].
distances
tgt_
x
=
ift
.
RGSpace
(
self
.
dataset
.
npix_
x
,
distances
=
_dist
)
GMO_
x
=
GeomMaskOperator
(
lambda_
x
_full
.
target
,
tgt_
x
)
lambda_
x
=
GMO_
x
(
lambda_
x
_full
.
clip
(
-
30
,
30
))
# Viral load
_dist
=
lambda_
ll
_full
.
target
[
0
].
distances
tgt_
ll
=
ift
.
RGSpace
(
self
.
dataset
.
npix_
ll
,
distances
=
_dist
)
GMO_
ll
=
GeomMaskOperator
(
lambda_
ll
_full
.
target
,
tgt_
ll
)
lambda_
ll
=
GMO_
ll
(
lambda_
ll
_full
.
clip
(
-
30
,
30
))
_dist
=
lambda_
y
_full
.
target
[
0
].
distances
tgt_
y
=
ift
.
RGSpace
(
self
.
dataset
.
npix_
y
,
distances
=
_dist
)
GMO_
y
=
GeomMaskOperator
(
lambda_
y
_full
.
target
,
tgt_
y
)
lambda_
y
=
GMO_
y
(
lambda_
y
_full
.
clip
(
-
30
,
30
))
return
lambda_
age
,
lambda_
ll
,
lambda_
ag
_full
,
lambda_
ll
_full
,
amplitudes
return
lambda_
x
,
lambda_
y
,
lambda_
x
_full
,
lambda_
y
_full
,
amplitudes
def
initialize_independent_components
(
self
):
_
,
sp1
,
sp2
=
self
.
dataset
.
zero_pad
()
# Set up signal model
#
Age
Parameters
age
_dictionary
=
self
.
setup
[
'indep'
][
'
age
'
]
age
_offset_mean
=
age
_dictionary
[
'offset_dict'
][
'offset_mean'
]
age
_offset_std
=
age
_dictionary
[
'offset_dict'
][
'offset_std'
]
#
X
Parameters
x
_dictionary
=
self
.
setup
[
'indep'
][
'
x
'
]
x
_offset_mean
=
x
_dictionary
[
'offset_dict'
][
'offset_mean'
]
x
_offset_std
=
x
_dictionary
[
'offset_dict'
][
'offset_std'
]
# Log Load Parameters
ll
_dictionary
=
self
.
setup
[
'indep'
][
'log_load'
]
ll
_offset_mean
=
ll
_dictionary
[
'offset_dict'
][
'offset_mean'
]
ll
_offset_std
=
ll
_dictionary
[
'offset_dict'
][
'offset_std'
]
indep_
ll
_prefix
=
ll
_dictionary
[
'offset_dict'
][
'prefix'
]
# Create the
age
axis with the density estimator
signal_response
,
ops
=
density_estimator
(
sp1
,
cf_fluctuations
=
age
_dictionary
[
'params'
],
cf_azm_uniform
=
age
_offset_std
,
azm_offset_mean
=
age
_offset_mean
,
pad
=
0
)
lambda_
ag
_full
=
ops
[
"correlated_field"
]
age
_amplitude
=
ops
[
"amplitude"
]
y
_dictionary
=
self
.
setup
[
'indep'
][
'log_load'
]
y
_offset_mean
=
y
_dictionary
[
'offset_dict'
][
'offset_mean'
]
y
_offset_std
=
y
_dictionary
[
'offset_dict'
][
'offset_std'
]
indep_
y
_prefix
=
y
_dictionary
[
'offset_dict'
][
'prefix'
]
# Create the
x
axis with the density estimator
signal_response
,
ops
=
density_estimator
(
sp1
,
cf_fluctuations
=
x
_dictionary
[
'params'
],
cf_azm_uniform
=
x
_offset_std
,
azm_offset_mean
=
x
_offset_mean
,
pad
=
0
)
lambda_
x
_full
=
ops
[
"correlated_field"
]
x
_amplitude
=
ops
[
"amplitude"
]
zero_mode
=
ops
[
"amplitude_total_offset"
]
# response = ops["exposure"]
# Create the viral load axis with the Matérn-kernel correlated field
correlated_field_maker
=
ift
.
CorrelatedFieldMaker
(
indep_
ll
_prefix
)
correlated_field_maker
.
set_amplitude_total_offset
(
ll
_offset_mean
,
ll
_offset_std
)
correlated_field_maker
.
add_fluctuations_matern
(
sp2
,
**
ll
_dictionary
[
'params'
])
lambda_
ll
_full
=
correlated_field_maker
.
finalize
()
ll
_amplitude
=
correlated_field_maker
.
amplitude
correlated_field_maker
=
ift
.
CorrelatedFieldMaker
(
indep_
y
_prefix
)
correlated_field_maker
.
set_amplitude_total_offset
(
y
_offset_mean
,
y
_offset_std
)
correlated_field_maker
.
add_fluctuations_matern
(
sp2
,
**
y
_dictionary
[
'params'
])
lambda_
y
_full
=
correlated_field_maker
.
finalize
()
y
_amplitude
=
correlated_field_maker
.
amplitude
amplitudes
=
[
age
_amplitude
,
ll
_amplitude
]
amplitudes
=
[
x
_amplitude
,
y
_amplitude
]
return
self
.
build_independent_components
(
lambda_
ag
_full
,
lambda_
ll
_full
,
amplitudes
)
return
self
.
build_independent_components
(
lambda_
x
_full
,
lambda_
y
_full
,
amplitudes
)
def
plot_prior_samples
(
self
,
n_samples
):
plot
=
ift
.
Plot
()
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment