diff --git a/main.ipynb b/main.ipynb
index 9197697d6ebdfc603af81226cb8fff70fd038a91..687ad9aeed2228a4a547f89e27ed6f6e49f4912a 100644
--- a/main.ipynb
+++ b/main.ipynb
@@ -371,7 +371,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# vaep.summary()"
+    "vaep.summary()"
    ]
   },
   {
diff --git a/source/model_utils.py b/source/model_utils.py
index 93a23534feccf84ecd46862e4683419caa02f673..5dfd7058c7a94323cee086e71f8b870fe18e24cc 100644
--- a/source/model_utils.py
+++ b/source/model_utils.py
@@ -109,13 +109,37 @@ class UncorrelatedFeaturesConstraint(Constraint):
         return self.weightage * self.uncorrelated_feature(x)
 
 
-class GaussianMixtureLayer(Layer):
-    def __init(self, *args, **kwargs):
-        self.is_placeholder = True
-        super(GaussianMixtureLayer, self).__init__(*args, **kwargs)
+class MultivariateNormalDiag(tfpl.DistributionLambda):
+    def __init__(
+        self,
+        event_size,
+        convert_to_tensor_fn=tfd.Distribution.sample,
+        validate_args=False,
+        **kwargs
+    ):
+
+        super(MultivariateNormalDiag, self).__init__(
+            lambda t: MultivariateNormalDiag.new(t, event_size, validate_args),
+            convert_to_tensor_fn,
+            **kwargs
+        )
 
-    def call(self, inputs, **kwargs):
-        pass
+    @staticmethod
+    def new(params, event_size, validate_args=False, name=None):
+        """Create the distribution instance from a `params` vector."""
+        with tf.name_scope(name or "MultivariateNormalDiag"):
+            params = tf.convert_to_tensor(params, name="params")
+        return tfd.mvn_diag.MultivariateNormalDiag(
+            loc=params[..., :event_size],
+            scale_diag=params[..., event_size:],
+            validate_args=validate_args,
+        )
+
+    @staticmethod
+    def params_size(event_size, name=None):
+        """The number of `params` needed to create a single distribution."""
+        with tf.name_scope(name or "MultivariateNormalDiag_params_size"):
+            return 2 * event_size
 
 
 class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
diff --git a/source/models.py b/source/models.py
index a2bda98228b422ce62c7dc6a64db57454f8e50d1..691098cb5e58c7361193657d2a76d47b959d1a8d 100644
--- a/source/models.py
+++ b/source/models.py
@@ -282,7 +282,7 @@ class SEQ_2_SEQ_VAE:
         encoder = Model_E5(encoder)
 
         encoder = Dense(
-            tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None
+            MultivariateNormalDiag.params_size(self.ENCODING), activation=None
         )(encoder)
 
         # Define and control custom loss functions
@@ -299,7 +299,7 @@ class SEQ_2_SEQ_VAE:
                     )
                 )
 
-        z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder)
+        z = MultivariateNormalDiag(self.ENCODING)(encoder)
 
         if "ELBO" in self.loss:
             z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
@@ -495,7 +495,7 @@ class SEQ_2_SEQ_VAEP:
         encoder = Model_E5(encoder)
 
         encoder = Dense(
-            tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None
+            MultivariateNormalDiag.params_size(self.ENCODING), activation=None
         )(encoder)
 
         # Define and control custom loss functions
@@ -511,7 +511,7 @@ class SEQ_2_SEQ_VAEP:
                     )
                 )
 
-        z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder)
+        z = MultivariateNormalDiag(self.ENCODING)(encoder)
 
         if "ELBO" in self.loss:
             z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
@@ -585,7 +585,7 @@ class SEQ_2_SEQ_VAEP:
         # end-to-end autoencoder
         encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
         vaep = Model(
-            inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAE"
+            inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAEP"
         )
 
         # Build generator as a separate entity
@@ -883,7 +883,6 @@ class SEQ_2_SEQ_MMVAEP:
 
 
 # TODO:
-#       - Try sample, mean and mode for MMDiscrepancyLayer
 #       - Gaussian Mixture + Categorical priors -> Deep Clustering
 #           - prior of equal gaussians
 #           - prior of equal gaussians + gaussian noise on the means (not exactly the same init)