Commit c37c8d52 authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated cluster entropy regularization

parent 233e85d5
......@@ -26,6 +26,15 @@ tfpl = tfp.layers
# Helper functions and classes
@tf.function
def compute_shannon_entropy(tensor):
"""Computes Shannon entropy for a given tensor"""
tensor = tf.cast(tensor, tf.dtypes.int32)
bins = tf.math.bincount(tensor, dtype=tf.dtypes.float32) / tf.cast(tensor.shape[0], tf.float32)
return -tf.reduce_sum( bins * tf.math.log(bins + 1e-5))
class exponential_learning_rate(tf.keras.callbacks.Callback):
"""Simple class that allows to grow learning rate exponentially during training"""
......@@ -276,7 +285,7 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
z = hard_groups[neighborhood]
# Compute Shannon entropy across samples
neigh_entropy = entropy(np.bincount(z))
neigh_entropy = compute_shannon_entropy(z)
# Add result to pre allocated array
purity_vector[i] = neigh_entropy
......@@ -581,13 +590,13 @@ class ClusterOverlap(Layer):
for i, sample in enumerate(random_idxs):
neighborhood = knn.kneighbors(
encoding[sample][np.newaxis, :], self.k, return_distance=False
tf.expand_dims(encoding[sample], 1), self.k, return_distance=False
).flatten()
z = hard_groups[neighborhood]
# Compute Shannon entropy across samples
neigh_entropy = entropy(np.bincount(z))
neigh_entropy = compute_shannon_entropy(z)
# Add result to pre allocated array
purity_vector[i] = neigh_entropy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment