From 81c668cd6718ad9423673d8620ae6e83b498f53b Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Wed, 19 May 2021 16:48:34 +0200
Subject: [PATCH] Replaced for loop with vectorised mapping on ClusterOverlap
 regularization layer

---
 deepof/model_utils.py     |  5 ++++-
 deepof_experiments.smk    | 13 +++++++++----
 tests/test_train_utils.py |  4 +++-
 3 files changed, 16 insertions(+), 6 deletions(-)

diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index 5f957fdf..783d0857 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -602,7 +602,10 @@ class ClusterOverlap(Layer):
         random_idxs = np.random.choice(random_idxs, self.samples)
 
         get_local_neighbourhood_entropy = partial(
-            get_neighbourhood_entropy, tensor=encodings, clusters=hard_groups, k=self.k
+            get_neighbourhood_entropy,
+            tensor=encodings,
+            clusters=hard_groups,
+            k=self.k,
         )
 
         purity_vector = tf.map_fn(
diff --git a/deepof_experiments.smk b/deepof_experiments.smk
index f34458aa..d2e82454 100644
--- a/deepof_experiments.smk
+++ b/deepof_experiments.smk
@@ -18,10 +18,11 @@ outpath = "/psycl/g/mpsstatgen/lucas/DLC/DLC_autoencoders/DeepOF/deepof/logs/"
 warmup_epochs = [15]
 warmup_mode = ["sigmoid"]
 losses = ["ELBO"]  # , "MMD", "ELBO+MMD"]
+overlap_loss = [0.1, 0.2, 0.5, 0.75, 1.]
 encodings = [6]  # [2, 4, 6, 8, 10, 12, 14, 16]
 cluster_numbers = [15]  # [1, 5, 10, 15, 20, 25]
 latent_reg = ["variance"]  # ["none", "categorical", "variance", "categorical+variance"]
-entropy_knn = [100]
+entropy_knn = [10]
 next_sequence_pred_weights = [0.15]
 phenotype_pred_weights = [0.0]
 rule_based_pred_weights = [0.0]
@@ -51,10 +52,11 @@ rule deepof_experiments:
             outpath + "train_models/trained_weights/"
             "GMVAE_input_type={input_type}_"
             "window_size={window_size}_"
-            "NextSeqPred={nspredweight}_"
-            "PhenoPred={phenpredweight}_"
-            "RuleBasedPred={rulesweight}_"
+            "NSPred={nspredweight}_"
+            "PPred={phenpredweight}_"
+            "RBPred={rulesweight}_"
             "loss={loss}_"
+            "overlap_loss={overlap_loss}_"
             "loss_warmup={warmup}_"
             "warmup_mode={warmup_mode}_"
             "encoding={encs}_"
@@ -66,6 +68,7 @@ rule deepof_experiments:
             input_type=input_types,
             window_size=window_lengths,
             loss=losses,
+            overlap_loss=overlap_loss,
             warmup=warmup_epochs,
             warmup_mode=warmup_mode,
             encs=encodings,
@@ -134,6 +137,7 @@ rule train_models:
         "PhenoPred={phenpredweight}_"
         "RuleBasedPred={rulesweight}_"
         "loss={loss}_"
+        "overlap_loss={overlap_loss}_"                           
         "loss_warmup={warmup}_"
         "warmup_mode={warmup_mode}_"
         "encoding={encs}_"
@@ -153,6 +157,7 @@ rule train_models:
         "--rule-based-prediction {wildcards.rulesweight} "
         "--latent-reg {wildcards.latreg} "
         "--loss {wildcards.loss} "
+        "--overlap_loss {wildcards.overlap_loss} "
         "--kl-annealing-mode {wildcards.warmup_mode} "
         "--kl-warmup {wildcards.warmup} "
         "--mmd-annealing-mode {wildcards.warmup_mode} "
diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py
index 07f21b26..8fb9450e 100644
--- a/tests/test_train_utils.py
+++ b/tests/test_train_utils.py
@@ -96,8 +96,10 @@ def test_autoencoder_fitting(
     phenotype_prediction,
     rule_based_prediction,
 ):
+
     X_train = np.random.uniform(-1, 1, [20, 5, 6])
     y_train = np.round(np.random.uniform(0, 1, [20, 1]))
+
     if rule_based_prediction:
         y_train = np.concatenate(
             [y_train, np.round(np.random.uniform(0, 1, [20, 6]), 1)], axis=1
@@ -117,7 +119,7 @@ def test_autoencoder_fitting(
 
     prun.deep_unsupervised_embedding(
         preprocessed_data,
-        batch_size=100,
+        batch_size=10,
         encoding_size=2,
         epochs=1,
         kl_warmup=1,
-- 
GitLab