diff --git a/source/model_utils.py b/source/model_utils.py index fa953e8e05d348e34b42c2faa9f303aed5fe1ddb..270c909818c8dec7649746690d11d4e92fcd4674 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -275,7 +275,7 @@ class Entropy_regulariser(Layer): Identity layer that adds cluster weight entropy to the loss function """ - def __init__(self, weight=1., *args, **kwargs): + def __init__(self, weight=1.0, *args, **kwargs): self.weight = weight super(Entropy_regulariser, self).__init__(*args, **kwargs) diff --git a/source/models.py b/source/models.py index d3c0c9e008ed119026110baacaf4bd7a6f21f855..a16c562b8769336e4fe35e5457f91cda967d4381 100644 --- a/source/models.py +++ b/source/models.py @@ -173,7 +173,7 @@ class SEQ_2_SEQ_GMVAE: number_of_components=1, predictor=True, overlap_loss=False, - entropy_reg_weight=0., + entropy_reg_weight=1.0, ): self.input_shape = input_shape self.batch_size = batch_size