From b3941ccfc9cbe6bc688c21c405cebdd35fc4d639 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Wed, 10 Jun 2020 10:52:50 +0200
Subject: [PATCH] Changed full model for diagonal model in all variational
 implementations in models.py

---
 main.ipynb            |  2 +-
 source/model_utils.py | 36 ++++++++++++++++++++++++++++++------
 source/models.py      | 11 +++++------
 3 files changed, 36 insertions(+), 13 deletions(-)

diff --git a/main.ipynb b/main.ipynb
index 9197697d..687ad9ae 100644
--- a/main.ipynb
+++ b/main.ipynb
@@ -371,7 +371,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# vaep.summary()"
+    "vaep.summary()"
    ]
   },
   {
diff --git a/source/model_utils.py b/source/model_utils.py
index 93a23534..5dfd7058 100644
--- a/source/model_utils.py
+++ b/source/model_utils.py
@@ -109,13 +109,37 @@ class UncorrelatedFeaturesConstraint(Constraint):
         return self.weightage * self.uncorrelated_feature(x)
 
 
-class GaussianMixtureLayer(Layer):
-    def __init(self, *args, **kwargs):
-        self.is_placeholder = True
-        super(GaussianMixtureLayer, self).__init__(*args, **kwargs)
+class MultivariateNormalDiag(tfpl.DistributionLambda):
+    def __init__(
+        self,
+        event_size,
+        convert_to_tensor_fn=tfd.Distribution.sample,
+        validate_args=False,
+        **kwargs
+    ):
+
+        super(MultivariateNormalDiag, self).__init__(
+            lambda t: MultivariateNormalDiag.new(t, event_size, validate_args),
+            convert_to_tensor_fn,
+            **kwargs
+        )
 
-    def call(self, inputs, **kwargs):
-        pass
+    @staticmethod
+    def new(params, event_size, validate_args=False, name=None):
+        """Create the distribution instance from a `params` vector."""
+        with tf.name_scope(name or "MultivariateNormalDiag"):
+            params = tf.convert_to_tensor(params, name="params")
+        return tfd.mvn_diag.MultivariateNormalDiag(
+            loc=params[..., :event_size],
+            scale_diag=params[..., event_size:],
+            validate_args=validate_args,
+        )
+
+    @staticmethod
+    def params_size(event_size, name=None):
+        """The number of `params` needed to create a single distribution."""
+        with tf.name_scope(name or "MultivariateNormalDiag_params_size"):
+            return 2 * event_size
 
 
 class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
diff --git a/source/models.py b/source/models.py
index a2bda982..691098cb 100644
--- a/source/models.py
+++ b/source/models.py
@@ -282,7 +282,7 @@ class SEQ_2_SEQ_VAE:
         encoder = Model_E5(encoder)
 
         encoder = Dense(
-            tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None
+            MultivariateNormalDiag.params_size(self.ENCODING), activation=None
         )(encoder)
 
         # Define and control custom loss functions
@@ -299,7 +299,7 @@ class SEQ_2_SEQ_VAE:
                     )
                 )
 
-        z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder)
+        z = MultivariateNormalDiag(self.ENCODING)(encoder)
 
         if "ELBO" in self.loss:
             z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
@@ -495,7 +495,7 @@ class SEQ_2_SEQ_VAEP:
         encoder = Model_E5(encoder)
 
         encoder = Dense(
-            tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None
+            MultivariateNormalDiag.params_size(self.ENCODING), activation=None
         )(encoder)
 
         # Define and control custom loss functions
@@ -511,7 +511,7 @@ class SEQ_2_SEQ_VAEP:
                     )
                 )
 
-        z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder)
+        z = MultivariateNormalDiag(self.ENCODING)(encoder)
 
         if "ELBO" in self.loss:
             z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
@@ -585,7 +585,7 @@ class SEQ_2_SEQ_VAEP:
         # end-to-end autoencoder
         encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
         vaep = Model(
-            inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAE"
+            inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAEP"
         )
 
         # Build generator as a separate entity
@@ -883,7 +883,6 @@ class SEQ_2_SEQ_MMVAEP:
 
 
 # TODO:
-#       - Try sample, mean and mode for MMDiscrepancyLayer
 #       - Gaussian Mixture + Categorical priors -> Deep Clustering
 #           - prior of equal gaussians
 #           - prior of equal gaussians + gaussian noise on the means (not exactly the same init)
-- 
GitLab