Commit 2f627d63 authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent ef5f9a98
......@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor):
@tf.function
def get_k_nearest_neighbors(tensor, k, index):
"""Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
query = tensor[index]
query = tf.gather(tensor, index)
distances = tf.norm(tensor - query, axis=1)
max_distance = tf.sort(distances)[k]
neighbourhood_mask = distances < max_distance
......@@ -558,13 +558,15 @@ class ClusterOverlap(Layer):
def __init__(
self,
batch_size: int,
encoding_dim: int,
k: int = 100,
loss_weight: float = 0.0,
samples: int = 512,
samples: int = 50,
*args,
**kwargs
):
self.batch_size = batch_size
self.enc = encoding_dim
self.k = k
self.loss_weight = loss_weight
......@@ -576,6 +578,7 @@ class ClusterOverlap(Layer):
"""Updates Constraint metadata"""
config = super().get_config().copy()
config.update({"batch_size": self.batch_size})
config.update({"enc": self.enc})
config.update({"k": self.k})
config.update({"loss_weight": self.loss_weight})
......@@ -583,7 +586,6 @@ class ClusterOverlap(Layer):
config.update({"samples": self.samples})
return config
@tf.function
def call(self, inputs, **kwargs):
"""Updates Layer's call method"""
......@@ -593,46 +595,42 @@ class ClusterOverlap(Layer):
max_groups = tf.reduce_max(categorical, axis=1)
# Iterate over samples and compute purity across neighbourhood
self.samples = tf.reduce_min([self.samples, tf.shape(encodings)[0]])
random_idxs = tf.range(tf.shape(encodings)[0])
random_idxs = tf.random.categorical(
tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0),
self.samples,
dtype=tf.dtypes.int32,
)
self.samples = np.min([self.samples, self.batch_size]) # convert to numpy
random_idxs = range(self.batch_size) # convert to batch size
random_idxs = np.random.choice(random_idxs, self.samples)
@tf.function
def get_local_neighbourhood_entropy(index):
return get_neighbourhood_entropy(
index, tensor=encodings, clusters=hard_groups, k=self.k
)
get_local_neighbourhood_entropy = partial(
get_neighbourhood_entropy, tensor=encodings, clusters=hard_groups, k=self.k
)
purity_vector = tf.map_fn(
get_local_neighbourhood_entropy, random_idxs, dtype=tf.dtypes.float32
get_local_neighbourhood_entropy,
tf.constant(random_idxs),
dtype=tf.dtypes.float32,
)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy = purity_vector * max_groups[random_idxs]
self.add_metric(
len(set(hard_groups[max_groups >= self.min_confidence])),
aggregation="mean",
name="number_of_populated_clusters",
)
self.add_metric(
max_groups,
aggregation="mean",
name="average_confidence_in_selected_cluster",
)
neighbourhood_entropy = purity_vector # * max_groups[random_idxs]
# self.add_metric(
# len(set(hard_groups[max_groups >= self.min_confidence])),
# aggregation="mean",
# name="number_of_populated_clusters",
# )
#
# self.add_metric(
# max_groups,
# aggregation="mean",
# name="average_confidence_in_selected_cluster",
# )
self.add_metric(
neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy"
)
if self.loss_weight:
self.add_loss(
self.loss_weight * neighbourhood_entropy, inputs=[target, categorical]
)
# if self.loss_weight:
# self.add_loss(
# self.loss_weight * neighbourhood_entropy, inputs=inputs
# )
return encodings
......@@ -475,6 +475,7 @@ class GMVAE:
if self.overlap_loss:
z = deepof.model_utils.ClusterOverlap(
self.batch_size,
self.ENCODING,
self.number_of_components,
loss_weight=self.overlap_loss,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment