diff --git a/deepof/train_model.py b/deepof/train_model.py
index 613d9574987df8c2f22e9922ca12b10cf2453c0b..0b233b97b57eea2a1c439cd5112e8a286de0cfbd 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -293,6 +293,7 @@ coords = project_coords.get_coords(
     center=animal_id + undercond + "Center",
     align=animal_id + undercond + "Spine_1",
     align_inplace=True,
+    propagate_labels=(pheno_class > 0)
 )
 distances = project_coords.get_distances()
 angles = project_coords.get_angles()
diff --git a/deepof_experiments.smk b/deepof_experiments.smk
index 6bff17828b3cd58c910df742c0463addbc3fafcf..e16646885f8e0ef3bce2cd21fec3709adc22a3ab 100644
--- a/deepof_experiments.smk
+++ b/deepof_experiments.smk
@@ -15,8 +15,9 @@ import os
 
 outpath = "/u/lucasmir/DLC/DLC_autoencoders/DeepOF/deepof/logs/"
 losses = ["ELBO"]#, "MMD", "ELBO+MMD"]
-encodings = [2, 4, 6, 8, 10, 12, 14, 16]
-cluster_numbers = [1, 5, 10, 15, 20]
+encodings = [4, 6, 8]#[2, 4, 6, 8, 10, 12, 14, 16]
+cluster_numbers = [10, 15]#[1, 5, 10, 15, 20]
+pheno_weights = [0.01, 0.1, 0.25, 0.5, 1, 2, 4, 10, 100]
 
 rule deepof_experiments:
     input:
@@ -27,6 +28,14 @@ rule deepof_experiments:
             encs=encodings,
             k=cluster_numbers,
         ),
+        expand(
+            "/u/lucasmir/DLC/DLC_autoencoders/DeepOF/deepof/logs/pheno_classification_experiments/trained_weights/"
+            "GMVAE_loss={loss}_encoding={encs}_k={k}_pheno={phenos}_run_1_final_weights.h5",
+            loss=losses,
+            encs=encodings,
+            k=cluster_numbers,
+            pheno=pheno_weights,
+        )
 
 
 # rule coarse_hyperparameter_tuning:
@@ -71,6 +80,36 @@ rule explore_encoding_dimension_and_loss_function:
         "--components {wildcards.k} "
         "--input-type coords "
         "--predictor 0 "
+        "--phenotype-classifier 0"
+        "--variational True "
+        "--loss {wildcards.loss} "
+        "--kl-warmup 20 "
+        "--mmd-warmup 20 "
+        "--montecarlo-kl 10 "
+        "--encoding-size {wildcards.encs} "
+        "--batch-size 256 "
+        "--window-size 11 "
+        "--window-step 11 "
+        "--stability-check 3  "
+        "--output-path {outpath}dimension_and_loss_experiments"
+
+
+rule explore_phenotype_classification:
+    input:
+        data_path="/u/lucasmir/DLC/DLC_models/deepof_single_topview/",
+    output:
+        trained_models=os.path.join(
+            outpath,
+            "pheno_classification_experiments/trained_weights/GMVAE_loss={loss}_encoding={encs}_k={k}_pheno={phenos}_run_1_final_weights.h5",
+        ),
+    shell:
+        "pipenv run python -m deepof.train_model "
+        "--train-path {input.data_path} "
+        "--val-num 5 "
+        "--components {wildcards.k} "
+        "--input-type coords "
+        "--predictor 0 "
+        "--phenotype-classifier {wildcards.phenos}"
         "--variational True "
         "--loss {wildcards.loss} "
         "--kl-warmup 20 "
diff --git a/examples/main.ipynb b/examples/main.ipynb
index 0579d8bbecf2e4bd9a3d5a199e4720e3e02754f2..3d1eded62f8ebf2b7db51b054c5b02b54140a213 100644
--- a/examples/main.ipynb
+++ b/examples/main.ipynb
@@ -48,7 +48,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 91,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -57,7 +57,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 92,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -71,8 +71,8 @@
     "dset11 = pd.read_excel(dset11, \"Tabelle2\")\n",
     "dset12 = pd.read_excel(dset12, \"Tabelle2\")\n",
     "\n",
-    "dset11.Test = dset11.Test.apply(lambda x: \"Test {}_s1.1\".format(x))\n",
-    "dset12.Test = dset12.Test.apply(lambda x: \"Test {}_s1.2\".format(x))\n",
+    "dset11.Test = dset11.Test.apply(lambda x: \"Test {}_s11\".format(x))\n",
+    "dset12.Test = dset12.Test.apply(lambda x: \"Test {}_s12\".format(x))\n",
     "\n",
     "dset1 = {\"CSDS\":list(dset11.loc[dset11.Treatment.isin([\"CTR+CSDS\",\"NatCre+CSDS\"]), \"Test\"]) + \n",
     "                list(dset12.loc[dset12.Treatment.isin([\"CTR+CSDS\",\"NatCre+CSDS\"]), \"Test\"]),\n",
@@ -91,7 +91,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 93,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -102,8 +102,8 @@
     "dset22 = pd.read_excel(\n",
     "    \"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_2/Part2/2_Single/OpenFieldvideos-part2.xlsx\"\n",
     ")\n",
-    "dset21.Test = dset21.Test.apply(lambda x: \"Test {}_s2.1\".format(x))\n",
-    "dset22.Test = dset22.Test.apply(lambda x: \"Test {}_s2.2\".format(x))\n",
+    "dset21.Test = dset21.Test.apply(lambda x: \"Test {}_s21\".format(x))\n",
+    "dset22.Test = dset22.Test.apply(lambda x: \"Test {}_s22\".format(x))\n",
     "\n",
     "dset2 = {\"CSDS\":list(dset21.loc[dset21.Treatment == \"Stress\", \"Test\"]) + \n",
     "                list(dset22.loc[dset22.Treatment == \"Stressed\", \"Test\"]),\n",
@@ -122,20 +122,22 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 94,
    "metadata": {},
    "outputs": [],
    "source": [
     "# Load third batch\n",
     "\n",
     "dset31 = pd.read_excel(\n",
-    "    \"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/1.Day2OF-SIpart1/JB05 2Female-ELS-OF-SIpart1.xlsx\"\n",
+    "    \"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/1.Day2OF-SIpart1/JB05 2Female-ELS-OF-SIpart1.xlsx\",\n",
+    "    sheet_name=1\n",
     ")\n",
     "dset32 = pd.read_excel(\n",
-    "    \"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/2.Day3OF-SIpart2/JB05 2FEMALE-ELS-OF-SIpart2.xlsx\"\n",
+    "    \"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/2.Day3OF-SIpart2/JB05 2FEMALE-ELS-OF-SIpart2.xlsx\",\n",
+    "    sheet_name=1\n",
     ")\n",
-    "dset31.Test = dset31.Test.apply(lambda x: \"Test {}_s3.1\".format(x))\n",
-    "dset32.Test = dset32.Test.apply(lambda x: \"Test {}_s3.2\".format(x))\n",
+    "dset31.Test = dset31.Test.apply(lambda x: \"Test {}_s31\".format(x))\n",
+    "dset32.Test = dset32.Test.apply(lambda x: \"Test {}_s32\".format(x))\n",
     "\n",
     "dset3 = {\"CSDS\":[],\n",
     "         \"NS\":  list(dset31.loc[:, \"Test\"]) +\n",
@@ -153,7 +155,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 95,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -164,7 +166,7 @@
     "dset41 = [vid for vid in dset41 if \"52\" not in vid]\n",
     "\n",
     "dset4 = {\"CSDS\":[],\n",
-    "         \"NS\":  [i[:-4]+\"_s4\" for i in dset41]}\n",
+    "         \"NS\":  [i[:-4]+\"_s41\" for i in dset41]}\n",
     "\n",
     "dset4inv = {}\n",
     "for i in flatten(list(dset4.values())):\n",
@@ -178,7 +180,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 96,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -188,7 +190,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 97,
    "metadata": {},
    "outputs": [
     {
@@ -206,6 +208,18 @@
     "print(115+52)"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 98,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Save aggregated dataset to disk\n",
+    "import pickle\n",
+    "with open(\"../../Desktop/deepof-data/deepof_single_topview/deepof_exp_conditions.pkl\", \"wb\") as handle:\n",
+    "    pickle.dump(aggregated_dset, handle, protocol=pickle.HIGHEST_PROTOCOL)"
+   ]
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
@@ -215,15 +229,15 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 99,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "CPU times: user 28.1 s, sys: 4.99 s, total: 33.1 s\n",
-      "Wall time: 7.5 s\n"
+      "CPU times: user 28 s, sys: 4.9 s, total: 32.9 s\n",
+      "Wall time: 7.4 s\n"
      ]
     }
    ],
@@ -233,13 +247,13 @@
     "                                  smooth_alpha=0.99,                                     \n",
     "                                  arena_dims=[380],\n",
     "                                  #exclude_bodyparts=[\"Tail_1\", \"Tail_2\", \"Tail_tip\", \"Tail_base\", \"Spine_2\"]\n",
-    "                                  #exp_conditions=dset2inv\n",
+    "                                  exp_conditions=aggregated_dset\n",
     "                                 )"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 100,
    "metadata": {},
    "outputs": [
     {
@@ -251,9 +265,9 @@
       "Computing distances...\n",
       "Computing angles...\n",
       "Done!\n",
-      "deepof analysis of 167 videos\n",
-      "CPU times: user 46.6 s, sys: 5.09 s, total: 51.7 s\n",
-      "Wall time: 53 s\n"
+      "Coordinates of 167 videos across 2 conditions\n",
+      "CPU times: user 46.3 s, sys: 3.75 s, total: 50 s\n",
+      "Wall time: 52.4 s\n"
      ]
     }
    ],
@@ -265,7 +279,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 101,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -274,7 +288,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 102,
    "metadata": {},
    "outputs": [
     {
@@ -298,13 +312,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 103,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "9a3430d72cb64caba2c1679487fdb94b",
+       "model_id": "62c6f88ba7ff4ca5965fefb1c204f460",
        "version_major": 2,
        "version_minor": 0
       },
@@ -339,7 +353,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 113,
    "metadata": {
     "scrolled": true
    },
@@ -348,14 +362,14 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "CPU times: user 54.3 s, sys: 475 ms, total: 54.8 s\n",
-      "Wall time: 54.8 s\n"
+      "CPU times: user 1min 2s, sys: 2.29 s, total: 1min 4s\n",
+      "Wall time: 1min 7s\n"
      ]
     }
    ],
    "source": [
     "%%time\n",
-    "deepof_coords = deepof_main.get_coords(center=\"Center\", polar=False, speed=0, align=\"Spine_1\", align_inplace=True, propagate_labels=False)\n",
+    "deepof_coords = deepof_main.get_coords(center=\"Center\", polar=False, speed=0, align=\"Spine_1\", align_inplace=True, propagate_labels=True)\n",
     "#deepof_dists  = deepof_main.get_distances(propagate_labels=False)\n",
     "#deepof_angles = deepof_main.get_angles(propagate_labels=False)"
    ]
@@ -369,21 +383,9 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 29,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Preprocessing training set...\n",
-      "Loading pre-trained model...\n",
-      "(227686, 11, 26)\n",
-      "CPU times: user 3.35 s, sys: 656 ms, total: 4.01 s\n",
-      "Wall time: 4.07 s\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "%%time\n",
     "\n",
@@ -417,7 +419,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 46,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -449,51 +451,18 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 47,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "<tf.Tensor: shape=(), dtype=float64, numpy=0.8289839462658529>"
-      ]
-     },
-     "execution_count": 47,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
     "tf.reduce_mean( tf.keras.losses.mean_absolute_error(video_pred, video_input[:, 6, :]))"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 48,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "c442eba8214949e19aee3eea4d99a736",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "# Draft: function to produce a video with the animal in motion using cv2\n",
     "import cv2\n",