{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# deepOF model evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given a dataset and a trained model, this notebook allows the user to \n", "\n", "* Load and inspect the different models (encoder, decoder, grouper, gmvaep)\n", "* Visualize reconstruction quality for a given model\n", "* Visualize a static latent space\n", "* Visualize trajectories on the latent space for a given video\n", "* sample from the latent space distributions and generate video clips showcasing generated data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.chdir(os.path.dirname(\"../\"))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import deepof.data\n", "import deepof.utils\n", "import numpy as np\n", "import pandas as pd\n", "import tensorflow as tf\n", "from collections import Counter\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "from sklearn.manifold import TSNE\n", "from sklearn.decomposition import PCA\n", "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", "import umap\n", "\n", "from ipywidgets import interact\n", "from IPython import display\n", "from matplotlib.animation import FuncAnimation\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "from ipywidgets import interact" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Define and run project" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "path = os.path.join(\"..\", \"..\", \"Desktop\", \"deepoftesttemp\")\n", "trained_network = os.path.join(\"..\", \"..\", \"Desktop\")\n", "exclude_bodyparts = [\"Tail_1\", \"Tail_2\", \"Tail_tip\", \"Tail_base\"]\n", "window_size = 11" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 318 ms, sys: 27.5 ms, total: 345 ms\n", "Wall time: 320 ms\n" ] } ], "source": [ "%%time\n", "proj = deepof.data.project(\n", " path=path, smooth_alpha=0.99, exclude_bodyparts=exclude_bodyparts, arena_dims=[380],\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading trajectories...\n", "Smoothing trajectories...\n", "Interpolating outliers...\n", "Iterative imputation of ocluded bodyparts...\n", "Computing distances...\n", "Computing angles...\n", "Done!\n", "deepof analysis of 2 videos\n", "CPU times: user 2.7 s, sys: 111 ms, total: 2.81 s\n", "Wall time: 723 ms\n" ] } ], "source": [ "%%time\n", "proj = proj.run(verbose=True)\n", "print(proj)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Load pretrained deepof model" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Set model parameters\n", "encoding = 6\n", "loss = \"ELBO\"\n", "k = 25\n", "pheno = 0\n", "predictor = 0" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['GMVAE_loss=ELBO_encoding=6_k=25_latreg=none_20210312-084005_final_weights.h5',\n", " 'GMVAE_loss=ELBO_encoding=6_k=25_latreg=variance_20210312-090508_final_weights.h5',\n", " 'GMVAE_loss=ELBO_encoding=6_k=25_latreg=categorical+variance_20210312-085926_final_weights.h5',\n", " 'GMVAE_loss=ELBO_encoding=6_k=25_latreg=categorical_20210312-093339_final_weights.h5']" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[i for i in os.listdir(trained_network) if i.endswith(\"h5\")]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "coords = proj.get_coords(center=\"Center\", align=\"Spine_1\", align_inplace=True)\n", "coords = coords.preprocess(test_videos=0, window_step=1, window_size=11, shuffle=True)[\n", " 0\n", "]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "encoder, decoder, grouper, gmvaep = deepof.models.SEQ_2_SEQ_GMVAE(\n", " loss=loss,\n", " number_of_components=k,\n", " compile_model=True,\n", " encoding=encoding,\n", " predictor=predictor,\n", " phenotype_prediction=pheno,\n", ").build(coords.shape)[:4]\n", "\n", "gmvaep.load_weights(\n", " os.path.join(\n", " trained_network, [i for i in os.listdir(trained_network) if i.endswith(\"h5\")][0]\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Uncomment to see model summaries\n", "# encoder.summary()\n", "# decoder.summary()\n", "# grouper.summary()\n", "# gmvaep.summary()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Uncomment to plot model structure\n", "def plot_model(model, name):\n", " tf.keras.utils.plot_model(\n", " model,\n", " to_file=os.path.join(\n", " path,\n", " \"deepof_{}_{}.png\".format(name, datetime.now().strftime(\"%Y%m%d-%H%M%S\")),\n", " ),\n", " show_shapes=True,\n", " show_dtype=False,\n", " show_layer_names=True,\n", " rankdir=\"TB\",\n", " expand_nested=True,\n", " dpi=200,\n", " )\n", "\n", "\n", "# plot_model(encoder, \"encoder\")\n", "# plot_model(decoder, \"decoder\")\n", "# plot_model(grouper, \"grouper\")\n", "# plot_model(gmvaep, \"gmvaep\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Evaluate reconstruction (to be incorporated into deepof.evaluate)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Auxiliary animation functions\n", "\n", "\n", "def plot_mouse_graph(instant_x, instant_y, instant_rec_x, instant_rec_y, ax, edges):\n", " \"\"\"Generates a graph plot of the mouse\"\"\"\n", " plots = []\n", " rec_plots = []\n", " for edge in edges:\n", " (temp_plot,) = ax.plot(\n", " [float(instant_x[edge[0]]), float(instant_x[edge[1]])],\n", " [float(instant_y[edge[0]]), float(instant_y[edge[1]])],\n", " color=\"#006699\",\n", " linewidth=2.0,\n", " )\n", " (temp_rec_plot,) = ax.plot(\n", " [float(instant_rec_x[edge[0]]), float(instant_rec_x[edge[1]])],\n", " [float(instant_rec_y[edge[0]]), float(instant_rec_y[edge[1]])],\n", " color=\"red\",\n", " linewidth=2.0,\n", " )\n", " plots.append(temp_plot)\n", " rec_plots.append(temp_rec_plot)\n", " return plots, rec_plots\n", "\n", "\n", "def update_mouse_graph(x, y, rec_x, rec_y, plots, rec_plots, edges):\n", " \"\"\"Updates the graph plot to enable animation\"\"\"\n", "\n", " for plot, edge in zip(plots, edges):\n", " plot.set_data(\n", " [float(x[edge[0]]), float(x[edge[1]])],\n", " [float(y[edge[0]]), float(y[edge[1]])],\n", " )\n", " for plot, edge in zip(rec_plots, edges):\n", " plot.set_data(\n", " [float(rec_x[edge[0]]), float(rec_x[edge[1]])],\n", " [float(rec_y[edge[0]]), float(rec_y[edge[1]])],\n", " )" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test 1_s11\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Display a video with the original data superimposed with the reconstructions\n", "\n", "coords = proj.get_coords(center=\"Center\", align=\"Spine_1\", align_inplace=True)\n", "random_exp = np.random.choice(list(coords.keys()), 1)[0]\n", "print(random_exp)\n", "\n", "\n", "def animate_mice_across_time(random_exp):\n", "\n", " # Define canvas\n", " fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n", "\n", " # Retrieve body graph\n", " edges = deepof.utils.connect_mouse_topview()\n", "\n", " for bpart in exclude_bodyparts:\n", " edges.remove_node(bpart)\n", "\n", " for limb in [\"Left_fhip\", \"Right_fhip\", \"Left_bhip\", \"Right_bhip\"]:\n", " edges.remove_edge(\"Center\", limb)\n", "\n", " edges = edges.edges()\n", "\n", " # Compute observed and predicted data to plot\n", " data = coords[random_exp]\n", " coords_rec = coords.filter_videos([random_exp])\n", " data_prep = coords_rec.preprocess(\n", " test_videos=0, window_step=1, window_size=window_size, shuffle=False\n", " )[0]\n", "\n", " data_rec = gmvaep.predict(data_prep)\n", " data_rec = pd.DataFrame(coords_rec._scaler.inverse_transform(data_rec[:, 6, :]))\n", " data_rec.columns = data.columns\n", " data = pd.DataFrame(coords_rec._scaler.inverse_transform(data_prep[:, 6, :]))\n", " data.columns = data_rec.columns\n", "\n", " # Add Central coordinate, lost during alignment\n", " data[\"Center\", \"x\"] = 0\n", " data[\"Center\", \"y\"] = 0\n", " data_rec[\"Center\", \"x\"] = 0\n", " data_rec[\"Center\", \"y\"] = 0\n", "\n", " # Plot!\n", " init_x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[0, :]\n", " init_y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[0, :]\n", " init_rec_x = data_rec.xs(\"x\", level=1, axis=1, drop_level=False).iloc[0, :]\n", " init_rec_y = data_rec.xs(\"y\", level=1, axis=1, drop_level=False).iloc[0, :]\n", "\n", " plots, rec_plots = plot_mouse_graph(\n", " init_x, init_y, init_rec_x, init_rec_y, ax, edges\n", " )\n", " scatter = ax.scatter(\n", " x=np.array(init_x), y=np.array(init_y), color=\"#006699\", label=\"Original\"\n", " )\n", " rec_scatter = ax.scatter(\n", " x=np.array(init_rec_x),\n", " y=np.array(init_rec_y),\n", " color=\"red\",\n", " label=\"Reconstruction\",\n", " )\n", "\n", " # Update data in main plot\n", " def animation_frame(i):\n", " # Update scatter plot\n", " x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[i, :]\n", " y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[i, :]\n", " rec_x = data_rec.xs(\"x\", level=1, axis=1, drop_level=False).iloc[i, :]\n", " rec_y = data_rec.xs(\"y\", level=1, axis=1, drop_level=False).iloc[i, :]\n", "\n", " scatter.set_offsets(np.c_[np.array(x), np.array(y)])\n", " rec_scatter.set_offsets(np.c_[np.array(rec_x), np.array(rec_y)])\n", " update_mouse_graph(x, y, rec_x, rec_y, plots, rec_plots, edges)\n", "\n", " return scatter\n", "\n", " animation = FuncAnimation(fig, func=animation_frame, frames=250, interval=50,)\n", "\n", " ax.set_title(\"Original versus reconstructed data\")\n", " ax.set_ylim(-100, 60)\n", " ax.set_xlim(-60, 60)\n", " ax.set_xlabel(\"x\")\n", " ax.set_ylabel(\"y\")\n", " plt.legend()\n", "\n", " video = animation.to_html5_video()\n", " html = display.HTML(video)\n", " display.display(html)\n", " plt.close()\n", "\n", "\n", "animate_mice_across_time(random_exp)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5. Evaluate latent space (to be incorporated into deepof.evaluate)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Get encodings and groupings for the same random video as above\n", "data_prep = coords.preprocess(\n", " test_videos=0, window_step=1, window_size=window_size, shuffle=False\n", ")[0]\n", "\n", "encodings = encoder.predict(data_prep)\n", "groupings = grouper.predict(data_prep)\n", "hard_groups = np.argmax(groupings, axis=1)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d8cd7c45d5bc4bad95fe95648d882daf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "interactive(children=(FloatSlider(value=0.5, description='minimum_confidence', max=1.0, step=0.01), Output()),…" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "@interact(minimum_confidence=(0.0, 1.0, 0.01))\n", "def plot_cluster_confidence(minimum_confidence):\n", " plt.figure(figsize=(12, 8))\n", "\n", " groups = hard_groups[np.max(groupings, axis=1) > minimum_confidence].flatten()\n", " groups = np.concatenate([groups, np.arange(25)])\n", " sns.countplot(groups)\n", " plt.xlabel(\"Cluster\")\n", " plt.title(\"Training instances per cluster\")\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The slider in the figure above lets you set the minimum confidence the model may yield when assigning a training instance to a cluster in order to be visualized." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e69ff9850be74d97af09fd1aa32b53fc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "interactive(children=(IntSlider(value=5500, description='samples', max=10000, min=1000, step=500), FloatSlider…" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot real data in the latent space\n", "\n", "\n", "@interact(\n", " samples=(1000, 10000, 500),\n", " minimum_confidence=(0.0, 0.99, 0.01),\n", " dim_red=[\"LDA\", \"PCA\", \"umap\", \"tSNE\"],\n", ")\n", "def plot_cluster_confidence(samples, minimum_confidence, dim_red):\n", " if dim_red == \"umap\":\n", " reducer = umap.UMAP(n_components=2)\n", " elif dim_red == \"LDA\":\n", " reducer = LinearDiscriminantAnalysis(n_components=2)\n", " elif dim_red == \"PCA\":\n", " reducer = PCA(n_components=2)\n", " else:\n", " reducer = TSNE(n_components=2)\n", "\n", " encods = encodings[np.max(groupings, axis=1) > minimum_confidence]\n", " groups = groupings[np.max(groupings, axis=1) > minimum_confidence]\n", " hgroups = hard_groups[np.max(groupings, axis=1) > minimum_confidence].flatten()\n", "\n", " samples = np.random.choice(range(encods.shape[0]), samples)\n", " sample_enc = encods[samples, :]\n", " sample_grp = groups[samples, :]\n", " sample_hgr = hgroups[samples]\n", "\n", " if dim_red != \"LDA\":\n", " enc = reducer.fit_transform(sample_enc)\n", " else:\n", " enc = reducer.fit_transform(sample_enc, sample_hgr)\n", "\n", " plt.figure(figsize=(12, 8))\n", "\n", " sns.scatterplot(x=enc[:, 0], y=enc[:, 1], hue=sample_hgr, palette=\"jet\")\n", " plt.xlabel(\"{} 1\".format(dim_red))\n", " plt.ylabel(\"{} 2\".format(dim_red))\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Plot trajectory of a video in latent space" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 4 }