Commit 6170ba00 by lucas_miranda

### Updated cluster entropy regularization - added custom graph-ready knn and...

`Updated cluster entropy regularization - added custom graph-ready knn and Shannon's entropy implementations`
parent d56babf5
 ... ... @@ -15,8 +15,6 @@ import matplotlib.pyplot as plt import numpy as np import tensorflow as tf import tensorflow_probability as tfp from scipy.stats import entropy from sklearn.neighbors import NearestNeighbors from tensorflow.keras import backend as K from tensorflow.keras.constraints import Constraint from tensorflow.keras.layers import Layer ... ... @@ -31,8 +29,21 @@ def compute_shannon_entropy(tensor): """Computes Shannon entropy for a given tensor""" tensor = tf.cast(tensor, tf.dtypes.int32) bins = tf.math.bincount(tensor, dtype=tf.dtypes.float32) / tf.cast(tensor.shape[0], tf.float32) return -tf.reduce_sum( bins * tf.math.log(bins + 1e-5)) bins = tf.math.bincount(tensor, dtype=tf.dtypes.float32) / tf.cast( tensor.shape[0], tf.float32 ) return -tf.reduce_sum(bins * tf.math.log(bins + 1e-5)) @tf.function def get_k_nearest_neighbors(tensor, k, index): """Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index""" query = tensor[index] distances = tf.norm(tensor - query, axis=1) max_distance = tf.sort(distances)[k] neighbourhood_mask = distances < max_distance return tf.transpose(tf.where(neighbourhood_mask)) class exponential_learning_rate(tf.keras.callbacks.Callback): ... ... @@ -267,9 +278,6 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback): hard_groups = groups.argmax(axis=1) max_groups = groups.max(axis=1) # compute pairwise distances on latent space knn = NearestNeighbors().fit(encoding) # Iterate over samples and compute purity across neighbourhood self.samples = np.min([self.samples, encoding.shape[0]]) random_idxs = np.random.choice( ... ... @@ -278,9 +286,8 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback): purity_vector = np.zeros(self.samples) for i, sample in enumerate(random_idxs): neighborhood = knn.kneighbors( encoding[sample][np.newaxis, :], self.k, return_distance=False ).flatten() neighborhood = get_k_nearest_neighbors(encodings, sample, self.k) neighborhood = tf.reshape(neighborhood, neighborhood.shape[1]) z = hard_groups[neighborhood] ... ... @@ -578,8 +585,6 @@ class ClusterOverlap(Layer): hard_groups = categorical.argmax(axis=1) max_groups = categorical.max(axis=1) knn = NearestNeighbors().fit(encodings) # Iterate over samples and compute purity across neighbourhood self.samples = tf.reduce_min([self.samples, encodings.shape[0]]) random_idxs = range(encoding.shape[0]) ... ... @@ -589,9 +594,8 @@ class ClusterOverlap(Layer): purity_vector = tf.zeros(self.samples) for i, sample in enumerate(random_idxs): neighborhood = knn.kneighbors( tf.expand_dims(encoding[sample], 1), self.k, return_distance=False ).flatten() neighborhood = get_k_nearest_neighbors(encodings, sample, self.k) neighborhood = tf.reshape(neighborhood, neighborhood.shape[1]) z = hard_groups[neighborhood] ... ...
