Commit 6170ba00 authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated cluster entropy regularization - added custom graph-ready knn and...

Updated cluster entropy regularization - added custom graph-ready knn and Shannon's entropy implementations
parent d56babf5
...@@ -15,8 +15,6 @@ import matplotlib.pyplot as plt ...@@ -15,8 +15,6 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow_probability as tfp import tensorflow_probability as tfp
from scipy.stats import entropy
from sklearn.neighbors import NearestNeighbors
from tensorflow.keras import backend as K from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer from tensorflow.keras.layers import Layer
...@@ -31,8 +29,21 @@ def compute_shannon_entropy(tensor): ...@@ -31,8 +29,21 @@ def compute_shannon_entropy(tensor):
"""Computes Shannon entropy for a given tensor""" """Computes Shannon entropy for a given tensor"""
tensor = tf.cast(tensor, tf.dtypes.int32) tensor = tf.cast(tensor, tf.dtypes.int32)
bins = tf.math.bincount(tensor, dtype=tf.dtypes.float32) / tf.cast(tensor.shape[0], tf.float32) bins = tf.math.bincount(tensor, dtype=tf.dtypes.float32) / tf.cast(
return -tf.reduce_sum( bins * tf.math.log(bins + 1e-5)) tensor.shape[0], tf.float32
)
return -tf.reduce_sum(bins * tf.math.log(bins + 1e-5))
@tf.function
def get_k_nearest_neighbors(tensor, k, index):
"""Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
query = tensor[index]
distances = tf.norm(tensor - query, axis=1)
max_distance = tf.sort(distances)[k]
neighbourhood_mask = distances < max_distance
return tf.transpose(tf.where(neighbourhood_mask))
class exponential_learning_rate(tf.keras.callbacks.Callback): class exponential_learning_rate(tf.keras.callbacks.Callback):
...@@ -267,9 +278,6 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback): ...@@ -267,9 +278,6 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
hard_groups = groups.argmax(axis=1) hard_groups = groups.argmax(axis=1)
max_groups = groups.max(axis=1) max_groups = groups.max(axis=1)
# compute pairwise distances on latent space
knn = NearestNeighbors().fit(encoding)
# Iterate over samples and compute purity across neighbourhood # Iterate over samples and compute purity across neighbourhood
self.samples = np.min([self.samples, encoding.shape[0]]) self.samples = np.min([self.samples, encoding.shape[0]])
random_idxs = np.random.choice( random_idxs = np.random.choice(
...@@ -278,9 +286,8 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback): ...@@ -278,9 +286,8 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
purity_vector = np.zeros(self.samples) purity_vector = np.zeros(self.samples)
for i, sample in enumerate(random_idxs): for i, sample in enumerate(random_idxs):
neighborhood = knn.kneighbors( neighborhood = get_k_nearest_neighbors(encodings, sample, self.k)
encoding[sample][np.newaxis, :], self.k, return_distance=False neighborhood = tf.reshape(neighborhood, neighborhood.shape[1])
).flatten()
z = hard_groups[neighborhood] z = hard_groups[neighborhood]
...@@ -578,8 +585,6 @@ class ClusterOverlap(Layer): ...@@ -578,8 +585,6 @@ class ClusterOverlap(Layer):
hard_groups = categorical.argmax(axis=1) hard_groups = categorical.argmax(axis=1)
max_groups = categorical.max(axis=1) max_groups = categorical.max(axis=1)
knn = NearestNeighbors().fit(encodings)
# Iterate over samples and compute purity across neighbourhood # Iterate over samples and compute purity across neighbourhood
self.samples = tf.reduce_min([self.samples, encodings.shape[0]]) self.samples = tf.reduce_min([self.samples, encodings.shape[0]])
random_idxs = range(encoding.shape[0]) random_idxs = range(encoding.shape[0])
...@@ -589,9 +594,8 @@ class ClusterOverlap(Layer): ...@@ -589,9 +594,8 @@ class ClusterOverlap(Layer):
purity_vector = tf.zeros(self.samples) purity_vector = tf.zeros(self.samples)
for i, sample in enumerate(random_idxs): for i, sample in enumerate(random_idxs):
neighborhood = knn.kneighbors( neighborhood = get_k_nearest_neighbors(encodings, sample, self.k)
tf.expand_dims(encoding[sample], 1), self.k, return_distance=False neighborhood = tf.reshape(neighborhood, neighborhood.shape[1])
).flatten()
z = hard_groups[neighborhood] z = hard_groups[neighborhood]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment