From 8017c6d29991279f51098185d1ac143e9947f71e Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Thu, 15 Apr 2021 12:55:34 +0200 Subject: [PATCH] Added extra branch to main autoencoder for rule_based prediction --- deepof/data.py | 2 ++ deepof/train_model.py | 12 +++++++----- deepof/train_utils.py | 20 +++++++++++++------- deepof_experiments.smk | 2 +- tests/test_train_utils.py | 6 ++++-- 5 files changed, 27 insertions(+), 15 deletions(-) diff --git a/deepof/data.py b/deepof/data.py index d942f710..27622f9e 100644 --- a/deepof/data.py +++ b/deepof/data.py @@ -903,6 +903,7 @@ class coordinates: reg_cluster_variance: bool = False, entropy_samples: int = 10000, entropy_knn: int = 100, + input_type: str = False, run: int = 0, ) -> Tuple: """ @@ -967,6 +968,7 @@ class coordinates: reg_cluster_variance=reg_cluster_variance, entropy_samples=entropy_samples, entropy_knn=entropy_knn, + input_type=input_type, run=run, ) diff --git a/deepof/train_model.py b/deepof/train_model.py index fab73c23..ad6adbe6 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -403,6 +403,7 @@ if not tune: reg_cluster_variance=("variance" in latent_reg), entropy_samples=entropy_samples, entropy_knn=entropy_knn, + input_type=input_type, run=run, ) @@ -413,16 +414,17 @@ else: run_ID, tensorboard_callback, entropy, onecycle = get_callbacks( X_train=X_train, - X_val=(X_val if X_val.shape != (0,) else None), batch_size=batch_size, - cp=False, variational=variational, - entropy_samples=entropy_samples, - entropy_knn=entropy_knn, - next_sequence_prediction=next_sequence_prediction, phenotype_prediction=phenotype_prediction, + next_sequence_prediction=next_sequence_prediction, rule_based_prediction=rule_base_prediction, loss=loss, + X_val=(X_val if X_val.shape != (0,) else None), + input_type=input_type, + cp=False, + entropy_samples=entropy_samples, + entropy_knn=entropy_knn, logparam=logparam, outpath=output_path, run=run, diff --git a/deepof/train_utils.py b/deepof/train_utils.py index c7842339..4057fbf8 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -74,6 +74,7 @@ def get_callbacks( rule_based_prediction: float, loss: str, X_val: np.array = None, + input_type: str = False, cp: bool = False, reg_cat_clusters: bool = False, reg_cluster_variance: bool = False, @@ -86,8 +87,10 @@ def get_callbacks( """Generates callbacks for model training, including: - run_ID: run name, with coarse parameter details; - tensorboard_callback: for real-time visualization; - - cp_callback: for checkpoint saving, - - onecycle: for learning rate scheduling""" + - cp_callback: for checkpoint saving; + - onecycle: for learning rate scheduling; + - entropy: neighborhood entropy in the latent space; + """ latreg = "none" if reg_cat_clusters and not reg_cluster_variance: @@ -99,6 +102,7 @@ def get_callbacks( run_ID = "{}{}{}{}{}{}{}_{}".format( ("GMVAE" if variational else "AE"), + ("_input_type={}".format(input_type) if input_type else "coords"), ("_NextSeqPred={}".format(next_sequence_prediction) if variational else ""), ("_PhenoPred={}".format(phenotype_prediction) if variational else ""), ("_RuleBasedPred={}".format(rule_based_prediction) if variational else ""), @@ -293,6 +297,7 @@ def autoencoder_fitting( reg_cluster_variance: bool, entropy_samples: int, entropy_knn: int, + input_type: str, run: int = 0, ): """Implementation function for deepof.data.coordinates.deep_unsupervised_embedding""" @@ -315,18 +320,19 @@ def autoencoder_fitting( # Load callbacks run_ID, *cbacks = get_callbacks( X_train=X_train, - X_val=(X_val if X_val.shape != (0,) else None), batch_size=batch_size, - cp=save_checkpoints, variational=variational, - next_sequence_prediction=next_sequence_prediction, phenotype_prediction=phenotype_prediction, + next_sequence_prediction=next_sequence_prediction, rule_based_prediction=rule_based_prediction, loss=loss, - entropy_samples=entropy_samples, - entropy_knn=entropy_knn, + input_type=input_type, + X_val=(X_val if X_val.shape != (0,) else None), + cp=save_checkpoints, reg_cat_clusters=reg_cat_clusters, reg_cluster_variance=reg_cluster_variance, + entropy_samples=entropy_samples, + entropy_knn=entropy_knn, logparam=logparam, outpath=output_path, run=run, diff --git a/deepof_experiments.smk b/deepof_experiments.smk index 0666e506..d39d5260 100644 --- a/deepof_experiments.smk +++ b/deepof_experiments.smk @@ -15,7 +15,7 @@ import os outpath = "/psycl/g/mpsstatgen/lucas/DLC/DLC_autoencoders/DeepOF/deepof/logs/" -losses = ["ELBO"] #, "MMD", "ELBO+MMD"] +losses = ["ELBO"] # , "MMD", "ELBO+MMD"] encodings = [6] # [2, 4, 6, 8, 10, 12, 14, 16] cluster_numbers = [25] # [1, 5, 10, 15, 20, 25] latent_reg = ["none"] # ["none", "categorical", "variance", "categorical+variance"] diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py index 24919511..abdfd750 100644 --- a/tests/test_train_utils.py +++ b/tests/test_train_utils.py @@ -60,11 +60,12 @@ def test_get_callbacks( X_train=X_train, batch_size=batch_size, variational=variational, - next_sequence_prediction=next_sequence_prediction, phenotype_prediction=phenotype_prediction, + next_sequence_prediction=next_sequence_prediction, rule_based_prediction=rule_based_prediction, loss=loss, X_val=X_train, + input_type=False, cp=True, reg_cat_clusters=False, reg_cluster_variance=False, @@ -179,11 +180,12 @@ def test_tune_search( X_train=X_train, batch_size=batch_size, variational=(hypermodel == "S2SGMVAE"), - next_sequence_prediction=next_sequence_prediction, phenotype_prediction=phenotype_prediction, + next_sequence_prediction=next_sequence_prediction, rule_based_prediction=rule_based_prediction, loss=loss, X_val=X_train, + input_type=False, cp=False, reg_cat_clusters=True, reg_cluster_variance=True, -- GitLab