{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# deepOF model evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given a dataset and a trained model, this notebook allows the user to \n", "\n", "* Load and inspect the different models (encoder, decoder, grouper, gmvaep)\n", "* Visualize reconstruction quality for a given model\n", "* Visualize a static latent space\n", "* Visualize trajectories on the latent space for a given video\n", "* sample from the latent space distributions and generate video clips showcasing generated data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.chdir(os.path.dirname(\"../\"))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import deepof.data\n", "import deepof.utils\n", "import numpy as np\n", "import pandas as pd\n", "import re\n", "import tensorflow as tf\n", "from collections import Counter\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "from sklearn.manifold import TSNE\n", "from sklearn.decomposition import PCA\n", "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", "import umap\n", "\n", "from ipywidgets import interactive, interact, HBox, Layout, VBox\n", "from IPython import display\n", "from matplotlib.animation import FuncAnimation\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "from ipywidgets import interact" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Define and run project" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "path = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof-data\", \"deepof_single_topview\")\n", "trained_network = os.path.join(\"..\", \"..\", \"Desktop\")\n", "exclude_bodyparts = tuple([\"\"])\n", "window_size = 24" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 42.2 s, sys: 3.13 s, total: 45.4 s\n", "Wall time: 37.9 s\n" ] } ], "source": [ "%%time\n", "proj = deepof.data.project(\n", " path=path, smooth_alpha=0.999, exclude_bodyparts=exclude_bodyparts, arena_dims=[380],\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading trajectories...\n", "Smoothing trajectories...\n", "Interpolating outliers...\n", "Iterative imputation of ocluded bodyparts...\n", "Computing distances...\n", "Computing angles...\n", "Done!\n", "deepof analysis of 166 videos\n", "CPU times: user 12min 28s, sys: 13.5 s, total: 12min 41s\n", "Wall time: 2min 31s\n" ] } ], "source": [ "%%time\n", "proj = proj.run(verbose=True)\n", "print(proj)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Load pretrained deepof model" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "coords = proj.get_coords(center=\"Center\", align=\"Spine_1\", align_inplace=True)\n", "data_prep = coords.preprocess(test_videos=0, window_step=1, window_size=window_size, shuffle=False)[\n", " 0\n", "]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=3_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=2_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=9_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=3_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=2_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=9_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=10_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=10_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=8_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=1_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=5_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=7_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=1_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=5_final_weights.h5']" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[i for i in os.listdir(trained_network) if i.endswith(\"h5\")]" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=9_final_weights.h5'" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "deepof_weights = [i for i in os.listdir(trained_network) if i.endswith(\"h5\")][2]\n", "deepof_weights" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# Set model parameters\n", "encoding = int(re.findall(\"encoding=(\\d+)_\", deepof_weights)[0])\n", "k = int(re.findall(\"k=(\\d+)_\", deepof_weights)[0])\n", "loss = re.findall(\"loss=(.+?)_\", deepof_weights)[0]\n", "NextSeqPred = float(re.findall(\"NextSeqPred=(.+?)_\", deepof_weights)[0])\n", "PhenoPred = float(re.findall(\"PhenoPred=(.+?)_\", deepof_weights)[0])\n", "RuleBasedPred = float(re.findall(\"RuleBasedPred=(.+?)_\", deepof_weights)[0])" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "(\n", " encode_to_vector,\n", " decoder,\n", " grouper,\n", " gmvaep,\n", " prior,\n", " posterior,\n", ") = deepof.models.SEQ_2_SEQ_GMVAE(\n", " loss=loss,\n", " number_of_components=k,\n", " compile_model=True,\n", " encoding=encoding,\n", " next_sequence_prediction=NextSeqPred,\n", " phenotype_prediction=PhenoPred,\n", " rule_based_prediction=RuleBasedPred,\n", ").build(\n", " data_prep.shape\n", ")\n", "\n", "gmvaep.load_weights(os.path.join(trained_network, deepof_weights))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Uncomment to see model summaries\n", "# encoder.summary()\n", "# decoder.summary()\n", "# grouper.summary()\n", "# gmvaep.summary()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "# Uncomment to plot model structure\n", "def plot_model(model, name):\n", " tf.keras.utils.plot_model(\n", " model,\n", " to_file=os.path.join(\n", " path,\n", " \"deepof_{}_{}.png\".format(name, datetime.now().strftime(\"%Y%m%d-%H%M%S\")),\n", " ),\n", " show_shapes=True,\n", " show_dtype=False,\n", " show_layer_names=True,\n", " rankdir=\"TB\",\n", " expand_nested=True,\n", " dpi=200,\n", " )\n", "\n", "\n", "# plot_model(encoder, \"encoder\")\n", "# plot_model(decoder, \"decoder\")\n", "# plot_model(grouper, \"grouper\")\n", "# plot_model(gmvaep, \"gmvaep\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Evaluate reconstruction (to be incorporated into deepof.evaluate)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# Auxiliary animation functions\n", "\n", "\n", "def plot_mouse_graph(instant_x, instant_y, instant_rec_x, instant_rec_y, ax, edges):\n", " \"\"\"Generates a graph plot of the mouse\"\"\"\n", " plots = []\n", " rec_plots = []\n", " for edge in edges:\n", " (temp_plot,) = ax.plot(\n", " [float(instant_x[edge[0]]), float(instant_x[edge[1]])],\n", " [float(instant_y[edge[0]]), float(instant_y[edge[1]])],\n", " color=\"#006699\",\n", " linewidth=2.0,\n", " )\n", " (temp_rec_plot,) = ax.plot(\n", " [float(instant_rec_x[edge[0]]), float(instant_rec_x[edge[1]])],\n", " [float(instant_rec_y[edge[0]]), float(instant_rec_y[edge[1]])],\n", " color=\"red\",\n", " linewidth=2.0,\n", " )\n", " plots.append(temp_plot)\n", " rec_plots.append(temp_rec_plot)\n", " return plots, rec_plots\n", "\n", "\n", "def update_mouse_graph(x, y, rec_x, rec_y, plots, rec_plots, edges):\n", " \"\"\"Updates the graph plot to enable animation\"\"\"\n", "\n", " for plot, edge in zip(plots, edges):\n", " plot.set_data(\n", " [float(x[edge[0]]), float(x[edge[1]])],\n", " [float(y[edge[0]]), float(y[edge[1]])],\n", " )\n", " for plot, edge in zip(rec_plots, edges):\n", " plot.set_data(\n", " [float(rec_x[edge[0]]), float(rec_x[edge[1]])],\n", " [float(rec_y[edge[0]]), float(rec_y[edge[1]])],\n", " )" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test 49_s41\n" ] }, { "data": { "text/html": [ "