Commit 1127311d authored by lucas_miranda's avatar lucas_miranda
Browse files

Modified cluster purity computation. Instead of KNN, we now look at...

Modified cluster purity computation. Instead of KNN, we now look at neighborhoods of a predefined radius
parent aa6dfb9f
...@@ -10,7 +10,7 @@ Functions and general utilities for the deepof tensorflow models. See documentat ...@@ -10,7 +10,7 @@ Functions and general utilities for the deepof tensorflow models. See documentat
from itertools import combinations from itertools import combinations
from typing import Any, Tuple from typing import Any, Tuple
from sklearn.neighbors import NearestNeighbors from sklearn.metrics import pairwise_distances
from tensorflow.keras import backend as K from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer from tensorflow.keras.layers import Layer
...@@ -205,17 +205,17 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback): ...@@ -205,17 +205,17 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
class knn_cluster_purity(tf.keras.callbacks.Callback): class knn_cluster_purity(tf.keras.callbacks.Callback):
""" """
Cluster purity callback. Computes assignment purity over K nearest neighbors in the latent space Cluster entropy callback. Computes assignment local entropy over a neighborhood of radius r in the latent space
""" """
def __init__( def __init__(
self, variational=True, validation_data=None, k=100, samples=10000, log_dir="." self, variational=True, validation_data=None, r=100, samples=10000, log_dir="."
): ):
super().__init__() super().__init__()
self.variational = variational self.variational = variational
self.validation_data = validation_data self.validation_data = validation_data
self.k = k self.r = r
self.samples = samples self.samples = samples
self.log_dir = log_dir self.log_dir = log_dir
...@@ -253,21 +253,22 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): ...@@ -253,21 +253,22 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
) )
hard_groups = groups.argmax(axis=1) hard_groups = groups.argmax(axis=1)
# Fit KNN model # compute pairwise distances on latent space
knn = NearestNeighbors().fit(encoding) pdist = pairwise_distances(encoding)
# Iterate over samples and compute purity over k neighbours # Iterate over samples and compute purity across neighbourhood
self.samples = np.min([self.samples, encoding.shape[0]]) self.samples = np.min([self.samples, encoding.shape[0]])
random_idxs = np.random.choice( random_idxs = np.random.choice(
range(encoding.shape[0]), self.samples, replace=False range(encoding.shape[0]), self.samples, replace=False
) )
purity_vector = np.zeros(self.samples) purity_vector = np.zeros(self.samples)
for i, sample in enumerate(random_idxs): for i, sample in enumerate(random_idxs):
indexes = knn.kneighbors(
encoding[sample][np.newaxis, :], self.k, return_distance=False neighborhood = pdist[sample] < self.r
)
purity_vector[i] = ( purity_vector[i] = (
np.sum(hard_groups[indexes] == hard_groups[sample]) np.sum(hard_groups[neighborhood] == hard_groups[sample])
/ self.k / self.k
* np.max(groups[sample]) * np.max(groups[sample])
) )
...@@ -275,7 +276,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): ...@@ -275,7 +276,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
writer = tf.summary.create_file_writer(self.log_dir) writer = tf.summary.create_file_writer(self.log_dir)
with writer.as_default(): with writer.as_default():
tf.summary.scalar( tf.summary.scalar(
"knn_cluster_purity", "neighborhood_cluster_purity",
data=purity_vector.mean(), data=purity_vector.mean(),
step=epoch, step=epoch,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment