diff --git a/main.ipynb b/main.ipynb
index b9df2e6a4579c56b1f50a1b09c7d9b29023361c4..d16ec63a2b21fec2cbdfa6a127b03c7a3bb43710 100644
--- a/main.ipynb
+++ b/main.ipynb
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -15,7 +15,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -30,7 +30,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 3,
    "metadata": {
     "tags": [
      "parameters"
@@ -50,7 +50,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -60,7 +60,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -76,9 +76,18 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 6,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "CPU times: user 2.59 s, sys: 818 ms, total: 3.41 s\n",
+      "Wall time: 1.1 s\n"
+     ]
+    }
+   ],
    "source": [
     "%%time\n",
     "DLC_social_1 = project(path=path,#Path where to find the required files\n",
@@ -106,7 +115,17 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Loading trajectories...\n",
+      "Smoothing trajectories...\n",
+      "Computing distances...\n"
+     ]
+    }
+   ],
    "source": [
     "%%time\n",
     "DLC_social_1_coords = DLC_social_1.run(verbose=True)\n",
@@ -336,11 +355,12 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "k.backend.clear_session()\n",
     "encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,\n",
     "                                                                   loss='ELBO+MMD',\n",
     "                                                                   kl_warmup_epochs=10,\n",
     "                                                                   mmd_warmup_epochs=10).build()\n",
-    "vae.build(pttest.shape)"
+    "#vae.build(pttest.shape)"
    ]
   },
   {
@@ -400,17 +420,6 @@
     "plot_model(gmvaep, show_shapes=True)"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {
-    "scrolled": false
-   },
-   "outputs": [],
-   "source": [
-    "?plot_model"
-   ]
-  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -419,7 +428,8 @@
    "source": [
     "#np.random.shuffle(pttest)\n",
     "pttrain = pttest[:-15000]\n",
-    "pttest  = pttest[-15000:]"
+    "pttest  = pttest[-15000:]\n",
+    "pttrain = pttrain[:15000]"
    ]
   },
   {
@@ -439,7 +449,7 @@
    "outputs": [],
    "source": [
     "# tf.config.experimental_run_functions_eagerly(False)\n",
-    "history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=500, batch_size=512, verbose=1,\n",
+    "history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=2, batch_size=512, verbose=1,\n",
     "                  validation_data=(pttest[:-1], pttest[:-1]),\n",
     "                  callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])"
    ]
diff --git a/source/model_utils.py b/source/model_utils.py
index 9114e067dcf230c1783b8061c1994ec47f68668f..32f3112ff5d87b5e6ed25070c1e78a2c59f564b0 100644
--- a/source/model_utils.py
+++ b/source/model_utils.py
@@ -7,20 +7,9 @@ import tensorflow as tf
 import tensorflow_probability as tfp
 
 tfd = tfp.distributions
+tfpl = tfp.layers
 
 # Helper functions
-def sampling(args, epsilon_std=1.0, number_of_components=1, categorical=None):
-    z_mean, z_log_sigma = args
-
-    if number_of_components == 1:
-        epsilon = K.random_normal(shape=K.shape(z_mean), mean=0.0, stddev=epsilon_std)
-        return z_mean + K.exp(z_log_sigma) * epsilon
-
-    else:
-        # Implement mixture of gaussians encoding and sampling
-        pass
-
-
 def compute_kernel(x, y):
     x_size = K.shape(x)[0]
     y_size = K.shape(y)[0]
@@ -120,35 +109,20 @@ class UncorrelatedFeaturesConstraint(Constraint):
         return self.weightage * self.uncorrelated_feature(x)
 
 
-class KLDivergenceLayer(Layer):
-
-    """ Identity transform layer that adds KL divergence
-    to the final model loss.
-    """
-
-    def __init__(self, beta=1.0, *args, **kwargs):
+class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
+    def __init__(self, *args, **kwargs):
         self.is_placeholder = True
-        self.beta = beta
         super(KLDivergenceLayer, self).__init__(*args, **kwargs)
 
-    def get_config(self):
-        config = super().get_config().copy()
-        config.update({"beta": self.beta})
-        return config
-
-    def call(self, inputs, **kwargs):
-        mu, log_var = inputs
-        KL_batch = (
-            -0.5
-            * self.beta
-            * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1)
+    def call(self, distribution_a):
+        kl_batch = self._regularizer(distribution_a)
+        self.add_loss(kl_batch, inputs=[distribution_a])
+        self.add_metric(
+            kl_batch, aggregation="mean", name="kl_divergence",
         )
+        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
 
-        self.add_loss(K.mean(KL_batch), inputs=inputs)
-        self.add_metric(KL_batch, aggregation="mean", name="kl_divergence")
-        self.add_metric(self.beta, aggregation="mean", name="kl_rate")
-
-        return inputs
+        return distribution_a
 
 
 class MMDiscrepancyLayer(Layer):
@@ -156,20 +130,21 @@ class MMDiscrepancyLayer(Layer):
     to the final model loss.
     """
 
-    def __init__(self, beta=1.0, *args, **kwargs):
+    def __init__(self, prior, beta=1.0, *args, **kwargs):
         self.is_placeholder = True
         self.beta = beta
+        self.prior = prior
         super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
 
     def get_config(self):
         config = super().get_config().copy()
         config.update({"beta": self.beta})
+        config.update({"prior": self.prior})
         return config
 
     def call(self, z, **kwargs):
-        true_samples = K.random_normal(K.shape(z))
+        true_samples = self.prior.sample(1)
         mmd_batch = self.beta * compute_mmd(true_samples, z)
-
         self.add_loss(K.mean(mmd_batch), inputs=z)
         self.add_metric(mmd_batch, aggregation="mean", name="mmd")
         self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
diff --git a/source/models.py b/source/models.py
index a2c89fe1d5367a25fc12ed0b9ec4b17b86105bd6..9ae61fec31600f08cdb7d9333b24e15977f49cae 100644
--- a/source/models.py
+++ b/source/models.py
@@ -281,9 +281,6 @@ class SEQ_2_SEQ_VAE:
         encoder = BatchNormalization()(encoder)
         encoder = Model_E5(encoder)
 
-        # z_mean = Dense(self.ENCODING)(encoder)
-        # z_log_sigma = Dense(self.ENCODING)(encoder)
-
         encoder = Dense(
             tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None
         )(encoder)
@@ -302,17 +299,10 @@ class SEQ_2_SEQ_VAE:
                     )
                 )
 
-            # z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma])
+        z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder)
 
-        # z = Lambda(sampling)([z_mean, z_log_sigma])
-        z = tfpl.MultivariateNormalTriL(
-            self.ENCODING,
-            activity_regularizer=(
-                tfpl.KLDivergenceRegularizer(self.prior, weight=kl_beta)
-                if "ELBO" in self.loss
-                else None
-            ),
-        )(encoder)
+        if "ELBO" in self.loss:
+            z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
 
         mmd_warmup_callback = False
         if "MMD" in self.loss:
@@ -327,7 +317,7 @@ class SEQ_2_SEQ_VAE:
                     )
                 )
 
-            z = MMDiscrepancyLayer(beta=mmd_beta)(z)
+            z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
 
         # Define and instantiate generator
         generator = Model_D0(z)
@@ -388,6 +378,7 @@ class SEQ_2_SEQ_VAEP:
         loss="ELBO+MMD",
         kl_warmup_epochs=0,
         mmd_warmup_epochs=0,
+        prior="standard_normal",
     ):
         self.input_shape = input_shape
         self.CONV_filters = CONV_filters
@@ -399,9 +390,16 @@ class SEQ_2_SEQ_VAEP:
         self.ENCODING = ENCODING
         self.learn_rate = learn_rate
         self.loss = loss
+        self.prior = prior
         self.kl_warmup = kl_warmup_epochs
         self.mmd_warmup = mmd_warmup_epochs
 
+        if self.prior == "standard_normal":
+            self.prior = tfd.Independent(
+                tfd.Normal(loc=tf.zeros(self.ENCODING), scale=1),
+                reinterpreted_batch_ndims=1,
+            )
+
         assert (
             "ELBO" in self.loss or "MMD" in self.loss
         ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"
@@ -496,8 +494,9 @@ class SEQ_2_SEQ_VAEP:
         encoder = BatchNormalization()(encoder)
         encoder = Model_E5(encoder)
 
-        z_mean = Dense(self.ENCODING)(encoder)
-        z_log_sigma = Dense(self.ENCODING)(encoder)
+        encoder = Dense(
+            tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None
+        )(encoder)
 
         # Define and control custom loss functions
         kl_warmup_callback = False
@@ -512,9 +511,10 @@ class SEQ_2_SEQ_VAEP:
                     )
                 )
 
-            z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma])
+        z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder)
 
-        z = Lambda(sampling)([z_mean, z_log_sigma])
+        if "ELBO" in self.loss:
+            z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
 
         mmd_warmup_callback = False
         if "MMD" in self.loss:
@@ -583,7 +583,7 @@ class SEQ_2_SEQ_VAEP:
         )(predictor)
 
         # end-to-end autoencoder
-        encoder = Model(x, z_mean, name="SEQ_2_SEQ_VEncoder")
+        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
         vaep = Model(
             inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAE"
         )
@@ -629,6 +629,7 @@ class SEQ_2_SEQ_MMVAEP:
         loss="ELBO+MMD",
         kl_warmup_epochs=0,
         mmd_warmup_epochs=0,
+        prior="standard_normal",
         number_of_components=1,
     ):
         self.input_shape = input_shape
@@ -641,13 +642,16 @@ class SEQ_2_SEQ_MMVAEP:
         self.ENCODING = ENCODING
         self.learn_rate = learn_rate
         self.loss = loss
+        self.prior = prior
         self.kl_warmup = kl_warmup_epochs
         self.mmd_warmup = mmd_warmup_epochs
         self.number_of_components = number_of_components
 
-        assert (
-            self.number_of_components > 0
-        ), "The number of components must be an integer greater than zero"
+        if self.prior == "standard_normal":
+            self.prior = tfd.Independent(
+                tfd.Normal(loc=tf.zeros(self.ENCODING), scale=1),
+                reinterpreted_batch_ndims=1,
+            )
 
         assert (
             "ELBO" in self.loss or "MMD" in self.loss
@@ -743,19 +747,9 @@ class SEQ_2_SEQ_MMVAEP:
         encoder = BatchNormalization()(encoder)
         encoder = Model_E5(encoder)
 
-        # Categorical prior on mixture of Gaussians
-        categories = Dense(self.number_of_components, activation="softmax")
-
-        # Define mean and log_sigma as lists of vectors with an item per prior component
-        z_mean = []
-        z_log_sigma = []
-        for i in range(self.number_of_components):
-            z_mean.append(
-                Dense(self.ENCODING, name="{}_gaussian_mean".format(i + 1))(encoder)
-            )
-            z_log_sigma.append(
-                Dense(self.ENCODING, name="{}_gaussian_sigma".format(i + 1))(encoder)
-            )
+        encoder = Dense(
+            tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None
+        )(encoder)
 
         # Define and control custom loss functions
         kl_warmup_callback = False
@@ -770,11 +764,10 @@ class SEQ_2_SEQ_MMVAEP:
                     )
                 )
 
-            z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)(
-                [z_mean[0], z_log_sigma[0]]
-            )
+        z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder)
 
-        z = Lambda(sampling)([z_mean, z_log_sigma])
+        if "ELBO" in self.loss:
+            z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
 
         mmd_warmup_callback = False
         if "MMD" in self.loss:
@@ -803,7 +796,7 @@ class SEQ_2_SEQ_MMVAEP:
         generator = Model_D5(generator)
         generator = Model_B5(generator)
         x_decoded_mean = TimeDistributed(
-            Dense(self.input_shape[2]), name="gmvaep_reconstruction"
+            Dense(self.input_shape[2]), name="vaep_reconstruction"
         )(generator)
 
         # Define and instantiate predictor
@@ -839,11 +832,11 @@ class SEQ_2_SEQ_MMVAEP:
         )(predictor)
         predictor = BatchNormalization()(predictor)
         x_predicted_mean = TimeDistributed(
-            Dense(self.input_shape[2]), name="gmvaep_prediction"
+            Dense(self.input_shape[2]), name="vaep_prediction"
         )(predictor)
 
         # end-to-end autoencoder
-        encoder = Model(x, z_mean, name="SEQ_2_SEQ_VEncoder")
+        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
         gmvaep = Model(
             inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAE"
         )