From 04f4192a081fdc2ef7f9de37994c63a6cae602c3 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Tue, 18 May 2021 19:26:03 +0200
Subject: [PATCH] Replaced for loop with vectorised mapping on ClusterOverlap
 regularization layer

---
 deepof/model_utils.py | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index b5763c28..37356da8 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor):
 @tf.function
 def get_k_nearest_neighbors(tensor, k, index):
     """Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
-    query = tensor[index]
+    query = tf.slice(tensor, index, 1)
     distances = tf.norm(tensor - query, axis=1)
     max_distance = tf.sort(distances)[k]
     neighbourhood_mask = distances < max_distance
@@ -583,7 +583,7 @@ class ClusterOverlap(Layer):
         config.update({"samples": self.samples})
         return config
 
-    #@tf.function
+    @tf.function
     def call(self, inputs, **kwargs):
         """Updates Layer's call method"""
 
@@ -594,8 +594,12 @@ class ClusterOverlap(Layer):
 
         # Iterate over samples and compute purity across neighbourhood
         self.samples = tf.reduce_min([self.samples, tf.shape(encodings)[0]])
-        random_idxs = range(encodings.shape[0])
-        random_idxs = np.random.choice(random_idxs, self.samples)
+        random_idxs = tf.range(tf.shape(encodings)[0])
+        random_idxs = tf.random.categorical(
+            tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0),
+            self.samples,
+            dtype=tf.dtypes.int32,
+        )
 
         @tf.function
         def get_local_neighbourhood_entropy(index):
@@ -604,7 +608,7 @@ class ClusterOverlap(Layer):
             )
 
         purity_vector = tf.map_fn(
-            get_local_neighbourhood_entropy, tf.constant(random_idxs), dtype=tf.dtypes.float32
+            get_local_neighbourhood_entropy, random_idxs, dtype=tf.dtypes.float32
         )
 
         ### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
-- 
GitLab