Commit 28e1fa13 authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated cluster entropy regularization - added custom graph-ready neighbourhood entropy function

parent e8e91836
......@@ -43,7 +43,17 @@ def get_k_nearest_neighbors(tensor, k, index):
distances = tf.norm(tensor - query, axis=1)
max_distance = tf.sort(distances)[k]
neighbourhood_mask = distances < max_distance
return tf.transpose(tf.where(neighbourhood_mask))
return tf.squeeze(tf.where(neighbourhood_mask))
@tf.function
def get_neighbourhood_entropy(tensor, clusters, k, index):
neighborhood = get_k_nearest_neighbors(tensor, k, index)
cluster_z = clusters[neighborhood]
neigh_entropy = compute_shannon_entropy(cluster_z)
return neigh_entropy
class exponential_learning_rate(tf.keras.callbacks.Callback):
......@@ -283,19 +293,9 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
random_idxs = np.random.choice(
range(encoding.shape[0]), self.samples, replace=False
)
purity_vector = np.zeros(self.samples)
for i, sample in enumerate(random_idxs):
neighborhood = get_k_nearest_neighbors(encodings, sample, self.k)
neighborhood = tf.reshape(neighborhood, neighborhood.shape[1])
z = hard_groups[neighborhood]
# Compute Shannon entropy across samples
neigh_entropy = compute_shannon_entropy(z)
# Add result to pre allocated array
purity_vector[i] = neigh_entropy
# Add result to pre allocated array
purity_vector = get_neighbourhood_entropy(encodings, random_idxs)
writer = tf.summary.create_file_writer(self.log_dir)
with writer.as_default():
......@@ -582,14 +582,14 @@ class ClusterOverlap(Layer):
def call(self, encodings, categorical, **kwargs):
"""Updates Layer's call method"""
hard_groups = categorical.argmax(axis=1)
max_groups = categorical.max(axis=1)
hard_groups = tf.math.argmax(categorical, axis=1)
max_groups = tf.reduce_max(categorical, axis=1)
# Iterate over samples and compute purity across neighbourhood
self.samples = tf.reduce_min([self.samples, encodings.shape[0]])
random_idxs = range(encoding.shape[0])
random_idxs = tf.random.categorical(
tf.math.log([[i / sum(random_idxs) for i in random_idxs]]), 2
tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0), self.samples
)
purity_vector = tf.zeros(self.samples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment