Commit 5199c2c8 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added latent regularization control to deepof.data.coordinates.deep_unsupervised_embedding()

parent 86f6b157
Pipeline #93296 passed with stage
in 50 minutes and 56 seconds
......@@ -821,6 +821,8 @@ class coordinates:
save_checkpoints: bool = False,
save_weights: bool = True,
variational: bool = True,
reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False,
) -> Tuple:
"""
Annotates coordinates using an unsupervised autoencoder.
......@@ -879,6 +881,8 @@ class coordinates:
save_checkpoints=save_checkpoints,
save_weights=save_weights,
variational=variational,
reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance,
)
# returns a list of trained tensorflow models
......
......@@ -108,6 +108,15 @@ parser.add_argument(
default=10,
type=int,
)
parser.add_argument(
"--latent-reg",
"-lreg",
help="Sets the strategy to regularize the latent mixture of Gaussians. "
"It has to be one of none, categorical (an elastic net penalty is applied to the categorical distribution),"
"variance (l2 penalty to the variance of the clusters) or categorical+variance. Defaults to none.",
default="none",
type=str,
)
parser.add_argument(
"--loss",
"-l",
......@@ -225,6 +234,7 @@ hparams = args.hyperparameters
input_type = args.input_type
k = args.components
kl_wu = args.kl_warmup
latent_reg = args.latent_reg
loss = args.loss
mmd_wu = args.mmd_warmup
mc_kl = args.montecarlo_kl
......@@ -365,6 +375,8 @@ if not tune:
save_checkpoints=False,
save_weights=True,
variational=variational,
reg_cat_clusters=("categorical" in latent_reg),
reg_cluster_variance=("variance" in latent_reg),
)
else:
......
......@@ -261,6 +261,8 @@ def autoencoder_fitting(
save_checkpoints: bool,
save_weights: bool,
variational: bool,
reg_cat_clusters: bool,
reg_cluster_variance: bool,
):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
......@@ -336,6 +338,8 @@ def autoencoder_fitting(
overlap_loss=False,
phenotype_prediction=phenotype_class,
predictor=predictor,
reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance,
).build(
X_train.shape
)
......
......@@ -15,25 +15,26 @@ import os
outpath = "/psycl/g/mpsstatgen/lucas/DLC/DLC_autoencoders/DeepOF/deepof/logs/"
losses = ["ELBO"] # , "MMD", "ELBO+MMD"]
encodings = [4, 6, 8] # [2, 4, 6, 8, 10, 12, 14, 16]
cluster_numbers = [10, 15, 20] # [1, 5, 10, 15, 20]
encodings = [4, 8] # [2, 4, 6, 8, 10, 12, 14, 16]
cluster_numbers = [25] # [1, 5, 10, 15, 20, 25]
latent_reg = ["none", "categorical", "variance", "categorical+variance"]
pheno_weights = [0.01, 0.1, 0.25, 0.5, 1.0, 2.0, 4.0, 10.0, 100.0]
rule deepof_experiments:
input:
expand( "/psycl/g/mpsstatgen/lucas/DLC/DLC_autoencoders/DeepOF/deepof/logs/hyperparameter_tuning/trained_weights/"
"GMVAE_loss={loss}_encoding=2_run_1_final_weights.h5",
loss=losses,
# expand( "/psycl/g/mpsstatgen/lucas/DLC/DLC_autoencoders/DeepOF/deepof/logs/hyperparameter_tuning/trained_weights/"
# "GMVAE_loss={loss}_encoding=2_run_1_final_weights.h5",
# loss=losses,
# )
expand(
"/psycl/g/mpsstatgen/lucas/DLC/DLC_autoencoders/DeepOF/deepof/logs/dimension_and_loss_experiments/trained_weights/"
"GMVAE_loss={loss}_encoding={encs}_k={k}_latreg={latreg}_final_weights.h5",
loss=losses,
encs=encodings,
k=cluster_numbers,
latreg=latent_reg,
)
# expand(
# "/psycl/g/mpsstatgen/lucas/DLC/DLC_autoencoders/DeepOF/deepof/logs/dimension_and_loss_experiments/trained_weights/"
# "GMVAE_loss={loss}_encoding={encs}_k={k}_run_1_final_weights.h5",
# loss=losses,
# encs=encodings,
# k=cluster_numbers,
# ),
# expand(
# "/psycl/g/mpsstatgen/lucas/DLC/DLC_autoencoders/DeepOF/deepof/logs/pheno_classification_experiments/trained_weights/"
# "GMVAE_loss={loss}_encoding={encs}_k={k}_pheno={phenos}_run_1_final_weights.h5",
......@@ -44,63 +45,64 @@ rule deepof_experiments:
# ),
rule coarse_hyperparameter_tuning:
input:
data_path="/psycl/g/mpsstatgen/lucas/DLC/DLC_models/deepof_single_topview/",
output:
trained_models=os.path.join(
outpath,
"hyperparameter_tuning/trained_weights/GMVAE_loss={loss}_encoding=2_run_1_final_weights.h5",
),
shell:
"pipenv run python -m deepof.train_model "
"--train-path {input.data_path} "
"--val-num 25 "
"--components 15 "
"--input-type coords "
"--predictor 0 "
"--phenotype-classifier 0 "
"--variational True "
"--loss {wildcards.loss} "
"--kl-warmup 20 "
"--mmd-warmup 0 "
"--encoding-size 2 "
"--batch-size 256 "
"--window-size 24 "
"--window-step 12 "
"--output-path {outpath}coarse_hyperparameter_tuning "
"--hyperparameter-tuning hyperband "
"--hpt-trials 3"
# rule explore_encoding_dimension_and_loss_function:
# rule coarse_hyperparameter_tuning:
# input:
# data_path=ancient("/psycl/g/mpsstatgen/lucas/DLC/DLC_models/deepof_single_topview/"),
# data_path="/psycl/g/mpsstatgen/lucas/DLC/DLC_models/deepof_single_topview/",
# output:
# trained_models=os.path.join(
# outpath,
# "dimension_and_loss_experiments/trained_weights/GMVAE_loss={loss}_encoding={encs}_k={k}_run_1_final_weights.h5",
# "hyperparameter_tuning/trained_weights/GMVAE_loss={loss}_encoding=2_run_1_final_weights.h5",
# ),
# shell:
# "pipenv run python -m deepof.train_model "
# "--train-path {input.data_path} "
# "--val-num 5 "
# "--components {wildcards.k} "
# "--val-num 25 "
# "--components 15 "
# "--input-type coords "
# "--predictor 0 "
# "--phenotype-classifier 0 "
# "--variational True "
# "--loss {wildcards.loss} "
# "--kl-warmup 20 "
# "--mmd-warmup 20 "
# "--montecarlo-kl 10 "
# "--encoding-size {wildcards.encs} "
# "--mmd-warmup 0 "
# "--encoding-size 2 "
# "--batch-size 256 "
# "--window-size 24 "
# "--window-step 6 "
# "--exclude-bodyparts Tail_base,Tail_1,Tail_2,Tail_tip "
# "--stability-check 3 "
# "--output-path {outpath}dimension_and_loss_experiments"
# "--window-step 12 "
# "--output-path {outpath}coarse_hyperparameter_tuning "
# "--hyperparameter-tuning hyperband "
# "--hpt-trials 3"
rule latent_regularization_experiments:
input:
data_path=ancient("/psycl/g/mpsstatgen/lucas/DLC/DLC_models/deepof_single_topview/"),
output:
trained_models=os.path.join(
outpath,
"latent_regularization_experiments/trained_weights/GMVAE_loss={loss}_encoding={encs}_k={k}_latreg={latreg}_final_weights.h5",
),
shell:
"pipenv run python -m deepof.train_model "
"--train-path {input.data_path} "
"--val-num 5 "
"--components {wildcards.k} "
"--input-type coords "
"--predictor 0 "
"--phenotype-classifier 0 "
"--variational True "
"--latent-reg {wildcards.latreg} "
"--loss {wildcards.loss} "
"--kl-warmup 20 "
"--mmd-warmup 20 "
"--montecarlo-kl 10 "
"--encoding-size {wildcards.encs} "
"--batch-size 256 "
"--window-size 24 "
"--window-step 6 "
"--exclude-bodyparts Tail_base,Tail_1,Tail_2,Tail_tip "
"--stability-check 3 "
"--output-path {outpath}latent_regularization_experiments"
#
#
# rule explore_phenotype_classification:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment