diff --git a/source/model_utils.py b/source/model_utils.py index 21ede4f7de96160c265f0ae5bfeaf26c3ed5e1ba..49cc589caaeb4ee30587222b95330268c528dbfc 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -275,7 +275,7 @@ class Entropy_regulariser(Layer): Identity layer that adds cluster weight entropy to the loss function """ - def __init__(self, weight=1., *args, **kwargs): + def __init__(self, weight=0., *args, **kwargs): self.weight = weight super(Entropy_regulariser, self).__init__(*args, **kwargs) @@ -290,7 +290,7 @@ class Entropy_regulariser(Layer): ) # Adds metric that monitors dead neurons in the latent space - self.add_metric(-entropy, aggregation="mean", name="weight_entropy") + self.add_metric(entropy, aggregation="mean", name="-weight_entropy") self.add_loss(self.weight * K.sum(entropy), inputs=[z])