From b8089d68837587c6262b5f7f0f919cad07dca78b Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Wed, 19 May 2021 01:49:03 +0200 Subject: [PATCH] Replaced for loop with vectorised mapping on ClusterOverlap regularization layer --- deepof/model_utils.py | 32 +++++++++++++++++++------------- deepof/models.py | 13 ++++++------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/deepof/model_utils.py b/deepof/model_utils.py index 064d4cb2..585623e9 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -612,19 +612,25 @@ class ClusterOverlap(Layer): ) ### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ### - neighbourhood_entropy = purity_vector # * max_groups[random_idxs] - - # self.add_metric( - # len(set(hard_groups[max_groups >= self.min_confidence])), - # aggregation="mean", - # name="number_of_populated_clusters", - # ) - # - # self.add_metric( - # max_groups, - # aggregation="mean", - # name="average_confidence_in_selected_cluster", - # ) + neighbourhood_entropy = purity_vector * tf.gather( + max_groups, tf.constant(random_idxs) + ) + + self.add_metric( + tf.shape( + tf.unique( + tf.squeeze(tf.gather(hard_groups, tf.where(max_groups >= self.min_confidence))) + ) + ), + aggregation="mean", + name="number_of_populated_clusters", + ) + + self.add_metric( + max_groups, + aggregation="mean", + name="average_confidence_in_selected_cluster", + ) self.add_metric( neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy" diff --git a/deepof/models.py b/deepof/models.py index c8a61fc3..10a18709 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -473,13 +473,12 @@ class GMVAE: # Dummy layer with no parameters, to retrieve the previous tensor z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z) - if self.overlap_loss: - z = deepof.model_utils.ClusterOverlap( - self.batch_size, - self.ENCODING, - self.number_of_components, - loss_weight=self.overlap_loss, - )([z, z_cat]) + z = deepof.model_utils.ClusterOverlap( + self.batch_size, + self.ENCODING, + self.number_of_components, + loss_weight=self.overlap_loss, + )([z, z_cat]) # Define and instantiate generator g = Input(shape=self.ENCODING) -- GitLab