{ "cells": [ { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] } ], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# deepOF model evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given a dataset and a trained model, this notebook allows the user to \n", "\n", "* Load and inspect the different models (encoder, decoder, grouper, gmvaep)\n", "* Visualize reconstruction quality for a given model\n", "* Visualize a static latent space\n", "* Visualize trajectories on the latent space for a given video\n", "* sample from the latent space distributions and generate video clips showcasing generated data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.chdir(os.path.dirname(\"../\"))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import deepof.data\n", "import deepof.utils\n", "import numpy as np\n", "import pandas as pd\n", "import re\n", "import tensorflow as tf\n", "from collections import Counter\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "from sklearn.manifold import TSNE\n", "from sklearn.decomposition import PCA\n", "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", "import umap\n", "\n", "from ipywidgets import interactive, interact, HBox, Layout, VBox\n", "from IPython import display\n", "from matplotlib.animation import FuncAnimation\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "from ipywidgets import interact" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Define and run project" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "path = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof-data\", \"deepof_single_topview\")\n", "trained_network = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof_trained_weights\")\n", "exclude_bodyparts = tuple([\"\"])\n", "window_size = 24" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 42.6 s, sys: 3.06 s, total: 45.6 s\n", "Wall time: 38.1 s\n" ] } ], "source": [ "%%time\n", "proj = deepof.data.project(\n", " path=path, smooth_alpha=0.999, exclude_bodyparts=exclude_bodyparts, arena_dims=[380],\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading trajectories...\n", "Smoothing trajectories...\n", "Interpolating outliers...\n", "Iterative imputation of ocluded bodyparts...\n", "Computing distances...\n", "Computing angles...\n", "Done!\n", "deepof analysis of 166 videos\n", "CPU times: user 11min 12s, sys: 12.8 s, total: 11min 25s\n", "Wall time: 2min 19s\n" ] } ], "source": [ "%%time\n", "proj = proj.run(verbose=True)\n", "print(proj)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Load pretrained deepof model" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "coords = proj.get_coords(center=\"Center\", align=\"Spine_1\", align_inplace=True)\n", "data_prep = coords.preprocess(test_videos=0, window_step=1, window_size=window_size, shuffle=False)[\n", " 0\n", "]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=3_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=2_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=8_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=1_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=9_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=5_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=6_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=10_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=7_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=3_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=2_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=8_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=9_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=4_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=6_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=7_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=3_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=10_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=10_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=8_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=1_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=5_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=4_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=7_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=2_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=9_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=1_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=5_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=4_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=6_final_weights.h5']" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[i for i in os.listdir(trained_network) if i.endswith(\"h5\")]" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=2_final_weights.h5'" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "deepof_weights = [i for i in os.listdir(trained_network) if i.endswith(\"h5\")][1]\n", "deepof_weights" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "# Set model parameters\n", "encoding = int(re.findall(\"encoding=(\\d+)_\", deepof_weights)[0])\n", "k = int(re.findall(\"k=(\\d+)_\", deepof_weights)[0])\n", "loss = re.findall(\"loss=(.+?)_\", deepof_weights)[0]\n", "NextSeqPred = float(re.findall(\"NextSeqPred=(.+?)_\", deepof_weights)[0])\n", "PhenoPred = float(re.findall(\"PhenoPred=(.+?)_\", deepof_weights)[0])\n", "RuleBasedPred = float(re.findall(\"RuleBasedPred=(.+?)_\", deepof_weights)[0])" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "(\n", " encode_to_vector,\n", " decoder,\n", " grouper,\n", " gmvaep,\n", " prior,\n", " posterior,\n", ") = deepof.models.SEQ_2_SEQ_GMVAE(\n", " loss=loss,\n", " number_of_components=k,\n", " compile_model=True,\n", " encoding=encoding,\n", " next_sequence_prediction=NextSeqPred,\n", " phenotype_prediction=PhenoPred,\n", " rule_based_prediction=RuleBasedPred,\n", ").build(\n", " data_prep.shape\n", ")\n", "\n", "gmvaep.load_weights(os.path.join(trained_network, deepof_weights))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Uncomment to see model summaries\n", "# encoder.summary()\n", "# decoder.summary()\n", "# grouper.summary()\n", "# gmvaep.summary()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# Uncomment to plot model structure\n", "def plot_model(model, name):\n", " tf.keras.utils.plot_model(\n", " model,\n", " to_file=os.path.join(\n", " path,\n", " \"deepof_{}_{}.png\".format(name, datetime.now().strftime(\"%Y%m%d-%H%M%S\")),\n", " ),\n", " show_shapes=True,\n", " show_dtype=False,\n", " show_layer_names=True,\n", " rankdir=\"TB\",\n", " expand_nested=True,\n", " dpi=200,\n", " )\n", "\n", "\n", "# plot_model(encoder, \"encoder\")\n", "# plot_model(decoder, \"decoder\")\n", "# plot_model(grouper, \"grouper\")\n", "# plot_model(gmvaep, \"gmvaep\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Evaluate reconstruction (to be incorporated into deepof.evaluate)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# Auxiliary animation functions\n", "\n", "\n", "def plot_mouse_graph(instant_x, instant_y, instant_rec_x, instant_rec_y, ax, edges):\n", " \"\"\"Generates a graph plot of the mouse\"\"\"\n", " plots = []\n", " rec_plots = []\n", " for edge in edges:\n", " (temp_plot,) = ax.plot(\n", " [float(instant_x[edge[0]]), float(instant_x[edge[1]])],\n", " [float(instant_y[edge[0]]), float(instant_y[edge[1]])],\n", " color=\"#006699\",\n", " linewidth=2.0,\n", " )\n", " (temp_rec_plot,) = ax.plot(\n", " [float(instant_rec_x[edge[0]]), float(instant_rec_x[edge[1]])],\n", " [float(instant_rec_y[edge[0]]), float(instant_rec_y[edge[1]])],\n", " color=\"red\",\n", " linewidth=2.0,\n", " )\n", " plots.append(temp_plot)\n", " rec_plots.append(temp_rec_plot)\n", " return plots, rec_plots\n", "\n", "\n", "def update_mouse_graph(x, y, rec_x, rec_y, plots, rec_plots, edges):\n", " \"\"\"Updates the graph plot to enable animation\"\"\"\n", "\n", " for plot, edge in zip(plots, edges):\n", " plot.set_data(\n", " [float(x[edge[0]]), float(x[edge[1]])],\n", " [float(y[edge[0]]), float(y[edge[1]])],\n", " )\n", " for plot, edge in zip(rec_plots, edges):\n", " plot.set_data(\n", " [float(rec_x[edge[0]]), float(rec_x[edge[1]])],\n", " [float(rec_y[edge[0]]), float(rec_y[edge[1]])],\n", " )" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test 43_s12\n" ] }, { "data": { "text/html": [ "