From b8089d68837587c6262b5f7f0f919cad07dca78b Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Wed, 19 May 2021 01:49:03 +0200
Subject: [PATCH] Replaced for loop with vectorised mapping on ClusterOverlap
 regularization layer

---
 deepof/model_utils.py | 32 +++++++++++++++++++-------------
 deepof/models.py      | 13 ++++++-------
 2 files changed, 25 insertions(+), 20 deletions(-)

diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index 064d4cb2..585623e9 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -612,19 +612,25 @@ class ClusterOverlap(Layer):
         )
 
         ### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
-        neighbourhood_entropy = purity_vector  # * max_groups[random_idxs]
-
-        # self.add_metric(
-        #     len(set(hard_groups[max_groups >= self.min_confidence])),
-        #     aggregation="mean",
-        #     name="number_of_populated_clusters",
-        # )
-        #
-        # self.add_metric(
-        #     max_groups,
-        #     aggregation="mean",
-        #     name="average_confidence_in_selected_cluster",
-        # )
+        neighbourhood_entropy = purity_vector * tf.gather(
+            max_groups, tf.constant(random_idxs)
+        )
+
+        self.add_metric(
+            tf.shape(
+                tf.unique(
+                    tf.squeeze(tf.gather(hard_groups, tf.where(max_groups >= self.min_confidence)))
+                )
+            ),
+            aggregation="mean",
+            name="number_of_populated_clusters",
+        )
+
+        self.add_metric(
+            max_groups,
+            aggregation="mean",
+            name="average_confidence_in_selected_cluster",
+        )
 
         self.add_metric(
             neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy"
diff --git a/deepof/models.py b/deepof/models.py
index c8a61fc3..10a18709 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -473,13 +473,12 @@ class GMVAE:
         # Dummy layer with no parameters, to retrieve the previous tensor
         z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z)
 
-        if self.overlap_loss:
-            z = deepof.model_utils.ClusterOverlap(
-                self.batch_size,
-                self.ENCODING,
-                self.number_of_components,
-                loss_weight=self.overlap_loss,
-            )([z, z_cat])
+        z = deepof.model_utils.ClusterOverlap(
+            self.batch_size,
+            self.ENCODING,
+            self.number_of_components,
+            loss_weight=self.overlap_loss,
+        )([z, z_cat])
 
         # Define and instantiate generator
         g = Input(shape=self.ENCODING)
-- 
GitLab