Commit 1089c048 authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated cluster entropy regularization - added custom graph-ready neighbourhood entropy function

parent 28e1fa13
......@@ -27,10 +27,10 @@ tfpl = tfp.layers
@tf.function
def compute_shannon_entropy(tensor):
"""Computes Shannon entropy for a given tensor"""
tensor = tf.cast(tensor, tf.dtypes.int32)
bins = tf.math.bincount(tensor, dtype=tf.dtypes.float32) / tf.cast(
tensor.shape[0], tf.float32
bins = (
tf.math.bincount(tensor, dtype=tf.dtypes.float32)
/ tf.cast(tf.shape(tensor), tf.dtypes.float32)[0]
)
return -tf.reduce_sum(bins * tf.math.log(bins + 1e-5))
......@@ -38,7 +38,6 @@ def compute_shannon_entropy(tensor):
@tf.function
def get_k_nearest_neighbors(tensor, k, index):
"""Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
query = tensor[index]
distances = tf.norm(tensor - query, axis=1)
max_distance = tf.sort(distances)[k]
......@@ -48,11 +47,9 @@ def get_k_nearest_neighbors(tensor, k, index):
@tf.function
def get_neighbourhood_entropy(tensor, clusters, k, index):
neighborhood = get_k_nearest_neighbors(tensor, k, index)
cluster_z = clusters[neighborhood]
cluster_z = tf.gather(clusters, neighborhood)
neigh_entropy = compute_shannon_entropy(cluster_z)
return neigh_entropy
......@@ -295,7 +292,12 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
)
# Add result to pre allocated array
purity_vector = get_neighbourhood_entropy(encodings, random_idxs)
purity_vector = np.zeros(self.samples)
for i, sample in enumerate(random_idxs):
purity_vector[i] = get_neighbourhood_entropy(
encodings, hard_groups, self.k, sample
)
writer = tf.summary.create_file_writer(self.log_dir)
with writer.as_default():
......@@ -591,19 +593,13 @@ class ClusterOverlap(Layer):
random_idxs = tf.random.categorical(
tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0), self.samples
)
purity_vector = tf.zeros(self.samples)
for i, sample in enumerate(random_idxs):
neighborhood = get_k_nearest_neighbors(encodings, sample, self.k)
neighborhood = tf.reshape(neighborhood, neighborhood.shape[1])
z = hard_groups[neighborhood]
# Compute Shannon entropy across samples
neigh_entropy = compute_shannon_entropy(z)
purity_vector = tf.map_fn(get_neighbourhood_entropy, random_idxs)
# Add result to pre allocated array
purity_vector[i] = neigh_entropy
for i, sample in enumerate(random_idxs):
purity_vector[i] = get_neighbourhood_entropy(
encodings, hard_groups, self.k, sample
)
neighbourhood_entropy = purity_vector * max_groups[random_idxs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment