diff --git a/deepof/data.py b/deepof/data.py index d942f710b322299e95c938de64b0c80b894591b0..27622f9ee4e08614eb46ab024e9c8a890532909f 100644 --- a/deepof/data.py +++ b/deepof/data.py @@ -903,6 +903,7 @@ class coordinates: reg_cluster_variance: bool = False, entropy_samples: int = 10000, entropy_knn: int = 100, + input_type: str = False, run: int = 0, ) -> Tuple: """ @@ -967,6 +968,7 @@ class coordinates: reg_cluster_variance=reg_cluster_variance, entropy_samples=entropy_samples, entropy_knn=entropy_knn, + input_type=input_type, run=run, ) diff --git a/deepof/train_model.py b/deepof/train_model.py index fab73c23b4a52796a95a0ca0e717379db1985081..ad6adbe664b78f214dc52b13929651cacffc9259 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -403,6 +403,7 @@ if not tune: reg_cluster_variance=("variance" in latent_reg), entropy_samples=entropy_samples, entropy_knn=entropy_knn, + input_type=input_type, run=run, ) @@ -413,16 +414,17 @@ else: run_ID, tensorboard_callback, entropy, onecycle = get_callbacks( X_train=X_train, - X_val=(X_val if X_val.shape != (0,) else None), batch_size=batch_size, - cp=False, variational=variational, - entropy_samples=entropy_samples, - entropy_knn=entropy_knn, - next_sequence_prediction=next_sequence_prediction, phenotype_prediction=phenotype_prediction, + next_sequence_prediction=next_sequence_prediction, rule_based_prediction=rule_base_prediction, loss=loss, + X_val=(X_val if X_val.shape != (0,) else None), + input_type=input_type, + cp=False, + entropy_samples=entropy_samples, + entropy_knn=entropy_knn, logparam=logparam, outpath=output_path, run=run, diff --git a/deepof/train_utils.py b/deepof/train_utils.py index c784233927582d5466977cd70c5ee14a87e22f42..4057fbf8acbdfd1ca54d389686958037031f6afc 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -74,6 +74,7 @@ def get_callbacks( rule_based_prediction: float, loss: str, X_val: np.array = None, + input_type: str = False, cp: bool = False, reg_cat_clusters: bool = False, reg_cluster_variance: bool = False, @@ -86,8 +87,10 @@ def get_callbacks( """Generates callbacks for model training, including: - run_ID: run name, with coarse parameter details; - tensorboard_callback: for real-time visualization; - - cp_callback: for checkpoint saving, - - onecycle: for learning rate scheduling""" + - cp_callback: for checkpoint saving; + - onecycle: for learning rate scheduling; + - entropy: neighborhood entropy in the latent space; + """ latreg = "none" if reg_cat_clusters and not reg_cluster_variance: @@ -99,6 +102,7 @@ def get_callbacks( run_ID = "{}{}{}{}{}{}{}_{}".format( ("GMVAE" if variational else "AE"), + ("_input_type={}".format(input_type) if input_type else "coords"), ("_NextSeqPred={}".format(next_sequence_prediction) if variational else ""), ("_PhenoPred={}".format(phenotype_prediction) if variational else ""), ("_RuleBasedPred={}".format(rule_based_prediction) if variational else ""), @@ -293,6 +297,7 @@ def autoencoder_fitting( reg_cluster_variance: bool, entropy_samples: int, entropy_knn: int, + input_type: str, run: int = 0, ): """Implementation function for deepof.data.coordinates.deep_unsupervised_embedding""" @@ -315,18 +320,19 @@ def autoencoder_fitting( # Load callbacks run_ID, *cbacks = get_callbacks( X_train=X_train, - X_val=(X_val if X_val.shape != (0,) else None), batch_size=batch_size, - cp=save_checkpoints, variational=variational, - next_sequence_prediction=next_sequence_prediction, phenotype_prediction=phenotype_prediction, + next_sequence_prediction=next_sequence_prediction, rule_based_prediction=rule_based_prediction, loss=loss, - entropy_samples=entropy_samples, - entropy_knn=entropy_knn, + input_type=input_type, + X_val=(X_val if X_val.shape != (0,) else None), + cp=save_checkpoints, reg_cat_clusters=reg_cat_clusters, reg_cluster_variance=reg_cluster_variance, + entropy_samples=entropy_samples, + entropy_knn=entropy_knn, logparam=logparam, outpath=output_path, run=run, diff --git a/deepof_experiments.smk b/deepof_experiments.smk index 0666e506124335db1e3840ac94edadbf7ce32189..d39d526087334461f76d02d52158f3194ac3460a 100644 --- a/deepof_experiments.smk +++ b/deepof_experiments.smk @@ -15,7 +15,7 @@ import os outpath = "/psycl/g/mpsstatgen/lucas/DLC/DLC_autoencoders/DeepOF/deepof/logs/" -losses = ["ELBO"] #, "MMD", "ELBO+MMD"] +losses = ["ELBO"] # , "MMD", "ELBO+MMD"] encodings = [6] # [2, 4, 6, 8, 10, 12, 14, 16] cluster_numbers = [25] # [1, 5, 10, 15, 20, 25] latent_reg = ["none"] # ["none", "categorical", "variance", "categorical+variance"] diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py index 249195114cdecbefd58b153f16d012646d0864d7..abdfd7509e1ca3e98c912936ccfd54051dd613a4 100644 --- a/tests/test_train_utils.py +++ b/tests/test_train_utils.py @@ -60,11 +60,12 @@ def test_get_callbacks( X_train=X_train, batch_size=batch_size, variational=variational, - next_sequence_prediction=next_sequence_prediction, phenotype_prediction=phenotype_prediction, + next_sequence_prediction=next_sequence_prediction, rule_based_prediction=rule_based_prediction, loss=loss, X_val=X_train, + input_type=False, cp=True, reg_cat_clusters=False, reg_cluster_variance=False, @@ -179,11 +180,12 @@ def test_tune_search( X_train=X_train, batch_size=batch_size, variational=(hypermodel == "S2SGMVAE"), - next_sequence_prediction=next_sequence_prediction, phenotype_prediction=phenotype_prediction, + next_sequence_prediction=next_sequence_prediction, rule_based_prediction=rule_based_prediction, loss=loss, X_val=X_train, + input_type=False, cp=False, reg_cat_clusters=True, reg_cluster_variance=True,