From 8017c6d29991279f51098185d1ac143e9947f71e Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Thu, 15 Apr 2021 12:55:34 +0200
Subject: [PATCH] Added extra branch to main autoencoder for rule_based
 prediction

---
 deepof/data.py            |  2 ++
 deepof/train_model.py     | 12 +++++++-----
 deepof/train_utils.py     | 20 +++++++++++++-------
 deepof_experiments.smk    |  2 +-
 tests/test_train_utils.py |  6 ++++--
 5 files changed, 27 insertions(+), 15 deletions(-)

diff --git a/deepof/data.py b/deepof/data.py
index d942f710..27622f9e 100644
--- a/deepof/data.py
+++ b/deepof/data.py
@@ -903,6 +903,7 @@ class coordinates:
         reg_cluster_variance: bool = False,
         entropy_samples: int = 10000,
         entropy_knn: int = 100,
+        input_type: str = False,
         run: int = 0,
     ) -> Tuple:
         """
@@ -967,6 +968,7 @@ class coordinates:
             reg_cluster_variance=reg_cluster_variance,
             entropy_samples=entropy_samples,
             entropy_knn=entropy_knn,
+            input_type=input_type,
             run=run,
         )
 
diff --git a/deepof/train_model.py b/deepof/train_model.py
index fab73c23..ad6adbe6 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -403,6 +403,7 @@ if not tune:
         reg_cluster_variance=("variance" in latent_reg),
         entropy_samples=entropy_samples,
         entropy_knn=entropy_knn,
+        input_type=input_type,
         run=run,
     )
 
@@ -413,16 +414,17 @@ else:
 
     run_ID, tensorboard_callback, entropy, onecycle = get_callbacks(
         X_train=X_train,
-        X_val=(X_val if X_val.shape != (0,) else None),
         batch_size=batch_size,
-        cp=False,
         variational=variational,
-        entropy_samples=entropy_samples,
-        entropy_knn=entropy_knn,
-        next_sequence_prediction=next_sequence_prediction,
         phenotype_prediction=phenotype_prediction,
+        next_sequence_prediction=next_sequence_prediction,
         rule_based_prediction=rule_base_prediction,
         loss=loss,
+        X_val=(X_val if X_val.shape != (0,) else None),
+        input_type=input_type,
+        cp=False,
+        entropy_samples=entropy_samples,
+        entropy_knn=entropy_knn,
         logparam=logparam,
         outpath=output_path,
         run=run,
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index c7842339..4057fbf8 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -74,6 +74,7 @@ def get_callbacks(
     rule_based_prediction: float,
     loss: str,
     X_val: np.array = None,
+    input_type: str = False,
     cp: bool = False,
     reg_cat_clusters: bool = False,
     reg_cluster_variance: bool = False,
@@ -86,8 +87,10 @@ def get_callbacks(
     """Generates callbacks for model training, including:
     - run_ID: run name, with coarse parameter details;
     - tensorboard_callback: for real-time visualization;
-    - cp_callback: for checkpoint saving,
-    - onecycle: for learning rate scheduling"""
+    - cp_callback: for checkpoint saving;
+    - onecycle: for learning rate scheduling;
+    - entropy: neighborhood entropy in the latent space;
+    """
 
     latreg = "none"
     if reg_cat_clusters and not reg_cluster_variance:
@@ -99,6 +102,7 @@ def get_callbacks(
 
     run_ID = "{}{}{}{}{}{}{}_{}".format(
         ("GMVAE" if variational else "AE"),
+        ("_input_type={}".format(input_type) if input_type else "coords"),
         ("_NextSeqPred={}".format(next_sequence_prediction) if variational else ""),
         ("_PhenoPred={}".format(phenotype_prediction) if variational else ""),
         ("_RuleBasedPred={}".format(rule_based_prediction) if variational else ""),
@@ -293,6 +297,7 @@ def autoencoder_fitting(
     reg_cluster_variance: bool,
     entropy_samples: int,
     entropy_knn: int,
+    input_type: str,
     run: int = 0,
 ):
     """Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
@@ -315,18 +320,19 @@ def autoencoder_fitting(
     # Load callbacks
     run_ID, *cbacks = get_callbacks(
         X_train=X_train,
-        X_val=(X_val if X_val.shape != (0,) else None),
         batch_size=batch_size,
-        cp=save_checkpoints,
         variational=variational,
-        next_sequence_prediction=next_sequence_prediction,
         phenotype_prediction=phenotype_prediction,
+        next_sequence_prediction=next_sequence_prediction,
         rule_based_prediction=rule_based_prediction,
         loss=loss,
-        entropy_samples=entropy_samples,
-        entropy_knn=entropy_knn,
+        input_type=input_type,
+        X_val=(X_val if X_val.shape != (0,) else None),
+        cp=save_checkpoints,
         reg_cat_clusters=reg_cat_clusters,
         reg_cluster_variance=reg_cluster_variance,
+        entropy_samples=entropy_samples,
+        entropy_knn=entropy_knn,
         logparam=logparam,
         outpath=output_path,
         run=run,
diff --git a/deepof_experiments.smk b/deepof_experiments.smk
index 0666e506..d39d5260 100644
--- a/deepof_experiments.smk
+++ b/deepof_experiments.smk
@@ -15,7 +15,7 @@ import os
 
 outpath = "/psycl/g/mpsstatgen/lucas/DLC/DLC_autoencoders/DeepOF/deepof/logs/"
 
-losses = ["ELBO"] #, "MMD", "ELBO+MMD"]
+losses = ["ELBO"]  # , "MMD", "ELBO+MMD"]
 encodings = [6]  # [2, 4, 6, 8, 10, 12, 14, 16]
 cluster_numbers = [25]  # [1, 5, 10, 15, 20, 25]
 latent_reg = ["none"]  # ["none", "categorical", "variance", "categorical+variance"]
diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py
index 24919511..abdfd750 100644
--- a/tests/test_train_utils.py
+++ b/tests/test_train_utils.py
@@ -60,11 +60,12 @@ def test_get_callbacks(
         X_train=X_train,
         batch_size=batch_size,
         variational=variational,
-        next_sequence_prediction=next_sequence_prediction,
         phenotype_prediction=phenotype_prediction,
+        next_sequence_prediction=next_sequence_prediction,
         rule_based_prediction=rule_based_prediction,
         loss=loss,
         X_val=X_train,
+        input_type=False,
         cp=True,
         reg_cat_clusters=False,
         reg_cluster_variance=False,
@@ -179,11 +180,12 @@ def test_tune_search(
             X_train=X_train,
             batch_size=batch_size,
             variational=(hypermodel == "S2SGMVAE"),
-            next_sequence_prediction=next_sequence_prediction,
             phenotype_prediction=phenotype_prediction,
+            next_sequence_prediction=next_sequence_prediction,
             rule_based_prediction=rule_based_prediction,
             loss=loss,
             X_val=X_train,
+            input_type=False,
             cp=False,
             reg_cat_clusters=True,
             reg_cluster_variance=True,
-- 
GitLab