diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py index e16690f993c51ef6fc95eca9e9db89e1093bfdec..ca024943189386a0215a07099de247128624341b 100644 --- a/deepof/hypermodels.py +++ b/deepof/hypermodels.py @@ -170,7 +170,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): gmvaep, kl_warmup_callback, mmd_warmup_callback = deepof.models.SEQ_2_SEQ_GMVAE( architecture_hparams={ - "bidirectional_merge": "concat", + "bidirectional_merge": "ave", "clipvalue": clipvalue, "dense_activation": dense_activation, "dense_layers_per_branch": dense_layers_per_branch, diff --git a/deepof/models.py b/deepof/models.py index 9953ed42478abad118315f0dbca8f874fc4f600b..5d553a3a88010ec6dec6f7dbd2d2cc5a77ad480b 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -248,9 +248,9 @@ class SEQ_2_SEQ_GMVAE: entropy_reg_weight: float = 0.0, huber_delta: float = 1.0, initialiser_iters: int = int(1), - kl_warmup_epochs: int = 0, + kl_warmup_epochs: int = 20, loss: str = "ELBO+MMD", - mmd_warmup_epochs: int = 0, + mmd_warmup_epochs: int = 20, number_of_components: int = 1, overlap_loss: float = False, phenotype_prediction: float = 0.0, diff --git a/deepof/train_utils.py b/deepof/train_utils.py index 393ace1f23357499d50a9b4b307d14976b3eb7e6..68f5a880d51d2f9ec782fade34db2f9551a9160c 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -32,12 +32,15 @@ def load_hparams(hparams): hparams = pickle.load(handle) else: hparams = { - "units_conv": 256, - "units_lstm": 256, - "units_dense2": 64, - "dropout_rate": 0.25, - "encoding": 16, + "bidirectional_merge": "ave", + "clipvalue": 1.0, + "dense_activation": "relu", + "dense_layers_per_branch": 1, + "dropout_rate": 1e-3, "learning_rate": 1e-3, + "units_conv": 160, + "units_dense2": 120, + "units_lstm": 300, } return hparams