From 416706636d196b8ffbc00af76fae75e2f49bb0a4 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Wed, 25 Nov 2020 18:41:59 +0100
Subject: [PATCH] Added a hyperparameter controlling the number of layers near
 the latent space

---
 deepof/hypermodels.py | 10 ++++----
 deepof/models.py      | 53 +++++++++++++++++++++++++------------------
 2 files changed, 36 insertions(+), 27 deletions(-)

diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py
index a4809fa4..ba413ef0 100644
--- a/deepof/hypermodels.py
+++ b/deepof/hypermodels.py
@@ -107,11 +107,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
 
         # Architectural hyperparameters
         clipvalue = hp.Float(
-            "clipvalue",
-            min_value=0.0,
-            max_value=1.0,
-            default=0.5,
-            sampling="Linear"
+            "clipvalue", min_value=0.0, max_value=1.0, default=0.5, sampling="Linear"
         )
         conv_filters = hp.Int(
             "units_conv", min_value=128, max_value=160, step=16, default=128,
@@ -120,6 +116,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
             "units_dense2", min_value=120, max_value=180, step=10, default=150,
         )
         dense_activation = hp.Choice("dense_activation", values=["elu", "relu"])
+        dense_layers_per_branch = hp.Int("dense_layers_per_branch", min_value=1, max_value=3, default=1)
         dropout_rate = hp.Float(
             "dropout_rate",
             min_value=0.0,
@@ -138,6 +135,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
             conv_filters,
             dense_2,
             dense_activation,
+            dense_layers_per_branch,
             dropout_rate,
             encoding,
             k,
@@ -153,6 +151,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
             conv_filters,
             dense_2,
             dense_activation,
+            dense_layers_per_branch,
             dropout_rate,
             encoding,
             k,
@@ -164,6 +163,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
                 "clipvalue": clipvalue,
                 "dense_activation": dense_activation,
                 "dropout_rate": dropout_rate,
+                "dense_layers_per_branch": dense_layers_per_branch,
                 "encoding": encoding,
                 "units_conv": conv_filters,
                 "units_dense_2": dense_2,
diff --git a/deepof/models.py b/deepof/models.py
index 5aae169b..d00cc824 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -253,6 +253,7 @@ class SEQ_2_SEQ_GMVAE:
         self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
         self.clipvalue = self.hparams["clipvalue"]
         self.dense_activation = self.hparams["dense_activation"]
+        self.dense_layers_per_branch = self.hparams["dense_layers_per_branch"]
         self.learn_rate = self.hparams["learning_rate"]
         self.lstm_unroll = True
         self.compile = compile_model
@@ -313,14 +314,15 @@ class SEQ_2_SEQ_GMVAE:
         """Sets the default parameters for the model. Overwritable with a dictionary"""
 
         defaults = {
-            "units_conv": 256,
-            "units_lstm": 256,
-            "units_dense2": 64,
+            "clipvalue": 0.5,
+            "dense_activation": "elu",
+            "dense_layers_per_branch": 1,
             "dropout_rate": 0.15,
             "encoding": 16,
             "learning_rate": 1e-3,
-            "clipvalue": 0.5,
-            "dense_activation": "elu",
+            "units_conv": 256,
+            "units_dense2": 64,
+            "units_lstm": 256,
         }
 
         for k, v in params.items():
@@ -370,25 +372,32 @@ class SEQ_2_SEQ_GMVAE:
             kernel_initializer=he_uniform(),
             use_bias=True,
         )
-        Model_E4 = Dense(
-            self.DENSE_2,
-            activation=self.dense_activation,
-            # kernel_constraint=UnitNorm(axis=0),
-            kernel_initializer=he_uniform(),
-            use_bias=True,
-        )
+
+        Model_E4 = [
+            Dense(
+                self.DENSE_2,
+                activation=self.dense_activation,
+                # kernel_constraint=UnitNorm(axis=0),
+                kernel_initializer=he_uniform(),
+                use_bias=True,
+            )
+            for _ in self.dense_layers_per_branch
+        ]
 
         # Decoder layers
         Model_B1 = BatchNormalization()
         Model_B2 = BatchNormalization()
         Model_B3 = BatchNormalization()
         Model_B4 = BatchNormalization()
-        Model_D1 = Dense(
-            self.DENSE_2,
-            activation=self.dense_activation,
-            kernel_initializer=he_uniform(),
-            use_bias=True,
-        )
+        Model_D1 = [
+            Dense(
+                self.DENSE_2,
+                activation=self.dense_activation,
+                kernel_initializer=he_uniform(),
+                use_bias=True,
+            )
+            for _ in self.dense_layers_per_branch
+        ]
         Model_D2 = Dense(
             self.DENSE_1,
             activation=self.dense_activation,
@@ -516,7 +525,7 @@ class SEQ_2_SEQ_GMVAE:
         encoder = Model_E3(encoder)
         encoder = BatchNormalization()(encoder)
         encoder = Dropout(self.DROPOUT_RATE)(encoder)
-        encoder = Model_E4(encoder)
+        encoder = Sequential(Model_E4)(encoder)
         encoder = BatchNormalization()(encoder)
 
         # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
@@ -546,7 +555,7 @@ class SEQ_2_SEQ_GMVAE:
                     tfd.Independent(
                         tfd.Normal(
                             loc=gauss[1][..., : self.ENCODING, k],
-                            scale=softplus(gauss[1][..., self.ENCODING :, k]),
+                            scale=softplus(gauss[1][..., self.ENCODING:, k]),
                         ),
                         reinterpreted_batch_ndims=1,
                     )
@@ -588,7 +597,7 @@ class SEQ_2_SEQ_GMVAE:
             )(z)
 
         # Define and instantiate generator
-        generator = Model_D1(z)
+        generator = Sequential(Model_D1)(z)
         generator = Model_B1(generator)
         generator = Model_D2(generator)
         generator = Model_B2(generator)
@@ -650,7 +659,7 @@ class SEQ_2_SEQ_GMVAE:
 
         # Build generator as a separate entity
         g = Input(shape=self.ENCODING)
-        _generator = Model_D1(g)
+        _generator = Sequential(Model_D1)(g)
         _generator = Model_B1(_generator)
         _generator = Model_D2(_generator)
         _generator = Model_B2(_generator)
-- 
GitLab