From 7b86681defc3ca2182d8bbdd7ce6d14f982c747b Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Sun, 7 Jun 2020 13:02:12 +0200
Subject: [PATCH] Instanciated a mean and a variance for each component, and
 the categorical prior as a Dense layer with a softmax activation

---
 source/models.py | 56 +++++++++++++++++++++++++-----------------------
 1 file changed, 29 insertions(+), 27 deletions(-)

diff --git a/source/models.py b/source/models.py
index c21bbcfc..cfb9d626 100644
--- a/source/models.py
+++ b/source/models.py
@@ -344,9 +344,7 @@ class SEQ_2_SEQ_VAE:
             return self.input_shape[1:] * huber(x_, x_decoded_mean_)
 
         vae.compile(
-            loss=huber_loss,
-            optimizer=Adam(lr=self.learn_rate,),
-            metrics=["mae"],
+            loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
         )
 
         return encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback
@@ -587,9 +585,7 @@ class SEQ_2_SEQ_VAEP:
             return self.input_shape[1:] * huber(x_, x_decoded_mean_)
 
         vaep.compile(
-            loss=huber_loss,
-            optimizer=Adam(lr=self.learn_rate,),
-            metrics=["mae"],
+            loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
         )
 
         return encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback
@@ -597,19 +593,19 @@ class SEQ_2_SEQ_VAEP:
 
 class SEQ_2_SEQ_MMVAEP:
     def __init__(
-            self,
-            input_shape,
-            CONV_filters=256,
-            LSTM_units_1=256,
-            LSTM_units_2=64,
-            DENSE_2=64,
-            DROPOUT_RATE=0.25,
-            ENCODING=32,
-            learn_rate=1e-3,
-            loss="ELBO+MMD",
-            kl_warmup_epochs=0,
-            mmd_warmup_epochs=0,
-            number_of_components=1,
+        self,
+        input_shape,
+        CONV_filters=256,
+        LSTM_units_1=256,
+        LSTM_units_2=64,
+        DENSE_2=64,
+        DROPOUT_RATE=0.25,
+        ENCODING=32,
+        learn_rate=1e-3,
+        loss="ELBO+MMD",
+        kl_warmup_epochs=0,
+        mmd_warmup_epochs=0,
+        number_of_components=1,
     ):
         self.input_shape = input_shape
         self.CONV_filters = CONV_filters
@@ -626,7 +622,11 @@ class SEQ_2_SEQ_MMVAEP:
         self.number_of_components = number_of_components
 
         assert (
-                "ELBO" in self.loss or "MMD" in self.loss
+            self.number_of_components > 0
+        ), "The number of components must be an integer greater than zero"
+
+        assert (
+            "ELBO" in self.loss or "MMD" in self.loss
         ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"
 
     def build(self):
@@ -684,8 +684,8 @@ class SEQ_2_SEQ_MMVAEP:
         Model_D0 = DenseTranspose(
             Model_E5, activation="relu", output_dim=self.ENCODING,
         )
-        Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2, )
-        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1, )
+        Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
+        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
         Model_D3 = RepeatVector(self.input_shape[1])
         Model_D4 = Bidirectional(
             LSTM(
@@ -719,8 +719,12 @@ class SEQ_2_SEQ_MMVAEP:
         encoder = BatchNormalization()(encoder)
         encoder = Model_E5(encoder)
 
-        z_mean = Dense(self.ENCODING)(encoder)
-        z_log_sigma = Dense(self.ENCODING)(encoder)
+        # Categorical prior on mixture of Gaussians
+        categories = Dense(self.number_of_components, activation="softmax")
+
+        for _ in range(self.number_of_components):
+            z_mean = Dense(self.ENCODING)(encoder)
+            z_log_sigma = Dense(self.ENCODING)(encoder)
 
         # Define and control custom loss functions
         kl_warmup_callback = False
@@ -832,9 +836,7 @@ class SEQ_2_SEQ_MMVAEP:
             return self.input_shape[1:] * huber(x_, x_decoded_mean_)
 
         gmvaep.compile(
-            loss=huber_loss,
-            optimizer=Adam(lr=self.learn_rate, ),
-            metrics=["mae"],
+            loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
         )
 
         return encoder, generator, gmvaep, kl_warmup_callback, mmd_warmup_callback
-- 
GitLab