diff --git a/deepof/model_utils.py b/deepof/model_utils.py index 064d4cb2fbc9ee7c371ce71fc13da60f27b91cc4..585623e9b3259b5cad761129aa927f25ba1de2d3 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -612,19 +612,25 @@ class ClusterOverlap(Layer): ) ### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ### - neighbourhood_entropy = purity_vector # * max_groups[random_idxs] - - # self.add_metric( - # len(set(hard_groups[max_groups >= self.min_confidence])), - # aggregation="mean", - # name="number_of_populated_clusters", - # ) - # - # self.add_metric( - # max_groups, - # aggregation="mean", - # name="average_confidence_in_selected_cluster", - # ) + neighbourhood_entropy = purity_vector * tf.gather( + max_groups, tf.constant(random_idxs) + ) + + self.add_metric( + tf.shape( + tf.unique( + tf.squeeze(tf.gather(hard_groups, tf.where(max_groups >= self.min_confidence))) + ) + ), + aggregation="mean", + name="number_of_populated_clusters", + ) + + self.add_metric( + max_groups, + aggregation="mean", + name="average_confidence_in_selected_cluster", + ) self.add_metric( neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy" diff --git a/deepof/models.py b/deepof/models.py index c8a61fc35a18885e3ae506dd7ff498407ba4c93e..10a187090cb0b301f4dae9599aea322b8c0fffb8 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -473,13 +473,12 @@ class GMVAE: # Dummy layer with no parameters, to retrieve the previous tensor z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z) - if self.overlap_loss: - z = deepof.model_utils.ClusterOverlap( - self.batch_size, - self.ENCODING, - self.number_of_components, - loss_weight=self.overlap_loss, - )([z, z_cat]) + z = deepof.model_utils.ClusterOverlap( + self.batch_size, + self.ENCODING, + self.number_of_components, + loss_weight=self.overlap_loss, + )([z, z_cat]) # Define and instantiate generator g = Input(shape=self.ENCODING)