diff --git a/deepof/models.py b/deepof/models.py
index 3ce3e5b6a62a383cc010f8646050b5e6fd1cd3c8..29e73854e5aa93cccc49912ed48c58cb4388664a 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -624,7 +624,7 @@ class SEQ_2_SEQ_GMVAE:
             name="encoding_distribution",
         )([z_cat, z_gauss])
 
-        encode_to_distribution = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
+        posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
 
         # Define and control custom loss functions
         if "ELBO" in self.loss:
@@ -679,7 +679,7 @@ class SEQ_2_SEQ_GMVAE:
         )(generator)
 
         # define individual branches as models
-        encode_to_vector = Model(x, z, name="SEQ_2_SEQ_VEncoder")
+        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
         generator = Model(g, x_decoded_mean, name="vae_reconstruction")
 
         def log_loss(x_true, p_x_q_given_z):
@@ -687,7 +687,7 @@ class SEQ_2_SEQ_GMVAE:
             the output distribution"""
             return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))
 
-        model_outs = [generator(encode_to_vector.outputs)]
+        model_outs = [generator(encoder.outputs)]
         model_losses = [log_loss]
         model_metrics = {"vae_reconstruction": ["mae", "mse"]}
         loss_weights = [1.0]
@@ -736,9 +736,9 @@ class SEQ_2_SEQ_GMVAE:
             loss_weights.append(self.phenotype_prediction)
 
         # define grouper and end-to-end autoencoder model
-        grouper = Model(encode_to_vector.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
+        grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
         gmvaep = Model(
-            inputs=encode_to_vector.inputs,
+            inputs=encoder.inputs,
             outputs=model_outs,
             name="SEQ_2_SEQ_GMVAE",
         )
@@ -754,11 +754,12 @@ class SEQ_2_SEQ_GMVAE:
         gmvaep.build(input_shape)
 
         return (
-            encode_to_vector,
-            encode_to_distribution,
+            encoder,
             generator,
             grouper,
             gmvaep,
+            self.prior,
+            posterior,
         )
 
     @prior.setter
diff --git a/supplementary_notebooks/deepof_model_evaluation.ipynb b/supplementary_notebooks/deepof_model_evaluation.ipynb
index de62ca5557c95aa9440697eb9cb3304a7a28dea3..8e1bbe4075a4b6afa2a5190d43aab72a0b75de46 100644
--- a/supplementary_notebooks/deepof_model_evaluation.ipynb
+++ b/supplementary_notebooks/deepof_model_evaluation.ipynb
@@ -15219,165 +15219,24 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Retrieve latent distribution parameters and sample from posterior\n",
-    "def get_median_params(component, categories, cluster, param):\n",
-    "    # means = [np.median(component.mean().numpy(), axis=0) for component in mix_components]\n",
-    "    # stddevs = [np.median(component.stddev().numpy(), axis=0) for component in mix_components]\n",
-    "    if param == \"mean\":\n",
-    "        component = component.mean().numpy()\n",
-    "    elif param == \"stddev\":\n",
-    "        component = component.stddev().numpy()\n",
-    "\n",
-    "    cluster_select = np.argmax(categories, axis=1) == cluster\n",
-    "    if np.sum(cluster_select) == 0:\n",
-    "        return None\n",
-    "    component = component[cluster_select]\n",
-    "    return np.median(component, axis=0)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 19,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def retrieve_latent_parameters(\n",
-    "    distribution, reduce=False, plot=False, categories=None, filt=0, save=True\n",
-    "):\n",
-    "    mix_components = distribution.components\n",
-    "\n",
-    "    # The main problem is here! We need to select only those training instances in which a given cluster was selected.\n",
-    "    # Then compute the median for those only\n",
-    "\n",
-    "    means = [\n",
-    "        get_median_params(component, categories, i, \"mean\")\n",
-    "        for i, component in enumerate(mix_components)\n",
-    "    ]\n",
-    "    stddevs = [\n",
-    "        get_median_params(component, categories, i, \"stddev\")\n",
-    "        for i, component in enumerate(mix_components)\n",
-    "    ]\n",
-    "    means = [i for i in means if i is not None]\n",
-    "    stddevs = [i for i in stddevs if i is not None]\n",
-    "\n",
-    "    filts = np.max(categories, axis=0) > filt\n",
-    "    means = [i for i, j in zip(means, filts) if j]\n",
-    "    stddevs = [i for i, j in zip(stddevs, filts) if j]\n",
-    "\n",
-    "    return means, stddevs\n",
-    "\n",
-    "\n",
-    "def sample_from_posterior(\n",
-    "    decoder, parameters, component, enable_variance=False, samples=1\n",
-    "):\n",
-    "    means, stddevs = parameters\n",
-    "    sample = np.random.normal(\n",
-    "        size=[samples, len(means[component])],\n",
-    "        loc=means[component],\n",
-    "        scale=(stddevs[component] if enable_variance else 0),\n",
-    "    )\n",
-    "    reconstruction = decoder(sample).mean()\n",
-    "\n",
-    "    return reconstruction"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 105,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "samples = np.random.choice(range(data_prep.shape[0]), 10000)\n",
-    "latent_distribution = encode_to_distribution(data_prep[samples])"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 101,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "means, stddevs = retrieve_latent_parameters(\n",
-    "    latent_distribution,\n",
-    "    categories=groupings[samples],\n",
-    "    reduce=False,\n",
-    "    plot=False,\n",
-    "    filt=0.,\n",
-    "    save=False,\n",
-    ")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 104,
+   "execution_count": 153,
    "metadata": {},
    "outputs": [
     {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "fd7e515df218432b88e6fa7a6a75f9de",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "interactive(children=(Dropdown(description='dim_red', options=('PCA', 'LDA', 'umap', 'tSNE'), value='PCA'), In…"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
+     "ename": "AttributeError",
+     "evalue": "'Functional' object has no attribute 'prior'",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-153-0d3135aace65>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# Get prior distribution\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mgmvaep\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprior\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;31mAttributeError\u001b[0m: 'Functional' object has no attribute 'prior'"
+     ]
     }
    ],
    "source": [
-    "# Plot sampled data in the latent space\n",
-    "\n",
-    "\n",
-    "@interact(dim_red=[\"PCA\", \"LDA\", \"umap\", \"tSNE\"], samples=(5000, 15000))\n",
-    "def plot_static_latent_space(dim_red, samples):\n",
-    "    if dim_red == \"umap\":\n",
-    "        reducer = umap.UMAP(n_components=2)\n",
-    "    elif dim_red == \"LDA\":\n",
-    "        reducer = LinearDiscriminantAnalysis(n_components=2)\n",
-    "    elif dim_red == \"PCA\":\n",
-    "        reducer = PCA(n_components=2)\n",
-    "    else:\n",
-    "        reducer = TSNE(n_components=2)\n",
-    "\n",
-    "    categories = latent_distribution.cat.sample(samples).numpy().flatten()\n",
-    "    mixture_sample = np.squeeze(\n",
-    "        np.concatenate(\n",
-    "            [latent_distribution.components[i].sample(1) for i in categories]\n",
-    "        )\n",
-    "    )\n",
-    "\n",
-    "    print(mixture_sample.shape)\n",
-    "\n",
-    "    if dim_red != \"LDA\":\n",
-    "        enc = reducer.fit_transform(mixture_sample)\n",
-    "    else:\n",
-    "        enc = reducer.fit_transform(\n",
-    "            mixture_sample,\n",
-    "            np.repeat(range(len(latent_distribution.components)), categories),\n",
-    "        )\n",
-    "\n",
-    "    plt.figure(figsize=(12, 8))\n",
-    "\n",
-    "    sns.scatterplot(enc[:, 0], enc[:, 1], hue=categories, palette=\"muted\")\n",
-    "    plt.title(\n",
-    "        \"Mean representation of latent space - K={}/{} - L={}\".format(\n",
-    "            len(means),\n",
-    "            len(latent_distribution.components),\n",
-    "            len(latent_distribution.components),\n",
-    "        )\n",
-    "    )\n",
-    "\n",
-    "    plt.xlabel(\"{} 1\".format(dim_red))\n",
-    "    plt.ylabel(\"{} 2\".format(dim_red))\n",
-    "    plt.suptitle(\"Static view of trained latent space\")\n",
-    "    plt.show()"
+    "# Get prior distribution\n",
+    "gmvaep.prior"
    ]
   },
   {