diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index b5763c28bdf4c90876f9ddfbf394435d477ef067..37356da85d05ccbe0496958e7d8e4e4384d3986e 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor):
 @tf.function
 def get_k_nearest_neighbors(tensor, k, index):
     """Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
-    query = tensor[index]
+    query = tf.slice(tensor, index, 1)
     distances = tf.norm(tensor - query, axis=1)
     max_distance = tf.sort(distances)[k]
     neighbourhood_mask = distances < max_distance
@@ -583,7 +583,7 @@ class ClusterOverlap(Layer):
         config.update({"samples": self.samples})
         return config
 
-    #@tf.function
+    @tf.function
     def call(self, inputs, **kwargs):
         """Updates Layer's call method"""
 
@@ -594,8 +594,12 @@ class ClusterOverlap(Layer):
 
         # Iterate over samples and compute purity across neighbourhood
         self.samples = tf.reduce_min([self.samples, tf.shape(encodings)[0]])
-        random_idxs = range(encodings.shape[0])
-        random_idxs = np.random.choice(random_idxs, self.samples)
+        random_idxs = tf.range(tf.shape(encodings)[0])
+        random_idxs = tf.random.categorical(
+            tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0),
+            self.samples,
+            dtype=tf.dtypes.int32,
+        )
 
         @tf.function
         def get_local_neighbourhood_entropy(index):
@@ -604,7 +608,7 @@ class ClusterOverlap(Layer):
             )
 
         purity_vector = tf.map_fn(
-            get_local_neighbourhood_entropy, tf.constant(random_idxs), dtype=tf.dtypes.float32
+            get_local_neighbourhood_entropy, random_idxs, dtype=tf.dtypes.float32
         )
 
         ### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###