{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# deepOF model evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given a dataset and a trained model, this notebook allows the user to \n", "\n", "* Load and inspect the different models (encoder, decoder, grouper, gmvaep)\n", "* Visualize reconstruction quality for a given model\n", "* Visualize a static latent space\n", "* Visualize trajectories on the latent space for a given video\n", "* sample from the latent space distributions and generate video clips showcasing generated data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.chdir(os.path.dirname(\"../\"))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import deepof.data\n", "import deepof.utils\n", "import numpy as np\n", "import pandas as pd\n", "import re\n", "import tensorflow as tf\n", "from collections import Counter\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "from sklearn.manifold import TSNE\n", "from sklearn.decomposition import PCA\n", "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", "import umap\n", "\n", "from ipywidgets import interactive, interact, HBox, Layout, VBox\n", "from IPython import display\n", "from matplotlib.animation import FuncAnimation\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "from ipywidgets import interact" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Define and run project" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "path = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof-data\", \"deepof_single_topview\")\n", "trained_network = os.path.join(\"..\", \"..\", \"Desktop\")\n", "exclude_bodyparts = tuple([\"\"])\n", "window_size = 24" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 41.7 s, sys: 3.12 s, total: 44.8 s\n", "Wall time: 37.5 s\n" ] } ], "source": [ "%%time\n", "proj = deepof.data.project(\n", " path=path, smooth_alpha=0.999, exclude_bodyparts=exclude_bodyparts, arena_dims=[380],\n", ")" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading trajectories...\n", "Smoothing trajectories...\n", "Interpolating outliers...\n", "Iterative imputation of ocluded bodyparts...\n", "Computing distances...\n", "Computing angles...\n", "Done!\n", "deepof analysis of 166 videos\n", "CPU times: user 12min 42s, sys: 25.4 s, total: 13min 8s\n", "Wall time: 3min 4s\n" ] } ], "source": [ "%%time\n", "proj = proj.run(verbose=True)\n", "print(proj)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Load pretrained deepof model" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "coords = proj.get_coords(center=\"Center\", align=\"Spine_1\", align_inplace=True)\n", "data_prep = coords.preprocess(test_videos=0, window_step=1, window_size=window_size, shuffle=False)[\n", " 0\n", "]" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=3_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=2_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=9_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=3_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=2_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=9_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=10_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=10_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=8_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=1_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=5_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=7_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=1_final_weights.h5',\n", " 'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=5_final_weights.h5']" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[i for i in os.listdir(trained_network) if i.endswith(\"h5\")]" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'GMVAE_input_type=coords_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_encoding=6_k=15_latreg=none_entknn=100_run=8_final_weights.h5'" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "deepof_weights = [i for i in os.listdir(trained_network) if i.endswith(\"h5\")][8]\n", "deepof_weights" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "# Set model parameters\n", "encoding = int(re.findall(\"encoding=(\\d+)_\", deepof_weights)[0])\n", "k = int(re.findall(\"k=(\\d+)_\", deepof_weights)[0])\n", "loss = re.findall(\"loss=(.+?)_\", deepof_weights)[0]\n", "pheno = 0\n", "predictor = 0" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "(\n", " encode_to_vector,\n", " decoder,\n", " grouper,\n", " gmvaep,\n", " prior,\n", " posterior,\n", ") = deepof.models.SEQ_2_SEQ_GMVAE(\n", " loss=loss,\n", " number_of_components=k,\n", " compile_model=True,\n", " encoding=encoding,\n", " next_sequence_prediction=predictor,\n", " phenotype_prediction=pheno,\n", ").build(\n", " data_prep.shape\n", ")\n", "\n", "gmvaep.load_weights(os.path.join(trained_network, deepof_weights))" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Uncomment to see model summaries\n", "# encoder.summary()\n", "# decoder.summary()\n", "# grouper.summary()\n", "# gmvaep.summary()" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "# Uncomment to plot model structure\n", "def plot_model(model, name):\n", " tf.keras.utils.plot_model(\n", " model,\n", " to_file=os.path.join(\n", " path,\n", " \"deepof_{}_{}.png\".format(name, datetime.now().strftime(\"%Y%m%d-%H%M%S\")),\n", " ),\n", " show_shapes=True,\n", " show_dtype=False,\n", " show_layer_names=True,\n", " rankdir=\"TB\",\n", " expand_nested=True,\n", " dpi=200,\n", " )\n", "\n", "\n", "# plot_model(encoder, \"encoder\")\n", "# plot_model(decoder, \"decoder\")\n", "# plot_model(grouper, \"grouper\")\n", "# plot_model(gmvaep, \"gmvaep\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Evaluate reconstruction (to be incorporated into deepof.evaluate)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "# Auxiliary animation functions\n", "\n", "\n", "def plot_mouse_graph(instant_x, instant_y, instant_rec_x, instant_rec_y, ax, edges):\n", " \"\"\"Generates a graph plot of the mouse\"\"\"\n", " plots = []\n", " rec_plots = []\n", " for edge in edges:\n", " (temp_plot,) = ax.plot(\n", " [float(instant_x[edge[0]]), float(instant_x[edge[1]])],\n", " [float(instant_y[edge[0]]), float(instant_y[edge[1]])],\n", " color=\"#006699\",\n", " linewidth=2.0,\n", " )\n", " (temp_rec_plot,) = ax.plot(\n", " [float(instant_rec_x[edge[0]]), float(instant_rec_x[edge[1]])],\n", " [float(instant_rec_y[edge[0]]), float(instant_rec_y[edge[1]])],\n", " color=\"red\",\n", " linewidth=2.0,\n", " )\n", " plots.append(temp_plot)\n", " rec_plots.append(temp_rec_plot)\n", " return plots, rec_plots\n", "\n", "\n", "def update_mouse_graph(x, y, rec_x, rec_y, plots, rec_plots, edges):\n", " \"\"\"Updates the graph plot to enable animation\"\"\"\n", "\n", " for plot, edge in zip(plots, edges):\n", " plot.set_data(\n", " [float(x[edge[0]]), float(x[edge[1]])],\n", " [float(y[edge[0]]), float(y[edge[1]])],\n", " )\n", " for plot, edge in zip(rec_plots, edges):\n", " plot.set_data(\n", " [float(rec_x[edge[0]]), float(rec_x[edge[1]])],\n", " [float(rec_y[edge[0]]), float(rec_y[edge[1]])],\n", " )" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test 57_s41\n" ] }, { "data": { "text/html": [ "