Commit e8e91836 authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated cluster entropy regularization - added custom graph-ready knn and...

Updated cluster entropy regularization - added custom graph-ready knn and Shannon's entropy implementations
parent c37c8d52
......@@ -15,8 +15,6 @@ import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from scipy.stats import entropy
from sklearn.neighbors import NearestNeighbors
from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
......@@ -31,8 +29,21 @@ def compute_shannon_entropy(tensor):
"""Computes Shannon entropy for a given tensor"""
tensor = tf.cast(tensor, tf.dtypes.int32)
bins = tf.math.bincount(tensor, dtype=tf.dtypes.float32) / tf.cast(tensor.shape[0], tf.float32)
return -tf.reduce_sum( bins * tf.math.log(bins + 1e-5))
bins = tf.math.bincount(tensor, dtype=tf.dtypes.float32) / tf.cast(
tensor.shape[0], tf.float32
)
return -tf.reduce_sum(bins * tf.math.log(bins + 1e-5))
@tf.function
def get_k_nearest_neighbors(tensor, k, index):
"""Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
query = tensor[index]
distances = tf.norm(tensor - query, axis=1)
max_distance = tf.sort(distances)[k]
neighbourhood_mask = distances < max_distance
return tf.transpose(tf.where(neighbourhood_mask))
class exponential_learning_rate(tf.keras.callbacks.Callback):
......@@ -267,9 +278,6 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
hard_groups = groups.argmax(axis=1)
max_groups = groups.max(axis=1)
# compute pairwise distances on latent space
knn = NearestNeighbors().fit(encoding)
# Iterate over samples and compute purity across neighbourhood
self.samples = np.min([self.samples, encoding.shape[0]])
random_idxs = np.random.choice(
......@@ -278,9 +286,8 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
purity_vector = np.zeros(self.samples)
for i, sample in enumerate(random_idxs):
neighborhood = knn.kneighbors(
encoding[sample][np.newaxis, :], self.k, return_distance=False
).flatten()
neighborhood = get_k_nearest_neighbors(encodings, sample, self.k)
neighborhood = tf.reshape(neighborhood, neighborhood.shape[1])
z = hard_groups[neighborhood]
......@@ -578,8 +585,6 @@ class ClusterOverlap(Layer):
hard_groups = categorical.argmax(axis=1)
max_groups = categorical.max(axis=1)
knn = NearestNeighbors().fit(encodings)
# Iterate over samples and compute purity across neighbourhood
self.samples = tf.reduce_min([self.samples, encodings.shape[0]])
random_idxs = range(encoding.shape[0])
......@@ -589,9 +594,8 @@ class ClusterOverlap(Layer):
purity_vector = tf.zeros(self.samples)
for i, sample in enumerate(random_idxs):
neighborhood = knn.kneighbors(
tf.expand_dims(encoding[sample], 1), self.k, return_distance=False
).flatten()
neighborhood = get_k_nearest_neighbors(encodings, sample, self.k)
neighborhood = tf.reshape(neighborhood, neighborhood.shape[1])
z = hard_groups[neighborhood]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment