Commit a7080a5c authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent c36f55b7
......@@ -612,19 +612,25 @@ class ClusterOverlap(Layer):
)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy = purity_vector # * max_groups[random_idxs]
# self.add_metric(
# len(set(hard_groups[max_groups >= self.min_confidence])),
# aggregation="mean",
# name="number_of_populated_clusters",
# )
#
# self.add_metric(
# max_groups,
# aggregation="mean",
# name="average_confidence_in_selected_cluster",
# )
neighbourhood_entropy = purity_vector * tf.gather(
max_groups, tf.constant(random_idxs)
)
self.add_metric(
tf.shape(
tf.unique(
tf.squeeze(tf.gather(hard_groups, tf.where(max_groups >= self.min_confidence)))
)
),
aggregation="mean",
name="number_of_populated_clusters",
)
self.add_metric(
max_groups,
aggregation="mean",
name="average_confidence_in_selected_cluster",
)
self.add_metric(
neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy"
......
......@@ -473,13 +473,12 @@ class GMVAE:
# Dummy layer with no parameters, to retrieve the previous tensor
z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z)
if self.overlap_loss:
z = deepof.model_utils.ClusterOverlap(
self.batch_size,
self.ENCODING,
self.number_of_components,
loss_weight=self.overlap_loss,
)([z, z_cat])
z = deepof.model_utils.ClusterOverlap(
self.batch_size,
self.ENCODING,
self.number_of_components,
loss_weight=self.overlap_loss,
)([z, z_cat])
# Define and instantiate generator
g = Input(shape=self.ENCODING)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment