Skip to content
Snippets Groups Projects
Commit 0958d797 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent abc4fd8f
No related branches found
No related tags found
No related merge requests found
......@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor):
@tf.function
def get_k_nearest_neighbors(tensor, k, index):
"""Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
query = tensor[index]
query = tf.slice(tensor, index, 1)
distances = tf.norm(tensor - query, axis=1)
max_distance = tf.sort(distances)[k]
neighbourhood_mask = distances < max_distance
......@@ -583,7 +583,7 @@ class ClusterOverlap(Layer):
config.update({"samples": self.samples})
return config
#@tf.function
@tf.function
def call(self, inputs, **kwargs):
"""Updates Layer's call method"""
......@@ -594,8 +594,12 @@ class ClusterOverlap(Layer):
# Iterate over samples and compute purity across neighbourhood
self.samples = tf.reduce_min([self.samples, tf.shape(encodings)[0]])
random_idxs = range(encodings.shape[0])
random_idxs = np.random.choice(random_idxs, self.samples)
random_idxs = tf.range(tf.shape(encodings)[0])
random_idxs = tf.random.categorical(
tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0),
self.samples,
dtype=tf.dtypes.int32,
)
@tf.function
def get_local_neighbourhood_entropy(index):
......@@ -604,7 +608,7 @@ class ClusterOverlap(Layer):
)
purity_vector = tf.map_fn(
get_local_neighbourhood_entropy, tf.constant(random_idxs), dtype=tf.dtypes.float32
get_local_neighbourhood_entropy, random_idxs, dtype=tf.dtypes.float32
)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment