diff --git a/deepof/model_utils.py b/deepof/model_utils.py index 8321e758f72c97cc2afec7aea97cdb7ace4ffe9f..d355acdbcff2c4ef90a810b241ed662857a6f987 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -10,7 +10,7 @@ Functions and general utilities for the deepof tensorflow models. See documentat from itertools import combinations from typing import Any, Tuple -from sklearn.neighbors import NearestNeighbors +from sklearn.metrics import pairwise_distances from tensorflow.keras import backend as K from tensorflow.keras.constraints import Constraint from tensorflow.keras.layers import Layer @@ -205,17 +205,17 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback): class knn_cluster_purity(tf.keras.callbacks.Callback): """ - Cluster purity callback. Computes assignment purity over K nearest neighbors in the latent space + Cluster entropy callback. Computes assignment local entropy over a neighborhood of radius r in the latent space """ def __init__( - self, variational=True, validation_data=None, k=100, samples=10000, log_dir="." + self, variational=True, validation_data=None, r=100, samples=10000, log_dir="." ): super().__init__() self.variational = variational self.validation_data = validation_data - self.k = k + self.r = r self.samples = samples self.log_dir = log_dir @@ -253,21 +253,22 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): ) hard_groups = groups.argmax(axis=1) - # Fit KNN model - knn = NearestNeighbors().fit(encoding) + # compute pairwise distances on latent space + pdist = pairwise_distances(encoding) - # Iterate over samples and compute purity over k neighbours + # Iterate over samples and compute purity across neighbourhood self.samples = np.min([self.samples, encoding.shape[0]]) random_idxs = np.random.choice( range(encoding.shape[0]), self.samples, replace=False ) purity_vector = np.zeros(self.samples) + for i, sample in enumerate(random_idxs): - indexes = knn.kneighbors( - encoding[sample][np.newaxis, :], self.k, return_distance=False - ) + + neighborhood = pdist[sample] < self.r + purity_vector[i] = ( - np.sum(hard_groups[indexes] == hard_groups[sample]) + np.sum(hard_groups[neighborhood] == hard_groups[sample]) / self.k * np.max(groups[sample]) ) @@ -275,7 +276,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): writer = tf.summary.create_file_writer(self.log_dir) with writer.as_default(): tf.summary.scalar( - "knn_cluster_purity", + "neighborhood_cluster_purity", data=purity_vector.mean(), step=epoch, )