From b0b65f835e411fce82a5bb6e76abce0002024551 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Thu, 3 Dec 2020 03:32:52 +0100
Subject: [PATCH] Added encoding size as a CL parameter in train_model.py

---
 deepof/hypermodels.py |  2 +-
 deepof/models.py      |  4 ++--
 deepof/train_utils.py | 13 ++++++++-----
 3 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py
index e16690f9..ca024943 100644
--- a/deepof/hypermodels.py
+++ b/deepof/hypermodels.py
@@ -170,7 +170,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
 
         gmvaep, kl_warmup_callback, mmd_warmup_callback = deepof.models.SEQ_2_SEQ_GMVAE(
             architecture_hparams={
-                "bidirectional_merge": "concat",
+                "bidirectional_merge": "ave",
                 "clipvalue": clipvalue,
                 "dense_activation": dense_activation,
                 "dense_layers_per_branch": dense_layers_per_branch,
diff --git a/deepof/models.py b/deepof/models.py
index 9953ed42..5d553a3a 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -248,9 +248,9 @@ class SEQ_2_SEQ_GMVAE:
         entropy_reg_weight: float = 0.0,
         huber_delta: float = 1.0,
         initialiser_iters: int = int(1),
-        kl_warmup_epochs: int = 0,
+        kl_warmup_epochs: int = 20,
         loss: str = "ELBO+MMD",
-        mmd_warmup_epochs: int = 0,
+        mmd_warmup_epochs: int = 20,
         number_of_components: int = 1,
         overlap_loss: float = False,
         phenotype_prediction: float = 0.0,
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index 393ace1f..68f5a880 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -32,12 +32,15 @@ def load_hparams(hparams):
             hparams = pickle.load(handle)
     else:
         hparams = {
-            "units_conv": 256,
-            "units_lstm": 256,
-            "units_dense2": 64,
-            "dropout_rate": 0.25,
-            "encoding": 16,
+            "bidirectional_merge": "ave",
+            "clipvalue": 1.0,
+            "dense_activation": "relu",
+            "dense_layers_per_branch": 1,
+            "dropout_rate": 1e-3,
             "learning_rate": 1e-3,
+            "units_conv": 160,
+            "units_dense2": 120,
+            "units_lstm": 300,
         }
     return hparams
 
-- 
GitLab