diff --git a/deepof/train_model.py b/deepof/train_model.py index 613d9574987df8c2f22e9922ca12b10cf2453c0b..0b233b97b57eea2a1c439cd5112e8a286de0cfbd 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -293,6 +293,7 @@ coords = project_coords.get_coords( center=animal_id + undercond + "Center", align=animal_id + undercond + "Spine_1", align_inplace=True, + propagate_labels=(pheno_class > 0) ) distances = project_coords.get_distances() angles = project_coords.get_angles() diff --git a/deepof_experiments.smk b/deepof_experiments.smk index 6bff17828b3cd58c910df742c0463addbc3fafcf..e16646885f8e0ef3bce2cd21fec3709adc22a3ab 100644 --- a/deepof_experiments.smk +++ b/deepof_experiments.smk @@ -15,8 +15,9 @@ import os outpath = "/u/lucasmir/DLC/DLC_autoencoders/DeepOF/deepof/logs/" losses = ["ELBO"]#, "MMD", "ELBO+MMD"] -encodings = [2, 4, 6, 8, 10, 12, 14, 16] -cluster_numbers = [1, 5, 10, 15, 20] +encodings = [4, 6, 8]#[2, 4, 6, 8, 10, 12, 14, 16] +cluster_numbers = [10, 15]#[1, 5, 10, 15, 20] +pheno_weights = [0.01, 0.1, 0.25, 0.5, 1, 2, 4, 10, 100] rule deepof_experiments: input: @@ -27,6 +28,14 @@ rule deepof_experiments: encs=encodings, k=cluster_numbers, ), + expand( + "/u/lucasmir/DLC/DLC_autoencoders/DeepOF/deepof/logs/pheno_classification_experiments/trained_weights/" + "GMVAE_loss={loss}_encoding={encs}_k={k}_pheno={phenos}_run_1_final_weights.h5", + loss=losses, + encs=encodings, + k=cluster_numbers, + pheno=pheno_weights, + ) # rule coarse_hyperparameter_tuning: @@ -71,6 +80,36 @@ rule explore_encoding_dimension_and_loss_function: "--components {wildcards.k} " "--input-type coords " "--predictor 0 " + "--phenotype-classifier 0" + "--variational True " + "--loss {wildcards.loss} " + "--kl-warmup 20 " + "--mmd-warmup 20 " + "--montecarlo-kl 10 " + "--encoding-size {wildcards.encs} " + "--batch-size 256 " + "--window-size 11 " + "--window-step 11 " + "--stability-check 3 " + "--output-path {outpath}dimension_and_loss_experiments" + + +rule explore_phenotype_classification: + input: + data_path="/u/lucasmir/DLC/DLC_models/deepof_single_topview/", + output: + trained_models=os.path.join( + outpath, + "pheno_classification_experiments/trained_weights/GMVAE_loss={loss}_encoding={encs}_k={k}_pheno={phenos}_run_1_final_weights.h5", + ), + shell: + "pipenv run python -m deepof.train_model " + "--train-path {input.data_path} " + "--val-num 5 " + "--components {wildcards.k} " + "--input-type coords " + "--predictor 0 " + "--phenotype-classifier {wildcards.phenos}" "--variational True " "--loss {wildcards.loss} " "--kl-warmup 20 " diff --git a/examples/main.ipynb b/examples/main.ipynb index 0579d8bbecf2e4bd9a3d5a199e4720e3e02754f2..3d1eded62f8ebf2b7db51b054c5b02b54140a213 100644 --- a/examples/main.ipynb +++ b/examples/main.ipynb @@ -48,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 91, "metadata": {}, "outputs": [], "source": [ @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 92, "metadata": {}, "outputs": [], "source": [ @@ -71,8 +71,8 @@ "dset11 = pd.read_excel(dset11, \"Tabelle2\")\n", "dset12 = pd.read_excel(dset12, \"Tabelle2\")\n", "\n", - "dset11.Test = dset11.Test.apply(lambda x: \"Test {}_s1.1\".format(x))\n", - "dset12.Test = dset12.Test.apply(lambda x: \"Test {}_s1.2\".format(x))\n", + "dset11.Test = dset11.Test.apply(lambda x: \"Test {}_s11\".format(x))\n", + "dset12.Test = dset12.Test.apply(lambda x: \"Test {}_s12\".format(x))\n", "\n", "dset1 = {\"CSDS\":list(dset11.loc[dset11.Treatment.isin([\"CTR+CSDS\",\"NatCre+CSDS\"]), \"Test\"]) + \n", " list(dset12.loc[dset12.Treatment.isin([\"CTR+CSDS\",\"NatCre+CSDS\"]), \"Test\"]),\n", @@ -91,7 +91,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 93, "metadata": {}, "outputs": [], "source": [ @@ -102,8 +102,8 @@ "dset22 = pd.read_excel(\n", " \"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_2/Part2/2_Single/OpenFieldvideos-part2.xlsx\"\n", ")\n", - "dset21.Test = dset21.Test.apply(lambda x: \"Test {}_s2.1\".format(x))\n", - "dset22.Test = dset22.Test.apply(lambda x: \"Test {}_s2.2\".format(x))\n", + "dset21.Test = dset21.Test.apply(lambda x: \"Test {}_s21\".format(x))\n", + "dset22.Test = dset22.Test.apply(lambda x: \"Test {}_s22\".format(x))\n", "\n", "dset2 = {\"CSDS\":list(dset21.loc[dset21.Treatment == \"Stress\", \"Test\"]) + \n", " list(dset22.loc[dset22.Treatment == \"Stressed\", \"Test\"]),\n", @@ -122,20 +122,22 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 94, "metadata": {}, "outputs": [], "source": [ "# Load third batch\n", "\n", "dset31 = pd.read_excel(\n", - " \"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/1.Day2OF-SIpart1/JB05 2Female-ELS-OF-SIpart1.xlsx\"\n", + " \"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/1.Day2OF-SIpart1/JB05 2Female-ELS-OF-SIpart1.xlsx\",\n", + " sheet_name=1\n", ")\n", "dset32 = pd.read_excel(\n", - " \"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/2.Day3OF-SIpart2/JB05 2FEMALE-ELS-OF-SIpart2.xlsx\"\n", + " \"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/2.Day3OF-SIpart2/JB05 2FEMALE-ELS-OF-SIpart2.xlsx\",\n", + " sheet_name=1\n", ")\n", - "dset31.Test = dset31.Test.apply(lambda x: \"Test {}_s3.1\".format(x))\n", - "dset32.Test = dset32.Test.apply(lambda x: \"Test {}_s3.2\".format(x))\n", + "dset31.Test = dset31.Test.apply(lambda x: \"Test {}_s31\".format(x))\n", + "dset32.Test = dset32.Test.apply(lambda x: \"Test {}_s32\".format(x))\n", "\n", "dset3 = {\"CSDS\":[],\n", " \"NS\": list(dset31.loc[:, \"Test\"]) +\n", @@ -153,7 +155,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 95, "metadata": {}, "outputs": [], "source": [ @@ -164,7 +166,7 @@ "dset41 = [vid for vid in dset41 if \"52\" not in vid]\n", "\n", "dset4 = {\"CSDS\":[],\n", - " \"NS\": [i[:-4]+\"_s4\" for i in dset41]}\n", + " \"NS\": [i[:-4]+\"_s41\" for i in dset41]}\n", "\n", "dset4inv = {}\n", "for i in flatten(list(dset4.values())):\n", @@ -178,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 96, "metadata": {}, "outputs": [], "source": [ @@ -188,7 +190,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 97, "metadata": {}, "outputs": [ { @@ -206,6 +208,18 @@ "print(115+52)" ] }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [], + "source": [ + "# Save aggregated dataset to disk\n", + "import pickle\n", + "with open(\"../../Desktop/deepof-data/deepof_single_topview/deepof_exp_conditions.pkl\", \"wb\") as handle:\n", + " pickle.dump(aggregated_dset, handle, protocol=pickle.HIGHEST_PROTOCOL)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -215,15 +229,15 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 99, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 28.1 s, sys: 4.99 s, total: 33.1 s\n", - "Wall time: 7.5 s\n" + "CPU times: user 28 s, sys: 4.9 s, total: 32.9 s\n", + "Wall time: 7.4 s\n" ] } ], @@ -233,13 +247,13 @@ " smooth_alpha=0.99, \n", " arena_dims=[380],\n", " #exclude_bodyparts=[\"Tail_1\", \"Tail_2\", \"Tail_tip\", \"Tail_base\", \"Spine_2\"]\n", - " #exp_conditions=dset2inv\n", + " exp_conditions=aggregated_dset\n", " )" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 100, "metadata": {}, "outputs": [ { @@ -251,9 +265,9 @@ "Computing distances...\n", "Computing angles...\n", "Done!\n", - "deepof analysis of 167 videos\n", - "CPU times: user 46.6 s, sys: 5.09 s, total: 51.7 s\n", - "Wall time: 53 s\n" + "Coordinates of 167 videos across 2 conditions\n", + "CPU times: user 46.3 s, sys: 3.75 s, total: 50 s\n", + "Wall time: 52.4 s\n" ] } ], @@ -265,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 101, "metadata": {}, "outputs": [], "source": [ @@ -274,7 +288,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 102, "metadata": {}, "outputs": [ { @@ -298,13 +312,13 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 103, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9a3430d72cb64caba2c1679487fdb94b", + "model_id": "62c6f88ba7ff4ca5965fefb1c204f460", "version_major": 2, "version_minor": 0 }, @@ -339,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 113, "metadata": { "scrolled": true }, @@ -348,14 +362,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 54.3 s, sys: 475 ms, total: 54.8 s\n", - "Wall time: 54.8 s\n" + "CPU times: user 1min 2s, sys: 2.29 s, total: 1min 4s\n", + "Wall time: 1min 7s\n" ] } ], "source": [ "%%time\n", - "deepof_coords = deepof_main.get_coords(center=\"Center\", polar=False, speed=0, align=\"Spine_1\", align_inplace=True, propagate_labels=False)\n", + "deepof_coords = deepof_main.get_coords(center=\"Center\", polar=False, speed=0, align=\"Spine_1\", align_inplace=True, propagate_labels=True)\n", "#deepof_dists = deepof_main.get_distances(propagate_labels=False)\n", "#deepof_angles = deepof_main.get_angles(propagate_labels=False)" ] @@ -369,21 +383,9 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Preprocessing training set...\n", - "Loading pre-trained model...\n", - "(227686, 11, 26)\n", - "CPU times: user 3.35 s, sys: 656 ms, total: 4.01 s\n", - "Wall time: 4.07 s\n" - ] - } - ], + "outputs": [], "source": [ "%%time\n", "\n", @@ -417,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -449,51 +451,18 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "<tf.Tensor: shape=(), dtype=float64, numpy=0.8289839462658529>" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "tf.reduce_mean( tf.keras.losses.mean_absolute_error(video_pred, video_input[:, 6, :]))" ] }, { "cell_type": "code", - "execution_count": 48, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c442eba8214949e19aee3eea4d99a736", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], + "outputs": [], "source": [ "# Draft: function to produce a video with the animal in motion using cv2\n", "import cv2\n",