From 4cf9fb3f64a29b843cd8b7741d640d080db1256e Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Thu, 27 May 2021 17:44:58 +0200
Subject: [PATCH] Replaced for loop with vectorised mapping on ClusterOverlap
 regularization layer

---
 deepof/model_utils.py  | 31 ++++++++++++++++++-------------
 deepof_experiments.smk |  2 +-
 2 files changed, 19 insertions(+), 14 deletions(-)

diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index e753bf43..211dc070 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -619,21 +619,23 @@ class ClusterOverlap(Layer):
             max_groups, tf.constant(random_idxs)
         )
 
-        self.add_metric(
-            tf.cast(
-                tf.shape(
-                    tf.unique(
-                        tf.reshape(
-                            tf.gather(
-                                tf.cast(hard_groups, tf.dtypes.float32),
-                                tf.where(max_groups >= self.min_confidence),
-                            ),
-                            [-1],
+        number_of_clusters = tf.cast(
+            tf.shape(
+                tf.unique(
+                    tf.reshape(
+                        tf.gather(
+                            tf.cast(hard_groups, tf.dtypes.float32),
+                            tf.where(max_groups >= self.min_confidence),
                         ),
-                    )[0],
+                        [-1],
+                    ),
                 )[0],
-                tf.dtypes.float32,
-            ),
+            )[0],
+            tf.dtypes.float32,
+        )
+
+        self.add_metric(
+            number_of_clusters,
             name="number_of_populated_clusters",
         )
 
@@ -648,6 +650,9 @@ class ClusterOverlap(Layer):
         )
 
         if self.loss_weight:
+            # minimize local entropy
             self.add_loss(self.loss_weight * tf.reduce_mean(neighbourhood_entropy))
+            # maximize number of clusters
+            self.add_loss(-self.loss_weight * tf.reduce_mean(number_of_clusters))
 
         return encodings
diff --git a/deepof_experiments.smk b/deepof_experiments.smk
index 4a1490fa..cf1e7b4c 100644
--- a/deepof_experiments.smk
+++ b/deepof_experiments.smk
@@ -28,7 +28,7 @@ phenotype_pred_weights = [0.0]
 rule_based_pred_weights = [0.0]
 window_lengths = [22]  # range(11,56,11)
 input_types = ["coords"]
-run = list(range(1, 11))
+run = list(range(1, 4))
 
 
 rule deepof_experiments:
-- 
GitLab