diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index 8321e758f72c97cc2afec7aea97cdb7ace4ffe9f..d355acdbcff2c4ef90a810b241ed662857a6f987 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -10,7 +10,7 @@ Functions and general utilities for the deepof tensorflow models. See documentat
 
 from itertools import combinations
 from typing import Any, Tuple
-from sklearn.neighbors import NearestNeighbors
+from sklearn.metrics import pairwise_distances
 from tensorflow.keras import backend as K
 from tensorflow.keras.constraints import Constraint
 from tensorflow.keras.layers import Layer
@@ -205,17 +205,17 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
 class knn_cluster_purity(tf.keras.callbacks.Callback):
     """
 
-    Cluster purity callback. Computes assignment purity over K nearest neighbors in the latent space
+    Cluster entropy callback. Computes assignment local entropy over a neighborhood of radius r in the latent space
 
     """
 
     def __init__(
-        self, variational=True, validation_data=None, k=100, samples=10000, log_dir="."
+        self, variational=True, validation_data=None, r=100, samples=10000, log_dir="."
     ):
         super().__init__()
         self.variational = variational
         self.validation_data = validation_data
-        self.k = k
+        self.r = r
         self.samples = samples
         self.log_dir = log_dir
 
@@ -253,21 +253,22 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
             )
             hard_groups = groups.argmax(axis=1)
 
-            # Fit KNN model
-            knn = NearestNeighbors().fit(encoding)
+            # compute pairwise distances on latent space
+            pdist = pairwise_distances(encoding)
 
-            # Iterate over samples and compute purity over k neighbours
+            # Iterate over samples and compute purity across neighbourhood
             self.samples = np.min([self.samples, encoding.shape[0]])
             random_idxs = np.random.choice(
                 range(encoding.shape[0]), self.samples, replace=False
             )
             purity_vector = np.zeros(self.samples)
+
             for i, sample in enumerate(random_idxs):
-                indexes = knn.kneighbors(
-                    encoding[sample][np.newaxis, :], self.k, return_distance=False
-                )
+
+                neighborhood = pdist[sample] < self.r
+
                 purity_vector[i] = (
-                    np.sum(hard_groups[indexes] == hard_groups[sample])
+                    np.sum(hard_groups[neighborhood] == hard_groups[sample])
                     / self.k
                     * np.max(groups[sample])
                 )
@@ -275,7 +276,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
             writer = tf.summary.create_file_writer(self.log_dir)
             with writer.as_default():
                 tf.summary.scalar(
-                    "knn_cluster_purity",
+                    "neighborhood_cluster_purity",
                     data=purity_vector.mean(),
                     step=epoch,
                 )