Commit 2f627d63 authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent ef5f9a98
...@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor): ...@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor):
@tf.function @tf.function
def get_k_nearest_neighbors(tensor, k, index): def get_k_nearest_neighbors(tensor, k, index):
"""Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index""" """Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
query = tensor[index] query = tf.gather(tensor, index)
distances = tf.norm(tensor - query, axis=1) distances = tf.norm(tensor - query, axis=1)
max_distance = tf.sort(distances)[k] max_distance = tf.sort(distances)[k]
neighbourhood_mask = distances < max_distance neighbourhood_mask = distances < max_distance
...@@ -558,13 +558,15 @@ class ClusterOverlap(Layer): ...@@ -558,13 +558,15 @@ class ClusterOverlap(Layer):
def __init__( def __init__(
self, self,
batch_size: int,
encoding_dim: int, encoding_dim: int,
k: int = 100, k: int = 100,
loss_weight: float = 0.0, loss_weight: float = 0.0,
samples: int = 512, samples: int = 50,
*args, *args,
**kwargs **kwargs
): ):
self.batch_size = batch_size
self.enc = encoding_dim self.enc = encoding_dim
self.k = k self.k = k
self.loss_weight = loss_weight self.loss_weight = loss_weight
...@@ -576,6 +578,7 @@ class ClusterOverlap(Layer): ...@@ -576,6 +578,7 @@ class ClusterOverlap(Layer):
"""Updates Constraint metadata""" """Updates Constraint metadata"""
config = super().get_config().copy() config = super().get_config().copy()
config.update({"batch_size": self.batch_size})
config.update({"enc": self.enc}) config.update({"enc": self.enc})
config.update({"k": self.k}) config.update({"k": self.k})
config.update({"loss_weight": self.loss_weight}) config.update({"loss_weight": self.loss_weight})
...@@ -583,7 +586,6 @@ class ClusterOverlap(Layer): ...@@ -583,7 +586,6 @@ class ClusterOverlap(Layer):
config.update({"samples": self.samples}) config.update({"samples": self.samples})
return config return config
@tf.function
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
"""Updates Layer's call method""" """Updates Layer's call method"""
...@@ -593,46 +595,42 @@ class ClusterOverlap(Layer): ...@@ -593,46 +595,42 @@ class ClusterOverlap(Layer):
max_groups = tf.reduce_max(categorical, axis=1) max_groups = tf.reduce_max(categorical, axis=1)
# Iterate over samples and compute purity across neighbourhood # Iterate over samples and compute purity across neighbourhood
self.samples = tf.reduce_min([self.samples, tf.shape(encodings)[0]]) self.samples = np.min([self.samples, self.batch_size]) # convert to numpy
random_idxs = tf.range(tf.shape(encodings)[0]) random_idxs = range(self.batch_size) # convert to batch size
random_idxs = tf.random.categorical( random_idxs = np.random.choice(random_idxs, self.samples)
tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0),
self.samples,
dtype=tf.dtypes.int32,
)
@tf.function get_local_neighbourhood_entropy = partial(
def get_local_neighbourhood_entropy(index): get_neighbourhood_entropy, tensor=encodings, clusters=hard_groups, k=self.k
return get_neighbourhood_entropy( )
index, tensor=encodings, clusters=hard_groups, k=self.k
)
purity_vector = tf.map_fn( purity_vector = tf.map_fn(
get_local_neighbourhood_entropy, random_idxs, dtype=tf.dtypes.float32 get_local_neighbourhood_entropy,
tf.constant(random_idxs),
dtype=tf.dtypes.float32,
) )
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ### ### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy = purity_vector * max_groups[random_idxs] neighbourhood_entropy = purity_vector # * max_groups[random_idxs]
self.add_metric( # self.add_metric(
len(set(hard_groups[max_groups >= self.min_confidence])), # len(set(hard_groups[max_groups >= self.min_confidence])),
aggregation="mean", # aggregation="mean",
name="number_of_populated_clusters", # name="number_of_populated_clusters",
) # )
#
self.add_metric( # self.add_metric(
max_groups, # max_groups,
aggregation="mean", # aggregation="mean",
name="average_confidence_in_selected_cluster", # name="average_confidence_in_selected_cluster",
) # )
self.add_metric( self.add_metric(
neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy" neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy"
) )
if self.loss_weight: # if self.loss_weight:
self.add_loss( # self.add_loss(
self.loss_weight * neighbourhood_entropy, inputs=[target, categorical] # self.loss_weight * neighbourhood_entropy, inputs=inputs
) # )
return encodings return encodings
...@@ -475,6 +475,7 @@ class GMVAE: ...@@ -475,6 +475,7 @@ class GMVAE:
if self.overlap_loss: if self.overlap_loss:
z = deepof.model_utils.ClusterOverlap( z = deepof.model_utils.ClusterOverlap(
self.batch_size,
self.ENCODING, self.ENCODING,
self.number_of_components, self.number_of_components,
loss_weight=self.overlap_loss, loss_weight=self.overlap_loss,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment