Commit 0958d797 authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent abc4fd8f
......@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor):
@tf.function
def get_k_nearest_neighbors(tensor, k, index):
"""Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
query = tensor[index]
query = tf.slice(tensor, index, 1)
distances = tf.norm(tensor - query, axis=1)
max_distance = tf.sort(distances)[k]
neighbourhood_mask = distances < max_distance
......@@ -583,7 +583,7 @@ class ClusterOverlap(Layer):
config.update({"samples": self.samples})
return config
#@tf.function
@tf.function
def call(self, inputs, **kwargs):
"""Updates Layer's call method"""
......@@ -594,8 +594,12 @@ class ClusterOverlap(Layer):
# Iterate over samples and compute purity across neighbourhood
self.samples = tf.reduce_min([self.samples, tf.shape(encodings)[0]])
random_idxs = range(encodings.shape[0])
random_idxs = np.random.choice(random_idxs, self.samples)
random_idxs = tf.range(tf.shape(encodings)[0])
random_idxs = tf.random.categorical(
tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0),
self.samples,
dtype=tf.dtypes.int32,
)
@tf.function
def get_local_neighbourhood_entropy(index):
......@@ -604,7 +608,7 @@ class ClusterOverlap(Layer):
)
purity_vector = tf.map_fn(
get_local_neighbourhood_entropy, tf.constant(random_idxs), dtype=tf.dtypes.float32
get_local_neighbourhood_entropy, random_idxs, dtype=tf.dtypes.float32
)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment