diff --git a/deepof/model_utils.py b/deepof/model_utils.py index e753bf4330cfc787de667400fcf591396e020846..211dc070128ebed42d84e97004339f5db363bade 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -619,21 +619,23 @@ class ClusterOverlap(Layer): max_groups, tf.constant(random_idxs) ) - self.add_metric( - tf.cast( - tf.shape( - tf.unique( - tf.reshape( - tf.gather( - tf.cast(hard_groups, tf.dtypes.float32), - tf.where(max_groups >= self.min_confidence), - ), - [-1], + number_of_clusters = tf.cast( + tf.shape( + tf.unique( + tf.reshape( + tf.gather( + tf.cast(hard_groups, tf.dtypes.float32), + tf.where(max_groups >= self.min_confidence), ), - )[0], + [-1], + ), )[0], - tf.dtypes.float32, - ), + )[0], + tf.dtypes.float32, + ) + + self.add_metric( + number_of_clusters, name="number_of_populated_clusters", ) @@ -648,6 +650,9 @@ class ClusterOverlap(Layer): ) if self.loss_weight: + # minimize local entropy self.add_loss(self.loss_weight * tf.reduce_mean(neighbourhood_entropy)) + # maximize number of clusters + self.add_loss(-self.loss_weight * tf.reduce_mean(number_of_clusters)) return encodings diff --git a/deepof_experiments.smk b/deepof_experiments.smk index 4a1490fac810e68e5274b3016748cde3c5c4c2b6..cf1e7b4c3455d97ca64f843d00e6ee2a8d9091f3 100644 --- a/deepof_experiments.smk +++ b/deepof_experiments.smk @@ -28,7 +28,7 @@ phenotype_pred_weights = [0.0] rule_based_pred_weights = [0.0] window_lengths = [22] # range(11,56,11) input_types = ["coords"] -run = list(range(1, 11)) +run = list(range(1, 4)) rule deepof_experiments: