{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# deepOF model evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given a dataset and a trained model, this notebook allows the user to \n", "\n", "* Load and inspect the different models (encoder, decoder, grouper, gmvaep)\n", "* Visualize reconstruction quality for a given model\n", "* Visualize a static latent space\n", "* Visualize trajectories on the latent space for a given video\n", "* sample from the latent space distributions and generate video clips showcasing generated data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.chdir(os.path.dirname(\"../\"))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "import deepof.data\n", "import deepof.utils\n", "import numpy as np\n", "import pandas as pd\n", "import tensorflow as tf\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "from ipywidgets import interact\n", "from IPython import display\n", "from matplotlib.animation import FuncAnimation\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "from ipywidgets import interact" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Define and run project" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "path = os.path.join(\"..\", \"..\", \"Desktop\", \"deepoftesttemp\")\n", "trained_network = os.path.join(\"..\", \"..\", \"Desktop\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 297 ms, sys: 23.9 ms, total: 321 ms\n", "Wall time: 270 ms\n" ] } ], "source": [ "%%time\n", "proj = deepof.data.project(\n", " path=path,\n", " smooth_alpha=0.99,\n", " exclude_bodyparts=[\"Tail_1\", \"Tail_2\", \"Tail_tip\", \"Tail_base\"],\n", " arena_dims=[380],\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading trajectories...\n", "Smoothing trajectories...\n", "Interpolating outliers...\n", "Iterative imputation of ocluded bodyparts...\n", "Computing distances...\n", "Computing angles...\n", "Done!\n", "deepof analysis of 2 videos\n", "CPU times: user 2.63 s, sys: 86.3 ms, total: 2.72 s\n", "Wall time: 683 ms\n" ] } ], "source": [ "%%time\n", "proj = proj.run(verbose=True)\n", "print(proj)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Load pretrained deepof model" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "coords = proj.get_coords(center=\"Center\", align=\"Spine_1\", align_inplace=True)\n", "preprocessed_data,_,_,_ = coords.preprocess(test_videos=0, window_step=5, window_size=24)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Set model parameters\n", "encoding=6\n", "loss=\"ELBO\"\n", "k=25\n", "pheno=0\n", "predictor=0" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "encoder, decoder, grouper, gmvaep = deepof.models.SEQ_2_SEQ_GMVAE(\n", " loss=loss,\n", " number_of_components=k,\n", " compile_model=True,\n", " encoding=encoding,\n", " predictor=predictor,\n", " phenotype_prediction=pheno,\n", ").build(preprocessed_data.shape)[:4]\n", "\n", "gmvaep.load_weights(\n", " os.path.join(\n", " trained_network, [i for i in os.listdir(trained_network) if i.endswith(\"h5\")][0]\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Uncomment to see model summaries\n", "# encoder.summary()\n", "# decoder.summary()\n", "# grouper.summary()\n", "# gmvaep.summary()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# Uncomment to plot model structure\n", "def plot_model(model, name):\n", " tf.keras.utils.plot_model(\n", " model,\n", " to_file=os.path.join(\n", " path,\n", " \"deepof_{}_{}.png\".format(name, datetime.now().strftime(\"%Y%m%d-%H%M%S\")),\n", " ),\n", " show_shapes=True,\n", " show_dtype=False,\n", " show_layer_names=True,\n", " rankdir=\"TB\",\n", " expand_nested=True,\n", " dpi=200,\n", " )\n", "\n", "\n", "# plot_model(encoder, \"encoder\")\n", "# plot_model(decoder, \"decoder\")\n", "# plot_model(grouper, \"grouper\")\n", "# plot_model(gmvaep, \"gmvaep\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Pass data through all models" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "encodings = encoder.predict(preprocessed_data)\n", "groupings = grouper.predict(preprocessed_data)\n", "reconstrs = gmvaep.predict(preprocessed_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5. Evaluate reconstruction (to be incorporated into deepof.evaluate)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Fit a scaler to the data, to back-transform reconstructions later\n", "scaler = StandardScaler().fit(preprocessed_data[:, 0, :])" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Rescale reconstructions\n", "rescaled_reconstructions = scaler.transform(\n", " reconstrs.reshape(reconstrs.shape[0] * reconstrs.shape[1], reconstrs.shape[2])\n", ")\n", "rescaled_reconstructions = rescaled_reconstructions.reshape(reconstrs.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Display a video with the original data superimposed with the reconstructions\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6. Evaluate latent space (to be incorporated into deepof.evaluate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 4 }