Commit 04f4192a authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent 8d3cb4c3
...@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor): ...@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor):
@tf.function @tf.function
def get_k_nearest_neighbors(tensor, k, index): def get_k_nearest_neighbors(tensor, k, index):
"""Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index""" """Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
query = tensor[index] query = tf.slice(tensor, index, 1)
distances = tf.norm(tensor - query, axis=1) distances = tf.norm(tensor - query, axis=1)
max_distance = tf.sort(distances)[k] max_distance = tf.sort(distances)[k]
neighbourhood_mask = distances < max_distance neighbourhood_mask = distances < max_distance
...@@ -583,7 +583,7 @@ class ClusterOverlap(Layer): ...@@ -583,7 +583,7 @@ class ClusterOverlap(Layer):
config.update({"samples": self.samples}) config.update({"samples": self.samples})
return config return config
#@tf.function @tf.function
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
"""Updates Layer's call method""" """Updates Layer's call method"""
...@@ -594,8 +594,12 @@ class ClusterOverlap(Layer): ...@@ -594,8 +594,12 @@ class ClusterOverlap(Layer):
# Iterate over samples and compute purity across neighbourhood # Iterate over samples and compute purity across neighbourhood
self.samples = tf.reduce_min([self.samples, tf.shape(encodings)[0]]) self.samples = tf.reduce_min([self.samples, tf.shape(encodings)[0]])
random_idxs = range(encodings.shape[0]) random_idxs = tf.range(tf.shape(encodings)[0])
random_idxs = np.random.choice(random_idxs, self.samples) random_idxs = tf.random.categorical(
tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0),
self.samples,
dtype=tf.dtypes.int32,
)
@tf.function @tf.function
def get_local_neighbourhood_entropy(index): def get_local_neighbourhood_entropy(index):
...@@ -604,7 +608,7 @@ class ClusterOverlap(Layer): ...@@ -604,7 +608,7 @@ class ClusterOverlap(Layer):
) )
purity_vector = tf.map_fn( purity_vector = tf.map_fn(
get_local_neighbourhood_entropy, tf.constant(random_idxs), dtype=tf.dtypes.float32 get_local_neighbourhood_entropy, random_idxs, dtype=tf.dtypes.float32
) )
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ### ### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment